From 383dde6b355b45be454064fc2d6f7d19319436e3 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 3 Mar 2024 17:26:23 +0100 Subject: [PATCH] Simplify compilation of globals in initalizers, fix bfind.u64 --- ptx/src/ast.rs | 11 +--- ptx/src/emit.rs | 44 +++---------- ptx/src/ptx.lalrpop | 4 +- ptx/src/translate.rs | 145 +++---------------------------------------- 4 files changed, 19 insertions(+), 185 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 994cf7c..d3b9403 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1383,19 +1383,12 @@ pub enum TextureGeometry { #[derive(Clone)] pub enum Initializer { Constant(ImmediateValue), - Global(ID, InitializerType), - GenericGlobal(ID, InitializerType), + Global(ID), + GenericGlobal(ID), Add(Box<(Initializer, Initializer)>), Array(Vec>), } -#[derive(Clone)] -pub enum InitializerType { - Unknown, - Value(Type), - Function(Vec, Vec), -} - #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index dbbed6e..346cc64 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -402,27 +402,20 @@ unsafe fn get_llvm_const( let const2 = get_llvm_const(ctx, type_, Some(init2))?; LLVMConstAdd(const1, const2) } - (_, Some(ast::Initializer::Global(id, type_))) => { + (_, Some(ast::Initializer::Global(id))) => { let name = ctx.names.value(id)?; let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?; - let mut zero = LLVMConstInt(b64, 0, 0); - let src_type = get_initializer_llvm_type(ctx, type_)?; - let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1); - LLVMConstPtrToInt(global_ptr, b64) + LLVMConstPtrToInt(name, b64) } - (_, Some(ast::Initializer::GenericGlobal(id, type_))) => { + (_, Some(ast::Initializer::GenericGlobal(id))) => { let name = ctx.names.value(id)?; - let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?; - let mut zero = LLVMConstInt(b64, 0, 0); - let src_type = get_initializer_llvm_type(ctx, type_)?; - let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1); - // void pointers are illegal in LLVM IR let b8 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?; let b8_generic_ptr = LLVMPointerType( b8, get_llvm_address_space(&ctx.constants, ast::StateSpace::Generic)?, ); - let generic_ptr = LLVMConstAddrSpaceCast(global_ptr, b8_generic_ptr); + let generic_ptr = LLVMConstAddrSpaceCast(name, b8_generic_ptr); + let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?; LLVMConstPtrToInt(generic_ptr, b64) } _ => return Err(TranslateError::todo()), @@ -430,28 +423,6 @@ unsafe fn get_llvm_const( Ok(const_value) } -fn get_initializer_llvm_type( - ctx: &mut EmitContext, - type_: ast::InitializerType, -) -> Result { - Ok(match type_ { - ast::InitializerType::Unknown => return Err(TranslateError::unreachable()), - ast::InitializerType::Value(type_) => get_llvm_type(ctx, &type_)?, - ast::InitializerType::Function(return_args, input_args) => { - let return_type = match &*return_args { - [] => llvm::void_type(&ctx.context), - [type_] => get_llvm_type(ctx, type_)?, - [..] => get_llvm_type_struct(ctx, return_args.into_iter().map(Cow::Owned))?, - }; - get_llvm_function_type( - ctx, - return_type, - input_args.iter().map(|type_| (type_, ast::StateSpace::Reg)), - )? - } - }) -} - unsafe fn get_llvm_const_scalar( ctx: &mut EmitContext, scalar_type: ast::ScalarType, @@ -1305,7 +1276,8 @@ fn emit_inst_bfind( let builder = ctx.builder.get(); let src = arg.src.get_llvm_value(&mut ctx.names)?; let llvm_dst_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?; - let const_0 = unsafe { LLVMConstInt(llvm_dst_type, 0, 0) }; + let llvm_src_type = get_llvm_type(ctx, &ast::Type::Scalar(details.type_))?; + let const_0 = unsafe { LLVMConstInt(llvm_src_type, 0, 0) }; let const_int_max = unsafe { LLVMConstInt(llvm_dst_type, u64::MAX, 0) }; let is_zero = unsafe { LLVMBuildICmp( @@ -1316,7 +1288,7 @@ fn emit_inst_bfind( LLVM_UNNAMED, ) }; - let mut clz_result = emit_inst_clz_impl(ctx, ast::ScalarType::U32, None, arg.src, true)?; + let mut clz_result = emit_inst_clz_impl(ctx, details.type_, None, arg.src, true)?; if !details.shift { let bits = unsafe { LLVMConstInt( diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 5345066..daad23d 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -652,8 +652,8 @@ Initializer: ast::Initializer<&'input str> = { InitializerNoAdd: ast::Initializer<&'input str> = { => ast::Initializer::Constant(val), - => ast::Initializer::Global(id, ast::InitializerType::Unknown), - "generic" "(" ")" => ast::Initializer::GenericGlobal(id, ast::InitializerType::Unknown), + => ast::Initializer::Global(id), + "generic" "(" ")" => ast::Initializer::GenericGlobal(id), "{" > "}" => ast::Initializer::Array(array_init) } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a907e5e..79c070b 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1031,10 +1031,8 @@ fn normalize_method<'a, 'b, 'input>( normalize_method_params(&mut fn_scope, &*method.func_directive.return_arguments)?; let input_arguments = normalize_method_params(&mut fn_scope, &*method.func_directive.input_arguments)?; - if !is_kernel { - if let hash_map::Entry::Vacant(entry) = function_decls.entry(name) { - entry.insert((return_arguments.clone(), input_arguments.clone())); - } + if let hash_map::Entry::Vacant(entry) = function_decls.entry(name) { + entry.insert((return_arguments.clone(), input_arguments.clone())); } let source_name = if has_global_name { Some(Cow::Borrowed(method.func_directive.name())) @@ -1188,11 +1186,9 @@ fn expand_initializer2<'a, 'b, 'input>( ) -> Result, TranslateError> { Ok(match init { ast::Initializer::Constant(c) => ast::Initializer::Constant(c), - ast::Initializer::Global(g, type_) => { - ast::Initializer::Global(scope.get_id_in_module_scope(g)?, type_) - } - ast::Initializer::GenericGlobal(g, type_) => { - ast::Initializer::GenericGlobal(scope.get_id_in_module_scope(g)?, type_) + ast::Initializer::Global(g) => ast::Initializer::Global(scope.get_id_in_module_scope(g)?), + ast::Initializer::GenericGlobal(g) => { + ast::Initializer::GenericGlobal(scope.get_id_in_module_scope(g)?) } ast::Initializer::Add(add) => { let (init1, init2) = *add; @@ -1285,11 +1281,7 @@ fn resolve_instruction_types<'input>( .map(|directive| { Ok(match directive { TranslationDirective::Variable(linking, compiled_name, var) => { - TranslationDirective::Variable( - linking, - compiled_name, - resolve_initializers(id_defs, var)?, - ) + TranslationDirective::Variable(linking, compiled_name, var) } TranslationDirective::Method(method) => { let body = match method.body { @@ -1461,9 +1453,7 @@ fn resolve_instruction_types_method<'input>( } }, Statement::Label(i) => result.push(Statement::Label(i)), - Statement::Variable(v) => { - result.push(Statement::Variable(resolve_initializers(id_defs, v)?)) - } + Statement::Variable(v) => result.push(Statement::Variable(v)), Statement::Conditional(c) => result.push(Statement::Conditional(c)), _ => return Err(TranslateError::unreachable()), } @@ -1471,42 +1461,6 @@ fn resolve_instruction_types_method<'input>( Ok(result) } -fn resolve_initializers<'input>( - id_defs: &mut IdNameMapBuilder<'input>, - mut v: Variable, -) -> Result { - fn resolve_initializer_impl<'input>( - id_defs: &mut IdNameMapBuilder<'input>, - init: &mut ast::Initializer, - ) -> Result<(), TranslateError> { - match init { - ast::Initializer::Constant(_) => {} - ast::Initializer::Global(name, type_) - | ast::Initializer::GenericGlobal(name, type_) => { - *type_ = if let Some((src_type, _, _, _)) = id_defs.try_get_typed(*name)? { - ast::InitializerType::Value(src_type) - } else { - ast::InitializerType::Unknown - }; - } - ast::Initializer::Add(subinit) => { - resolve_initializer_impl(id_defs, &mut (*subinit).0)?; - resolve_initializer_impl(id_defs, &mut (*subinit).1)?; - } - ast::Initializer::Array(inits) => { - for init in inits.iter_mut() { - resolve_initializer_impl(id_defs, init)?; - } - } - } - Ok(()) - } - if let Some(ref mut init) = v.initializer { - resolve_initializer_impl(id_defs, init)?; - } - Ok(v) -} - // TODO: All this garbage should be replaced with proper constant propagation or // at least ability to visit statements without moving them struct KernelConstantsVisitor { @@ -3370,7 +3324,6 @@ fn to_llvm_module_impl2<'a, 'input>( // raytracing passes rely heavily on particular PTX patterns, they must run before implicit conversions translation_module = raytracing::postprocess(translation_module, raytracing_state)?; } - let translation_module = resolve_type_of_global_fnptrs(translation_module)?; let translation_module = insert_implicit_conversions(translation_module)?; let translation_module = insert_compilation_mode_prologue(translation_module); let translation_module = normalize_labels(translation_module)?; @@ -3402,76 +3355,6 @@ fn to_llvm_module_impl2<'a, 'input>( }) } -fn resolve_type_of_global_fnptrs( - mut translation_module: TranslationModule, -) -> Result, TranslateError> { - let mut functions: FxHashMap, Vec)> = FxHashMap::default(); - for directive in translation_module.directives.iter_mut() { - match directive { - TranslationDirective::Variable(_, _, variable) => { - if let Some(ref mut initializer) = variable.initializer { - set_iniitalizer_type(&mut functions, initializer); - } - } - TranslationDirective::Method(method) => { - if method.is_kernel { - continue; - } - match functions.entry(method.name) { - hash_map::Entry::Occupied(_) => {} - hash_map::Entry::Vacant(entry) => { - entry.insert(( - extract_argument_types(&method.return_arguments)?, - extract_argument_types(&method.input_arguments)?, - )); - } - } - } - } - } - Ok(translation_module) -} - -fn extract_argument_types( - args: &[ast::VariableDeclaration], -) -> Result, TranslateError> { - args.iter() - .map(|var| { - if var.state_space != ast::StateSpace::Reg { - return Err(TranslateError::unreachable()); - } - Ok(var.type_.clone()) - }) - .collect() -} - -fn set_iniitalizer_type( - functions: &mut FxHashMap, Vec)>, - initializer: &mut ast::Initializer, -) { - match initializer { - ast::Initializer::Constant(_) => {} - ast::Initializer::Global(name, type_) | ast::Initializer::GenericGlobal(name, type_) => { - if let Some((return_arguments, input_arguments)) = functions.get(name) { - *type_ = ast::InitializerType::Function( - return_arguments.clone(), - input_arguments.clone(), - ); - } - } - ast::Initializer::Add(add) => { - let (add1, add2) = &mut **add; - set_iniitalizer_type(functions, add1); - set_iniitalizer_type(functions, add2); - } - ast::Initializer::Array(array) => { - for initializer in array.iter_mut() { - set_iniitalizer_type(functions, initializer); - } - } - } -} - // In PTX it's legal to have a function like this: // .func noreturn(.param .b64 noreturn_0) // .noreturn @@ -5281,20 +5164,6 @@ impl<'input> IdNameMapBuilder<'input> { } } - pub(crate) fn try_get_typed( - &self, - id: Id, - ) -> Result, bool)>, TranslateError> { - match self.type_check.get(&id) { - Some(Some(x)) => Ok(Some(x.clone())), - Some(None) => Ok(None), - None => match self.globals.special_registers.get(id) { - Some(x) => Ok(Some((x.get_type(), ast::StateSpace::Sreg, None, true))), - None => Err(TranslateError::untyped_symbol()), - }, - } - } - pub(crate) fn get_typed( &self, id: Id,