mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-04 15:19:49 +00:00
Add float-to-int cvt
This commit is contained in:
parent
3105674618
commit
002a19354a
1 changed files with 97 additions and 27 deletions
|
@ -1118,7 +1118,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
c"llvm.cos.f32",
|
c"llvm.cos.f32",
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&ast::ScalarType::F32.into(),
|
&ast::ScalarType::F32.into(),
|
||||||
vec![(arguments.src, llvm_f32)],
|
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||||
)?;
|
)?;
|
||||||
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1371,7 +1371,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
c"llvm.sin.f32",
|
c"llvm.sin.f32",
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&ast::ScalarType::F32.into(),
|
&ast::ScalarType::F32.into(),
|
||||||
vec![(arguments.src, llvm_f32)],
|
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
|
||||||
)?;
|
)?;
|
||||||
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
|
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1382,7 +1382,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
name: &CStr,
|
name: &CStr,
|
||||||
dst: Option<SpirvWord>,
|
dst: Option<SpirvWord>,
|
||||||
return_type: &ast::Type,
|
return_type: &ast::Type,
|
||||||
arguments: Vec<(SpirvWord, LLVMTypeRef)>,
|
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
|
||||||
) -> Result<LLVMValueRef, TranslateError> {
|
) -> Result<LLVMValueRef, TranslateError> {
|
||||||
let fn_type = get_function_type(
|
let fn_type = get_function_type(
|
||||||
self.context,
|
self.context,
|
||||||
|
@ -1393,10 +1393,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
if fn_ == ptr::null_mut() {
|
if fn_ == ptr::null_mut() {
|
||||||
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||||
}
|
}
|
||||||
let mut arguments = arguments
|
let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::<Vec<_>>();
|
||||||
.iter()
|
|
||||||
.map(|(arg, _)| self.resolver.value(*arg))
|
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
|
||||||
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
|
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
|
||||||
LLVMBuildCall2(
|
LLVMBuildCall2(
|
||||||
self.builder,
|
self.builder,
|
||||||
|
@ -1538,11 +1535,11 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
|
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
|
||||||
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
|
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
|
||||||
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
|
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
|
||||||
ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(),
|
ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt,
|
||||||
ptx_parser::CvtMode::FPTruncate {
|
ptx_parser::CvtMode::FPTruncate {
|
||||||
rounding,
|
rounding,
|
||||||
flush_to_zero,
|
flush_to_zero,
|
||||||
} => todo!(),
|
} => LLVMBuildFPTrunc,
|
||||||
ptx_parser::CvtMode::FPRound {
|
ptx_parser::CvtMode::FPRound {
|
||||||
integer_rounding,
|
integer_rounding,
|
||||||
flush_to_zero,
|
flush_to_zero,
|
||||||
|
@ -1550,11 +1547,27 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
ptx_parser::CvtMode::SignedFromFP {
|
ptx_parser::CvtMode::SignedFromFP {
|
||||||
rounding,
|
rounding,
|
||||||
flush_to_zero,
|
flush_to_zero,
|
||||||
} => todo!(),
|
} => {
|
||||||
|
return self.emit_cvt_float_to_int(
|
||||||
|
data.from,
|
||||||
|
data.to,
|
||||||
|
rounding,
|
||||||
|
arguments,
|
||||||
|
"llvm.fptosi.sat",
|
||||||
|
)
|
||||||
|
}
|
||||||
ptx_parser::CvtMode::UnsignedFromFP {
|
ptx_parser::CvtMode::UnsignedFromFP {
|
||||||
rounding,
|
rounding,
|
||||||
flush_to_zero,
|
flush_to_zero,
|
||||||
} => todo!(),
|
} => {
|
||||||
|
return self.emit_cvt_float_to_int(
|
||||||
|
data.from,
|
||||||
|
data.to,
|
||||||
|
rounding,
|
||||||
|
arguments,
|
||||||
|
"llvm.fptoui.sat",
|
||||||
|
)
|
||||||
|
}
|
||||||
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
|
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
|
||||||
ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
|
ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
|
||||||
};
|
};
|
||||||
|
@ -1565,6 +1578,45 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn emit_cvt_float_to_int(
|
||||||
|
&mut self,
|
||||||
|
from: ast::ScalarType,
|
||||||
|
to: ast::ScalarType,
|
||||||
|
rounding: ast::RoundingMode,
|
||||||
|
arguments: ptx_parser::CvtArgs<SpirvWord>,
|
||||||
|
llvm_cast: &str,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let prefix = match rounding {
|
||||||
|
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
|
||||||
|
ptx_parser::RoundingMode::Zero => "llvm.trunc",
|
||||||
|
ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
|
||||||
|
ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
|
||||||
|
};
|
||||||
|
let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from));
|
||||||
|
let rounded_float = self.emit_intrinsic(
|
||||||
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
|
None,
|
||||||
|
&from.into(),
|
||||||
|
vec![(
|
||||||
|
self.resolver.value(arguments.src)?,
|
||||||
|
get_scalar_type(self.context, from),
|
||||||
|
)],
|
||||||
|
)?;
|
||||||
|
let cast_intrinsic = format!(
|
||||||
|
"{}.{}.{}\0",
|
||||||
|
llvm_cast,
|
||||||
|
LLVMTypeDisplay(to),
|
||||||
|
LLVMTypeDisplay(from)
|
||||||
|
);
|
||||||
|
self.emit_intrinsic(
|
||||||
|
unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
|
||||||
|
Some(arguments.dst),
|
||||||
|
&to.into(),
|
||||||
|
vec![(rounded_float, get_scalar_type(self.context, from))],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn emit_rsqrt(
|
fn emit_rsqrt(
|
||||||
&mut self,
|
&mut self,
|
||||||
data: ptx_parser::TypeFtz,
|
data: ptx_parser::TypeFtz,
|
||||||
|
@ -1580,7 +1632,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
&data.type_.into(),
|
||||||
vec![(arguments.src, type_)],
|
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1601,7 +1653,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
&data.type_.into(),
|
||||||
vec![(arguments.src, type_)],
|
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1623,7 +1675,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
&data.type_.into(),
|
||||||
vec![(arguments.src, type_)],
|
vec![(self.resolver.value(arguments.src)?, type_)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1745,7 +1797,10 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
&data.type_.into(),
|
||||||
vec![(arguments.src, get_scalar_type(self.context, data.type_))],
|
vec![(
|
||||||
|
self.resolver.value(arguments.src)?,
|
||||||
|
get_scalar_type(self.context, data.type_),
|
||||||
|
)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1760,7 +1815,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&ast::ScalarType::F32.into(),
|
&ast::ScalarType::F32.into(),
|
||||||
vec![(
|
vec![(
|
||||||
arguments.src,
|
self.resolver.value(arguments.src)?,
|
||||||
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
||||||
)],
|
)],
|
||||||
)?;
|
)?;
|
||||||
|
@ -1814,7 +1869,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
intrinsic,
|
intrinsic,
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&type_.into(),
|
&type_.into(),
|
||||||
vec![(arguments.src, llvm_type)],
|
vec![(self.resolver.value(arguments.src)?, llvm_type)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1832,13 +1887,16 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
}
|
}
|
||||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
|
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
|
||||||
};
|
};
|
||||||
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
|
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
|
||||||
let llvm_type = get_scalar_type(self.context, data.type_());
|
let llvm_type = get_scalar_type(self.context, data.type_());
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_().into(),
|
&data.type_().into(),
|
||||||
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
|
vec![
|
||||||
|
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||||
|
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||||
|
],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1856,13 +1914,16 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
}
|
}
|
||||||
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
|
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
|
||||||
};
|
};
|
||||||
let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
|
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
|
||||||
let llvm_type = get_scalar_type(self.context, data.type_());
|
let llvm_type = get_scalar_type(self.context, data.type_());
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_().into(),
|
&data.type_().into(),
|
||||||
vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
|
vec![
|
||||||
|
(self.resolver.value(arguments.src1)?, llvm_type),
|
||||||
|
(self.resolver.value(arguments.src2)?, llvm_type),
|
||||||
|
],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1872,15 +1933,24 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
data: ptx_parser::ArithFloat,
|
data: ptx_parser::ArithFloat,
|
||||||
arguments: ptx_parser::FmaArgs<SpirvWord>,
|
arguments: ptx_parser::FmaArgs<SpirvWord>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_));
|
let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_));
|
||||||
self.emit_intrinsic(
|
self.emit_intrinsic(
|
||||||
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
|
||||||
Some(arguments.dst),
|
Some(arguments.dst),
|
||||||
&data.type_.into(),
|
&data.type_.into(),
|
||||||
vec![
|
vec![
|
||||||
(arguments.src1, get_scalar_type(self.context, data.type_)),
|
(
|
||||||
(arguments.src2, get_scalar_type(self.context, data.type_)),
|
self.resolver.value(arguments.src1)?,
|
||||||
(arguments.src3, get_scalar_type(self.context, data.type_)),
|
get_scalar_type(self.context, data.type_),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
self.resolver.value(arguments.src2)?,
|
||||||
|
get_scalar_type(self.context, data.type_),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
self.resolver.value(arguments.src3)?,
|
||||||
|
get_scalar_type(self.context, data.type_),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -2238,9 +2308,9 @@ impl ResolveIdent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ScalarTypeInLLVM(ast::ScalarType);
|
struct LLVMTypeDisplay(ast::ScalarType);
|
||||||
|
|
||||||
impl std::fmt::Display for ScalarTypeInLLVM {
|
impl std::fmt::Display for LLVMTypeDisplay {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self.0 {
|
match self.0 {
|
||||||
ast::ScalarType::Pred => write!(f, "i1"),
|
ast::ScalarType::Pred => write!(f, "i1"),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue