diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 225fc1d..bf5de1c 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1316,6 +1316,7 @@ pub enum TuningDirective { MaxNtid(u32, u32, u32), ReqNtid(u32, u32, u32), MinNCtaPerSm(u32), + Noreturn } #[repr(u8)] diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index af72f89..ddfddae 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -575,6 +575,17 @@ fn emit_tuning_single<'a>( format!("{0},{0}", size).as_bytes(), ); } + ast::TuningDirective::Noreturn => { + let noreturn = b"noreturn"; + let attr_kind = unsafe { + LLVMGetEnumAttributeKindForName(noreturn.as_ptr().cast(), noreturn.len()) + }; + if attr_kind == 0 { + panic!(); + } + let noreturn = unsafe { LLVMCreateEnumAttribute(ctx.context.get(), attr_kind, 0) }; + unsafe { LLVMAddAttributeAtIndex(llvm_method, LLVMAttributeFunctionIndex, noreturn) }; + } } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 547810f..e1aaa3b 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -116,6 +116,7 @@ match { ".ne", ".neu", ".noftz", + ".noreturn", ".num", ".or", ".param", @@ -531,6 +532,8 @@ LinkingDirective: ast::LinkingDirective = { }; TuningDirective: ast::TuningDirective = { + // not a performance tuning directive but fits here in the grammar + ".noreturn" => ast::TuningDirective::Noreturn, ".maxnreg" => ast::TuningDirective::MaxNReg(ncta), ".maxntid" => ast::TuningDirective::MaxNtid(nx, 1, 1), ".maxntid" "," => ast::TuningDirective::MaxNtid(nx, ny, 1), diff --git a/ptx/src/test/spirv_build/noreturn.ll b/ptx/src/test/spirv_build/noreturn.ll new file mode 100644 index 0000000..286b289 --- /dev/null +++ b/ptx/src/test/spirv_build/noreturn.ll @@ -0,0 +1,19 @@ +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7" +target triple = "amdgcn-amd-amdhsa" + +; Function Attrs: noreturn +define private void @noreturn(i64 %"6") #0 { +"9": + %"3" = alloca i64, align 8, addrspace(5) + %"4" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"4", align 1 + %"5" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"5", align 1 + %"8" = alloca i64, align 8, addrspace(5) + store i64 %"6", ptr addrspace(5) %"3", align 8 + %"7" = load i64, ptr addrspace(5) %"3", align 8 + store i64 %"7", ptr addrspace(5) %"8", align 8 + ret void +} + +attributes #0 = { noreturn "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_build/noreturn.ptx b/ptx/src/test/spirv_build/noreturn.ptx new file mode 100644 index 0000000..fd34bc6 --- /dev/null +++ b/ptx/src/test/spirv_build/noreturn.ptx @@ -0,0 +1,8 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.weak .func noreturn(.param .b64 noreturn_0) +.noreturn +{ +} \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index f7fd281..c63a258 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -3345,6 +3345,7 @@ fn to_llvm_module_impl2<'a, 'input>( if let Some(ref mut raytracing_state) = raytracing { translation_module = raytracing::run_on_normalized(translation_module, raytracing_state)?; } + let translation_module = return_from_noreturn(translation_module); let translation_module = extract_builtin_functions(translation_module); let translation_module = resolve_instruction_types(translation_module, functions)?; let mut translation_module = restructure_function_return_types(translation_module)?; @@ -3392,6 +3393,32 @@ fn to_llvm_module_impl2<'a, 'input>( }) } +// In PTX it's legal to have a function like this: +// .func noreturn(.param .b64 noreturn_0) +// .noreturn +// { +// } +// Which trips up LLVM. We normalize this by inserting `ret;` +fn return_from_noreturn( + mut translation_module: TranslationModule, +) -> TranslationModule { + for directive in translation_module.directives.iter_mut() { + match directive { + TranslationDirective::Method(method) => { + if let Some(ref mut body) = method.body { + if body.is_empty() && method.tuning.contains(&ast::TuningDirective::Noreturn) { + body.push(Statement::Instruction(ast::Instruction::Ret( + ast::RetData { uniform: false }, + ))); + } + } + } + TranslationDirective::Variable(..) => {} + } + } + translation_module +} + // From "Performance Tips for Frontend Authors" (https://llvm.org/docs/Frontend/PerformanceTips.html): // "The SROA (Scalar Replacement Of Aggregates) and Mem2Reg passes only attempt to eliminate alloca // instructions that are in the entry basic block. Given SSA is the canonical form expected by much @@ -3586,7 +3613,8 @@ fn create_metadata<'input>( match tuning { // TODO: measure ast::TuningDirective::MaxNReg(_) - | ast::TuningDirective::MinNCtaPerSm(_) => {} + | ast::TuningDirective::MinNCtaPerSm(_) + | ast::TuningDirective::Noreturn => {} ast::TuningDirective::MaxNtid(x, y, z) => { let size = x as u64 * y as u64 * z as u64; kernel_metadata.push(( @@ -3632,7 +3660,8 @@ fn insert_compilation_mode_prologue<'input>( for t in tuning.iter_mut() { match t { ast::TuningDirective::MaxNReg(_) - | ast::TuningDirective::MinNCtaPerSm(_) => {} + | ast::TuningDirective::MinNCtaPerSm(_) + | ast::TuningDirective::Noreturn => {} ast::TuningDirective::MaxNtid(_, _, z) => { *z *= 2; }