diff --git a/src/bin.rs b/src/bin.rs index d0e9230..875f596 100644 --- a/src/bin.rs +++ b/src/bin.rs @@ -1,76 +1,89 @@ -extern crate clap; -extern crate detours_sys; - -use std::error::Error; -use std::ffi::OsStr; -use std::mem; -use std::os::windows::ffi::OsStrExt; -use std::ptr; - -use clap::{App, AppSettings, Arg}; - -mod win_err; - -fn main() -> Result<(), Box> { - let matches = App::new("notCUDA injector") - .setting(AppSettings::TrailingVarArg) - .arg( - Arg::with_name("EXE") - .help("Path to the executable to be injected with notCUDA") - .required(true), - ) - .arg( - Arg::with_name("ARGS") - .multiple(true) - .help("Arguments that will be passed to "), - ) - .get_matches(); - let exe = matches.value_of_os("EXE").unwrap(); - let args: Vec<&OsStr> = matches - .values_of_os("ARGS") - .map(|x| x.collect()) - .unwrap_or_else(|| Vec::new()); - let mut cmd_line = Vec::::with_capacity(exe.len() + 2); - cmd_line.push('\"' as u16); - copy_to(exe, &mut cmd_line); - cmd_line.push('\"' as u16); - cmd_line.push(' ' as u16); - args.split_last().map(|(last_arg, args)| { - for arg in args { - cmd_line.reserve(arg.len()); - copy_to(arg, &mut cmd_line); - cmd_line.push(' ' as u16); - } - copy_to(last_arg, &mut cmd_line); - }); - - cmd_line.push(0); - let mut startup_info = unsafe { mem::zeroed::() }; - let mut proc_info = unsafe { mem::zeroed::() }; - let process_success = unsafe { - detours_sys::DetourCreateProcessWithDllExW( - ptr::null(), - cmd_line.as_mut_ptr(), - ptr::null_mut(), - ptr::null_mut(), - 0, - 0x10, - ptr::null_mut(), - ptr::null(), - &mut startup_info as *mut _, - &mut proc_info as *mut _, - "nvcuda_redirect.dll".as_ptr() as *const i8, - Option::None, - ) - }; - if process_success == 0 { - return Err(win_err::error_string(win_err::errno()))?; - } - Ok(()) -} - -fn copy_to(from: &OsStr, to: &mut Vec) { - for x in from.encode_wide() { - to.push(x); - } -} +extern crate clap; +extern crate detours_sys; + +use std::error::Error; +use std::ffi::OsStr; +use std::mem; +use std::os::windows::ffi::OsStrExt; +use std::ptr; + +use clap::{App, AppSettings, Arg}; + +#[macro_use] +mod win; + +fn main() -> Result<(), Box> { + let matches = App::new("notCUDA injector") + .setting(AppSettings::TrailingVarArg) + .arg( + Arg::with_name("EXE") + .help("Path to the executable to be injected with notCUDA") + .required(true), + ) + .arg( + Arg::with_name("ARGS") + .multiple(true) + .help("Arguments that will be passed to "), + ) + .get_matches(); + let exe = matches.value_of_os("EXE").unwrap(); + let args: Vec<&OsStr> = matches + .values_of_os("ARGS") + .map(|x| x.collect()) + .unwrap_or_else(|| Vec::new()); + let mut cmd_line = Vec::::with_capacity(exe.len() + 2); + cmd_line.push('\"' as u16); + copy_to(exe, &mut cmd_line); + cmd_line.push('\"' as u16); + cmd_line.push(' ' as u16); + args.split_last().map(|(last_arg, args)| { + for arg in args { + cmd_line.reserve(arg.len()); + copy_to(arg, &mut cmd_line); + cmd_line.push(' ' as u16); + } + copy_to(last_arg, &mut cmd_line); + }); + + cmd_line.push(0); + let mut startup_info = unsafe { mem::zeroed::() }; + let mut proc_info = unsafe { mem::zeroed::() }; + os_call!( + detours_sys::DetourCreateProcessWithDllExW( + ptr::null(), + cmd_line.as_mut_ptr(), + ptr::null_mut(), + ptr::null_mut(), + 0, + 0x10, + ptr::null_mut(), + ptr::null(), + &mut startup_info as *mut _, + &mut proc_info as *mut _, + "nvcuda_redirect.dll".as_ptr() as *const i8, + Option::None + ), + 0 + ); + Ok(()) + /* + + cmd_line.as_mut_ptr(), + ptr::null_mut(), + ptr::null_mut(), + 0, + 0x10, + ptr::null_mut(), + ptr::null(), + &mut startup_info as *mut _, + &mut proc_info as *mut _, + "nvcuda_redirect.dll".as_ptr() as *const i8, + Option::None, + */ +} + +fn copy_to(from: &OsStr, to: &mut Vec) { + for x in from.encode_wide() { + to.push(x); + } +} diff --git a/src/win_err.rs b/src/win.rs similarity index 72% rename from src/win_err.rs rename to src/win.rs index f3a675b..3055202 100644 --- a/src/win_err.rs +++ b/src/win.rs @@ -1,103 +1,149 @@ -#![allow(non_snake_case)] - -use std::ptr; - -mod c { - use std::ffi::c_void; - use std::os::raw::{c_ulong}; - - pub type DWORD = c_ulong; - pub type HANDLE = LPVOID; - pub type LPVOID = *mut c_void; - pub type HINSTANCE = HANDLE; - pub type HMODULE = HINSTANCE; - pub type WCHAR = u16; - pub type LPCWSTR = *const WCHAR; - pub type LPWSTR = *mut WCHAR; - - pub const FACILITY_NT_BIT: DWORD = 0x1000_0000; - pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800; - pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000; - pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200; - - extern "system" { - pub fn GetLastError() -> DWORD; - pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE; - pub fn FormatMessageW( - flags: DWORD, - lpSrc: LPVOID, - msgId: DWORD, - langId: DWORD, - buf: LPWSTR, - nsize: DWORD, - args: *const c_void, - ) -> DWORD; - } -} - -pub fn errno() -> i32 { - unsafe { c::GetLastError() as i32 } -} - -/// Gets a detailed string description for the given error number. -pub fn error_string(mut errnum: i32) -> String { - // This value is calculated from the macro - // MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT) - let langId = 0x0800 as c::DWORD; - - let mut buf = [0 as c::WCHAR; 2048]; - - unsafe { - let mut module = ptr::null_mut(); - let mut flags = 0; - - // NTSTATUS errors may be encoded as HRESULT, which may returned from - // GetLastError. For more information about Windows error codes, see - // `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx - if (errnum & c::FACILITY_NT_BIT as i32) != 0 { - // format according to https://support.microsoft.com/en-us/help/259693 - const NTDLL_DLL: &[u16] = &[ - 'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _, - 'L' as _, 0, - ]; - module = c::GetModuleHandleW(NTDLL_DLL.as_ptr()); - - if module != ptr::null_mut() { - errnum ^= c::FACILITY_NT_BIT as i32; - flags = c::FORMAT_MESSAGE_FROM_HMODULE; - } - } - - let res = c::FormatMessageW( - flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS, - module, - errnum as c::DWORD, - langId, - buf.as_mut_ptr(), - buf.len() as c::DWORD, - ptr::null(), - ) as usize; - if res == 0 { - // Sometimes FormatMessageW can fail e.g., system doesn't like langId, - let fm_err = errno(); - return format!( - "OS Error {} (FormatMessageW() returned error {})", - errnum, fm_err - ); - } - - match String::from_utf16(&buf[..res]) { - Ok(mut msg) => { - // Trim trailing CRLF inserted by FormatMessageW - let len = msg.trim_end().len(); - msg.truncate(len); - msg - } - Err(..) => format!( - "OS Error {} (FormatMessageW() returned \ - invalid UTF-16)", - errnum - ), - } - } -} +#![allow(non_snake_case)] + +use std::error; +use std::error::Error; +use std::fmt; +use std::ptr; + +mod c { + use std::ffi::c_void; + use std::os::raw::c_ulong; + + pub type DWORD = c_ulong; + pub type HANDLE = LPVOID; + pub type LPVOID = *mut c_void; + pub type HINSTANCE = HANDLE; + pub type HMODULE = HINSTANCE; + pub type WCHAR = u16; + pub type LPCWSTR = *const WCHAR; + pub type LPWSTR = *mut WCHAR; + + pub const FACILITY_NT_BIT: DWORD = 0x1000_0000; + pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800; + pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000; + pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200; + + extern "system" { + pub fn GetLastError() -> DWORD; + pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE; + pub fn FormatMessageW( + flags: DWORD, + lpSrc: LPVOID, + msgId: DWORD, + langId: DWORD, + buf: LPWSTR, + nsize: DWORD, + args: *const c_void, + ) -> DWORD; + } +} + +macro_rules! last_ident { + ($i:ident) => { + stringify!($i) + }; + ($start:ident, $($cont:ident),+) => { + last_ident!($($cont),+) + }; +} + +macro_rules! os_call { + ($($path:ident)::+ ($($args:expr),*), $success:expr) => { + let result = unsafe{ $($path)::+ ($($args),+) }; + if result != $success { + let name = last_ident!($($path),+); + let err_code = $crate::win::errno(); + Err($crate::win::OsError{ + function: name, + error_code: err_code as u32, + message: $crate::win::error_string(err_code) + })?; + } + }; +} + +#[derive(Debug)] +pub struct OsError { + pub function: &'static str, + pub error_code: u32, + pub message: String, +} + +impl fmt::Display for OsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl error::Error for OsError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + None + } +} + +pub fn errno() -> i32 { + unsafe { c::GetLastError() as i32 } +} + +/// Gets a detailed string description for the given error number. +pub fn error_string(mut errnum: i32) -> String { + // This value is calculated from the macro + // MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT) + let langId = 0x0800 as c::DWORD; + + let mut buf = [0 as c::WCHAR; 2048]; + + unsafe { + let mut module = ptr::null_mut(); + let mut flags = 0; + + // NTSTATUS errors may be encoded as HRESULT, which may returned from + // GetLastError. For more information about Windows error codes, see + // `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx + if (errnum & c::FACILITY_NT_BIT as i32) != 0 { + // format according to https://support.microsoft.com/en-us/help/259693 + const NTDLL_DLL: &[u16] = &[ + 'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _, + 'L' as _, 0, + ]; + module = c::GetModuleHandleW(NTDLL_DLL.as_ptr()); + + if module != ptr::null_mut() { + errnum ^= c::FACILITY_NT_BIT as i32; + flags = c::FORMAT_MESSAGE_FROM_HMODULE; + } + } + + let res = c::FormatMessageW( + flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS, + module, + errnum as c::DWORD, + langId, + buf.as_mut_ptr(), + buf.len() as c::DWORD, + ptr::null(), + ) as usize; + if res == 0 { + // Sometimes FormatMessageW can fail e.g., system doesn't like langId, + let fm_err = errno(); + return format!( + "OS Error {} (FormatMessageW() returned error {})", + errnum, fm_err + ); + } + + match String::from_utf16(&buf[..res]) { + Ok(mut msg) => { + // Trim trailing CRLF inserted by FormatMessageW + let len = msg.trim_end().len(); + msg.truncate(len); + msg + } + Err(..) => format!( + "OS Error {} (FormatMessageW() returned \ + invalid UTF-16)", + errnum + ), + } + } +}