diff --git a/ChocolArm64/Instruction/AInstEmitSimdShift.cs b/ChocolArm64/Instruction/AInstEmitSimdShift.cs index a5ecd893fb..8918c0e1ba 100644 --- a/ChocolArm64/Instruction/AInstEmitSimdShift.cs +++ b/ChocolArm64/Instruction/AInstEmitSimdShift.cs @@ -221,7 +221,33 @@ namespace ChocolArm64.Instruction public static void Ssra_V(AILEmitterCtx Context) { - EmitVectorShrImmOpSx(Context, ShrImmFlags.Accumulate); + AOpCodeSimdShImm Op = (AOpCodeSimdShImm)Context.CurrOp; + + if (AOptimizations.UseSse2 && Op.Size > 0 + && Op.Size < 3) + { + Type[] TypesSra = new Type[] { VectorIntTypesPerSizeLog2[Op.Size], typeof(byte) }; + Type[] TypesAdd = new Type[] { VectorIntTypesPerSizeLog2[Op.Size], VectorIntTypesPerSizeLog2[Op.Size] }; + + EmitLdvecWithSignedCast(Context, Op.Rd, Op.Size); + EmitLdvecWithSignedCast(Context, Op.Rn, Op.Size); + + Context.EmitLdc_I4(GetImmShr(Op)); + + Context.EmitCall(typeof(Sse2).GetMethod(nameof(Sse2.ShiftRightArithmetic), TypesSra)); + Context.EmitCall(typeof(Sse2).GetMethod(nameof(Sse2.Add), TypesAdd)); + + EmitStvecWithSignedCast(Context, Op.Rd, Op.Size); + + if (Op.RegisterSize == ARegisterSize.SIMD64) + { + EmitVectorZeroUpper(Context, Op.Rd); + } + } + else + { + EmitVectorShrImmOpSx(Context, ShrImmFlags.Accumulate); + } } public static void Uqrshrn_S(AILEmitterCtx Context) @@ -315,7 +341,32 @@ namespace ChocolArm64.Instruction public static void Usra_V(AILEmitterCtx Context) { - EmitVectorShrImmOpZx(Context, ShrImmFlags.Accumulate); + AOpCodeSimdShImm Op = (AOpCodeSimdShImm)Context.CurrOp; + + if (AOptimizations.UseSse2 && Op.Size > 0) + { + Type[] TypesSrl = new Type[] { VectorUIntTypesPerSizeLog2[Op.Size], typeof(byte) }; + Type[] TypesAdd = new Type[] { VectorUIntTypesPerSizeLog2[Op.Size], VectorUIntTypesPerSizeLog2[Op.Size] }; + + EmitLdvecWithUnsignedCast(Context, Op.Rd, Op.Size); + EmitLdvecWithUnsignedCast(Context, Op.Rn, Op.Size); + + Context.EmitLdc_I4(GetImmShr(Op)); + + Context.EmitCall(typeof(Sse2).GetMethod(nameof(Sse2.ShiftRightLogical), TypesSrl)); + Context.EmitCall(typeof(Sse2).GetMethod(nameof(Sse2.Add), TypesAdd)); + + EmitStvecWithUnsignedCast(Context, Op.Rd, Op.Size); + + if (Op.RegisterSize == ARegisterSize.SIMD64) + { + EmitVectorZeroUpper(Context, Op.Rd); + } + } + else + { + EmitVectorShrImmOpZx(Context, ShrImmFlags.Accumulate); + } } private static void EmitVectorShl(AILEmitterCtx Context, bool Signed)