Wire up AMD compilation

This commit is contained in:
Andrzej Janik 2021-08-06 13:19:55 +02:00
commit 479014a783
2 changed files with 84 additions and 49 deletions

View file

@ -27,6 +27,7 @@ pub struct Device {
pub primary_context: context::Context, pub primary_context: context::Context,
pub allocations: HashSet<*mut c_void>, pub allocations: HashSet<*mut c_void>,
pub is_amd: bool, pub is_amd: bool,
pub name: String,
} }
unsafe impl Send for Device {} unsafe impl Send for Device {}
@ -44,6 +45,12 @@ impl Device {
let queue = ocl_core::create_command_queue(&ctx, ocl_dev, None)?; let queue = ocl_core::create_command_queue(&ctx, ocl_dev, None)?;
let primary_context = let primary_context =
context::Context::new(context::ContextData::new(0, true, ptr::null_mut())?); context::Context::new(context::ContextData::new(0, true, ptr::null_mut())?);
let props = ocl_core::get_device_info(ocl_dev, ocl_core::DeviceInfo::Name)?;
let name = if let ocl_core::DeviceInfoResult::Name(name) = props {
Ok(name)
} else {
Err(CUresult::CUDA_ERROR_UNKNOWN)
}?;
Ok(Self { Ok(Self {
index: Index(idx as c_int), index: Index(idx as c_int),
ocl_base: ocl_dev, ocl_base: ocl_dev,
@ -52,6 +59,7 @@ impl Device {
primary_context, primary_context,
allocations: HashSet::new(), allocations: HashSet::new(),
is_amd, is_amd,
name,
}) })
} }
@ -83,14 +91,7 @@ pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUres
if name == ptr::null_mut() || len < 0 { if name == ptr::null_mut() || len < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let name_string = GlobalState::lock_device(dev_idx, |dev| { let name_string = GlobalState::lock_device(dev_idx, |dev| dev.name.clone())?;
let props = ocl_core::get_device_info(dev.ocl_base, ocl_core::DeviceInfo::Name)?;
if let ocl_core::DeviceInfoResult::Name(name) = props {
Ok(name)
} else {
Err(CUresult::CUDA_ERROR_UNKNOWN)
}
})??;
let mut dst_null_pos = cmp::min((len - 1) as usize, name_string.len()); let mut dst_null_pos = cmp::min((len - 1) as usize, name_string.len());
unsafe { std::ptr::copy_nonoverlapping(name_string.as_ptr() as *const _, name, dst_null_pos) }; unsafe { std::ptr::copy_nonoverlapping(name_string.as_ptr() as *const _, name, dst_null_pos) };
if name_string.len() + PROJECT_URL_SUFFIX_LONG.len() < (len as usize) { if name_string.len() + PROJECT_URL_SUFFIX_LONG.len() < (len as usize) {
@ -179,7 +180,7 @@ pub fn get_attribute(
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR => { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR => {
GlobalState::lock_device(dev_idx, |dev| { GlobalState::lock_device(dev_idx, |dev| {
if !dev.is_amd { if !dev.is_amd {
8i32 * 7 // correct for GEN9 7 // correct for GEN9
} else { } else {
4i32 * 32 // probably correct for RDNA 4i32 * 32 // probably correct for RDNA
} }

View file

@ -1,10 +1,11 @@
use std::{ use std::{
borrow::Cow,
collections::hash_map, collections::hash_map,
collections::HashMap, collections::HashMap,
ffi::c_void, ffi::c_void,
ffi::CStr, ffi::CStr,
ffi::CString, ffi::CString,
io::{self, Write}, io::{self, Read, Write},
mem, mem,
os::raw::{c_char, c_int, c_uint}, os::raw::{c_char, c_int, c_uint},
path::PathBuf, path::PathBuf,
@ -106,9 +107,8 @@ impl SpirvModule {
"oclc_wavefrontsize64_off.bc", "oclc_wavefrontsize64_off.bc",
]; ];
const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_"; const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
const AMDGPU_DEVICE: &'static str = "gfx1010";
fn get_bitcode_paths() -> impl Iterator<Item = PathBuf> { fn get_bitcode_paths(device_name: &str) -> impl Iterator<Item = PathBuf> {
let generic_paths = Self::AMDGPU_BITCODE.iter().map(|x| { let generic_paths = Self::AMDGPU_BITCODE.iter().map(|x| {
let mut path = PathBuf::from(Self::AMDGPU); let mut path = PathBuf::from(Self::AMDGPU);
path.push("amdgcn"); path.push("amdgcn");
@ -122,19 +122,27 @@ impl SpirvModule {
additional_path.push(format!( additional_path.push(format!(
"{}{}{}", "{}{}{}",
Self::AMDGPU_BITCODE_DEVICE_PREFIX, Self::AMDGPU_BITCODE_DEVICE_PREFIX,
&Self::AMDGPU_DEVICE[3..], &device_name[3..],
".bc" ".bc"
)); ));
generic_paths.chain(std::iter::once(additional_path)) generic_paths.chain(std::iter::once(additional_path))
} }
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> { fn compile_amd(
Ok(()) device_name: &str,
spirv_il: &[u8],
ptx_lib: Option<(&'static [u8], &'static [u8])>,
) -> io::Result<Vec<u8>> {
unimplemented!()
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
fn compile_amd(spirv_il: &[u8], ptx_lib: Option<&'static [u8]>) -> io::Result<()> { fn compile_amd(
device_name: &str,
spirv_il: &[u8],
ptx_lib: Option<(&'static [u8], &'static [u8])>,
) -> io::Result<Vec<u8>> {
let dir = tempfile::tempdir()?; let dir = tempfile::tempdir()?;
let mut spirv = NamedTempFile::new_in(&dir)?; let mut spirv = NamedTempFile::new_in(&dir)?;
let llvm = NamedTempFile::new_in(&dir)?; let llvm = NamedTempFile::new_in(&dir)?;
@ -156,20 +164,20 @@ impl SpirvModule {
.arg("-o") .arg("-o")
.arg(linked_binary.path()) .arg(linked_binary.path())
.arg(llvm.path()) .arg(llvm.path())
.args(Self::get_bitcode_paths()); .args(Self::get_bitcode_paths(device_name));
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
linker_cmd.arg("-v"); linker_cmd.arg("-v");
} }
let status = linker_cmd.status()?; let status = linker_cmd.status()?;
assert!(status.success()); assert!(status.success());
let mut ptx_lib_bitcode = NamedTempFile::new_in(&dir)?; let mut ptx_lib_bitcode = NamedTempFile::new_in(&dir)?;
let compiled_binary = NamedTempFile::new_in(&dir)?; let mut compiled_binary = NamedTempFile::new_in(&dir)?;
let mut cland_exe = PathBuf::from(Self::AMDGPU); let mut cland_exe = PathBuf::from(Self::AMDGPU);
cland_exe.push("bin"); cland_exe.push("bin");
cland_exe.push("clang"); cland_exe.push("clang");
let mut compiler_cmd = Command::new(&cland_exe); let mut compiler_cmd = Command::new(&cland_exe);
compiler_cmd compiler_cmd
.arg(format!("-mcpu={}", Self::AMDGPU_DEVICE)) .arg(format!("-mcpu={}", device_name))
.arg("-O3") .arg("-O3")
.arg("-Xlinker") .arg("-Xlinker")
.arg("--no-undefined") .arg("--no-undefined")
@ -178,7 +186,7 @@ impl SpirvModule {
.arg("-o") .arg("-o")
.arg(compiled_binary.path()) .arg(compiled_binary.path())
.arg(linked_binary.path()); .arg(linked_binary.path());
if let Some(bitcode) = ptx_lib { if let Some((_, bitcode)) = ptx_lib {
ptx_lib_bitcode.write_all(bitcode)?; ptx_lib_bitcode.write_all(bitcode)?;
compiler_cmd.arg(ptx_lib_bitcode.path()); compiler_cmd.arg(ptx_lib_bitcode.path());
}; };
@ -187,40 +195,30 @@ impl SpirvModule {
} }
let status = compiler_cmd.status()?; let status = compiler_cmd.status()?;
assert!(status.success()); assert!(status.success());
Ok(()) let mut result = Vec::new();
compiled_binary.read_to_end(&mut result)?;
Ok(result)
} }
pub fn compile<'a>( fn compile_intel<'a>(
&self,
ctx: &ocl_core::Context, ctx: &ocl_core::Context,
dev: &ocl_core::DeviceId, dev: &ocl_core::DeviceId,
) -> Result<ocl_core::Program, CUresult> { byte_il: &'a [u8],
let byte_il = unsafe { build_options: &CString,
slice::from_raw_parts( ptx_lib: Option<(&'static [u8], &'static [u8])>,
self.binaries.as_ptr() as *const u8, ) -> ocl_core::Result<ocl_core::Program> {
self.binaries.len() * mem::size_of::<u32>(),
)
};
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?; let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
let main_module = match self.should_link_ptx_impl { Ok(match ptx_lib {
None => { None => {
Self::compile_amd(byte_il, None).unwrap(); ocl_core::build_program(&main_module, Some(&[dev]), build_options, None, None)?;
ocl_core::build_program(
&main_module,
Some(&[dev]),
&self.build_options,
None,
None,
)?;
main_module main_module
} }
Some((ptx_impl_intel, ptx_impl_amd)) => { Some((ptx_impl_intel, _)) => {
Self::compile_amd(byte_il, Some(ptx_impl_amd)).unwrap();
let ptx_impl_prog = ocl_core::create_program_with_il(ctx, ptx_impl_intel, None)?; let ptx_impl_prog = ocl_core::create_program_with_il(ctx, ptx_impl_intel, None)?;
ocl_core::compile_program( ocl_core::compile_program(
&main_module, &main_module,
Some(&[dev]), Some(&[dev]),
&self.build_options, build_options,
&[], &[],
&[], &[],
None, None,
@ -230,7 +228,7 @@ impl SpirvModule {
ocl_core::compile_program( ocl_core::compile_program(
&ptx_impl_prog, &ptx_impl_prog,
Some(&[dev]), Some(&[dev]),
&self.build_options, build_options,
&[], &[],
&[], &[],
None, None,
@ -240,15 +238,43 @@ impl SpirvModule {
ocl_core::link_program( ocl_core::link_program(
ctx, ctx,
Some(&[dev]), Some(&[dev]),
&self.build_options, build_options,
&[&main_module, &ptx_impl_prog], &[&main_module, &ptx_impl_prog],
None, None,
None, None,
None, None,
)? )?
} }
})
}
pub fn compile<'a>(
&self,
ctx: &ocl_core::Context,
dev: &ocl_core::DeviceId,
device_name: &str,
is_amd: bool,
) -> Result<ocl_core::Program, CUresult> {
let byte_il = unsafe {
slice::from_raw_parts(
self.binaries.as_ptr() as *const u8,
self.binaries.len() * mem::size_of::<u32>(),
)
}; };
Ok(main_module) let ocl_program = if is_amd {
let binary_prog =
Self::compile_amd(device_name, byte_il, self.should_link_ptx_impl).unwrap();
ocl_core::create_program_with_binary(ctx, &[dev], &[&binary_prog[..]])?
} else {
Self::compile_intel(
ctx,
dev,
byte_il,
&self.build_options,
self.should_link_ptx_impl,
)?
};
Ok(ocl_program)
} }
} }
@ -268,9 +294,12 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => { hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule { let new_module = CompiledModule {
base: module base: module.spirv.compile(
.spirv &device.ocl_context,
.compile(&device.ocl_context, &device.ocl_base)?, &device.ocl_base,
&device.name,
device.is_amd,
)?,
kernels: HashMap::new(), kernels: HashMap::new(),
}; };
entry.insert(new_module) entry.insert(new_module)
@ -340,7 +369,12 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| { let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device }; let device = unsafe { &mut *ctx.device };
let l0_module = spirv_data.compile(&device.ocl_context, &device.ocl_base)?; let l0_module = spirv_data.compile(
&device.ocl_context,
&device.ocl_base,
&device.name,
device.is_amd,
)?;
let mut device_binaries = HashMap::new(); let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule { let compiled_module = CompiledModule {
base: l0_module, base: l0_module,