Optimize some parts of LSRA

- BitMap now operates on 64-bit int rather than 32-bit
- BitMap is now pooled in a ThreadStatic pool (within lrsa)
- BitMap now is now its own iterator. Marginally speeds up iterating
through the bits.
- A few cases where enumerators were generated have been converted to
forms that generate less garbage.
- New data structure for sorting _usePositions in LiveIntervals. Much
faster split, NextUseAfter, initial insertion. Random insertion is
slightly slower.
- That last one is WIP since you need to insert the values backwards. It
would be ideal if it just flipped it for you, uncomplicating things on
the caller side.
This commit is contained in:
riperiperi 2020-02-11 23:54:50 +00:00
commit 7421a6b3e5
7 changed files with 293 additions and 84 deletions

View file

@ -32,7 +32,7 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
private int _operationsCount; private int _operationsCount;
private class AllocationContext private class AllocationContext : IDisposable
{ {
public RegisterMasks Masks { get; } public RegisterMasks Masks { get; }
@ -49,8 +49,8 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
StackAlloc = stackAlloc; StackAlloc = stackAlloc;
Masks = masks; Masks = masks;
Active = new BitMap(intervalsCount); Active = BitMapPool.Allocate(intervalsCount);
Inactive = new BitMap(intervalsCount); Inactive = BitMapPool.Allocate(intervalsCount);
} }
public void MoveActiveToInactive(int bit) public void MoveActiveToInactive(int bit)
@ -69,6 +69,11 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
dest.Set(bit); dest.Set(bit);
} }
public void Dispose()
{
BitMapPool.Release();
}
} }
public AllocationResult RunPass( public AllocationResult RunPass(
@ -121,10 +126,14 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
InsertSplitCopies(); InsertSplitCopies();
InsertSplitCopiesAtEdges(cfg); InsertSplitCopiesAtEdges(cfg);
return new AllocationResult( AllocationResult result = new AllocationResult(
context.IntUsedRegisters, context.IntUsedRegisters,
context.VecUsedRegisters, context.VecUsedRegisters,
context.StackAlloc.TotalSize); context.StackAlloc.TotalSize);
context.Dispose();
return result;
} }
private void AllocateInterval(AllocationContext context, LiveInterval current, int cIndex) private void AllocateInterval(AllocationContext context, LiveInterval current, int cIndex)
@ -618,15 +627,22 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
bool hasSingleOrNoSuccessor = block.Next == null || block.Branch == null; bool hasSingleOrNoSuccessor = block.Next == null || block.Branch == null;
foreach (BasicBlock successor in Successors(block)) for (int i = 0; i < 2; i++)
{ {
// This used to use an enumerable, but it ended up generating a lot of garbage, so now it is a loop.
BasicBlock successor = (i == 0) ? block.Next : block.Branch;
if (successor == null)
{
continue;
}
int succIndex = successor.Index; int succIndex = successor.Index;
// If the current node is a split node, then the actual successor node // If the current node is a split node, then the actual successor node
// (the successor before the split) should be right after it. // (the successor before the split) should be right after it.
if (IsSplitEdgeBlock(successor)) if (IsSplitEdgeBlock(successor))
{ {
succIndex = Successors(successor).First().Index; succIndex = FirstSuccessor(successor).Index;
} }
CopyResolver copyResolver = new CopyResolver(); CopyResolver copyResolver = new CopyResolver();
@ -699,8 +715,10 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{ {
Operand register = GetRegister(current); Operand register = GetRegister(current);
foreach (int usePosition in current.UsePositions()) IList<int> usePositions = current.UsePositions();
for (int i = usePositions.Count - 1; i >= 0; i--)
{ {
int usePosition = -usePositions[i];
(_, Node operation) = GetOperationNode(usePosition); (_, Node operation) = GetOperationNode(usePosition);
for (int index = 0; index < operation.SourcesCount; index++) for (int index = 0; index < operation.SourcesCount; index++)
@ -778,8 +796,9 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{ {
_operationNodes.Add((block.Operations, node)); _operationNodes.Add((block.Operations, node));
foreach (Operand dest in Destinations(node)) for (int i = 0; i < node.DestinationsCount; i++)
{ {
Operand dest = node.GetDestination(i);
if (dest.Kind == OperandKind.LocalVariable && visited.Add(dest)) if (dest.Kind == OperandKind.LocalVariable && visited.Add(dest))
{ {
dest.NumberLocal(_intervals.Count); dest.NumberLocal(_intervals.Count);
@ -815,12 +834,12 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
// Compute local live sets. // Compute local live sets.
for (BasicBlock block = cfg.Blocks.First; block != null; block = block.ListNext) for (BasicBlock block = cfg.Blocks.First; block != null; block = block.ListNext)
{ {
BitMap liveGen = new BitMap(mapSize); BitMap liveGen = BitMapPool.Allocate(mapSize);
BitMap liveKill = new BitMap(mapSize); BitMap liveKill = BitMapPool.Allocate(mapSize);
for (Node node = block.Operations.First; node != null; node = node.ListNext) for (Node node = block.Operations.First; node != null; node = node.ListNext)
{ {
foreach (Operand source in Sources(node)) Sources(node, (source) =>
{ {
int id = GetOperandId(source); int id = GetOperandId(source);
@ -828,10 +847,11 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{ {
liveGen.Set(id); liveGen.Set(id);
} }
} });
foreach (Operand dest in Destinations(node)) for (int i = 0; i < node.DestinationsCount; i++)
{ {
Operand dest = node.GetDestination(i);
liveKill.Set(GetOperandId(dest)); liveKill.Set(GetOperandId(dest));
} }
} }
@ -846,8 +866,8 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
for (int index = 0; index < cfg.Blocks.Count; index++) for (int index = 0; index < cfg.Blocks.Count; index++)
{ {
blkLiveIn [index] = new BitMap(mapSize); blkLiveIn [index] = BitMapPool.Allocate(mapSize);
blkLiveOut[index] = new BitMap(mapSize); blkLiveOut[index] = BitMapPool.Allocate(mapSize);
} }
bool modified; bool modified;
@ -862,13 +882,10 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
BitMap liveOut = blkLiveOut[block.Index]; BitMap liveOut = blkLiveOut[block.Index];
foreach (BasicBlock successor in Successors(block)) if ((block.Next != null && liveOut.Set(blkLiveIn[block.Next.Index])) || (block.Branch != null && liveOut.Set(blkLiveIn[block.Branch.Index])))
{
if (liveOut.Set(blkLiveIn[successor.Index]))
{ {
modified = true; modified = true;
} }
}
BitMap liveIn = blkLiveIn[block.Index]; BitMap liveIn = blkLiveIn[block.Index];
@ -920,21 +937,22 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{ {
operationPos -= InstructionGap; operationPos -= InstructionGap;
foreach (Operand dest in Destinations(node)) for (int i = 0; i < node.DestinationsCount; i++)
{ {
Operand dest = node.GetDestination(i);
LiveInterval interval = _intervals[GetOperandId(dest)]; LiveInterval interval = _intervals[GetOperandId(dest)];
interval.SetStart(operationPos + 1); interval.SetStart(operationPos + 1);
interval.AddUsePosition(operationPos + 1); interval.AddUsePosition(operationPos + 1);
} }
foreach (Operand source in Sources(node)) Sources(node, (source) =>
{ {
LiveInterval interval = _intervals[GetOperandId(source)]; LiveInterval interval = _intervals[GetOperandId(source)];
interval.AddRange(blockStart, operationPos + 1); interval.AddRange(blockStart, operationPos + 1);
interval.AddUsePosition(operationPos); interval.AddUsePosition(operationPos);
} });
if (node is Operation operation && operation.Instruction == Instruction.Call) if (node is Operation operation && operation.Instruction == Instruction.Call)
{ {
@ -982,17 +1000,9 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
return (register.Index << 1) | (register.Type == RegisterType.Vector ? 1 : 0); return (register.Index << 1) | (register.Type == RegisterType.Vector ? 1 : 0);
} }
private static IEnumerable<BasicBlock> Successors(BasicBlock block) private static BasicBlock FirstSuccessor(BasicBlock block)
{ {
if (block.Next != null) return block.Next ?? block.Branch;
{
yield return block.Next;
}
if (block.Branch != null)
{
yield return block.Branch;
}
} }
private static IEnumerable<Node> BottomOperations(BasicBlock block) private static IEnumerable<Node> BottomOperations(BasicBlock block)
@ -1007,15 +1017,7 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
} }
} }
private static IEnumerable<Operand> Destinations(Node node) private static void Sources(Node node, Action<Operand> action)
{
for (int index = 0; index < node.DestinationsCount; index++)
{
yield return node.GetDestination(index);
}
}
private static IEnumerable<Operand> Sources(Node node)
{ {
for (int index = 0; index < node.SourcesCount; index++) for (int index = 0; index < node.SourcesCount; index++)
{ {
@ -1023,7 +1025,7 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
if (IsLocalOrRegister(source.Kind)) if (IsLocalOrRegister(source.Kind))
{ {
yield return source; action(source);
} }
else if (source.Kind == OperandKind.Memory) else if (source.Kind == OperandKind.Memory)
{ {

View file

@ -1,3 +1,4 @@
using ARMeilleure.Common;
using ARMeilleure.IntermediateRepresentation; using ARMeilleure.IntermediateRepresentation;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
@ -12,7 +13,7 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
private LiveInterval _parent; private LiveInterval _parent;
private SortedSet<int> _usePositions; private SortedIntegerList _usePositions;
public int UsesCount => _usePositions.Count; public int UsesCount => _usePositions.Count;
@ -38,7 +39,7 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
Local = local; Local = local;
_parent = parent ?? this; _parent = parent ?? this;
_usePositions = new SortedSet<int>(); _usePositions = new SortedIntegerList();
_ranges = new List<LiveRange>(); _ranges = new List<LiveRange>();
@ -196,7 +197,9 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
public void AddUsePosition(int position) public void AddUsePosition(int position)
{ {
_usePositions.Add(position); // Inserts are in descending order, but ascending is faster for SortedList<>.
// We flip the ordering, then iterate backwards when using the final list.
_usePositions.Add(-position);
} }
public bool Overlaps(int position) public bool Overlaps(int position)
@ -247,9 +250,9 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
return _childs.Values; return _childs.Values;
} }
public IEnumerable<int> UsePositions() public IList<int> UsePositions()
{ {
return _usePositions; return _usePositions.GetList();
} }
public int FirstUse() public int FirstUse()
@ -259,20 +262,19 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
return NotFound; return NotFound;
} }
return _usePositions.First(); return -_usePositions.Last();
} }
public int NextUseAfter(int position) public int NextUseAfter(int position)
{ {
foreach (int usePosition in _usePositions) int index = _usePositions.FindLessEqualIndex(-position);
{ return (index >= 0) ? -_usePositions[index] : NotFound;
if (usePosition >= position)
{
return usePosition;
}
} }
return NotFound; public void RemoveAfter(int position)
{
int index = _usePositions.FindLessEqualIndex(-position);
_usePositions.RemoveRange(0, index + 1);
} }
public LiveInterval Split(int position) public LiveInterval Split(int position)
@ -311,12 +313,14 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
_ranges.RemoveRange(splitIndex, count); _ranges.RemoveRange(splitIndex, count);
} }
foreach (int usePosition in _usePositions.Where(x => x >= position)) int addAfter = _usePositions.FindLessEqualIndex(-position);
for (int index = addAfter; index >= 0; index--)
{ {
int usePosition = _usePositions[index];
right._usePositions.Add(usePosition); right._usePositions.Add(usePosition);
} }
_usePositions.RemoveWhere(x => x >= position); RemoveAfter(position);
Debug.Assert(_ranges.Count != 0, "Left interval is empty after split."); Debug.Assert(_ranges.Count != 0, "Left interval is empty after split.");

View file

@ -3,18 +3,47 @@ using System.Collections.Generic;
namespace ARMeilleure.Common namespace ARMeilleure.Common
{ {
class BitMap : IEnumerable<int> class BitMap : IEnumerator<int>
{ {
private const int IntSize = 32; private const int IntSize = 64;
private const int IntMask = IntSize - 1; private const int IntMask = IntSize - 1;
private List<int> _masks; private List<long> _masks;
private int _enumIndex;
private long _enumMask;
private int _enumBit;
public int Current => _enumIndex * IntSize + _enumBit;
object IEnumerator.Current => Current;
public BitMap()
{
_masks = new List<long>(0);
}
public BitMap(int initialCapacity) public BitMap(int initialCapacity)
{ {
int count = (initialCapacity + IntMask) / IntSize; int count = (initialCapacity + IntMask) / IntSize;
_masks = new List<int>(count); _masks = new List<long>(count);
while (count-- > 0)
{
_masks.Add(0);
}
}
public void Reset(int initialCapacity)
{
int count = (initialCapacity + IntMask) / IntSize;
if (count > _masks.Capacity)
{
_masks.Capacity = count;
}
_masks.Clear();
while (count-- > 0) while (count-- > 0)
{ {
@ -29,7 +58,7 @@ namespace ARMeilleure.Common
int wordIndex = bit / IntSize; int wordIndex = bit / IntSize;
int wordBit = bit & IntMask; int wordBit = bit & IntMask;
int wordMask = 1 << wordBit; long wordMask = 1L << wordBit;
if ((_masks[wordIndex] & wordMask) != 0) if ((_masks[wordIndex] & wordMask) != 0)
{ {
@ -48,7 +77,7 @@ namespace ARMeilleure.Common
int wordIndex = bit / IntSize; int wordIndex = bit / IntSize;
int wordBit = bit & IntMask; int wordBit = bit & IntMask;
int wordMask = 1 << wordBit; long wordMask = 1L << wordBit;
_masks[wordIndex] &= ~wordMask; _masks[wordIndex] &= ~wordMask;
} }
@ -60,7 +89,7 @@ namespace ARMeilleure.Common
int wordIndex = bit / IntSize; int wordIndex = bit / IntSize;
int wordBit = bit & IntMask; int wordBit = bit & IntMask;
return (_masks[wordIndex] & (1 << wordBit)) != 0; return (_masks[wordIndex] & (1L << wordBit)) != 0;
} }
public bool Set(BitMap map) public bool Set(BitMap map)
@ -71,7 +100,7 @@ namespace ARMeilleure.Common
for (int index = 0; index < _masks.Count; index++) for (int index = 0; index < _masks.Count; index++)
{ {
int newValue = _masks[index] | map._masks[index]; long newValue = _masks[index] | map._masks[index];
if (_masks[index] != newValue) if (_masks[index] != newValue)
{ {
@ -92,7 +121,7 @@ namespace ARMeilleure.Common
for (int index = 0; index < _masks.Count; index++) for (int index = 0; index < _masks.Count; index++)
{ {
int newValue = _masks[index] & ~map._masks[index]; long newValue = _masks[index] & ~map._masks[index];
if (_masks[index] != newValue) if (_masks[index] != newValue)
{ {
@ -105,6 +134,10 @@ namespace ARMeilleure.Common
return modified; return modified;
} }
#region IEnumerable<long> Methods
// Note: The bit enumerator is embedded in this class to avoid creating garbage when enumerating.
private void EnsureCapacity(int size) private void EnsureCapacity(int size)
{ {
while (_masks.Count * IntSize < size) while (_masks.Count * IntSize < size)
@ -115,24 +148,38 @@ namespace ARMeilleure.Common
public IEnumerator<int> GetEnumerator() public IEnumerator<int> GetEnumerator()
{ {
for (int index = 0; index < _masks.Count; index++) Reset();
return this;
}
public bool MoveNext()
{ {
int mask = _masks[index]; if (_enumMask != 0)
while (mask != 0)
{ {
int bit = BitUtils.LowestBitSet(mask); _enumMask &= ~(1L << _enumBit);
mask &= ~(1 << bit);
yield return index * IntSize + bit;
} }
} while (_enumMask == 0)
}
IEnumerator IEnumerable.GetEnumerator()
{ {
return GetEnumerator(); if (++_enumIndex >= _masks.Count)
{
return false;
} }
_enumMask = _masks[_enumIndex];
}
_enumBit = BitUtils.LowestBitSet(_enumMask);
return true;
}
public void Reset()
{
_enumIndex = -1;
_enumMask = 0;
_enumBit = 0;
}
public void Dispose() { }
#endregion
} }
} }

View file

@ -0,0 +1,19 @@
using System;
namespace ARMeilleure.Common
{
static class BitMapPool
{
public static BitMap Allocate(int initialCapacity)
{
BitMap result = ThreadStaticPool<BitMap>.Instance.Allocate();
result.Reset(initialCapacity);
return result;
}
public static void Release()
{
ThreadStaticPool<BitMap>.Instance.Clear();
}
}
}

View file

@ -1,3 +1,5 @@
using System.Runtime.CompilerServices;
namespace ARMeilleure.Common namespace ARMeilleure.Common
{ {
static class BitUtils static class BitUtils
@ -6,11 +8,16 @@ namespace ARMeilleure.Common
private static readonly int[] DeBrujinLbsLut; private static readonly int[] DeBrujinLbsLut;
private const long DeBrujinSequence64 = 0x37e84a99dae458f;
private static readonly int[] DeBrujinLbsLut64;
private static readonly sbyte[] HbsNibbleLut; private static readonly sbyte[] HbsNibbleLut;
static BitUtils() static BitUtils()
{ {
DeBrujinLbsLut = new int[32]; DeBrujinLbsLut = new int[32];
DeBrujinLbsLut64 = new int[64];
for (int index = 0; index < DeBrujinLbsLut.Length; index++) for (int index = 0; index < DeBrujinLbsLut.Length; index++)
{ {
@ -19,6 +26,13 @@ namespace ARMeilleure.Common
DeBrujinLbsLut[lutIndex] = index; DeBrujinLbsLut[lutIndex] = index;
} }
for (int index = 0; index < DeBrujinLbsLut64.Length; index++)
{
ulong lutIndex = (ulong)(DeBrujinSequence64 * (1L << index)) >> 58;
DeBrujinLbsLut64[lutIndex] = index;
}
HbsNibbleLut = new sbyte[] { -1, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 }; HbsNibbleLut = new sbyte[] { -1, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 };
} }
@ -64,6 +78,7 @@ namespace ARMeilleure.Common
return HbsNibbleLut[value]; return HbsNibbleLut[value];
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int LowestBitSet(int value) public static int LowestBitSet(int value)
{ {
if (value == 0) if (value == 0)
@ -76,6 +91,19 @@ namespace ARMeilleure.Common
return DeBrujinLbsLut[(uint)(DeBrujinSequence * lsb) >> 27]; return DeBrujinLbsLut[(uint)(DeBrujinSequence * lsb) >> 27];
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int LowestBitSet(long value)
{
if (value == 0)
{
return -1;
}
long lsb = value & -value;
return DeBrujinLbsLut64[(ulong)(DeBrujinSequence64 * lsb) >> 58];
}
public static long Replicate(long bits, int size) public static long Replicate(long bits, int size)
{ {
long output = 0; long output = 0;

View file

@ -0,0 +1,109 @@
using System.Collections.Generic;
namespace ARMeilleure.Common
{
public class SortedIntegerList
{
private List<int> _items;
public int Count => _items.Count;
public int this[int index]
{
get
{
return _items[index];
}
set
{
_items[index] = value;
}
}
public SortedIntegerList()
{
_items = new List<int>();
}
public bool Add(int value)
{
if (_items.Count > 0 && value > Last())
{
_items.Add(value);
return true;
}
else
{
// Binary search for the location to insert.
int min = 0;
int max = Count - 1;
while (min <= max)
{
int mid = min + (max - min) / 2;
int existing = _items[mid];
if (value > existing)
{
min = mid + 1;
}
else if (value < existing)
{
max = mid - 1;
}
else
{
// This value already exists in the list. Return false.
return false;
}
}
_items.Insert(min, value);
return true;
}
}
public int FindLessEqualIndex(int value)
{
int min = 0;
int max = Count - 1;
while (min <= max)
{
int mid = min + (max - min) / 2;
int existing = _items[mid];
if (value > existing)
{
min = mid + 1;
}
else if (value < existing)
{
max = mid - 1;
}
else
{
return mid;
}
}
return max;
}
public void RemoveRange(int index, int count)
{
if (count > 0)
{
_items.RemoveRange(index, count);
}
}
public int Last()
{
return _items[Count - 1];
}
public List<int> GetList()
{
return _items;
}
}
}

View file

@ -3,7 +3,7 @@ using System.Threading;
namespace ARMeilleure namespace ARMeilleure
{ {
public class ThreadStaticPool<T> where T : class, new() internal class ThreadStaticPool<T> where T : class, new()
{ {
[ThreadStatic] [ThreadStatic]
private static ThreadStaticPool<T> _instance; private static ThreadStaticPool<T> _instance;