Implement Jump Table for Native Calls

NOTE: this slows down rejit considerably! Not recommended to be used
without codegen optimisation or AOT.

- Does not work on Linux
- A32 needs an additional commit.
This commit is contained in:
riperiperi 2020-02-05 01:51:57 +00:00
parent 08c0e3829b
commit cc9ab3471b
15 changed files with 424 additions and 23 deletions

View file

@ -136,7 +136,7 @@ namespace ARMeilleure.CodeGen.Optimizations
private static bool HasSideEffects(Node node)
{
return (node is Operation operation) && operation.Instruction == Instruction.Call;
return (node is Operation operation) && (operation.Instruction == Instruction.Call || operation.Instruction == Instruction.Tailcall);
}
private static bool IsPropagableCopy(Operation operation)

View file

@ -117,6 +117,7 @@ namespace ARMeilleure.CodeGen.X86
Add(X86Instruction.Imul, new InstructionInfo(BadOp, 0x0000006b, 0x00000069, BadOp, 0x00000faf, InstructionFlags.None));
Add(X86Instruction.Imul128, new InstructionInfo(BadOp, BadOp, BadOp, BadOp, 0x050000f7, InstructionFlags.None));
Add(X86Instruction.Insertps, new InstructionInfo(BadOp, BadOp, BadOp, BadOp, 0x000f3a21, InstructionFlags.Vex | InstructionFlags.Prefix66));
Add(X86Instruction.Jmp, new InstructionInfo(0x040000ff, BadOp, BadOp, BadOp, BadOp, InstructionFlags.None));
Add(X86Instruction.Lea, new InstructionInfo(BadOp, BadOp, BadOp, BadOp, 0x0000008d, InstructionFlags.None));
Add(X86Instruction.Maxpd, new InstructionInfo(BadOp, BadOp, BadOp, BadOp, 0x00000f5f, InstructionFlags.Vex | InstructionFlags.Prefix66));
Add(X86Instruction.Maxps, new InstructionInfo(BadOp, BadOp, BadOp, BadOp, 0x00000f5f, InstructionFlags.Vex));
@ -480,6 +481,11 @@ namespace ARMeilleure.CodeGen.X86
}
}
public void Jmp(Operand dest)
{
WriteInstruction(dest, null, OperandType.None, X86Instruction.Jmp);
}
public void Lea(Operand dest, Operand source, OperandType type)
{
WriteInstruction(dest, source, type, X86Instruction.Lea);

View file

@ -76,6 +76,7 @@ namespace ARMeilleure.CodeGen.X86
Add(Instruction.Store16, GenerateStore16);
Add(Instruction.Store8, GenerateStore8);
Add(Instruction.Subtract, GenerateSubtract);
Add(Instruction.Tailcall, GenerateTailcall);
Add(Instruction.VectorCreateScalar, GenerateVectorCreateScalar);
Add(Instruction.VectorExtract, GenerateVectorExtract);
Add(Instruction.VectorExtract16, GenerateVectorExtract16);
@ -1083,6 +1084,13 @@ namespace ARMeilleure.CodeGen.X86
}
}
private static void GenerateTailcall(CodeGenContext context, Operation operation)
{
WriteEpilogue(context);
context.Assembler.Jmp(operation.GetSource(0));
}
private static void GenerateVectorCreateScalar(CodeGenContext context, Operation operation)
{
Operand dest = operation.Destination;

View file

@ -101,6 +101,10 @@ namespace ARMeilleure.CodeGen.X86
}
break;
case Instruction.Tailcall:
HandleTailcallWindowsAbi(stackAlloc, node, operation);
break;
case Instruction.VectorInsert8:
if (!HardwareCapabilities.SupportsSse41)
{
@ -829,6 +833,53 @@ namespace ARMeilleure.CodeGen.X86
return node;
}
private static void HandleTailcallWindowsAbi(StackAllocator stackAlloc, LLNode node, Operation operation)
{
Operand dest = operation.Destination;
LinkedList<Node> nodes = node.List;
int argsCount = operation.SourcesCount - 1;
int maxArgs = CallingConvention.GetArgumentsOnRegsCount();
if (argsCount > maxArgs)
{
argsCount = maxArgs;
}
Operand[] sources = new Operand[1 + argsCount];
// Handle arguments passed on registers.
for (int index = 0; index < argsCount; index++)
{
Operand source = operation.GetSource(1 + index);
Operand argReg = source.Type.IsInteger()
? Gpr(CallingConvention.GetIntArgumentRegister(index), source.Type)
: Xmm(CallingConvention.GetVecArgumentRegister(index), source.Type);
Operation copyOp = new Operation(Instruction.Copy, argReg, source);
HandleConstantCopy(nodes.AddBefore(node, copyOp), copyOp);
sources[1 + index] = argReg;
}
// The target address must be on the return registers, since we
// don't return anything and it is guaranteed to not be a
// callee saved register (which would be trashed on the epilogue).
Operand retReg = Gpr(CallingConvention.GetIntReturnRegister(), OperandType.I64);
Operation addrCopyOp = new Operation(Instruction.Copy, retReg, operation.GetSource(0));
nodes.AddBefore(node, addrCopyOp);
sources[0] = retReg;
operation.SetSources(sources);
}
private static void HandleLoadArgumentWindowsAbi(
CompilerContext cctx,
IntrusiveList<Node> nodes,

View file

@ -50,6 +50,7 @@ namespace ARMeilleure.CodeGen.X86
Imul,
Imul128,
Insertps,
Jmp,
Lea,
Maxpd,
Maxps,

View file

@ -121,7 +121,7 @@ namespace ARMeilleure.Decoders
currBlock.Branch = GetBlock((ulong)op.Immediate);
}
if (!IsUnconditionalBranch(lastOp) /*|| isCall*/)
if (!IsUnconditionalBranch(lastOp) || isCall)
{
currBlock.Next = GetBlock(currBlock.EndAddress);
}

View file

@ -56,7 +56,7 @@ namespace ARMeilleure.Instructions
{
OpCodeBReg op = (OpCodeBReg)context.CurrOp;
EmitVirtualJump(context, GetIntOrZR(context, op.Rn));
EmitVirtualJump(context, GetIntOrZR(context, op.Rn), op.Rn == RegisterAlias.Lr);
}
public static void Cbnz(ArmEmitterContext context) => EmitCb(context, onNotZero: true);
@ -71,7 +71,7 @@ namespace ARMeilleure.Instructions
public static void Ret(ArmEmitterContext context)
{
context.Return(context.BitwiseOr(GetIntOrZR(context, RegisterAlias.Lr), Const(CallFlag)));
context.Return(GetIntOrZR(context, RegisterAlias.Lr));
}
public static void Tbnz(ArmEmitterContext context) => EmitTb(context, onNotZero: true);

View file

@ -142,7 +142,33 @@ namespace ARMeilleure.Instructions
public static void EmitCall(ArmEmitterContext context, ulong immediate)
{
context.Return(Const(immediate | CallFlag));
EmitJumpTableCall(context, Const(immediate));
}
private static void EmitNativeCall(ArmEmitterContext context, Operand funcAddr, bool isJump = false)
{
context.StoreToContext();
Operand returnAddress;
if (isJump)
{
context.Tailcall(funcAddr, context.LoadArgument(OperandType.I64, 0));
}
else
{
returnAddress = context.Call(funcAddr, OperandType.I64, context.LoadArgument(OperandType.I64, 0));
context.LoadFromContext();
// InstEmitFlowHelper.EmitContinueOrReturnCheck(context, returnAddress);
// If the return address isn't to our next instruction, we need to return to the JIT can figure out what to do.
Operand continueLabel = Label();
Operand next = Const(GetNextOpAddress(context.CurrOp));
context.BranchIfTrue(continueLabel, context.ICompareEqual(context.BitwiseAnd(returnAddress, Const(~1L)), next));
context.Return(returnAddress);
context.MarkLabel(continueLabel);
}
}
public static void EmitVirtualCall(ArmEmitterContext context, Operand target)
@ -150,17 +176,24 @@ namespace ARMeilleure.Instructions
EmitVirtualCallOrJump(context, target, isJump: false);
}
public static void EmitVirtualJump(ArmEmitterContext context, Operand target)
public static void EmitVirtualJump(ArmEmitterContext context, Operand target, bool isReturn)
{
EmitVirtualCallOrJump(context, target, isJump: true);
EmitVirtualCallOrJump(context, target, isJump: true, isReturn: isReturn);
}
private static void EmitVirtualCallOrJump(ArmEmitterContext context, Operand target, bool isJump)
private static void EmitVirtualCallOrJump(ArmEmitterContext context, Operand target, bool isJump, bool isReturn = false)
{
context.Return(context.BitwiseOr(target, Const(target.Type, (long)CallFlag)));
if (isReturn)
{
context.Return(target);
}
else
{
EmitJumpTableCall(context, target, isJump);
}
}
private static void EmitContinueOrReturnCheck(ArmEmitterContext context, Operand retVal)
public static void EmitContinueOrReturnCheck(ArmEmitterContext context, Operand retVal)
{
// Note: The return value of the called method will be placed
// at the Stack, the return value is always a Int64 with the
@ -188,5 +221,135 @@ namespace ARMeilleure.Instructions
{
return op.Address + (ulong)op.OpCodeSizeInBytes;
}
public static void EmitDynamicTableCall(ArmEmitterContext context, Operand tableAddress, Operand address, bool isJump)
{
if (address.Type == OperandType.I32)
{
address = context.ZeroExtend32(OperandType.I64, address);
}
// Loop over elements of the dynamic table. Unrolled loop.
// TODO: different reserved size for jumps? Need to do some testing to see what is reasonable.
Operand endLabel = Label();
Operand fallbackLabel = Label();
for (int i = 0; i < JumpTable.DynamicTableElems; i++)
{
// TODO: USE COMPARE AND SWAP I64 TO ENSURE ATOMIC OPERATIONS
Operand nextLabel = Label();
// Load this entry from the table.
Operand entry = context.Load(OperandType.I64, tableAddress);
// If it's 0, we can take this entry in the table.
// (TODO: compare and exchange with our address _first_ when implemented, then just check if the entry is ours)
Operand hasAddressLabel = Label();
Operand gotTableLabel = Label();
context.BranchIfTrue(hasAddressLabel, entry);
// Take the entry.
context.Store(tableAddress, address);
context.Branch(gotTableLabel);
context.MarkLabel(hasAddressLabel);
// If there is an entry here, is it ours?
context.BranchIfFalse(nextLabel, context.ICompareEqual(entry, address));
context.MarkLabel(gotTableLabel);
// It's ours, so what function is it pointing to?
Operand missingFunctionLabel = Label();
Operand targetFunction = context.Load(OperandType.I64, context.Add(tableAddress, Const(8)));
context.BranchIfFalse(missingFunctionLabel, targetFunction);
// Call the function.
EmitNativeCall(context, targetFunction, isJump);
context.Branch(endLabel);
// We need to find the missing function. This can only be from a list of HighCq functions, which the JumpTable maintains.
context.MarkLabel(missingFunctionLabel);
Operand goodCallAddr = context.Call(new _U64_U64(context.JumpTable.TryGetFunction), address); // TODO: NativeInterface call to it? (easier to AOT)
context.BranchIfFalse(fallbackLabel, goodCallAddr); // Fallback if it doesn't exist yet.
// Call the function.
EmitNativeCall(context, goodCallAddr, isJump);
context.Branch(endLabel);
context.MarkLabel(nextLabel);
tableAddress = context.Add(tableAddress, Const(16)); // Move to the next table entry.
}
context.MarkLabel(fallbackLabel);
address = context.BitwiseOr(address, Const(address.Type, 1)); // Set call flag.
Operand fallbackAddr = context.Call(new _U64_U64(NativeInterface.GetFunctionAddress), address);
EmitNativeCall(context, fallbackAddr, isJump);
context.MarkLabel(endLabel);
}
public static void EmitJumpTableCall(ArmEmitterContext context, Operand address, bool isJump = false)
{
// Does the call have a constant value, or can it be folded to one?
// TODO: Constant folding. Indirect calls are slower in the best case and emit more code so we want to avoid them when possible.
bool isConst = address.Kind == OperandKind.Constant;
long constAddr = (long)address.Value;
if (isJump || !isConst || !context.HighCq)
{
if (context.HighCq)
{
// Virtual branch/call - store first used addresses on a small table for fast lookup.
int entry = context.JumpTable.ReserveDynamicEntry();
int jumpOffset = entry * JumpTable.JumpTableStride * JumpTable.DynamicTableElems;
Operand dynTablePtr = Const(context.JumpTable.DynamicPointer.ToInt64() + jumpOffset);
EmitDynamicTableCall(context, dynTablePtr, address, isJump);
}
else
{
// Don't emit indirect calls or jumps if we're compiling in lowCq mode.
// This avoids wasting space on the jump and indirect tables.
context.Return(context.BitwiseOr(address, Const(address.Type, 1))); // Set call flag.
}
}
else
{
int entry = context.JumpTable.ReserveTableEntry(context.BaseAddress & (~3L), constAddr);
int jumpOffset = entry * JumpTable.JumpTableStride + 8; // Offset directly to the host address.
// TODO: Portable jump table ptr with IR adding of the offset. Will be easy to relocate for things like AOT.
Operand tableEntryPtr = Const(context.JumpTable.BasePointer.ToInt64() + jumpOffset);
Operand funcAddr = context.Load(OperandType.I64, tableEntryPtr);
Operand directCallLabel = Label();
Operand endLabel = Label();
// Host address in the table is 0 until the function is rejit.
context.BranchIfTrue(directCallLabel, funcAddr);
// Call the function through the translator until it is rejit.
address = context.BitwiseOr(address, Const(address.Type, 1)); // Set call flag.
Operand fallbackAddr = context.Call(new _U64_U64(NativeInterface.GetFunctionAddress), address);
EmitNativeCall(context, fallbackAddr);
context.Branch(endLabel);
context.MarkLabel(directCallLabel);
EmitNativeCall(context, funcAddr); // Call the function directly if it is present in the entry.
context.MarkLabel(endLabel);
}
}
}
}

View file

@ -1,5 +1,6 @@
using ARMeilleure.Memory;
using ARMeilleure.State;
using ARMeilleure.Translation;
using System;
namespace ARMeilleure.Instructions
@ -10,17 +11,19 @@ namespace ARMeilleure.Instructions
private class ThreadContext
{
public ExecutionContext Context { get; }
public MemoryManager Memory { get; }
public ExecutionContext Context { get; }
public MemoryManager Memory { get; }
public Translator Translator { get; }
public ulong ExclusiveAddress { get; set; }
public ulong ExclusiveValueLow { get; set; }
public ulong ExclusiveValueHigh { get; set; }
public ThreadContext(ExecutionContext context, MemoryManager memory)
public ThreadContext(ExecutionContext context, MemoryManager memory, Translator translator)
{
Context = context;
Memory = memory;
Context = context;
Memory = memory;
Translator = translator;
ExclusiveAddress = ulong.MaxValue;
}
@ -29,9 +32,9 @@ namespace ARMeilleure.Instructions
[ThreadStatic]
private static ThreadContext _context;
public static void RegisterThread(ExecutionContext context, MemoryManager memory)
public static void RegisterThread(ExecutionContext context, MemoryManager memory, Translator translator)
{
_context = new ThreadContext(context, memory);
_context = new ThreadContext(context, memory, translator);
}
public static void UnregisterThread()
@ -381,6 +384,12 @@ namespace ARMeilleure.Instructions
return address & ~((4UL << ErgSizeLog2) - 1);
}
public static ulong GetFunctionAddress(ulong address)
{
TranslatedFunction function = _context.Translator.GetOrTranslate(address, GetContext().ExecutionMode);
return (ulong)function.GetPointer().ToInt64();
}
public static void ClearExclusive()
{
_context.ExclusiveAddress = ulong.MaxValue;

View file

@ -52,6 +52,7 @@ namespace ARMeilleure.IntermediateRepresentation
Store16,
Store8,
Subtract,
Tailcall,
VectorCreateScalar,
VectorExtract,
VectorExtract16,

View file

@ -41,10 +41,19 @@ namespace ARMeilleure.Translation
public Aarch32Mode Mode { get; }
public ArmEmitterContext(MemoryManager memory, Aarch32Mode mode)
public JumpTable JumpTable { get; }
public long BaseAddress { get; }
public bool HighCq { get; }
public ArmEmitterContext(MemoryManager memory, JumpTable jumpTable, long baseAddress, bool highCq, Aarch32Mode mode)
{
Memory = memory;
Mode = mode;
Memory = memory;
JumpTable = jumpTable;
BaseAddress = baseAddress;
HighCq = highCq;
Mode = mode;
_labels = new Dictionary<ulong, Operand>();
}

View file

@ -143,6 +143,19 @@ namespace ARMeilleure.Translation
}
}
public void Tailcall(Operand address, params Operand[] callArgs)
{
Operand[] args = new Operand[callArgs.Length + 1];
args[0] = address;
Array.Copy(callArgs, 0, args, 1, callArgs.Length);
Add(Instruction.Tailcall, null, args);
_needsNewBlock = true;
}
public Operand CompareAndSwap128(Operand address, Operand expected, Operand desired)
{
return Add(Instruction.CompareAndSwap128, Local(OperandType.V128), address, expected, desired);

View file

@ -0,0 +1,128 @@
using ARMeilleure.Memory;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Threading;
namespace ARMeilleure.Translation
{
class JumpTable
{
// The jump table is a block of (guestAddress, hostAddress) function mappings.
// Each entry corresponds to one branch in a JIT compiled function. The entries are
// reserved specifically for each call.
// The _dependants dictionary can be used to update the hostAddress for any functions that change.
public const int JumpTableStride = 16; // 8 byte guest address, 8 byte host address
private const int JumpTableSize = 1048576;
private const int JumpTableByteSize = JumpTableSize * JumpTableStride;
// The dynamic table is also a block of (guestAddress, hostAddress) function mappings.
// The main difference is that indirect calls and jumps reserve _multiple_ entries on the table.
// These start out as all 0. When an indirect call is made, it tries to find the guest address on the table.
// If we get to an empty address, the guestAddress is set to the call that we want.
// If we get to a guestAddress that matches our own (or we just claimed it), the hostAddress is read.
// If it is non-zero, we immediately branch or call the host function.
// If it is 0, NativeInterface is called to find the rejited address of the call.
// If none is found, the hostAddress entry stays at 0. Otherwise, the new address is placed in the entry.
// If the table size is exhausted and we didn't find our desired address, we fall back to doing requesting
// the function from the JIT.
private const int DynamicTableSize = 1048576;
public const int DynamicTableElems = 10;
private const int DynamicTableByteSize = DynamicTableSize * JumpTableStride * DynamicTableElems;
private int _tableEnd = 0;
private int _dynTableEnd = 0;
private ConcurrentDictionary<ulong, TranslatedFunction> _targets;
private ConcurrentDictionary<ulong, LinkedList<int>> _dependants; // TODO: Attach to TranslatedFunction or a wrapper class.
public IntPtr BasePointer { get; }
public IntPtr DynamicPointer { get; }
public JumpTable()
{
BasePointer = MemoryManagement.Allocate(JumpTableByteSize);
DynamicPointer = MemoryManagement.Allocate(DynamicTableByteSize);
_targets = new ConcurrentDictionary<ulong, TranslatedFunction>();
_dependants = new ConcurrentDictionary<ulong, LinkedList<int>>();
}
public void RegisterFunction(ulong address, TranslatedFunction func) {
address &= ~3UL;
_targets.AddOrUpdate(address, func, (key, oldFunc) => func);
long funcPtr = func.GetPointer().ToInt64();
// Update all jump table entries that target this address.
LinkedList<int> myDependants;
if (_dependants.TryGetValue(address, out myDependants)) {
lock (myDependants)
{
foreach (var entry in myDependants)
{
IntPtr addr = BasePointer + entry * JumpTableStride;
Marshal.WriteInt64(addr, 8, funcPtr);
}
}
}
}
public ulong TryGetFunction(ulong address)
{
TranslatedFunction result;
if (_targets.TryGetValue(address, out result))
{
return (ulong)result.GetPointer().ToInt64();
}
return 0;
}
public int ReserveDynamicEntry()
{
int entry = Interlocked.Increment(ref _dynTableEnd);
if (entry >= DynamicTableSize)
{
throw new OutOfMemoryException("JIT Dynamic Jump Table Exhausted.");
}
return entry;
}
public int ReserveTableEntry(long ownerAddress, long address)
{
int entry = Interlocked.Increment(ref _tableEnd);
if (entry >= JumpTableSize)
{
throw new OutOfMemoryException("JIT Direct Jump Table Exhausted.");
}
// Is the address we have already registered? If so, put the function address in the jump table.
long value = 0;
TranslatedFunction func;
if (_targets.TryGetValue((ulong)address, out func))
{
value = func.GetPointer().ToInt64();
}
// Make sure changes to the function at the target address update this jump table entry.
LinkedList<int> targetDependants = _dependants.GetOrAdd((ulong)address, (addr) => new LinkedList<int>());
targetDependants.AddLast(entry);
IntPtr addr = BasePointer + entry * JumpTableStride;
Marshal.WriteInt64(addr, 0, address);
Marshal.WriteInt64(addr, 8, value);
return entry;
}
}
}

View file

@ -1,3 +1,5 @@
using System;
using System.Runtime.InteropServices;
using System.Threading;
namespace ARMeilleure.Translation
@ -26,5 +28,10 @@ namespace ARMeilleure.Translation
{
return _rejit && Interlocked.Increment(ref _callCount) == MinCallsForRejit;
}
public IntPtr GetPointer()
{
return Marshal.GetFunctionPointerForDelegate(_func);
}
}
}

View file

@ -20,6 +20,8 @@ namespace ARMeilleure.Translation
private ConcurrentDictionary<ulong, TranslatedFunction> _funcs;
private JumpTable _jumpTable;
private PriorityQueue<RejitRequest> _backgroundQueue;
private AutoResetEvent _backgroundTranslatorEvent;
@ -32,6 +34,8 @@ namespace ARMeilleure.Translation
_funcs = new ConcurrentDictionary<ulong, TranslatedFunction>();
_jumpTable = new JumpTable();
_backgroundQueue = new PriorityQueue<RejitRequest>(2);
_backgroundTranslatorEvent = new AutoResetEvent(false);
@ -46,6 +50,7 @@ namespace ARMeilleure.Translation
TranslatedFunction func = Translate(request.Address, request.Mode, highCq: true);
_funcs.AddOrUpdate(request.Address, func, (key, oldFunc) => func);
_jumpTable.RegisterFunction(request.Address, func);
}
else
{
@ -69,7 +74,7 @@ namespace ARMeilleure.Translation
Statistics.InitializeTimer();
NativeInterface.RegisterThread(context, _memory);
NativeInterface.RegisterThread(context, _memory, this);
do
{
@ -98,7 +103,7 @@ namespace ARMeilleure.Translation
return nextAddr;
}
private TranslatedFunction GetOrTranslate(ulong address, ExecutionMode mode)
internal TranslatedFunction GetOrTranslate(ulong address, ExecutionMode mode)
{
// TODO: Investigate how we should handle code at unaligned addresses.
// Currently, those low bits are used to store special flags.
@ -124,7 +129,7 @@ namespace ARMeilleure.Translation
private TranslatedFunction Translate(ulong address, ExecutionMode mode, bool highCq)
{
ArmEmitterContext context = new ArmEmitterContext(_memory, Aarch32Mode.User);
ArmEmitterContext context = new ArmEmitterContext(_memory, _jumpTable, (long)address, highCq, Aarch32Mode.User);
Logger.StartPass(PassName.Decoding);