Allow ftz and saturated conversions

This commit is contained in:
Andrzej Janik 2024-09-03 18:11:09 +02:00
commit 3f31069e1b
3 changed files with 33 additions and 22 deletions

View file

@ -2163,9 +2163,6 @@ fn emit_cvt(
builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;
} }
ptx_parser::CvtMode::FPExtend { flush_to_zero } => { ptx_parser::CvtMode::FPExtend { flush_to_zero } => {
if flush_to_zero == Some(true) {
todo!()
}
let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
} }
@ -2173,9 +2170,6 @@ fn emit_cvt(
rounding, rounding,
flush_to_zero, flush_to_zero,
} => { } => {
if flush_to_zero == Some(true) {
todo!()
}
let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
emit_rounding_decoration(builder, arg.dst, Some(rounding)); emit_rounding_decoration(builder, arg.dst, Some(rounding));
@ -2234,9 +2228,6 @@ fn emit_cvt(
rounding, rounding,
flush_to_zero, flush_to_zero,
} => { } => {
if flush_to_zero == Some(true) {
todo!()
}
let dest_t: ast::ScalarType = dets.to.into(); let dest_t: ast::ScalarType = dets.to.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?;
@ -2246,9 +2237,6 @@ fn emit_cvt(
rounding, rounding,
flush_to_zero, flush_to_zero,
} => { } => {
if flush_to_zero == Some(true) {
todo!()
}
let dest_t: ast::ScalarType = dets.to.into(); let dest_t: ast::ScalarType = dets.to.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;

View file

@ -1384,8 +1384,8 @@ impl CvtDetails {
dst: ScalarType, dst: ScalarType,
src: ScalarType, src: ScalarType,
) -> Self { ) -> Self {
if saturate { if saturate && dst.kind() == ScalarKind::Float {
errors.push(PtxError::Todo); errors.push(PtxError::SyntaxError);
} }
// Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
let flush_to_zero = match (dst, src) { let flush_to_zero = match (dst, src) {
@ -1432,6 +1432,18 @@ impl CvtDetails {
}, },
(ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
(ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
(ScalarKind::Signed, ScalarKind::Unsigned) if saturate => {
CvtMode::SaturateUnsignedToSigned
}
(ScalarKind::Unsigned, ScalarKind::Signed) if saturate => {
CvtMode::SaturateSignedToUnsigned
}
(ScalarKind::Unsigned, ScalarKind::Signed)
| (ScalarKind::Signed, ScalarKind::Unsigned)
if dst.size_of() == src.size_of() =>
{
CvtMode::Bitcast
}
(ScalarKind::Unsigned, ScalarKind::Unsigned) (ScalarKind::Unsigned, ScalarKind::Unsigned)
| (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
Ordering::Less => CvtMode::Truncate, Ordering::Less => CvtMode::Truncate,
@ -1444,6 +1456,7 @@ impl CvtDetails {
} }
} }
}, },
(ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned,
(_, _) => { (_, _) => {
errors.push(PtxError::SyntaxError); errors.push(PtxError::SyntaxError);
CvtMode::Bitcast CvtMode::Bitcast

View file

@ -289,7 +289,12 @@ pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'
state, state,
input: &input[..], input: &input[..],
}; };
module.parse(parser).ok() let parsing_result = module.parse(parser).ok();
if !errors.is_empty() {
None
} else {
parsing_result
}
} }
pub fn parse_module_checked<'input>( pub fn parse_module_checked<'input>(
@ -314,19 +319,24 @@ pub fn parse_module_checked<'input>(
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors); return Err(errors);
} }
let parse_error = { let parse_result = {
let state = PtxParserState::new(&mut errors); let state = PtxParserState::new(&mut errors);
let parser = PtxParser { let parser = PtxParser {
state, state,
input: &tokens[..], input: &tokens[..],
}; };
match module.parse(parser) { module
Ok(ast) => return Ok(ast), .parse(parser)
Err(err) => PtxError::Parser(err.into_inner()), .map_err(|err| PtxError::Parser(err.into_inner()))
}
}; };
errors.push(parse_error); match parse_result {
Err(errors) Ok(result) if errors.is_empty() => Ok(result),
Ok(_) => Err(errors),
Err(err) => {
errors.push(err);
Err(errors)
}
}
} }
fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> { fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {