Random fixes (#504)

This is a collection of random changes coming from the workload I'm working on. The most important change is better support for `.params`: PTX uses .param namespace both for some local variables and kernel args. This is a problem for us because those are different address spaces on AMDGPU. So far we've made an effort to convert to local and const namespaces whenever possible, but this commit tries to handle more patterns, which are impossible to track precisely, by converting to generic space.
This commit is contained in:
Andrzej Janik 2025-09-12 22:52:33 +02:00 committed by GitHub
commit 6c811a55d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 282 additions and 64 deletions

View file

@ -6,6 +6,7 @@ use std::io::{self, Write};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::ExitCode; use std::process::ExitCode;
use std::str; use std::str;
use std::time::Instant;
use std::{env, mem}; use std::{env, mem};
mod error; mod error;
@ -57,19 +58,12 @@ fn main_core() -> Result<(), CompilerError> {
let arch: String = match opts.arch { let arch: String = match opts.arch {
Some(s) => s, Some(s) => s,
None => { None => (|| {
(|| {
let runtime = hip::Runtime::load()?; let runtime = hip::Runtime::load()?;
runtime.init()?; runtime.init()?;
get_gpu_arch(&runtime) get_gpu_arch(&runtime)
})() })()
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned()) .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()),
/*
get_gpu_arch(&mut dev_props)
.map(String::from)
.unwrap_or(DEFAULT_ARCH.to_owned())
*/
}
}; };
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
@ -83,6 +77,7 @@ fn main_core() -> Result<(), CompilerError> {
write_to_file(bytes, &output_path).unwrap(); write_to_file(bytes, &output_path).unwrap();
}; };
let mut start = Instant::now();
comgr::compile_bitcode( comgr::compile_bitcode(
&comgr, &comgr,
&arch, &arch,
@ -92,17 +87,22 @@ fn main_core() -> Result<(), CompilerError> {
Some(&comgr_hook), Some(&comgr_hook),
) )
.map_err(CompilerError::from)?; .map_err(CompilerError::from)?;
report_pass_time("compile_bitcode", &mut start);
Ok(()) Ok(())
} }
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> { fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
let mut start = Instant::now();
let module = ptx::to_llvm_module( let module = ptx::to_llvm_module(
ast, ast,
ptx::Attributes { ptx::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|pass| {
report_pass_time(pass, &mut start);
},
) )
.map_err(CompilerError::from)?; .map_err(CompilerError::from)?;
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
@ -117,6 +117,12 @@ fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
}) })
} }
fn report_pass_time(pass: &str, start: &mut Instant) {
let duration = start.elapsed();
println!("Pass {:?} took {:?}", pass, duration);
*start = Instant::now();
}
#[derive(Debug)] #[derive(Debug)]
struct LLVMArtifacts { struct LLVMArtifacts {
bitcode: Vec<u8>, bitcode: Vec<u8>,

View file

@ -313,7 +313,7 @@ fn join(
pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream { pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream {
let fn_ = parse_macro_input!(item as syn::ItemFn); let fn_ = parse_macro_input!(item as syn::ItemFn);
let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia"); let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia");
let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_amdgpu"); let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_zluda");
let fn_name = fn_.sig.ident.clone(); let fn_name = fn_.sig.ident.clone();
quote! { quote! {
#[test] #[test]

7
ext/rocm_smi-sys/build.rs vendored Normal file
View file

@ -0,0 +1,7 @@
use std::env::VarError;
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=rocm_smi64");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
Ok(())
}

View file

@ -114,6 +114,13 @@ fn run_statement<'a, 'input>(
result.push(Statement::Instruction(instruction)); result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction));
} }
Statement::Instruction(ast::Instruction::Mov { data, arguments }) => {
let instruction = visitor.visit_mov(data, arguments);
let instruction = ast::visit_map(instruction, visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::PtrAccess(ptr_access) => { Statement::PtrAccess(ptr_access) => {
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
let statement = statement.visit_map(visitor)?; let statement = statement.visit_map(visitor)?;
@ -293,6 +300,39 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
}) })
} }
fn visit_mov(
&mut self,
data: ptx_parser::MovDetails,
mut arguments: ptx_parser::MovArgs<SpirvWord>,
) -> ast::Instruction<SpirvWord> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
name,
new_space,
old_space,
} => {
let generic_var = self
.resolver
.register_unnamed(Some((data.typ.clone(), ast::StateSpace::Reg)));
self.pre.push(ast::Instruction::Cvta {
data: ast::CvtaDetails {
state_space: *new_space,
direction: ast::CvtaDirection::ExplicitToGeneric,
},
arguments: ast::CvtaArgs {
dst: generic_var,
src: *name,
},
});
arguments.src = generic_var;
}
}
}
ast::Instruction::Mov { data, arguments }
}
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> { fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space { let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,

View file

@ -152,11 +152,11 @@ fn is_addressable(this: ast::StateSpace) -> bool {
| ast::StateSpace::Generic | ast::StateSpace::Generic
| ast::StateSpace::Global | ast::StateSpace::Global
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => true, | ast::StateSpace::Shared
| ast::StateSpace::ParamEntry => true,
ast::StateSpace::Param | ast::StateSpace::Reg => false, ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta | ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(), | ast::StateSpace::ParamFunc => todo!(),
} }
} }
@ -180,13 +180,14 @@ fn default_implicit_conversion_space(
| ast::StateSpace::Generic | ast::StateSpace::Generic
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::Local | ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), | ast::StateSpace::Shared
| ast::StateSpace::Param => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
}, },
ast::Type::Scalar(ast::ScalarType::B32) ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr)) Ok(Some(ConversionKind::BitToPtr))
} }
_ => Err(error_mismatched_type()), _ => Err(error_mismatched_type()),
@ -220,7 +221,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool {
ast::StateSpace::Global ast::StateSpace::Global
| ast::StateSpace::Const | ast::StateSpace::Const
| ast::StateSpace::Local | ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta | ast::StateSpace::SharedCta
| ast::StateSpace::SharedCluster | ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true, | ast::StateSpace::Shared => true,
ast::StateSpace::Reg ast::StateSpace::Reg

View file

@ -703,7 +703,7 @@ impl<'a> MethodEmitContext<'a> {
}); });
Ok(()) Ok(())
} }
_ => todo!(), _ => return Err(error_todo()),
} }
} }
@ -2406,7 +2406,7 @@ impl<'a> MethodEmitContext<'a> {
(control >> 12) & 0b1111, (control >> 12) & 0b1111,
]; ];
if components.iter().any(|&c| c > 7) { if components.iter().any(|&c| c > 7) {
return Err(TranslateError::Todo("".to_string())); return Err(error_todo());
} }
let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;

View file

@ -169,8 +169,9 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
ast::ScalarType::U16x2 => todo!(), ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => unsafe {
ast::ScalarType::S16x2 => todo!(), LLVMVectorType(LLVMInt16TypeInContext(context), 2)
},
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) }, ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) }, ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) },
} }
@ -180,14 +181,17 @@ fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space { match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())), // This is dodgy, we try our best to convert all .param into either
// .param::entry or .local, but we can't always succeed.
// In those cases we convert .param into generic address space
ast::StateSpace::Param => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())), ast::StateSpace::ParamFunc => Err(error_todo()),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE),
ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())), ast::StateSpace::SharedCta => Err(error_todo()),
ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())), ast::StateSpace::SharedCluster => Err(error_todo()),
} }
} }

View file

@ -60,31 +60,48 @@ pub struct Attributes {
pub fn to_llvm_module<'input>( pub fn to_llvm_module<'input>(
ast: ast::Module<'input>, ast: ast::Module<'input>,
attributes: Attributes, attributes: Attributes,
mut on_pass_end: impl FnMut(&str),
) -> Result<Module, TranslateError> { ) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
on_pass_end("normalize_identifiers2");
let directives = replace_known_functions::run(&mut flat_resolver, directives); let directives = replace_known_functions::run(&mut flat_resolver, directives);
on_pass_end("replace_known_functions");
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_predicates2");
let directives = resolve_function_pointers::run(directives)?; let directives = resolve_function_pointers::run(directives)?;
on_pass_end("resolve_function_pointers");
let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?; let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?;
on_pass_end("fix_special_registers");
let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?;
on_pass_end("expand_operands");
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
on_pass_end("insert_post_saturation");
let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
on_pass_end("deparamize_functions");
let directives = let directives =
replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?; replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions_fp_required");
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_basic_blocks");
let directives = remove_unreachable_basic_blocks::run(directives)?; let directives = remove_unreachable_basic_blocks::run(directives)?;
on_pass_end("remove_unreachable_basic_blocks");
let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?;
on_pass_end("instruction_mode_to_global_mode");
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
on_pass_end("insert_explicit_load_store");
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
on_pass_end("insert_implicit_conversions2");
let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?; let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions");
let directives = hoist_globals::run(directives)?; let directives = hoist_globals::run(directives)?;
on_pass_end("hoist_globals");
let context = llvm::Context::new(); let context = llvm::Context::new();
let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?; let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?;
let attributes_ir = llvm::attributes::run(&context, attributes)?; let attributes_ir = llvm::attributes::run(&context, attributes)?;
on_pass_end("emit_llvm");
Ok(Module { Ok(Module {
llvm_ir, llvm_ir,
attributes_ir, attributes_ir,

View file

@ -0,0 +1,34 @@
define amdgpu_kernel void @param_is_addressable(ptr addrspace(4) byref(i64) %"33", ptr addrspace(4) byref(i64) %"34") #0 {
%"35" = alloca i64, align 8, addrspace(5)
%"36" = alloca i64, align 8, addrspace(5)
%"37" = alloca i64, align 8, addrspace(5)
br label %1
1: ; preds = %0
br label %"32"
"32": ; preds = %1
%"38" = load i64, ptr addrspace(4) %"33", align 8
store i64 %"38", ptr addrspace(5) %"35", align 8
%"39" = load i64, ptr addrspace(4) %"34", align 8
store i64 %"39", ptr addrspace(5) %"36", align 8
%"49" = ptrtoint ptr addrspace(4) %"33" to i64
%2 = inttoptr i64 %"49" to ptr addrspace(4)
%"40" = addrspacecast ptr addrspace(4) %2 to ptr
store ptr %"40", ptr addrspace(5) %"37", align 8
%"43" = load i64, ptr addrspace(5) %"37", align 8
%"50" = inttoptr i64 %"43" to ptr
%"42" = load i64, ptr %"50", align 8
store i64 %"42", ptr addrspace(5) %"37", align 8
%"45" = load i64, ptr addrspace(5) %"37", align 8
%"46" = load i64, ptr addrspace(5) %"35", align 8
%"51" = sub i64 %"45", %"46"
store i64 %"51", ptr addrspace(5) %"37", align 8
%"47" = load i64, ptr addrspace(5) %"36", align 8
%"48" = load i64, ptr addrspace(5) %"37", align 8
%"53" = inttoptr i64 %"47" to ptr
store i64 %"48", ptr %"53", align 8
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -38,7 +38,7 @@ fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> {
let attributes = pass::Attributes { let attributes = pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}; };
crate::to_llvm_module(ast, attributes)?; crate::to_llvm_module(ast, attributes, |_| {})?;
Ok(()) Ok(())
} }

View file

@ -339,6 +339,7 @@ test_ptx!(
[0x8e2da590u32, 0xedeaee14, 0x248a9f70], [0x8e2da590u32, 0xedeaee14, 0x248a9f70],
[613065134u32] [613065134u32]
); );
test_ptx!(param_is_addressable, [0xDEAD], [0u64]);
test_ptx!(assertfail); test_ptx!(assertfail);
// TODO: not yet supported // TODO: not yet supported
@ -556,6 +557,7 @@ fn test_hip_assert<
pass::Attributes { pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let name = CString::new(name)?; let name = CString::new(name)?;
@ -576,6 +578,7 @@ fn test_llvm_assert(
pass::Attributes { pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); let actual_ll = llvm_ir.llvm_ir.print_module_to_string();

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry param_is_addressable(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b64 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
mov.b64 temp, input;
ld.param.b64 temp, [temp];
sub.u64 temp, temp, in_addr;
st.u64 [out_addr], temp;
ret;
}

View file

@ -892,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>(
( (
vector_prefix, vector_prefix,
scalar_type, scalar_type,
opt((Token::DotPtr, Token::DotGlobal)), opt((Token::DotPtr, opt(Token::DotGlobal))),
opt(align.verify(|x| x.count_ones() == 1)), opt(align.verify(|x| x.count_ones() == 1)),
ident, ident,
), ),

View file

@ -42,6 +42,7 @@ fn main() {
ptx::Attributes { ptx::Attributes {
clock_rate: clock_rate as u32, clock_rate: clock_rate as u32,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let elf_binary = comgr::compile_bitcode( let elf_binary = comgr::compile_bitcode(

Binary file not shown.

View file

@ -494,6 +494,10 @@ pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> CUresult {
Ok(()) Ok(())
} }
pub(crate) fn primary_context_release_v2(hip_dev: hipDevice_t) -> CUresult {
primary_context_release(hip_dev)
}
pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult { pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult {
let (ctx, _) = get_primary_context(hip_dev)?; let (ctx, _) = get_primary_context(hip_dev)?;
ctx.with_state_mut(|state| state.reset())?; ctx.with_state_mut(|state| state.reset())?;

View file

@ -166,7 +166,9 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
None => return CUresult::ERROR_INVALID_VALUE, None => return CUresult::ERROR_INVALID_VALUE,
}; };
device::primary_context_retain(pctx, hip_dev) let (_, cu_ctx) = device::get_primary_context(hip_dev)?;
*pctx = cu_ctx;
Ok(())
} }
unsafe extern "system" fn get_module_from_cubin_ext1( unsafe extern "system" fn get_module_from_cubin_ext1(
@ -362,8 +364,8 @@ fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
(0..device_count) (0..device_count)
.map(|dev| { .map(|dev| {
let mut guid = CUuuid_st { bytes: [0; 16] }; let mut guid = unsafe { mem::zeroed() };
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; device::get_uuid_v2(&mut guid, dev)?;
let mut pci_domain = 0; let mut pci_domain = 0;
device::get_attribute( device::get_attribute(
@ -387,7 +389,7 @@ fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
)?; )?;
Ok(::dark_api::DeviceHashinfo { Ok(::dark_api::DeviceHashinfo {
guid, guid: unsafe { mem::transmute(guid) },
pci_domain, pci_domain,
pci_bus, pci_bus,
pci_device, pci_device,
@ -525,8 +527,57 @@ pub(crate) unsafe fn launch_kernel_ex(
Ok(()) Ok(())
} }
pub(crate) unsafe fn get_error_string(
error: cuda_types::cuda::CUresult,
error_string: &mut *const ::core::ffi::c_char,
) -> CUresult {
*error_string = match error {
CUresult::SUCCESS => c"no error".as_ptr(),
CUresult::ERROR_INVALID_VALUE => c"invalid value".as_ptr(),
CUresult::ERROR_OUT_OF_MEMORY => c"out of memory".as_ptr(),
CUresult::ERROR_NOT_INITIALIZED => c"driver not initialized".as_ptr(),
CUresult::ERROR_DEINITIALIZED => c"driver deinitialized".as_ptr(),
CUresult::ERROR_NO_DEVICE => c"no CUDA-capable device is detected".as_ptr(),
CUresult::ERROR_INVALID_DEVICE => c"invalid device".as_ptr(),
CUresult::ERROR_INVALID_IMAGE => c"invalid kernel image".as_ptr(),
CUresult::ERROR_INVALID_CONTEXT => c"invalid context".as_ptr(),
CUresult::ERROR_CONTEXT_ALREADY_CURRENT => c"context already current".as_ptr(),
CUresult::ERROR_MAP_FAILED => c"map failed".as_ptr(),
CUresult::ERROR_UNMAP_FAILED => c"unmap failed".as_ptr(),
CUresult::ERROR_ARRAY_IS_MAPPED => c"array is mapped".as_ptr(),
CUresult::ERROR_ALREADY_MAPPED => c"already mapped".as_ptr(),
CUresult::ERROR_NO_BINARY_FOR_GPU => c"no binary for GPU".as_ptr(),
CUresult::ERROR_ALREADY_ACQUIRED => c"already acquired".as_ptr(),
CUresult::ERROR_NOT_MAPPED => c"not mapped".as_ptr(),
CUresult::ERROR_NOT_SUPPORTED => c"operation not supported".as_ptr(),
CUresult::ERROR_INVALID_SOURCE => c"invalid source".as_ptr(),
CUresult::ERROR_FILE_NOT_FOUND => c"file not found".as_ptr(),
CUresult::ERROR_INVALID_HANDLE => c"invalid handle".as_ptr(),
CUresult::ERROR_NOT_READY => c"not ready".as_ptr(),
CUresult::ERROR_ILLEGAL_ADDRESS => c"illegal address".as_ptr(),
CUresult::ERROR_LAUNCH_OUT_OF_RESOURCES => c"launch out of resources".as_ptr(),
CUresult::ERROR_LAUNCH_TIMEOUT => c"launch timeout".as_ptr(),
CUresult::ERROR_LAUNCH_INCOMPATIBLE_TEXTURING => c"launch incompatible texturing".as_ptr(),
CUresult::ERROR_PEER_ACCESS_ALREADY_ENABLED => c"peer access already enabled".as_ptr(),
CUresult::ERROR_PEER_ACCESS_NOT_ENABLED => c"peer access not enabled".as_ptr(),
CUresult::ERROR_PRIMARY_CONTEXT_ACTIVE => c"primary context active".as_ptr(),
CUresult::ERROR_CONTEXT_IS_DESTROYED => c"context is destroyed".as_ptr(),
CUresult::ERROR_ASSERT => c"device-side assert triggered".as_ptr(),
CUresult::ERROR_TOO_MANY_PEERS => c"too many peers".as_ptr(),
CUresult::ERROR_HOST_MEMORY_ALREADY_REGISTERED => {
c"host memory already registered".as_ptr()
}
CUresult::ERROR_HOST_MEMORY_NOT_REGISTERED => c"host memory not registered".as_ptr(),
CUresult::ERROR_UNKNOWN => c"unknown error".as_ptr(),
_ => c"error".as_ptr(),
};
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::i32;
use crate::r#impl::driver::AllocationInfo; use crate::r#impl::driver::AllocationInfo;
use crate::tests::CudaApi; use crate::tests::CudaApi;
use cuda_macros::test_cuda; use cuda_macros::test_cuda;
@ -571,4 +622,34 @@ mod tests {
} }
assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None); assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None);
} }
#[test_cuda]
fn primary_context_is_inactive_on_init(api: impl CudaApi) {
api.cuInit(0);
let mut flags = u32::MAX;
let mut active = i32::MAX;
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
assert_eq!(flags, 0);
assert_eq!(active, 0);
}
#[test_cuda]
unsafe fn cudart_interface_fn2_creates_inactive_primary_ctx(api: impl CudaApi) {
api.cuInit(0);
let mut table_ptr = std::ptr::null();
api.cuGetExportTable(&mut table_ptr, &dark_api::cuda::CudartInterface::GUID);
let cuda_rt_iface = dark_api::cuda::CudartInterface::new(table_ptr);
let mut dark_ctx = std::mem::zeroed();
cuda_rt_iface
.cudart_interface_fn2(&mut dark_ctx, 0)
.unwrap();
let mut flags = u32::MAX;
let mut active = i32::MAX;
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
assert_eq!(flags, 0);
assert_eq!(active, 0);
let mut primary_ctx = std::mem::zeroed();
api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0);
assert_eq!(dark_ctx.0, primary_ctx.0);
}
} }

View file

@ -55,12 +55,11 @@ pub(crate) unsafe fn load_data(
_jit_options_values: Option<&mut *mut ::core::ffi::c_void>, _jit_options_values: Option<&mut *mut ::core::ffi::c_void>,
_num_jit_options: ::core::ffi::c_uint, _num_jit_options: ::core::ffi::c_uint,
library_options: Option<&mut CUlibraryOption>, library_options: Option<&mut CUlibraryOption>,
library_option_values: Option<&mut *mut ::core::ffi::c_void>, _library_option_values: Option<&mut *mut ::core::ffi::c_void>,
num_library_options: ::core::ffi::c_uint, num_library_options: ::core::ffi::c_uint,
) -> CUresult { ) -> CUresult {
let global_state = driver::global_state()?; let global_state = driver::global_state()?;
let options = let options = LibraryOptions::load(library_options, num_library_options)?;
LibraryOptions::load(library_options, library_option_values, num_library_options)?;
let library = Library { let library = Library {
data: LibraryData::new(code as *mut c_void, options.preserve_binary)?, data: LibraryData::new(code as *mut c_void, options.preserve_binary)?,
modules: vec![OnceLock::new(); global_state.devices.len()], modules: vec![OnceLock::new(); global_state.devices.len()],
@ -76,7 +75,6 @@ struct LibraryOptions {
impl LibraryOptions { impl LibraryOptions {
unsafe fn load( unsafe fn load(
library_options: Option<&mut CUlibraryOption>, library_options: Option<&mut CUlibraryOption>,
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
num_library_options: ::core::ffi::c_uint, num_library_options: ::core::ffi::c_uint,
) -> Result<Self, CUerror> { ) -> Result<Self, CUerror> {
if num_library_options == 0 { if num_library_options == 0 {
@ -84,28 +82,17 @@ impl LibraryOptions {
preserve_binary: false, preserve_binary: false,
}); });
} }
let (library_options, library_option_values) = let library_options = match library_options {
match (library_options, library_option_values) { Some(library_options) => {
(Some(library_options), Some(library_option_values)) => { std::slice::from_raw_parts(library_options, num_library_options as usize)
let library_options =
std::slice::from_raw_parts(library_options, num_library_options as usize);
let library_option_values = std::slice::from_raw_parts(
library_option_values,
num_library_options as usize,
);
(library_options, library_option_values)
} }
_ => return Err(CUerror::INVALID_VALUE), _ => return Err(CUerror::INVALID_VALUE),
}; };
let mut preserve_binary = false; let mut preserve_binary = false;
for (option, value) in library_options for option in library_options.iter().copied() {
.iter()
.copied()
.zip(library_option_values.iter())
{
match option { match option {
CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => { CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => {
preserve_binary = *(value.cast::<bool>()); preserve_binary = true;
} }
_ => return Err(CUerror::NOT_SUPPORTED), _ => return Err(CUerror::NOT_SUPPORTED),
} }
@ -156,10 +143,7 @@ mod tests {
use crate::tests::CudaApi; use crate::tests::CudaApi;
use cuda_macros::test_cuda; use cuda_macros::test_cuda;
use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts}; use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts};
use std::{ use std::{ffi::CStr, mem, ptr};
ffi::{c_void, CStr},
mem, ptr,
};
#[test_cuda] #[test_cuda]
unsafe fn library_loads_without_context(api: impl CudaApi) { unsafe fn library_loads_without_context(api: impl CudaApi) {
@ -191,7 +175,7 @@ mod tests {
ptr::null_mut(), ptr::null_mut(),
0, 0,
[CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(), [CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(),
[(&true as *const bool) as *mut c_void].as_mut_ptr(), ptr::null_mut(),
1, 1,
); );
assert_eq!( assert_eq!(

View file

@ -153,3 +153,11 @@ pub(crate) unsafe fn set_d8_async(
) -> hipError_t { ) -> hipError_t {
hipMemsetD8Async(dst_device, uc, n, stream) hipMemsetD8Async(dst_device, uc, n, stream)
} }
pub(crate) fn get_allocation_granularity(
_granularity: &mut usize,
_property: &cuda_types::cuda::CUmemAllocationProp,
_option: cuda_types::cuda::CUmemAllocationGranularity_flags,
) -> CUresult {
CUresult::ERROR_NOT_SUPPORTED
}

View file

@ -138,7 +138,7 @@ fn compile_from_ptx_and_cache(
} else { } else {
ptx_parser::parse_module_unchecked(text) ptx_parser::parse_module_unchecked(text)
}; };
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?; let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
comgr, comgr,
gcn_arch, gcn_arch,

View file

@ -93,6 +93,7 @@ cuda_macros::cuda_function_declarations!(
cuDeviceGetUuid_v2, cuDeviceGetUuid_v2,
cuDevicePrimaryCtxGetState, cuDevicePrimaryCtxGetState,
cuDevicePrimaryCtxRelease, cuDevicePrimaryCtxRelease,
cuDevicePrimaryCtxRelease_v2,
cuDevicePrimaryCtxReset, cuDevicePrimaryCtxReset,
cuDevicePrimaryCtxRetain, cuDevicePrimaryCtxRetain,
cuDeviceTotalMem_v2, cuDeviceTotalMem_v2,
@ -104,6 +105,7 @@ cuda_macros::cuda_function_declarations!(
cuEventSynchronize, cuEventSynchronize,
cuFuncGetAttribute, cuFuncGetAttribute,
cuFuncSetAttribute, cuFuncSetAttribute,
cuGetErrorString,
cuGetExportTable, cuGetExportTable,
cuGetProcAddress, cuGetProcAddress,
cuGetProcAddress_v2, cuGetProcAddress_v2,
@ -126,6 +128,7 @@ cuda_macros::cuda_function_declarations!(
cuMemFreeHost, cuMemFreeHost,
cuMemFree_v2, cuMemFree_v2,
cuMemGetAddressRange_v2, cuMemGetAddressRange_v2,
cuMemGetAllocationGranularity,
cuMemGetInfo_v2, cuMemGetInfo_v2,
cuMemHostAlloc, cuMemHostAlloc,
cuMemRetainAllocationHandle, cuMemRetainAllocationHandle,

View file

@ -165,7 +165,10 @@ from_cuda_nop!(
nvmlDevice_t, nvmlDevice_t,
nvmlFieldValue_t, nvmlFieldValue_t,
nvmlGpuFabricInfo_t, nvmlGpuFabricInfo_t,
cublasLtHandle_t cublasLtHandle_t,
CUmemAllocationGranularity_flags,
CUmemAllocationProp,
CUresult
); );
from_cuda_transmute!( from_cuda_transmute!(
CUuuid => hipUUID, CUuuid => hipUUID,