Simplify compilation of globals in initalizers, fix bfind.u64

This commit is contained in:
Andrzej Janik 2024-03-03 17:26:23 +01:00
parent 4b4f33e29e
commit 383dde6b35
4 changed files with 19 additions and 185 deletions

View file

@ -1383,19 +1383,12 @@ pub enum TextureGeometry {
#[derive(Clone)]
pub enum Initializer<ID> {
Constant(ImmediateValue),
Global(ID, InitializerType),
GenericGlobal(ID, InitializerType),
Global(ID),
GenericGlobal(ID),
Add(Box<(Initializer<ID>, Initializer<ID>)>),
Array(Vec<Initializer<ID>>),
}
#[derive(Clone)]
pub enum InitializerType {
Unknown,
Value(Type),
Function(Vec<Type>, Vec<Type>),
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -402,27 +402,20 @@ unsafe fn get_llvm_const(
let const2 = get_llvm_const(ctx, type_, Some(init2))?;
LLVMConstAdd(const1, const2)
}
(_, Some(ast::Initializer::Global(id, type_))) => {
(_, Some(ast::Initializer::Global(id))) => {
let name = ctx.names.value(id)?;
let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?;
let mut zero = LLVMConstInt(b64, 0, 0);
let src_type = get_initializer_llvm_type(ctx, type_)?;
let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1);
LLVMConstPtrToInt(global_ptr, b64)
LLVMConstPtrToInt(name, b64)
}
(_, Some(ast::Initializer::GenericGlobal(id, type_))) => {
(_, Some(ast::Initializer::GenericGlobal(id))) => {
let name = ctx.names.value(id)?;
let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?;
let mut zero = LLVMConstInt(b64, 0, 0);
let src_type = get_initializer_llvm_type(ctx, type_)?;
let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1);
// void pointers are illegal in LLVM IR
let b8 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?;
let b8_generic_ptr = LLVMPointerType(
b8,
get_llvm_address_space(&ctx.constants, ast::StateSpace::Generic)?,
);
let generic_ptr = LLVMConstAddrSpaceCast(global_ptr, b8_generic_ptr);
let generic_ptr = LLVMConstAddrSpaceCast(name, b8_generic_ptr);
let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?;
LLVMConstPtrToInt(generic_ptr, b64)
}
_ => return Err(TranslateError::todo()),
@ -430,28 +423,6 @@ unsafe fn get_llvm_const(
Ok(const_value)
}
fn get_initializer_llvm_type(
ctx: &mut EmitContext,
type_: ast::InitializerType,
) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::InitializerType::Unknown => return Err(TranslateError::unreachable()),
ast::InitializerType::Value(type_) => get_llvm_type(ctx, &type_)?,
ast::InitializerType::Function(return_args, input_args) => {
let return_type = match &*return_args {
[] => llvm::void_type(&ctx.context),
[type_] => get_llvm_type(ctx, type_)?,
[..] => get_llvm_type_struct(ctx, return_args.into_iter().map(Cow::Owned))?,
};
get_llvm_function_type(
ctx,
return_type,
input_args.iter().map(|type_| (type_, ast::StateSpace::Reg)),
)?
}
})
}
unsafe fn get_llvm_const_scalar(
ctx: &mut EmitContext,
scalar_type: ast::ScalarType,
@ -1305,7 +1276,8 @@ fn emit_inst_bfind(
let builder = ctx.builder.get();
let src = arg.src.get_llvm_value(&mut ctx.names)?;
let llvm_dst_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?;
let const_0 = unsafe { LLVMConstInt(llvm_dst_type, 0, 0) };
let llvm_src_type = get_llvm_type(ctx, &ast::Type::Scalar(details.type_))?;
let const_0 = unsafe { LLVMConstInt(llvm_src_type, 0, 0) };
let const_int_max = unsafe { LLVMConstInt(llvm_dst_type, u64::MAX, 0) };
let is_zero = unsafe {
LLVMBuildICmp(
@ -1316,7 +1288,7 @@ fn emit_inst_bfind(
LLVM_UNNAMED,
)
};
let mut clz_result = emit_inst_clz_impl(ctx, ast::ScalarType::U32, None, arg.src, true)?;
let mut clz_result = emit_inst_clz_impl(ctx, details.type_, None, arg.src, true)?;
if !details.shift {
let bits = unsafe {
LLVMConstInt(

View file

@ -652,8 +652,8 @@ Initializer: ast::Initializer<&'input str> = {
InitializerNoAdd: ast::Initializer<&'input str> = {
<val:ImmediateValue> => ast::Initializer::Constant(val),
<id:ExtendedID> => ast::Initializer::Global(id, ast::InitializerType::Unknown),
"generic" "(" <id:ExtendedID> ")" => ast::Initializer::GenericGlobal(id, ast::InitializerType::Unknown),
<id:ExtendedID> => ast::Initializer::Global(id),
"generic" "(" <id:ExtendedID> ")" => ast::Initializer::GenericGlobal(id),
"{" <array_init:Comma<Initializer>> "}" => ast::Initializer::Array(array_init)
}

View file

@ -1031,10 +1031,8 @@ fn normalize_method<'a, 'b, 'input>(
normalize_method_params(&mut fn_scope, &*method.func_directive.return_arguments)?;
let input_arguments =
normalize_method_params(&mut fn_scope, &*method.func_directive.input_arguments)?;
if !is_kernel {
if let hash_map::Entry::Vacant(entry) = function_decls.entry(name) {
entry.insert((return_arguments.clone(), input_arguments.clone()));
}
if let hash_map::Entry::Vacant(entry) = function_decls.entry(name) {
entry.insert((return_arguments.clone(), input_arguments.clone()));
}
let source_name = if has_global_name {
Some(Cow::Borrowed(method.func_directive.name()))
@ -1188,11 +1186,9 @@ fn expand_initializer2<'a, 'b, 'input>(
) -> Result<ast::Initializer<Id>, TranslateError> {
Ok(match init {
ast::Initializer::Constant(c) => ast::Initializer::Constant(c),
ast::Initializer::Global(g, type_) => {
ast::Initializer::Global(scope.get_id_in_module_scope(g)?, type_)
}
ast::Initializer::GenericGlobal(g, type_) => {
ast::Initializer::GenericGlobal(scope.get_id_in_module_scope(g)?, type_)
ast::Initializer::Global(g) => ast::Initializer::Global(scope.get_id_in_module_scope(g)?),
ast::Initializer::GenericGlobal(g) => {
ast::Initializer::GenericGlobal(scope.get_id_in_module_scope(g)?)
}
ast::Initializer::Add(add) => {
let (init1, init2) = *add;
@ -1285,11 +1281,7 @@ fn resolve_instruction_types<'input>(
.map(|directive| {
Ok(match directive {
TranslationDirective::Variable(linking, compiled_name, var) => {
TranslationDirective::Variable(
linking,
compiled_name,
resolve_initializers(id_defs, var)?,
)
TranslationDirective::Variable(linking, compiled_name, var)
}
TranslationDirective::Method(method) => {
let body = match method.body {
@ -1461,9 +1453,7 @@ fn resolve_instruction_types_method<'input>(
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => {
result.push(Statement::Variable(resolve_initializers(id_defs, v)?))
}
Statement::Variable(v) => result.push(Statement::Variable(v)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
_ => return Err(TranslateError::unreachable()),
}
@ -1471,42 +1461,6 @@ fn resolve_instruction_types_method<'input>(
Ok(result)
}
fn resolve_initializers<'input>(
id_defs: &mut IdNameMapBuilder<'input>,
mut v: Variable,
) -> Result<Variable, TranslateError> {
fn resolve_initializer_impl<'input>(
id_defs: &mut IdNameMapBuilder<'input>,
init: &mut ast::Initializer<Id>,
) -> Result<(), TranslateError> {
match init {
ast::Initializer::Constant(_) => {}
ast::Initializer::Global(name, type_)
| ast::Initializer::GenericGlobal(name, type_) => {
*type_ = if let Some((src_type, _, _, _)) = id_defs.try_get_typed(*name)? {
ast::InitializerType::Value(src_type)
} else {
ast::InitializerType::Unknown
};
}
ast::Initializer::Add(subinit) => {
resolve_initializer_impl(id_defs, &mut (*subinit).0)?;
resolve_initializer_impl(id_defs, &mut (*subinit).1)?;
}
ast::Initializer::Array(inits) => {
for init in inits.iter_mut() {
resolve_initializer_impl(id_defs, init)?;
}
}
}
Ok(())
}
if let Some(ref mut init) = v.initializer {
resolve_initializer_impl(id_defs, init)?;
}
Ok(v)
}
// TODO: All this garbage should be replaced with proper constant propagation or
// at least ability to visit statements without moving them
struct KernelConstantsVisitor {
@ -3370,7 +3324,6 @@ fn to_llvm_module_impl2<'a, 'input>(
// raytracing passes rely heavily on particular PTX patterns, they must run before implicit conversions
translation_module = raytracing::postprocess(translation_module, raytracing_state)?;
}
let translation_module = resolve_type_of_global_fnptrs(translation_module)?;
let translation_module = insert_implicit_conversions(translation_module)?;
let translation_module = insert_compilation_mode_prologue(translation_module);
let translation_module = normalize_labels(translation_module)?;
@ -3402,76 +3355,6 @@ fn to_llvm_module_impl2<'a, 'input>(
})
}
fn resolve_type_of_global_fnptrs(
mut translation_module: TranslationModule<ExpandedArgParams>,
) -> Result<TranslationModule<ExpandedArgParams>, TranslateError> {
let mut functions: FxHashMap<Id, (Vec<ast::Type>, Vec<ast::Type>)> = FxHashMap::default();
for directive in translation_module.directives.iter_mut() {
match directive {
TranslationDirective::Variable(_, _, variable) => {
if let Some(ref mut initializer) = variable.initializer {
set_iniitalizer_type(&mut functions, initializer);
}
}
TranslationDirective::Method(method) => {
if method.is_kernel {
continue;
}
match functions.entry(method.name) {
hash_map::Entry::Occupied(_) => {}
hash_map::Entry::Vacant(entry) => {
entry.insert((
extract_argument_types(&method.return_arguments)?,
extract_argument_types(&method.input_arguments)?,
));
}
}
}
}
}
Ok(translation_module)
}
fn extract_argument_types(
args: &[ast::VariableDeclaration<Id>],
) -> Result<Vec<ast::Type>, TranslateError> {
args.iter()
.map(|var| {
if var.state_space != ast::StateSpace::Reg {
return Err(TranslateError::unreachable());
}
Ok(var.type_.clone())
})
.collect()
}
fn set_iniitalizer_type(
functions: &mut FxHashMap<Id, (Vec<ast::Type>, Vec<ast::Type>)>,
initializer: &mut ast::Initializer<Id>,
) {
match initializer {
ast::Initializer::Constant(_) => {}
ast::Initializer::Global(name, type_) | ast::Initializer::GenericGlobal(name, type_) => {
if let Some((return_arguments, input_arguments)) = functions.get(name) {
*type_ = ast::InitializerType::Function(
return_arguments.clone(),
input_arguments.clone(),
);
}
}
ast::Initializer::Add(add) => {
let (add1, add2) = &mut **add;
set_iniitalizer_type(functions, add1);
set_iniitalizer_type(functions, add2);
}
ast::Initializer::Array(array) => {
for initializer in array.iter_mut() {
set_iniitalizer_type(functions, initializer);
}
}
}
}
// In PTX it's legal to have a function like this:
// .func noreturn(.param .b64 noreturn_0)
// .noreturn
@ -5281,20 +5164,6 @@ impl<'input> IdNameMapBuilder<'input> {
}
}
pub(crate) fn try_get_typed(
&self,
id: Id,
) -> Result<Option<(ast::Type, ast::StateSpace, Option<u32>, bool)>, TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(Some(x.clone())),
Some(None) => Ok(None),
None => match self.globals.special_registers.get(id) {
Some(x) => Ok(Some((x.get_type(), ast::StateSpace::Sreg, None, true))),
None => Err(TranslateError::untyped_symbol()),
},
}
}
pub(crate) fn get_typed(
&self,
id: Id,