Add mode-setting wrappers to functions

This commit is contained in:
Andrzej Janik 2025-03-13 16:49:06 +00:00
parent a0d4b7eeb2
commit 87fe601494
3 changed files with 323 additions and 94 deletions

View file

@ -734,7 +734,9 @@ pub(crate) fn run<'input>(
}
Statement::RetValue(..)
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
bb_state.record_ret(*name)?;
if !is_kernel {
bb_state.record_ret(*name)?;
}
}
Statement::Label(label) => {
bb_state.start(*label);
@ -808,7 +810,7 @@ pub(crate) fn run<'input>(
)?;
let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?;
*/
let directives = insert_mode_control(directives, temp)?;
let directives = insert_mode_control(flat_resolver, directives, temp)?;
Ok(directives)
}
@ -1018,47 +1020,48 @@ struct TwinMode<T> {
}
fn insert_mode_control(
flat_resolver: &mut super::GlobalStringIdentResolver2,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
global_modes: FullModeInsertion2,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let directives_len = directives.len();
directives
.into_iter()
.map(|mut directive| {
.map(|directive| {
let mut new_directives = SmallVec::<[_; 4]>::new();
let (fn_name, initial_mode, body_ptr) = match directive {
let (mut method, initial_mode) = match directive {
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => {
new_directives.push(directive);
return Ok(new_directives);
}
Directive2::Method(Function2 {
name,
body: Some(ref mut body),
ref mut flush_to_zero_f32,
ref mut flush_to_zero_f16f64,
ref mut rounding_mode_f32,
ref mut rounding_mode_f16f64,
..
}) => {
Directive2::Method(
mut method @ Function2 {
name,
body: Some(_),
..
},
) => {
let initial_mode = global_modes
.basic_blocks
.get(&name)
.ok_or_else(error_unreachable)?;
let denormal_mode = initial_mode.denormal.twin_mode;
let rounding_mode = initial_mode.rounding.twin_mode;
*flush_to_zero_f32 = denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz();
*flush_to_zero_f16f64 =
method.flush_to_zero_f32 =
denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz();
method.flush_to_zero_f16f64 =
denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz();
*rounding_mode_f32 = rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast();
*rounding_mode_f16f64 =
method.rounding_mode_f32 =
rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast();
method.rounding_mode_f16f64 =
rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast();
(name, initial_mode, body)
(method, initial_mode)
}
};
emit_mode_prelude(fn_name, &mut new_directives);
let old_body = mem::replace(body_ptr, Vec::new());
emit_mode_prelude(flat_resolver, &method, &global_modes, &mut new_directives)?;
let old_body = method.body.take().unwrap();
let mut result = Vec::with_capacity(old_body.len());
let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode);
let mut bb_state = BasicBlockControlState::new(&global_modes, initial_mode);
let mut old_body = old_body.into_iter();
while let Some(mut statement) = old_body.next() {
let mut call_target = None;
@ -1115,8 +1118,8 @@ fn insert_mode_control(
}
}
}
*body_ptr = result;
new_directives.push(directive);
method.body = Some(result);
new_directives.push(Directive2::Method(method));
Ok(new_directives)
})
.try_fold(Vec::with_capacity(directives_len), |mut acc, d| {
@ -1126,35 +1129,219 @@ fn insert_mode_control(
}
fn emit_mode_prelude(
fn_name: SpirvWord,
global_modes: FullModeInsertion2,
flat_resolver: &mut super::GlobalStringIdentResolver2,
method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
global_modes: &FullModeInsertion2,
new_directives: &mut SmallVec<[Directive2<ptx_parser::Instruction<SpirvWord>, SpirvWord>; 4]>,
) -> Result<(), TranslateError> {
let fn_mode_state = global_modes.basic_blocks.get(&fn_name).ok_or_else(error_unreachable)?;
let fn_mode_state = global_modes
.basic_blocks
.get(&method.name)
.ok_or_else(error_unreachable)?;
if let Some(dual_prologue) = fn_mode_state.dual_prologue {
new_directives.push(Directive2::Method(
Function2 {
return_arguments: todo!(),
name: dual_prologue,
input_arguments: todo!(),
body: todo!(),
is_kernel: false,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::NONE,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
}
new_directives.push(create_fn_wrapper(
flat_resolver,
method,
dual_prologue,
[
ModeRegister::Denormal {
f32: fn_mode_state
.denormal
.twin_mode
.f32
.unwrap_of_default()
.to_ftz(),
f16f64: fn_mode_state
.denormal
.twin_mode
.f16f64
.unwrap_of_default()
.to_ftz(),
},
ModeRegister::Rounding {
f32: fn_mode_state
.rounding
.twin_mode
.f32
.unwrap_of_default()
.to_ast(),
f16f64: fn_mode_state
.rounding
.twin_mode
.f16f64
.unwrap_of_default()
.to_ast(),
},
]
.into_iter(),
));
}
if let Some(prologue) = fn_mode_state.denormal.prologue {
todo!()
new_directives.push(create_fn_wrapper(
flat_resolver,
method,
prologue,
[ModeRegister::Denormal {
f32: fn_mode_state
.denormal
.twin_mode
.f32
.unwrap_of_default()
.to_ftz(),
f16f64: fn_mode_state
.denormal
.twin_mode
.f16f64
.unwrap_of_default()
.to_ftz(),
}]
.into_iter(),
));
}
if let Some(prologue) = fn_mode_state.rounding.prologue {
todo!()
new_directives.push(create_fn_wrapper(
flat_resolver,
method,
prologue,
[ModeRegister::Rounding {
f32: fn_mode_state
.rounding
.twin_mode
.f32
.unwrap_of_default()
.to_ast(),
f16f64: fn_mode_state
.rounding
.twin_mode
.f16f64
.unwrap_of_default()
.to_ast(),
}]
.into_iter(),
));
}
Ok(())
}
fn create_fn_wrapper(
flat_resolver: &mut super::GlobalStringIdentResolver2,
method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
name: SpirvWord,
modes: impl ExactSizeIterator<Item = ModeRegister>,
) -> Directive2<ast::Instruction<SpirvWord>, SpirvWord> {
// * Label
// * return argument registers
// * input argument registers
// * Load input arguments
// * set modes
// * call
// * return with value
let return_arguments = rename_variables(flat_resolver, &method.return_arguments);
let input_arguments = rename_variables(flat_resolver, &method.input_arguments);
let mut body = Vec::with_capacity(
1 + (input_arguments.len() * 2) + return_arguments.len() + modes.len() + 2,
);
body.push(Statement::Label(flat_resolver.register_unnamed(None)));
let return_variables = append_variables(flat_resolver, &mut body, &return_arguments);
let input_variables = append_variables(flat_resolver, &mut body, &input_arguments);
for (index, input_reg) in input_variables.iter().enumerate() {
body.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: input_arguments[index].state_space,
caching: ast::LdCacheOperator::Cached,
typ: input_arguments[index].v_type.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
src: input_arguments[index].name,
dst: *input_reg,
},
}));
}
body.extend(modes.map(|mode_set| Statement::SetMode(mode_set)));
// Out of order because we want to use return_variables before they are moved
let ret_statement = if return_arguments.is_empty() {
Statement::Instruction(ast::Instruction::Ret {
data: ast::RetData { uniform: false },
})
} else {
Statement::RetValue(
ast::RetData { uniform: false },
return_variables
.iter()
.enumerate()
.map(|(index, var)| (*var, method.return_arguments[index].v_type.clone()))
.collect(),
)
};
body.push(Statement::Instruction(ast::Instruction::Call {
data: ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|arg| (arg.v_type.clone(), arg.state_space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|arg| (arg.v_type.clone(), arg.state_space))
.collect(),
},
arguments: ast::CallArgs {
return_arguments: return_variables,
func: method.name,
input_arguments: input_variables,
},
}));
body.push(ret_statement);
Directive2::Method(Function2 {
return_arguments,
name,
input_arguments,
body: Some(body),
is_kernel: false,
import_as: None,
tuning: Vec::new(),
linkage: ast::LinkingDirective::NONE,
flush_to_zero_f32: false,
flush_to_zero_f16f64: false,
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
})
}
fn rename_variables(
flat_resolver: &mut super::GlobalStringIdentResolver2,
variables: &Vec<ast::Variable<SpirvWord>>,
) -> Vec<ast::Variable<SpirvWord>> {
variables
.iter()
.cloned()
.map(|arg| ast::Variable {
name: flat_resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))),
..arg
})
.collect()
}
fn append_variables<'a, 'input: 'a>(
flat_resolver: &'a mut super::GlobalStringIdentResolver2<'input>,
body: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
arguments: &'a Vec<ast::Variable<SpirvWord>>,
) -> Vec<SpirvWord> {
let mut result = Vec::with_capacity(arguments.len());
for arg in arguments {
let name = flat_resolver.register_unnamed(Some((arg.v_type.clone(), ast::StateSpace::Reg)));
body.push(Statement::Variable(ast::Variable {
align: None,
v_type: arg.v_type.clone(),
state_space: ast::StateSpace::Reg,
name,
array_init: Vec::new(),
}));
result.push(name);
}
result
}
struct BasicBlockControlState<'a> {
@ -1163,7 +1350,6 @@ struct BasicBlockControlState<'a> {
denormal_f16f64: RegisterState<bool>,
rounding_f32: RegisterState<ast::RoundingMode>,
rounding_f16f64: RegisterState<ast::RoundingMode>,
current_bb: SpirvWord,
}
#[derive(Clone, Copy)]
@ -1175,20 +1361,6 @@ struct RegisterState<T> {
}
impl<T> RegisterState<T> {
fn single(t: T) -> Self {
RegisterState {
last_foldable: None,
current_value: Resolved::Value(t),
}
}
fn conflict() -> Self {
RegisterState {
last_foldable: None,
current_value: Resolved::Conflict,
}
}
fn new<U>(value: Resolved<U>) -> RegisterState<T>
where
U: Into<T>,
@ -1201,11 +1373,7 @@ impl<T> RegisterState<T> {
}
impl<'a> BasicBlockControlState<'a> {
fn new(
global_modes: &'a FullModeInsertion2,
current_bb: SpirvWord,
initial_mode: &FullBasicBlockEntryState,
) -> Self {
fn new(global_modes: &'a FullModeInsertion2, initial_mode: &FullBasicBlockEntryState) -> Self {
let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32);
let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64);
let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32);
@ -1216,7 +1384,6 @@ impl<'a> BasicBlockControlState<'a> {
denormal_f16f64,
rounding_f32,
rounding_f16f64,
current_bb,
}
}
@ -1405,12 +1572,6 @@ fn redirect_jump_impl(
Ok(())
}
struct ModeJumpTargets {
dual_prologue: Option<SpirvWord>,
denormal: Option<SpirvWord>,
rounding: Option<SpirvWord>,
}
#[derive(Copy, Clone)]
enum Resolved<T> {
Conflict,
@ -1457,13 +1618,6 @@ impl<T> Resolved<T> {
Resolved::Conflict => Err(err()),
}
}
fn has_value(&self) -> bool {
match self {
Resolved::Value(_) => true,
Resolved::Conflict => false,
}
}
}
trait ModeView {

View file

@ -258,30 +258,83 @@ fn call_with_mode() {
));
let [to_fn0] = calls(method_1);
let [_, dual_prelude, _, _, add] = labels(method_1);
let [post_call, post_prelude_0, post_prelude_1, post_prelude_2] = branches(method_1);
let [post_call, post_prelude_dual, post_prelude_denormal, post_prelude_rounding] =
branches(method_1);
assert_eq!(methods[0].name, to_fn0);
assert_eq!(post_call, dual_prelude);
assert_eq!(post_prelude_0, add);
assert_eq!(post_prelude_1, add);
assert_eq!(post_prelude_2, add);
assert_eq!(post_prelude_dual, add);
assert_eq!(post_prelude_denormal, add);
assert_eq!(post_prelude_rounding, add);
let method_2 = methods[2].body.as_ref().unwrap();
assert!(matches!(
&**method_2,
[
Statement::Label(..),
Statement::Variable(..),
Statement::Variable(..),
Statement::Conditional(..),
Statement::Label(..),
Statement::Conditional(..),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
// Dual prelude
Statement::SetMode(ModeRegister::Denormal {
f32: true,
f32: false,
f16f64: true
}),
Statement::SetMode(ModeRegister::Rounding {
f32: ast::RoundingMode::PositiveInf,
f32: ast::RoundingMode::NegativeInf,
f16f64: ast::RoundingMode::NearestEven
}),
Statement::Instruction(ast::Instruction::Call { .. }),
Statement::Instruction(ast::Instruction::Bra { .. }),
// Denormal prelude
Statement::Label(..),
Statement::SetMode(ModeRegister::Denormal {
f32: false,
f16f64: true
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
// Rounding prelude
Statement::Label(..),
Statement::SetMode(ModeRegister::Rounding {
f32: ast::RoundingMode::NegativeInf,
f16f64: ast::RoundingMode::NearestEven
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::SetMode(ModeRegister::Denormal {
f32: false,
f16f64: true
}),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Add { .. }),
Statement::Instruction(ast::Instruction::Bra { .. }),
Statement::Label(..),
Statement::Instruction(ast::Instruction::Ret { .. }),
]
));
let [(if_rm_true, if_rm_false), (if_rz_true, if_rz_false)] = conditionals(method_2);
let [_, conditional2, post_conditional2, prelude_dual, _, _, add1, add2_set_denormal, add2, ret] =
labels(method_2);
let [post_conditional2_jump, post_prelude_dual, post_prelude_denormal, post_prelude_rounding, post_add1, post_add2_set_denormal, post_add2] =
branches(method_2);
assert_eq!(if_rm_true, prelude_dual);
assert_eq!(if_rm_false, conditional2);
assert_eq!(if_rz_true, post_conditional2);
assert_eq!(if_rz_false, add2_set_denormal);
assert_eq!(post_conditional2_jump, prelude_dual);
assert_eq!(post_prelude_dual, add1);
assert_eq!(post_prelude_denormal, add1);
assert_eq!(post_prelude_rounding, add1);
assert_eq!(post_add1, ret);
assert_eq!(post_add2_set_denormal, add2);
assert_eq!(post_add2, ret);
}
fn branches<const N: usize>(
@ -303,10 +356,12 @@ fn labels<const N: usize>(
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> [SpirvWord; N] {
fn_.iter()
.filter_map(|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
Statement::Label(label) => Some(*label),
_ => None,
})
.filter_map(
|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
Statement::Label(label) => Some(*label),
_ => None,
},
)
.collect::<Vec<_>>()
.try_into()
.unwrap()
@ -317,7 +372,25 @@ fn calls<const N: usize>(
) -> [SpirvWord; N] {
fn_.iter()
.filter_map(|s| match s {
Statement::Instruction(ast::Instruction::Call { arguments: ast::CallArgs { func,.. }, .. }) => Some(*func),
Statement::Instruction(ast::Instruction::Call {
arguments: ast::CallArgs { func, .. },
..
}) => Some(*func),
_ => None,
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
fn conditionals<const N: usize>(
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> [(SpirvWord, SpirvWord); N] {
fn_.iter()
.filter_map(|s| match s {
Statement::Conditional(BrachCondition {
if_true, if_false, ..
}) => Some((*if_true, *if_false)),
_ => None,
})
.collect::<Vec<_>>()

View file

@ -7,7 +7,7 @@ use super::*;
// represent kernels as separate nodes with its own separate entry/exit mode
// * Inserts label at the start of every basic block
// * Insert explicit jumps before labels
// * Functions get a single `ret;` exit point - this is because mode computation
// * Non-.entry methods get a single `ret;` exit point - this is because mode computation
// logic requires it. Control flow graph constructed by mode computation
// models function calls as jumps into and then from another function.
// If this cfg allowed multiple return basic blocks then there would be cases
@ -19,10 +19,10 @@ pub(crate) fn run(
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
for directive in directives.iter_mut() {
let body_ref = match directive {
let (body_ref, is_kernel) = match directive {
Directive2::Method(Function2 {
body: Some(body), ..
}) => body,
body: Some(body), is_kernel, ..
}) => (body, *is_kernel),
_ => continue,
};
let body = std::mem::replace(body_ref, Vec::new());
@ -74,7 +74,9 @@ pub(crate) fn run(
return Err(error_unreachable());
}
Statement::Instruction(ast::Instruction::Ret { .. }) => {
return_statements.push(result.len())
if !is_kernel {
return_statements.push(result.len());
}
}
_ => {}
}