diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index cbbf2dc..50f9d3d 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 937bda1..553070e 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -4,15 +4,50 @@ #include #include -#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_ ## NAME +#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME -extern "C" { - uint32_t FUNC(activemask)() { +extern "C" +{ + uint32_t FUNC(activemask)() + { return __builtin_amdgcn_read_exec_lo(); } size_t __ockl_get_local_size(uint32_t) __device__; - uint32_t FUNC(sreg_ntid)(uint8_t member) { + uint32_t FUNC(sreg_ntid)(uint8_t member) + { return (uint32_t)__ockl_get_local_size(member); } + + int32_t __ockl_bfe_i32(int32_t, uint32_t, uint32_t) __attribute__((device)); + int32_t FUNC(bfe_s32)(int32_t base, uint32_t pos, uint32_t len) + { + return __ockl_bfe_i32(base, pos, len); + } + + uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device)); + uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos, uint32_t len) + { + return __ockl_bfe_u32(base, pos, len); + } + + // LLVM contains mentions of llvm.amdgcn.ubfe.i64 and llvm.amdgcn.sbfe.i64, + // but using it only leads to LLVM crashes on RDNA2 + uint64_t FUNC(bfe_u64)(uint64_t base, uint32_t b, uint32_t c) + { + uint8_t pos = uint8_t(b); + uint8_t len = uint8_t(c); + if (len == 0) + return 0; + return (base >> pos) & ((1U << len) - 1U); + } + + int64_t FUNC(bfe_s64)(int64_t base, uint32_t b, uint32_t c) + { + uint8_t pos = uint8_t(b); + uint8_t len = uint8_t(c); + if (len == 0) + return 0; + return (base >> pos) & ((1U << len) - 1U); + } } diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 36a9623..cbb1570 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -472,14 +472,15 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Popc { data, arguments } => todo!(), ast::Instruction::Xor { data, arguments } => todo!(), ast::Instruction::Rem { data, arguments } => todo!(), - ast::Instruction::Bfe { data, arguments } => todo!(), ast::Instruction::Bfi { data, arguments } => todo!(), ast::Instruction::PrmtSlow { arguments } => todo!(), ast::Instruction::Prmt { data, arguments } => todo!(), ast::Instruction::Membar { data } => todo!(), ast::Instruction::Trap {} => todo!(), // replaced by a function call - ast::Instruction::Activemask { arguments } => return Err(error_unreachable()), + ast::Instruction::Bfe { .. } | ast::Instruction::Activemask { .. } => { + return Err(error_unreachable()) + } } } diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs index 2912366..37e477f 100644 --- a/ptx/src/pass/extract_globals.rs +++ b/ptx/src/pass/extract_globals.rs @@ -219,33 +219,6 @@ fn instruction_to_fn_call( })) } -fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { - match this { - ast::ScalarType::B8 => "b8", - ast::ScalarType::B16 => "b16", - ast::ScalarType::B32 => "b32", - ast::ScalarType::B64 => "b64", - ast::ScalarType::B128 => "b128", - ast::ScalarType::U8 => "u8", - ast::ScalarType::U16 => "u16", - ast::ScalarType::U16x2 => "u16x2", - ast::ScalarType::U32 => "u32", - ast::ScalarType::U64 => "u64", - ast::ScalarType::S8 => "s8", - ast::ScalarType::S16 => "s16", - ast::ScalarType::S16x2 => "s16x2", - ast::ScalarType::S32 => "s32", - ast::ScalarType::S64 => "s64", - ast::ScalarType::F16 => "f16", - ast::ScalarType::F16x2 => "f16x2", - ast::ScalarType::F32 => "f32", - ast::ScalarType::F64 => "f64", - ast::ScalarType::BF16 => "bf16", - ast::ScalarType::BF16x2 => "bf16x2", - ast::ScalarType::Pred => "pred", - } -} - fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str { match this { ast::AtomSemantics::Relaxed => "relaxed", diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index ead747a..df0af8f 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -2019,3 +2019,30 @@ pub struct VectorAccess { src: SpirvWord, member: u8, } + +fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { + match this { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::B128 => "b128", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U16x2 => "u16x2", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S16x2 => "s16x2", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::BF16 => "bf16", + ast::ScalarType::BF16x2 => "bf16x2", + ast::ScalarType::Pred => "pred", + } +} \ No newline at end of file diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 30ff75d..75ee676 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -96,6 +96,10 @@ fn run_instruction<'input>( i @ ptx_parser::Instruction::Activemask { .. } => { to_call(resolver, fn_declarations, "activemask".into(), i)? } + i @ ptx_parser::Instruction::Bfe { data, .. } => { + let name = ["bfe_", scalar_to_ptx_name(data)].concat(); + to_call(resolver, fn_declarations, name.into(), i)? + } i => i, }) } diff --git a/ptx/src/test/spirv_run/bfe.ptx b/ptx/src/test/spirv_run/bfe.ptx index 60ee8a6..a01a14a 100644 --- a/ptx/src/test/spirv_run/bfe.ptx +++ b/ptx/src/test/spirv_run/bfe.ptx @@ -10,14 +10,28 @@ .reg .u64 in_addr; .reg .u64 out_addr; .reg .u32 temp<3>; + .reg .b32 result<2>; + .reg .b64 temp64_0; + .reg .b32 temp64_1; + .reg .b32 temp64_2; + .reg .b64 result64_<2>; ld.param.u64 in_addr, [input]; ld.param.u64 out_addr, [output]; - ld.u32 temp0, [in_addr]; - ld.u32 temp1, [in_addr+4]; - ld.u32 temp2, [in_addr+8]; - bfe.u32 temp0, temp0, temp1, temp2; - st.u32 [out_addr], temp0; + ld.b64 temp64_0, [in_addr]; + ld.b32 temp64_1, [in_addr+8]; + ld.b32 temp64_2, [in_addr+16]; + ld.u32 temp0, [in_addr+24]; + ld.u32 temp1, [in_addr+28]; + ld.u32 temp2, [in_addr+32]; + //bfe.u64 result64_0, temp64_0, temp64_1, temp64_2; + bfe.s64 result64_1, temp64_0, temp64_1, temp64_2; + bfe.u32 result0, temp0, temp1, temp2; + bfe.s32 result1, temp0, temp1, temp2; + st.b64 [out_addr], result64_0; + st.b64 [out_addr], result64_1; + st.b32 [out_addr], result0; + st.b32 [out_addr], result1; ret; }