From 17529f951d362a10b1b25bd76f3942b07b8a0084 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 17 Feb 2025 01:12:11 +0000 Subject: [PATCH] Improve parser for ftz pass --- ptx_parser/src/ast.rs | 9 +++- ptx_parser/src/lib.rs | 112 +++++++++++++++++++++++++----------------- 2 files changed, 76 insertions(+), 45 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index c5e8e79..19a2897 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1019,9 +1019,16 @@ pub struct ArithInteger { #[derive(Copy, Clone)] pub struct ArithFloat { pub type_: ScalarType, - pub rounding: Option, + pub rounding: RoundingMode, pub flush_to_zero: Option, pub saturate: bool, + // From PTX documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions-add + // Note that an add instruction with an explicit rounding modifier is treated conservatively by + // the code optimizer. An add instruction with no rounding modifier defaults to + // round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular, + // mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add + // instructions on the target device. + pub is_fusable: bool } #[derive(Copy, Clone, PartialEq, Eq)] diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index f2c376d..ca40f63 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1906,9 +1906,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f32, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: rnd.is_none() } ), arguments: AddArgs { @@ -1921,9 +1922,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f64, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: rnd.is_none() } ), arguments: AddArgs { @@ -1940,9 +1942,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f16, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: rnd.is_none() } ), arguments: AddArgs { @@ -1955,9 +1958,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f16x2, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: rnd.is_none() } ), arguments: AddArgs { @@ -1970,9 +1974,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: bf16, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: rnd.is_none() } ), arguments: AddArgs { @@ -1985,9 +1990,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: bf16x2, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: rnd.is_none() } ), arguments: AddArgs { @@ -2032,9 +2038,10 @@ derive_parser!( data: ast::MulDetails::Float ( ast::ArithFloat { type_: f32, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), saturate: sat, + is_fusable: rnd.is_none() } ), arguments: MulArgs { dst: d, src1: a, src2: b } @@ -2045,9 +2052,10 @@ derive_parser!( data: ast::MulDetails::Float ( ast::ArithFloat { type_: f64, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, saturate: false, + is_fusable: rnd.is_none() } ), arguments: MulArgs { dst: d, src1: a, src2: b } @@ -2061,9 +2069,10 @@ derive_parser!( data: ast::MulDetails::Float ( ast::ArithFloat { type_: f16, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), saturate: sat, + is_fusable: rnd.is_none() } ), arguments: MulArgs { dst: d, src1: a, src2: b } @@ -2074,9 +2083,10 @@ derive_parser!( data: ast::MulDetails::Float ( ast::ArithFloat { type_: f16x2, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), saturate: sat, + is_fusable: rnd.is_none() } ), arguments: MulArgs { dst: d, src1: a, src2: b } @@ -2087,9 +2097,10 @@ derive_parser!( data: ast::MulDetails::Float ( ast::ArithFloat { type_: bf16, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, saturate: false, + is_fusable: rnd.is_none() } ), arguments: MulArgs { dst: d, src1: a, src2: b } @@ -2100,9 +2111,10 @@ derive_parser!( data: ast::MulDetails::Float ( ast::ArithFloat { type_: bf16x2, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, saturate: false, + is_fusable: rnd.is_none() } ), arguments: MulArgs { dst: d, src1: a, src2: b } @@ -2386,9 +2398,10 @@ derive_parser!( data: ast::MadDetails::Float( ast::ArithFloat { type_: f32, - rounding: None, + rounding: ast::RoundingMode::NearestEven, flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: false } ), arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } @@ -2399,9 +2412,10 @@ derive_parser!( data: ast::MadDetails::Float( ast::ArithFloat { type_: f32, - rounding: Some(rnd.into()), + rounding: rnd.into(), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: false } ), arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } @@ -2412,9 +2426,10 @@ derive_parser!( data: ast::MadDetails::Float( ast::ArithFloat { type_: f64, - rounding: Some(rnd.into()), + rounding: rnd.into(), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: false } ), arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } @@ -2429,9 +2444,10 @@ derive_parser!( ast::Instruction::Fma { data: ast::ArithFloat { type_: f32, - rounding: Some(rnd.into()), + rounding: rnd.into(), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: false }, arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } } @@ -2440,9 +2456,10 @@ derive_parser!( ast::Instruction::Fma { data: ast::ArithFloat { type_: f64, - rounding: Some(rnd.into()), + rounding: rnd.into(), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: false }, arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } } @@ -2454,9 +2471,10 @@ derive_parser!( ast::Instruction::Fma { data: ast::ArithFloat { type_: f16, - rounding: Some(rnd.into()), + rounding: rnd.into(), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: false }, arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } } @@ -2504,9 +2522,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f32, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: rnd.is_none() } ), arguments: SubArgs { dst: d, src1: a, src2: b } @@ -2517,9 +2536,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f64, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: rnd.is_none() } ), arguments: SubArgs { dst: d, src1: a, src2: b } @@ -2533,9 +2553,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f16, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: rnd.is_none() } ), arguments: SubArgs { dst: d, src1: a, src2: b } @@ -2546,9 +2567,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: f16x2, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz), - saturate: sat + saturate: sat, + is_fusable: rnd.is_none() } ), arguments: SubArgs { dst: d, src1: a, src2: b } @@ -2559,9 +2581,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: bf16, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: rnd.is_none() } ), arguments: SubArgs { dst: d, src1: a, src2: b } @@ -2572,9 +2595,10 @@ derive_parser!( data: ast::ArithDetails::Float( ast::ArithFloat { type_: bf16x2, - rounding: rnd.map(Into::into), + rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven), flush_to_zero: None, - saturate: false + saturate: false, + is_fusable: rnd.is_none() } ), arguments: SubArgs { dst: d, src1: a, src2: b } @@ -2877,7 +2901,7 @@ derive_parser!( rsqrt.approx.f64 d, a => { ast::Instruction::Rsqrt { data: ast::TypeFtz { - flush_to_zero: None, + flush_to_zero: Some(false), type_: f64 }, arguments: RsqrtArgs { dst: d, src: a } @@ -2886,7 +2910,7 @@ derive_parser!( rsqrt.approx.ftz.f64 d, a => { ast::Instruction::Rsqrt { data: ast::TypeFtz { - flush_to_zero: None, + flush_to_zero: Some(true), type_: f64 }, arguments: RsqrtArgs { dst: d, src: a }