mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Add saturated integer conversions
This commit is contained in:
parent
002a19354a
commit
73eb31fec5
1 changed files with 125 additions and 6 deletions
|
@ -1533,8 +1533,12 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
|
||||
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
|
||||
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
|
||||
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
|
||||
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
|
||||
ptx_parser::CvtMode::SaturateUnsignedToSigned => {
|
||||
return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
|
||||
}
|
||||
ptx_parser::CvtMode::SaturateSignedToUnsigned => {
|
||||
return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
|
||||
}
|
||||
ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt,
|
||||
ptx_parser::CvtMode::FPTruncate {
|
||||
rounding,
|
||||
|
@ -1543,7 +1547,15 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ptx_parser::CvtMode::FPRound {
|
||||
integer_rounding,
|
||||
flush_to_zero,
|
||||
} => todo!(),
|
||||
} => {
|
||||
return self.emit_cvt_float_to_int(
|
||||
data.from,
|
||||
data.to,
|
||||
integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
|
||||
arguments,
|
||||
Some(LLVMBuildFPToSI),
|
||||
)
|
||||
}
|
||||
ptx_parser::CvtMode::SignedFromFP {
|
||||
rounding,
|
||||
flush_to_zero,
|
||||
|
@ -1553,7 +1565,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
data.to,
|
||||
rounding,
|
||||
arguments,
|
||||
"llvm.fptosi.sat",
|
||||
Some(LLVMBuildFPToSI),
|
||||
)
|
||||
}
|
||||
ptx_parser::CvtMode::UnsignedFromFP {
|
||||
|
@ -1565,7 +1577,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
data.to,
|
||||
rounding,
|
||||
arguments,
|
||||
"llvm.fptoui.sat",
|
||||
Some(LLVMBuildFPToUI),
|
||||
)
|
||||
}
|
||||
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
|
||||
|
@ -1578,13 +1590,105 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_cvt_unsigned_to_signed_sat(
|
||||
&mut self,
|
||||
from: ptx_parser::ScalarType,
|
||||
to: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
// This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
|
||||
// so if it's downcast to a smaller type, it will be the maximum value
|
||||
// of the smaller type
|
||||
let max_value = match to {
|
||||
ptx_parser::ScalarType::S8 => i8::MAX as u64,
|
||||
ptx_parser::ScalarType::S16 => i16::MAX as u64,
|
||||
ptx_parser::ScalarType::S32 => i32::MAX as u64,
|
||||
ptx_parser::ScalarType::S64 => i64::MAX as u64,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let from_llvm = get_scalar_type(self.context, from);
|
||||
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
|
||||
let clamped = self.emit_intrinsic(
|
||||
c"llvm.umin",
|
||||
None,
|
||||
&from.into(),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(max, from_llvm),
|
||||
],
|
||||
)?;
|
||||
let resize_fn = if to.layout().size() >= from.layout().size() {
|
||||
LLVMBuildSExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
};
|
||||
let to_llvm = get_scalar_type(self.context, to);
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
resize_fn(self.builder, clamped, to_llvm, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_cvt_signed_to_unsigned_sat(
|
||||
&mut self,
|
||||
from: ptx_parser::ScalarType,
|
||||
to: ptx_parser::ScalarType,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let from_llvm = get_scalar_type(self.context, from);
|
||||
let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
|
||||
let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
|
||||
let zero_clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
|
||||
None,
|
||||
&from.into(),
|
||||
vec![
|
||||
(self.resolver.value(arguments.src)?, from_llvm),
|
||||
(zero, from_llvm),
|
||||
],
|
||||
)?;
|
||||
// zero_clamped is now unsigned
|
||||
let max_value = match to {
|
||||
ptx_parser::ScalarType::U8 => u8::MAX as u64,
|
||||
ptx_parser::ScalarType::U16 => u16::MAX as u64,
|
||||
ptx_parser::ScalarType::U32 => u32::MAX as u64,
|
||||
ptx_parser::ScalarType::U64 => u64::MAX as u64,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
|
||||
let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
|
||||
let fully_clamped = self.emit_intrinsic(
|
||||
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
|
||||
None,
|
||||
&from.into(),
|
||||
vec![(zero_clamped, from_llvm), (max, from_llvm)],
|
||||
)?;
|
||||
let resize_fn = if to.layout().size() >= from.layout().size() {
|
||||
LLVMBuildZExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
};
|
||||
let to_llvm = get_scalar_type(self.context, to);
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
resize_fn(self.builder, fully_clamped, to_llvm, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_cvt_float_to_int(
|
||||
&mut self,
|
||||
from: ast::ScalarType,
|
||||
to: ast::ScalarType,
|
||||
rounding: ast::RoundingMode,
|
||||
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||
llvm_cast: &str,
|
||||
llvm_cast: Option<
|
||||
unsafe extern "C" fn(
|
||||
arg1: LLVMBuilderRef,
|
||||
Val: LLVMValueRef,
|
||||
DestTy: LLVMTypeRef,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef,
|
||||
>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let prefix = match rounding {
|
||||
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
|
||||
|
@ -1602,6 +1706,20 @@ impl<'a> MethodEmitContext<'a> {
|
|||
get_scalar_type(self.context, from),
|
||||
)],
|
||||
)?;
|
||||
if let Some(llvm_cast) = llvm_cast {
|
||||
let to = get_scalar_type(self.context, to);
|
||||
let poisoned_dst =
|
||||
unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildFreeze(self.builder, poisoned_dst, dst)
|
||||
});
|
||||
} else {
|
||||
self.resolver.register(arguments.dst, rounded_float);
|
||||
}
|
||||
// Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
|
||||
// values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
|
||||
// saturates by default and we don't care about NaNs anyway
|
||||
/*
|
||||
let cast_intrinsic = format!(
|
||||
"{}.{}.{}\0",
|
||||
llvm_cast,
|
||||
|
@ -1614,6 +1732,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
&to.into(),
|
||||
vec![(rounded_float, get_scalar_type(self.context, from))],
|
||||
)?;
|
||||
*/
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue