Implement call

This commit is contained in:
Andrzej Janik 2024-10-06 02:05:16 +02:00
parent 6456f0d1a1
commit 053c41fbb9
7 changed files with 198 additions and 97 deletions

View file

@ -26,70 +26,68 @@ fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
if method.func_decl.name.is_kernel() {
return Ok(method);
}
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
for arg in method.func_decl.return_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
if !method.func_decl.name.is_kernel() {
for arg in method.func_decl.return_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
}
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
for arg in method.func_decl.input_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
for arg in method.func_decl.input_arguments.iter_mut() {
match arg.state_space {
ptx_parser::StateSpace::Param => {
arg.state_space = ptx_parser::StateSpace::Reg;
let old_name = arg.name;
arg.name =
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
if is_declaration {
continue;
}
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: arg.name,
},
}));
}
body.push(Statement::Variable(ast::Variable {
align: None,
name: old_name,
v_type: arg.v_type.clone(),
state_space: ptx_parser::StateSpace::Param,
array_init: Vec::new(),
}));
body.push(Statement::Instruction(ast::Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::StCacheOperator::Writethrough,
typ: arg.v_type.clone(),
},
arguments: ast::StArgs {
src1: old_name,
src2: arg.name,
},
}));
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
ptx_parser::StateSpace::Reg => {}
_ => return Err(error_unreachable()),
}
}
if remap_returns.is_empty() {
return Ok(method);
}
let body = method
.body
.map(|statements| {
@ -116,24 +114,6 @@ fn run_statement<'input>(
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Ret { .. }) => {
for (old_name, new_name, type_) in remap_returns.iter().cloned() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Reg,
caching: ast::LdCacheOperator::Cached,
typ: type_,
non_coherent: false,
},
arguments: ast::LdArgs {
dst: new_name,
src: old_name,
},
}));
}
result.push(statement);
}
Statement::Instruction(ast::Instruction::Call {
mut data,
mut arguments,
@ -194,6 +174,24 @@ fn run_statement<'input>(
}));
result.extend(post_st.into_iter());
}
Statement::Instruction(ast::Instruction::Ret { data }) => {
for (old_name, new_name, type_) in remap_returns.iter() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
dst: *new_name,
src: *old_name,
},
}));
}
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
}
statement => {
result.push(statement);
}

View file

@ -231,15 +231,18 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
})
.ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl
.input_arguments
.iter()
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
if fn_ == ptr::null_mut() {
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl
.input_arguments
.iter()
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
}
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
}
@ -277,6 +280,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
for var in func_decl.return_arguments {
method_emitter.emit_variable(var)?;
}
for statement in statements.iter() {
if let Statement::Label(label) = statement {
method_emitter.emit_label_initial(*label);
@ -382,7 +388,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Statement::StoreVar(store) => self.emit_store_var(store)?,
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
Statement::Constant(constant) => self.emit_constant(constant)?,
Statement::RetValue(_, _) => todo!(),
Statement::RetValue(_, values) => self.emit_ret_value(values)?,
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
Statement::RepackVector(_) => todo!(),
Statement::FunctionPointer(_) => todo!(),
@ -560,7 +566,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
ConversionKind::AddressOf => todo!(),
ConversionKind::AddressOf => {
let src = self.resolver.value(conversion.src)?;
let dst_type = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildPtrToInt(self.builder, src, dst_type, dst)
});
Ok(())
}
}
}
@ -797,6 +810,24 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
fn emit_ret_value(
&mut self,
values: Vec<(SpirvWord, ptx_parser::Type)>,
) -> Result<(), TranslateError> {
match &*values {
[] => unsafe { LLVMBuildRetVoid(self.builder) },
[(value, type_)] => {
let value = self.resolver.value(*value)?;
let type_ = get_type(self.context, type_)?;
let value =
unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMBuildRet(self.builder, value) }
}
_ => todo!(),
};
Ok(())
}
}
fn get_pointer_type<'ctx>(

View file

