Improve parser for ftz pass

This commit is contained in:
Andrzej Janik 2025-02-17 01:12:11 +00:00
commit 17529f951d
2 changed files with 76 additions and 45 deletions

View file

@ -1019,9 +1019,16 @@ pub struct ArithInteger {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct ArithFloat { pub struct ArithFloat {
pub type_: ScalarType, pub type_: ScalarType,
pub rounding: Option<RoundingMode>, pub rounding: RoundingMode,
pub flush_to_zero: Option<bool>, pub flush_to_zero: Option<bool>,
pub saturate: bool, pub saturate: bool,
// From PTX documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions-add
// Note that an add instruction with an explicit rounding modifier is treated conservatively by
// the code optimizer. An add instruction with no rounding modifier defaults to
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
// instructions on the target device.
pub is_fusable: bool
} }
#[derive(Copy, Clone, PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq)]

View file

@ -1906,9 +1906,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f32, type_: f32,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: AddArgs { arguments: AddArgs {
@ -1921,9 +1922,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f64, type_: f64,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: AddArgs { arguments: AddArgs {
@ -1940,9 +1942,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f16, type_: f16,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: AddArgs { arguments: AddArgs {
@ -1955,9 +1958,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f16x2, type_: f16x2,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: AddArgs { arguments: AddArgs {
@ -1970,9 +1974,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: bf16, type_: bf16,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: AddArgs { arguments: AddArgs {
@ -1985,9 +1990,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: bf16x2, type_: bf16x2,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: AddArgs { arguments: AddArgs {
@ -2032,9 +2038,10 @@ derive_parser!(
data: ast::MulDetails::Float ( data: ast::MulDetails::Float (
ast::ArithFloat { ast::ArithFloat {
type_: f32, type_: f32,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat, saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: MulArgs { dst: d, src1: a, src2: b } arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2045,9 +2052,10 @@ derive_parser!(
data: ast::MulDetails::Float ( data: ast::MulDetails::Float (
ast::ArithFloat { ast::ArithFloat {
type_: f64, type_: f64,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false, saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: MulArgs { dst: d, src1: a, src2: b } arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2061,9 +2069,10 @@ derive_parser!(
data: ast::MulDetails::Float ( data: ast::MulDetails::Float (
ast::ArithFloat { ast::ArithFloat {
type_: f16, type_: f16,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat, saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: MulArgs { dst: d, src1: a, src2: b } arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2074,9 +2083,10 @@ derive_parser!(
data: ast::MulDetails::Float ( data: ast::MulDetails::Float (
ast::ArithFloat { ast::ArithFloat {
type_: f16x2, type_: f16x2,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat, saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: MulArgs { dst: d, src1: a, src2: b } arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2087,9 +2097,10 @@ derive_parser!(
data: ast::MulDetails::Float ( data: ast::MulDetails::Float (
ast::ArithFloat { ast::ArithFloat {
type_: bf16, type_: bf16,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false, saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: MulArgs { dst: d, src1: a, src2: b } arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2100,9 +2111,10 @@ derive_parser!(
data: ast::MulDetails::Float ( data: ast::MulDetails::Float (
ast::ArithFloat { ast::ArithFloat {
type_: bf16x2, type_: bf16x2,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false, saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: MulArgs { dst: d, src1: a, src2: b } arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2386,9 +2398,10 @@ derive_parser!(
data: ast::MadDetails::Float( data: ast::MadDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f32, type_: f32,
rounding: None, rounding: ast::RoundingMode::NearestEven,
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: false
} }
), ),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2399,9 +2412,10 @@ derive_parser!(
data: ast::MadDetails::Float( data: ast::MadDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f32, type_: f32,
rounding: Some(rnd.into()), rounding: rnd.into(),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: false
} }
), ),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2412,9 +2426,10 @@ derive_parser!(
data: ast::MadDetails::Float( data: ast::MadDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f64, type_: f64,
rounding: Some(rnd.into()), rounding: rnd.into(),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: false
} }
), ),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2429,9 +2444,10 @@ derive_parser!(
ast::Instruction::Fma { ast::Instruction::Fma {
data: ast::ArithFloat { data: ast::ArithFloat {
type_: f32, type_: f32,
rounding: Some(rnd.into()), rounding: rnd.into(),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: false
}, },
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
} }
@ -2440,9 +2456,10 @@ derive_parser!(
ast::Instruction::Fma { ast::Instruction::Fma {
data: ast::ArithFloat { data: ast::ArithFloat {
type_: f64, type_: f64,
rounding: Some(rnd.into()), rounding: rnd.into(),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: false
}, },
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
} }
@ -2454,9 +2471,10 @@ derive_parser!(
ast::Instruction::Fma { ast::Instruction::Fma {
data: ast::ArithFloat { data: ast::ArithFloat {
type_: f16, type_: f16,
rounding: Some(rnd.into()), rounding: rnd.into(),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: false
}, },
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
} }
@ -2504,9 +2522,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f32, type_: f32,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: SubArgs { dst: d, src1: a, src2: b } arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2517,9 +2536,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f64, type_: f64,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: SubArgs { dst: d, src1: a, src2: b } arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2533,9 +2553,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f16, type_: f16,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: SubArgs { dst: d, src1: a, src2: b } arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2546,9 +2567,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: f16x2, type_: f16x2,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz), flush_to_zero: Some(ftz),
saturate: sat saturate: sat,
is_fusable: rnd.is_none()
} }
), ),
arguments: SubArgs { dst: d, src1: a, src2: b } arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2559,9 +2581,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: bf16, type_: bf16,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: SubArgs { dst: d, src1: a, src2: b } arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2572,9 +2595,10 @@ derive_parser!(
data: ast::ArithDetails::Float( data: ast::ArithDetails::Float(
ast::ArithFloat { ast::ArithFloat {
type_: bf16x2, type_: bf16x2,
rounding: rnd.map(Into::into), rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None, flush_to_zero: None,
saturate: false saturate: false,
is_fusable: rnd.is_none()
} }
), ),
arguments: SubArgs { dst: d, src1: a, src2: b } arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2877,7 +2901,7 @@ derive_parser!(
rsqrt.approx.f64 d, a => { rsqrt.approx.f64 d, a => {
ast::Instruction::Rsqrt { ast::Instruction::Rsqrt {
data: ast::TypeFtz { data: ast::TypeFtz {
flush_to_zero: None, flush_to_zero: Some(false),
type_: f64 type_: f64
}, },
arguments: RsqrtArgs { dst: d, src: a } arguments: RsqrtArgs { dst: d, src: a }
@ -2886,7 +2910,7 @@ derive_parser!(
rsqrt.approx.ftz.f64 d, a => { rsqrt.approx.ftz.f64 d, a => {
ast::Instruction::Rsqrt { ast::Instruction::Rsqrt {
data: ast::TypeFtz { data: ast::TypeFtz {
flush_to_zero: None, flush_to_zero: Some(true),
type_: f64 type_: f64
}, },
arguments: RsqrtArgs { dst: d, src: a } arguments: RsqrtArgs { dst: d, src: a }