diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 5432207..36e7191 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -261,6 +261,7 @@ pub enum Instruction { Call(CallInst

), Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), + Fma(ArithFloat, Arg4

), Or(ScalarType, Arg3

), Sub(ArithDetails, Arg3

), Min(MinMaxDetails, Arg3

), diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 18ec4fb..b20a30a 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -743,6 +743,7 @@ Instruction: ast::Instruction> = { InstCall, InstAbs, InstMad, + InstFma, InstOr, InstAnd, InstSub, @@ -1345,7 +1346,11 @@ InstAbs: ast::Instruction> = { InstMad: ast::Instruction> = { "mad" => ast::Instruction::Mad(d, a), "mad" ".hi" ".sat" ".s32" => todo!(), - "fma" => ast::Instruction::Mad(ast::MulDetails::Float(f), a), +}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma +InstFma: ast::Instruction> = { + "fma" => ast::Instruction::Fma(f, a), }; SignedIntType: ast::ScalarType = { diff --git a/ptx/src/test/spirv_run/cos.spvtxt b/ptx/src/test/spirv_run/cos.spvtxt index 6fafcb5..8d6a0ca 100644 --- a/ptx/src/test/spirv_run/cos.spvtxt +++ b/ptx/src/test/spirv_run/cos.spvtxt @@ -37,7 +37,7 @@ %11 = OpLoad %float %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 cos %14 + %13 = OpExtInst %float %21 native_cos %14 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %float %6 diff --git a/ptx/src/test/spirv_run/ex2.spvtxt b/ptx/src/test/spirv_run/ex2.spvtxt index 62c44b8..3d7b58d 100644 --- a/ptx/src/test/spirv_run/ex2.spvtxt +++ b/ptx/src/test/spirv_run/ex2.spvtxt @@ -37,7 +37,7 @@ %11 = OpLoad %float %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 exp2 %14 + %13 = OpExtInst %float %21 native_exp2 %14 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %float %6 diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt index 8cc0e16..91a2159 100644 --- a/ptx/src/test/spirv_run/fma.spvtxt +++ b/ptx/src/test/spirv_run/fma.spvtxt @@ -59,7 +59,7 @@ %20 = OpLoad %float %6 %21 = OpLoad %float %7 %22 = OpLoad %float %8 - %19 = OpExtInst %float %35 mad %20 %21 %22 + %19 = OpExtInst %float %35 fma %20 %21 %22 OpStore %6 %19 %23 = OpLoad %ulong %5 %24 = OpLoad %float %6 diff --git a/ptx/src/test/spirv_run/lg2.spvtxt b/ptx/src/test/spirv_run/lg2.spvtxt index 3c7ca77..c30eeff 100644 --- a/ptx/src/test/spirv_run/lg2.spvtxt +++ b/ptx/src/test/spirv_run/lg2.spvtxt @@ -37,7 +37,7 @@ %11 = OpLoad %float %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 log2 %14 + %13 = OpExtInst %float %21 native_log2 %14 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %float %6 diff --git a/ptx/src/test/spirv_run/rcp.spvtxt b/ptx/src/test/spirv_run/rcp.spvtxt index 2d56ee8..09fa0d9 100644 --- a/ptx/src/test/spirv_run/rcp.spvtxt +++ b/ptx/src/test/spirv_run/rcp.spvtxt @@ -10,7 +10,7 @@ %21 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "rcp" - OpDecorate %13 FPFastMathMode AllowRecip + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 %24 = OpTypeFunction %void %ulong %ulong @@ -18,7 +18,6 @@ %float = OpTypeFloat 32 %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float - %float_1 = OpConstant %float 1 %1 = OpFunction %void None %24 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong @@ -39,11 +38,11 @@ %11 = OpLoad %float %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %float %6 - %13 = OpFDiv %float %float_1 %14 + %13 = OpExtInst %float %21 native_recip %14 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %float %6 %18 = OpConvertUToPtr %_ptr_Generic_float %15 OpStore %18 %16 Aligned 4 OpReturn - OpFunctionEnd + OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/sin.spvtxt b/ptx/src/test/spirv_run/sin.spvtxt index 618d5f2..02eba40 100644 --- a/ptx/src/test/spirv_run/sin.spvtxt +++ b/ptx/src/test/spirv_run/sin.spvtxt @@ -37,7 +37,7 @@ %11 = OpLoad %float %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 sin %14 + %13 = OpExtInst %float %21 native_sin %14 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %float %6 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c236438..91e4237 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -559,25 +559,29 @@ fn emit_directives<'input>( &directives, kernel_info, )?; - for t in f.tuning.iter() { - match *t { - ast::TuningDirective::MaxNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id, - spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, - [nx, ny, nz], - ); + if func_decl.name.is_kernel() { + // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) + builder.execution_mode(fn_id, spirv_headers::ExecutionMode::ContractionOff, []); + for t in f.tuning.iter() { + match *t { + ast::TuningDirective::MaxNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id, + spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, + [nx, ny, nz], + ); + } + ast::TuningDirective::ReqNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id, + spirv_headers::ExecutionMode::LocalSize, + [nx, ny, nz], + ); + } + // Too architecture specific + ast::TuningDirective::MaxNReg(..) + | ast::TuningDirective::MinNCtaPerSm(..) => {} } - ast::TuningDirective::ReqNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id, - spirv_headers::ExecutionMode::LocalSize, - [nx, ny, nz], - ); - } - // Too architecture specific - ast::TuningDirective::MaxNReg(..) - | ast::TuningDirective::MinNCtaPerSm(..) => {} } } emit_function_body_ops(builder, map, opencl_id, &f_body)?; @@ -2772,6 +2776,7 @@ fn emit_function_body_ops( emit_mad_float(builder, map, opencl, desc, arg)? } }, + ast::Instruction::Fma(fma, arg) => emit_fma_float(builder, map, opencl, fma, arg)?, ast::Instruction::Or(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); if *t == ast::ScalarType::Pred { @@ -2798,7 +2803,7 @@ fn emit_function_body_ops( emit_max(builder, map, opencl, d, a)?; } ast::Instruction::Rcp(d, a) => { - emit_rcp(builder, map, d, a)?; + emit_rcp(builder, map, opencl, d, a)?; } ast::Instruction::And(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); @@ -2901,7 +2906,7 @@ fn emit_function_body_ops( result_type, Some(arg.dst), opencl, - spirv::CLOp::sin as u32, + spirv::CLOp::native_sin as u32, [dr::Operand::IdRef(arg.src)].iter().cloned(), )?; } @@ -2911,7 +2916,7 @@ fn emit_function_body_ops( result_type, Some(arg.dst), opencl, - spirv::CLOp::cos as u32, + spirv::CLOp::native_cos as u32, [dr::Operand::IdRef(arg.src)].iter().cloned(), )?; } @@ -2921,7 +2926,7 @@ fn emit_function_body_ops( result_type, Some(arg.dst), opencl, - spirv::CLOp::log2 as u32, + spirv::CLOp::native_log2 as u32, [dr::Operand::IdRef(arg.src)].iter().cloned(), )?; } @@ -2931,7 +2936,7 @@ fn emit_function_body_ops( result_type, Some(arg.dst), opencl, - spirv::CLOp::exp2 as u32, + spirv::CLOp::native_exp2 as u32, [dr::Operand::IdRef(arg.src)].iter().cloned(), )?; } @@ -3237,20 +3242,31 @@ fn emit_mul_float( fn emit_rcp( builder: &mut dr::Builder, map: &mut TypeWordMap, + opencl: spirv::Word, desc: &ast::RcpDetails, - a: &ast::Arg2, + arg: &ast::Arg2, ) -> Result<(), TranslateError> { let (instr_type, constant) = if desc.is_f64 { (ast::ScalarType::F64, vec_repr(1.0f64)) } else { (ast::ScalarType::F32, vec_repr(1.0f32)) }; - let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; let result_type = map.get_or_add_scalar(builder, instr_type); - builder.f_div(result_type, Some(a.dst), one, a.src)?; - emit_rounding_decoration(builder, a.dst, desc.rounding); + if !desc.is_f64 && desc.rounding.is_none() { + builder.ext_inst( + result_type, + Some(arg.dst), + opencl, + spirv::CLOp::native_recip as u32, + [dr::Operand::IdRef(arg.src)].iter().cloned(), + )?; + return Ok(()); + } + let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; + builder.f_div(result_type, Some(arg.dst), one, arg.src)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); builder.decorate( - a.dst, + arg.dst, spirv::Decoration::FPFastMathMode, [dr::Operand::FPFastMathMode( spirv::FPFastMathMode::ALLOW_RECIP, @@ -3372,6 +3388,30 @@ fn emit_mad_sint( Ok(()) } +fn emit_fma_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::ArithFloat, + arg: &ast::Arg4, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); + builder.ext_inst( + inst_type, + Some(arg.dst), + opencl, + spirv::CLOp::fma as spirv::Word, + [ + dr::Operand::IdRef(arg.src1), + dr::Operand::IdRef(arg.src2), + dr::Operand::IdRef(arg.src3), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + fn emit_mad_float( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -5713,6 +5753,10 @@ impl ast::Instruction { let is_wide = d.is_wide(); ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?) } + ast::Instruction::Fma(d, a) => { + let inst_type = ast::Type::Scalar(d.typ); + ast::Instruction::Fma(d, a.map(visitor, &inst_type, false)?) + } ast::Instruction::Or(t, a) => ast::Instruction::Or( t, a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, @@ -6106,6 +6150,7 @@ impl ast::Instruction { | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control .flush_to_zero .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), + ast::Instruction::Fma(d, _) => d.flush_to_zero.map(|ftz| (ftz, d.typ.size_of())), ast::Instruction::Setp(details, _) => details .flush_to_zero .map(|ftz| (ftz, details.typ.size_of())), diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 3b43c49..e886eb9 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -494,7 +494,7 @@ pub fn get_attribute( l0::sys::ze_result_t::ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, )) */ - return Ok(()); + 0 } }; unsafe { *pi = value }; diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs index 05f864b..548936f 100644 --- a/zluda/src/impl/function.rs +++ b/zluda/src/impl/function.rs @@ -51,6 +51,37 @@ impl LegacyArguments { } } +unsafe fn set_arg( + kernel: &ocl_core::Kernel, + arg_index: usize, + arg_size: usize, + arg_value: *const c_void, + is_mem: bool, +) -> Result<(), CUresult> { + if is_mem { + let error = 0; + unsafe { + ocl_core::ffi::clSetKernelArgSVMPointer( + kernel.as_ptr(), + arg_index as u32, + *(arg_value as *const _), + ) + }; + if error != 0 { + panic!("clSetKernelArgSVMPointer"); + } + } else { + unsafe { + ocl_core::set_kernel_arg( + kernel, + arg_index as u32, + ocl_core::ArgVal::from_raw(arg_size, arg_value, is_mem), + )?; + }; + } + Ok(()) +} + pub fn launch_kernel( f: *mut Function, grid_dim_x: c_uint, @@ -74,27 +105,7 @@ pub fn launch_kernel( let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?; if kernel_params != ptr::null_mut() { for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() { - if is_mem { - let error = 0; - unsafe { - ocl_core::ffi::clSetKernelArgSVMPointer( - func.base.as_ptr(), - i as u32, - *(*kernel_params.add(i) as *const _), - ) - }; - if error != 0 { - panic!("clSetKernelArgSVMPointer"); - } - } else { - unsafe { - ocl_core::set_kernel_arg( - &func.base, - i as u32, - ocl_core::ArgVal::from_raw(arg_size, *kernel_params.add(i), is_mem), - )?; - }; - } + unsafe { set_arg(&func.base, i, arg_size, *kernel_params.add(i), is_mem)? }; } } else { let mut offset = 0; @@ -126,15 +137,13 @@ pub fn launch_kernel( for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() { let buffer_offset = round_up_to_multiple(offset, arg_size); unsafe { - ocl_core::set_kernel_arg( + set_arg( &func.base, - i as u32, - ocl_core::ArgVal::from_raw( - arg_size, - buffer_ptr.add(buffer_offset) as *const _, - is_mem, - ), - )?; + i, + arg_size, + buffer_ptr.add(buffer_offset) as *const _, + is_mem, + )? }; offset = buffer_offset + arg_size; } @@ -144,11 +153,13 @@ pub fn launch_kernel( } if func.use_shared_mem { unsafe { - ocl_core::set_kernel_arg( + set_arg( &func.base, - func.arg_size.len() as u32, - ocl_core::ArgVal::from_raw(shared_mem_bytes as usize, ptr::null(), false), - )?; + func.arg_size.len(), + shared_mem_bytes as usize, + ptr::null(), + false, + )? }; } let global_dims = [ @@ -192,9 +203,9 @@ pub(crate) fn get_attribute( CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { let max_threads = GlobalState::lock_function(func, |func| { if let ocl_core::KernelWorkGroupInfoResult::WorkGroupSize(size) = - ocl_core::get_kernel_work_group_info::( + ocl_core::get_kernel_work_group_info::<()>( &func.base, - unsafe { ocl_core::DeviceId::null() }, + (), ocl_core::KernelWorkGroupInfo::WorkGroupSize, )? { diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 3e96a8c..7293ca6 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -1,16 +1,32 @@ -use super::{stream, CUresult, GlobalState}; +use super::{ + stream::{self, CU_STREAM_LEGACY}, + CUresult, GlobalState, +}; use std::{ ffi::c_void, mem::{self, size_of}, }; pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> { - let ptr = GlobalState::lock_current_context(|ctx| { - let dev = unsafe { &mut *ctx.device }; - Ok::<_, CUresult>(unsafe { + let ptr = GlobalState::lock_stream(CU_STREAM_LEGACY, |stream_data| { + let dev = unsafe { &*(*stream_data.context).device }; + let queue = stream_data.cmd_list.as_ref().unwrap(); + let ptr = unsafe { dev.ocl_ext .device_mem_alloc(&dev.ocl_context, &dev.ocl_base, bytesize, 0)? - }) + }; + // CUDA does the same thing and e.g. GeekBench relies on this behavior + let event = unsafe { + dev.ocl_ext.enqueue_memfill( + queue, + ptr, + &0u8 as *const u8 as *const c_void, + 1, + bytesize, + )? + }; + ocl_core::wait_for_event(&event)?; + Ok::<_, CUresult>(ptr) })??; unsafe { *dptr = ptr }; Ok(()) diff --git a/zluda_dump/src/debug.ptx b/zluda_dump/src/debug.ptx new file mode 100644 index 0000000..29104f8 --- /dev/null +++ b/zluda_dump/src/debug.ptx @@ -0,0 +1,55 @@ +/* + This collection of functions is here to assist with debugging + You use it by manually pasting into a module.ptx that was generated by zluda_dump + and inspecting content of additional debug buffer in replay.py +*/ + +.func debug_dump_from_thread_16(.reg.b64 debug_addr, .reg.u32 global_id_0, .reg.b16 value) +{ + .reg.u32 local_id; + mov.u32 local_id, %tid.x; + .reg.u32 local_size; + mov.u32 local_size, %ntid.x; + .reg.u32 group_id; + mov.u32 group_id, %ctaid.x; + .reg.b32 global_id; + mad.lo.u32 global_id, group_id, local_size, local_id; + .reg.pred should_exit; + setp.ne.u32 should_exit, global_id, global_id_0; + @should_exit bra END; + .reg.b32 index; + ld.global.u32 index, [debug_addr]; + st.global.u32 [debug_addr], index+1; + .reg.u64 st_offset; + cvt.u64.u32 st_offset, index; + mad.lo.u64 st_offset, st_offset, 2, 4; // sizeof(b16), sizeof(32) + add.u64 debug_addr, debug_addr, st_offset; + st.global.u16 [debug_addr], value; +END: + ret; +} + +.func debug_dump_from_thread_32(.reg.b64 debug_addr, .reg.u32 global_id_0, .reg.b32 value) +{ + .reg.u32 local_id; + mov.u32 local_id, %tid.x; + .reg.u32 local_size; + mov.u32 local_size, %ntid.x; + .reg.u32 group_id; + mov.u32 group_id, %ctaid.x; + .reg.b32 global_id; + mad.lo.u32 global_id, group_id, local_size, local_id; + .reg.pred should_exit; + setp.ne.u32 should_exit, global_id, global_id_0; + @should_exit bra END; + .reg.b32 index; + ld.global.u32 index, [debug_addr]; + st.global.u32 [debug_addr], index+1; + .reg.u64 st_offset; + cvt.u64.u32 st_offset, index; + mad.lo.u64 st_offset, st_offset, 4, 4; // sizeof(b32), sizeof(32) + add.u64 debug_addr, debug_addr, st_offset; + st.global.u32 [debug_addr], value; +END: + ret; +} diff --git a/zluda_dump/src/replay.py b/zluda_dump/src/replay.py index 723d954..c331d53 100644 --- a/zluda_dump/src/replay.py +++ b/zluda_dump/src/replay.py @@ -53,7 +53,7 @@ def parse_arguments(dump_path, prefix): def append_debug_buffer(args, grid, block): args = list(args) - items = block[0] * block[1] * block[2] * block[0] * block[1] * block[2] + items = grid[0] * grid[1] * grid[2] * block[0] * block[1] * block[2] debug_buff = np.zeros(items, dtype=np.uint32) args.append((drv.InOut(debug_buff), debug_buff)) return args