Apply computed denormal modes to basic blocks

This commit is contained in:
Andrzej Janik 2025-02-22 01:31:32 +00:00
parent aaa31da026
commit 82ca92c5c3
14 changed files with 626 additions and 123 deletions

View file

@ -95,16 +95,7 @@ fn run_method<'input>(
Ok::<_, TranslateError>(body)
})
.transpose()?;
Ok(Function2 {
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
Ok(Function2 { body, ..method })
}
fn run_statement<'input>(

View file

@ -243,6 +243,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
if !method.is_kernel {
self.resolver.register(method.name, fn_);
self.emit_fn_attribute(fn_, "denormal-fp-math-f32", "dynamic");
self.emit_fn_attribute(fn_, "denormal-fp-math", "dynamic");
} else {
self.emit_fn_attribute(
fn_,
"denormal-fp-math-f32",
llvm_ftz(method.flush_to_zero_f32),
);
self.emit_fn_attribute(
fn_,
"denormal-fp-math",
llvm_ftz(method.flush_to_zero_f16f64),
);
}
for (i, param) in method.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
@ -413,6 +426,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
}
fn llvm_ftz(ftz: bool) -> &'static str {
if ftz {
"preserve-sign"
} else {
"ieee"
}
}
fn get_input_argument_type(
context: LLVMContextRef,
v_type: &ast::Type,
@ -469,6 +490,7 @@ impl<'a> MethodEmitContext<'a> {
Statement::FunctionPointer(_) => todo!(),
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?,
})
}
@ -1124,7 +1146,7 @@ impl<'a> MethodEmitContext<'a> {
let cos = self.emit_intrinsic(
c"llvm.cos.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
Some(&ast::ScalarType::F32.into()),
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
@ -1377,7 +1399,7 @@ impl<'a> MethodEmitContext<'a> {
let sin = self.emit_intrinsic(
c"llvm.sin.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
Some(&ast::ScalarType::F32.into()),
vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
@ -1388,12 +1410,12 @@ impl<'a> MethodEmitContext<'a> {
&mut self,
name: &CStr,
dst: Option<SpirvWord>,
return_type: &ast::Type,
return_type: Option<&ast::Type>,
arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
) -> Result<LLVMValueRef, TranslateError> {
let fn_type = get_function_type(
self.context,
iter::once(return_type),
return_type.into_iter(),
arguments.iter().map(|(_, type_)| Ok(*type_)),
)?;
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
@ -1612,7 +1634,7 @@ impl<'a> MethodEmitContext<'a> {
let clamped = self.emit_intrinsic(
c"llvm.umin",
None,
&from.into(),
Some(&from.into()),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(max, from_llvm),
@ -1642,7 +1664,7 @@ impl<'a> MethodEmitContext<'a> {
let zero_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
None,
&from.into(),
Some(&from.into()),
vec![
(self.resolver.value(arguments.src)?, from_llvm),
(zero, from_llvm),
@ -1661,7 +1683,7 @@ impl<'a> MethodEmitContext<'a> {
let fully_clamped = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
None,
&from.into(),
Some(&from.into()),
vec![(zero_clamped, from_llvm), (max, from_llvm)],
)?;
let resize_fn = if to.layout().size() >= from.layout().size() {
@ -1701,7 +1723,7 @@ impl<'a> MethodEmitContext<'a> {
let rounded_float = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
None,
&from.into(),
Some(&from.into()),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, from),
@ -1770,7 +1792,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
Some(&data.type_.into()),
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
@ -1791,7 +1813,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
Some(&data.type_.into()),
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
@ -1813,7 +1835,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
Some(&data.type_.into()),
vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
@ -1935,7 +1957,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&data.type_.into(),
Some(&data.type_.into()),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, data.type_),
@ -1952,7 +1974,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
c"llvm.amdgcn.log.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
Some(&ast::ScalarType::F32.into()),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, ast::ScalarType::F32.into()),
@ -2007,7 +2029,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
intrinsic,
Some(arguments.dst),
&type_.into(),
Some(&type_.into()),
vec![(self.resolver.value(arguments.src)?, llvm_type)],
)?;
Ok(())
@ -2031,7 +2053,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
Some(&data.type_().into()),
vec![
(self.resolver.value(arguments.src1)?, llvm_type),
(self.resolver.value(arguments.src2)?, llvm_type),
@ -2058,7 +2080,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
Some(&data.type_().into()),
vec![
(self.resolver.value(arguments.src1)?, llvm_type),
(self.resolver.value(arguments.src2)?, llvm_type),
@ -2076,7 +2098,7 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_.into(),
Some(&data.type_.into()),
vec![
(
self.resolver.value(arguments.src1)?,
@ -2197,12 +2219,49 @@ impl<'a> MethodEmitContext<'a> {
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_.into(),
Some(&data.type_.into()),
intrinsic_arguments,
)?;
Ok(())
}
fn emit_set_mode(&mut self, mode_reg: ModeRegister) -> Result<(), TranslateError> {
let intrinsic = c"llvm.amdgcn.s.setreg";
let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32);
let (hwreg, value) = match mode_reg {
ModeRegister::DenormalF32(ftz) => {
let (reg, offset, size) = (1, 4, 2u32);
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
(hwreg, if ftz { 0u32 } else { 3 })
}
ModeRegister::DenormalF16F64(ftz) => {
let (reg, offset, size) = (1, 6, 2u32);
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
(hwreg, if ftz { 0 } else { 3 })
}
ModeRegister::DenormalBoth { f32, f16f64 } => {
let (reg, offset, size) = (1, 4, 4u32);
let hwreg = reg | (offset << 6) | ((size - 1) << 11);
let f32 = if f32 { 0 } else { 3 };
let f16f64 = if f16f64 { 0 } else { 3 };
let value = f32 | f16f64 << 2;
(hwreg, value)
}
ModeRegister::RoundingF32(rounding_mode) => todo!(),
ModeRegister::RoundingF16F64(rounding_mode) => todo!(),
ModeRegister::RoundingBoth { f32, f16f64 } => todo!(),
};
let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) };
let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) };
self.emit_intrinsic(
intrinsic,
None,
None,
vec![(hwreg_llvm, llvm_i32), (value_llvm, llvm_i32)],
)?;
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19

View file

@ -41,14 +41,18 @@ fn run_method<'input>(
})
.transpose()?;
Ok(Function2 {
body,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
roundind_mode_f32: method.roundind_mode_f32,
roundind_mode_f16f64: method.roundind_mode_f16f64,
})
}

View file

@ -20,6 +20,10 @@ pub(super) fn run<'a, 'input>(
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}));
sreg_to_function.insert(sreg, name);
},
@ -60,16 +64,7 @@ fn run_method<'a, 'input>(
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
Ok(Function2 { body, ..method })
}
fn run_statement<'a, 'input>(

View file

@ -64,16 +64,7 @@ fn run_method<'a, 'input>(
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
Ok(Function2 { body, ..method })
}
fn run_statement<'a, 'input>(

View file

@ -1,6 +1,7 @@
use super::BrachCondition;
use super::Directive2;
use super::Function2;
use super::ModeRegister;
use super::SpirvWord;
use super::Statement;
use super::TranslateError;
@ -18,6 +19,7 @@ use rustc_hash::FxHashSet;
use std::hash::Hash;
use std::iter;
use std::mem;
use std::u32;
use strum::EnumCount;
use strum_macros::{EnumCount, VariantArray};
@ -36,6 +38,13 @@ impl DenormalMode {
DenormalMode::Preserve
}
}
fn to_ftz(self) -> bool {
match self {
DenormalMode::FlushToZero => true,
DenormalMode::Preserve => false,
}
}
}
impl Into<usize> for DenormalMode {
@ -94,20 +103,19 @@ impl InstructionModes {
_ => {}
}
}
fn set_if_some<T: Copy>(source: &mut Option<T>, value: Option<T>) {
match (source, value) {
(Some(ref mut x), Some(y)) => *x = y,
_ => {}
fn set_if_any<T: Copy>(source: &mut Option<T>, value: Option<T>) {
if let Some(x) = value {
*source = Some(x);
}
}
set_if_none(&mut entry.denormal_f32, self.denormal_f32);
set_if_none(&mut entry.denormal_f16f64, self.denormal_f16f64);
set_if_none(&mut entry.rounding_f32, self.rounding_f32);
set_if_none(&mut entry.rounding_f16f64, self.rounding_f16f64);
set_if_some(&mut exit.denormal_f32, self.denormal_f32);
set_if_some(&mut exit.denormal_f16f64, self.denormal_f16f64);
set_if_some(&mut exit.rounding_f32, self.rounding_f32);
set_if_some(&mut exit.rounding_f16f64, self.rounding_f16f64);
set_if_any(&mut exit.denormal_f32, self.denormal_f32);
set_if_any(&mut exit.denormal_f16f64, self.denormal_f16f64);
set_if_any(&mut exit.rounding_f32, self.rounding_f32);
set_if_any(&mut exit.rounding_f16f64, self.rounding_f16f64);
}
fn none() -> Self {
@ -209,18 +217,12 @@ impl InstructionModes {
flush_to_zero.map(DenormalMode::from_ftz),
Some(RoundingMode::from_ast(rounding)),
),
ast::CvtMode::SignedFromFP {
flush_to_zero,
rounding,
// float to int contains rounding field, but it's not a rounding
// mode but rather round-to-int operation that will be applied
ast::CvtMode::SignedFromFP { flush_to_zero, .. }
| ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => {
Self::new(cvt.from, flush_to_zero.map(DenormalMode::from_ftz), None)
}
| ast::CvtMode::UnsignedFromFP {
flush_to_zero,
rounding,
} => Self::new(
cvt.from,
flush_to_zero.map(DenormalMode::from_ftz),
Some(RoundingMode::from_ast(rounding)),
),
ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => {
Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd)))
}
@ -263,22 +265,15 @@ impl ControlFlowGraph {
}
fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) {
self.graph[node].denormal_f32 = Mode {
entry: entry.denormal_f32.map(ExtendedMode::BasicBlock),
exit: exit.denormal_f32.map(ExtendedMode::BasicBlock),
};
self.graph[node].denormal_f16f64 = Mode {
entry: entry.denormal_f16f64.map(ExtendedMode::BasicBlock),
exit: exit.denormal_f16f64.map(ExtendedMode::BasicBlock),
};
self.graph[node].rounding_f32 = Mode {
entry: entry.rounding_f32.map(ExtendedMode::BasicBlock),
exit: exit.rounding_f32.map(ExtendedMode::BasicBlock),
};
self.graph[node].rounding_f16f64 = Mode {
entry: entry.rounding_f16f64.map(ExtendedMode::BasicBlock),
exit: exit.rounding_f16f64.map(ExtendedMode::BasicBlock),
};
let node = &mut self.graph[node];
node.denormal_f32.entry = entry.denormal_f32.map(ExtendedMode::BasicBlock);
node.denormal_f16f64.entry = entry.denormal_f16f64.map(ExtendedMode::BasicBlock);
node.rounding_f32.entry = entry.rounding_f32.map(ExtendedMode::BasicBlock);
node.rounding_f16f64.entry = entry.rounding_f16f64.map(ExtendedMode::BasicBlock);
node.denormal_f32.exit = exit.denormal_f32.map(ExtendedMode::BasicBlock);
node.denormal_f16f64.exit = exit.denormal_f16f64.map(ExtendedMode::BasicBlock);
node.rounding_f32.exit = exit.rounding_f32.map(ExtendedMode::BasicBlock);
node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock);
}
}
@ -343,7 +338,7 @@ trait EnumTuple {
pub(crate) fn run<'input>(
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
mut directives: Vec<super::Directive2<ast::Instruction<SpirvWord>, super::SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut cfg = ControlFlowGraph::new();
for directive in directives.iter() {
@ -351,42 +346,39 @@ pub(crate) fn run<'input>(
super::Directive2::Method(Function2 {
name,
body: Some(body),
is_kernel,
..
}) => {
let mut basic_block = Some(cfg.add_entry_basic_block(*name));
let mut entry = InstructionModes::none();
let mut exit = InstructionModes::none();
// TODO: implement for non-kernels
if !*is_kernel {
todo!()
}
let entry_index = cfg.add_entry_basic_block(*name);
let mut bb_state = BasicBlockState::new(&mut cfg);
let mut body_iter = body.iter();
match body_iter.next() {
Some(Statement::Label(label)) => {
bb_state.cfg.add_jump(entry_index, *label);
bb_state.start(*label);
}
_ => return Err(error_unreachable()),
};
for statement in body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Bra { arguments }) => {
let bb_index = basic_block.ok_or_else(error_unreachable)?;
cfg.add_jump(bb_index, arguments.src);
cfg.set_modes(
bb_index,
mem::replace(&mut entry, InstructionModes::none()),
mem::replace(&mut exit, InstructionModes::none()),
);
basic_block = None;
bb_state.end(&[arguments.src]);
}
Statement::Label(label) => {
basic_block = Some(cfg.get_or_add_basic_block(*label));
bb_state.start(*label);
}
Statement::Conditional(BrachCondition {
if_true, if_false, ..
}) => {
let bb_index = basic_block.ok_or_else(error_unreachable)?;
cfg.add_jump(bb_index, *if_true);
cfg.add_jump(bb_index, *if_false);
cfg.set_modes(
bb_index,
mem::replace(&mut entry, InstructionModes::none()),
mem::replace(&mut exit, InstructionModes::none()),
);
basic_block = None;
bb_state.end(&[*if_true, *if_false]);
}
Statement::Instruction(instruction) => {
let modes = get_modes(instruction);
modes.fold_into(&mut entry, &mut exit);
bb_state.append(modes);
}
_ => {}
}
@ -395,7 +387,370 @@ pub(crate) fn run<'input>(
_ => {}
}
}
todo!()
let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32);
let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64);
let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32);
let rounding_f16f64 = compute_single_mode(&cfg, |node| node.rounding_f16f64);
let denormal_f32 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f32);
let denormal_f16f64 = optimize::<DenormalMode, { DenormalMode::COUNT }>(denormal_f16f64);
let rounding_f32 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f32);
let rounding_f16f64 = optimize::<RoundingMode, { RoundingMode::COUNT }>(rounding_f16f64);
insert_mode_control(
flat_resolver,
&mut directives,
&cfg,
denormal_f32,
denormal_f16f64,
rounding_f32,
rounding_f16f64,
)?;
Ok(directives)
}
fn insert_mode_control<'input>(
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
directives: &mut [Directive2<ast::Instruction<SpirvWord>, SpirvWord>],
cfg: &ControlFlowGraph,
denormal_f32: ModeInsertions<DenormalMode>,
denormal_f16f64: ModeInsertions<DenormalMode>,
rounding_f32: ModeInsertions<RoundingMode>,
rounding_f16f64: ModeInsertions<RoundingMode>,
) -> Result<(), TranslateError> {
for directive in directives.iter_mut() {
let body_ptr = match directive {
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => continue,
Directive2::Method(Function2 {
name,
body: Some(body),
flush_to_zero_f32,
flush_to_zero_f16f64,
roundind_mode_f32: rounding_mode_f32,
roundind_mode_f16f64: rounding_mode_f16f64,
..
}) => {
*flush_to_zero_f32 = denormal_f32
.kernels
.get(name)
.copied()
.unwrap_or(DenormalMode::default())
.to_ftz();
*flush_to_zero_f16f64 = denormal_f16f64
.kernels
.get(name)
.copied()
.unwrap_or(DenormalMode::default())
.to_ftz();
*rounding_mode_f32 = rounding_f32
.kernels
.get(name)
.copied()
.unwrap_or(RoundingMode::default())
.to_ast();
*rounding_mode_f16f64 = rounding_f16f64
.kernels
.get(name)
.copied()
.unwrap_or(RoundingMode::default())
.to_ast();
body
}
};
let mut old_body = mem::replace(body_ptr, Vec::new());
let mut result = Vec::with_capacity(old_body.len());
let mut bb_state = BasicBlockControlState::new(
&denormal_f32,
&denormal_f16f64,
&rounding_f32,
&rounding_f16f64,
);
for statement in old_body.into_iter() {
match &statement {
Statement::Label(label) => {
bb_state.start(*label);
}
Statement::Instruction(instruction) => {
let modes = get_modes(&instruction);
bb_state.insert(&mut result, modes)?;
}
_ => {}
}
result.push(statement);
}
*body_ptr = result;
}
Ok(())
}
struct BasicBlockControlState<'a> {
global_denormal_f32: &'a ModeInsertions<DenormalMode>,
global_denormal_f16f64: &'a ModeInsertions<DenormalMode>,
global_rounding_f32: &'a ModeInsertions<RoundingMode>,
global_rounding_f16f64: &'a ModeInsertions<RoundingMode>,
basic_block: SpirvWord,
denormal_f32: RegisterState<bool>,
denormal_f16f64: RegisterState<bool>,
foldable_rounding_f32: Option<usize>,
foldable_rounding_f16f64: Option<usize>,
}
#[derive(Clone, Copy)]
enum RegisterState<T> {
Inherited,
Unknown,
Value(Option<usize>, T),
}
impl<T> RegisterState<T> {
fn empty() -> Self {
Self::Unknown
}
fn new(must_insert: bool) -> Self {
if must_insert {
Self::Unknown
} else {
Self::Inherited
}
}
}
impl<'a> BasicBlockControlState<'a> {
fn new(
global_denormal_f32: &'a ModeInsertions<DenormalMode>,
global_denormal_f16f64: &'a ModeInsertions<DenormalMode>,
global_rounding_f32: &'a ModeInsertions<RoundingMode>,
global_rounding_f16f64: &'a ModeInsertions<RoundingMode>,
) -> Self {
BasicBlockControlState {
global_denormal_f32,
global_denormal_f16f64,
global_rounding_f32,
global_rounding_f16f64,
basic_block: SpirvWord(u32::MAX),
denormal_f32: RegisterState::empty(),
denormal_f16f64: RegisterState::empty(),
foldable_rounding_f32: None,
foldable_rounding_f16f64: None,
}
}
fn start(&mut self, label: SpirvWord) {
self.denormal_f32 =
RegisterState::new(self.global_denormal_f32.basic_blocks.contains(&label));
self.denormal_f32 =
RegisterState::new(self.global_denormal_f16f64.basic_blocks.contains(&label));
self.foldable_rounding_f32 = None;
self.foldable_rounding_f16f64 = None;
self.basic_block = label;
}
fn add_or_fold_mode_set(
&mut self,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
new_mode: bool,
) -> Option<usize> {
// try and fold into the other mode set
if let RegisterState::Value(Some(other_index), other_value) = self.denormal_f16f64 {
if let Some(Statement::SetMode(ModeRegister::DenormalF16F64(_))) =
result.get_mut(other_index)
{
result[other_index] = Statement::SetMode(ModeRegister::DenormalBoth {
f32: new_mode,
f16f64: other_value,
});
return None;
}
}
result.push(Statement::SetMode(ModeRegister::DenormalF32(new_mode)));
Some(result.len() - 1)
}
fn insert(
&mut self,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
modes: InstructionModes,
) -> Result<(), TranslateError> {
self.insert_one::<DenormalF32View>(result, modes.denormal_f32.map(DenormalMode::to_ftz))?;
self.insert_one::<DenormalF16F64View>(
result,
modes.denormal_f16f64.map(DenormalMode::to_ftz),
)?;
Ok(())
}
fn insert_one<View: ModeView>(
&mut self,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
mode: Option<View::Value>,
) -> Result<(), TranslateError> {
if let Some(new_mode) = mode {
let register_state = View::get_register(self);
match register_state {
RegisterState::Inherited => {
View::set_register(self, RegisterState::Value(None, new_mode));
}
RegisterState::Unknown => {
View::set_register(
self,
RegisterState::Value(
Some(self.add_or_fold_mode_set2::<View>(result, new_mode)),
new_mode,
),
);
}
RegisterState::Value(_, old_value) => {
if new_mode == old_value {
return Ok(());
}
View::set_register(
self,
RegisterState::Value(
Some(self.add_or_fold_mode_set2::<View>(result, new_mode)),
new_mode,
),
);
}
}
}
Ok(())
}
// Return the index of the last insertion of SetMode with this mode
fn add_or_fold_mode_set2<View: ModeView>(
&self,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
new_mode: View::Value,
) -> usize {
// try and fold into the other mode set in struction
if let RegisterState::Value(Some(twin_index), _) = View::TwinView::get_register(self) {
if let Some(Statement::SetMode(register_mode)) = result.get_mut(twin_index) {
if let Some(twin_mode) = View::TwinView::get_single_mode(register_mode) {
*register_mode = View::new_mode(new_mode, Some(twin_mode));
return twin_index;
}
}
}
result.push(Statement::SetMode(View::new_mode(new_mode, None)));
result.len() - 1
}
}
trait ModeView {
type Value: PartialEq + Eq + Copy + Clone;
type TwinView: ModeView<Value = Self::Value>;
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value>;
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>);
fn new_mode(t: Self::Value, other: Option<Self::Value>) -> ModeRegister;
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value>;
}
struct DenormalF32View;
impl ModeView for DenormalF32View {
type Value = bool;
type TwinView = DenormalF16F64View;
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value> {
bb.denormal_f32
}
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>) {
bb.denormal_f32 = reg;
}
fn new_mode(f32: Self::Value, f16f64: Option<Self::Value>) -> ModeRegister {
match f16f64 {
Some(f16f64) => ModeRegister::DenormalBoth { f32, f16f64 },
None => ModeRegister::DenormalF32(f32),
}
}
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value> {
match reg {
ModeRegister::DenormalF32(value) => Some(*value),
_ => None,
}
}
}
struct DenormalF16F64View;
impl ModeView for DenormalF16F64View {
type Value = bool;
type TwinView = DenormalF32View;
fn get_register(bb: &BasicBlockControlState) -> RegisterState<Self::Value> {
bb.denormal_f16f64
}
fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState<Self::Value>) {
bb.denormal_f16f64 = reg;
}
fn new_mode(f16f64: Self::Value, f32: Option<Self::Value>) -> ModeRegister {
match f32 {
Some(f32) => ModeRegister::DenormalBoth { f16f64, f32 },
None => ModeRegister::DenormalF16F64(f16f64),
}
}
fn get_single_mode(reg: &ModeRegister) -> Option<Self::Value> {
match reg {
ModeRegister::DenormalF16F64(value) => Some(*value),
_ => None,
}
}
}
struct BasicBlockState<'a> {
cfg: &'a mut ControlFlowGraph,
node_index: Option<NodeIndex>,
// If it's a kernel basic block then we don't track entry instruction mode
entry: InstructionModes,
exit: InstructionModes,
}
impl<'a> BasicBlockState<'a> {
fn new(cfg: &'a mut ControlFlowGraph) -> BasicBlockState<'a> {
Self {
cfg,
node_index: None,
entry: InstructionModes::none(),
exit: InstructionModes::none(),
}
}
fn start(&mut self, label: SpirvWord) {
self.end(&[]);
self.node_index = Some(self.cfg.get_or_add_basic_block(label));
}
fn end(&mut self, jumps: &[SpirvWord]) {
let node_index = self.node_index.take();
let node_index = match node_index {
Some(x) => x,
None => return,
};
for target in jumps {
self.cfg.add_jump(node_index, *target);
}
self.cfg.set_modes(
node_index,
mem::replace(&mut self.entry, InstructionModes::none()),
mem::replace(&mut self.exit, InstructionModes::none()),
);
}
fn append(&mut self, modes: InstructionModes) {
modes.fold_into(&mut self.entry, &mut self.exit);
}
}
impl<'a> Drop for BasicBlockState<'a> {
fn drop(&mut self) {
self.end(&[]);
}
}
fn compute_single_mode<T: Copy + Eq>(
@ -424,10 +779,9 @@ fn compute_single_mode<T: Copy + Eq>(
UniqueVec::new(graph.graph.neighbors_directed(index, Direction::Incoming));
let mut visited = FxHashSet::default();
while let Some(current) = to_visit.pop() {
if visited.contains(&current) {
if !visited.insert(current) {
continue;
}
visited.insert(current);
let exit_mode = getter(graph.graph.node_weight(current).unwrap()).exit;
match exit_mode {
None => {
@ -462,6 +816,7 @@ fn compute_single_mode<T: Copy + Eq>(
}
}
#[derive(Debug)]
struct PartialModeInsertion<T> {
bb_must_insert_mode: FxHashSet<SpirvWord>,
bb_maybe_insert_mode: FxHashMap<SpirvWord, (T, FxHashSet<SpirvWord>)>,
@ -498,10 +853,11 @@ fn optimize<T: Copy + Into<usize> + strum::VariantArray + std::fmt::Debug, const
}
}
let mut kernels = FxHashMap::default();
for (kernel, modes) in kernel_modes {
'iterate_kernels: for (kernel, modes) in kernel_modes {
for (mode, var) in modes.into_iter().enumerate() {
if solution[var] > 0.5 {
kernels.insert(kernel, T::VARIANTS[mode]);
continue 'iterate_kernels;
}
}
}

View file

@ -19,6 +19,7 @@ mod hoist_globals;
mod insert_explicit_load_store;
mod insert_ftz_control;
mod insert_implicit_conversions2;
mod normalize_basic_blocks;
mod normalize_identifiers2;
mod normalize_predicates2;
mod replace_instructions_with_function_calls;
@ -52,6 +53,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives);
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
@ -197,6 +199,22 @@ enum Statement<I, P: ast::Operand> {
FunctionPointer(FunctionPointerDetails),
VectorRead(VectorRead),
VectorWrite(VectorWrite),
SetMode(ModeRegister),
}
enum ModeRegister {
DenormalF32(bool),
DenormalF16F64(bool),
DenormalBoth {
f32: bool,
f16f64: bool,
},
RoundingF32(ast::RoundingMode),
RoundingF16F64(ast::RoundingMode),
RoundingBoth {
f32: ast::RoundingMode,
f16f64: ast::RoundingMode,
},
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@ -469,6 +487,7 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
let src = visitor.visit_ident(src, None, false, false)?;
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
}
Statement::SetMode(mode_register) => Statement::SetMode(mode_register),
})
}
}
@ -573,6 +592,10 @@ struct Function2<Instruction, Operand: ast::Operand> {
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
linkage: ast::LinkingDirective,
flush_to_zero_f32: bool,
flush_to_zero_f16f64: bool,
roundind_mode_f32: ast::RoundingMode,
roundind_mode_f16f64: ast::RoundingMode,
}
type NormalizedDirective2 = Directive2<

View file

@ -0,0 +1,52 @@
use super::*;
// This pass normalized ptx modules in two ways that makes mode computation pass
// and code emissions passes much simpler:
// * Inserts label at the start of every function
// This makes control flow graph simpler in mode computation block: we can
// represent kernels as separate nodes with its own separate entry/exit mode
// * Inserts label at the start of every basic block
pub(crate) fn run(
flat_resolver: &mut GlobalStringIdentResolver2<'_>,
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>> {
for directive in directives.iter_mut() {
let body_ref = match directive {
Directive2::Method(Function2 {
body: Some(body), ..
}) => body,
_ => continue,
};
let body = std::mem::replace(body_ref, Vec::new());
let mut result = Vec::with_capacity(body.len());
let mut needs_label = false;
let mut body_iterator = body.into_iter();
match body_iterator.next() {
Some(Statement::Label(_)) => {}
Some(statement) => {
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
result.push(statement);
}
None => {}
}
for statement in body_iterator {
if needs_label && !matches!(statement, Statement::Label(..)) {
result.push(Statement::Label(flat_resolver.register_unnamed(None)));
}
needs_label = is_block_terminator(&statement);
result.push(statement);
}
*body_ref = result;
}
directives
}
fn is_block_terminator(instruction: &Statement<ast::Instruction<SpirvWord>, SpirvWord>) -> bool {
match instruction {
Statement::Conditional(..)
| Statement::Instruction(ast::Instruction::Bra { .. })
| Statement::Instruction(ast::Instruction::Ret { .. }) => true,
_ => false,
}
}

View file

@ -52,9 +52,13 @@ fn run_method<'input, 'b>(
input_arguments,
body,
import_as: None,
tuning: method.tuning,
linkage,
is_kernel,
tuning: method.tuning,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
}

View file

@ -36,14 +36,18 @@ fn run_method<'input>(
})
.transpose()?;
Ok(Function2 {
body,
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
flush_to_zero_f32: method.flush_to_zero_f32,
flush_to_zero_f16f64: method.flush_to_zero_f16f64,
roundind_mode_f32: method.roundind_mode_f32,
roundind_mode_f16f64: method.roundind_mode_f16f64,
})
}

View file

@ -21,6 +21,10 @@ pub(super) fn run<'input>(
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
is_kernel: false,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
roundind_mode_f32: ptx_parser::RoundingMode::NearestEven,
roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
})
.collect::<Vec<_>>();

View file

@ -40,16 +40,7 @@ fn run_method<'input>(
.collect::<Result<Vec<_>, _>>()
})
.transpose()?;
Ok(Function2 {
return_arguments: method.return_arguments,
name: method.name,
input_arguments: method.input_arguments,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
is_kernel: method.is_kernel,
})
Ok(Function2 { body, ..method })
}
fn run_statement<'input>(

View file

@ -0,0 +1,27 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry malformed_label(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
bra BB0;
// this basic block does not start with a label
ld.u64 temp, [out_addr];
BB0:
ld.u64 temp, [in_addr];
add.u64 temp2, temp, 1;
st.u64 [out_addr], temp2;
ret;
}

View file

@ -186,6 +186,8 @@ test_ptx!(
[0x800000u32, 0xFFFFFF]
);
test_ptx!(malformed_label, [2u64], [3u64]);
test_ptx!(assertfail);
test_ptx!(func_ptr);
test_ptx!(lanemask_lt);