Add saturated integer conversions

This commit is contained in:
Andrzej Janik 2024-10-16 03:12:54 +02:00
parent 002a19354a
commit 73eb31fec5

View file

@ -1533,8 +1533,12 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
ptx_parser::CvtMode::SaturateUnsignedToSigned => {
return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
}
ptx_parser::CvtMode::SaturateSignedToUnsigned => {
return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
}
ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt,
ptx_parser::CvtMode::FPTruncate {
rounding,
@ -1543,7 +1547,15 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::CvtMode::FPRound {
integer_rounding,
flush_to_zero,
} => todo!(),
} => {
return self.emit_cvt_float_to_int(
data.from,
data.to,
integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
arguments,
Some(LLVMBuildFPToSI),
)
}
ptx_parser::CvtMode::SignedFromFP {
rounding,
flush_to_zero,
@ -1553,7 +1565,7 @@ impl<'a> MethodEmitContext<'a> {
data.to,
rounding,
arguments,
"llvm.fptosi.sat",
Some(LLVMBuildFPToSI),
)
}
ptx_parser::CvtMode::UnsignedFromFP {
@ -1565,7 +1577,7 @@ impl<'a> MethodEmitContext<'a> {
data.to,
rounding,
arguments,
"llvm.fptoui.sat",
Some(LLVMBuildFPToUI),
)
}
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
@ -1578,13 +1590,105 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_cvt_unsigned_to_signed_sat(
&mut self,
from: ptx_parser::ScalarType,
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
// This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
// so if it's downcast to a smaller type, it will be the maximum value
// of the smaller type
let max_value = match to {
ptx_parser::ScalarType::S8 => i8::MAX as u64,
ptx_parser::ScalarType::S16 => i16::MAX as u64,
ptx_parser::ScalarType::S32 => i32::MAX as u64,
ptx_parser::ScalarType::S64 => i64::MAX as u64,
_ => return Err(error_unreachable()),
};
let from_llvm = get_scalar_type(self.context, from);
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
let clamped = self.emit_intrinsic(
c"llvm.umin",
None,
&from.into(),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(max, from_llvm),
],
)?;
let resize_fn = if to.layout().size() >= from.layout().size() {
LLVMBuildSExtOrBitCast
} else {
LLVMBuildTrunc
};
let to_llvm = get_scalar_type(self.context, to);
self.resolver.with_result(arguments.dst, |dst| unsafe {
resize_fn(self.builder, clamped, to_llvm, dst)
});
Ok(())
}
fn emit_cvt_signed_to_unsigned_sat(
&mut self,
from: ptx_parser::ScalarType,
to: ptx_parser::ScalarType,
arguments: ptx_parser::CvtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let from_llvm = get_scalar_type(self.context, from);
let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
let zero_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
None,
&from.into(),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(zero, from_llvm),
],
)?;
// zero_clamped is now unsigned
let max_value = match to {
ptx_parser::ScalarType::U8 => u8::MAX as u64,
ptx_parser::ScalarType::U16 => u16::MAX as u64,
ptx_parser::ScalarType::U32 => u32::MAX as u64,
ptx_parser::ScalarType::U64 => u64::MAX as u64,
_ => return Err(error_unreachable()),
};
let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
let fully_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
None,
&from.into(),
vec![(zero_clamped, from_llvm), (max, from_llvm)],
)?;
let resize_fn = if to.layout().size() >= from.layout().size() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
};
let to_llvm = get_scalar_type(self.context, to);
self.resolver.with_result(arguments.dst, |dst| unsafe {
resize_fn(self.builder, fully_clamped, to_llvm, dst)
});
Ok(())
}
fn emit_cvt_float_to_int(
&mut self,
from: ast::ScalarType,
to: ast::ScalarType,
rounding: ast::RoundingMode,
arguments: ptx_parser::CvtArgs<SpirvWord>,
llvm_cast: &str,
llvm_cast: Option<
unsafe extern "C" fn(
arg1: LLVMBuilderRef,
Val: LLVMValueRef,
DestTy: LLVMTypeRef,
Name: *const i8,
) -> LLVMValueRef,
>,
) -> Result<(), TranslateError> {
let prefix = match rounding {
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
@ -1602,6 +1706,20 @@ impl<'a> MethodEmitContext<'a> {
get_scalar_type(self.context, from),
)],
)?;
if let Some(llvm_cast) = llvm_cast {
let to = get_scalar_type(self.context, to);
let poisoned_dst =
unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildFreeze(self.builder, poisoned_dst, dst)
});
} else {
self.resolver.register(arguments.dst, rounded_float);
}
// Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
// values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
// saturates by default and we don't care about NaNs anyway
/*
let cast_intrinsic = format!(
"{}.{}.{}\0",
llvm_cast,
@ -1614,6 +1732,7 @@ impl<'a> MethodEmitContext<'a> {
&to.into(),
vec![(rounded_float, get_scalar_type(self.context, from))],
)?;
*/
Ok(())
}