Improve parser for ftz pass

This commit is contained in:
Andrzej Janik 2025-02-17 01:12:11 +00:00
parent 241cf43a52
commit 17529f951d
2 changed files with 76 additions and 45 deletions

View file

@ -1019,9 +1019,16 @@ pub struct ArithInteger {
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub type_: ScalarType,
pub rounding: Option<RoundingMode>,
pub rounding: RoundingMode,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
// From PTX documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions-add
// Note that an add instruction with an explicit rounding modifier is treated conservatively by
// the code optimizer. An add instruction with no rounding modifier defaults to
// round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular,
// mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add
// instructions on the target device.
pub is_fusable: bool
}
#[derive(Copy, Clone, PartialEq, Eq)]

View file

@ -1906,9 +1906,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1921,9 +1922,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1940,9 +1942,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1955,9 +1958,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1970,9 +1974,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -1985,9 +1990,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: AddArgs {
@ -2032,9 +2038,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2045,9 +2052,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2061,9 +2069,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2074,9 +2083,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2087,9 +2097,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2100,9 +2111,10 @@ derive_parser!(
data: ast::MulDetails::Float (
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: MulArgs { dst: d, src1: a, src2: b }
@ -2386,9 +2398,10 @@ derive_parser!(
data: ast::MadDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: None,
rounding: ast::RoundingMode::NearestEven,
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2399,9 +2412,10 @@ derive_parser!(
data: ast::MadDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2412,9 +2426,10 @@ derive_parser!(
data: ast::MadDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
@ -2429,9 +2444,10 @@ derive_parser!(
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
@ -2440,9 +2456,10 @@ derive_parser!(
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: false
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
@ -2454,9 +2471,10 @@ derive_parser!(
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f16,
rounding: Some(rnd.into()),
rounding: rnd.into(),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: false
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
@ -2504,9 +2522,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2517,9 +2536,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2533,9 +2553,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2546,9 +2567,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz),
saturate: sat
saturate: sat,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2559,9 +2581,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2572,9 +2595,10 @@ derive_parser!(
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
rounding: rnd.map(Into::into).unwrap_or(ast::RoundingMode::NearestEven),
flush_to_zero: None,
saturate: false
saturate: false,
is_fusable: rnd.is_none()
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
@ -2877,7 +2901,7 @@ derive_parser!(
rsqrt.approx.f64 d, a => {
ast::Instruction::Rsqrt {
data: ast::TypeFtz {
flush_to_zero: None,
flush_to_zero: Some(false),
type_: f64
},
arguments: RsqrtArgs { dst: d, src: a }
@ -2886,7 +2910,7 @@ derive_parser!(
rsqrt.approx.ftz.f64 d, a => {
ast::Instruction::Rsqrt {
data: ast::TypeFtz {
flush_to_zero: None,
flush_to_zero: Some(true),
type_: f64
},
arguments: RsqrtArgs { dst: d, src: a }