mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Add mode-setting wrappers to functions
This commit is contained in:
parent
a0d4b7eeb2
commit
87fe601494
3 changed files with 323 additions and 94 deletions
|
@ -734,7 +734,9 @@ pub(crate) fn run<'input>(
|
|||
}
|
||||
Statement::RetValue(..)
|
||||
| Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||
bb_state.record_ret(*name)?;
|
||||
if !is_kernel {
|
||||
bb_state.record_ret(*name)?;
|
||||
}
|
||||
}
|
||||
Statement::Label(label) => {
|
||||
bb_state.start(*label);
|
||||
|
@ -808,7 +810,7 @@ pub(crate) fn run<'input>(
|
|||
)?;
|
||||
let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?;
|
||||
*/
|
||||
let directives = insert_mode_control(directives, temp)?;
|
||||
let directives = insert_mode_control(flat_resolver, directives, temp)?;
|
||||
Ok(directives)
|
||||
}
|
||||
|
||||
|
@ -1018,47 +1020,48 @@ struct TwinMode<T> {
|
|||
}
|
||||
|
||||
fn insert_mode_control(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
global_modes: FullModeInsertion2,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let directives_len = directives.len();
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|mut directive| {
|
||||
.map(|directive| {
|
||||
let mut new_directives = SmallVec::<[_; 4]>::new();
|
||||
let (fn_name, initial_mode, body_ptr) = match directive {
|
||||
let (mut method, initial_mode) = match directive {
|
||||
Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => {
|
||||
new_directives.push(directive);
|
||||
return Ok(new_directives);
|
||||
}
|
||||
Directive2::Method(Function2 {
|
||||
name,
|
||||
body: Some(ref mut body),
|
||||
ref mut flush_to_zero_f32,
|
||||
ref mut flush_to_zero_f16f64,
|
||||
ref mut rounding_mode_f32,
|
||||
ref mut rounding_mode_f16f64,
|
||||
..
|
||||
}) => {
|
||||
Directive2::Method(
|
||||
mut method @ Function2 {
|
||||
name,
|
||||
body: Some(_),
|
||||
..
|
||||
},
|
||||
) => {
|
||||
let initial_mode = global_modes
|
||||
.basic_blocks
|
||||
.get(&name)
|
||||
.ok_or_else(error_unreachable)?;
|
||||
let denormal_mode = initial_mode.denormal.twin_mode;
|
||||
let rounding_mode = initial_mode.rounding.twin_mode;
|
||||
*flush_to_zero_f32 = denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz();
|
||||
*flush_to_zero_f16f64 =
|
||||
method.flush_to_zero_f32 =
|
||||
denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz();
|
||||
method.flush_to_zero_f16f64 =
|
||||
denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz();
|
||||
*rounding_mode_f32 = rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast();
|
||||
*rounding_mode_f16f64 =
|
||||
method.rounding_mode_f32 =
|
||||
rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast();
|
||||
method.rounding_mode_f16f64 =
|
||||
rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast();
|
||||
(name, initial_mode, body)
|
||||
(method, initial_mode)
|
||||
}
|
||||
};
|
||||
emit_mode_prelude(fn_name, &mut new_directives);
|
||||
let old_body = mem::replace(body_ptr, Vec::new());
|
||||
emit_mode_prelude(flat_resolver, &method, &global_modes, &mut new_directives)?;
|
||||
let old_body = method.body.take().unwrap();
|
||||
let mut result = Vec::with_capacity(old_body.len());
|
||||
let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode);
|
||||
let mut bb_state = BasicBlockControlState::new(&global_modes, initial_mode);
|
||||
let mut old_body = old_body.into_iter();
|
||||
while let Some(mut statement) = old_body.next() {
|
||||
let mut call_target = None;
|
||||
|
@ -1115,8 +1118,8 @@ fn insert_mode_control(
|
|||
}
|
||||
}
|
||||
}
|
||||
*body_ptr = result;
|
||||
new_directives.push(directive);
|
||||
method.body = Some(result);
|
||||
new_directives.push(Directive2::Method(method));
|
||||
Ok(new_directives)
|
||||
})
|
||||
.try_fold(Vec::with_capacity(directives_len), |mut acc, d| {
|
||||
|
@ -1126,35 +1129,219 @@ fn insert_mode_control(
|
|||
}
|
||||
|
||||
fn emit_mode_prelude(
|
||||
fn_name: SpirvWord,
|
||||
global_modes: FullModeInsertion2,
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||
method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
global_modes: &FullModeInsertion2,
|
||||
new_directives: &mut SmallVec<[Directive2<ptx_parser::Instruction<SpirvWord>, SpirvWord>; 4]>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let fn_mode_state = global_modes.basic_blocks.get(&fn_name).ok_or_else(error_unreachable)?;
|
||||
let fn_mode_state = global_modes
|
||||
.basic_blocks
|
||||
.get(&method.name)
|
||||
.ok_or_else(error_unreachable)?;
|
||||
if let Some(dual_prologue) = fn_mode_state.dual_prologue {
|
||||
new_directives.push(Directive2::Method(
|
||||
Function2 {
|
||||
return_arguments: todo!(),
|
||||
name: dual_prologue,
|
||||
input_arguments: todo!(),
|
||||
body: todo!(),
|
||||
is_kernel: false,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::NONE,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
}
|
||||
new_directives.push(create_fn_wrapper(
|
||||
flat_resolver,
|
||||
method,
|
||||
dual_prologue,
|
||||
[
|
||||
ModeRegister::Denormal {
|
||||
f32: fn_mode_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.to_ftz(),
|
||||
f16f64: fn_mode_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.to_ftz(),
|
||||
},
|
||||
ModeRegister::Rounding {
|
||||
f32: fn_mode_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.to_ast(),
|
||||
f16f64: fn_mode_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.to_ast(),
|
||||
},
|
||||
]
|
||||
.into_iter(),
|
||||
));
|
||||
}
|
||||
if let Some(prologue) = fn_mode_state.denormal.prologue {
|
||||
todo!()
|
||||
new_directives.push(create_fn_wrapper(
|
||||
flat_resolver,
|
||||
method,
|
||||
prologue,
|
||||
[ModeRegister::Denormal {
|
||||
f32: fn_mode_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.to_ftz(),
|
||||
f16f64: fn_mode_state
|
||||
.denormal
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.to_ftz(),
|
||||
}]
|
||||
.into_iter(),
|
||||
));
|
||||
}
|
||||
if let Some(prologue) = fn_mode_state.rounding.prologue {
|
||||
todo!()
|
||||
new_directives.push(create_fn_wrapper(
|
||||
flat_resolver,
|
||||
method,
|
||||
prologue,
|
||||
[ModeRegister::Rounding {
|
||||
f32: fn_mode_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f32
|
||||
.unwrap_of_default()
|
||||
.to_ast(),
|
||||
f16f64: fn_mode_state
|
||||
.rounding
|
||||
.twin_mode
|
||||
.f16f64
|
||||
.unwrap_of_default()
|
||||
.to_ast(),
|
||||
}]
|
||||
.into_iter(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_fn_wrapper(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||
method: &Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
name: SpirvWord,
|
||||
modes: impl ExactSizeIterator<Item = ModeRegister>,
|
||||
) -> Directive2<ast::Instruction<SpirvWord>, SpirvWord> {
|
||||
// * Label
|
||||
// * return argument registers
|
||||
// * input argument registers
|
||||
// * Load input arguments
|
||||
// * set modes
|
||||
// * call
|
||||
// * return with value
|
||||
let return_arguments = rename_variables(flat_resolver, &method.return_arguments);
|
||||
let input_arguments = rename_variables(flat_resolver, &method.input_arguments);
|
||||
let mut body = Vec::with_capacity(
|
||||
1 + (input_arguments.len() * 2) + return_arguments.len() + modes.len() + 2,
|
||||
);
|
||||
body.push(Statement::Label(flat_resolver.register_unnamed(None)));
|
||||
let return_variables = append_variables(flat_resolver, &mut body, &return_arguments);
|
||||
let input_variables = append_variables(flat_resolver, &mut body, &input_arguments);
|
||||
for (index, input_reg) in input_variables.iter().enumerate() {
|
||||
body.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: input_arguments[index].state_space,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: input_arguments[index].v_type.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
src: input_arguments[index].name,
|
||||
dst: *input_reg,
|
||||
},
|
||||
}));
|
||||
}
|
||||
body.extend(modes.map(|mode_set| Statement::SetMode(mode_set)));
|
||||
// Out of order because we want to use return_variables before they are moved
|
||||
let ret_statement = if return_arguments.is_empty() {
|
||||
Statement::Instruction(ast::Instruction::Ret {
|
||||
data: ast::RetData { uniform: false },
|
||||
})
|
||||
} else {
|
||||
Statement::RetValue(
|
||||
ast::RetData { uniform: false },
|
||||
return_variables
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, var)| (*var, method.return_arguments[index].v_type.clone()))
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
body.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data: ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|arg| (arg.v_type.clone(), arg.state_space))
|
||||
.collect(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|arg| (arg.v_type.clone(), arg.state_space))
|
||||
.collect(),
|
||||
},
|
||||
arguments: ast::CallArgs {
|
||||
return_arguments: return_variables,
|
||||
func: method.name,
|
||||
input_arguments: input_variables,
|
||||
},
|
||||
}));
|
||||
body.push(ret_statement);
|
||||
Directive2::Method(Function2 {
|
||||
return_arguments,
|
||||
name,
|
||||
input_arguments,
|
||||
body: Some(body),
|
||||
is_kernel: false,
|
||||
import_as: None,
|
||||
tuning: Vec::new(),
|
||||
linkage: ast::LinkingDirective::NONE,
|
||||
flush_to_zero_f32: false,
|
||||
flush_to_zero_f16f64: false,
|
||||
rounding_mode_f32: ptx_parser::RoundingMode::NearestEven,
|
||||
rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven,
|
||||
})
|
||||
}
|
||||
|
||||
fn rename_variables(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2,
|
||||
variables: &Vec<ast::Variable<SpirvWord>>,
|
||||
) -> Vec<ast::Variable<SpirvWord>> {
|
||||
variables
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|arg| ast::Variable {
|
||||
name: flat_resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))),
|
||||
..arg
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn append_variables<'a, 'input: 'a>(
|
||||
flat_resolver: &'a mut super::GlobalStringIdentResolver2<'input>,
|
||||
body: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
arguments: &'a Vec<ast::Variable<SpirvWord>>,
|
||||
) -> Vec<SpirvWord> {
|
||||
let mut result = Vec::with_capacity(arguments.len());
|
||||
for arg in arguments {
|
||||
let name = flat_resolver.register_unnamed(Some((arg.v_type.clone(), ast::StateSpace::Reg)));
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
result.push(name);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
struct BasicBlockControlState<'a> {
|
||||
|
@ -1163,7 +1350,6 @@ struct BasicBlockControlState<'a> {
|
|||
denormal_f16f64: RegisterState<bool>,
|
||||
rounding_f32: RegisterState<ast::RoundingMode>,
|
||||
rounding_f16f64: RegisterState<ast::RoundingMode>,
|
||||
current_bb: SpirvWord,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
|
@ -1175,20 +1361,6 @@ struct RegisterState<T> {
|
|||
}
|
||||
|
||||
impl<T> RegisterState<T> {
|
||||
fn single(t: T) -> Self {
|
||||
RegisterState {
|
||||
last_foldable: None,
|
||||
current_value: Resolved::Value(t),
|
||||
}
|
||||
}
|
||||
|
||||
fn conflict() -> Self {
|
||||
RegisterState {
|
||||
last_foldable: None,
|
||||
current_value: Resolved::Conflict,
|
||||
}
|
||||
}
|
||||
|
||||
fn new<U>(value: Resolved<U>) -> RegisterState<T>
|
||||
where
|
||||
U: Into<T>,
|
||||
|
@ -1201,11 +1373,7 @@ impl<T> RegisterState<T> {
|
|||
}
|
||||
|
||||
impl<'a> BasicBlockControlState<'a> {
|
||||
fn new(
|
||||
global_modes: &'a FullModeInsertion2,
|
||||
current_bb: SpirvWord,
|
||||
initial_mode: &FullBasicBlockEntryState,
|
||||
) -> Self {
|
||||
fn new(global_modes: &'a FullModeInsertion2, initial_mode: &FullBasicBlockEntryState) -> Self {
|
||||
let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32);
|
||||
let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64);
|
||||
let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32);
|
||||
|
@ -1216,7 +1384,6 @@ impl<'a> BasicBlockControlState<'a> {
|
|||
denormal_f16f64,
|
||||
rounding_f32,
|
||||
rounding_f16f64,
|
||||
current_bb,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1405,12 +1572,6 @@ fn redirect_jump_impl(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct ModeJumpTargets {
|
||||
dual_prologue: Option<SpirvWord>,
|
||||
denormal: Option<SpirvWord>,
|
||||
rounding: Option<SpirvWord>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum Resolved<T> {
|
||||
Conflict,
|
||||
|
@ -1457,13 +1618,6 @@ impl<T> Resolved<T> {
|
|||
Resolved::Conflict => Err(err()),
|
||||
}
|
||||
}
|
||||
|
||||
fn has_value(&self) -> bool {
|
||||
match self {
|
||||
Resolved::Value(_) => true,
|
||||
Resolved::Conflict => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait ModeView {
|
||||
|
|
|
@ -258,30 +258,83 @@ fn call_with_mode() {
|
|||
));
|
||||
let [to_fn0] = calls(method_1);
|
||||
let [_, dual_prelude, _, _, add] = labels(method_1);
|
||||
let [post_call, post_prelude_0, post_prelude_1, post_prelude_2] = branches(method_1);
|
||||
let [post_call, post_prelude_dual, post_prelude_denormal, post_prelude_rounding] =
|
||||
branches(method_1);
|
||||
assert_eq!(methods[0].name, to_fn0);
|
||||
assert_eq!(post_call, dual_prelude);
|
||||
assert_eq!(post_prelude_0, add);
|
||||
assert_eq!(post_prelude_1, add);
|
||||
assert_eq!(post_prelude_2, add);
|
||||
assert_eq!(post_prelude_dual, add);
|
||||
assert_eq!(post_prelude_denormal, add);
|
||||
assert_eq!(post_prelude_rounding, add);
|
||||
|
||||
let method_2 = methods[2].body.as_ref().unwrap();
|
||||
assert!(matches!(
|
||||
&**method_2,
|
||||
[
|
||||
Statement::Label(..),
|
||||
Statement::Variable(..),
|
||||
Statement::Variable(..),
|
||||
Statement::Conditional(..),
|
||||
Statement::Label(..),
|
||||
Statement::Conditional(..),
|
||||
Statement::Label(..),
|
||||
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||
Statement::Label(..),
|
||||
// Dual prelude
|
||||
Statement::SetMode(ModeRegister::Denormal {
|
||||
f32: true,
|
||||
f32: false,
|
||||
f16f64: true
|
||||
}),
|
||||
Statement::SetMode(ModeRegister::Rounding {
|
||||
f32: ast::RoundingMode::PositiveInf,
|
||||
f32: ast::RoundingMode::NegativeInf,
|
||||
f16f64: ast::RoundingMode::NearestEven
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Call { .. }),
|
||||
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||
// Denormal prelude
|
||||
Statement::Label(..),
|
||||
Statement::SetMode(ModeRegister::Denormal {
|
||||
f32: false,
|
||||
f16f64: true
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||
// Rounding prelude
|
||||
Statement::Label(..),
|
||||
Statement::SetMode(ModeRegister::Rounding {
|
||||
f32: ast::RoundingMode::NegativeInf,
|
||||
f16f64: ast::RoundingMode::NearestEven
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||
Statement::Label(..),
|
||||
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||
Statement::Label(..),
|
||||
Statement::SetMode(ModeRegister::Denormal {
|
||||
f32: false,
|
||||
f16f64: true
|
||||
}),
|
||||
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||
Statement::Label(..),
|
||||
Statement::Instruction(ast::Instruction::Add { .. }),
|
||||
Statement::Instruction(ast::Instruction::Bra { .. }),
|
||||
Statement::Label(..),
|
||||
Statement::Instruction(ast::Instruction::Ret { .. }),
|
||||
]
|
||||
));
|
||||
let [(if_rm_true, if_rm_false), (if_rz_true, if_rz_false)] = conditionals(method_2);
|
||||
let [_, conditional2, post_conditional2, prelude_dual, _, _, add1, add2_set_denormal, add2, ret] =
|
||||
labels(method_2);
|
||||
let [post_conditional2_jump, post_prelude_dual, post_prelude_denormal, post_prelude_rounding, post_add1, post_add2_set_denormal, post_add2] =
|
||||
branches(method_2);
|
||||
assert_eq!(if_rm_true, prelude_dual);
|
||||
assert_eq!(if_rm_false, conditional2);
|
||||
assert_eq!(if_rz_true, post_conditional2);
|
||||
assert_eq!(if_rz_false, add2_set_denormal);
|
||||
assert_eq!(post_conditional2_jump, prelude_dual);
|
||||
assert_eq!(post_prelude_dual, add1);
|
||||
assert_eq!(post_prelude_denormal, add1);
|
||||
assert_eq!(post_prelude_rounding, add1);
|
||||
assert_eq!(post_add1, ret);
|
||||
assert_eq!(post_add2_set_denormal, add2);
|
||||
assert_eq!(post_add2, ret);
|
||||
}
|
||||
|
||||
fn branches<const N: usize>(
|
||||
|
@ -303,10 +356,12 @@ fn labels<const N: usize>(
|
|||
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> [SpirvWord; N] {
|
||||
fn_.iter()
|
||||
.filter_map(|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
|
||||
Statement::Label(label) => Some(*label),
|
||||
_ => None,
|
||||
})
|
||||
.filter_map(
|
||||
|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
|
||||
Statement::Label(label) => Some(*label),
|
||||
_ => None,
|
||||
},
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
|
@ -317,7 +372,25 @@ fn calls<const N: usize>(
|
|||
) -> [SpirvWord; N] {
|
||||
fn_.iter()
|
||||
.filter_map(|s| match s {
|
||||
Statement::Instruction(ast::Instruction::Call { arguments: ast::CallArgs { func,.. }, .. }) => Some(*func),
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
arguments: ast::CallArgs { func, .. },
|
||||
..
|
||||
}) => Some(*func),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn conditionals<const N: usize>(
|
||||
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> [(SpirvWord, SpirvWord); N] {
|
||||
fn_.iter()
|
||||
.filter_map(|s| match s {
|
||||
Statement::Conditional(BrachCondition {
|
||||
if_true, if_false, ..
|
||||
}) => Some((*if_true, *if_false)),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
|
|
|
@ -7,7 +7,7 @@ use super::*;
|
|||
// represent kernels as separate nodes with its own separate entry/exit mode
|
||||
// * Inserts label at the start of every basic block
|
||||
// * Insert explicit jumps before labels
|
||||
// * Functions get a single `ret;` exit point - this is because mode computation
|
||||
// * Non-.entry methods get a single `ret;` exit point - this is because mode computation
|
||||
// logic requires it. Control flow graph constructed by mode computation
|
||||
// models function calls as jumps into and then from another function.
|
||||
// If this cfg allowed multiple return basic blocks then there would be cases
|
||||
|
@ -19,10 +19,10 @@ pub(crate) fn run(
|
|||
mut directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
for directive in directives.iter_mut() {
|
||||
let body_ref = match directive {
|
||||
let (body_ref, is_kernel) = match directive {
|
||||
Directive2::Method(Function2 {
|
||||
body: Some(body), ..
|
||||
}) => body,
|
||||
body: Some(body), is_kernel, ..
|
||||
}) => (body, *is_kernel),
|
||||
_ => continue,
|
||||
};
|
||||
let body = std::mem::replace(body_ref, Vec::new());
|
||||
|
@ -74,7 +74,9 @@ pub(crate) fn run(
|
|||
return Err(error_unreachable());
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||
return_statements.push(result.len())
|
||||
if !is_kernel {
|
||||
return_statements.push(result.len());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue