Implement vote instruction and add support for %laneid (#484)

This commit is contained in:
Andrzej Janik 2025-08-29 03:23:09 +02:00 committed by GitHub
commit ea99dcc0b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 538 additions and 15 deletions

Binary file not shown.

View file

@ -48,6 +48,11 @@ extern "C"
return (uint32_t)__ockl_get_num_groups(member);
}
uint32_t FUNC(sreg_laneid)()
{
return __lane_id();
}
uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __device__;
uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32)
{
@ -519,4 +524,42 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
{
return in;
}
__device__ static inline uint32_t ballot(bool value, bool negate)
{
__builtin_amdgcn_wave_barrier();
return __builtin_amdgcn_ballot_w32(negate ? !value : value);
}
bool FUNC(vote_sync_any_pred)(bool value, uint32_t membermask __attribute__((unused)))
{
return ballot(value, false) != 0;
}
bool FUNC(vote_sync_any_pred_negate)(bool value, uint32_t membermask __attribute__((unused)))
{
return ballot(value, true) != 0;
}
// IMPORTANT: exec mask must be a subset of membermask, the behavior is undefined otherwise
bool FUNC(vote_sync_all_pred)(bool value, uint32_t membermask __attribute__((unused)))
{
return ballot(value, false) == __builtin_amdgcn_read_exec_lo();
}
// also known as "none"
bool FUNC(vote_sync_all_pred_negate)(bool value, uint32_t membermask __attribute__((unused)))
{
return ballot(value, false) == 0;
}
uint32_t FUNC(vote_sync_ballot_b32)(bool value, uint32_t membermask __attribute__((unused)))
{
return ballot(value, false);
}
uint32_t FUNC(vote_sync_ballot_b32_negate)(bool value, uint32_t membermask __attribute__((unused)))
{
return ballot(value, true);
}
}

View file

