diff --git a/zluda_trace/src/log.rs b/zluda_trace/src/log.rs index b3f9716..9cbb9cc 100644 --- a/zluda_trace/src/log.rs +++ b/zluda_trace/src/log.rs @@ -303,6 +303,7 @@ pub(crate) enum ErrorEntry { }, NullPointer(&'static str), UnknownLibrary(CUlibrary), + SavedModule(String), } unsafe impl Send for ErrorEntry {} @@ -344,93 +345,94 @@ impl Display for ErrorEntry { match self { ErrorEntry::IoError(e) => e.fmt(f), ErrorEntry::CreatedDumpDirectory(dir) => { - write!( - f, - "Created trace directory {} ", - dir.as_os_str().to_string_lossy() - ) - } + write!( + f, + "Created trace directory {} ", + dir.as_os_str().to_string_lossy() + ) + } ErrorEntry::ErrorBox(e) => e.fmt(f), ErrorEntry::UnsupportedModule { - module, - raw_image, - kind, - } => { - write!( - f, - "Unsupported {} module {:?} loaded from module image {:?}", - kind, module, raw_image - ) - } + module, + raw_image, + kind, + } => { + write!( + f, + "Unsupported {} module {:?} loaded from module image {:?}", + kind, module, raw_image + ) + } ErrorEntry::MalformedModulePath(e) => e.fmt(f), ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f), ErrorEntry::ModuleParsingError(file_name) => { - write!( - f, - "Error parsing module, log has been written to {}", - file_name - ) - } + write!( + f, + "Error parsing module, log has been written to {}", + file_name + ) + } ErrorEntry::NulInsideModuleText(e) => e.fmt(f), ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)), ErrorEntry::UnexpectedBinaryField { - field_name, - expected, - observed, - } => write!( - f, - "Unexpected field {}. Expected one of: [{}], observed: {}", - field_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), + field_name, + expected, + observed, + } => write!( + f, + "Unexpected field {}. Expected one of: [{}], observed: {}", + field_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), ErrorEntry::UnexpectedArgument { - arg_name, - expected, - observed, - } => write!( - f, - "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", - arg_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), + arg_name, + expected, + observed, + } => write!( + f, + "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", + arg_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), ErrorEntry::InvalidEnvVar { - var, - pattern, - value, - } => write!( - f, - "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" - ), + var, + pattern, + value, + } => write!( + f, + "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" + ), ErrorEntry::FunctionNotFound(cuda_function_name) => write!( - f, - "No function {cuda_function_name} in the underlying library" - ), + f, + "No function {cuda_function_name} in the underlying library" + ), ErrorEntry::UnexpectedExportTableSize { expected, computed } => { - write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") - } + write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") + } ErrorEntry::IntegrityCheck { original, overriden } => { - write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") - } + write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") + } ErrorEntry::NullPointer(type_) => { - write!(f, "Null pointer of type {type_} encountered") - } + write!(f, "Null pointer of type {type_} encountered") + } ErrorEntry::UnknownLibrary(culibrary) => { - write!(f, "Unknown library: ")?; - let mut temp_buffer = Vec::new(); - CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); - f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) - } + write!(f, "Unknown library: ")?; + let mut temp_buffer = Vec::new(); + CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); + f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) + } + ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"), } } } diff --git a/zluda_trace/src/trace.rs b/zluda_trace/src/trace.rs index e71aacd..f397d34 100644 --- a/zluda_trace/src/trace.rs +++ b/zluda_trace/src/trace.rs @@ -128,12 +128,11 @@ impl StateTracker { fn_logger: &mut FnCallLog, type_: &'static str, ) { - fn_logger.log_io_error(self.writer.save_module( - self.library_counter, - index, - submodule, - type_, - )); + fn_logger.try_(|fn_logger| { + self.writer + .save_module(fn_logger, self.library_counter, index, submodule, type_) + .map_err(ErrorEntry::IoError) + }); if type_ == "ptx" { match CString::new(submodule) { Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), @@ -323,6 +322,7 @@ impl DumpWriter { fn save_module( &self, + fn_logger: &mut FnCallLog, module_index: usize, submodule_index: Option<(usize, Option)>, buffer: &[u8], @@ -332,9 +332,13 @@ impl DumpWriter { None => return Ok(()), Some(d) => d.clone(), }; - dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); - let mut file = File::create_new(dump_file)?; - file.write_all(buffer)?; + let file_name = Self::get_file_name(module_index, submodule_index, kind); + dump_file.push(&file_name); + { + let mut file = File::create_new(dump_file)?; + file.write_all(buffer)?; + } + fn_logger.log(ErrorEntry::SavedModule(file_name)); Ok(()) } @@ -349,7 +353,7 @@ impl DumpWriter { Some(d) => d.clone(), }; log_file.push(Self::get_file_name(module_index, submodule_index, "log")); - let mut file = File::create(log_file)?; + let mut file = File::create_new(log_file)?; for error in errors { writeln!(file, "{}", error)?; }