diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 4a184d2..94d18e4 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -173,6 +173,7 @@ fn run_instruction<'input>( | ast::Instruction::Setp { .. } | ast::Instruction::SetpBool { .. } | ast::Instruction::ShflSync { .. } + | ast::Instruction::Shf { .. } | ast::Instruction::Shl { .. } | ast::Instruction::Shr { .. } | ast::Instruction::Sin { .. } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 9202ad4..4e1ca5c 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1823,6 +1823,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::CpAsyncCommitGroup { .. } | ast::Instruction::CpAsyncWaitGroup { .. } | ast::Instruction::CpAsyncWaitAll { .. } + | ast::Instruction::Shf { .. } | ast::Instruction::Shl { .. } | ast::Instruction::Selp { .. } | ast::Instruction::Ret { .. } diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index f234507..5d9516f 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -34,7 +34,7 @@ use crate::pass::*; use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; -use ptx_parser::{CpAsyncArgs, CpAsyncDetails, Mul24Control}; +use ptx_parser::{CpAsyncArgs, CpAsyncDetails, FunnelShiftMode, Mul24Control, ShfArgs}; struct Builder(LLVMBuilderRef); @@ -485,6 +485,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments), + ast::Instruction::Shf { data, arguments } => self.emit_shf(data, arguments), ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments), ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments), ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), @@ -1990,6 +1991,61 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_shf( + &mut self, + data: ptx_parser::ShfDetails, + arguments: ShfArgs, + ) -> Result<(), TranslateError> { + let lsb = self.resolver.value(arguments.src_a)?; + let msb = self.resolver.value(arguments.src_b)?; + let shift_amount = self.resolver.value(arguments.src_c)?; + + let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); + let const_32 = unsafe { LLVMConstInt(llvm_i32, 32, 0) }; + + let intrinsic = match data.direction { + ptx_parser::ShiftDirection::R => c"llvm.fshr.i32", + ptx_parser::ShiftDirection::L => c"llvm.fshl.i32", + }; + + let shifted = self.emit_intrinsic( + intrinsic, + None, + Some(&ast::Type::Scalar(ptx_parser::ScalarType::B32)), + vec![(msb, llvm_i32), (lsb, llvm_i32), (shift_amount, llvm_i32)], + )?; + + if data.mode == FunnelShiftMode::Clamp { + // `llvm.fsh*` acts like `shf.*.wrap`. To implement clamp, we must conditionally return + // the left or right-most 32 bits if `shift_amount` is greater than or equal to 32. + + let should_clamp = unsafe { + LLVMBuildICmp( + self.builder, + LLVMIntPredicate::LLVMIntUGE, + shift_amount, + const_32, + LLVM_UNNAMED.as_ptr(), + ) + }; + + let max_shift = match data.direction { + ptx_parser::ShiftDirection::R => msb, + ptx_parser::ShiftDirection::L => lsb, + }; + + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildSelect(self.builder, should_clamp, max_shift, shifted, dst) + }); + } else { + let name = self.resolver.get_or_add(arguments.dst); + unsafe { LLVMSetValueName2(shifted, name.as_ptr().cast(), name.len()) }; + self.resolver.register(arguments.dst, shifted); + } + + Ok(()) + } + fn emit_shr( &mut self, data: ptx_parser::ShrData, diff --git a/ptx/src/test/ll/shf_l.ll b/ptx/src/test/ll/shf_l.ll new file mode 100644 index 0000000..5d6a64b --- /dev/null +++ b/ptx/src/test/ll/shf_l.ll @@ -0,0 +1,50 @@ +define amdgpu_kernel void @shf_l(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"45" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"45", ptr addrspace(5) %"39", align 8 + %"46" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"46", ptr addrspace(5) %"40", align 8 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"59" = inttoptr i64 %"48" to ptr + %"47" = load i32, ptr %"59", align 4 + store i32 %"47", ptr addrspace(5) %"41", align 4 + %"49" = load i64, ptr addrspace(5) %"39", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"33" = getelementptr inbounds i8, ptr %"60", i64 4 + %"50" = load i32, ptr %"33", align 4 + store i32 %"50", ptr addrspace(5) %"42", align 4 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"61" = inttoptr i64 %"51" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 8 + %"52" = load i32, ptr %"35", align 4 + store i32 %"52", ptr addrspace(5) %"43", align 4 + %"54" = load i32, ptr addrspace(5) %"41", align 4 + %"55" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = load i32, ptr addrspace(5) %"43", align 4 + %2 = call i32 @llvm.fshl.i32(i32 %"55", i32 %"54", i32 %"56") + %3 = icmp uge i32 %"56", 32 + %"62" = select i1 %3, i32 %"54", i32 %2 + store i32 %"62", ptr addrspace(5) %"44", align 4 + %"57" = load i64, ptr addrspace(5) %"40", align 8 + %"58" = load i32, ptr addrspace(5) %"44", align 4 + %"63" = inttoptr i64 %"57" to ptr + store i32 %"58", ptr %"63", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.fshl.i32(i32, i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/ll/shf_l_clamp.ll b/ptx/src/test/ll/shf_l_clamp.ll new file mode 100644 index 0000000..a395db1 --- /dev/null +++ b/ptx/src/test/ll/shf_l_clamp.ll @@ -0,0 +1,50 @@ +define amdgpu_kernel void @shf_l_clamp(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"45" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"45", ptr addrspace(5) %"39", align 8 + %"46" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"46", ptr addrspace(5) %"40", align 8 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"59" = inttoptr i64 %"48" to ptr + %"47" = load i32, ptr %"59", align 4 + store i32 %"47", ptr addrspace(5) %"41", align 4 + %"49" = load i64, ptr addrspace(5) %"39", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"33" = getelementptr inbounds i8, ptr %"60", i64 4 + %"50" = load i32, ptr %"33", align 4 + store i32 %"50", ptr addrspace(5) %"42", align 4 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"61" = inttoptr i64 %"51" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 8 + %"52" = load i32, ptr %"35", align 4 + store i32 %"52", ptr addrspace(5) %"43", align 4 + %"54" = load i32, ptr addrspace(5) %"41", align 4 + %"55" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = load i32, ptr addrspace(5) %"43", align 4 + %2 = call i32 @llvm.fshl.i32(i32 %"55", i32 %"54", i32 %"56") + %3 = icmp uge i32 %"56", 32 + %"62" = select i1 %3, i32 %"54", i32 %2 + store i32 %"62", ptr addrspace(5) %"44", align 4 + %"57" = load i64, ptr addrspace(5) %"40", align 8 + %"58" = load i32, ptr addrspace(5) %"44", align 4 + %"63" = inttoptr i64 %"57" to ptr + store i32 %"58", ptr %"63", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.fshl.i32(i32, i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/ll/shf_l_wrap.ll b/ptx/src/test/ll/shf_l_wrap.ll new file mode 100644 index 0000000..fc5cfc1 --- /dev/null +++ b/ptx/src/test/ll/shf_l_wrap.ll @@ -0,0 +1,48 @@ +define amdgpu_kernel void @shf_l_wrap(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"45" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"45", ptr addrspace(5) %"39", align 8 + %"46" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"46", ptr addrspace(5) %"40", align 8 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"59" = inttoptr i64 %"48" to ptr + %"47" = load i32, ptr %"59", align 4 + store i32 %"47", ptr addrspace(5) %"41", align 4 + %"49" = load i64, ptr addrspace(5) %"39", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"33" = getelementptr inbounds i8, ptr %"60", i64 4 + %"50" = load i32, ptr %"33", align 4 + store i32 %"50", ptr addrspace(5) %"42", align 4 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"61" = inttoptr i64 %"51" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 8 + %"52" = load i32, ptr %"35", align 4 + store i32 %"52", ptr addrspace(5) %"43", align 4 + %"54" = load i32, ptr addrspace(5) %"41", align 4 + %"55" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = load i32, ptr addrspace(5) %"43", align 4 + %"62" = call i32 @llvm.fshl.i32(i32 %"55", i32 %"54", i32 %"56") + store i32 %"62", ptr addrspace(5) %"44", align 4 + %"57" = load i64, ptr addrspace(5) %"40", align 8 + %"58" = load i32, ptr addrspace(5) %"44", align 4 + %"63" = inttoptr i64 %"57" to ptr + store i32 %"58", ptr %"63", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.fshl.i32(i32, i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } \ No newline at end of file diff --git a/ptx/src/test/ll/shf_r.ll b/ptx/src/test/ll/shf_r.ll new file mode 100644 index 0000000..0bce4a0 --- /dev/null +++ b/ptx/src/test/ll/shf_r.ll @@ -0,0 +1,50 @@ +define amdgpu_kernel void @shf_r(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"45" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"45", ptr addrspace(5) %"39", align 8 + %"46" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"46", ptr addrspace(5) %"40", align 8 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"59" = inttoptr i64 %"48" to ptr + %"47" = load i32, ptr %"59", align 4 + store i32 %"47", ptr addrspace(5) %"41", align 4 + %"49" = load i64, ptr addrspace(5) %"39", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"33" = getelementptr inbounds i8, ptr %"60", i64 4 + %"50" = load i32, ptr %"33", align 4 + store i32 %"50", ptr addrspace(5) %"42", align 4 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"61" = inttoptr i64 %"51" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 8 + %"52" = load i32, ptr %"35", align 4 + store i32 %"52", ptr addrspace(5) %"43", align 4 + %"54" = load i32, ptr addrspace(5) %"41", align 4 + %"55" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = load i32, ptr addrspace(5) %"43", align 4 + %2 = call i32 @llvm.fshr.i32(i32 %"55", i32 %"54", i32 %"56") + %3 = icmp uge i32 %"56", 32 + %"62" = select i1 %3, i32 %"55", i32 %2 + store i32 %"62", ptr addrspace(5) %"44", align 4 + %"57" = load i64, ptr addrspace(5) %"40", align 8 + %"58" = load i32, ptr addrspace(5) %"44", align 4 + %"63" = inttoptr i64 %"57" to ptr + store i32 %"58", ptr %"63", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.fshr.i32(i32, i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/ll/shf_r_clamp.ll b/ptx/src/test/ll/shf_r_clamp.ll new file mode 100644 index 0000000..e7e2b8c --- /dev/null +++ b/ptx/src/test/ll/shf_r_clamp.ll @@ -0,0 +1,50 @@ +define amdgpu_kernel void @shf_r_clamp(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"45" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"45", ptr addrspace(5) %"39", align 8 + %"46" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"46", ptr addrspace(5) %"40", align 8 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"59" = inttoptr i64 %"48" to ptr + %"47" = load i32, ptr %"59", align 4 + store i32 %"47", ptr addrspace(5) %"41", align 4 + %"49" = load i64, ptr addrspace(5) %"39", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"33" = getelementptr inbounds i8, ptr %"60", i64 4 + %"50" = load i32, ptr %"33", align 4 + store i32 %"50", ptr addrspace(5) %"42", align 4 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"61" = inttoptr i64 %"51" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 8 + %"52" = load i32, ptr %"35", align 4 + store i32 %"52", ptr addrspace(5) %"43", align 4 + %"54" = load i32, ptr addrspace(5) %"41", align 4 + %"55" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = load i32, ptr addrspace(5) %"43", align 4 + %2 = call i32 @llvm.fshr.i32(i32 %"55", i32 %"54", i32 %"56") + %3 = icmp uge i32 %"56", 32 + %"62" = select i1 %3, i32 %"55", i32 %2 + store i32 %"62", ptr addrspace(5) %"44", align 4 + %"57" = load i64, ptr addrspace(5) %"40", align 8 + %"58" = load i32, ptr addrspace(5) %"44", align 4 + %"63" = inttoptr i64 %"57" to ptr + store i32 %"58", ptr %"63", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.fshr.i32(i32, i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/ll/shf_r_wrap.ll b/ptx/src/test/ll/shf_r_wrap.ll new file mode 100644 index 0000000..86c255d --- /dev/null +++ b/ptx/src/test/ll/shf_r_wrap.ll @@ -0,0 +1,48 @@ +define amdgpu_kernel void @shf_r_wrap(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i32, align 4, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"45" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"45", ptr addrspace(5) %"39", align 8 + %"46" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"46", ptr addrspace(5) %"40", align 8 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"59" = inttoptr i64 %"48" to ptr + %"47" = load i32, ptr %"59", align 4 + store i32 %"47", ptr addrspace(5) %"41", align 4 + %"49" = load i64, ptr addrspace(5) %"39", align 8 + %"60" = inttoptr i64 %"49" to ptr + %"33" = getelementptr inbounds i8, ptr %"60", i64 4 + %"50" = load i32, ptr %"33", align 4 + store i32 %"50", ptr addrspace(5) %"42", align 4 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"61" = inttoptr i64 %"51" to ptr + %"35" = getelementptr inbounds i8, ptr %"61", i64 8 + %"52" = load i32, ptr %"35", align 4 + store i32 %"52", ptr addrspace(5) %"43", align 4 + %"54" = load i32, ptr addrspace(5) %"41", align 4 + %"55" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = load i32, ptr addrspace(5) %"43", align 4 + %"62" = call i32 @llvm.fshr.i32(i32 %"55", i32 %"54", i32 %"56") + store i32 %"62", ptr addrspace(5) %"44", align 4 + %"57" = load i64, ptr addrspace(5) %"40", align 8 + %"58" = load i32, ptr addrspace(5) %"44", align 4 + %"63" = inttoptr i64 %"57" to ptr + store i32 %"58", ptr %"63", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.fshr.i32(i32, i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 7f7a424..095a3d3 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -302,6 +302,28 @@ test_ptx!(tanh, [f32::INFINITY], [1.0f32]); test_ptx!(cp_async, [0u32], [1u32, 2u32, 3u32, 0u32]); test_ptx!(nanosleep, [0u64], [0u64]); +test_ptx!(shf_l, [0x12345678u32, 0x9abcdef0u32, 12], [0xcdef0123u32]); +test_ptx!(shf_r, [0x12345678u32, 0x9abcdef0u32, 12], [0xef012345u32]); +test_ptx!( + shf_l_clamp, + [0x12345678u32, 0x9abcdef0u32, 44], + [0x12345678u32] +); +test_ptx!( + shf_r_clamp, + [0x12345678u32, 0x9abcdef0u32, 44], + [0x9abcdef0u32] +); +test_ptx!( + shf_l_wrap, + [0x12345678u32, 0x9abcdef0u32, 44], + [0xcdef0123u32] +); +test_ptx!( + shf_r_wrap, + [0x12345678u32, 0x9abcdef0u32, 44], + [0xef012345u32] +); test_ptx!(assertfail); // TODO: not yet supported diff --git a/ptx/src/test/spirv_run/shf_l.ptx b/ptx/src/test/spirv_run/shf_l.ptx new file mode 100644 index 0000000..16fc854 --- /dev/null +++ b/ptx/src/test/spirv_run/shf_l.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_32 +.address_size 64 + +.visible .entry shf_l( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 in_a; + .reg .b32 in_b; + .reg .b32 in_c; + .reg .u32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 in_a, [in_addr]; + ld.b32 in_b, [in_addr+4]; + ld.b32 in_c, [in_addr+8]; + + shf.l.clamp.b32 result, in_a, in_b, in_c; + + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/shf_l_clamp.ptx b/ptx/src/test/spirv_run/shf_l_clamp.ptx new file mode 100644 index 0000000..d940ce2 --- /dev/null +++ b/ptx/src/test/spirv_run/shf_l_clamp.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_32 +.address_size 64 + +.visible .entry shf_l_clamp( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 in_a; + .reg .b32 in_b; + .reg .b32 in_c; + .reg .u32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 in_a, [in_addr]; + ld.b32 in_b, [in_addr+4]; + ld.b32 in_c, [in_addr+8]; + + shf.l.clamp.b32 result, in_a, in_b, in_c; + + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/shf_l_wrap.ptx b/ptx/src/test/spirv_run/shf_l_wrap.ptx new file mode 100644 index 0000000..6c3c8cc --- /dev/null +++ b/ptx/src/test/spirv_run/shf_l_wrap.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_32 +.address_size 64 + +.visible .entry shf_l_wrap( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 in_a; + .reg .b32 in_b; + .reg .b32 in_c; + .reg .u32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 in_a, [in_addr]; + ld.b32 in_b, [in_addr+4]; + ld.b32 in_c, [in_addr+8]; + + shf.l.wrap.b32 result, in_a, in_b, in_c; + + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/shf_r.ptx b/ptx/src/test/spirv_run/shf_r.ptx new file mode 100644 index 0000000..4e0a17b --- /dev/null +++ b/ptx/src/test/spirv_run/shf_r.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_32 +.address_size 64 + +.visible .entry shf_r( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 in_a; + .reg .b32 in_b; + .reg .b32 in_c; + .reg .u32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 in_a, [in_addr]; + ld.b32 in_b, [in_addr+4]; + ld.b32 in_c, [in_addr+8]; + + shf.r.clamp.b32 result, in_a, in_b, in_c; + + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/shf_r_clamp.ptx b/ptx/src/test/spirv_run/shf_r_clamp.ptx new file mode 100644 index 0000000..9e59d74 --- /dev/null +++ b/ptx/src/test/spirv_run/shf_r_clamp.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_32 +.address_size 64 + +.visible .entry shf_r_clamp( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 in_a; + .reg .b32 in_b; + .reg .b32 in_c; + .reg .u32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 in_a, [in_addr]; + ld.b32 in_b, [in_addr+4]; + ld.b32 in_c, [in_addr+8]; + + shf.r.clamp.b32 result, in_a, in_b, in_c; + + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/shf_r_wrap.ptx b/ptx/src/test/spirv_run/shf_r_wrap.ptx new file mode 100644 index 0000000..abc2fbd --- /dev/null +++ b/ptx/src/test/spirv_run/shf_r_wrap.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_32 +.address_size 64 + +.visible .entry shf_r_wrap( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 in_a; + .reg .b32 in_b; + .reg .b32 in_c; + .reg .u32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 in_a, [in_addr]; + ld.b32 in_b, [in_addr+4]; + ld.b32 in_c, [in_addr+8]; + + shf.r.wrap.b32 result, in_a, in_b, in_c; + + st.b32 [out_addr], result; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index d140adb..464423f 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -2,7 +2,9 @@ use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; -use crate::{Mul24Control, PtxError, PtxParserState, Reduction, ShuffleMode}; +use crate::{ + FunnelShiftMode, Mul24Control, PtxError, PtxParserState, Reduction, ShiftDirection, ShuffleMode, +}; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; @@ -549,6 +551,16 @@ ptx_parser_macros::generate_instruction_type!( src_membermask: T } }, + Shf { + data: ShfDetails, + type: Type::Scalar(ScalarType::B32), + arguments: { + dst: T, + src_a: T, + src_b: T, + src_c: T + } + }, Shl { data: ScalarType, type: { Type::Scalar(data.clone()) }, @@ -1103,6 +1115,11 @@ pub struct CpAsyncDetails { pub src_size: Option, } +pub struct ShfDetails { + pub direction: ShiftDirection, + pub mode: FunnelShiftMode, +} + #[derive(Clone)] pub enum ParsedOperand { Reg(Ident), diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 4b79d47..886836a 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1739,6 +1739,12 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum ShuffleMode { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum ShiftDirection { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum FunnelShiftMode { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3625,6 +3631,17 @@ derive_parser!( cp.async.wait_all => { Instruction::CpAsyncWaitAll {} } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-shf + shf.dir.mode.b32 d, a, b, c => { + Instruction::Shf { + data: ShfDetails { direction: dir, mode: mode }, + arguments: ShfArgs { dst: d, src_a: a, src_b: b, src_c: c } + } + } + + .dir: ShiftDirection = { .l, .r }; + .mode: FunnelShiftMode = { .clamp, .wrap }; ); #[cfg(test)]