@ -2,13 +2,13 @@ use super::*;
pub(super) fn run<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
special_registers: &'a SpecialRegistersMap,
directives: Vec<UnconditionalDirective>,
) -> Result<Vec<UnconditionalDirective>, TranslateError> {
let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len());
let mut result = Vec::with_capacity(SpecialRegistersMap::len() + directives.len());
let mut sreg_to_function =
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default());
SpecialRegistersMap2::foreach_declaration(
FxHashMap::with_capacity_and_hasher(SpecialRegistersMap::len(), Default::default());
SpecialRegistersMap::foreach_declaration(
resolver,
|sreg, (return_arguments, name, input_arguments)| {
result.push(UnconditionalDirective::Method(UnconditionalFunction {
@ -80,7 +80,7 @@ fn run_statement<'a, 'input>(
struct SpecialRegisterResolver<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
special_registers: &'a SpecialRegistersMap2,
special_registers: &'a SpecialRegistersMap,
sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
result: Vec<UnconditionalStatement>,
}

View file

@ -194,7 +194,8 @@ fn run_instruction<'input>(
}
| ast::Instruction::Tanh { .. }
| ast::Instruction::Trap {}
| ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)),
| ast::Instruction::Xor { .. }
| ast::Instruction::Vote { .. } => result.push(Statement::Instruction(instruction)),
ast::Instruction::Add {
data:
ast::ArithDetails::Float(ast::ArithFloat {

View file

@ -1852,7 +1852,8 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
| ast::Instruction::Atom { .. }
| ast::Instruction::Mul24 { .. }
| ast::Instruction::Nanosleep { .. }
| ast::Instruction::AtomCas { .. } => InstructionModes::none(),
| ast::Instruction::AtomCas { .. }
| ast::Instruction::Vote { .. } => InstructionModes::none(),
ast::Instruction::Add {
data: ast::ArithDetails::Integer(_),
..

View file

@ -542,6 +542,7 @@ impl<'a> MethodEmitContext<'a> {
| ast::Instruction::Bfi { .. }
| ast::Instruction::Activemask { .. }
| ast::Instruction::ShflSync { .. }
| ast::Instruction::Vote { .. }
| ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()),
}
}

View file

@ -13,7 +13,7 @@ use strum_macros::EnumIter;
mod deparamize_functions;
mod expand_operands;
mod fix_special_registers2;
mod fix_special_registers;
mod hoist_globals;
mod insert_explicit_load_store;
mod insert_implicit_conversions2;
@ -63,12 +63,12 @@ pub fn to_llvm_module<'input>(
) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
let directives = replace_known_functions::run(&mut flat_resolver, directives);
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
@ -119,6 +119,7 @@ enum PtxSpecialRegister {
Nctaid,
Clock,
LanemaskLt,
Laneid,
}
impl PtxSpecialRegister {
@ -130,6 +131,7 @@ impl PtxSpecialRegister {
Self::Nctaid => "%nctaid",
Self::Clock => "%clock",
Self::LanemaskLt => "%lanemask_lt",
Self::Laneid => "%laneid",
}
}
@ -151,6 +153,7 @@ impl PtxSpecialRegister {
PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
PtxSpecialRegister::Clock => ast::ScalarType::U32,
PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
PtxSpecialRegister::Laneid => ast::ScalarType::U32,
}
}
@ -160,7 +163,9 @@ impl PtxSpecialRegister {
| PtxSpecialRegister::Ntid
| PtxSpecialRegister::Ctaid
| PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None,
PtxSpecialRegister::Clock
| PtxSpecialRegister::LanemaskLt
| PtxSpecialRegister::Laneid => None,
}
}
@ -172,6 +177,7 @@ impl PtxSpecialRegister {
PtxSpecialRegister::Nctaid => "sreg_nctaid",
PtxSpecialRegister::Clock => "sreg_clock",
PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
PtxSpecialRegister::Laneid => "sreg_laneid",
}
}
}
@ -885,14 +891,14 @@ impl<'input> ScopeMarker<'input> {
}
}
struct SpecialRegistersMap2 {
struct SpecialRegistersMap {
reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
}
impl SpecialRegistersMap2 {
impl SpecialRegistersMap {
fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
let mut result = SpecialRegistersMap2 {
let mut result = SpecialRegistersMap {
reg_to_id: FxHashMap::default(),
id_to_reg: FxHashMap::default(),
};

View file

@ -385,6 +385,21 @@ fn run_instruction<'input>(
ptx_parser::Instruction::BarRed { data, arguments },
)?
}
ptx_parser::Instruction::Vote { data, arguments } => {
let mode = match data.mode {
ptx_parser::VoteMode::Any => "any_pred",
ptx_parser::VoteMode::All => "all_pred",
ptx_parser::VoteMode::Ballot => "ballot_b32",
};
let negate = if data.negate { "_negate" } else { "" };
let name = format!("vote_sync_{mode}{negate}");
to_call(
resolver,
fn_declarations,
name.into(),
ptx_parser::Instruction::Vote { data, arguments },
)?
}
ptx_parser::Instruction::ShflSync {
data,
arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. },

View file

@ -0,0 +1,66 @@
declare hidden i1 @__zluda_ptx_impl_vote_sync_all_pred(i1, i32) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
declare hidden i32 @__zluda_ptx_impl_sreg_laneid() #0
define amdgpu_kernel void @vote_all(ptr addrspace(4) byref(i64) %"51") #1 {
%"52" = alloca i32, align 4, addrspace(5)
%"53" = alloca i32, align 4, addrspace(5)
%"54" = alloca i1, align 1, addrspace(5)
%"55" = alloca i1, align 1, addrspace(5)
%"56" = alloca i32, align 4, addrspace(5)
%"57" = alloca i64, align 8, addrspace(5)
%"69" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"46"
"46": ; preds = %1
%"58" = load i64, ptr addrspace(4) %"51", align 8
store i64 %"58", ptr addrspace(5) %"57", align 8
%"37" = call i32 @__zluda_ptx_impl_sreg_laneid()
br label %"47"
"47": ; preds = %"46"
store i32 %"37", ptr addrspace(5) %"52", align 4
%"39" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"48"
"48": ; preds = %"47"
store i32 %"39", ptr addrspace(5) %"53", align 4
%"62" = load i32, ptr addrspace(5) %"52", align 4
%2 = icmp ne i32 %"62", 0
store i1 %2, ptr addrspace(5) %"54", align 1
store i1 false, ptr addrspace(5) %"55", align 1
%"64" = load i1, ptr addrspace(5) %"54", align 1
br i1 %"64", label %"17", label %"18"
"17": ; preds = %"48"
%"66" = load i1, ptr addrspace(5) %"54", align 1
%"65" = call i1 @__zluda_ptx_impl_vote_sync_all_pred(i1 %"66", i32 -2)
store i1 %"65", ptr addrspace(5) %"55", align 1
br label %"18"
"18": ; preds = %"17", %"48"
%"68" = load i1, ptr addrspace(5) %"55", align 1
%"67" = select i1 %"68", i32 1, i32 0
store i32 %"67", ptr addrspace(5) %"56", align 4
%"71" = load i32, ptr addrspace(5) %"53", align 4
%3 = zext i32 %"71" to i64
%"70" = mul i64 %3, 4
store i64 %"70", ptr addrspace(5) %"69", align 8
%"73" = load i64, ptr addrspace(5) %"57", align 8
%"74" = load i64, ptr addrspace(5) %"69", align 8
%"72" = add i64 %"73", %"74"
store i64 %"72", ptr addrspace(5) %"57", align 8
%"75" = load i64, ptr addrspace(5) %"57", align 8
%"76" = load i32, ptr addrspace(5) %"56", align 4
%"77" = inttoptr i64 %"75" to ptr
store i32 %"76", ptr %"77", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,65 @@
declare hidden i1 @__zluda_ptx_impl_vote_sync_all_pred(i1, i32) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
declare hidden i32 @__zluda_ptx_impl_sreg_laneid() #0
define amdgpu_kernel void @vote_all_sub(ptr addrspace(4) byref(i64) %"53") #1 {
%"54" = alloca i32, align 4, addrspace(5)
%"55" = alloca i32, align 4, addrspace(5)
%"56" = alloca i1, align 1, addrspace(5)
%"57" = alloca i1, align 1, addrspace(5)
%"58" = alloca i32, align 4, addrspace(5)
%"59" = alloca i64, align 8, addrspace(5)
%"70" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"48"
"48": ; preds = %1
%"60" = load i64, ptr addrspace(4) %"53", align 8
store i64 %"60", ptr addrspace(5) %"59", align 8
%"38" = call i32 @__zluda_ptx_impl_sreg_laneid()
br label %"49"
"49": ; preds = %"48"
store i32 %"38", ptr addrspace(5) %"54", align 4
%"40" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"50"
"50": ; preds = %"49"
store i32 %"40", ptr addrspace(5) %"55", align 4
%"64" = load i32, ptr addrspace(5) %"54", align 4
%2 = icmp eq i32 %"64", 0
store i1 %2, ptr addrspace(5) %"56", align 1
store i1 false, ptr addrspace(5) %"57", align 1
%"66" = load i1, ptr addrspace(5) %"56", align 1
br i1 %"66", label %"10", label %"19"
"19": ; preds = %"50"
%"67" = call i1 @__zluda_ptx_impl_vote_sync_all_pred(i1 true, i32 -1)
store i1 %"67", ptr addrspace(5) %"57", align 1
br label %"10"
"10": ; preds = %"19", %"50"
%"69" = load i1, ptr addrspace(5) %"57", align 1
%"68" = select i1 %"69", i32 1, i32 0
store i32 %"68", ptr addrspace(5) %"58", align 4
%"72" = load i32, ptr addrspace(5) %"55", align 4
%3 = zext i32 %"72" to i64
%"71" = mul i64 %3, 4
store i64 %"71", ptr addrspace(5) %"70", align 8
%"74" = load i64, ptr addrspace(5) %"59", align 8
%"75" = load i64, ptr addrspace(5) %"70", align 8
%"73" = add i64 %"74", %"75"
store i64 %"73", ptr addrspace(5) %"59", align 8
%"76" = load i64, ptr addrspace(5) %"59", align 8
%"77" = load i32, ptr addrspace(5) %"58", align 4
%"78" = inttoptr i64 %"76" to ptr
store i32 %"77", ptr %"78", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,50 @@
declare hidden i1 @__zluda_ptx_impl_vote_sync_any_pred_negate(i1, i32) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @vote_any(ptr addrspace(4) byref(i64) %"44") #1 {
%"45" = alloca i32, align 4, addrspace(5)
%"46" = alloca i1, align 1, addrspace(5)
%"47" = alloca i1, align 1, addrspace(5)
%"48" = alloca i32, align 4, addrspace(5)
%"49" = alloca i64, align 8, addrspace(5)
%"58" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"41"
"41": ; preds = %1
%"50" = load i64, ptr addrspace(4) %"44", align 8
store i64 %"50", ptr addrspace(5) %"49", align 8
%"35" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"42"
"42": ; preds = %"41"
store i32 %"35", ptr addrspace(5) %"45", align 4
%"53" = load i32, ptr addrspace(5) %"45", align 4
%2 = icmp uge i32 %"53", 32
store i1 %2, ptr addrspace(5) %"46", align 1
%"55" = load i1, ptr addrspace(5) %"46", align 1
%"54" = call i1 @__zluda_ptx_impl_vote_sync_any_pred_negate(i1 %"55", i32 -1)
store i1 %"54", ptr addrspace(5) %"47", align 1
%"57" = load i1, ptr addrspace(5) %"47", align 1
%"56" = select i1 %"57", i32 1, i32 0
store i32 %"56", ptr addrspace(5) %"48", align 4
%"60" = load i32, ptr addrspace(5) %"45", align 4
%3 = zext i32 %"60" to i64
%"59" = mul i64 %3, 4
store i64 %"59", ptr addrspace(5) %"58", align 8
%"62" = load i64, ptr addrspace(5) %"49", align 8
%"63" = load i64, ptr addrspace(5) %"58", align 8
%"61" = add i64 %"62", %"63"
store i64 %"61", ptr addrspace(5) %"49", align 8
%"64" = load i64, ptr addrspace(5) %"49", align 8
%"65" = load i32, ptr addrspace(5) %"48", align 4
%"66" = inttoptr i64 %"64" to ptr
store i32 %"65", ptr %"66", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,46 @@
declare hidden i32 @__zluda_ptx_impl_vote_sync_ballot_b32(i1, i32) #0
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
define amdgpu_kernel void @vote_ballot(ptr addrspace(4) byref(i64) %"41") #1 {
%"42" = alloca i32, align 4, addrspace(5)
%"43" = alloca i1, align 1, addrspace(5)
%"44" = alloca i32, align 4, addrspace(5)
%"45" = alloca i64, align 8, addrspace(5)
%"52" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"38"
"38": ; preds = %1
%"46" = load i64, ptr addrspace(4) %"41", align 8
store i64 %"46", ptr addrspace(5) %"45", align 8
%"34" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
br label %"39"
"39": ; preds = %"38"
store i32 %"34", ptr addrspace(5) %"42", align 4
%"49" = load i32, ptr addrspace(5) %"42", align 4
%2 = icmp uge i32 %"49", 34
store i1 %2, ptr addrspace(5) %"43", align 1
%"51" = load i1, ptr addrspace(5) %"43", align 1
%"60" = call i32 @__zluda_ptx_impl_vote_sync_ballot_b32(i1 %"51", i32 -1)
store i32 %"60", ptr addrspace(5) %"44", align 4
%"54" = load i32, ptr addrspace(5) %"42", align 4
%3 = zext i32 %"54" to i64
%"53" = mul i64 %3, 4
store i64 %"53", ptr addrspace(5) %"52", align 8
%"56" = load i64, ptr addrspace(5) %"45", align 8
%"57" = load i64, ptr addrspace(5) %"52", align 8
%"55" = add i64 %"56", %"57"
store i64 %"55", ptr addrspace(5) %"45", align 8
%"58" = load i64, ptr addrspace(5) %"45", align 8
%"59" = load i32, ptr addrspace(5) %"44", align 4
%"61" = inttoptr i64 %"58" to ptr
store i32 %"59", ptr %"61", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -409,6 +409,41 @@ test_ptx_warp!(
225u32, 237u32, 235u32, 236u32, 237u32,
]
);
test_ptx_warp!(
vote_all,
[
0u32, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1
]
);
test_ptx_warp!(
vote_all_sub,
[
0u32, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1
]
);
test_ptx_warp!(
vote_any,
[
1u32, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0
]
);
test_ptx_warp!(
vote_ballot,
[
0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292,
4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292,
4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292,
4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292,
4294967292, 4294967292, 4294967292, 4294967292, 4294967292
]
);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,40 @@
.version 7.0
.target sm_70
.address_size 64
.visible .entry vote_all(
.param .u64 output
)
{
.reg .u32 laneid;
.reg .u32 tid;
.reg .pred not_first_lane;
.reg .pred result_pred;
.reg .u32 result;
.reg .u64 out_ptr;
ld.param.u64 out_ptr, [output];
mov.u32 laneid, %laneid;
mov.u32 tid, %tid.x;
setp.ne.u32 not_first_lane, laneid, 0;
mov.pred result_pred, 0;
// IMPORTANT:
// PTX documentation states:
// "The behavior of vote.sync is undefined if the executing thread is not in the membermask."
// You might think that means:
// "The value produced by vote.sync is undefined if the if the executing thread is not in the membermask."
// But it actually means:
// "The instruction `vote.sync` is _undefined behavior_ (in C/C++ sense) for _all threads in the warp_ if the executing thread is not in the membermask."
// Compiler _can_ and _does_ skip vote.sync entirely if it can prove that the membermask does not match execution mask
@not_first_lane vote.sync.all.pred result_pred, not_first_lane, 0xFFFFFFFE;
selp.u32 result, 1, 0, result_pred;
.reg .u64 out_offset;
mul.wide.u32 out_offset, tid, 4;
add.u64 out_ptr, out_ptr, out_offset;
st.u32 [out_ptr], result;
ret;
}

View file

@ -0,0 +1,35 @@
.version 7.0
.target sm_70
.address_size 64
.visible .entry vote_all_sub(
.param .u64 output
)
{
.reg .u32 laneid;
.reg .u32 tid;
.reg .pred first_lane;
.reg .pred result_pred;
.reg .u32 result;
.reg .u64 out_ptr;
ld.param.u64 out_ptr, [output];
mov.u32 laneid, %laneid;
mov.u32 tid, %tid.x;
setp.eq.u32 first_lane, laneid, 0;
mov.pred result_pred, 0;
@first_lane bra EXIT;
// IMPORTANT: it is legal for membermask to be bigger than the execution mask
vote.sync.all.pred result_pred, 1, 0xFFFFFFFF;
EXIT:
selp.u32 result, 1, 0, result_pred;
.reg .u64 out_offset;
mul.wide.u32 out_offset, tid, 4;
add.u64 out_ptr, out_ptr, out_offset;
st.u32 [out_ptr], result;
ret;
}

View file

@ -0,0 +1,29 @@
.version 7.0
.target sm_70
.address_size 64
.visible .entry vote_any(
.param .u64 output
)
{
.reg .u32 tid;
.reg .pred tid_is_greater_equal_32;
.reg .pred result_pred;
.reg .u32 result;
.reg .u64 out_ptr;
ld.param.u64 out_ptr, [output];
mov.u32 tid, %tid.x;
setp.ge.u32 tid_is_greater_equal_32, tid, 32;
vote.sync.any.pred result_pred, !tid_is_greater_equal_32, 0xFFFFFFFF;
selp.u32 result, 1, 0, result_pred;
.reg .u64 out_offset;
mul.wide.u32 out_offset, tid, 4;
add.u64 out_ptr, out_ptr, out_offset;
st.u32 [out_ptr], result;
ret;
}

View file

@ -0,0 +1,27 @@
.version 7.0
.target sm_70
.address_size 64
.visible .entry vote_ballot(
.param .u64 output
)
{
.reg .u32 tid;
.reg .pred tid_is_greater_equal_34;
.reg .u32 result;
.reg .u64 out_ptr;
ld.param.u64 out_ptr, [output];
mov.u32 tid, %tid.x;
setp.ge.u32 tid_is_greater_equal_34, tid, 34;
vote.sync.ballot.b32 result, tid_is_greater_equal_34, 0xFFFFFFFF;
.reg .u64 out_offset;
mul.wide.u32 out_offset, tid, 4;
add.u64 out_ptr, out_ptr, out_offset;
st.u32 [out_ptr], result;
ret;
}

View file

@ -3,7 +3,8 @@ use super::{
StateSpace, VectorPrefix,
};
use crate::{
FunnelShiftMode, Mul24Control, PtxError, PtxParserState, Reduction, ShiftDirection, ShuffleMode,
FunnelShiftMode, Mul24Control, PtxError, PtxParserState, Reduction, ShiftDirection,
ShuffleMode, VoteMode,
};
use bitflags::bitflags;
use std::{alloc::Layout, cmp::Ordering, fmt::Write, num::NonZeroU8};
@ -678,6 +679,22 @@ ptx_parser_macros::generate_instruction_type!(
src: T
}
},
Vote {
type: Type::Scalar(data.mode.type_()),
data: VoteDetails,
arguments<T>: {
dst: T,
src1: {
repr: T,
type: { Type::Scalar(ScalarType::Pred) },
},
src2: {
repr: T,
type: { Type::Scalar(ScalarType::U32) },
}
}
}
}
);
@ -2203,3 +2220,17 @@ pub enum DivFloatKind {
pub struct FlushToZero {
pub flush_to_zero: bool,
}
pub struct VoteDetails {
pub mode: VoteMode,
pub negate: bool,
}
impl VoteMode {
fn type_(self) -> ScalarType {
match self {
VoteMode::All | VoteMode::Any => ScalarType::Pred,
VoteMode::Ballot => ScalarType::B32,
}
}
}

View file

@ -1792,6 +1792,11 @@ derive_parser!(
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum FunnelShiftMode { }
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum VoteMode {
Ballot
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
mov{.vec}.type d, a => {
Instruction::Mov {
@ -3737,6 +3742,33 @@ derive_parser!(
.atype: ScalarType = { .u32, .s32 };
.btype: ScalarType = { .u32, .s32 };
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync
vote.sync.mode.pred d, {!}a, membermask => {
let (negate, a) = a;
Instruction::Vote {
data: VoteDetails {
mode,
negate
},
arguments: VoteArgs { dst: d, src1: a, src2: membermask }
}
}
vote.sync.ballot.b32 d, {!}a, membermask => {
let (negate, a) = a;
Instruction::Vote {
data: VoteDetails {
mode: VoteMode::Ballot,
negate
},
arguments: VoteArgs { dst: d, src1: a, src2: membermask }
}
}
// .mode: VoteMode = { .all, .any, .uni };
.mode: VoteMode = { .all, .any };
);
#[cfg(test)]