diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 9c5671b..40b2447 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -138,11 +138,23 @@ impl<'a> DataSet<'a> { Ok(()) } + pub fn from_data(comgr: &'a Comgr, data: impl Iterator) -> Result { + let dataset = Self::new(comgr)?; + for data in data { + dataset.add(&data)?; + } + Ok(dataset) + } + fn get_data(&self, kind: DataKind, index: usize) -> Result { let mut handle = 0u64; call_dispatch!(self.comgr => amd_comgr_action_data_get_data(self, kind, { index }, { std::ptr::from_mut(&mut handle).cast() })); Ok(Data(handle)) } + + fn get_content(&self, comgr: &Comgr, kind: DataKind, index: usize) -> Result, Error> { + self.get_data(kind, index).map(|data| data.copy_content(comgr))? + } } struct Data(u64); @@ -196,28 +208,24 @@ pub fn compile_bitcode( attributes_buffer: &[u8], compiler_hook: Option<&dyn Fn(&Vec, String)>, ) -> Result, Error> { - let bitcode_data_set = DataSet::new(comgr)?; - let main_bitcode_data = Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?; - bitcode_data_set.add(&main_bitcode_data)?; - let stdlib_bitcode_data = Data::new(comgr, DataKind::Bc, c"ptx_impl.bc", ptx_impl)?; - bitcode_data_set.add(&stdlib_bitcode_data)?; - let attributes_bitcode_data = - Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?; - bitcode_data_set.add(&attributes_bitcode_data)?; + let bitcode_data_set = DataSet::from_data( + comgr, + [ + Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?, + Data::new(comgr, DataKind::Bc, c"ptx_impl.bc", ptx_impl)?, + Data::new(comgr, DataKind::Bc, c"attributes.bc", attributes_buffer)?, + ].into_iter(), + )?; let linking_info = ActionInfo::new(comgr)?; let linked_data_set = comgr.do_action(ActionKind::LinkBcToBc, &linking_info, &bitcode_data_set)?; if let Some(hook) = compiler_hook { // Run compiler hook on human-readable LLVM IR - let data = linked_data_set.get_data(DataKind::Bc, 0)?; - let data = data.copy_content(comgr)?; + let data = linked_data_set.get_content(comgr, DataKind::Bc, 0)?; let data = ptx::bitcode_to_ir(data); hook(&data, String::from("linked.ll")); } - let compile_to_exec = ActionInfo::new(comgr)?; - compile_to_exec.set_isa_name(gcn_arch)?; - compile_to_exec.set_language(Language::LlvmIr)?; let common_options = [ // This makes no sense, but it makes ockl linking work c"-Xclang", @@ -249,14 +257,16 @@ pub fn compile_bitcode( c"-inlinehint-threshold=3250", ] }; + let compile_to_exec = ActionInfo::new(comgr)?; + compile_to_exec.set_isa_name(gcn_arch)?; + compile_to_exec.set_language(Language::LlvmIr)?; compile_to_exec.set_options(common_options.chain(opt_options))?; let exec_data_set = comgr.do_action( ActionKind::CompileSourceToExecutable, &compile_to_exec, &linked_data_set, )?; - let executable = exec_data_set.get_data(DataKind::Executable, 0)?; - let executable = executable.copy_content(comgr); + let executable = exec_data_set.get_content(comgr, DataKind::Executable, 0); if let Some(hook) = compiler_hook { // Run compiler hook for executable hook( @@ -272,9 +282,8 @@ pub fn compile_bitcode( &action_info, &exec_data_set, )?; - let disassembly = disassembly.get_data(DataKind::Source, 0)?; - let disassembly = disassembly.copy_content(comgr); - hook(&disassembly.unwrap_or(Vec::new()), String::from("asm")) + let disassembly = disassembly.get_content(comgr, DataKind::Source, 0)?; + hook(&disassembly, String::from("asm")) } executable } @@ -296,14 +305,15 @@ pub fn get_symbols(comgr: &Comgr, elf: &[u8]) -> Result, Erro } pub fn get_clang_version(comgr: &Comgr) -> Result { - let version_string_set = DataSet::new(comgr)?; - let version_string = Data::new( + let version_string_set = DataSet::from_data( comgr, - DataKind::Source, - c"version.cpp", - b"__clang_version__", + iter::once(Data::new( + comgr, + DataKind::Source, + c"version.cpp", + b"__clang_version__", + )?), )?; - version_string_set.add(&version_string)?; let preprocessor_info = ActionInfo::new(comgr)?; preprocessor_info.set_language(Language::Hip)?; preprocessor_info.set_options(iter::once(c"-P"))?; @@ -312,8 +322,8 @@ pub fn get_clang_version(comgr: &Comgr) -> Result { &preprocessor_info, &version_string_set, )?; - let data = preprocessed.get_data(DataKind::Source, 0)?; - String::from_utf8(trim_whitespace_and_quotes(data.copy_content(comgr)?)?) + let data = preprocessed.get_content(comgr, DataKind::Source, 0)?; + String::from_utf8(trim_whitespace_and_quotes(data)?) .map_err(|_| Error::UNKNOWN) }