@ -1502,7 +1502,8 @@ fn emit_function_body_ops<'input>(
builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?;
}
Statement::RetValue(_, id) => {
builder.ret_value(id.0)?;
todo!()
//builder.ret_value(id.0)?;
}
Statement::PtrAccess(PtrAccess {
underlying_type,

View file

@ -40,9 +40,6 @@ fn run_method<'a, 'input>(
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
for arg in func_decl.return_arguments.iter_mut() {
visitor.visit_variable(arg)?;
}
let is_kernel = func_decl.name.is_kernel();
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
@ -57,12 +54,16 @@ fn run_method<'a, 'input>(
arg.state_space = new_space;
}
};
for arg in func_decl.return_arguments.iter_mut() {
visitor.visit_variable(arg)?;
}
let return_arguments = &func_decl.return_arguments[..];
let body = method
.body
.map(move |statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(&mut visitor, &mut result, statement)?;
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
@ -79,10 +80,33 @@ fn run_method<'a, 'input>(
fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
return_arguments: &[ast::Variable<SpirvWord>],
result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement,
) -> Result<(), TranslateError> {
match statement {
Statement::Instruction(ast::Instruction::Ret { data }) => {
let statement = if return_arguments.is_empty() {
Statement::Instruction(ast::Instruction::Ret { data })
} else {
Statement::RetValue(
data,
return_arguments
.iter()
.map(|arg| {
if arg.state_space != ast::StateSpace::Local {
return Err(error_unreachable());
}
Ok((arg.name, arg.v_type.clone()))
})
.collect::<Result<Vec<_>, _>>()?,
)
};
let new_statement = statement.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(new_statement);
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
@ -271,9 +295,9 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
fn visit(
&mut self,
ident: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(remap) = self.variables.get(&ident) {
match remap {

View file

@ -72,7 +72,8 @@ pub(super) fn run<'a, 'b>(
typ: return_reg.v_type.clone(),
member_index: None,
}));
result.push(Statement::RetValue(data, new_id));
unimplemented!()
//result.push(Statement::RetValue(data, new_id));
}
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
_ => unimplemented!(),

View file

@ -770,7 +770,7 @@ enum Statement<I, P: ast::Operand> {
StoreVar(StoreVarDetails),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, SpirvWord),
RetValue(ast::RetData, Vec<(SpirvWord, ast::Type)>),
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
@ -906,9 +906,20 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
Statement::Constant(ConstantDefinition { dst, typ, value })
}
Statement::RetValue(data, value) => {
// TODO:
// We should report type here
let value = visitor.visit_ident(value, None, false, false)?;
let value = value
.into_iter()
.map(|(ident, type_)| {
Ok((
visitor.visit_ident(
ident,
Some((&type_, ast::StateSpace::Local)),
false,
false,
)?,
type_,
))
})
.collect::<Result<Vec<_>, _>>()?;
Statement::RetValue(data, value)
}
Statement::PtrAccess(PtrAccess {
@ -1867,6 +1878,41 @@ impl<'input, 'b> ScopedResolver<'input, 'b> {
scope.flush(self.flat_resolver);
}
fn add_or_get_in_current_scope_untyped(
&mut self,
name: &'input str,
) -> Result<SpirvWord, TranslateError> {
let current_scope = self.scopes.last_mut().unwrap();
Ok(
match current_scope.name_to_ident.entry(Cow::Borrowed(name)) {
hash_map::Entry::Occupied(occupied_entry) => {
let ident = *occupied_entry.get();
let entry = current_scope
.ident_map
.get(&ident)
.ok_or_else(|| error_unreachable())?;
if entry.type_space.is_some() {
return Err(error_unknown_symbol());
}
ident
}
hash_map::Entry::Vacant(vacant_entry) => {
let new_id = self.flat_resolver.current_id;
self.flat_resolver.current_id.0 += 1;
vacant_entry.insert(new_id);
current_scope.ident_map.insert(
new_id,
IdentEntry {
name: Some(Cow::Borrowed(name)),
type_space: None,
},
);
new_id
}
},
)
}
fn add(
&mut self,
name: Cow<'input, str>,
@ -2045,4 +2091,4 @@ fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
ast::ScalarType::BF16x2 => "bf16x2",
ast::ScalarType::Pred => "pred",
}
}
}

View file

@ -37,7 +37,7 @@ fn run_method<'input, 'b>(
let name = match method.func_directive.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(text) => {
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?)
}
};
resolver.start_scope();