diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index 8aa4576..5147b79 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -2163,9 +2163,6 @@ fn emit_cvt( builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; } ptx_parser::CvtMode::FPExtend { flush_to_zero } => { - if flush_to_zero == Some(true) { - todo!() - } let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; } @@ -2173,9 +2170,6 @@ fn emit_cvt( rounding, flush_to_zero, } => { - if flush_to_zero == Some(true) { - todo!() - } let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); @@ -2234,9 +2228,6 @@ fn emit_cvt( rounding, flush_to_zero, } => { - if flush_to_zero == Some(true) { - todo!() - } let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; @@ -2246,9 +2237,6 @@ fn emit_cvt( rounding, flush_to_zero, } => { - if flush_to_zero == Some(true) { - todo!() - } let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index f0d7f9f..f5e65b4 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1384,8 +1384,8 @@ impl CvtDetails { dst: ScalarType, src: ScalarType, ) -> Self { - if saturate { - errors.push(PtxError::Todo); + if saturate && dst.kind() == ScalarKind::Float { + errors.push(PtxError::SyntaxError); } // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. let flush_to_zero = match (dst, src) { @@ -1432,6 +1432,18 @@ impl CvtDetails { }, (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), + (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => { + CvtMode::SaturateUnsignedToSigned + } + (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => { + CvtMode::SaturateSignedToUnsigned + } + (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Signed, ScalarKind::Unsigned) + if dst.size_of() == src.size_of() => + { + CvtMode::Bitcast + } (ScalarKind::Unsigned, ScalarKind::Unsigned) | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { Ordering::Less => CvtMode::Truncate, @@ -1444,6 +1456,7 @@ impl CvtDetails { } } }, + (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned, (_, _) => { errors.push(PtxError::SyntaxError); CvtMode::Bitcast diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 357304b..b81d826 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -289,7 +289,12 @@ pub fn parse_module_unchecked<'input>(text: &'input str) -> Option( @@ -314,19 +319,24 @@ pub fn parse_module_checked<'input>( if !errors.is_empty() { return Err(errors); } - let parse_error = { + let parse_result = { let state = PtxParserState::new(&mut errors); let parser = PtxParser { state, input: &tokens[..], }; - match module.parse(parser) { - Ok(ast) => return Ok(ast), - Err(err) => PtxError::Parser(err.into_inner()), - } + module + .parse(parser) + .map_err(|err| PtxError::Parser(err.into_inner())) }; - errors.push(parse_error); - Err(errors) + match parse_result { + Ok(result) if errors.is_empty() => Ok(result), + Ok(_) => Err(errors), + Err(err) => { + errors.push(err); + Err(errors) + } + } } fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> {