Refactor winapi calls to surface errors

This commit is contained in:
Andrzej Janik 2020-01-06 23:15:00 +01:00
commit 6bd033c369
2 changed files with 238 additions and 179 deletions

View file

@ -1,76 +1,89 @@
extern crate clap; extern crate clap;
extern crate detours_sys; extern crate detours_sys;
use std::error::Error; use std::error::Error;
use std::ffi::OsStr; use std::ffi::OsStr;
use std::mem; use std::mem;
use std::os::windows::ffi::OsStrExt; use std::os::windows::ffi::OsStrExt;
use std::ptr; use std::ptr;
use clap::{App, AppSettings, Arg}; use clap::{App, AppSettings, Arg};
mod win_err; #[macro_use]
mod win;
fn main() -> Result<(), Box<dyn Error>> {
let matches = App::new("notCUDA injector") fn main() -> Result<(), Box<dyn Error>> {
.setting(AppSettings::TrailingVarArg) let matches = App::new("notCUDA injector")
.arg( .setting(AppSettings::TrailingVarArg)
Arg::with_name("EXE") .arg(
.help("Path to the executable to be injected with notCUDA") Arg::with_name("EXE")
.required(true), .help("Path to the executable to be injected with notCUDA")
) .required(true),
.arg( )
Arg::with_name("ARGS") .arg(
.multiple(true) Arg::with_name("ARGS")
.help("Arguments that will be passed to <EXE>"), .multiple(true)
) .help("Arguments that will be passed to <EXE>"),
.get_matches(); )
let exe = matches.value_of_os("EXE").unwrap(); .get_matches();
let args: Vec<&OsStr> = matches let exe = matches.value_of_os("EXE").unwrap();
.values_of_os("ARGS") let args: Vec<&OsStr> = matches
.map(|x| x.collect()) .values_of_os("ARGS")
.unwrap_or_else(|| Vec::new()); .map(|x| x.collect())
let mut cmd_line = Vec::<u16>::with_capacity(exe.len() + 2); .unwrap_or_else(|| Vec::new());
cmd_line.push('\"' as u16); let mut cmd_line = Vec::<u16>::with_capacity(exe.len() + 2);
copy_to(exe, &mut cmd_line); cmd_line.push('\"' as u16);
cmd_line.push('\"' as u16); copy_to(exe, &mut cmd_line);
cmd_line.push(' ' as u16); cmd_line.push('\"' as u16);
args.split_last().map(|(last_arg, args)| { cmd_line.push(' ' as u16);
for arg in args { args.split_last().map(|(last_arg, args)| {
cmd_line.reserve(arg.len()); for arg in args {
copy_to(arg, &mut cmd_line); cmd_line.reserve(arg.len());
cmd_line.push(' ' as u16); copy_to(arg, &mut cmd_line);
} cmd_line.push(' ' as u16);
copy_to(last_arg, &mut cmd_line); }
}); copy_to(last_arg, &mut cmd_line);
});
cmd_line.push(0);
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() }; cmd_line.push(0);
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() }; let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let process_success = unsafe { let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
detours_sys::DetourCreateProcessWithDllExW( os_call!(
ptr::null(), detours_sys::DetourCreateProcessWithDllExW(
cmd_line.as_mut_ptr(), ptr::null(),
ptr::null_mut(), cmd_line.as_mut_ptr(),
ptr::null_mut(), ptr::null_mut(),
0, ptr::null_mut(),
0x10, 0,
ptr::null_mut(), 0x10,
ptr::null(), ptr::null_mut(),
&mut startup_info as *mut _, ptr::null(),
&mut proc_info as *mut _, &mut startup_info as *mut _,
"nvcuda_redirect.dll".as_ptr() as *const i8, &mut proc_info as *mut _,
Option::None, "nvcuda_redirect.dll".as_ptr() as *const i8,
) Option::None
}; ),
if process_success == 0 { 0
return Err(win_err::error_string(win_err::errno()))?; );
} Ok(())
Ok(()) /*
}
cmd_line.as_mut_ptr(),
fn copy_to(from: &OsStr, to: &mut Vec<u16>) { ptr::null_mut(),
for x in from.encode_wide() { ptr::null_mut(),
to.push(x); 0,
} 0x10,
} ptr::null_mut(),
ptr::null(),
&mut startup_info as *mut _,
&mut proc_info as *mut _,
"nvcuda_redirect.dll".as_ptr() as *const i8,
Option::None,
*/
}
fn copy_to(from: &OsStr, to: &mut Vec<u16>) {
for x in from.encode_wide() {
to.push(x);
}
}

View file

@ -1,103 +1,149 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
use std::ptr; use std::error;
use std::error::Error;
mod c { use std::fmt;
use std::ffi::c_void; use std::ptr;
use std::os::raw::{c_ulong};
mod c {
pub type DWORD = c_ulong; use std::ffi::c_void;
pub type HANDLE = LPVOID; use std::os::raw::c_ulong;
pub type LPVOID = *mut c_void;
pub type HINSTANCE = HANDLE; pub type DWORD = c_ulong;
pub type HMODULE = HINSTANCE; pub type HANDLE = LPVOID;
pub type WCHAR = u16; pub type LPVOID = *mut c_void;
pub type LPCWSTR = *const WCHAR; pub type HINSTANCE = HANDLE;
pub type LPWSTR = *mut WCHAR; pub type HMODULE = HINSTANCE;
pub type WCHAR = u16;
pub const FACILITY_NT_BIT: DWORD = 0x1000_0000; pub type LPCWSTR = *const WCHAR;
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800; pub type LPWSTR = *mut WCHAR;
pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200; pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
extern "system" { pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
pub fn GetLastError() -> DWORD; pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;
pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
pub fn FormatMessageW( extern "system" {
flags: DWORD, pub fn GetLastError() -> DWORD;
lpSrc: LPVOID, pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
msgId: DWORD, pub fn FormatMessageW(
langId: DWORD, flags: DWORD,
buf: LPWSTR, lpSrc: LPVOID,
nsize: DWORD, msgId: DWORD,
args: *const c_void, langId: DWORD,
) -> DWORD; buf: LPWSTR,
} nsize: DWORD,
} args: *const c_void,
) -> DWORD;
pub fn errno() -> i32 { }
unsafe { c::GetLastError() as i32 } }
}
macro_rules! last_ident {
/// Gets a detailed string description for the given error number. ($i:ident) => {
pub fn error_string(mut errnum: i32) -> String { stringify!($i)
// This value is calculated from the macro };
// MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT) ($start:ident, $($cont:ident),+) => {
let langId = 0x0800 as c::DWORD; last_ident!($($cont),+)
};
let mut buf = [0 as c::WCHAR; 2048]; }
unsafe { macro_rules! os_call {
let mut module = ptr::null_mut(); ($($path:ident)::+ ($($args:expr),*), $success:expr) => {
let mut flags = 0; let result = unsafe{ $($path)::+ ($($args),+) };
if result != $success {
// NTSTATUS errors may be encoded as HRESULT, which may returned from let name = last_ident!($($path),+);
// GetLastError. For more information about Windows error codes, see let err_code = $crate::win::errno();
// `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx Err($crate::win::OsError{
if (errnum & c::FACILITY_NT_BIT as i32) != 0 { function: name,
// format according to https://support.microsoft.com/en-us/help/259693 error_code: err_code as u32,
const NTDLL_DLL: &[u16] = &[ message: $crate::win::error_string(err_code)
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _, })?;
'L' as _, 0, }
]; };
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr()); }
if module != ptr::null_mut() { #[derive(Debug)]
errnum ^= c::FACILITY_NT_BIT as i32; pub struct OsError {
flags = c::FORMAT_MESSAGE_FROM_HMODULE; pub function: &'static str,
} pub error_code: u32,
} pub message: String,
}
let res = c::FormatMessageW(
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS, impl fmt::Display for OsError {
module, fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
errnum as c::DWORD, write!(f, "{:?}", self)
langId, }
buf.as_mut_ptr(), }
buf.len() as c::DWORD,
ptr::null(), impl error::Error for OsError {
) as usize; fn source(&self) -> Option<&(dyn error::Error + 'static)> {
if res == 0 { None
// Sometimes FormatMessageW can fail e.g., system doesn't like langId, }
let fm_err = errno(); }
return format!(
"OS Error {} (FormatMessageW() returned error {})", pub fn errno() -> i32 {
errnum, fm_err unsafe { c::GetLastError() as i32 }
); }
}
/// Gets a detailed string description for the given error number.
match String::from_utf16(&buf[..res]) { pub fn error_string(mut errnum: i32) -> String {
Ok(mut msg) => { // This value is calculated from the macro
// Trim trailing CRLF inserted by FormatMessageW // MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
let len = msg.trim_end().len(); let langId = 0x0800 as c::DWORD;
msg.truncate(len);
msg let mut buf = [0 as c::WCHAR; 2048];
}
Err(..) => format!( unsafe {
"OS Error {} (FormatMessageW() returned \ let mut module = ptr::null_mut();
invalid UTF-16)", let mut flags = 0;
errnum
), // NTSTATUS errors may be encoded as HRESULT, which may returned from
} // GetLastError. For more information about Windows error codes, see
} // `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
} if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
// format according to https://support.microsoft.com/en-us/help/259693
const NTDLL_DLL: &[u16] = &[
'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
'L' as _, 0,
];
module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());
if module != ptr::null_mut() {
errnum ^= c::FACILITY_NT_BIT as i32;
flags = c::FORMAT_MESSAGE_FROM_HMODULE;
}
}
let res = c::FormatMessageW(
flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
module,
errnum as c::DWORD,
langId,
buf.as_mut_ptr(),
buf.len() as c::DWORD,
ptr::null(),
) as usize;
if res == 0 {
// Sometimes FormatMessageW can fail e.g., system doesn't like langId,
let fm_err = errno();
return format!(
"OS Error {} (FormatMessageW() returned error {})",
errnum, fm_err
);
}
match String::from_utf16(&buf[..res]) {
Ok(mut msg) => {
// Trim trailing CRLF inserted by FormatMessageW
let len = msg.trim_end().len();
msg.truncate(len);
msg
}
Err(..) => format!(
"OS Error {} (FormatMessageW() returned \
invalid UTF-16)",
errnum
),
}
}
}