mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 22:30:41 +00:00
Add shifts, cvt, rsqrt, sqrt, rcp, more sregs
This commit is contained in:
parent
d9c33ca505
commit
ae42eac925
5 changed files with 230 additions and 37 deletions
Binary file not shown.
|
@ -13,12 +13,30 @@ extern "C"
|
||||||
return __builtin_amdgcn_read_exec_lo();
|
return __builtin_amdgcn_read_exec_lo();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t __ockl_get_local_id(uint32_t) __device__;
|
||||||
|
uint32_t FUNC(sreg_tid)(uint8_t member)
|
||||||
|
{
|
||||||
|
return (uint32_t)__ockl_get_local_id(member);
|
||||||
|
}
|
||||||
|
|
||||||
size_t __ockl_get_local_size(uint32_t) __device__;
|
size_t __ockl_get_local_size(uint32_t) __device__;
|
||||||
uint32_t FUNC(sreg_ntid)(uint8_t member)
|
uint32_t FUNC(sreg_ntid)(uint8_t member)
|
||||||
{
|
{
|
||||||
return (uint32_t)__ockl_get_local_size(member);
|
return (uint32_t)__ockl_get_local_size(member);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t __ockl_get_global_id(uint32_t) __device__;
|
||||||
|
uint32_t FUNC(sreg_ctaid)(uint8_t member)
|
||||||
|
{
|
||||||
|
return (uint32_t)__ockl_get_global_id(member);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t __ockl_get_global_size(uint32_t) __device__;
|
||||||
|
uint32_t FUNC(sreg_nctaid)(uint8_t member)
|
||||||
|
{
|
||||||
|
return (uint32_t)__ockl_get_global_size(member);
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device));
|
uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device));
|
||||||
uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32)
|
uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32)
|
||||||
{
|
{
|
||||||
|
|
|
@ -522,9 +522,9 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
|
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
|
||||||
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
|
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
|
||||||
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
|
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
|
||||||
ast::Instruction::Cvt { .. } => todo!(),
|
ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments),
|
||||||
ast::Instruction::Shr { .. } => todo!(),
|
ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments),
|
||||||
ast::Instruction::Shl { .. } => todo!(),
|
ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
|
||||||
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
|
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
|
||||||
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
|
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
|
||||||
ast::Instruction::Abs { .. } => todo!(),
|
ast::Instruction::Abs { .. } => todo!(),
|
||||||
|
@ -533,9 +533,9 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
|
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
|
||||||
ast::Instruction::Min { .. } => todo!(),
|
ast::Instruction::Min { .. } => todo!(),
|
||||||
ast::Instruction::Max { .. } => todo!(),
|
ast::Instruction::Max { .. } => todo!(),
|
||||||
ast::Instruction::Rcp { .. } => todo!(),
|
ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
|
||||||
ast::Instruction::Sqrt { .. } => todo!(),
|
ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
|
||||||
ast::Instruction::Rsqrt { .. } => todo!(),
|
ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
|
||||||
ast::Instruction::Selp { .. } => todo!(),
|
ast::Instruction::Selp { .. } => todo!(),
|
||||||
ast::Instruction::Bar { .. } => todo!(),
|
ast::Instruction::Bar { .. } => todo!(),
|
||||||
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
|
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
|
||||||
|
@ -1406,6 +1406,212 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_cvt(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::CvtDetails,
|
||||||
|
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let dst_type = get_scalar_type(self.context, data.to);
|
||||||
|
let llvm_fn = match data.mode {
|
||||||
|
ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
|
||||||
|
ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
|
||||||
|
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
|
||||||
|
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
|
||||||
|
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
|
||||||
|
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
|
||||||
|
ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(),
|
||||||
|
ptx_parser::CvtMode::FPTruncate {
|
||||||
|
rounding,
|
||||||
|
flush_to_zero,
|
||||||
|
} => todo!(),
|
||||||
|
ptx_parser::CvtMode::FPRound {
|
||||||
|
integer_rounding,
|
||||||
|
flush_to_zero,
|
||||||
|
} => todo!(),
|
||||||
|
ptx_parser::CvtMode::SignedFromFP {
|
||||||
|
rounding,
|
||||||
|
flush_to_zero,
|
||||||
|
} => todo!(),
|
||||||
|
ptx_parser::CvtMode::UnsignedFromFP {
|
||||||
|
rounding,
|
||||||
|
flush_to_zero,
|
||||||
|
} => todo!(),
|
||||||
|
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
|
||||||
|
ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
|
||||||
|
};
|
||||||
|
let src = self.resolver.value(arguments.src)?;
|
||||||
|
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||||
|
llvm_fn(self.builder, src, dst_type, dst)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_rsqrt(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::TypeFtz,
|
||||||
|
arguments: ptx_parser::RsqrtArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let type_ = get_scalar_type(self.context, data.type_);
|
||||||
|
let intrinsic = match data.type_ {
|
||||||
|
ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32",
|
||||||
|
ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64",
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
self.emit_intrinsic(
|
||||||
|
intrinsic,
|
||||||
|
Some(arguments.dst),
|
||||||
|
&data.type_.into(),
|
||||||
|
vec![(arguments.src, type_)],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_sqrt(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::RcpData,
|
||||||
|
arguments: ptx_parser::SqrtArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let type_ = get_scalar_type(self.context, data.type_);
|
||||||
|
let intrinsic = match (data.type_, data.kind) {
|
||||||
|
(ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32",
|
||||||
|
(ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32",
|
||||||
|
(ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64",
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
self.emit_intrinsic(
|
||||||
|
intrinsic,
|
||||||
|
Some(arguments.dst),
|
||||||
|
&data.type_.into(),
|
||||||
|
vec![(arguments.src, type_)],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_rcp(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::RcpData,
|
||||||
|
arguments: ptx_parser::RcpArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let type_ = get_scalar_type(self.context, data.type_);
|
||||||
|
let intrinsic = match (data.type_, data.kind) {
|
||||||
|
(ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32",
|
||||||
|
(_, ast::RcpKind::Compliant(rnd)) => {
|
||||||
|
return self.emit_rcp_compliant(data, arguments, rnd)
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
self.emit_intrinsic(
|
||||||
|
intrinsic,
|
||||||
|
Some(arguments.dst),
|
||||||
|
&data.type_.into(),
|
||||||
|
vec![(arguments.src, type_)],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_rcp_compliant(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::RcpData,
|
||||||
|
arguments: ptx_parser::RcpArgs<SpirvWord>,
|
||||||
|
_rnd: ast::RoundingMode,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let type_ = get_scalar_type(self.context, data.type_);
|
||||||
|
let one = unsafe { LLVMConstReal(type_, 1.0) };
|
||||||
|
let src = self.resolver.value(arguments.src)?;
|
||||||
|
let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||||
|
LLVMBuildFDiv(self.builder, one, src, dst)
|
||||||
|
});
|
||||||
|
unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) };
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_shr(
|
||||||
|
&mut self,
|
||||||
|
data: ptx_parser::ShrData,
|
||||||
|
arguments: ptx_parser::ShrArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let shift_fn = match data.kind {
|
||||||
|
ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
|
||||||
|
ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
|
||||||
|
};
|
||||||
|
self.emit_shift(
|
||||||
|
data.type_,
|
||||||
|
arguments.dst,
|
||||||
|
arguments.src1,
|
||||||
|
arguments.src2,
|
||||||
|
shift_fn,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_shl(
|
||||||
|
&mut self,
|
||||||
|
type_: ptx_parser::ScalarType,
|
||||||
|
arguments: ptx_parser::ShlArgs<SpirvWord>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
self.emit_shift(
|
||||||
|
type_,
|
||||||
|
arguments.dst,
|
||||||
|
arguments.src1,
|
||||||
|
arguments.src2,
|
||||||
|
LLVMBuildShl,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_shift(
|
||||||
|
&mut self,
|
||||||
|
type_: ast::ScalarType,
|
||||||
|
dst: SpirvWord,
|
||||||
|
src1: SpirvWord,
|
||||||
|
src2: SpirvWord,
|
||||||
|
llvm_fn: unsafe extern "C" fn(
|
||||||
|
LLVMBuilderRef,
|
||||||
|
LLVMValueRef,
|
||||||
|
LLVMValueRef,
|
||||||
|
*const i8,
|
||||||
|
) -> LLVMValueRef,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let src1 = self.resolver.value(src1)?;
|
||||||
|
let shift_size = self.resolver.value(src2)?;
|
||||||
|
let integer_bits = type_.layout().size() * 8;
|
||||||
|
let integer_bits_constant = unsafe {
|
||||||
|
LLVMConstInt(
|
||||||
|
get_scalar_type(self.context, ast::ScalarType::U32),
|
||||||
|
integer_bits as u64,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let should_clamp = unsafe {
|
||||||
|
LLVMBuildICmp(
|
||||||
|
self.builder,
|
||||||
|
LLVMIntPredicate::LLVMIntUGE,
|
||||||
|
shift_size,
|
||||||
|
integer_bits_constant,
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let llvm_type = get_scalar_type(self.context, type_);
|
||||||
|
let zero = unsafe { LLVMConstNull(llvm_type) };
|
||||||
|
let normalized_shift_size = if type_.layout().size() >= 4 {
|
||||||
|
unsafe {
|
||||||
|
LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) }
|
||||||
|
};
|
||||||
|
let shifted = unsafe {
|
||||||
|
llvm_fn(
|
||||||
|
self.builder,
|
||||||
|
src1,
|
||||||
|
normalized_shift_size,
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.resolver.with_result(dst, |dst| unsafe {
|
||||||
|
LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||||
// Should be available in LLVM 19
|
// Should be available in LLVM 19
|
||||||
|
|
|
@ -45,7 +45,6 @@ test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]);
|
||||||
test_ptx!(bra, [10u64], [11u64]);
|
test_ptx!(bra, [10u64], [11u64]);
|
||||||
test_ptx!(not, [0u64], [u64::max_value()]);
|
test_ptx!(not, [0u64], [u64::max_value()]);
|
||||||
test_ptx!(shl, [11u64], [44u64]);
|
test_ptx!(shl, [11u64], [44u64]);
|
||||||
test_ptx!(shl_link_hack, [11u64], [44u64]);
|
|
||||||
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
|
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
|
||||||
test_ptx!(cvta, [3.0f32], [3.0f32]);
|
test_ptx!(cvta, [3.0f32], [3.0f32]);
|
||||||
test_ptx!(block, [1u64], [2u64]);
|
test_ptx!(block, [1u64], [2u64]);
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
// HACK ALERT
|
|
||||||
// This test is for testing workaround for a bug in IGC where linking fails
|
|
||||||
// if there is shl/shr with different width of value and shift
|
|
||||||
|
|
||||||
.version 6.5
|
|
||||||
.target sm_30
|
|
||||||
.address_size 64
|
|
||||||
|
|
||||||
.visible .entry shl_link_hack(
|
|
||||||
.param .u64 input,
|
|
||||||
.param .u64 output
|
|
||||||
)
|
|
||||||
{
|
|
||||||
.reg .u64 in_addr;
|
|
||||||
.reg .u64 out_addr;
|
|
||||||
.reg .u64 temp;
|
|
||||||
.reg .u64 temp2;
|
|
||||||
|
|
||||||
ld.param.u64 in_addr, [input];
|
|
||||||
ld.param.u64 out_addr, [output];
|
|
||||||
|
|
||||||
// Here only to trigger linking
|
|
||||||
.reg .u32 unused;
|
|
||||||
atom.inc.u32 unused, [out_addr], 2000000;
|
|
||||||
|
|
||||||
ld.u64 temp, [in_addr];
|
|
||||||
shl.b64 temp2, temp, 2;
|
|
||||||
st.u64 [out_addr], temp2;
|
|
||||||
ret;
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue