mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-27 11:49:04 +00:00
Merge branch 'master' into nan
This commit is contained in:
commit
25bd3861da
22 changed files with 282 additions and 64 deletions
|
@ -6,6 +6,7 @@ use std::io::{self, Write};
|
|||
use std::path::{Path, PathBuf};
|
||||
use std::process::ExitCode;
|
||||
use std::str;
|
||||
use std::time::Instant;
|
||||
use std::{env, mem};
|
||||
|
||||
mod error;
|
||||
|
@ -57,19 +58,12 @@ fn main_core() -> Result<(), CompilerError> {
|
|||
|
||||
let arch: String = match opts.arch {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
(|| {
|
||||
let runtime = hip::Runtime::load()?;
|
||||
runtime.init()?;
|
||||
get_gpu_arch(&runtime)
|
||||
})()
|
||||
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned())
|
||||
/*
|
||||
get_gpu_arch(&mut dev_props)
|
||||
.map(String::from)
|
||||
.unwrap_or(DEFAULT_ARCH.to_owned())
|
||||
*/
|
||||
}
|
||||
None => (|| {
|
||||
let runtime = hip::Runtime::load()?;
|
||||
runtime.init()?;
|
||||
get_gpu_arch(&runtime)
|
||||
})()
|
||||
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned()),
|
||||
};
|
||||
|
||||
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
|
||||
|
@ -83,6 +77,7 @@ fn main_core() -> Result<(), CompilerError> {
|
|||
write_to_file(bytes, &output_path).unwrap();
|
||||
};
|
||||
|
||||
let mut start = Instant::now();
|
||||
comgr::compile_bitcode(
|
||||
&comgr,
|
||||
&arch,
|
||||
|
@ -92,17 +87,22 @@ fn main_core() -> Result<(), CompilerError> {
|
|||
Some(&comgr_hook),
|
||||
)
|
||||
.map_err(CompilerError::from)?;
|
||||
report_pass_time("compile_bitcode", &mut start);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
|
||||
let mut start = Instant::now();
|
||||
let module = ptx::to_llvm_module(
|
||||
ast,
|
||||
ptx::Attributes {
|
||||
clock_rate: 2124000,
|
||||
},
|
||||
|pass| {
|
||||
report_pass_time(pass, &mut start);
|
||||
},
|
||||
)
|
||||
.map_err(CompilerError::from)?;
|
||||
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
|
||||
|
@ -117,6 +117,12 @@ fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
|
|||
})
|
||||
}
|
||||
|
||||
fn report_pass_time(pass: &str, start: &mut Instant) {
|
||||
let duration = start.elapsed();
|
||||
println!("Pass {:?} took {:?}", pass, duration);
|
||||
*start = Instant::now();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LLVMArtifacts {
|
||||
bitcode: Vec<u8>,
|
||||
|
|
|
@ -313,7 +313,7 @@ fn join(
|
|||
pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let fn_ = parse_macro_input!(item as syn::ItemFn);
|
||||
let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia");
|
||||
let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_amdgpu");
|
||||
let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_zluda");
|
||||
let fn_name = fn_.sig.ident.clone();
|
||||
quote! {
|
||||
#[test]
|
||||
|
|
7
ext/rocm_smi-sys/build.rs
vendored
Normal file
7
ext/rocm_smi-sys/build.rs
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
use std::env::VarError;
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
println!("cargo:rustc-link-lib=dylib=rocm_smi64");
|
||||
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
|
||||
Ok(())
|
||||
}
|
|
@ -114,6 +114,13 @@ fn run_statement<'a, 'input>(
|
|||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Mov { data, arguments }) => {
|
||||
let instruction = visitor.visit_mov(data, arguments);
|
||||
let instruction = ast::visit_map(instruction, visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::PtrAccess(ptr_access) => {
|
||||
let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
|
||||
let statement = statement.visit_map(visitor)?;
|
||||
|
@ -293,6 +300,39 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
|||
})
|
||||
}
|
||||
|
||||
fn visit_mov(
|
||||
&mut self,
|
||||
data: ptx_parser::MovDetails,
|
||||
mut arguments: ptx_parser::MovArgs<SpirvWord>,
|
||||
) -> ast::Instruction<SpirvWord> {
|
||||
if let Some(remap) = self.variables.get(&arguments.src) {
|
||||
match remap {
|
||||
RemapAction::PreLdPostSt { .. } => {}
|
||||
RemapAction::LDStSpaceChange {
|
||||
name,
|
||||
new_space,
|
||||
old_space,
|
||||
} => {
|
||||
let generic_var = self
|
||||
.resolver
|
||||
.register_unnamed(Some((data.typ.clone(), ast::StateSpace::Reg)));
|
||||
self.pre.push(ast::Instruction::Cvta {
|
||||
data: ast::CvtaDetails {
|
||||
state_space: *new_space,
|
||||
direction: ast::CvtaDirection::ExplicitToGeneric,
|
||||
},
|
||||
arguments: ast::CvtaArgs {
|
||||
dst: generic_var,
|
||||
src: *name,
|
||||
},
|
||||
});
|
||||
arguments.src = generic_var;
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::Instruction::Mov { data, arguments }
|
||||
}
|
||||
|
||||
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
|
||||
let old_space = match var.state_space {
|
||||
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
|
||||
|
|
|
@ -152,11 +152,11 @@ fn is_addressable(this: ast::StateSpace) -> bool {
|
|||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => true,
|
||||
| ast::StateSpace::Shared
|
||||
| ast::StateSpace::ParamEntry => true,
|
||||
ast::StateSpace::Param | ast::StateSpace::Reg => false,
|
||||
ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc => todo!(),
|
||||
}
|
||||
}
|
||||
|
@ -180,13 +180,14 @@ fn default_implicit_conversion_space(
|
|||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||
| ast::StateSpace::Shared
|
||||
| ast::StateSpace::Param => Ok(Some(ConversionKind::BitToPtr)),
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
ast::Type::Scalar(ast::ScalarType::B32)
|
||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||
ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||
Ok(Some(ConversionKind::BitToPtr))
|
||||
}
|
||||
_ => Err(error_mismatched_type()),
|
||||
|
@ -220,7 +221,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
|||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Reg
|
||||
|
|
|
@ -704,7 +704,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
});
|
||||
Ok(())
|
||||
}
|
||||
_ => todo!(),
|
||||
_ => return Err(error_todo()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2455,7 +2455,7 @@ impl<'a> MethodEmitContext<'a> {
|
|||
(control >> 12) & 0b1111,
|
||||
];
|
||||
if components.iter().any(|&c| c > 7) {
|
||||
return Err(TranslateError::Todo("".to_string()));
|
||||
return Err(error_todo());
|
||||
}
|
||||
let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
|
||||
let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;
|
||||
|
|
|
@ -169,8 +169,9 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||
ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) },
|
||||
ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) },
|
||||
ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) },
|
||||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
ast::ScalarType::U16x2 | ast::ScalarType::S16x2 => unsafe {
|
||||
LLVMVectorType(LLVMInt16TypeInContext(context), 2)
|
||||
},
|
||||
ast::ScalarType::F16x2 => unsafe { LLVMVectorType(LLVMHalfTypeInContext(context), 2) },
|
||||
ast::ScalarType::BF16x2 => unsafe { LLVMVectorType(LLVMBFloatTypeInContext(context), 2) },
|
||||
}
|
||||
|
@ -180,14 +181,17 @@ fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
|
|||
match space {
|
||||
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|
||||
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
|
||||
ast::StateSpace::Param => Err(TranslateError::Todo("".to_string())),
|
||||
// This is dodgy, we try our best to convert all .param into either
|
||||
// .param::entry or .local, but we can't always succeed.
|
||||
// In those cases we convert .param into generic address space
|
||||
ast::StateSpace::Param => Ok(GENERIC_ADDRESS_SPACE),
|
||||
ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
|
||||
ast::StateSpace::ParamFunc => Err(TranslateError::Todo("".to_string())),
|
||||
ast::StateSpace::ParamFunc => Err(error_todo()),
|
||||
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
|
||||
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
|
||||
ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE),
|
||||
ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE),
|
||||
ast::StateSpace::SharedCta => Err(TranslateError::Todo("".to_string())),
|
||||
ast::StateSpace::SharedCluster => Err(TranslateError::Todo("".to_string())),
|
||||
ast::StateSpace::SharedCta => Err(error_todo()),
|
||||
ast::StateSpace::SharedCluster => Err(error_todo()),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,31 +60,48 @@ pub struct Attributes {
|
|||
pub fn to_llvm_module<'input>(
|
||||
ast: ast::Module<'input>,
|
||||
attributes: Attributes,
|
||||
mut on_pass_end: impl FnMut(&str),
|
||||
) -> Result<Module, TranslateError> {
|
||||
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
|
||||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||
let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?;
|
||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
|
||||
on_pass_end("normalize_identifiers2");
|
||||
let directives = replace_known_functions::run(&mut flat_resolver, directives);
|
||||
on_pass_end("replace_known_functions");
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("normalize_predicates2");
|
||||
let directives = resolve_function_pointers::run(directives)?;
|
||||
on_pass_end("resolve_function_pointers");
|
||||
let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||
on_pass_end("fix_special_registers");
|
||||
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("expand_operands");
|
||||
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("insert_post_saturation");
|
||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("deparamize_functions");
|
||||
let directives =
|
||||
replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("replace_instructions_with_functions_fp_required");
|
||||
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("normalize_basic_blocks");
|
||||
let directives = remove_unreachable_basic_blocks::run(directives)?;
|
||||
on_pass_end("remove_unreachable_basic_blocks");
|
||||
let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("instruction_mode_to_global_mode");
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("insert_explicit_load_store");
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("insert_implicit_conversions2");
|
||||
let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?;
|
||||
on_pass_end("replace_instructions_with_functions");
|
||||
let directives = hoist_globals::run(directives)?;
|
||||
|
||||
on_pass_end("hoist_globals");
|
||||
let context = llvm::Context::new();
|
||||
let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?;
|
||||
let attributes_ir = llvm::attributes::run(&context, attributes)?;
|
||||
on_pass_end("emit_llvm");
|
||||
Ok(Module {
|
||||
llvm_ir,
|
||||
attributes_ir,
|
||||
|
|
34
ptx/src/test/ll/param_is_addressable.ll
Normal file
34
ptx/src/test/ll/param_is_addressable.ll
Normal file
|
@ -0,0 +1,34 @@
|
|||
define amdgpu_kernel void @param_is_addressable(ptr addrspace(4) byref(i64) %"33", ptr addrspace(4) byref(i64) %"34") #0 {
|
||||
%"35" = alloca i64, align 8, addrspace(5)
|
||||
%"36" = alloca i64, align 8, addrspace(5)
|
||||
%"37" = alloca i64, align 8, addrspace(5)
|
||||
br label %1
|
||||
|
||||
1: ; preds = %0
|
||||
br label %"32"
|
||||
|
||||
"32": ; preds = %1
|
||||
%"38" = load i64, ptr addrspace(4) %"33", align 8
|
||||
store i64 %"38", ptr addrspace(5) %"35", align 8
|
||||
%"39" = load i64, ptr addrspace(4) %"34", align 8
|
||||
store i64 %"39", ptr addrspace(5) %"36", align 8
|
||||
%"49" = ptrtoint ptr addrspace(4) %"33" to i64
|
||||
%2 = inttoptr i64 %"49" to ptr addrspace(4)
|
||||
%"40" = addrspacecast ptr addrspace(4) %2 to ptr
|
||||
store ptr %"40", ptr addrspace(5) %"37", align 8
|
||||
%"43" = load i64, ptr addrspace(5) %"37", align 8
|
||||
%"50" = inttoptr i64 %"43" to ptr
|
||||
%"42" = load i64, ptr %"50", align 8
|
||||
store i64 %"42", ptr addrspace(5) %"37", align 8
|
||||
%"45" = load i64, ptr addrspace(5) %"37", align 8
|
||||
%"46" = load i64, ptr addrspace(5) %"35", align 8
|
||||
%"51" = sub i64 %"45", %"46"
|
||||
store i64 %"51", ptr addrspace(5) %"37", align 8
|
||||
%"47" = load i64, ptr addrspace(5) %"36", align 8
|
||||
%"48" = load i64, ptr addrspace(5) %"37", align 8
|
||||
%"53" = inttoptr i64 %"47" to ptr
|
||||
store i64 %"48", ptr %"53", align 8
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
|
@ -38,7 +38,7 @@ fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> {
|
|||
let attributes = pass::Attributes {
|
||||
clock_rate: 2124000,
|
||||
};
|
||||
crate::to_llvm_module(ast, attributes)?;
|
||||
crate::to_llvm_module(ast, attributes, |_| {})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -345,6 +345,7 @@ test_ptx!(
|
|||
[0x8e2da590u32, 0xedeaee14, 0x248a9f70],
|
||||
[613065134u32]
|
||||
);
|
||||
test_ptx!(param_is_addressable, [0xDEAD], [0u64]);
|
||||
|
||||
test_ptx!(assertfail);
|
||||
// TODO: not yet supported
|
||||
|
@ -562,6 +563,7 @@ fn test_hip_assert<
|
|||
pass::Attributes {
|
||||
clock_rate: 2124000,
|
||||
},
|
||||
|_| {},
|
||||
)
|
||||
.unwrap();
|
||||
let name = CString::new(name)?;
|
||||
|
@ -582,6 +584,7 @@ fn test_llvm_assert(
|
|||
pass::Attributes {
|
||||
clock_rate: 2124000,
|
||||
},
|
||||
|_| {},
|
||||
)
|
||||
.unwrap();
|
||||
let actual_ll = llvm_ir.llvm_ir.print_module_to_string();
|
||||
|
|
22
ptx/src/test/spirv_run/param_is_addressable.ptx
Normal file
22
ptx/src/test/spirv_run/param_is_addressable.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry param_is_addressable(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b64 temp;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
mov.b64 temp, input;
|
||||
ld.param.b64 temp, [temp];
|
||||
sub.u64 temp, temp, in_addr;
|
||||
st.u64 [out_addr], temp;
|
||||
ret;
|
||||
}
|
|
@ -892,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>(
|
|||
(
|
||||
vector_prefix,
|
||||
scalar_type,
|
||||
opt((Token::DotPtr, Token::DotGlobal)),
|
||||
opt((Token::DotPtr, opt(Token::DotGlobal))),
|
||||
opt(align.verify(|x| x.count_ones() == 1)),
|
||||
ident,
|
||||
),
|
||||
|
|
|
@ -42,6 +42,7 @@ fn main() {
|
|||
ptx::Attributes {
|
||||
clock_rate: clock_rate as u32,
|
||||
},
|
||||
|_| {},
|
||||
)
|
||||
.unwrap();
|
||||
let elf_binary = comgr::compile_bitcode(
|
||||
|
|
Binary file not shown.
|
@ -494,6 +494,10 @@ pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> CUresult {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn primary_context_release_v2(hip_dev: hipDevice_t) -> CUresult {
|
||||
primary_context_release(hip_dev)
|
||||
}
|
||||
|
||||
pub(crate) fn primary_context_reset(hip_dev: hipDevice_t) -> CUresult {
|
||||
let (ctx, _) = get_primary_context(hip_dev)?;
|
||||
ctx.with_state_mut(|state| state.reset())?;
|
||||
|
|
|
@ -166,7 +166,9 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
|||
None => return CUresult::ERROR_INVALID_VALUE,
|
||||
};
|
||||
|
||||
device::primary_context_retain(pctx, hip_dev)
|
||||
let (_, cu_ctx) = device::get_primary_context(hip_dev)?;
|
||||
*pctx = cu_ctx;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe extern "system" fn get_module_from_cubin_ext1(
|
||||
|
@ -362,8 +364,8 @@ fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
|
|||
|
||||
(0..device_count)
|
||||
.map(|dev| {
|
||||
let mut guid = CUuuid_st { bytes: [0; 16] };
|
||||
unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? };
|
||||
let mut guid = unsafe { mem::zeroed() };
|
||||
device::get_uuid_v2(&mut guid, dev)?;
|
||||
|
||||
let mut pci_domain = 0;
|
||||
device::get_attribute(
|
||||
|
@ -387,7 +389,7 @@ fn get_device_hash_info() -> Result<Vec<::dark_api::DeviceHashinfo>, CUerror> {
|
|||
)?;
|
||||
|
||||
Ok(::dark_api::DeviceHashinfo {
|
||||
guid,
|
||||
guid: unsafe { mem::transmute(guid) },
|
||||
pci_domain,
|
||||
pci_bus,
|
||||
pci_device,
|
||||
|
@ -525,8 +527,57 @@ pub(crate) unsafe fn launch_kernel_ex(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn get_error_string(
|
||||
error: cuda_types::cuda::CUresult,
|
||||
error_string: &mut *const ::core::ffi::c_char,
|
||||
) -> CUresult {
|
||||
*error_string = match error {
|
||||
CUresult::SUCCESS => c"no error".as_ptr(),
|
||||
CUresult::ERROR_INVALID_VALUE => c"invalid value".as_ptr(),
|
||||
CUresult::ERROR_OUT_OF_MEMORY => c"out of memory".as_ptr(),
|
||||
CUresult::ERROR_NOT_INITIALIZED => c"driver not initialized".as_ptr(),
|
||||
CUresult::ERROR_DEINITIALIZED => c"driver deinitialized".as_ptr(),
|
||||
CUresult::ERROR_NO_DEVICE => c"no CUDA-capable device is detected".as_ptr(),
|
||||
CUresult::ERROR_INVALID_DEVICE => c"invalid device".as_ptr(),
|
||||
CUresult::ERROR_INVALID_IMAGE => c"invalid kernel image".as_ptr(),
|
||||
CUresult::ERROR_INVALID_CONTEXT => c"invalid context".as_ptr(),
|
||||
CUresult::ERROR_CONTEXT_ALREADY_CURRENT => c"context already current".as_ptr(),
|
||||
CUresult::ERROR_MAP_FAILED => c"map failed".as_ptr(),
|
||||
CUresult::ERROR_UNMAP_FAILED => c"unmap failed".as_ptr(),
|
||||
CUresult::ERROR_ARRAY_IS_MAPPED => c"array is mapped".as_ptr(),
|
||||
CUresult::ERROR_ALREADY_MAPPED => c"already mapped".as_ptr(),
|
||||
CUresult::ERROR_NO_BINARY_FOR_GPU => c"no binary for GPU".as_ptr(),
|
||||
CUresult::ERROR_ALREADY_ACQUIRED => c"already acquired".as_ptr(),
|
||||
CUresult::ERROR_NOT_MAPPED => c"not mapped".as_ptr(),
|
||||
CUresult::ERROR_NOT_SUPPORTED => c"operation not supported".as_ptr(),
|
||||
CUresult::ERROR_INVALID_SOURCE => c"invalid source".as_ptr(),
|
||||
CUresult::ERROR_FILE_NOT_FOUND => c"file not found".as_ptr(),
|
||||
CUresult::ERROR_INVALID_HANDLE => c"invalid handle".as_ptr(),
|
||||
CUresult::ERROR_NOT_READY => c"not ready".as_ptr(),
|
||||
CUresult::ERROR_ILLEGAL_ADDRESS => c"illegal address".as_ptr(),
|
||||
CUresult::ERROR_LAUNCH_OUT_OF_RESOURCES => c"launch out of resources".as_ptr(),
|
||||
CUresult::ERROR_LAUNCH_TIMEOUT => c"launch timeout".as_ptr(),
|
||||
CUresult::ERROR_LAUNCH_INCOMPATIBLE_TEXTURING => c"launch incompatible texturing".as_ptr(),
|
||||
CUresult::ERROR_PEER_ACCESS_ALREADY_ENABLED => c"peer access already enabled".as_ptr(),
|
||||
CUresult::ERROR_PEER_ACCESS_NOT_ENABLED => c"peer access not enabled".as_ptr(),
|
||||
CUresult::ERROR_PRIMARY_CONTEXT_ACTIVE => c"primary context active".as_ptr(),
|
||||
CUresult::ERROR_CONTEXT_IS_DESTROYED => c"context is destroyed".as_ptr(),
|
||||
CUresult::ERROR_ASSERT => c"device-side assert triggered".as_ptr(),
|
||||
CUresult::ERROR_TOO_MANY_PEERS => c"too many peers".as_ptr(),
|
||||
CUresult::ERROR_HOST_MEMORY_ALREADY_REGISTERED => {
|
||||
c"host memory already registered".as_ptr()
|
||||
}
|
||||
CUresult::ERROR_HOST_MEMORY_NOT_REGISTERED => c"host memory not registered".as_ptr(),
|
||||
CUresult::ERROR_UNKNOWN => c"unknown error".as_ptr(),
|
||||
_ => c"error".as_ptr(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::i32;
|
||||
|
||||
use crate::r#impl::driver::AllocationInfo;
|
||||
use crate::tests::CudaApi;
|
||||
use cuda_macros::test_cuda;
|
||||
|
@ -571,4 +622,34 @@ mod tests {
|
|||
}
|
||||
assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None);
|
||||
}
|
||||
|
||||
#[test_cuda]
|
||||
fn primary_context_is_inactive_on_init(api: impl CudaApi) {
|
||||
api.cuInit(0);
|
||||
let mut flags = u32::MAX;
|
||||
let mut active = i32::MAX;
|
||||
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
|
||||
assert_eq!(flags, 0);
|
||||
assert_eq!(active, 0);
|
||||
}
|
||||
|
||||
#[test_cuda]
|
||||
unsafe fn cudart_interface_fn2_creates_inactive_primary_ctx(api: impl CudaApi) {
|
||||
api.cuInit(0);
|
||||
let mut table_ptr = std::ptr::null();
|
||||
api.cuGetExportTable(&mut table_ptr, &dark_api::cuda::CudartInterface::GUID);
|
||||
let cuda_rt_iface = dark_api::cuda::CudartInterface::new(table_ptr);
|
||||
let mut dark_ctx = std::mem::zeroed();
|
||||
cuda_rt_iface
|
||||
.cudart_interface_fn2(&mut dark_ctx, 0)
|
||||
.unwrap();
|
||||
let mut flags = u32::MAX;
|
||||
let mut active = i32::MAX;
|
||||
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
|
||||
assert_eq!(flags, 0);
|
||||
assert_eq!(active, 0);
|
||||
let mut primary_ctx = std::mem::zeroed();
|
||||
api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0);
|
||||
assert_eq!(dark_ctx.0, primary_ctx.0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,12 +55,11 @@ pub(crate) unsafe fn load_data(
|
|||
_jit_options_values: Option<&mut *mut ::core::ffi::c_void>,
|
||||
_num_jit_options: ::core::ffi::c_uint,
|
||||
library_options: Option<&mut CUlibraryOption>,
|
||||
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
|
||||
_library_option_values: Option<&mut *mut ::core::ffi::c_void>,
|
||||
num_library_options: ::core::ffi::c_uint,
|
||||
) -> CUresult {
|
||||
let global_state = driver::global_state()?;
|
||||
let options =
|
||||
LibraryOptions::load(library_options, library_option_values, num_library_options)?;
|
||||
let options = LibraryOptions::load(library_options, num_library_options)?;
|
||||
let library = Library {
|
||||
data: LibraryData::new(code as *mut c_void, options.preserve_binary)?,
|
||||
modules: vec![OnceLock::new(); global_state.devices.len()],
|
||||
|
@ -76,7 +75,6 @@ struct LibraryOptions {
|
|||
impl LibraryOptions {
|
||||
unsafe fn load(
|
||||
library_options: Option<&mut CUlibraryOption>,
|
||||
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
|
||||
num_library_options: ::core::ffi::c_uint,
|
||||
) -> Result<Self, CUerror> {
|
||||
if num_library_options == 0 {
|
||||
|
@ -84,28 +82,17 @@ impl LibraryOptions {
|
|||
preserve_binary: false,
|
||||
});
|
||||
}
|
||||
let (library_options, library_option_values) =
|
||||
match (library_options, library_option_values) {
|
||||
(Some(library_options), Some(library_option_values)) => {
|
||||
let library_options =
|
||||
std::slice::from_raw_parts(library_options, num_library_options as usize);
|
||||
let library_option_values = std::slice::from_raw_parts(
|
||||
library_option_values,
|
||||
num_library_options as usize,
|
||||
);
|
||||
(library_options, library_option_values)
|
||||
}
|
||||
_ => return Err(CUerror::INVALID_VALUE),
|
||||
};
|
||||
let library_options = match library_options {
|
||||
Some(library_options) => {
|
||||
std::slice::from_raw_parts(library_options, num_library_options as usize)
|
||||
}
|
||||
_ => return Err(CUerror::INVALID_VALUE),
|
||||
};
|
||||
let mut preserve_binary = false;
|
||||
for (option, value) in library_options
|
||||
.iter()
|
||||
.copied()
|
||||
.zip(library_option_values.iter())
|
||||
{
|
||||
for option in library_options.iter().copied() {
|
||||
match option {
|
||||
CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => {
|
||||
preserve_binary = *(value.cast::<bool>());
|
||||
preserve_binary = true;
|
||||
}
|
||||
_ => return Err(CUerror::NOT_SUPPORTED),
|
||||
}
|
||||
|
@ -156,10 +143,7 @@ mod tests {
|
|||
use crate::tests::CudaApi;
|
||||
use cuda_macros::test_cuda;
|
||||
use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts};
|
||||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
mem, ptr,
|
||||
};
|
||||
use std::{ffi::CStr, mem, ptr};
|
||||
|
||||
#[test_cuda]
|
||||
unsafe fn library_loads_without_context(api: impl CudaApi) {
|
||||
|
@ -191,7 +175,7 @@ mod tests {
|
|||
ptr::null_mut(),
|
||||
0,
|
||||
[CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(),
|
||||
[(&true as *const bool) as *mut c_void].as_mut_ptr(),
|
||||
ptr::null_mut(),
|
||||
1,
|
||||
);
|
||||
assert_eq!(
|
||||
|
|
|
@ -153,3 +153,11 @@ pub(crate) unsafe fn set_d8_async(
|
|||
) -> hipError_t {
|
||||
hipMemsetD8Async(dst_device, uc, n, stream)
|
||||
}
|
||||
|
||||
pub(crate) fn get_allocation_granularity(
|
||||
_granularity: &mut usize,
|
||||
_property: &cuda_types::cuda::CUmemAllocationProp,
|
||||
_option: cuda_types::cuda::CUmemAllocationGranularity_flags,
|
||||
) -> CUresult {
|
||||
CUresult::ERROR_NOT_SUPPORTED
|
||||
}
|
||||
|
|
|
@ -138,7 +138,7 @@ fn compile_from_ptx_and_cache(
|
|||
} else {
|
||||
ptx_parser::parse_module_unchecked(text)
|
||||
};
|
||||
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
|
||||
let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
|
||||
let elf_module = comgr::compile_bitcode(
|
||||
comgr,
|
||||
gcn_arch,
|
||||
|
|
|
@ -93,6 +93,7 @@ cuda_macros::cuda_function_declarations!(
|
|||
cuDeviceGetUuid_v2,
|
||||
cuDevicePrimaryCtxGetState,
|
||||
cuDevicePrimaryCtxRelease,
|
||||
cuDevicePrimaryCtxRelease_v2,
|
||||
cuDevicePrimaryCtxReset,
|
||||
cuDevicePrimaryCtxRetain,
|
||||
cuDeviceTotalMem_v2,
|
||||
|
@ -104,6 +105,7 @@ cuda_macros::cuda_function_declarations!(
|
|||
cuEventSynchronize,
|
||||
cuFuncGetAttribute,
|
||||
cuFuncSetAttribute,
|
||||
cuGetErrorString,
|
||||
cuGetExportTable,
|
||||
cuGetProcAddress,
|
||||
cuGetProcAddress_v2,
|
||||
|
@ -126,6 +128,7 @@ cuda_macros::cuda_function_declarations!(
|
|||
cuMemFreeHost,
|
||||
cuMemFree_v2,
|
||||
cuMemGetAddressRange_v2,
|
||||
cuMemGetAllocationGranularity,
|
||||
cuMemGetInfo_v2,
|
||||
cuMemHostAlloc,
|
||||
cuMemRetainAllocationHandle,
|
||||
|
|
|
@ -165,7 +165,10 @@ from_cuda_nop!(
|
|||
nvmlDevice_t,
|
||||
nvmlFieldValue_t,
|
||||
nvmlGpuFabricInfo_t,
|
||||
cublasLtHandle_t
|
||||
cublasLtHandle_t,
|
||||
CUmemAllocationGranularity_flags,
|
||||
CUmemAllocationProp,
|
||||
CUresult
|
||||
);
|
||||
from_cuda_transmute!(
|
||||
CUuuid => hipUUID,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue