Support lists of variables to be declared (#516)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled
ZLUDA / Build AMD GPU unit tests (push) Has been cancelled
ZLUDA / Run AMD GPU unit tests (push) Has been cancelled

For example,

```
.reg .u32 a, b;
```
This commit is contained in:
Violet 2025-09-19 13:36:48 -07:00 committed by GitHub
commit 875ac13be2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 315 additions and 187 deletions

View file

@ -29,22 +29,24 @@ fn run_method<'input>(
let mut remap_returns = Vec::new(); let mut remap_returns = Vec::new();
if !method.is_kernel { if !method.is_kernel {
for arg in method.return_arguments.iter_mut() { for arg in method.return_arguments.iter_mut() {
match arg.state_space { match arg.info.state_space {
ptx_parser::StateSpace::Param => { ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg; arg.info.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name; let old_name = arg.name;
arg.name = arg.name = resolver
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); .register_unnamed(Some((arg.info.v_type.clone(), arg.info.state_space)));
if is_declaration { if is_declaration {
continue; continue;
} }
remap_returns.push((old_name, arg.name, arg.v_type.clone())); remap_returns.push((old_name, arg.name, arg.info.v_type.clone()));
body.push(Statement::Variable(ast::Variable { body.push(Statement::Variable(ast::Variable {
info: ast::VariableInfo {
align: None, align: None,
name: old_name, v_type: arg.info.v_type.clone(),
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param, state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(), array_init: Vec::new(),
},
name: old_name,
})); }));
} }
ptx_parser::StateSpace::Reg => {} ptx_parser::StateSpace::Reg => {}
@ -52,28 +54,30 @@ fn run_method<'input>(
} }
} }
for arg in method.input_arguments.iter_mut() { for arg in method.input_arguments.iter_mut() {
match arg.state_space { match arg.info.state_space {
ptx_parser::StateSpace::Param => { ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg; arg.info.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name; let old_name = arg.name;
arg.name = arg.name = resolver
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); .register_unnamed(Some((arg.info.v_type.clone(), arg.info.state_space)));
if is_declaration { if is_declaration {
continue; continue;
} }
body.push(Statement::Variable(ast::Variable { body.push(Statement::Variable(ast::Variable {
info: ast::VariableInfo {
align: None, align: None,
name: old_name, v_type: arg.info.v_type.clone(),
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param, state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(), array_init: Vec::new(),
},
name: old_name,
})); }));
body.push(Statement::Instruction(ast::Instruction::St { body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData { data: ast::StData {
qualifier: ast::LdStQualifier::Weak, qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param, state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough, caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(), typ: arg.info.v_type.clone(),
}, },
arguments: ast::StArgs { arguments: ast::StArgs {
src1: old_name, src1: old_name,

View file

@ -30,11 +30,19 @@ fn run_function<'input>(
statements statements
.into_iter() .into_iter()
.filter_map(|statement| match statement { .filter_map(|statement| match statement {
Statement::Variable(var @ ast::Variable { Statement::Variable(
var @ ast::Variable {
info:
ast::VariableInfo {
state_space: state_space:
ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::Shared,
.. ..
}) => { },
..
},
) => {
result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
None None
} }

View file

@ -40,14 +40,14 @@ fn run_method<'a, 'input>(
if is_kernel { if is_kernel {
for arg in method.input_arguments.iter_mut() { for arg in method.input_arguments.iter_mut() {
let old_name = arg.name; let old_name = arg.name;
let old_space = arg.state_space; let old_space = arg.info.state_space;
let new_space = ast::StateSpace::ParamEntry; let new_space = ast::StateSpace::ParamEntry;
let new_name = visitor let new_name = visitor
.resolver .resolver
.register_unnamed(Some((arg.v_type.clone(), new_space))); .register_unnamed(Some((arg.info.v_type.clone(), new_space)));
visitor.input_argument(old_name, new_name, old_space)?; visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name; arg.name = new_name;
arg.state_space = new_space; arg.info.state_space = new_space;
} }
}; };
for arg in method.return_arguments.iter_mut() { for arg in method.return_arguments.iter_mut() {
@ -83,10 +83,10 @@ fn run_statement<'a, 'input>(
return_arguments return_arguments
.iter() .iter()
.map(|arg| { .map(|arg| {
if arg.state_space != ast::StateSpace::Local { if arg.info.state_space != ast::StateSpace::Local {
return Err(error_unreachable()); return Err(error_unreachable());
} }
Ok((arg.name, arg.v_type.clone())) Ok((arg.name, arg.info.v_type.clone()))
}) })
.collect::<Result<Vec<_>, _>>()?, .collect::<Result<Vec<_>, _>>()?,
) )
@ -332,7 +332,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
} }
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> { fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space { let old_space = match var.info.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
// Do nothing // Do nothing
ptx_parser::StateSpace::Local => return Ok(()), ptx_parser::StateSpace::Local => return Ok(()),
@ -350,10 +350,10 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
let new_space = ast::StateSpace::Local; let new_space = ast::StateSpace::Local;
let new_name = self let new_name = self
.resolver .resolver
.register_unnamed(Some((var.v_type.clone(), new_space))); .register_unnamed(Some((var.info.v_type.clone(), new_space)));
self.variable(&var.v_type, old_name, new_name, old_space)?; self.variable(&var.info.v_type, old_name, new_name, old_space)?;
var.name = new_name; var.name = new_name;
var.state_space = new_space; var.info.state_space = new_space;
Ok(()) Ok(())
} }
} }

View file

@ -195,7 +195,7 @@ fn compile_methods(ptx: &str) -> Vec<Function2<ast::Instruction<SpirvWord>, Spir
let module = ptx_parser::parse_module_checked(ptx).unwrap(); let module = ptx_parser::parse_module_checked(ptx).unwrap();
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap(); let directives = normalize_identifiers::run(&mut scoped_resolver, module.directives).unwrap();
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives).unwrap(); let directives = normalize_basic_blocks::run(&mut flat_resolver, directives).unwrap();

View file

@ -122,11 +122,10 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
if fn_ == ptr::null_mut() { if fn_ == ptr::null_mut() {
let fn_type = get_function_type( let fn_type = get_function_type(
self.context, self.context,
method.return_arguments.iter().map(|v| &v.v_type), method.return_arguments.iter().map(|v| &v.info.v_type),
method method.input_arguments.iter().map(|v| {
.input_arguments get_input_argument_type(self.context, &v.info.v_type, v.info.state_space)
.iter() }),
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?; )?;
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true"); self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true");
@ -153,7 +152,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
for (i, param) in method.input_arguments.iter().enumerate() { for (i, param) in method.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) }; let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name); let name = self.resolver.get_or_add(param.name);
if let Some(align) = param.align { if let Some(align) = param.info.align {
unsafe { LLVMSetParamAlignment(value, align) }; unsafe { LLVMSetParamAlignment(value, align) };
} }
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
@ -166,7 +165,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
LLVMCreateTypeAttribute( LLVMCreateTypeAttribute(
self.context, self.context,
attr_kind, attr_kind,
get_type(self.context, &param.v_type)?, get_type(self.context, &param.info.v_type)?,
) )
}; };
unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
@ -241,17 +240,17 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
let global = unsafe { let global = unsafe {
LLVMAddGlobalInAddressSpace( LLVMAddGlobalInAddressSpace(
self.module, self.module,
get_type(self.context, &var.v_type)?, get_type(self.context, &var.info.v_type)?,
name.as_ptr(), name.as_ptr(),
get_state_space(var.state_space)?, get_state_space(var.info.state_space)?,
) )
}; };
self.resolver.register(var.name, global); self.resolver.register(var.name, global);
if let Some(align) = var.align { if let Some(align) = var.info.align {
unsafe { LLVMSetAlignment(global, align) }; unsafe { LLVMSetAlignment(global, align) };
} }
if !var.array_init.is_empty() { if !var.info.array_init.is_empty() {
let initializer = self.get_array_init(&var.v_type, &*var.array_init)?; let initializer = self.get_array_init(&var.info.v_type, &*var.info.array_init)?;
unsafe { LLVMSetInitializer(global, initializer) }; unsafe { LLVMSetInitializer(global, initializer) };
} }
Ok(()) Ok(())
@ -422,16 +421,16 @@ impl<'a> MethodEmitContext<'a> {
let alloca = unsafe { let alloca = unsafe {
LLVMZludaBuildAlloca( LLVMZludaBuildAlloca(
self.variables_builder.get(), self.variables_builder.get(),
get_type(self.context, &var.v_type)?, get_type(self.context, &var.info.v_type)?,
get_state_space(var.state_space)?, get_state_space(var.info.state_space)?,
self.resolver.get_or_add_raw(var.name), self.resolver.get_or_add_raw(var.name),
) )
}; };
self.resolver.register(var.name, alloca); self.resolver.register(var.name, alloca);
if let Some(align) = var.align { if let Some(align) = var.info.align {
unsafe { LLVMSetAlignment(alloca, align) }; unsafe { LLVMSetAlignment(alloca, align) };
} }
if !var.array_init.is_empty() { if !var.info.array_init.is_empty() {
return Err(error_unreachable()); return Err(error_unreachable());
} }
Ok(()) Ok(())

View file

@ -21,7 +21,7 @@ mod insert_post_saturation;
mod instruction_mode_to_global_mode; mod instruction_mode_to_global_mode;
pub mod llvm; pub mod llvm;
mod normalize_basic_blocks; mod normalize_basic_blocks;
mod normalize_identifiers2; mod normalize_identifiers;
mod normalize_predicates2; mod normalize_predicates2;
mod remove_unreachable_basic_blocks; mod remove_unreachable_basic_blocks;
mod replace_instructions_with_functions; mod replace_instructions_with_functions;
@ -65,8 +65,8 @@ pub fn to_llvm_module<'input>(
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; let directives = normalize_identifiers::run(&mut scoped_resolver, ast.directives)?;
on_pass_end("normalize_identifiers2"); on_pass_end("normalize_identifiers");
let directives = replace_known_functions::run(&mut flat_resolver, directives); let directives = replace_known_functions::run(&mut flat_resolver, directives);
on_pass_end("replace_known_functions"); on_pass_end("replace_known_functions");
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
@ -308,16 +308,18 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
Statement::Variable(var) => { Statement::Variable(var) => {
let name = visitor.visit_ident( let name = visitor.visit_ident(
var.name, var.name,
Some((&var.v_type, var.state_space)), Some((&var.info.v_type, var.info.state_space)),
true, true,
false, false,
)?; )?;
Statement::Variable(ast::Variable { Statement::Variable(ast::Variable {
align: var.align, info: ast::VariableInfo {
v_type: var.v_type, align: var.info.align,
state_space: var.state_space, v_type: var.info.v_type,
state_space: var.info.state_space,
array_init: var.info.array_init,
},
name, name,
array_init: var.array_init,
}) })
} }
Statement::Conditional(conditional) => { Statement::Conditional(conditional) => {
@ -978,20 +980,24 @@ impl SpecialRegistersMap {
let return_type = sreg.get_function_return_type(); let return_type = sreg.get_function_return_type();
let input_type = sreg.get_function_input_type(); let input_type = sreg.get_function_input_type();
let return_arguments = vec![ast::Variable { let return_arguments = vec![ast::Variable {
info: ast::VariableInfo {
align: None, align: None,
v_type: return_type.into(), v_type: return_type.into(),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
array_init: Vec::new(), array_init: Vec::new(),
},
name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
}]; }];
let input_arguments = input_type let input_arguments = input_type
.into_iter() .into_iter()
.map(|type_| ast::Variable { .map(|type_| ast::Variable {
info: ast::VariableInfo {
align: None, align: None,
v_type: type_.into(), v_type: type_.into(),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
array_init: Vec::new(), array_init: Vec::new(),
},
name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
fn_(sreg, (return_arguments, name, input_arguments)); fn_(sreg, (return_arguments, name, input_arguments));

View file

@ -80,19 +80,28 @@ fn run_function_decl<'input, 'b>(
Ok((return_arguments, input_arguments)) Ok((return_arguments, input_arguments))
} }
fn run_variable_info<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
info: ast::VariableInfo<&'input str>,
) -> Result<ast::VariableInfo<SpirvWord>, TranslateError> {
Ok(ast::VariableInfo {
align: info.align,
v_type: info.v_type,
state_space: info.state_space,
array_init: run_array_init(resolver, &info.array_init)?,
})
}
fn run_variable<'input, 'b>( fn run_variable<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>, resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>, variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> { ) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable { Ok(ast::Variable {
info: run_variable_info(resolver, variable.info.clone())?,
name: resolver.add( name: resolver.add(
Cow::Borrowed(variable.name), Cow::Borrowed(variable.name),
Some((variable.v_type.clone(), variable.state_space)), Some((variable.info.v_type.clone(), variable.info.state_space)),
)?, )?,
align: variable.align,
v_type: variable.v_type,
state_space: variable.state_space,
array_init: run_array_init(resolver, &variable.array_init)?,
}) })
} }
@ -158,38 +167,28 @@ fn run_multivariable<'input, 'b>(
result: &mut Vec<NormalizedStatement>, result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>, variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
match variable.count { match variable {
Some(count) => { ptx_parser::MultiVariable::Parameterized { info, name, count } => {
for i in 0..count { for i in 0..count {
let name = Cow::Owned(format!("{}{}", variable.var.name, i)); let name = Cow::Owned(format!("{}{}", name, i));
let ident = resolver.add( let ident = resolver.add(name, Some((info.v_type.clone(), info.state_space)))?;
name,
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: variable.var.align, info: run_variable_info(resolver, info.clone())?,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident, name: ident,
array_init: run_array_init(resolver, &variable.var.array_init)?,
})); }));
} }
} }
None => { ptx_parser::MultiVariable::Names { info, names } => {
let name = Cow::Borrowed(variable.var.name); for name in names {
let ident = resolver.add( let name = Cow::Borrowed(name);
name, let ident = resolver.add(name, Some((info.v_type.clone(), info.state_space)))?;
Some((variable.var.v_type.clone(), variable.var.state_space)),
)?;
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: variable.var.align, info: run_variable_info(resolver, info.clone())?,
v_type: variable.var.v_type.clone(),
state_space: variable.var.state_space,
name: ident, name: ident,
array_init: run_array_init(resolver, &variable.var.array_init)?,
})); }));
} }
} }
}
Ok(()) Ok(())
} }

View file

@ -580,11 +580,13 @@ fn to_variables<'input>(
arguments arguments
.iter() .iter()
.map(|(type_, space)| ast::Variable { .map(|(type_, space)| ast::Variable {
info: ast::VariableInfo {
align: None, align: None,
v_type: type_.clone(), v_type: type_.clone(),
state_space: *space, state_space: *space,
name: resolver.register_unnamed(Some((type_.clone(), *space))),
array_init: Vec::new(), array_init: Vec::new(),
},
name: resolver.register_unnamed(Some((type_.clone(), *space))),
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }

View file

@ -33,41 +33,49 @@ pub(crate) fn run<'input>(
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::U8), ast::Type::Scalar(ast::ScalarType::U8),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U8), v_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
], ],
name: imports.part1, name: imports.part1,
input_arguments: vec![ input_arguments: vec![
@ -76,21 +84,25 @@ pub(crate) fn run<'input>(
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
], ],
body: None, body: None,
import_as: None, import_as: None,
@ -108,10 +120,12 @@ pub(crate) fn run<'input>(
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
},
}], }],
name: imports.part2, name: imports.part2,
input_arguments: vec![ input_arguments: vec![
@ -120,61 +134,73 @@ pub(crate) fn run<'input>(
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::F32), ast::Type::Scalar(ast::ScalarType::F32),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::F32), v_type: ast::Type::Scalar(ast::ScalarType::F32),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
ast::Variable { ast::Variable {
name: resolver.register_unnamed(Some(( name: resolver.register_unnamed(Some((
ast::Type::Scalar(ast::ScalarType::U8), ast::Type::Scalar(ast::ScalarType::U8),
ast::StateSpace::Reg, ast::StateSpace::Reg,
))), ))),
info: ast::VariableInfo {
align: None, align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U8), v_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Reg, state_space: ast::StateSpace::Reg,
array_init: Vec::new(), array_init: Vec::new(),
}, },
},
], ],
body: None, body: None,
import_as: None, import_as: None,

