Wire up AMD compilation

This commit is contained in:
Andrzej Janik 2021-08-06 13:19:55 +02:00
parent 5bfc2a56b9
commit 479014a783
2 changed files with 84 additions and 49 deletions

View file

@ -27,6 +27,7 @@ pub struct Device {
pub primary_context: context::Context,
pub allocations: HashSet<*mut c_void>,
pub is_amd: bool,
pub name: String,
}
unsafe impl Send for Device {}
@ -44,6 +45,12 @@ impl Device {
let queue = ocl_core::create_command_queue(&ctx, ocl_dev, None)?;
let primary_context =
context::Context::new(context::ContextData::new(0, true, ptr::null_mut())?);
let props = ocl_core::get_device_info(ocl_dev, ocl_core::DeviceInfo::Name)?;
let name = if let ocl_core::DeviceInfoResult::Name(name) = props {
Ok(name)
} else {
Err(CUresult::CUDA_ERROR_UNKNOWN)
}?;
Ok(Self {
index: Index(idx as c_int),
ocl_base: ocl_dev,
@ -52,6 +59,7 @@ impl Device {
primary_context,
allocations: HashSet::new(),
is_amd,
name,
})
}
@ -83,14 +91,7 @@ pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUres
if name == ptr::null_mut() || len < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let name_string = GlobalState::lock_device(dev_idx, |dev| {
let props = ocl_core::get_device_info(dev.ocl_base, ocl_core::DeviceInfo::Name)?;
if let ocl_core::DeviceInfoResult::Name(name) = props {
Ok(name)
} else {
Err(CUresult::CUDA_ERROR_UNKNOWN)
}
})??;
let name_string = GlobalState::lock_device(dev_idx, |dev| dev.name.clone())?;
let mut dst_null_pos = cmp::min((len - 1) as usize, name_string.len());
unsafe { std::ptr::copy_nonoverlapping(name_string.as_ptr() as *const _, name, dst_null_pos) };
if name_string.len() + PROJECT_URL_SUFFIX_LONG.len() < (len as usize) {
@ -179,7 +180,7 @@ pub fn get_attribute(
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR => {
GlobalState::lock_device(dev_idx, |dev| {
if !dev.is_amd {
8i32 * 7 // correct for GEN9
7 // correct for GEN9
} else {
4i32 * 32 // probably correct for RDNA
}

View file

@ -1,10 +1,11 @@
use std::{
borrow::Cow,
collections::hash_map,
collections::HashMap,
ffi::c_void,
ffi::CStr,
ffi::CString,
io::{self, Write},
io::{self, Read, Write},
mem,
os::raw::{c_char, c_int, c_uint},
path::PathBuf,
@ -106,9 +107,8 @@ impl SpirvModule {
"oclc_wavefrontsize64_off.bc",
];
const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
const AMDGPU_DEVICE: &'static str = "gfx1010";
fn get_bitcode_paths() -> impl Iterator<Item = PathBuf> {
fn get_bitcode_paths(device_name: &str) -> impl Iterator<Item = PathBuf> {
let generic_paths = Self::AMDGPU_BITCODE.iter().map(|x| {
let mut path = PathBuf::from(Self::AMDGPU);
path.push("amdgcn");
@ -122,19 +122,27 @@ impl SpirvModule {
additional_path.push(format!(
"{}{}{}",
Self::AMDGPU_BITCODE_DEVICE_PREFIX,
&Self::AMDGPU_DEVICE[3..],
&device_name[3..],
".bc"
));
generic_paths.chain(std::iter::once(additional_path))
}
#[cfg(not(target_os = "linux"))]
fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> {
Ok(())
fn compile_amd(
device_name: &str,
spirv_il: &[u8],
ptx_lib: Option<(&'static [u8], &'static [u8])>,
) -> io::Result<Vec<u8>> {
unimplemented!()
}
#[cfg(target_os = "linux")]
fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> {
fn compile_amd(
device_name: &str,
spirv_il: &[u8],
ptx_lib: Option<(&'static [u8], &'static [u8])>,
) -> io::Result<Vec<u8>> {
let dir = tempfile::tempdir()?;
let mut spirv = NamedTempFile::new_in(&dir)?;
let llvm = NamedTempFile::new_in(&dir)?;
@ -156,20 +164,20 @@ impl SpirvModule {
.arg("-o")
.arg(linked_binary.path())
.arg(llvm.path())
.args(Self::get_bitcode_paths());
.args(Self::get_bitcode_paths(device_name));
if cfg!(debug_assertions) {
linker_cmd.arg("-v");
}
let status = linker_cmd.status()?;
assert!(status.success());
let mut ptx_lib_bitcode = NamedTempFile::new_in(&dir)?;
let compiled_binary = NamedTempFile::new_in(&dir)?;
let mut compiled_binary = NamedTempFile::new_in(&dir)?;
let mut cland_exe = PathBuf::from(Self::AMDGPU);
cland_exe.push("bin");
cland_exe.push("clang");
let mut compiler_cmd = Command::new(&cland_exe);
compiler_cmd
.arg(format!("-mcpu={}", Self::AMDGPU_DEVICE))
.arg(format!("-mcpu={}", device_name))
.arg("-O3")
.arg("-Xlinker")
.arg("--no-undefined")
@ -178,7 +186,7 @@ impl SpirvModule {
.arg("-o")
.arg(compiled_binary.path())
.arg(linked_binary.path());
if let Some(bitcode) = ptx_lib {
if let Some((_, bitcode)) = ptx_lib {
ptx_lib_bitcode.write_all(bitcode)?;
compiler_cmd.arg(ptx_lib_bitcode.path());
};
@ -187,40 +195,30 @@ impl SpirvModule {
}
let status = compiler_cmd.status()?;
assert!(status.success());
Ok(())
let mut result = Vec::new();
compiled_binary.read_to_end(&mut result)?;
Ok(result)
}
pub fn compile<'a>(
&self,
fn compile_intel<'a>(
ctx: &ocl_core::Context,
dev: &ocl_core::DeviceId,
) -> Result<ocl_core::Program, CUresult> {
let byte_il = unsafe {
slice::from_raw_parts(
self.binaries.as_ptr() as *const u8,
self.binaries.len() * mem::size_of::<u32>(),
)
};
byte_il: &'a [u8],
build_options: &CString,
ptx_lib: Option<(&'static [u8], &'static [u8])>,
) -> ocl_core::Result<ocl_core::Program> {
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
let main_module = match self.should_link_ptx_impl {
Ok(match ptx_lib {
None => {
Self::compile_amd(byte_il, None).unwrap();
ocl_core::build_program(
&main_module,
Some(&[dev]),
&self.build_options,
None,
None,
)?;
ocl_core::build_program(&main_module, Some(&[dev]), build_options, None, None)?;
main_module
}
Some((ptx_impl_intel, ptx_impl_amd)) => {
Self::compile_amd(byte_il, Some(ptx_impl_amd)).unwrap();
Some((ptx_impl_intel, _)) => {
let ptx_impl_prog = ocl_core::create_program_with_il(ctx, ptx_impl_intel, None)?;
ocl_core::compile_program(
&main_module,
Some(&[dev]),
&self.build_options,
build_options,
&[],
&[],
None,
@ -230,7 +228,7 @@ impl SpirvModule {
ocl_core::compile_program(
&ptx_impl_prog,
Some(&[dev]),
&self.build_options,
build_options,
&[],
&[],
None,
@ -240,15 +238,43 @@ impl SpirvModule {
ocl_core::link_program(
ctx,
Some(&[dev]),
&self.build_options,
build_options,
&[&main_module, &ptx_impl_prog],
None,
None,
None,
)?
}
})
}
pub fn compile<'a>(
&self,
ctx: &ocl_core::Context,
dev: &ocl_core::DeviceId,
device_name: &str,
is_amd: bool,
) -> Result<ocl_core::Program, CUresult> {
let byte_il = unsafe {
slice::from_raw_parts(
self.binaries.as_ptr() as *const u8,
self.binaries.len() * mem::size_of::<u32>(),
)
};
Ok(main_module)
let ocl_program = if is_amd {
let binary_prog =
Self::compile_amd(device_name, byte_il, self.should_link_ptx_impl).unwrap();
ocl_core::create_program_with_binary(ctx, &[dev], &[&binary_prog[..]])?
} else {
Self::compile_intel(
ctx,
dev,
byte_il,
&self.build_options,
self.should_link_ptx_impl,
)?
};
Ok(ocl_program)
}
}
@ -268,9 +294,12 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule {
base: module
.spirv
.compile(&device.ocl_context, &device.ocl_base)?,
base: module.spirv.compile(
&device.ocl_context,
&device.ocl_base,
&device.name,
device.is_amd,
)?,
kernels: HashMap::new(),
};
entry.insert(new_module)
@ -340,7 +369,12 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device };
let l0_module = spirv_data.compile(&device.ocl_context, &device.ocl_base)?;
let l0_module = spirv_data.compile(
&device.ocl_context,
&device.ocl_base,
&device.name,
device.is_amd,
)?;
let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule {
base: l0_module,