Add warp-wide tests (#400)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run

This commit is contained in:
Violet 2025-07-02 18:11:36 -07:00 committed by GitHub
commit 7bdd20f0dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 172 additions and 63 deletions

39
ptx/src/test/ll/tid.ll Normal file
View file

@ -0,0 +1,39 @@
declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @tid(ptr addrspace(4) byref(i64) %"34") #1 {
%"35" = alloca i64, align 8, addrspace(5)
%"36" = alloca i32, align 4, addrspace(5)
%"37" = alloca i64, align 8, addrspace(5)
%"38" = alloca i8, align 1, addrspace(5)
br label %1
1: ; preds = %0
br label %"31"
"31": ; preds = %1
%"30" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"32"
"32": ; preds = %"31"
store i32 %"30", ptr addrspace(5) %"36", align 4
%"41" = load i32, ptr addrspace(5) %"36", align 4
%"40" = zext i32 %"41" to i64
store i64 %"40", ptr addrspace(5) %"37", align 4
%"43" = load i32, ptr addrspace(5) %"36", align 4
%"42" = trunc i32 %"43" to i8
store i8 %"42", ptr addrspace(5) %"38", align 1
%"44" = load i64, ptr addrspace(4) %"34", align 4
store i64 %"44", ptr addrspace(5) %"35", align 4
%"46" = load i64, ptr addrspace(5) %"35", align 4
%"47" = load i64, ptr addrspace(5) %"37", align 4
%"45" = add i64 %"46", %"47"
store i64 %"45", ptr addrspace(5) %"35", align 4
%"48" = load i64, ptr addrspace(5) %"35", align 4
%"49" = load i8, ptr addrspace(5) %"38", align 1
%"50" = inttoptr i64 %"48" to ptr
store i8 %"49", ptr %"50", align 1
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -36,28 +36,8 @@ macro_rules! read_test_file {
};
}
macro_rules! test_ptx {
($fn_name:ident, $input:expr, $output:expr) => {
paste::item! {
#[test]
fn [<$fn_name _hip>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let mut output = $output;
test_hip_assert(stringify!($fn_name), &ptx, &input, &mut output)
}
}
paste::item! {
#[test]
fn [<$fn_name _cuda>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let mut output = $output;
test_cuda_assert(stringify!($fn_name), &ptx, &input, &mut output)
}
}
macro_rules! test_ptx_llvm {
($fn_name:ident) => {
paste::item! {
#[test]
fn [<$fn_name _llvm>]() -> Result<(), Box<dyn std::error::Error>> {
@ -66,17 +46,60 @@ macro_rules! test_ptx {
test_llvm_assert(stringify!($fn_name), &ptx, ll.trim())
}
}
}
}
macro_rules! test_ptx {
($fn_name:ident, $input:expr, $output:expr) => {
paste::item! {
#[test]
fn [<$fn_name _hip>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let output = $output;
test_hip_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1)
}
}
paste::item! {
#[test]
fn [<$fn_name _cuda>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let output = $output;
test_cuda_assert(stringify!($fn_name), &ptx, Some(&input), &output, 1)
}
}
test_ptx_llvm!($fn_name);
};
($fn_name:ident) => {
test_ptx_llvm!($fn_name);
};
}
macro_rules! test_ptx_warp {
($fn_name:ident, $output:expr) => {
paste::item! {
#[test]
fn [<$fn_name _llvm>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
let ll = include_str!(concat!("../ll/", stringify!($fn_name), ".ll")).trim();
test_llvm_assert(stringify!($fn_name), ptx, &ll)
fn [<$fn_name _hip>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let mut output = $output;
test_hip_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64)
}
}
paste::item! {
#[test]
fn [<$fn_name _cuda>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx = read_test_file!(concat!(stringify!($fn_name), ".ptx"));
let mut output = $output;
test_cuda_assert(stringify!($fn_name), &ptx, None::<&[u8]>, &mut output, 64)
}
}
test_ptx_llvm!($fn_name);
};
}
@ -278,6 +301,12 @@ test_ptx!(assertfail);
test_ptx!(lanemask_lt);
test_ptx!(extern_func);
test_ptx_warp!(tid, [
0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, 15u8,
16u8, 17u8, 18u8, 19u8, 20u8, 21u8, 22u8, 23u8, 24u8, 25u8, 26u8, 27u8, 28u8, 29u8, 30u8, 31u8,
32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8,
48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8,
]);
struct DisplayError<T: Debug> {
err: T,
}
@ -302,14 +331,15 @@ fn test_hip_assert<
>(
name: &str,
ptx_text: &str,
input: &[Input],
output: &mut [Output],
input: Option<&[Input]>,
output: &[Output],
block_dim_x: u32,
) -> Result<(), Box<dyn error::Error>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap();
let name = CString::new(name)?;
let result =
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;
run_hip(name.as_c_str(), llvm_ir, input, output, block_dim_x).map_err(|err| DisplayError { err })?;
assert_eq!(result.as_slice(), output);
Ok(())
}
@ -344,11 +374,12 @@ fn test_cuda_assert<
>(
name: &str,
ptx_text: &str,
input: &[Input],
output: &mut [Output],
input: Option<&[Input]>,
output: &[Output],
block_dim_x: u32,
) -> Result<(), Box<dyn error::Error>> {
let name = CString::new(name)?;
let result = run_cuda(name.as_c_str(), ptx_text, input, output);
let result = run_cuda(name.as_c_str(), ptx_text, input, output, block_dim_x);
assert_eq!(result.as_slice(), output);
Ok(())
}
@ -356,8 +387,9 @@ fn test_cuda_assert<
fn run_cuda<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Default>(
name: &CStr,
ptx_module: &str,
input: &[Input],
output: &mut [Output],
input: Option<&[Input]>,
output: &[Output],
block_dim_x: u32,
) -> Vec<Output> {
unsafe { CUDA.cuInit(0) }.unwrap().unwrap();
let ptx_module = CString::new(ptx_module).unwrap();
@ -375,14 +407,15 @@ fn run_cuda<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + De
unsafe { CUDA.cuModuleGetFunction(&mut kernel, module, name.as_ptr()) }
.unwrap()
.unwrap();
let mut inp_b = unsafe { mem::zeroed() };
unsafe { CUDA.cuMemAlloc_v2(&mut inp_b, input.len() * mem::size_of::<Input>()) }
.unwrap()
.unwrap();
let mut out_b = unsafe { mem::zeroed() };
unsafe { CUDA.cuMemAlloc_v2(&mut out_b, output.len() * mem::size_of::<Output>()) }
.unwrap()
.unwrap();
let mut inp_b = unsafe { mem::zeroed() };
if let Some(input) = input {
unsafe { CUDA.cuMemAlloc_v2(&mut inp_b, input.len() * mem::size_of::<Input>()) }
.unwrap()
.unwrap();
unsafe {
CUDA.cuMemcpyHtoD_v2(
inp_b,
@ -392,17 +425,22 @@ fn run_cuda<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + De
}
.unwrap()
.unwrap();
}
unsafe { CUDA.cuMemsetD8_v2(out_b, 0, output.len() * mem::size_of::<Output>()) }
.unwrap()
.unwrap();
let mut args = [&inp_b, &out_b];
let mut args = if input.is_some() {
[&inp_b, &out_b]
} else {
[&out_b, &out_b]
};
unsafe {
CUDA.cuLaunchKernel(
kernel,
1,
1,
1,
1,
block_dim_x,
1,
1,
1024,
@ -472,8 +510,9 @@ static CUDA: std::sync::LazyLock<DynamicCuda> =
fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Default>(
name: &CStr,
module: pass::Module,
input: &[Input],
output: &mut [Output],
input: Option<&[Input]>,
output: &[Output],
block_dim_x: u32,
) -> Result<Vec<Output>, hipError_t> {
use hip_runtime_sys::*;
unsafe { hipInit(0) }.unwrap();
@ -496,10 +535,11 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap();
let mut kernel = unsafe { mem::zeroed() };
unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap();
let mut inp_b = ptr::null_mut();
unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::<Input>()) }.unwrap();
let mut out_b = ptr::null_mut();
unsafe { hipMalloc(&mut out_b, output.len() * mem::size_of::<Output>()) }.unwrap();
let mut inp_b = ptr::null_mut();
if let Some(input) = input {
unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::<Input>()) }.unwrap();
unsafe {
hipMemcpyWithStream(
inp_b,
@ -510,15 +550,20 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
)
}
.unwrap();
}
unsafe { hipMemset(out_b, 0, output.len() * mem::size_of::<Output>()) }.unwrap();
let mut args = [&inp_b, &out_b];
let mut args = if input.is_some() {
[&inp_b, &out_b]
} else {
[&out_b, &out_b]
};
unsafe {
hipModuleLaunchKernel(
kernel,
1,
1,
1,
1,
block_dim_x,
1,
1,
1024,

View file

@ -0,0 +1,25 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry tid(
.param .u64 output
)
{
.reg .u64 out_addr;
.reg .u32 thread_id;
.reg .u64 thread_id_u64;
.reg .u8 thread_id_u8;
mov.u32 thread_id, %tid.x;
cvt.u64.u32 thread_id_u64, thread_id;
cvt.u8.u32 thread_id_u8, thread_id;
ld.param.u64 out_addr, [output];
add.u64 out_addr, out_addr, thread_id_u64;
st.u8 [out_addr], thread_id_u8;
ret;
}