diff --git a/ARMeilleure/Decoders/Block.cs b/ARMeilleure/Decoders/Block.cs index 3d13c2d5e4..918fa0a77e 100644 --- a/ARMeilleure/Decoders/Block.cs +++ b/ARMeilleure/Decoders/Block.cs @@ -10,6 +10,7 @@ namespace ARMeilleure.Decoders public Block Next { get; set; } public Block Branch { get; set; } + public bool TailCall { get; set; } public List OpCodes { get; private set; } diff --git a/ARMeilleure/Decoders/Decoder.cs b/ARMeilleure/Decoders/Decoder.cs index c232760912..68c61ffa0d 100644 --- a/ARMeilleure/Decoders/Decoder.cs +++ b/ARMeilleure/Decoders/Decoder.cs @@ -140,9 +140,77 @@ namespace ARMeilleure.Decoders } } + RemoveTailCalls(address, blocks); + return blocks.ToArray(); } + private static void RemoveTailCalls(ulong entryAddress, List blocks) + { + // Detect tail calls: + // - Assume this function spans the space covered by contiguous code blocks surrounding the entry address. + // - Unconditional jump to an area outside this contiguous region will be treated as a tail call. + // - Include a small allowance for jumps outside the contiguous range. + + if (!BinarySearch(blocks, entryAddress, out int entryBlockId)) + { + throw new InvalidOperationException("Function entry point is not contained in a block."); + } + + ulong allowance = 4; + Block entryBlock = blocks[entryBlockId]; + int startBlockIndex = entryBlockId; + Block startBlock = entryBlock; + int endBlockIndex = entryBlockId; + Block endBlock = entryBlock; + + for (int i = entryBlockId + 1; i < blocks.Count; i++) // Search forwards. + { + Block block = blocks[i]; + if (endBlock.EndAddress < block.Address - allowance) + { + break; // End of contiguous function. + } + + endBlock = block; + endBlockIndex = i; + } + + for (int i = entryBlockId - 1; i >= 0; i--) // Search backwards. + { + Block block = blocks[i]; + if (startBlock.Address > block.EndAddress + allowance) + { + break; // End of contiguous function. + } + + startBlock = block; + startBlockIndex = i; + } + + if (startBlockIndex == 0 && endBlockIndex == blocks.Count - 1) + { + return; // Nothing to do here. + } + + // Replace all branches to blocks outside the range with null, and force a tail call. + + for (int i = startBlockIndex; i <= endBlockIndex; i++) + { + Block block = blocks[i]; + if (block.Branch != null && (block.Branch.Address > endBlock.EndAddress || block.Branch.EndAddress < startBlock.Address)) + { + block.Branch = null; + block.TailCall = true; + } + } + + // Finally, delete all blocks outside the contiguous range. + + blocks.RemoveRange(endBlockIndex + 1, (blocks.Count - endBlockIndex) - 1); + blocks.RemoveRange(0, startBlockIndex); + } + private static bool BinarySearch(List blocks, ulong address, out int index) { index = 0; diff --git a/ARMeilleure/Instructions/InstEmitException.cs b/ARMeilleure/Instructions/InstEmitException.cs index 6f7b6fd51f..f0bde242a6 100644 --- a/ARMeilleure/Instructions/InstEmitException.cs +++ b/ARMeilleure/Instructions/InstEmitException.cs @@ -2,6 +2,7 @@ using ARMeilleure.Decoders; using ARMeilleure.Translation; using System; +using static ARMeilleure.Instructions.InstEmitFlowHelper; using static ARMeilleure.IntermediateRepresentation.OperandHelper; namespace ARMeilleure.Instructions @@ -30,7 +31,7 @@ namespace ARMeilleure.Instructions if (context.CurrBlock.Next == null) { - context.Return(Const(op.Address + 4)); + EmitTailContinue(context, Const(op.Address + 4)); } } @@ -48,7 +49,7 @@ namespace ARMeilleure.Instructions if (context.CurrBlock.Next == null) { - context.Return(Const(op.Address + 4)); + EmitTailContinue(context, Const(op.Address + 4)); } } } diff --git a/ARMeilleure/Instructions/InstEmitFlow.cs b/ARMeilleure/Instructions/InstEmitFlow.cs index 918dace8bb..bac9ec588c 100644 --- a/ARMeilleure/Instructions/InstEmitFlow.cs +++ b/ARMeilleure/Instructions/InstEmitFlow.cs @@ -21,7 +21,7 @@ namespace ARMeilleure.Instructions } else { - context.Return(Const(op.Immediate)); + EmitTailContinue(context, Const(op.Immediate), context.CurrBlock.TailCall); } } @@ -96,7 +96,7 @@ namespace ARMeilleure.Instructions if (context.CurrBlock.Next == null) { - context.Return(Const(op.Address + 4)); + EmitTailContinue(context, Const(op.Address + 4)); } } else @@ -105,11 +105,11 @@ namespace ARMeilleure.Instructions EmitCondBranch(context, lblTaken, cond); - context.Return(Const(op.Address + 4)); + EmitTailContinue(context, Const(op.Address + 4)); context.MarkLabel(lblTaken); - context.Return(Const(op.Immediate)); + EmitTailContinue(context, Const(op.Immediate)); } } @@ -132,7 +132,7 @@ namespace ARMeilleure.Instructions if (context.CurrBlock.Next == null) { - context.Return(Const(op.Address + 4)); + EmitTailContinue(context, Const(op.Address + 4)); } } else @@ -148,11 +148,11 @@ namespace ARMeilleure.Instructions context.BranchIfFalse(lblTaken, value); } - context.Return(Const(op.Address + 4)); + EmitTailContinue(context, Const(op.Address + 4)); context.MarkLabel(lblTaken); - context.Return(Const(op.Immediate)); + EmitTailContinue(context, Const(op.Immediate)); } } } diff --git a/ARMeilleure/Instructions/InstEmitFlowHelper.cs b/ARMeilleure/Instructions/InstEmitFlowHelper.cs index 054bf90154..b068a3fc35 100644 --- a/ARMeilleure/Instructions/InstEmitFlowHelper.cs +++ b/ARMeilleure/Instructions/InstEmitFlowHelper.cs @@ -142,7 +142,7 @@ namespace ARMeilleure.Instructions public static void EmitCall(ArmEmitterContext context, ulong immediate) { - EmitJumpTableCall(context, Const(immediate)); + EmitJumpTableBranch(context, Const(immediate)); } private static void EmitNativeCall(ArmEmitterContext context, Operand funcAddr, bool isJump = false) @@ -158,16 +158,7 @@ namespace ARMeilleure.Instructions returnAddress = context.Call(funcAddr, OperandType.I64, context.LoadArgument(OperandType.I64, 0)); context.LoadFromContext(); - // InstEmitFlowHelper.EmitContinueOrReturnCheck(context, returnAddress); - - // If the return address isn't to our next instruction, we need to return to the JIT can figure out what to do. - Operand continueLabel = Label(); - Operand next = Const(GetNextOpAddress(context.CurrOp)); - context.BranchIfTrue(continueLabel, context.ICompareEqual(context.BitwiseAnd(returnAddress, Const(~1L)), next)); - - context.Return(returnAddress); - - context.MarkLabel(continueLabel); + EmitContinueOrReturnCheck(context, returnAddress); } } @@ -189,31 +180,34 @@ namespace ARMeilleure.Instructions } else { - EmitJumpTableCall(context, target, isJump); + EmitJumpTableBranch(context, target, isJump); } } - public static void EmitContinueOrReturnCheck(ArmEmitterContext context, Operand retVal) + public static void EmitContinueOrReturnCheck(ArmEmitterContext context, Operand returnAddress) { // Note: The return value of the called method will be placed // at the Stack, the return value is always a Int64 with the // return address of the function. We check if the address is // correct, if it isn't we keep returning until we reach the dispatcher. - ulong nextAddr = GetNextOpAddress(context.CurrOp); + Operand nextAddr = Const(GetNextOpAddress(context.CurrOp)); if (context.CurrBlock.Next != null) { + // Try to continue within this block. + // If the return address isn't to our next instruction, we need to return to the JIT can figure out what to do. Operand lblContinue = Label(); - context.BranchIfTrue(lblContinue, context.ICompareEqual(retVal, Const(nextAddr))); + context.BranchIfTrue(lblContinue, context.ICompareEqual(context.BitwiseAnd(returnAddress, Const(~1L)), nextAddr)); - context.Return(Const(nextAddr)); + context.Return(returnAddress); context.MarkLabel(lblContinue); } else { - context.Return(Const(nextAddr)); + // No code following this instruction, ask the translator with return address and jump to it. + EmitTailContinue(context, nextAddr); } } @@ -222,23 +216,37 @@ namespace ARMeilleure.Instructions return op.Address + (ulong)op.OpCodeSizeInBytes; } + public static void EmitTailContinue(ArmEmitterContext context, Operand address, bool allowRejit = false) + { + bool complexShit = true; + if (complexShit) + { + if (allowRejit) + { + address = context.BitwiseOr(address, Const(1L)); + } + + Operand fallbackAddr = context.Call(new _U64_U64(NativeInterface.GetFunctionAddress), address); + + EmitNativeCall(context, fallbackAddr, true); + } + else + { + context.Return(address); + } + } + public static void EmitDynamicTableCall(ArmEmitterContext context, Operand tableAddress, Operand address, bool isJump) { - if (address.Type == OperandType.I32) - { - address = context.ZeroExtend32(OperandType.I64, address); - } - // Loop over elements of the dynamic table. Unrolled loop. - // TODO: different reserved size for jumps? Need to do some testing to see what is reasonable. Operand endLabel = Label(); Operand fallbackLabel = Label(); + // Currently this uses a size of 1, as higher values inflate code size for no real benefit. for (int i = 0; i < JumpTable.DynamicTableElems; i++) { // TODO: USE COMPARE AND SWAP I64 TO ENSURE ATOMIC OPERATIONS - Operand nextLabel = Label(); // Load this entry from the table. @@ -265,6 +273,7 @@ namespace ARMeilleure.Instructions Operand missingFunctionLabel = Label(); Operand targetFunctionPtr = context.Add(tableAddress, Const(8)); Operand targetFunction = context.Load(OperandType.I64, targetFunctionPtr); + context.BranchIfFalse(missingFunctionLabel, targetFunction); // Call the function. @@ -273,7 +282,7 @@ namespace ARMeilleure.Instructions // We need to find the missing function. This can only be from a list of HighCq functions, which the JumpTable maintains. context.MarkLabel(missingFunctionLabel); - Operand goodCallAddr = context.Call(new _U64_U64(context.JumpTable.TryGetFunction), address); // TODO: NativeInterface call to it? (easier to AOT) + Operand goodCallAddr = context.Call(new _U64_U64(NativeInterface.GetHighCqFunctionAddress), address); context.BranchIfFalse(fallbackLabel, goodCallAddr); // Fallback if it doesn't exist yet. @@ -296,31 +305,36 @@ namespace ARMeilleure.Instructions context.MarkLabel(endLabel); } - public static void EmitJumpTableCall(ArmEmitterContext context, Operand address, bool isJump = false) + public static void EmitJumpTableBranch(ArmEmitterContext context, Operand address, bool isJump = false) { - // Does the call have a constant value, or can it be folded to one? + if (address.Type == OperandType.I32) + { + address = context.ZeroExtend32(OperandType.I64, address); + } + // TODO: Constant folding. Indirect calls are slower in the best case and emit more code so we want to avoid them when possible. bool isConst = address.Kind == OperandKind.Constant; long constAddr = (long)address.Value; - if (isJump || !isConst || !context.HighCq) + if (!context.HighCq) { - if (context.HighCq) - { - // Virtual branch/call - store first used addresses on a small table for fast lookup. - int entry = context.JumpTable.ReserveDynamicEntry(); + // Don't emit indirect calls or jumps if we're compiling in lowCq mode. + // This avoids wasting space on the jump and indirect tables. + // Just ask the translator for the function address. - int jumpOffset = entry * JumpTable.JumpTableStride * JumpTable.DynamicTableElems; - Operand dynTablePtr = Const(context.JumpTable.DynamicPointer.ToInt64() + jumpOffset); + address = context.BitwiseOr(address, Const(address.Type, 1)); // Set call flag. + Operand fallbackAddr = context.Call(new _U64_U64(NativeInterface.GetFunctionAddress), address); + EmitNativeCall(context, fallbackAddr, isJump); + } + else if (!isConst) + { + // Virtual branch/call - store first used addresses on a small table for fast lookup. + int entry = context.JumpTable.ReserveDynamicEntry(); - EmitDynamicTableCall(context, dynTablePtr, address, isJump); - } - else - { - // Don't emit indirect calls or jumps if we're compiling in lowCq mode. - // This avoids wasting space on the jump and indirect tables. - context.Return(context.BitwiseOr(address, Const(address.Type, 1))); // Set call flag. - } + int jumpOffset = entry * JumpTable.JumpTableStride * JumpTable.DynamicTableElems; + Operand dynTablePtr = Const(context.JumpTable.DynamicPointer.ToInt64() + jumpOffset); + + EmitDynamicTableCall(context, dynTablePtr, address, isJump); } else { @@ -336,19 +350,19 @@ namespace ARMeilleure.Instructions Operand directCallLabel = Label(); Operand endLabel = Label(); - // Host address in the table is 0 until the function is rejit. + // Host address in the table is 0 until the function is rejit. Use fallback until then. context.BranchIfTrue(directCallLabel, funcAddr); // Call the function through the translator until it is rejit. address = context.BitwiseOr(address, Const(address.Type, 1)); // Set call flag. Operand fallbackAddr = context.Call(new _U64_U64(NativeInterface.GetFunctionAddress), address); - EmitNativeCall(context, fallbackAddr); + EmitNativeCall(context, fallbackAddr, isJump); context.Branch(endLabel); context.MarkLabel(directCallLabel); - EmitNativeCall(context, funcAddr); // Call the function directly if it is present in the entry. + EmitNativeCall(context, funcAddr, isJump); // Call the function directly if it is present in the entry. context.MarkLabel(endLabel); } diff --git a/ARMeilleure/Instructions/NativeInterface.cs b/ARMeilleure/Instructions/NativeInterface.cs index 233315c4e0..c32f5ca6ae 100644 --- a/ARMeilleure/Instructions/NativeInterface.cs +++ b/ARMeilleure/Instructions/NativeInterface.cs @@ -390,6 +390,12 @@ namespace ARMeilleure.Instructions return (ulong)function.GetPointer().ToInt64(); } + public static ulong GetHighCqFunctionAddress(ulong address) + { + TranslatedFunction function = _context.Translator.TryGetHighCqFunction(address); + return (function != null) ? (ulong)function.GetPointer().ToInt64() : 0; + } + public static void ClearExclusive() { _context.ExclusiveAddress = ulong.MaxValue; diff --git a/ARMeilleure/Translation/JumpTable.cs b/ARMeilleure/Translation/JumpTable.cs index e07e09fb79..133a21ff51 100644 --- a/ARMeilleure/Translation/JumpTable.cs +++ b/ARMeilleure/Translation/JumpTable.cs @@ -36,7 +36,7 @@ namespace ARMeilleure.Translation private const int DynamicTableSize = 1048576; - public const int DynamicTableElems = 10; + public const int DynamicTableElems = 1; private const int DynamicTableByteSize = DynamicTableSize * JumpTableStride * DynamicTableElems; @@ -77,14 +77,14 @@ namespace ARMeilleure.Translation } } - public ulong TryGetFunction(ulong address) + public TranslatedFunction TryGetFunction(ulong address) { TranslatedFunction result; if (_targets.TryGetValue(address, out result)) { - return (ulong)result.GetPointer().ToInt64(); + return result; } - return 0; + return null; } public int ReserveDynamicEntry() diff --git a/ARMeilleure/Translation/Translator.cs b/ARMeilleure/Translation/Translator.cs index 99312b7bf4..51ddbbbc3c 100644 --- a/ARMeilleure/Translation/Translator.cs +++ b/ARMeilleure/Translation/Translator.cs @@ -127,13 +127,20 @@ namespace ARMeilleure.Translation return func; } + internal TranslatedFunction TryGetHighCqFunction(ulong address) + { + return _jumpTable.TryGetFunction(address); + } + private TranslatedFunction Translate(ulong address, ExecutionMode mode, bool highCq) { ArmEmitterContext context = new ArmEmitterContext(_memory, _jumpTable, (long)address, highCq, Aarch32Mode.User); Logger.StartPass(PassName.Decoding); - Block[] blocks = highCq + bool alwaysFunctions = true; + + Block[] blocks = alwaysFunctions ? Decoder.DecodeFunction (_memory, address, mode) : Decoder.DecodeBasicBlock(_memory, address, mode); @@ -221,7 +228,7 @@ namespace ARMeilleure.Translation // with some kind of branch). if (isLastOp && block.Next == null) { - context.Return(Const(opCode.Address + (ulong)opCode.OpCodeSizeInBytes)); + InstEmitFlowHelper.EmitTailContinue(context, Const(opCode.Address + (ulong)opCode.OpCodeSizeInBytes)); } } }