View file

@ -12,7 +12,7 @@ fn run_expand_operands(ptx: ptx_parser::Module) -> String {
// We run the minimal number of passes required to produce the input expected by expand_operands // We run the minimal number of passes required to produce the input expected by expand_operands
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); let directives = normalize_identifiers::run(&mut scoped_resolver, ptx.directives).unwrap();
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
directive2_vec_to_string(&flat_resolver, directives) directive2_vec_to_string(&flat_resolver, directives)

View file

@ -12,7 +12,7 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String {
// We run the minimal number of passes required to produce the input expected by insert_implicit_conversions // We run the minimal number of passes required to produce the input expected by insert_implicit_conversions
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); let directives = normalize_identifiers::run(&mut scoped_resolver, ptx.directives).unwrap();
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives).unwrap(); let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives).unwrap();

View file

@ -0,0 +1,37 @@
define amdgpu_kernel void @reg_multi(ptr addrspace(4) byref(i64) %"38", ptr addrspace(4) byref(i64) %"39") #0 {
%"40" = alloca i64, align 8, addrspace(5)
%"41" = alloca i64, align 8, addrspace(5)
%"42" = alloca i32, align 4, addrspace(5)
%"43" = alloca i32, align 4, addrspace(5)
br label %1
1: ; preds = %0
br label %"37"
"37": ; preds = %1
%"44" = load i64, ptr addrspace(4) %"38", align 8
store i64 %"44", ptr addrspace(5) %"40", align 8
%"45" = load i64, ptr addrspace(4) %"39", align 8
store i64 %"45", ptr addrspace(5) %"41", align 8
%"47" = load i64, ptr addrspace(5) %"40", align 8
%"54" = inttoptr i64 %"47" to ptr
%"46" = load i32, ptr %"54", align 4
store i32 %"46", ptr addrspace(5) %"42", align 4
%"48" = load i64, ptr addrspace(5) %"40", align 8
%"55" = inttoptr i64 %"48" to ptr
%"34" = getelementptr inbounds i8, ptr %"55", i64 4
%"49" = load i32, ptr %"34", align 4
store i32 %"49", ptr addrspace(5) %"43", align 4
%"50" = load i64, ptr addrspace(5) %"41", align 8
%"51" = load i32, ptr addrspace(5) %"42", align 4
%"56" = inttoptr i64 %"50" to ptr
store i32 %"51", ptr %"56", align 4
%"52" = load i64, ptr addrspace(5) %"41", align 8
%"57" = inttoptr i64 %"52" to ptr
%"36" = getelementptr inbounds i8, ptr %"57", i64 4
%"53" = load i32, ptr addrspace(5) %"43", align 4
store i32 %"53", ptr %"36", align 4
ret void
}
attributes #0 = { "amdgpu-ieee"="false" "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -124,6 +124,7 @@ test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
test_ptx!(ntid, [3u32], [4u32]); test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_local, [12u64], [13u64]); test_ptx!(reg_local, [12u64], [13u64]);
test_ptx!(reg_multi, [123u32, 456u32], [123u32, 456u32]);
test_ptx!(mov_address, [0xDEADu64], [0u64]); test_ptx!(mov_address, [0xDEADu64], [0u64]);
test_ptx!(b64tof64, [111u64], [111u64]); test_ptx!(b64tof64, [111u64], [111u64]);
// This segfaults NV compiler // This segfaults NV compiler

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry reg_multi(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 a, b;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 a, [in_addr];
ld.u32 b, [in_addr+4];
st.u32 [out_addr], a;
st.u32 [out_addr+4], b;
ret;
}

