Merge branch 'master' into nan

This commit is contained in:
Violet 2025-09-12 23:16:28 +00:00
commit 25bd3861da
22 changed files with 282 additions and 64 deletions

View file

@ -6,6 +6,7 @@ use std::io::{self, Write};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::ExitCode; use std::process::ExitCode;
use std::str; use std::str;
use std::time::Instant;
use std::{env, mem}; use std::{env, mem};
mod error; mod error;
@ -57,19 +58,12 @@ fn main_core() -> Result<(), CompilerError> {
let arch: String = match opts.arch { let arch: String = match opts.arch {
Some(s) => s, Some(s) => s,
None => { None => (|| {
(|| { let runtime = hip::Runtime::load()?;
let runtime = hip::Runtime::load()?; runtime.init()?;
runtime.init()?; get_gpu_arch(&runtime)
get_gpu_arch(&runtime) })()
})() .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()),
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned())
/*
get_gpu_arch(&mut dev_props)
.map(String::from)
.unwrap_or(DEFAULT_ARCH.to_owned())
*/
}
}; };
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
@ -83,6 +77,7 @@ fn main_core() -> Result<(), CompilerError> {
write_to_file(bytes, &output_path).unwrap(); write_to_file(bytes, &output_path).unwrap();
}; };
let mut start = Instant::now();
comgr::compile_bitcode( comgr::compile_bitcode(
&comgr, &comgr,
&arch, &arch,
@ -92,17 +87,22 @@ fn main_core() -> Result<(), CompilerError> {
Some(&comgr_hook), Some(&comgr_hook),
) )
.map_err(CompilerError::from)?; .map_err(CompilerError::from)?;
report_pass_time("compile_bitcode", &mut start);
Ok(()) Ok(())
} }
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> { fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
let mut start = Instant::now();
let module = ptx::to_llvm_module( let module = ptx::to_llvm_module(
ast, ast,
ptx::Attributes { ptx::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|pass| {
report_pass_time(pass, &mut start);
},
) )
.map_err(CompilerError::from)?; .map_err(CompilerError::from)?;
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
@ -117,6 +117,12 @@ fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
}) })
} }
fn report_pass_time(pass: &str, start: &mut Instant) {
let duration = start.elapsed();
println!("Pass {:?} took {:?}", pass, duration);
*start = Instant::now();
}
#[derive(Debug)] #[derive(Debug)]
struct LLVMArtifacts { struct LLVMArtifacts {
bitcode: Vec<u8>, bitcode: Vec<u8>,

View file

@ -313,7 +313,7 @@ fn join(
pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream { pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream {
let fn_ = parse_macro_input!(item as syn::ItemFn); let fn_ = parse_macro_input!(item as syn::ItemFn);
let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia"); let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia");
let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_amdgpu"); let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_zluda");
let fn_name = fn_.sig.ident.clone(); let fn_name = fn_.sig.ident.clone();
quote! { quote! {
#[test] #[test]

7
ext/rocm_smi-sys/build.rs vendored Normal file
View file

@ -0,0 +1,7 @@
use std::env::VarError;
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=rocm_smi64");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
Ok(())
}

View file

@ -114,6 +114,13 @@ fn run_statement<'a, 'input>(
result.push(Statement::Instruction(instruction)); result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction));
} }
Statement::Instruction(ast::Instruction::Mov { data, arguments }) => {
let instruction = visitor.visit_mov(data, arguments);
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::PtrAccess(ptr_access) => { Statement::PtrAccess(ptr_access) => {
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
let statement = statement.visit_map(visitor)?; let statement = statement.visit_map(visitor)?;
@ -293,6 +300,39 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
}) })
} }
fn visit_mov(
&mut self,
data: ptx_parser::MovDetails,
mut arguments: ptx_parser::MovArgs<SpirvWord>,
) -> ast::Instruction<SpirvWord> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
name,
new_space,
old_space,
} => {
let generic_var = self
.resolver
.register_unnamed(Some((data.typ.clone(), ast::StateSpace::Reg)));
self.pre.push(ast::Instruction::Cvta {
data: ast::CvtaDetails {
state_space: *new_space,
direction: ast::CvtaDirection::ExplicitToGeneric,
},
arguments: ast::CvtaArgs {
dst: generic_var,
src: *name,
},
});
arguments.src = generic_var;
}
}
}
ast::Instruction::Mov { data, arguments }
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> { fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space { let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,

View file

@ -152,11 +152,11 @@ fn is_addressable(this: ast::StateSpace) -> bool {
| ast::StateSpace::Generic | ast::StateSpace::Generic
| ast::StateSpace::Global | ast::StateSpace::Global
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => true, | ast::StateSpace::Shared
| ast::StateSpace::ParamEntry => true,
ast::StateSpace::Param | ast::StateSpace::Reg => false, ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta | ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(), | ast::StateSpace::ParamFunc => todo!(),
} }
} }
@ -180,13 +180,14 @@ fn default_implicit_conversion_space(
| ast::StateSpace::Generic | ast::StateSpace::Generic
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), | ast::StateSpace::Shared
| ast::StateSpace::Param => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
}, },
ast::Type::Scalar(ast::ScalarType::B32) ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr)) Ok(Some(ConversionKind::BitToPtr))
} }
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
@ -220,7 +221,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool {
ast::StateSpace::Global ast::StateSpace::Global
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::Local | ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta | ast::StateSpace::SharedCta
| ast::StateSpace::SharedCluster | ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true, | ast::StateSpace::Shared => true,
ast::StateSpace::Reg ast::StateSpace::Reg

View file

@ -704,7 +704,7 @@ impl<'a> MethodEmitContext<'a> {
}); });
Ok(()) Ok(())
} }
_ => todo!(), _ => return Err(error_todo()),
} }
} }
@ -2455,7 +2455,7 @@ impl<'a> MethodEmitContext<'a> {
(control >> 12) & 0b1111, (control >> 12) & 0b1111,
]; ];
if components.iter().any(|&c| c > 7) { if components.iter().any(|&c| c > 7) {
return Err(TranslateError::Todo("".to_string())); return Err(error_todo());
} }
let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;

View file

@ -169,8 +169,9 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
ast::ScalarType::U16x2 => todo!(), ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => unsafe {
ast::ScalarType::S16x2 => todo!(), LLVMVectorType(LLVMInt16TypeInContext(context), 2)
},
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) }, ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) },
} }
@ -180,14 +181,17 @@ fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space { match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())), // This is dodgy, we try our best to convert all .param into either
// .param::entry or .local, but we can't always succeed.
// In those cases we convert .param into generic address space
ast::StateSpace::Param => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())), ast::StateSpace::ParamFunc => Err(error_todo()),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE),
ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())), ast::StateSpace::SharedCta => Err(error_todo()),
ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())), ast::StateSpace::SharedCluster => Err(error_todo()),
} }
} }

View file

@ -60,31 +60,48 @@ pub struct Attributes {
pub fn to_llvm_module<'input>( pub fn to_llvm_module<'input>(
ast: ast::Module<'input>, ast: ast::Module<'input>,
attributes: Attributes, attributes: Attributes,
mut on_pass_end: impl FnMut(&str),
) -> Result<Module, TranslateError> { ) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
on_pass_end("normalize_identifiers2");
let directives = replace_known_functions::run(&mut flat_resolver, directives); let directives = replace_known_functions::run(&mut flat_resolver, directives);
on_pass_end("replace_known_functions");
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_predicates2");
let directives = resolve_function_pointers::run(directives)?; let directives = resolve_function_pointers::run(directives)?;
on_pass_end("resolve_function_pointers");
let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?; let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?;
on_pass_end("fix_special_registers");
let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?;
on_pass_end("expand_operands");
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
on_pass_end("insert_post_saturation");
let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
on_pass_end("deparamize_functions");
let directives = let directives =
replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?; replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions_fp_required");
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_basic_blocks");
let directives = remove_unreachable_basic_blocks::run(directives)?; let directives = remove_unreachable_basic_blocks::run(directives)?;
on_pass_end("remove_unreachable_basic_blocks");
let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?;
on_pass_end("instruction_mode_to_global_mode");
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
on_pass_end("insert_explicit_load_store");
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
on_pass_end("insert_implicit_conversions2");
let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?; let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions");
let directives = hoist_globals::run(directives)?; let directives = hoist_globals::run(directives)?;
on_pass_end("hoist_globals");
let context = llvm::Context::new(); let context = llvm::Context::new();
let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?; let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?;
let attributes_ir = llvm::attributes::run(&context, attributes)?; let attributes_ir = llvm::attributes::run(&context, attributes)?;
on_pass_end("emit_llvm");
Ok(Module { Ok(Module {
llvm_ir, llvm_ir,
attributes_ir, attributes_ir,

View file

@ -0,0 +1,34 @@
define amdgpu_kernel void @param_is_addressable(ptr addrspace(4) byref(i64) %"33", ptr addrspace(4) byref(i64) %"34") #0 {
%"35" = alloca i64, align 8, addrspace(5)
%"36" = alloca i64, align 8, addrspace(5)
%"37" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"32"
"32": ; preds = %1
%"38" = load i64, ptr addrspace(4) %"33", align 8
store i64 %"38", ptr addrspace(5) %"35", align 8
%"39" = load i64, ptr addrspace(4) %"34", align 8
store i64 %"39", ptr addrspace(5) %"36", align 8
%"49" = ptrtoint ptr addrspace(4) %"33" to i64
%2 = inttoptr i64 %"49" to ptr addrspace(4)
%"40" = addrspacecast ptr addrspace(4) %2 to ptr
store ptr %"40", ptr addrspace(5) %"37", align 8
%"43" = load i64, ptr addrspace(5) %"37", align 8
%"50" = inttoptr i64 %"43" to ptr
%"42" = load i64, ptr %"50", align 8
store i64 %"42", ptr addrspace(5) %"37", align 8
%"45" = load i64, ptr addrspace(5) %"37", align 8
%"46" = load i64, ptr addrspace(5) %"35", align 8
%"51" = sub i64 %"45", %"46"
store i64 %"51", ptr addrspace(5) %"37", align 8
%"47" = load i64, ptr addrspace(5) %"36", align 8
%"48" = load i64, ptr addrspace(5) %"37", align 8
%"53" = inttoptr i64 %"47" to ptr
store i64 %"48", ptr %"53", align 8
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -38,7 +38,7 @@ fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> {
let attributes = pass::Attributes { let attributes = pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}; };
crate::to_llvm_module(ast, attributes)?; crate::to_llvm_module(ast, attributes, |_| {})?;
Ok(()) Ok(())
} }

View file

@ -345,6 +345,7 @@ test_ptx!(
[0x8e2da590u32, 0xedeaee14, 0x248a9f70], [0x8e2da590u32, 0xedeaee14, 0x248a9f70],
[613065134u32] [613065134u32]
); );
test_ptx!(param_is_addressable, [0xDEAD], [0u64]);
test_ptx!(assertfail); test_ptx!(assertfail);
// TODO: not yet supported // TODO: not yet supported
@ -562,6 +563,7 @@ fn test_hip_assert<
pass::Attributes { pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let name = CString::new(name)?; let name = CString::new(name)?;
@ -582,6 +584,7 @@ fn test_llvm_assert(
pass::Attributes { pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); let actual_ll = llvm_ir.llvm_ir.print_module_to_string();

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry param_is_addressable(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b64 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
mov.b64 temp, input;
ld.param.b64 temp, [temp];
sub.u64 temp, temp, in_addr;
st.u64 [out_addr], temp;
ret;
}

View file

@ -892,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>(
( (
vector_prefix, vector_prefix,
scalar_type, scalar_type,
opt((Token::DotPtr, Token::DotGlobal)), opt((Token::DotPtr, opt(Token::DotGlobal))),
opt(align.verify(|x| x.count_ones() == 1)), opt(align.verify(|x| x.count_ones() == 1)),
ident, ident,
), ),

View file

@ -42,6 +42,7 @@ fn main() {
ptx::Attributes { ptx::Attributes {
clock_rate: clock_rate as u32, clock_rate: clock_rate as u32,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let elf_binary = comgr::compile_bitcode( let elf_binary = comgr::compile_bitcode(

Binary file not shown.

View file

@ -494,6 +494,10 @@ pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> CUresult {
Ok(()) Ok(())
} }
pub(crate) fn primary_context_release_v2(hip_dev: hipDevice_t) -> CUresult {
primary_context_release(hip_dev)
}
pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult { pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult {
let (ctx, _) = get_primary_context(hip_dev)?; let (ctx, _) = get_primary_context(hip_dev)?;
ctx.with_state_mut(|state| state.reset())?; ctx.with_state_mut(|state| state.reset())?;

View file

@ -166,7 +166,9 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
None => return CUresult::ERROR_INVALID_VALUE, None => return CUresult::ERROR_INVALID_VALUE,
}; };
device::primary_context_retain(pctx, hip_dev) let (_, cu_ctx) = device::get_primary_context(hip_dev)?;
*pctx = cu_ctx;
Ok(())
} }
unsafe extern "system" fn get_module_from_cubin_ext1( unsafe extern "system" fn get_module_from_cubin_ext1(
@ -362,8 +364,8 @@ fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
(0..device_count) (0..device_count)
.map(|dev| { .map(|dev| {
let mut guid = CUuuid_st { bytes: [0; 16] }; let mut guid = unsafe { mem::zeroed() };
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; device::get_uuid_v2(&mut guid, dev)?;
let mut pci_domain = 0; let mut pci_domain = 0;
device::get_attribute( device::get_attribute(
@ -387,7 +389,7 @@ fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
)?; )?;
Ok(::dark_api::DeviceHashinfo { Ok(::dark_api::DeviceHashinfo {
guid, guid: unsafe { mem::transmute(guid) },
pci_domain, pci_domain,
pci_bus, pci_bus,
pci_device, pci_device,
@ -525,8 +527,57 @@ pub(crate) unsafe fn launch_kernel_ex(
Ok(()) Ok(())
} }
pub(crate) unsafe fn get_error_string(
error: cuda_types::cuda::CUresult,
error_string: &mut *const ::core::ffi::c_char,
) -> CUresult {
*error_string = match error {
CUresult::SUCCESS => c"no error".as_ptr(),
CUresult::ERROR_INVALID_VALUE => c"invalid value".as_ptr(),
CUresult::ERROR_OUT_OF_MEMORY => c"out of memory".as_ptr(),
CUresult::ERROR_NOT_INITIALIZED => c"driver not initialized".as_ptr(),
CUresult::ERROR_DEINITIALIZED => c"driver deinitialized".as_ptr(),
CUresult::ERROR_NO_DEVICE => c"no CUDA-capable device is detected".as_ptr(),
CUresult::ERROR_INVALID_DEVICE => c"invalid device".as_ptr(),
CUresult::ERROR_INVALID_IMAGE => c"invalid kernel image".as_ptr(),
CUresult::ERROR_INVALID_CONTEXT => c"invalid context".as_ptr(),
CUresult::ERROR_CONTEXT_ALREADY_CURRENT => c"context already current".as_ptr(),
CUresult::ERROR_MAP_FAILED => c"map failed".as_ptr(),
CUresult::ERROR_UNMAP_FAILED => c"unmap failed".as_ptr(),
CUresult::ERROR_ARRAY_IS_MAPPED => c"array is mapped".as_ptr(),
CUresult::ERROR_ALREADY_MAPPED => c"already mapped".as_ptr(),
CUresult::ERROR_NO_BINARY_FOR_GPU => c"no binary for GPU".as_ptr(),
CUresult::ERROR_ALREADY_ACQUIRED => c"already acquired".as_ptr(),
CUresult::ERROR_NOT_MAPPED => c"not mapped".as_ptr(),
CUresult::ERROR_NOT_SUPPORTED => c"operation not supported".as_ptr(),
CUresult::ERROR_INVALID_SOURCE => c"invalid source".as_ptr(),
CUresult::ERROR_FILE_NOT_FOUND => c"file not found".as_ptr(),
CUresult::ERROR_INVALID_HANDLE => c"invalid handle".as_ptr(),
CUresult::ERROR_NOT_READY => c"not ready".as_ptr(),
CUresult::ERROR_ILLEGAL_ADDRESS => c"illegal address".as_ptr(),
CUresult::ERROR_LAUNCH_OUT_OF_RESOURCES => c"launch out of resources".as_ptr(),
CUresult::ERROR_LAUNCH_TIMEOUT => c"launch timeout".as_ptr(),
CUresult::ERROR_LAUNCH_INCOMPATIBLE_TEXTURING => c"launch incompatible texturing".as_ptr(),
CUresult::ERROR_PEER_ACCESS_ALREADY_ENABLED => c"peer access already enabled".as_ptr(),
CUresult::ERROR_PEER_ACCESS_NOT_ENABLED => c"peer access not enabled".as_ptr(),
CUresult::ERROR_PRIMARY_CONTEXT_ACTIVE => c"primary context active".as_ptr(),
CUresult::ERROR_CONTEXT_IS_DESTROYED => c"context is destroyed".as_ptr(),
CUresult::ERROR_ASSERT => c"device-side assert triggered".as_ptr(),
CUresult::ERROR_TOO_MANY_PEERS => c"too many peers".as_ptr(),
CUresult::ERROR_HOST_MEMORY_ALREADY_REGISTERED => {
c"host memory already registered".as_ptr()
}
CUresult::ERROR_HOST_MEMORY_NOT_REGISTERED => c"host memory not registered".as_ptr(),
CUresult::ERROR_UNKNOWN => c"unknown error".as_ptr(),
_ => c"error".as_ptr(),
};
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::i32;
use crate::r#impl::driver::AllocationInfo; use crate::r#impl::driver::AllocationInfo;
use crate::tests::CudaApi; use crate::tests::CudaApi;
use cuda_macros::test_cuda; use cuda_macros::test_cuda;
@ -571,4 +622,34 @@ mod tests {
} }
assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None); assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None);
} }
#[test_cuda]
fn primary_context_is_inactive_on_init(api: impl CudaApi) {
api.cuInit(0);
let mut flags = u32::MAX;
let mut active = i32::MAX;
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
assert_eq!(flags, 0);
assert_eq!(active, 0);
}
#[test_cuda]
unsafe fn cudart_interface_fn2_creates_inactive_primary_ctx(api: impl CudaApi) {
api.cuInit(0);
let mut table_ptr = std::ptr::null();
api.cuGetExportTable(&mut table_ptr, &dark_api::cuda::CudartInterface::GUID);
let cuda_rt_iface = dark_api::cuda::CudartInterface::new(table_ptr);
let mut dark_ctx = std::mem::zeroed();
cuda_rt_iface
.cudart_interface_fn2(&mut dark_ctx, 0)
.unwrap();
let mut flags = u32::MAX;
let mut active = i32::MAX;
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
assert_eq!(flags, 0);
assert_eq!(active, 0);
let mut primary_ctx = std::mem::zeroed();
api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0);
assert_eq!(dark_ctx.0, primary_ctx.0);
}
} }

View file

@ -55,12 +55,11 @@ pub(crate) unsafe fn load_data(
_jit_options_values: Option<&mut *mut ::core::ffi::c_void>, _jit_options_values: Option<&mut *mut ::core::ffi::c_void>,
_num_jit_options: ::core::ffi::c_uint, _num_jit_options: ::core::ffi::c_uint,
library_options: Option<&mut CUlibraryOption>, library_options: Option<&mut CUlibraryOption>,
library_option_values: Option<&mut *mut ::core::ffi::c_void>, _library_option_values: Option<&mut *mut ::core::ffi::c_void>,
num_library_options: ::core::ffi::c_uint, num_library_options: ::core::ffi::c_uint,
) -> CUresult { ) -> CUresult {
let global_state = driver::global_state()?; let global_state = driver::global_state()?;
let options = let options = LibraryOptions::load(library_options, num_library_options)?;
LibraryOptions::load(library_options, library_option_values, num_library_options)?;
let library = Library { let library = Library {
data: LibraryData::new(code as *mut c_void, options.preserve_binary)?, data: LibraryData::new(code as *mut c_void, options.preserve_binary)?,
modules: vec![OnceLock::new(); global_state.devices.len()], modules: vec![OnceLock::new(); global_state.devices.len()],
@ -76,7 +75,6 @@ struct LibraryOptions {
impl LibraryOptions { impl LibraryOptions {
unsafe fn load( unsafe fn load(
library_options: Option<&mut CUlibraryOption>, library_options: Option<&mut CUlibraryOption>,
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
num_library_options: ::core::ffi::c_uint, num_library_options: ::core::ffi::c_uint,
) -> Result<Self, CUerror> { ) -> Result<Self, CUerror> {
if num_library_options == 0 { if num_library_options == 0 {
@ -84,28 +82,17 @@ impl LibraryOptions {
preserve_binary: false, preserve_binary: false,
}); });
} }
let (library_options, library_option_values) = let library_options = match library_options {
match (library_options, library_option_values) { Some(library_options) => {
(Some(library_options), Some(library_option_values)) => { std::slice::from_raw_parts(library_options, num_library_options as usize)
let library_options = }
std::slice::from_raw_parts(library_options, num_library_options as usize); _ => return Err(CUerror::INVALID_VALUE),
let library_option_values = std::slice::from_raw_parts( };
library_option_values,
num_library_options as usize,
);
(library_options, library_option_values)
}
_ => return Err(CUerror::INVALID_VALUE),
};
let mut preserve_binary = false; let mut preserve_binary = false;
for (option, value) in library_options for option in library_options.iter().copied() {
.iter()
.copied()
.zip(library_option_values.iter())
{
match option { match option {
CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => { CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => {
preserve_binary = *(value.cast::<bool>()); preserve_binary = true;
} }
_ => return Err(CUerror::NOT_SUPPORTED), _ => return Err(CUerror::NOT_SUPPORTED),
} }
@ -156,10 +143,7 @@ mod tests {
use crate::tests::CudaApi; use crate::tests::CudaApi;
use cuda_macros::test_cuda; use cuda_macros::test_cuda;
use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts}; use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts};
use std::{ use std::{ffi::CStr, mem, ptr};
ffi::{c_void, CStr},
mem, ptr,
};
#[test_cuda] #[test_cuda]
unsafe fn library_loads_without_context(api: impl CudaApi) { unsafe fn library_loads_without_context(api: impl CudaApi) {
@ -191,7 +175,7 @@ mod tests {
ptr::null_mut(), ptr::null_mut(),
0, 0,
[CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(), [CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(),
[(&true as *const bool) as *mut c_void].as_mut_ptr(), ptr::null_mut(),
1, 1,
); );
assert_eq!( assert_eq!(

View file

@ -153,3 +153,11 @@ pub(crate) unsafe fn set_d8_async(
) -> hipError_t { ) -> hipError_t {
hipMemsetD8Async(dst_device, uc, n, stream) hipMemsetD8Async(dst_device, uc, n, stream)
} }
pub(crate) fn get_allocation_granularity(
_granularity: &mut usize,
_property: &cuda_types::cuda::CUmemAllocationProp,
_option: cuda_types::cuda::CUmemAllocationGranularity_flags,
) -> CUresult {
CUresult::ERROR_NOT_SUPPORTED
}

View file

@ -138,7 +138,7 @@ fn compile_from_ptx_and_cache(
} else { } else {
ptx_parser::parse_module_unchecked(text) ptx_parser::parse_module_unchecked(text)
}; };
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?; let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
comgr, comgr,
gcn_arch, gcn_arch,

View file

@ -93,6 +93,7 @@ cuda_macros::cuda_function_declarations!(
cuDeviceGetUuid_v2, cuDeviceGetUuid_v2,
cuDevicePrimaryCtxGetState, cuDevicePrimaryCtxGetState,
cuDevicePrimaryCtxRelease, cuDevicePrimaryCtxRelease,
cuDevicePrimaryCtxRelease_v2,
cuDevicePrimaryCtxReset, cuDevicePrimaryCtxReset,
cuDevicePrimaryCtxRetain, cuDevicePrimaryCtxRetain,
cuDeviceTotalMem_v2, cuDeviceTotalMem_v2,
@ -104,6 +105,7 @@ cuda_macros::cuda_function_declarations!(
cuEventSynchronize, cuEventSynchronize,
cuFuncGetAttribute, cuFuncGetAttribute,
cuFuncSetAttribute, cuFuncSetAttribute,
cuGetErrorString,
cuGetExportTable, cuGetExportTable,
cuGetProcAddress, cuGetProcAddress,
cuGetProcAddress_v2, cuGetProcAddress_v2,
@ -126,6 +128,7 @@ cuda_macros::cuda_function_declarations!(
cuMemFreeHost, cuMemFreeHost,
cuMemFree_v2, cuMemFree_v2,
cuMemGetAddressRange_v2, cuMemGetAddressRange_v2,
cuMemGetAllocationGranularity,
cuMemGetInfo_v2, cuMemGetInfo_v2,
cuMemHostAlloc, cuMemHostAlloc,
cuMemRetainAllocationHandle, cuMemRetainAllocationHandle,

View file

@ -165,7 +165,10 @@ from_cuda_nop!(
nvmlDevice_t, nvmlDevice_t,
nvmlFieldValue_t, nvmlFieldValue_t,
nvmlGpuFabricInfo_t, nvmlGpuFabricInfo_t,
cublasLtHandle_t cublasLtHandle_t,
CUmemAllocationGranularity_flags,
CUmemAllocationProp,
CUresult
); );
from_cuda_transmute!( from_cuda_transmute!(
CUuuid => hipUUID, CUuuid => hipUUID,