diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs index 4729f694..c4568995 100644 --- a/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs +++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs @@ -10,6 +10,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 public readonly List Instructions; public readonly bool EndsWithBranch; public readonly bool HasHostCall; + public readonly bool HasHostCallSkipContext; public readonly bool IsTruncated; public readonly bool IsLoopEnd; public readonly bool IsThumb; @@ -20,6 +21,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 List instructions, bool endsWithBranch, bool hasHostCall, + bool hasHostCallSkipContext, bool isTruncated, bool isLoopEnd, bool isThumb) @@ -31,6 +33,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 Instructions = instructions; EndsWithBranch = endsWithBranch; HasHostCall = hasHostCall; + HasHostCallSkipContext = hasHostCallSkipContext; IsTruncated = isTruncated; IsLoopEnd = isLoopEnd; IsThumb = isThumb; @@ -57,6 +60,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 Instructions.GetRange(0, splitIndex), false, HasHostCall, + HasHostCallSkipContext, false, false, IsThumb); @@ -67,6 +71,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 Instructions.GetRange(splitIndex, splitCount), EndsWithBranch, HasHostCall, + HasHostCallSkipContext, IsTruncated, IsLoopEnd, IsThumb); diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs index e0a18e66..8a2b389a 100644 --- a/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs +++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs @@ -208,6 +208,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 InstMeta meta; InstFlags extraFlags = InstFlags.None; bool hasHostCall = false; + bool hasHostCallSkipContext = false; bool isTruncated = false; do @@ -246,9 +247,17 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 meta = InstTableA32.GetMeta(encoding, cpuPreset.Version, cpuPreset.Features); } - if (meta.Name.IsSystemOrCall() && !hasHostCall) + if (meta.Name.IsSystemOrCall()) { - hasHostCall = meta.Name.IsCall() || InstEmitSystem.NeedsCall(meta.Name); + if (!hasHostCall) + { + hasHostCall = InstEmitSystem.NeedsCall(meta.Name); + } + + if (!hasHostCallSkipContext) + { + hasHostCallSkipContext = meta.Name.IsCall() || InstEmitSystem.NeedsCallSkipContext(meta.Name); + } } insts.Add(new(encoding, meta.Name, meta.EmitFunc, meta.Flags | extraFlags)); @@ -259,8 +268,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 if (!isTruncated && IsBackwardsBranch(meta.Name, encoding)) { - hasHostCall = true; isLoopEnd = true; + hasHostCallSkipContext = true; } return new( @@ -269,6 +278,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 insts, !isTruncated, hasHostCall, + hasHostCallSkipContext, isTruncated, isLoopEnd, isThumb); diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs index a213c222..ca25057f 100644 --- a/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs +++ b/src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs @@ -6,6 +6,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 { public readonly List Blocks; public readonly bool HasHostCall; + public readonly bool HasHostCallSkipContext; public readonly bool IsTruncated; public MultiBlock(List blocks) @@ -15,12 +16,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 Block block = blocks[0]; HasHostCall = block.HasHostCall; + HasHostCallSkipContext = block.HasHostCallSkipContext; for (int index = 1; index < blocks.Count; index++) { block = blocks[index]; HasHostCall |= block.HasHostCall; + HasHostCallSkipContext |= block.HasHostCallSkipContext; } block = blocks[^1]; diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs index 6c705722..4a3f03b8 100644 --- a/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs +++ b/src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs @@ -106,6 +106,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32 if ((regMask & AbiConstants.ReservedRegsMask) == 0) { _gprMask |= regMask; + UsedGprsMask |= regMask; return firstCalleeSaved; } diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs index 1e8a8915..a668b577 100644 --- a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs +++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs @@ -305,12 +305,23 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 ForceConditionalEnd(cgContext, ref lastCondition, lastConditionIp); } + int reservedStackSize = 0; + + if (multiBlock.HasHostCall) + { + reservedStackSize = CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask); + } + else if (multiBlock.HasHostCallSkipContext) + { + reservedStackSize = 2 * sizeof(ulong); // Context and page table pointers. + } + RegisterSaveRestore rsr = new( regAlloc.UsedGprsMask & AbiConstants.GprCalleeSavedRegsMask, regAlloc.UsedFpSimdMask & AbiConstants.FpSimdCalleeSavedRegsMask, OperandType.FP64, - multiBlock.HasHostCall, - multiBlock.HasHostCall ? CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask) : 0); + multiBlock.HasHostCall || multiBlock.HasHostCallSkipContext, + reservedStackSize); TailMerger tailMerger = new(); @@ -596,7 +607,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 name == InstName.Ldm || name == InstName.Ldmda || name == InstName.Ldmdb || - name == InstName.Ldmib) + name == InstName.Ldmib || + name == InstName.Pop) { // Arm32 does not have a return instruction, instead returns are implemented // either using BX LR (for leaf functions), or POP { ... PC }. @@ -711,7 +723,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 switch (type) { case BranchType.SyncPoint: - InstEmitSystem.WriteSyncPoint(context.Writer, context.RegisterAllocator, context.TailMerger, context.GetReservedStackOffset()); + InstEmitSystem.WriteSyncPoint( + context.Writer, + ref asm, + context.RegisterAllocator, + context.TailMerger, + context.GetReservedStackOffset(), + context.StoreToContext, + context.LoadFromContext); break; case BranchType.SoftwareInterrupt: context.StoreToContext(); diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs index 81e44ba0..3b1ff5a2 100644 --- a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs +++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs @@ -199,12 +199,12 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 } } - private static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset) + public static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset) { WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: true); } - private static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset) + public static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset) { WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: false); } diff --git a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs index be0976fd..07f9f86a 100644 --- a/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs +++ b/src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs @@ -354,11 +354,18 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 // All instructions that might do a host call should be included here. // That is required to reserve space on the stack for caller saved registers. + return name == InstName.Mrrc; + } + + public static bool NeedsCallSkipContext(InstName name) + { + // All instructions that might do a host call should be included here. + // That is required to reserve space on the stack for caller saved registers. + switch (name) { case InstName.Mcr: case InstName.Mrc: - case InstName.Mrrc: case InstName.Svc: case InstName.Udf: return true; @@ -372,7 +379,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 Assembler asm = new(writer); WriteCall(ref asm, regAlloc, GetBkptHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm); - WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset); + WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset); } public static void WriteSvc(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint svcId) @@ -380,7 +387,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 Assembler asm = new(writer); WriteCall(ref asm, regAlloc, GetSvcHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, svcId); - WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset); + WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset); } public static void WriteUdf(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint imm) @@ -388,7 +395,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 Assembler asm = new(writer); WriteCall(ref asm, regAlloc, GetUdfHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm); - WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset); + WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset); } public static void WriteReadCntpct(CodeWriter writer, RegisterAllocator regAlloc, int spillBaseOffset, int rt, int rt2) @@ -422,14 +429,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 WriteFill(ref asm, regAlloc, resultMask, skipContext: false, spillBaseOffset, tempRegister); } - public static void WriteSyncPoint(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset) - { - Assembler asm = new(writer); - - WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: false, spillBaseOffset); - } - - private static void WriteSyncPoint(CodeWriter writer, ref Assembler asm, RegisterAllocator regAlloc, TailMerger tailMerger, bool skipContext, int spillBaseOffset) + public static void WriteSyncPoint( + CodeWriter writer, + ref Assembler asm, + RegisterAllocator regAlloc, + TailMerger tailMerger, + int spillBaseOffset, + Action storeToContext = null, + Action loadFromContext = null) { int tempRegister = regAlloc.AllocateTempGprRegister(); @@ -440,7 +447,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 int branchIndex = writer.InstructionPointer; asm.Cbnz(rt, 0); - WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister); + storeToContext?.Invoke(); + WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister); Operand rn = Register(tempRegister == 0 ? 1 : 0); @@ -449,7 +457,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 tailMerger.AddConditionalZeroReturn(writer, asm, Register(0, OperandType.I32)); - WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister); + WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister); + loadFromContext?.Invoke(); asm.LdrRiUn(rt, Register(regAlloc.FixedContextRegister), NativeContextOffsets.CounterOffset); @@ -514,18 +523,31 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 private static void WriteSpill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister) { - WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: true); + if (skipContext) + { + InstEmitFlow.WriteSpillSkipContext(ref asm, regAlloc, spillOffset); + } + else + { + WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: true); + } } private static void WriteFill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister) { - WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: false); + if (skipContext) + { + InstEmitFlow.WriteFillSkipContext(ref asm, regAlloc, spillOffset); + } + else + { + WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: false); + } } private static void WriteSpillOrFill( ref Assembler asm, RegisterAllocator regAlloc, - bool skipContext, uint exceptMask, int spillOffset, int tempRegister, @@ -533,11 +555,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 { uint gprMask = regAlloc.UsedGprsMask & ~(AbiConstants.GprCalleeSavedRegsMask | exceptMask); - if (skipContext) - { - gprMask &= ~Compiler.UsableGprsMask; - } - if (!spill) { // We must reload the status register before reloading the GPRs, @@ -600,11 +617,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64 uint fpSimdMask = regAlloc.UsedFpSimdMask; - if (skipContext) - { - fpSimdMask &= ~Compiler.UsableFpSimdMask; - } - while (fpSimdMask != 0) { int reg = BitOperations.TrailingZeroCount(fpSimdMask);