Fix some bugs in mode pass

This commit is contained in:
Andrzej Janik 2025-03-12 22:11:52 +00:00
parent c86473b396
commit a0d4b7eeb2
2 changed files with 804 additions and 270 deletions

File diff suppressed because it is too large Load diff

View file

@ -214,10 +214,12 @@ static CALL_WITH_MODE_PTX: &'static str = include_str!("call_with_mode.ptx");
#[test]
fn call_with_mode() {
let methods = compile_methods(CALL_WITH_MODE_PTX);
assert!(matches!(methods[0].body, None));
let method_1 = methods[1].body.as_ref().unwrap();
assert!(matches!(
&**methods[1].body.as_ref().unwrap(),
&**method_1,
[
Statement::Label(..),
Statement::Variable(..),
@ -254,4 +256,71 @@ fn call_with_mode() {
Statement::Instruction(ast::Instruction::Ret { .. }),
]
));
let [to_fn0] = calls(method_1);
let [_, dual_prelude, _, _, add] = labels(method_1);
let [post_call, post_prelude_0, post_prelude_1, post_prelude_2] = branches(method_1);
assert_eq!(methods[0].name, to_fn0);
assert_eq!(post_call, dual_prelude);
assert_eq!(post_prelude_0, add);
assert_eq!(post_prelude_1, add);
assert_eq!(post_prelude_2, add);
let method_2 = methods[2].body.as_ref().unwrap();
assert!(matches!(
&**method_2,
[
Statement::Label(..),
Statement::SetMode(ModeRegister::Denormal {
f32: true,
f16f64: true
}),
Statement::SetMode(ModeRegister::Rounding {
f32: ast::RoundingMode::PositiveInf,
f16f64: ast::RoundingMode::NearestEven
}),
Statement::Instruction(ast::Instruction::Call { .. }),
Statement::Instruction(ast::Instruction::Ret { .. }),
]
));
}
fn branches<const N: usize>(
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> [SpirvWord; N] {
fn_.iter()
.filter_map(|s| match s {
Statement::Instruction(ast::Instruction::Bra {
arguments: ast::BraArgs { src },
}) => Some(*src),
_ => None,
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
fn labels<const N: usize>(
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> [SpirvWord; N] {
fn_.iter()
.filter_map(|s: &Statement<ptx_parser::Instruction<SpirvWord>, SpirvWord>| match s {
Statement::Label(label) => Some(*label),
_ => None,
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
fn calls<const N: usize>(
fn_: &Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> [SpirvWord; N] {
fn_.iter()
.filter_map(|s| match s {
Statement::Instruction(ast::Instruction::Call { arguments: ast::CallArgs { func,.. }, .. }) => Some(*func),
_ => None,
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}