diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 8bbd1d7..22d378e 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -283,6 +283,7 @@ pub type KernelArgument = Variable; pub struct Function<'a, ID, S> { pub func_directive: MethodDecl<'a, ID>, + pub tuning: Vec, pub body: Option>, } @@ -1369,6 +1370,14 @@ bitflags! { } } +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum TuningDirective { + MaxNReg(u32), + MaxNtid(u32, u32, u32), + ReqNtid(u32, u32, u32), + MinNCtaPerSm(u32), +} + #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index ce3e387..631d5ad 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -87,6 +87,9 @@ match { ".ltu", ".lu", ".max", + ".maxnreg", + ".maxntid", + ".minnctapersm", ".min", ".nan", ".NaN", @@ -100,6 +103,7 @@ match { ".reg", ".relaxed", ".release", + ".reqntid", ".rm", ".rmi", ".rn", @@ -356,15 +360,27 @@ AddressSize = { Function: ast::Function<'input, &'input str, ast::Statement>> = { LinkingDirectives + => ast::Function{<>} }; - + LinkingDirective: ast::LinkingDirective = { ".extern" => ast::LinkingDirective::EXTERN, ".visible" => ast::LinkingDirective::VISIBLE, ".weak" => ast::LinkingDirective::WEAK, }; +TuningDirective: ast::TuningDirective = { + ".maxnreg" => ast::TuningDirective::MaxNReg(ncta), + ".maxntid" => ast::TuningDirective::MaxNtid(nx, 1, 1), + ".maxntid" "," => ast::TuningDirective::MaxNtid(nx, ny, 1), + ".maxntid" "," "," => ast::TuningDirective::MaxNtid(nx, ny, nz), + ".reqntid" => ast::TuningDirective::ReqNtid(nx, 1, 1), + ".reqntid" "," => ast::TuningDirective::ReqNtid(nx, ny, 1), + ".reqntid" "," "," => ast::TuningDirective::ReqNtid(nx, ny, nz), + ".minnctapersm" => ast::TuningDirective::MinNCtaPerSm(ncta), +}; + LinkingDirectives: ast::LinkingDirective = { => { ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y) diff --git a/ptx/src/test/spirv_run/add_tuning.ptx b/ptx/src/test/spirv_run/add_tuning.ptx new file mode 100644 index 0000000..2a5dcf8 --- /dev/null +++ b/ptx/src/test/spirv_run/add_tuning.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry add_tuning( + .param .u64 input, + .param .u64 output +) +.maxntid 256, 1, 1 +.minnctapersm 4 +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/add_tuning.spvtxt b/ptx/src/test/spirv_run/add_tuning.spvtxt new file mode 100644 index 0000000..173e0d4 --- /dev/null +++ b/ptx/src/test/spirv_run/add_tuning.spvtxt @@ -0,0 +1,48 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %23 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add_tuning" + OpExecutionMode %1 MaxWorkgroupSizeINTEL 256 1 1 + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %26 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %26 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %21 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %10 + %11 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 + %12 = OpLoad %ulong %19 Aligned 8 + OpStore %6 %12 + %15 = OpLoad %ulong %6 + %14 = OpIAdd %ulong %15 %ulong_1 + OpStore %7 %14 + %16 = OpLoad %ulong %5 + %17 = OpLoad %ulong %7 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 + OpStore %20 %17 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 91e6113..4178e2f 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -152,6 +152,7 @@ test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]); // For now, we just make sure that it builds and links test_ptx!(assertfail, [716523871u64], [716523872u64]); test_ptx!(cvt_s64_s32, [-1i32], [-1i64]); +test_ptx!(add_tuning, [2u64], [3u64]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7efcaf6..da0cc07 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -589,7 +589,7 @@ fn emit_directives<'input>( for var in f.globals.iter() { emit_variable(builder, map, var)?; } - emit_function_header( + let fn_id = emit_function_header( builder, map, &id_defs, @@ -600,6 +600,27 @@ fn emit_directives<'input>( &directives, kernel_info, )?; + for t in f.tuning.iter() { + match *t { + ast::TuningDirective::MaxNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id, + spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, + [nx, ny, nz], + ); + } + ast::TuningDirective::ReqNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id, + spirv_headers::ExecutionMode::LocalSize, + [nx, ny, nz], + ); + } + // Too architecture specific + ast::TuningDirective::MaxNReg(..) + | ast::TuningDirective::MinNCtaPerSm(..) => {} + } + } emit_function_body_ops(builder, map, opencl_id, &f_body)?; builder.end_function()?; if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) = @@ -729,6 +750,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, spirv_decl, + tuning, }) => { let call_key = MethodName::new(&func_decl); let statements = statements @@ -752,6 +774,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, spirv_decl, + tuning, }) } directive => directive, @@ -770,6 +793,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, mut spirv_decl, + tuning, }) => { if !methods_using_extern_shared.contains(&spirv_decl.name) { return Directive::Method(Function { @@ -778,6 +802,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, spirv_decl, + tuning, }); } let shared_id_param = new_id(); @@ -827,6 +852,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(new_statements), import_as, spirv_decl, + tuning, }) } directive => directive, @@ -1044,9 +1070,7 @@ fn emit_builtins( builder.decorate( id, spirv::Decoration::BuiltIn, - [dr::Operand::BuiltIn(reg.get_builtin())] - .iter() - .cloned(), + [dr::Operand::BuiltIn(reg.get_builtin())].iter().cloned(), ); } } @@ -1061,7 +1085,7 @@ fn emit_function_header<'a>( call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, -) -> Result<(), TranslateError> { +) -> Result { if let MethodName::Kernel(name) = func_decl.name { let input_args = if !func_decl.uses_shared_mem { func_decl.input.as_slice() @@ -1143,7 +1167,7 @@ fn emit_function_header<'a>( let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone())); builder.function_parameter(Some(input.name), result_type)?; } - Ok(()) + Ok(fn_id) } fn emit_capabilities(builder: &mut dr::Builder) { @@ -1235,7 +1259,14 @@ fn translate_function<'a>( _ => None, }; let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; - let mut func = to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)?; + let mut func = to_ssa( + ptx_impl_imports, + str_resolver, + fn_resolver, + fn_decl, + f.body, + f.tuning, + )?; func.import_as = import_as; if func.import_as.is_some() { ptx_impl_imports.insert( @@ -1293,6 +1324,7 @@ fn to_ssa<'input, 'b>( fn_defs: GlobalFnDeclResolver<'input, 'b>, f_args: ast::MethodDecl<'input, spirv::Word>, f_body: Option>>>, + tuning: Vec, ) -> Result, TranslateError> { let mut spirv_decl = SpirvMethodDecl::new(&f_args); let f_body = match f_body { @@ -1304,6 +1336,7 @@ fn to_ssa<'input, 'b>( globals: Vec::new(), import_as: None, spirv_decl, + tuning, }) } }; @@ -1335,6 +1368,7 @@ fn to_ssa<'input, 'b>( body: Some(f_body), import_as: None, spirv_decl, + tuning, }) } @@ -1716,6 +1750,7 @@ fn to_ptx_impl_atomic_call( body: None, import_as: Some(entry.key().clone()), spirv_decl, + tuning: Vec::new(), }; entry.insert(Directive::Method(func)); fn_id @@ -1809,6 +1844,7 @@ fn to_ptx_impl_bfe_call( body: None, import_as: Some(entry.key().clone()), spirv_decl, + tuning: Vec::new(), }; entry.insert(Directive::Method(func)); fn_id @@ -1907,6 +1943,7 @@ fn to_ptx_impl_bfi_call( body: None, import_as: Some(entry.key().clone()), spirv_decl, + tuning: Vec::new(), }; entry.insert(Directive::Method(func)); fn_id @@ -4112,16 +4149,11 @@ fn struct2_bitcast_to_wide( dst_type_id: spirv::Word, src: spirv::Word, ) -> Result<(), dr::Error> { - let low_bits = - builder.composite_extract(instruction_type, None, src, [0].iter().copied())?; - let high_bits = - builder.composite_extract(instruction_type, None, src, [1].iter().copied())?; + let low_bits = builder.composite_extract(instruction_type, None, src, [0].iter().copied())?; + let high_bits = builder.composite_extract(instruction_type, None, src, [1].iter().copied())?; let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2)); - let vector = builder.composite_construct( - vector_type, - None, - [low_bits, high_bits].iter().copied(), - )?; + let vector = + builder.composite_construct(vector_type, None, [low_bits, high_bits].iter().copied())?; builder.bitcast(dst_type_id, Some(dst), vector)?; Ok(()) } @@ -5668,6 +5700,7 @@ struct Function<'input> { pub globals: Vec>, pub body: Option>, import_as: Option, + tuning: Vec, } pub trait ArgumentMapVisitor {