View file

@ -998,29 +998,41 @@ impl<T: Operand, Err> MapOperand<Err> for Option<T> {
} }
} }
pub struct MultiVariable<ID> { pub enum MultiVariable<ID> {
pub var: Variable<ID>, Parameterized {
pub count: Option<u32>, info: VariableInfo<ID>,
name: ID,
count: u32,
},
Names {
info: VariableInfo<ID>,
names: Vec<ID>,
},
}
#[derive(Clone)]
pub struct VariableInfo<ID> {
pub align: Option<u32>,
pub v_type: Type,
pub state_space: StateSpace,
pub array_init: Vec<RegOrImmediate<ID>>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Variable<ID> { pub struct Variable<ID> {
pub align: Option<u32>, pub info: VariableInfo<ID>,
pub v_type: Type,
pub state_space: StateSpace,
pub name: ID, pub name: ID,
pub array_init: Vec<RegOrImmediate<ID>>,
} }
impl<ID: std::fmt::Display> std::fmt::Display for Variable<ID> { impl<ID: std::fmt::Display> std::fmt::Display for Variable<ID> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.state_space)?; write!(f, "{}", self.info.state_space)?;
if let Some(align) = self.align { if let Some(align) = self.info.align {
write!(f, " .align {}", align)?; write!(f, " .align {}", align)?;
} }
let (vector_size, scalar_type, array_dims) = match &self.v_type { let (vector_size, scalar_type, array_dims) = match &self.info.v_type {
Type::Scalar(scalar_type) => (None, *scalar_type, &vec![]), Type::Scalar(scalar_type) => (None, *scalar_type, &vec![]),
Type::Vector(size, scalar_type) => (Some(*size), *scalar_type, &vec![]), Type::Vector(size, scalar_type) => (Some(*size), *scalar_type, &vec![]),
Type::Array(vector_size, scalar_type, array_dims) => { Type::Array(vector_size, scalar_type, array_dims) => {
@ -1038,7 +1050,7 @@ impl<ID: std::fmt::Display> std::fmt::Display for Variable<ID> {
write!(f, "[{}]", dim)?; write!(f, "[{}]", dim)?;
} }
if self.array_init.len() > 0 { if self.info.array_init.len() > 0 {
todo!("Need to interpret the array initializer data as the appropriate type"); todo!("Need to interpret the array initializer data as the appropriate type");
} }

View file

@ -135,7 +135,7 @@ impl<'a, 'input> PtxParserState<'a, 'input> {
fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> {
input_arguments input_arguments
.iter() .iter()
.map(|var| (var.v_type.clone(), var.state_space)) .map(|var| (var.info.v_type.clone(), var.info.state_space))
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
} }
@ -552,7 +552,13 @@ fn module_variable<'a, 'input>(
let var = global_space let var = global_space
.flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space)) .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space))
// TODO: support multi var in globals // TODO: support multi var in globals
.map(|multi_var| multi_var.var) .verify_map(|multi_var| match multi_var {
MultiVariable::Names { info, names } if names.len() == 1 => Some(ast::Variable {
info,
name: names[0],
}),
_ => None,
})
.parse_next(stream)?; .parse_next(stream)?;
Ok((linking, var)) Ok((linking, var))
} }
@ -886,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>(
) -> impl Parser<PtxParser<'a, 'input>, Variable<&'input str>, ContextError> { ) -> impl Parser<PtxParser<'a, 'input>, Variable<&'input str>, ContextError> {
fn nvptx_kernel_declaration<'a, 'input>( fn nvptx_kernel_declaration<'a, 'input>(
stream: &mut PtxParser<'a, 'input>, stream: &mut PtxParser<'a, 'input>,
) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType, &'input str)> { ) -> PResult<((Option<u32>, Option<NonZeroU8>, ScalarType), &'input str)> {
trace( trace(
"nvptx_kernel_declaration", "nvptx_kernel_declaration",
( (
@ -897,15 +903,15 @@ fn method_parameter<'a, 'input: 'a>(
ident, ident,
), ),
) )
.map(|(vector, type_, _, align, name)| (align, vector, type_, name)) .map(|(vector, type_, _, align, name)| ((align, vector, type_), name))
.parse_next(stream) .parse_next(stream)
} }
trace( trace(
"method_parameter", "method_parameter",
move |stream: &mut PtxParser<'a, 'input>| { move |stream: &mut PtxParser<'a, 'input>| {
if kernel_decl_rules {} if kernel_decl_rules {}
let (align, vector, type_, name) = let ((align, vector, type_), name) =
alt((variable_declaration, nvptx_kernel_declaration)).parse_next(stream)?; alt(((variable_info, ident), nvptx_kernel_declaration)).parse_next(stream)?;
let array_dimensions = if state_space != StateSpace::Reg { let array_dimensions = if state_space != StateSpace::Reg {
opt(array_dimensions).parse_next(stream)? opt(array_dimensions).parse_next(stream)?
} else { } else {
@ -918,27 +924,28 @@ fn method_parameter<'a, 'input: 'a>(
} }
} }
Ok(Variable { Ok(Variable {
info: VariableInfo {
align, align,
v_type: Type::maybe_array(vector, type_, array_dimensions), v_type: Type::maybe_array(vector, type_, array_dimensions),
state_space, state_space,
name,
array_init: Vec::new(), array_init: Vec::new(),
},
name,
}) })
}, },
) )
} }
// TODO: split to a separate type // TODO: split to a separate type
fn variable_declaration<'a, 'input>( fn variable_info<'a, 'input>(
stream: &mut PtxParser<'a, 'input>, stream: &mut PtxParser<'a, 'input>,
) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType, &'input str)> { ) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType)> {
trace( trace(
"variable_declaration", "variable_info",
( (
opt(align.verify(|x| x.count_ones() == 1)), opt(align.verify(|x| x.count_ones() == 1)),
vector_prefix, vector_prefix,
scalar_type, scalar_type,
ident,
), ),
) )
.parse_next(stream) .parse_next(stream)
@ -951,21 +958,27 @@ fn multi_variable<'a, 'input: 'a>(
trace( trace(
"multi_variable", "multi_variable",
move |stream: &mut PtxParser<'a, 'input>| { move |stream: &mut PtxParser<'a, 'input>| {
let ((align, vector, type_, name), count) = ( let ((align, vector, type_), names, count): (_, Vec<_>, _) = (
variable_declaration, variable_info,
separated(1.., ident, Token::Comma),
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)),
) )
.parse_next(stream)?; .parse_next(stream)?;
if count.is_some() { if let Some(count) = count {
return Ok(MultiVariable { if names.len() > 1 {
var: Variable { // nvcc does not support parameterized variable names in comma-separated lists of names.
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
}
let name = names[0];
return Ok(MultiVariable::Parameterized {
info: VariableInfo {
align, align,
v_type: Type::maybe_vector_parsed(vector, type_), v_type: Type::maybe_vector_parsed(vector, type_),
state_space, state_space,
name,
array_init: Vec::new(), array_init: Vec::new(),
}, },
name,
count, count,
}); });
} }
@ -988,15 +1001,14 @@ fn multi_variable<'a, 'input: 'a>(
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
} }
} }
Ok(MultiVariable { Ok(MultiVariable::Names {
var: Variable { info: VariableInfo {
align, align,
v_type: Type::maybe_array(vector, type_, array_dimensions), v_type: Type::maybe_array(vector, type_, array_dimensions),
state_space, state_space,
name,
array_init: initializer.unwrap_or(Vec::new()), array_init: initializer.unwrap_or(Vec::new()),
}, },
count, names,
}) })
}, },
) )