diff --git a/ptx/src/pass/filter_for_demo.rs b/ptx/src/pass/filter_for_demo.rs new file mode 100644 index 0000000..816aab7 --- /dev/null +++ b/ptx/src/pass/filter_for_demo.rs @@ -0,0 +1,30 @@ +pub(crate) fn run<'input>( + directives: Vec>>, +) -> Vec>> { + let demo_kernels_path = std::env::var("ZLUDA_DEMO_KERNELS").unwrap(); + let demo_kernels_file = std::fs::read_to_string(demo_kernels_path).unwrap(); + let demo_kernels = demo_kernels_file + .lines() + .map(|line| line.trim()) + .filter(|line| !line.is_empty()) + .collect::>(); + let result = directives + .into_iter() + .filter(|directive| match directive { + ptx_parser::Directive::Method(_, method) => { + !method.func_directive.name.is_kernel() + || demo_kernels.contains(method.func_directive.name()) + } + _ => true, + }) + .collect::>(); + for directive in result.iter() { + match directive { + ptx_parser::Directive::Method(_, method) => { + eprintln!("{}", method.func_directive.name()); + } + _ => {} + } + } + result +} diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 904bf37..d3e0b7b 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -197,6 +197,7 @@ fn run_instruction<'input>( | ast::Instruction::Xor { .. } | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } + | ast::Instruction::GridDepControl { .. } | ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index a4c2dc4..229e179 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1855,6 +1855,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::AtomCas { .. } | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } + | ast::Instruction::GridDepControl { .. } | ast::Instruction::LdMatrix { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index baaff6a..0677345 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -522,6 +522,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop + ast::Instruction::GridDepControl { .. } => Ok(()), // nop // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 4f87dc3..e4a22c7 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -13,6 +13,7 @@ use strum_macros::EnumIter; mod deparamize_functions; mod expand_operands; +mod filter_for_demo; mod fix_special_registers; mod hoist_globals; mod insert_explicit_load_store; @@ -65,7 +66,9 @@ pub fn to_llvm_module<'input>( let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; - let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; + let directives = filter_for_demo::run(ast.directives); + on_pass_end("filter_for_demo"); + let directives = normalize_identifiers2::run(&mut scoped_resolver, directives)?; on_pass_end("normalize_identifiers2"); let directives = replace_known_functions::run(&mut flat_resolver, directives); on_pass_end("replace_known_functions"); diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 9fecba3..04e1b6a 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -721,6 +721,9 @@ ptx_parser_macros::generate_instruction_type!( space: { data.state_space }, } } + }, + GridDepControl { + data: crate::GridDepControlAction, } } ); diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 4253ae6..5389118 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3897,6 +3897,14 @@ derive_parser!( .type: ScalarType = {.b16, .b8}; // .dst_fmt = { .b8x16 }; // .src_fmt = { .b6x16_p32, .b4x16_p64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol + griddepcontrol.action => { + Instruction::GridDepControl { + data: action + } + } + .action: GridDepControlAction = { .launch_dependents, .wait }; ); #[cfg(test)]