diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 321e492..253ba4b 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -257,9 +257,15 @@ impl Module { ctx: &mut Context, d: &Device, binaries: &[&'a [u8]], + opts: Option<&CStr>, ) -> (Result, Option) { - let ocl_program = match Self::build_link_spirv_impl(binaries) { - Err(_) => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None), + let ocl_program = match Self::build_link_spirv_impl(binaries, opts) { + Err(_) => { + return ( + Err(sys::ze_result_t::ZE_RESULT_ERROR_MODULE_LINK_FAILURE), + None, + ) + } Ok(prog) => prog, }; match ocl_core::get_program_info(&ocl_program, ocl_core::ProgramInfo::Binaries) { @@ -271,7 +277,10 @@ impl Module { } } - fn build_link_spirv_impl<'a>(binaries: &[&'a [u8]]) -> ocl_core::Result { + fn build_link_spirv_impl<'a>( + binaries: &[&'a [u8]], + opts: Option<&CStr>, + ) -> ocl_core::Result { let platforms = ocl_core::get_platform_ids()?; let (platform, device) = platforms .iter() @@ -305,7 +314,22 @@ impl Module { for binary in binaries { programs.push(ocl_core::create_program_with_il(&ocl_ctx, binary, None)?); } - let options = CString::default(); + let options = match opts { + Some(o) => o.to_owned(), + None => CString::default(), + }; + for program in programs.iter() { + ocl_core::compile_program( + program, + Some(&[device]), + &options, + &[], + &[], + None, + None, + None, + )?; + } ocl_core::link_program::( &ocl_ctx, Some(&[device]), diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index fa46bf4..cba030e 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -83,8 +83,21 @@ impl SpirvModule { self.binaries.len() * mem::size_of::(), ) }; - let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?; - Ok(l0_module) + let l0_module = match self.should_link_ptx_impl { + None => { + l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())).0 + } + Some(ptx_impl) => { + l0::Module::build_link_spirv( + ctx, + &dev, + &[ptx_impl, byte_il], + Some(self.build_options.as_c_str()), + ) + .0 + } + }; + Ok(l0_module?) } } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 027e891..c70ab5c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -52,6 +52,7 @@ test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(bra, [10u64], [11u64]); test_ptx!(not, [0u64], [u64::max_value()]); test_ptx!(shl, [11u64], [44u64]); +test_ptx!(shl_link_hack, [11u64], [44u64]); test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); @@ -202,7 +203,12 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( let dev = devices.drain(0..1).next().unwrap(); let queue = ze::CommandQueue::new(&mut ctx, &dev)?; let (module, maybe_log) = match module.should_link_ptx_impl { - Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]), + Some(ptx_impl) => ze::Module::build_link_spirv( + &mut ctx, + &dev, + &[ptx_impl, byte_il], + Some(module.build_options.as_c_str()), + ), None => { let (module, log) = ze::Module::build_spirv( &mut ctx, @@ -262,7 +268,6 @@ fn test_spvtxt_assert<'a>( let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; assert!(errors.len() == 0); let spirv_module = translate::to_spirv_module(ast)?; - eprintln!("{}", rspirv::binary::Disassemble::disassemble(&spirv_module.spirv)); let spv_context = unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; assert!(spv_context != ptr::null_mut()); diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 20c3edb..2b14bd7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2920,15 +2920,31 @@ fn emit_function_body_ops( }?; } ast::Instruction::Shl(t, a) => { - let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); - builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?; + let full_type = t.to_type(); + let size_of = full_type.size_of(); + let result_type = map.get_or_add(builder, SpirvType::from(full_type)); + let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; + builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; } ast::Instruction::Shr(t, a) => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); + let full_type = ast::ScalarType::from(*t); + let size_of = full_type.size_of(); + let result_type = map.get_or_add_scalar(builder, full_type); + let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?; if t.signed() { - builder.shift_right_arithmetic(result_type, Some(a.dst), a.src1, a.src2)?; + builder.shift_right_arithmetic( + result_type, + Some(a.dst), + a.src1, + offset_src, + )?; } else { - builder.shift_right_logical(result_type, Some(a.dst), a.src1, a.src2)?; + builder.shift_right_logical( + result_type, + Some(a.dst), + a.src1, + offset_src, + )?; } } ast::Instruction::Cvt(dets, arg) => { @@ -3225,6 +3241,23 @@ fn emit_function_body_ops( Ok(()) } +// HACK ALERT +// For some reason IGC fails linking if the value and shift size are of different type +fn insert_shift_hack( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + offset_var: spirv::Word, + size_of: usize, +) -> Result { + let result_type = match size_of { + 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), + 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), + 4 => return Ok(offset_var), + _ => return Err(TranslateError::Unreachable), + }; + Ok(builder.u_convert(result_type, None, offset_var)?) +} + // TODO: check what kind of assembly do we emit fn emit_logical_xor_spirv( builder: &mut dr::Builder,