diff --git a/ptx/src/pass/insert_ftz_control/mod.rs b/ptx/src/pass/insert_ftz_control/mod.rs index 5aab249..00c2f86 100644 --- a/ptx/src/pass/insert_ftz_control/mod.rs +++ b/ptx/src/pass/insert_ftz_control/mod.rs @@ -734,7 +734,9 @@ pub(crate) fn run<'input>( } Statement::RetValue(..) | Statement::Instruction(ast::Instruction::Ret { .. }) => { - bb_state.record_ret(*name)?; + if !is_kernel { + bb_state.record_ret(*name)?; + } } Statement::Label(label) => { bb_state.start(*label); @@ -808,7 +810,7 @@ pub(crate) fn run<'input>( )?; let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?; */ - let directives = insert_mode_control(directives, temp)?; + let directives = insert_mode_control(flat_resolver, directives, temp)?; Ok(directives) } @@ -1018,47 +1020,48 @@ struct TwinMode { } fn insert_mode_control( + flat_resolver: &mut super::GlobalStringIdentResolver2, directives: Vec, SpirvWord>>, global_modes: FullModeInsertion2, ) -> Result, SpirvWord>>, TranslateError> { let directives_len = directives.len(); directives .into_iter() - .map(|mut directive| { + .map(|directive| { let mut new_directives = SmallVec::<[_; 4]>::new(); - let (fn_name, initial_mode, body_ptr) = match directive { + let (mut method, initial_mode) = match directive { Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => { new_directives.push(directive); return Ok(new_directives); } - Directive2::Method(Function2 { - name, - body: Some(ref mut body), - ref mut flush_to_zero_f32, - ref mut flush_to_zero_f16f64, - ref mut rounding_mode_f32, - ref mut rounding_mode_f16f64, - .. - }) => { + Directive2::Method( + mut method @ Function2 { + name, + body: Some(_), + .. + }, + ) => { let initial_mode = global_modes .basic_blocks .get(&name) .ok_or_else(error_unreachable)?; let denormal_mode = initial_mode.denormal.twin_mode; let rounding_mode = initial_mode.rounding.twin_mode; - *flush_to_zero_f32 = denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz(); - *flush_to_zero_f16f64 = + method.flush_to_zero_f32 = + denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz(); + method.flush_to_zero_f16f64 = denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz(); - *rounding_mode_f32 = rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast(); - *rounding_mode_f16f64 = + method.rounding_mode_f32 = + rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast(); + method.rounding_mode_f16f64 = rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast(); - (name, initial_mode, body) + (method, initial_mode) } }; - emit_mode_prelude(fn_name, &mut new_directives); - let old_body = mem::replace(body_ptr, Vec::new()); + emit_mode_prelude(flat_resolver, &method, &global_modes, &mut new_directives)?; + let old_body = method.body.take().unwrap(); let mut result = Vec::with_capacity(old_body.len()); - let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode); + let mut bb_state = BasicBlockControlState::new(&global_modes, initial_mode); let mut old_body = old_body.into_iter(); while let Some(mut statement) = old_body.next() { let mut call_target = None; @@ -1115,8 +1118,8 @@ fn insert_mode_control( } } } - *body_ptr = result; - new_directives.push(directive); + method.body = Some(result); + new_directives.push(Directive2::Method(method)); Ok(new_directives) }) .try_fold(Vec::with_capacity(directives_len), |mut acc, d| { @@ -1126,35 +1129,219 @@ fn insert_mode_control( } fn emit_mode_prelude( - fn_name: SpirvWord, - global_modes: FullModeInsertion2, + flat_resolver: &mut super::GlobalStringIdentResolver2, + method: &Function2, SpirvWord>, + global_modes: &FullModeInsertion2, new_directives: &mut SmallVec<[Directive2, SpirvWord>; 4]>, ) -> Result<(), TranslateError> { - let fn_mode_state = global_modes.basic_blocks.get(&fn_name).ok_or_else(error_unreachable)?; + let fn_mode_state = global_modes + .basic_blocks + .get(&method.name) + .ok_or_else(error_unreachable)?; if let Some(dual_prologue) = fn_mode_state.dual_prologue { - new_directives.push(Directive2::Method( - Function2 { - return_arguments: todo!(), - name: dual_prologue, - input_arguments: todo!(), - body: todo!(), - is_kernel: false, - import_as: None, - tuning: Vec::new(), - linkage: ast::LinkingDirective::NONE, - flush_to_zero_f32: false, - flush_to_zero_f16f64: false, - rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, - rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, - } + new_directives.push(create_fn_wrapper( + flat_resolver, + method, + dual_prologue, + [ + ModeRegister::Denormal { + f32: fn_mode_state + .denormal + .twin_mode + .f32 + .unwrap_of_default() + .to_ftz(), + f16f64: fn_mode_state + .denormal + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ftz(), + }, + ModeRegister::Rounding { + f32: fn_mode_state + .rounding + .twin_mode + .f32 + .unwrap_of_default() + .to_ast(), + f16f64: fn_mode_state + .rounding + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ast(), + }, + ] + .into_iter(), )); } if let Some(prologue) = fn_mode_state.denormal.prologue { - todo!() + new_directives.push(create_fn_wrapper( + flat_resolver, + method, + prologue, + [ModeRegister::Denormal { + f32: fn_mode_state + .denormal + .twin_mode + .f32 + .unwrap_of_default() + .to_ftz(), + f16f64: fn_mode_state + .denormal + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ftz(), + }] + .into_iter(), + )); } if let Some(prologue) = fn_mode_state.rounding.prologue { - todo!() + new_directives.push(create_fn_wrapper( + flat_resolver, + method, + prologue, + [ModeRegister::Rounding { + f32: fn_mode_state + .rounding + .twin_mode + .f32 + .unwrap_of_default() + .to_ast(), + f16f64: fn_mode_state + .rounding + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ast(), + }] + .into_iter(), + )); } + Ok(()) +} + +fn create_fn_wrapper( + flat_resolver: &mut super::GlobalStringIdentResolver2, + method: &Function2, SpirvWord>, + name: SpirvWord, + modes: impl ExactSizeIterator, +) -> Directive2, SpirvWord> { + // * Label + // * return argument registers + // * input argument registers + // * Load input arguments + // * set modes + // * call + // * return with value + let return_arguments = rename_variables(flat_resolver, &method.return_arguments); + let input_arguments = rename_variables(flat_resolver, &method.input_arguments); + let mut body = Vec::with_capacity( + 1 + (input_arguments.len() * 2) + return_arguments.len() + modes.len() + 2, + ); + body.push(Statement::Label(flat_resolver.register_unnamed(None))); + let return_variables = append_variables(flat_resolver, &mut body, &return_arguments); + let input_variables = append_variables(flat_resolver, &mut body, &input_arguments); + for (index, input_reg) in input_variables.iter().enumerate() { + body.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: input_arguments[index].state_space, + caching: ast::LdCacheOperator::Cached, + typ: input_arguments[index].v_type.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + src: input_arguments[index].name, + dst: *input_reg, + }, + })); + } + body.extend(modes.map(|mode_set| Statement::SetMode(mode_set))); + // Out of order because we want to use return_variables before they are moved + let ret_statement = if return_arguments.is_empty() { + Statement::Instruction(ast::Instruction::Ret { + data: ast::RetData { uniform: false }, + }) + } else { + Statement::RetValue( + ast::RetData { uniform: false }, + return_variables + .iter() + .enumerate() + .map(|(index, var)| (*var, method.return_arguments[index].v_type.clone())) + .collect(), + ) + }; + body.push(Statement::Instruction(ast::Instruction::Call { + data: ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|arg| (arg.v_type.clone(), arg.state_space)) + .collect(), + input_arguments: input_arguments + .iter() + .map(|arg| (arg.v_type.clone(), arg.state_space)) + .collect(), + }, + arguments: ast::CallArgs { + return_arguments: return_variables, + func: method.name, + input_arguments: input_variables, + }, + })); + body.push(ret_statement); + Directive2::Method(Function2 { + return_arguments, + name, + input_arguments, + body: Some(body), + is_kernel: false, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::NONE, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + }) +} + +fn rename_variables( + flat_resolver: &mut super::GlobalStringIdentResolver2, + variables: &Vec>, +) -> Vec> { + variables + .iter() + .cloned() + .map(|arg| ast::Variable { + name: flat_resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))), + ..arg + }) + .collect() +} + +fn append_variables<'a, 'input: 'a>( + flat_resolver: &'a mut super::GlobalStringIdentResolver2<'input>, + body: &mut Vec, SpirvWord>>, + arguments: &'a Vec>, +) -> Vec { + let mut result = Vec::with_capacity(arguments.len()); + for arg in arguments { + let name = flat_resolver.register_unnamed(Some((arg.v_type.clone(), ast::StateSpace::Reg))); + body.push(Statement::Variable(ast::Variable { + align: None, + v_type: arg.v_type.clone(), + state_space: ast::StateSpace::Reg, + name, + array_init: Vec::new(), + })); + result.push(name); + } + result } struct BasicBlockControlState<'a> { @@ -1163,7 +1350,6 @@ struct BasicBlockControlState<'a> { denormal_f16f64: RegisterState, rounding_f32: RegisterState, rounding_f16f64: RegisterState, - current_bb: SpirvWord, } #[derive(Clone, Copy)] @@ -1175,20 +1361,6 @@ struct RegisterState { } impl RegisterState { - fn single(t: T) -> Self { - RegisterState { - last_foldable: None, - current_value: Resolved::Value(t), - } - } - - fn conflict() -> Self { - RegisterState { - last_foldable: None, - current_value: Resolved::Conflict, - } - } - fn new(value: Resolved) -> RegisterState where U: Into, @@ -1201,11 +1373,7 @@ impl RegisterState { } impl<'a> BasicBlockControlState<'a> { - fn new( - global_modes: &'a FullModeInsertion2, - current_bb: SpirvWord, - initial_mode: &FullBasicBlockEntryState, - ) -> Self { + fn new(global_modes: &'a FullModeInsertion2, initial_mode: &FullBasicBlockEntryState) -> Self { let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32); let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64); let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32); @@ -1216,7 +1384,6 @@ impl<'a> BasicBlockControlState<'a> { denormal_f16f64, rounding_f32, rounding_f16f64, - current_bb, } } @@ -1405,12 +1572,6 @@ fn redirect_jump_impl( Ok(()) } -struct ModeJumpTargets { - dual_prologue: Option, - denormal: Option, - rounding: Option, -} - #[derive(Copy, Clone)] enum Resolved { Conflict, @@ -1457,13 +1618,6 @@ impl Resolved { Resolved::Conflict => Err(err()), } } - - fn has_value(&self) -> bool { - match self { - Resolved::Value(_) => true, - Resolved::Conflict => false, - } - } } trait ModeView { diff --git a/ptx/src/pass/insert_ftz_control/test.rs b/ptx/src/pass/insert_ftz_control/test.rs index 05c1fc8..ef59495 100644 --- a/ptx/src/pass/insert_ftz_control/test.rs +++ b/ptx/src/pass/insert_ftz_control/test.rs @@ -258,30 +258,83 @@ fn call_with_mode() { )); let [to_fn0] = calls(method_1); let [_, dual_prelude, _, _, add] = labels(method_1); - let [post_call, post_prelude_0, post_prelude_1, post_prelude_2] = branches(method_1); + let [post_call, post_prelude_dual, post_prelude_denormal, post_prelude_rounding] = + branches(method_1); assert_eq!(methods[0].name, to_fn0); assert_eq!(post_call, dual_prelude); - assert_eq!(post_prelude_0, add); - assert_eq!(post_prelude_1, add); - assert_eq!(post_prelude_2, add); + assert_eq!(post_prelude_dual, add); + assert_eq!(post_prelude_denormal, add); + assert_eq!(post_prelude_rounding, add); let method_2 = methods[2].body.as_ref().unwrap(); assert!(matches!( &**method_2, [ Statement::Label(..), + Statement::Variable(..), + Statement::Variable(..), + Statement::Conditional(..), + Statement::Label(..), + Statement::Conditional(..), + Statement::Label(..), + Statement::Instruction(ast::Instruction::Bra { .. }), + Statement::Label(..), + // Dual prelude Statement::SetMode(ModeRegister::Denormal { - f32: true, + f32: false, f16f64: true }), Statement::SetMode(ModeRegister::Rounding { - f32: ast::RoundingMode::PositiveInf, + f32: ast::RoundingMode::NegativeInf, f16f64: ast::RoundingMode::NearestEven }), - Statement::Instruction(ast::Instruction::Call { .. }), + Statement::Instruction(ast::Instruction::Bra { .. }), + // Denormal prelude + Statement::Label(..), + Statement::SetMode(ModeRegister::Denormal { + f32: false, + f16f64: true + }), + Statement::Instruction(ast::Instruction::Bra { .. }), + // Rounding prelude + Statement::Label(..), + Statement::SetMode(ModeRegister::Rounding { + f32: ast::RoundingMode::NegativeInf, + f16f64: ast::RoundingMode::NearestEven + }), + Statement::Instruction(ast::Instruction::Bra { .. }), + Statement::Label(..), + Statement::Instruction(ast::Instruction::Add { .. }), + Statement::Instruction(ast::Instruction::Bra { .. }), + Statement::Label(..), + Statement::SetMode(ModeRegister::Denormal { + f32: false, + f16f64: true + }), + Statement::Instruction(ast::Instruction::Bra { .. }), + Statement::Label(..), + Statement::Instruction(ast::Instruction::Add { .. }), + Statement::Instruction(ast::Instruction::Bra { .. }), + Statement::Label(..), Statement::Instruction(ast::Instruction::Ret { .. }), ] )); + let [(if_rm_true, if_rm_false), (if_rz_true, if_rz_false)] = conditionals(method_2); + let [_, conditional2, post_conditional2, prelude_dual, _, _, add1, add2_set_denormal, add2, ret] = + labels(method_2); + let [post_conditional2_jump, post_prelude_dual, post_prelude_denormal, post_prelude_rounding, post_add1, post_add2_set_denormal, post_add2] = + branches(method_2); + assert_eq!(if_rm_true, prelude_dual); + assert_eq!(if_rm_false, conditional2); + assert_eq!(if_rz_true, post_conditional2); + assert_eq!(if_rz_false, add2_set_denormal); + assert_eq!(post_conditional2_jump, prelude_dual); + assert_eq!(post_prelude_dual, add1); + assert_eq!(post_prelude_denormal, add1); + assert_eq!(post_prelude_rounding, add1); + assert_eq!(post_add1, ret); + assert_eq!(post_add2_set_denormal, add2); + assert_eq!(post_add2, ret); } fn branches( @@ -303,10 +356,12 @@ fn labels( fn_: &Vec, SpirvWord>>, ) -> [SpirvWord; N] { fn_.iter() - .filter_map(|s: &Statement, SpirvWord>| match s { - Statement::Label(label) => Some(*label), - _ => None, - }) + .filter_map( + |s: &Statement, SpirvWord>| match s { + Statement::Label(label) => Some(*label), + _ => None, + }, + ) .collect::>() .try_into() .unwrap() @@ -317,7 +372,25 @@ fn calls( ) -> [SpirvWord; N] { fn_.iter() .filter_map(|s| match s { - Statement::Instruction(ast::Instruction::Call { arguments: ast::CallArgs { func,.. }, .. }) => Some(*func), + Statement::Instruction(ast::Instruction::Call { + arguments: ast::CallArgs { func, .. }, + .. + }) => Some(*func), + _ => None, + }) + .collect::>() + .try_into() + .unwrap() +} + +fn conditionals( + fn_: &Vec, SpirvWord>>, +) -> [(SpirvWord, SpirvWord); N] { + fn_.iter() + .filter_map(|s| match s { + Statement::Conditional(BrachCondition { + if_true, if_false, .. + }) => Some((*if_true, *if_false)), _ => None, }) .collect::>() diff --git a/ptx/src/pass/normalize_basic_blocks.rs b/ptx/src/pass/normalize_basic_blocks.rs index f48e659..b81e9b3 100644 --- a/ptx/src/pass/normalize_basic_blocks.rs +++ b/ptx/src/pass/normalize_basic_blocks.rs @@ -7,7 +7,7 @@ use super::*; // represent kernels as separate nodes with its own separate entry/exit mode // * Inserts label at the start of every basic block // * Insert explicit jumps before labels -// * Functions get a single `ret;` exit point - this is because mode computation +// * Non-.entry methods get a single `ret;` exit point - this is because mode computation // logic requires it. Control flow graph constructed by mode computation // models function calls as jumps into and then from another function. // If this cfg allowed multiple return basic blocks then there would be cases @@ -19,10 +19,10 @@ pub(crate) fn run( mut directives: Vec, SpirvWord>>, ) -> Result, SpirvWord>>, TranslateError> { for directive in directives.iter_mut() { - let body_ref = match directive { + let (body_ref, is_kernel) = match directive { Directive2::Method(Function2 { - body: Some(body), .. - }) => body, + body: Some(body), is_kernel, .. + }) => (body, *is_kernel), _ => continue, }; let body = std::mem::replace(body_ref, Vec::new()); @@ -74,7 +74,9 @@ pub(crate) fn run( return Err(error_unreachable()); } Statement::Instruction(ast::Instruction::Ret { .. }) => { - return_statements.push(result.len()) + if !is_kernel { + return_statements.push(result.len()); + } } _ => {} }