Implement redux.sync for u32 and s32 (#500)

This commit is contained in:
Violet 2025-09-08 16:13:28 -07:00 committed by GitHub
commit d342e1a06e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 394 additions and 3 deletions

Binary file not shown.

View file

@ -562,4 +562,19 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
{ {
return ballot(value, true); return ballot(value, true);
} }
#define REDUX_SYNC_TYPE_IMPL(reducer, ptx_type, amd_type, cpp_type) \
cpp_type __ockl_wfred_##reducer##_##amd_type(cpp_type) __device__; \
cpp_type FUNC(redux_sync_##reducer##_##ptx_type)(cpp_type src, uint32_t membermask __attribute__((unused))) \
{ \
return __ockl_wfred_##reducer##_##amd_type(src); \
}
#define REDUX_SYNC_IMPL(reducer) \
REDUX_SYNC_TYPE_IMPL(reducer, u32, u32, uint32_t) \
REDUX_SYNC_TYPE_IMPL(reducer, s32, i32, int32_t)
REDUX_SYNC_IMPL(add);
REDUX_SYNC_IMPL(min);
REDUX_SYNC_IMPL(max);
} }

View file

@ -195,7 +195,8 @@ fn run_instruction<'input>(
| ast::Instruction::Tanh { .. } | ast::Instruction::Tanh { .. }
| ast::Instruction::Trap {} | ast::Instruction::Trap {}
| ast::Instruction::Xor { .. } | ast::Instruction::Xor { .. }
| ast::Instruction::Vote { .. } => result.push(Statement::Instruction(instruction)), | ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add { ast::Instruction::Add {
data: data:
ast::ArithDetails::Float(ast::ArithFloat { ast::ArithDetails::Float(ast::ArithFloat {

View file

@ -1853,7 +1853,8 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Mul24 { .. } | ast::Instruction::Mul24 { .. }
| ast::Instruction::Nanosleep { .. } | ast::Instruction::Nanosleep { .. }
| ast::Instruction::AtomCas { .. } | ast::Instruction::AtomCas { .. }
| ast::Instruction::Vote { .. } => InstructionModes::none(), | ast::Instruction::Vote { .. }
| ast::Instruction::ReduxSync { .. } => InstructionModes::none(),
ast::Instruction::Add { ast::Instruction::Add {
data: ast::ArithDetails::Integer(_), data: ast::ArithDetails::Integer(_),
.. ..

View file

@ -526,7 +526,8 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::Activemask { .. } | ast::Instruction::Activemask { .. }
| ast::Instruction::ShflSync { .. } | ast::Instruction::ShflSync { .. }
| ast::Instruction::Vote { .. } | ast::Instruction::Vote { .. }
| ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), | ast::Instruction::Nanosleep { .. }
| ast::Instruction::ReduxSync { .. } => return Err(error_unreachable()),
} }
} }

View file

@ -377,6 +377,7 @@ fn run_instruction<'input>(
let name = match data.pred_reduction { let name = match data.pred_reduction {
ptx_parser::Reduction::And => "bar_red_and_pred", ptx_parser::Reduction::And => "bar_red_and_pred",
ptx_parser::Reduction::Or => "bar_red_or_pred", ptx_parser::Reduction::Or => "bar_red_or_pred",
_ => return Err(error_unreachable()),
}; };
to_call( to_call(
resolver, resolver,
@ -400,6 +401,25 @@ fn run_instruction<'input>(
ptx_parser::Instruction::Vote { data, arguments }, ptx_parser::Instruction::Vote { data, arguments },
)? )?
} }
ptx_parser::Instruction::ReduxSync { data, arguments } => {
let op = match data.reduction {
ptx_parser::Reduction::Add => "add",
ptx_parser::Reduction::Min => "min",
ptx_parser::Reduction::Max => "max",
_ => return Err(error_unreachable()),
};
let name = format!(
"redux_sync_{}_{}",
op,
data.type_.to_string().replace(".", "")
);
to_call(
resolver,
fn_declarations,
name.into(),
ptx_parser::Instruction::ReduxSync { data, arguments },
)?
}
ptx_parser::Instruction::ShflSync { ptx_parser::Instruction::ShflSync {
data, data,
arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. }, arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. },

View file

@ -0,0 +1,58 @@
declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @redux_sync_add_u32_partial(ptr addrspace(4) byref(i64) %"46") #1 {
%"47" = alloca i32, align 4, addrspace(5)
%"48" = alloca i32, align 4, addrspace(5)
%"49" = alloca i64, align 8, addrspace(5)
%"50" = alloca i32, align 4, addrspace(5)
%"51" = alloca i1, align 1, addrspace(5)
%"62" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"43"
"43": ; preds = %1
%"52" = load i64, ptr addrspace(4) %"46", align 8
store i64 %"52", ptr addrspace(5) %"49", align 8
%"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"44"
"44": ; preds = %"43"
store i32 %"37", ptr addrspace(5) %"47", align 4
%"55" = load i32, ptr addrspace(5) %"47", align 4
%"54" = urem i32 %"55", 2
store i32 %"54", ptr addrspace(5) %"50", align 4
%"57" = load i32, ptr addrspace(5) %"50", align 4
%2 = icmp eq i32 %"57", 0
store i1 %2, ptr addrspace(5) %"51", align 1
store i32 0, ptr addrspace(5) %"48", align 4
%"59" = load i1, ptr addrspace(5) %"51", align 1
br i1 %"59", label %"16", label %"17"
"16": ; preds = %"44"
%"61" = load i32, ptr addrspace(5) %"47", align 4
%"60" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"61", i32 1431655765)
store i32 %"60", ptr addrspace(5) %"48", align 4
br label %"17"
"17": ; preds = %"16", %"44"
%"64" = load i32, ptr addrspace(5) %"47", align 4
%3 = zext i32 %"64" to i64
%"63" = mul i64 %3, 4
store i64 %"63", ptr addrspace(5) %"62", align 8
%"66" = load i64, ptr addrspace(5) %"49", align 8
%"67" = load i64, ptr addrspace(5) %"62", align 8
%"65" = add i64 %"66", %"67"
store i64 %"65", ptr addrspace(5) %"49", align 8
%"68" = load i64, ptr addrspace(5) %"49", align 8
%"69" = load i32, ptr addrspace(5) %"48", align 4
%"70" = inttoptr i64 %"68" to ptr
store i32 %"69", ptr %"70", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,67 @@
declare hidden i32 @__zluda_ptx_impl_redux_sync_min_s32(i32, i32) #0
declare hidden i32 @__zluda_ptx_impl_redux_sync_max_s32(i32, i32) #0
declare hidden i32 @__zluda_ptx_impl_redux_sync_add_s32(i32, i32) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @redux_sync_op_s32(ptr addrspace(4) byref(i64) %"46") #1 {
%"47" = alloca i32, align 4, addrspace(5)
%"48" = alloca i32, align 4, addrspace(5)
%"49" = alloca i32, align 4, addrspace(5)
%"50" = alloca i32, align 4, addrspace(5)
%"51" = alloca i32, align 4, addrspace(5)
%"52" = alloca i32, align 4, addrspace(5)
%"53" = alloca i64, align 8, addrspace(5)
%"70" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"43"
"43": ; preds = %1
%"54" = load i64, ptr addrspace(4) %"46", align 8
store i64 %"54", ptr addrspace(5) %"53", align 8
%"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"44"
"44": ; preds = %"43"
store i32 %"37", ptr addrspace(5) %"47", align 4
%"57" = load i32, ptr addrspace(5) %"47", align 4
%"56" = sub i32 %"57", 5
store i32 %"56", ptr addrspace(5) %"48", align 4
%"59" = load i32, ptr addrspace(5) %"48", align 4
%"58" = call i32 @__zluda_ptx_impl_redux_sync_add_s32(i32 %"59", i32 -1)
store i32 %"58", ptr addrspace(5) %"49", align 4
%"61" = load i32, ptr addrspace(5) %"48", align 4
%"60" = call i32 @__zluda_ptx_impl_redux_sync_min_s32(i32 %"61", i32 -1)
store i32 %"60", ptr addrspace(5) %"50", align 4
%"63" = load i32, ptr addrspace(5) %"48", align 4
%"62" = call i32 @__zluda_ptx_impl_redux_sync_max_s32(i32 %"63", i32 -1)
store i32 %"62", ptr addrspace(5) %"51", align 4
%"65" = load i32, ptr addrspace(5) %"49", align 4
%"66" = load i32, ptr addrspace(5) %"50", align 4
%"64" = add i32 %"65", %"66"
store i32 %"64", ptr addrspace(5) %"52", align 4
%"68" = load i32, ptr addrspace(5) %"52", align 4
%"69" = load i32, ptr addrspace(5) %"51", align 4
%"67" = add i32 %"68", %"69"
store i32 %"67", ptr addrspace(5) %"52", align 4
%"72" = load i32, ptr addrspace(5) %"47", align 4
%2 = zext i32 %"72" to i64
%"71" = mul i64 %2, 4
store i64 %"71", ptr addrspace(5) %"70", align 8
%"74" = load i64, ptr addrspace(5) %"53", align 8
%"75" = load i64, ptr addrspace(5) %"70", align 8
%"73" = add i64 %"74", %"75"
store i64 %"73", ptr addrspace(5) %"53", align 8
%"76" = load i64, ptr addrspace(5) %"53", align 8
%"77" = load i32, ptr addrspace(5) %"52", align 4
%"79" = inttoptr i64 %"76" to ptr
store i32 %"77", ptr %"79", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,63 @@
declare hidden i32 @__zluda_ptx_impl_redux_sync_max_u32(i32, i32) #0
declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0
declare hidden i32 @__zluda_ptx_impl_redux_sync_min_u32(i32, i32) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @redux_sync_op_u32(ptr addrspace(4) byref(i64) %"44") #1 {
%"45" = alloca i32, align 4, addrspace(5)
%"46" = alloca i32, align 4, addrspace(5)
%"47" = alloca i32, align 4, addrspace(5)
%"48" = alloca i32, align 4, addrspace(5)
%"49" = alloca i32, align 4, addrspace(5)
%"50" = alloca i64, align 8, addrspace(5)
%"65" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"41"
"41": ; preds = %1
%"51" = load i64, ptr addrspace(4) %"44", align 8
store i64 %"51", ptr addrspace(5) %"50", align 8
%"36" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"42"
"42": ; preds = %"41"
store i32 %"36", ptr addrspace(5) %"45", align 4
%"54" = load i32, ptr addrspace(5) %"45", align 4
%"53" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"54", i32 -1)
store i32 %"53", ptr addrspace(5) %"46", align 4
%"56" = load i32, ptr addrspace(5) %"45", align 4
%"55" = call i32 @__zluda_ptx_impl_redux_sync_min_u32(i32 %"56", i32 -1)
store i32 %"55", ptr addrspace(5) %"47", align 4
%"58" = load i32, ptr addrspace(5) %"45", align 4
%"57" = call i32 @__zluda_ptx_impl_redux_sync_max_u32(i32 %"58", i32 -1)
store i32 %"57", ptr addrspace(5) %"48", align 4
%"60" = load i32, ptr addrspace(5) %"46", align 4
%"61" = load i32, ptr addrspace(5) %"47", align 4
%"59" = add i32 %"60", %"61"
store i32 %"59", ptr addrspace(5) %"49", align 4
%"63" = load i32, ptr addrspace(5) %"49", align 4
%"64" = load i32, ptr addrspace(5) %"48", align 4
%"62" = add i32 %"63", %"64"
store i32 %"62", ptr addrspace(5) %"49", align 4
%"67" = load i32, ptr addrspace(5) %"45", align 4
%2 = zext i32 %"67" to i64
%"66" = mul i64 %2, 4
store i64 %"66", ptr addrspace(5) %"65", align 8
%"69" = load i64, ptr addrspace(5) %"50", align 8
%"70" = load i64, ptr addrspace(5) %"65", align 8
%"68" = add i64 %"69", %"70"
store i64 %"68", ptr addrspace(5) %"50", align 8
%"71" = load i64, ptr addrspace(5) %"50", align 8
%"72" = load i32, ptr addrspace(5) %"49", align 4
%"73" = inttoptr i64 %"71" to ptr
store i32 %"72", ptr %"73", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -452,6 +452,40 @@ test_ptx_warp!(
4294967292, 4294967292, 4294967292, 4294967292, 4294967292 4294967292, 4294967292, 4294967292, 4294967292, 4294967292
] ]
); );
test_ptx_warp!(
redux_sync_op_s32,
[
357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32,
357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32,
357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 1445i32,
1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32,
1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32,
1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32,
1445i32,
]
);
test_ptx_warp!(
redux_sync_op_u32,
[
527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32,
527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32,
527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 1615u32,
1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32,
1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32,
1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32,
1615u32,
]
);
test_ptx_warp!(
redux_sync_add_u32_partial,
[
240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32,
0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32,
240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32,
0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32,
752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32
]
);
struct DisplayError<T: Debug> { struct DisplayError<T: Debug> {
err: T, err: T,

View file

@ -0,0 +1,31 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry redux_sync_add_u32_partial(
.param .u64 output
)
{
.reg .u32 tid;
.reg .u32 result;
.reg .u64 out_ptr;
.reg .u32 tid_rem_2;
.reg .pred p;
ld.param.u64 out_ptr, [output];
mov.u32 tid, %tid.x;
rem.u32 tid_rem_2, tid, 2;
setp.eq.u32 p, tid_rem_2, 0;
mov.u32 result, 0;
@p redux.sync.add.u32 result, tid, 0x55555555;
.reg .u64 out_offset;
mul.wide.u32 out_offset, tid, 4;
add.u64 out_ptr, out_ptr, out_offset;
st.u32 [out_ptr], result;
ret;
}

View file

@ -0,0 +1,34 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry redux_sync_op_s32(
.param .u64 output
)
{
.reg .u32 tid;
.reg .s32 in;
.reg .s32 add_out;
.reg .s32 min_out;
.reg .s32 max_out;
.reg .s32 result;
.reg .u64 out_ptr;
ld.param.u64 out_ptr, [output];
mov.u32 tid, %tid.x;
sub.s32 in, tid, 5;
redux.sync.add.s32 add_out, in, 0xFFFFFFFF;
redux.sync.min.s32 min_out, in, 0xFFFFFFFF;
redux.sync.max.s32 max_out, in, 0xFFFFFFFF;
add.s32 result, add_out, min_out;
add.s32 result, result, max_out;
.reg .u64 out_offset;
mul.wide.u32 out_offset, tid, 4;
add.u64 out_ptr, out_ptr, out_offset;
st.s32 [out_ptr], result;
ret;
}

View file

@ -0,0 +1,32 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry redux_sync_op_u32(
.param .u64 output
)
{
.reg .u32 tid;
.reg .u32 add_out;
.reg .u32 min_out;
.reg .u32 max_out;
.reg .u32 result;
.reg .u64 out_ptr;
ld.param.u64 out_ptr, [output];
mov.u32 tid, %tid.x;
redux.sync.add.u32 add_out, tid, 0xFFFFFFFF;
redux.sync.min.u32 min_out, tid, 0xFFFFFFFF;
redux.sync.max.u32 max_out, tid, 0xFFFFFFFF;
add.u32 result, add_out, min_out;
add.u32 result, result, max_out;
.reg .u64 out_offset;
mul.wide.u32 out_offset, tid, 4;
add.u64 out_ptr, out_ptr, out_offset;
st.u32 [out_ptr], result;
ret;
}

View file

@ -695,6 +695,18 @@ ptx_parser_macros::generate_instruction_type!(
} }
} }
},
ReduxSync {
type: Type::Scalar(data.type_),
data: ReduxSyncData,
arguments<T>: {
dst: T,
src: T,
src_membermask: {
repr: T,
type: { Type::Scalar(ScalarType::U32) },
}
}
} }
} }
); );
@ -2272,3 +2284,8 @@ impl VoteMode {
} }
} }
} }
pub struct ReduxSyncData {
pub type_: ScalarType,
pub reduction: Reduction,
}

View file

@ -3844,6 +3844,23 @@ derive_parser!(
// .mode: VoteMode = { .all, .any, .uni }; // .mode: VoteMode = { .all, .any, .uni };
.mode: VoteMode = { .all, .any }; .mode: VoteMode = { .all, .any };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync
redux.sync.op.type dst, src, membermask => {
Instruction::ReduxSync {
data: ReduxSyncData { type_, reduction: op },
arguments: ReduxSyncArgs { dst, src, src_membermask: membermask }
}
}
.op: Reduction = {.add, .min, .max};
.type: ScalarType = {.u32, .s32};
// redux.sync.op.b32 dst, src, membermask;
// .op = {.and, .or, .xor}
// redux.sync.op{.abs.}{.NaN}.f32 dst, src, membermask;
// .op = { .min, .max }
); );
#[cfg(test)] #[cfg(test)]