diff --git a/ptx/src/pass/filter_for_demo.rs b/ptx/src/pass/filter_for_demo.rs new file mode 100644 index 0000000..816aab7 --- /dev/null +++ b/ptx/src/pass/filter_for_demo.rs @@ -0,0 +1,30 @@ +pub(crate) fn run<'input>( + directives: Vec>>, +) -> Vec>> { + let demo_kernels_path = std::env::var("ZLUDA_DEMO_KERNELS").unwrap(); + let demo_kernels_file = std::fs::read_to_string(demo_kernels_path).unwrap(); + let demo_kernels = demo_kernels_file + .lines() + .map(|line| line.trim()) + .filter(|line| !line.is_empty()) + .collect::>(); + let result = directives + .into_iter() + .filter(|directive| match directive { + ptx_parser::Directive::Method(_, method) => { + !method.func_directive.name.is_kernel() + || demo_kernels.contains(method.func_directive.name()) + } + _ => true, + }) + .collect::>(); + for directive in result.iter() { + match directive { + ptx_parser::Directive::Method(_, method) => { + eprintln!("{}", method.func_directive.name()); + } + _ => {} + } + } + result +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index e4b5b27..94fdad5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -13,6 +13,7 @@ use strum_macros::EnumIter; mod deparamize_functions; mod expand_operands; +mod filter_for_demo; mod fix_special_registers; mod hoist_globals; mod insert_explicit_load_store; @@ -65,8 +66,10 @@ pub fn to_llvm_module<'input>( let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; - let directives = normalize_identifiers::run(&mut scoped_resolver, ast.directives)?; - on_pass_end("normalize_identifiers"); + let directives = filter_for_demo::run(ast.directives); + on_pass_end("filter_for_demo"); + let directives = normalize_identifiers2::run(&mut scoped_resolver, directives)?; + on_pass_end("normalize_identifiers2"); let directives = replace_known_functions::run(&mut flat_resolver, directives); on_pass_end("replace_known_functions"); let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;