From 7bdd20f0dd4c145e257c4640bafab3cd4ec49717 Mon Sep 17 00:00:00 2001 From: Violet Date: Wed, 2 Jul 2025 18:11:36 -0700 Subject: [PATCH] Add warp-wide tests (#400) --- ptx/src/test/ll/tid.ll | 39 ++++++++ ptx/src/test/spirv_run/mod.rs | 171 +++++++++++++++++++++------------ ptx/src/test/spirv_run/tid.ptx | 25 +++++ 3 files changed, 172 insertions(+), 63 deletions(-) create mode 100644 ptx/src/test/ll/tid.ll create mode 100644 ptx/src/test/spirv_run/tid.ptx diff --git a/ptx/src/test/ll/tid.ll b/ptx/src/test/ll/tid.ll new file mode 100644 index 0000000..b15f372 --- /dev/null +++ b/ptx/src/test/ll/tid.ll @@ -0,0 +1,39 @@ +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @tid(ptr addrspace(4) byref(i64) %"34") #1 { + %"35" = alloca i64, align 8, addrspace(5) + %"36" = alloca i32, align 4, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + %"38" = alloca i8, align 1, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"31" + +"31": ; preds = %1 + %"30" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"32" + +"32": ; preds = %"31" + store i32 %"30", ptr addrspace(5) %"36", align 4 + %"41" = load i32, ptr addrspace(5) %"36", align 4 + %"40" = zext i32 %"41" to i64 + store i64 %"40", ptr addrspace(5) %"37", align 4 + %"43" = load i32, ptr addrspace(5) %"36", align 4 + %"42" = trunc i32 %"43" to i8 + store i8 %"42", ptr addrspace(5) %"38", align 1 + %"44" = load i64, ptr addrspace(4) %"34", align 4 + store i64 %"44", ptr addrspace(5) %"35", align 4 + %"46" = load i64, ptr addrspace(5) %"35", align 4 + %"47" = load i64, ptr addrspace(5) %"37", align 4 + %"45" = add i64 %"46", %"47" + store i64 %"45", ptr addrspace(5) %"35", align 4 + %"48" = load i64, ptr addrspace(5) %"35", align 4 + %"49" = load i8, ptr addrspace(5) %"38", align 1 + %"50" = inttoptr i64 %"48" to ptr + store i8 %"49", ptr %"50", align 1 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 1ac10ea..84e0731 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -36,28 +36,8 @@ macro_rules! read_test_file { }; } -macro_rules! test_ptx { - ($fn_name:ident, $input:expr, $output:expr) => { - paste::item! { - #[test] - fn [<$fn_name _hip>]() -> Result<(), Box> { - let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); - let input = $input; - let mut output = $output; - test_hip_assert(stringify!($fn_name), &ptx, &input, &mut output) - } - } - - paste::item! { - #[test] - fn [<$fn_name _cuda>]() -> Result<(), Box> { - let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); - let input = $input; - let mut output = $output; - test_cuda_assert(stringify!($fn_name), &ptx, &input, &mut output) - } - } - +macro_rules! test_ptx_llvm { + ($fn_name:ident) => { paste::item! { #[test] fn [<$fn_name _llvm>]() -> Result<(), Box> { @@ -66,17 +46,60 @@ macro_rules! test_ptx { test_llvm_assert(stringify!($fn_name), &ptx, ll.trim()) } } + } +} + +macro_rules! test_ptx { + ($fn_name:ident, $input:expr, $output:expr) => { + paste::item! { + #[test] + fn [<$fn_name _hip>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let input = $input; + let output = $output; + test_hip_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1) + } + } + + paste::item! { + #[test] + fn [<$fn_name _cuda>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let input = $input; + let output = $output; + test_cuda_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1) + } + } + + test_ptx_llvm!($fn_name); }; ($fn_name:ident) => { + test_ptx_llvm!($fn_name); + }; +} + +macro_rules! test_ptx_warp { + ($fn_name:ident, $output:expr) => { paste::item! { #[test] - fn [<$fn_name _llvm>]() -> Result<(), Box> { - let ptx = include_str!(concat!(stringify!($fn_name), ".ptx")); - let ll = include_str!(concat!("../ll/", stringify!($fn_name), ".ll")).trim(); - test_llvm_assert(stringify!($fn_name), ptx, &ll) + fn [<$fn_name _hip>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let mut output = $output; + test_hip_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64) } } + + paste::item! { + #[test] + fn [<$fn_name _cuda>]() -> Result<(), Box> { + let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx")); + let mut output = $output; + test_cuda_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64) + } + } + + test_ptx_llvm!($fn_name); }; } @@ -278,6 +301,12 @@ test_ptx!(assertfail); test_ptx!(lanemask_lt); test_ptx!(extern_func); +test_ptx_warp!(tid, [ + 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, 15u8, + 16u8, 17u8, 18u8, 19u8, 20u8, 21u8, 22u8, 23u8, 24u8, 25u8, 26u8, 27u8, 28u8, 29u8, 30u8, 31u8, + 32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8, + 48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8, +]); struct DisplayError { err: T, } @@ -302,14 +331,15 @@ fn test_hip_assert< >( name: &str, ptx_text: &str, - input: &[Input], - output: &mut [Output], + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); let llvm_ir = pass::to_llvm_module(ast).unwrap(); let name = CString::new(name)?; let result = - run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?; + run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x).map_err(|err| DisplayError { err })?; assert_eq!(result.as_slice(), output); Ok(()) } @@ -344,11 +374,12 @@ fn test_cuda_assert< >( name: &str, ptx_text: &str, - input: &[Input], - output: &mut [Output], + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, ) -> Result<(), Box> { let name = CString::new(name)?; - let result = run_cuda(name.as_c_str(), ptx_text, input, output); + let result = run_cuda(name.as_c_str(), ptx_text, input, output, block_dim_x); assert_eq!(result.as_slice(), output); Ok(()) } @@ -356,8 +387,9 @@ fn test_cuda_assert< fn run_cuda + Copy + Debug, Output: From + Copy + Debug + Default>( name: &CStr, ptx_module: &str, - input: &[Input], - output: &mut [Output], + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, ) -> Vec { unsafe { CUDA.cuInit(0) }.unwrap().unwrap(); let ptx_module = CString::new(ptx_module).unwrap(); @@ -375,34 +407,40 @@ fn run_cuda + Copy + Debug, Output: From + Copy + Debug + De unsafe { CUDA.cuModuleGetFunction(&mut kernel, module, name.as_ptr()) } .unwrap() .unwrap(); - let mut inp_b = unsafe { mem::zeroed() }; - unsafe { CUDA.cuMemAlloc_v2(&mut inp_b, input.len() * mem::size_of::()) } - .unwrap() - .unwrap(); let mut out_b = unsafe { mem::zeroed() }; unsafe { CUDA.cuMemAlloc_v2(&mut out_b, output.len() * mem::size_of::()) } .unwrap() .unwrap(); - unsafe { - CUDA.cuMemcpyHtoD_v2( - inp_b, - input.as_ptr() as _, - input.len() * mem::size_of::(), - ) + let mut inp_b = unsafe { mem::zeroed() }; + if let Some(input) = input { + unsafe { CUDA.cuMemAlloc_v2(&mut inp_b, input.len() * mem::size_of::()) } + .unwrap() + .unwrap(); + unsafe { + CUDA.cuMemcpyHtoD_v2( + inp_b, + input.as_ptr() as _, + input.len() * mem::size_of::(), + ) + } + .unwrap() + .unwrap(); } - .unwrap() - .unwrap(); unsafe { CUDA.cuMemsetD8_v2(out_b, 0, output.len() * mem::size_of::()) } .unwrap() .unwrap(); - let mut args = [&inp_b, &out_b]; + let mut args = if input.is_some() { + [&inp_b, &out_b] + } else { + [&out_b, &out_b] + }; unsafe { CUDA.cuLaunchKernel( kernel, 1, 1, 1, - 1, + block_dim_x, 1, 1, 1024, @@ -472,8 +510,9 @@ static CUDA: std::sync::LazyLock = fn run_hip + Copy + Debug, Output: From + Copy + Debug + Default>( name: &CStr, module: pass::Module, - input: &[Input], - output: &mut [Output], + input: Option<&[Input]>, + output: &[Output], + block_dim_x: u32, ) -> Result, hipError_t> { use hip_runtime_sys::*; unsafe { hipInit(0) }.unwrap(); @@ -496,29 +535,35 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap(); let mut kernel = unsafe { mem::zeroed() }; unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap(); - let mut inp_b = ptr::null_mut(); - unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::()) }.unwrap(); let mut out_b = ptr::null_mut(); unsafe { hipMalloc(&mut out_b, output.len() * mem::size_of::()) }.unwrap(); - unsafe { - hipMemcpyWithStream( - inp_b, - input.as_ptr() as _, - input.len() * mem::size_of::(), - hipMemcpyKind::hipMemcpyHostToDevice, - stream, - ) + let mut inp_b = ptr::null_mut(); + if let Some(input) = input { + unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::()) }.unwrap(); + unsafe { + hipMemcpyWithStream( + inp_b, + input.as_ptr() as _, + input.len() * mem::size_of::(), + hipMemcpyKind::hipMemcpyHostToDevice, + stream, + ) + } + .unwrap(); } - .unwrap(); unsafe { hipMemset(out_b, 0, output.len() * mem::size_of::()) }.unwrap(); - let mut args = [&inp_b, &out_b]; + let mut args = if input.is_some() { + [&inp_b, &out_b] + } else { + [&out_b, &out_b] + }; unsafe { hipModuleLaunchKernel( kernel, 1, 1, 1, - 1, + block_dim_x, 1, 1, 1024, diff --git a/ptx/src/test/spirv_run/tid.ptx b/ptx/src/test/spirv_run/tid.ptx new file mode 100644 index 0000000..014d0b6 --- /dev/null +++ b/ptx/src/test/spirv_run/tid.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry tid( + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u32 thread_id; + .reg .u64 thread_id_u64; + + .reg .u8 thread_id_u8; + + mov.u32 thread_id, %tid.x; + cvt.u64.u32 thread_id_u64, thread_id; + cvt.u8.u32 thread_id_u8, thread_id; + + ld.param.u64 out_addr, [output]; + + add.u64 out_addr, out_addr, thread_id_u64; + st.u8 [out_addr], thread_id_u8; + + ret; +}