mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Implement call
This commit is contained in:
parent
6456f0d1a1
commit
053c41fbb9
7 changed files with 198 additions and 97 deletions
|
@ -26,70 +26,68 @@ fn run_method<'input>(
|
|||
resolver: &mut GlobalStringIdentResolver2,
|
||||
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
if method.func_decl.name.is_kernel() {
|
||||
return Ok(method);
|
||||
}
|
||||
let is_declaration = method.body.is_none();
|
||||
let mut body = Vec::new();
|
||||
let mut remap_returns = Vec::new();
|
||||
for arg in method.func_decl.return_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
if !method.func_decl.name.is_kernel() {
|
||||
for arg in method.func_decl.return_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name =
|
||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
}
|
||||
remap_returns.push((old_name, arg.name, arg.v_type.clone()));
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in method.func_decl.input_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
for arg in method.func_decl.input_arguments.iter_mut() {
|
||||
match arg.state_space {
|
||||
ptx_parser::StateSpace::Param => {
|
||||
arg.state_space = ptx_parser::StateSpace::Reg;
|
||||
let old_name = arg.name;
|
||||
arg.name =
|
||||
resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
|
||||
if is_declaration {
|
||||
continue;
|
||||
}
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
body.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: arg.v_type.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: arg.name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
body.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: old_name,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ptx_parser::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
body.push(Statement::Instruction(ast::Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::StCacheOperator::Writethrough,
|
||||
typ: arg.v_type.clone(),
|
||||
},
|
||||
arguments: ast::StArgs {
|
||||
src1: old_name,
|
||||
src2: arg.name,
|
||||
},
|
||||
}));
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
ptx_parser::StateSpace::Reg => {}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
if remap_returns.is_empty() {
|
||||
return Ok(method);
|
||||
}
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
|
@ -116,24 +114,6 @@ fn run_statement<'input>(
|
|||
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Ret { .. }) => {
|
||||
for (old_name, new_name, type_) in remap_returns.iter().cloned() {
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Reg,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_,
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: new_name,
|
||||
src: old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
result.push(statement);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
mut data,
|
||||
mut arguments,
|
||||
|
@ -194,6 +174,24 @@ fn run_statement<'input>(
|
|||
}));
|
||||
result.extend(post_st.into_iter());
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||
for (old_name, new_name, type_) in remap_returns.iter() {
|
||||
result.push(Statement::Instruction(ast::Instruction::Ld {
|
||||
data: ast::LdDetails {
|
||||
qualifier: ast::LdStQualifier::Weak,
|
||||
state_space: ast::StateSpace::Param,
|
||||
caching: ast::LdCacheOperator::Cached,
|
||||
typ: type_.clone(),
|
||||
non_coherent: false,
|
||||
},
|
||||
arguments: ast::LdArgs {
|
||||
dst: *new_name,
|
||||
src: *old_name,
|
||||
},
|
||||
}));
|
||||
}
|
||||
result.push(Statement::Instruction(ast::Instruction::Ret { data }));
|
||||
}
|
||||
statement => {
|
||||
result.push(statement);
|
||||
}
|
||||
|
|
|
@ -231,15 +231,18 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
})
|
||||
.ok_or_else(|| error_unreachable())?;
|
||||
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||
func_decl
|
||||
.input_arguments
|
||||
.iter()
|
||||
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
|
||||
)?;
|
||||
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||
let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
|
||||
if fn_ == ptr::null_mut() {
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||
func_decl
|
||||
.input_arguments
|
||||
.iter()
|
||||
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
|
||||
)?;
|
||||
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||
}
|
||||
if let ast::MethodName::Func(name) = func_decl.name {
|
||||
self.resolver.register(name, fn_);
|
||||
}
|
||||
|
@ -277,6 +280,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
|
||||
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
|
||||
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
|
||||
for var in func_decl.return_arguments {
|
||||
method_emitter.emit_variable(var)?;
|
||||
}
|
||||
for statement in statements.iter() {
|
||||
if let Statement::Label(label) = statement {
|
||||
method_emitter.emit_label_initial(*label);
|
||||
|
@ -382,7 +388,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
Statement::StoreVar(store) => self.emit_store_var(store)?,
|
||||
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
|
||||
Statement::Constant(constant) => self.emit_constant(constant)?,
|
||||
Statement::RetValue(_, _) => todo!(),
|
||||
Statement::RetValue(_, values) => self.emit_ret_value(values)?,
|
||||
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
|
||||
Statement::RepackVector(_) => todo!(),
|
||||
Statement::FunctionPointer(_) => todo!(),
|
||||
|
@ -560,7 +566,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
});
|
||||
Ok(())
|
||||
}
|
||||
ConversionKind::AddressOf => todo!(),
|
||||
ConversionKind::AddressOf => {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let dst_type = get_type(self.context, &conversion.to_type)?;
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
LLVMBuildPtrToInt(self.builder, src, dst_type, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -797,6 +810,24 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_ret_value(
|
||||
&mut self,
|
||||
values: Vec<(SpirvWord, ptx_parser::Type)>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match &*values {
|
||||
[] => unsafe { LLVMBuildRetVoid(self.builder) },
|
||||
[(value, type_)] => {
|
||||
let value = self.resolver.value(*value)?;
|
||||
let type_ = get_type(self.context, type_)?;
|
||||
let value =
|
||||
unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) };
|
||||
unsafe { LLVMBuildRet(self.builder, value) }
|
||||
}
|
||||
_ => todo!(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
|
|
|
@ -1502,7 +1502,8 @@ fn emit_function_body_ops<'input>(
|
|||
builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?;
|
||||
}
|
||||
Statement::RetValue(_, id) => {
|
||||
builder.ret_value(id.0)?;
|
||||
todo!()
|
||||
//builder.ret_value(id.0)?;
|
||||
}
|
||||
Statement::PtrAccess(PtrAccess {
|
||||
underlying_type,
|
||||
|
|
|
@ -40,9 +40,6 @@ fn run_method<'a, 'input>(
|
|||
method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
|
||||
let mut func_decl = method.func_decl;
|
||||
for arg in func_decl.return_arguments.iter_mut() {
|
||||
visitor.visit_variable(arg)?;
|
||||
}
|
||||
let is_kernel = func_decl.name.is_kernel();
|
||||
if is_kernel {
|
||||
for arg in func_decl.input_arguments.iter_mut() {
|
||||
|
@ -57,12 +54,16 @@ fn run_method<'a, 'input>(
|
|||
arg.state_space = new_space;
|
||||
}
|
||||
};
|
||||
for arg in func_decl.return_arguments.iter_mut() {
|
||||
visitor.visit_variable(arg)?;
|
||||
}
|
||||
let return_arguments = &func_decl.return_arguments[..];
|
||||
let body = method
|
||||
.body
|
||||
.map(move |statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(&mut visitor, &mut result, statement)?;
|
||||
run_statement(&mut visitor, return_arguments, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
|
@ -79,10 +80,33 @@ fn run_method<'a, 'input>(
|
|||
|
||||
fn run_statement<'a, 'input>(
|
||||
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
|
||||
return_arguments: &[ast::Variable<SpirvWord>],
|
||||
result: &mut Vec<ExpandedStatement>,
|
||||
statement: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Ret { data }) => {
|
||||
let statement = if return_arguments.is_empty() {
|
||||
Statement::Instruction(ast::Instruction::Ret { data })
|
||||
} else {
|
||||
Statement::RetValue(
|
||||
data,
|
||||
return_arguments
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
if arg.state_space != ast::StateSpace::Local {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
Ok((arg.name, arg.v_type.clone()))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
};
|
||||
let new_statement = statement.visit_map(visitor)?;
|
||||
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
|
||||
result.push(new_statement);
|
||||
result.extend(visitor.post.drain(..).map(Statement::Instruction));
|
||||
}
|
||||
Statement::Variable(mut var) => {
|
||||
visitor.visit_variable(&mut var)?;
|
||||
result.push(Statement::Variable(var));
|
||||
|
@ -271,9 +295,9 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
|
|||
fn visit(
|
||||
&mut self,
|
||||
ident: SpirvWord,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if let Some(remap) = self.variables.get(&ident) {
|
||||
match remap {
|
||||
|
|
|
@ -72,7 +72,8 @@ pub(super) fn run<'a, 'b>(
|
|||
typ: return_reg.v_type.clone(),
|
||||
member_index: None,
|
||||
}));
|
||||
result.push(Statement::RetValue(data, new_id));
|
||||
unimplemented!()
|
||||
//result.push(Statement::RetValue(data, new_id));
|
||||
}
|
||||
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
|
||||
_ => unimplemented!(),
|
||||
|
|
|
@ -770,7 +770,7 @@ enum Statement<I, P: ast::Operand> {
|
|||
StoreVar(StoreVarDetails),
|
||||
Conversion(ImplicitConversion),
|
||||
Constant(ConstantDefinition),
|
||||
RetValue(ast::RetData, SpirvWord),
|
||||
RetValue(ast::RetData, Vec<(SpirvWord, ast::Type)>),
|
||||
PtrAccess(PtrAccess<P>),
|
||||
RepackVector(RepackVectorDetails),
|
||||
FunctionPointer(FunctionPointerDetails),
|
||||
|
@ -906,9 +906,20 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
|
|||
Statement::Constant(ConstantDefinition { dst, typ, value })
|
||||
}
|
||||
Statement::RetValue(data, value) => {
|
||||
// TODO:
|
||||
// We should report type here
|
||||
let value = visitor.visit_ident(value, None, false, false)?;
|
||||
let value = value
|
||||
.into_iter()
|
||||
.map(|(ident, type_)| {
|
||||
Ok((
|
||||
visitor.visit_ident(
|
||||
ident,
|
||||
Some((&type_, ast::StateSpace::Local)),
|
||||
false,
|
||||
false,
|
||||
)?,
|
||||
type_,
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Statement::RetValue(data, value)
|
||||
}
|
||||
Statement::PtrAccess(PtrAccess {
|
||||
|
@ -1867,6 +1878,41 @@ impl<'input, 'b> ScopedResolver<'input, 'b> {
|
|||
scope.flush(self.flat_resolver);
|
||||
}
|
||||
|
||||
fn add_or_get_in_current_scope_untyped(
|
||||
&mut self,
|
||||
name: &'input str,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let current_scope = self.scopes.last_mut().unwrap();
|
||||
Ok(
|
||||
match current_scope.name_to_ident.entry(Cow::Borrowed(name)) {
|
||||
hash_map::Entry::Occupied(occupied_entry) => {
|
||||
let ident = *occupied_entry.get();
|
||||
let entry = current_scope
|
||||
.ident_map
|
||||
.get(&ident)
|
||||
.ok_or_else(|| error_unreachable())?;
|
||||
if entry.type_space.is_some() {
|
||||
return Err(error_unknown_symbol());
|
||||
}
|
||||
ident
|
||||
}
|
||||
hash_map::Entry::Vacant(vacant_entry) => {
|
||||
let new_id = self.flat_resolver.current_id;
|
||||
self.flat_resolver.current_id.0 += 1;
|
||||
vacant_entry.insert(new_id);
|
||||
current_scope.ident_map.insert(
|
||||
new_id,
|
||||
IdentEntry {
|
||||
name: Some(Cow::Borrowed(name)),
|
||||
type_space: None,
|
||||
},
|
||||
);
|
||||
new_id
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn add(
|
||||
&mut self,
|
||||
name: Cow<'input, str>,
|
||||
|
@ -2045,4 +2091,4 @@ fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
|
|||
ast::ScalarType::BF16x2 => "bf16x2",
|
||||
ast::ScalarType::Pred => "pred",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ fn run_method<'input, 'b>(
|
|||
let name = match method.func_directive.name {
|
||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||
ast::MethodName::Func(text) => {
|
||||
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
|
||||
ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?)
|
||||
}
|
||||
};
|
||||
resolver.start_scope();
|
||||
|
|
Loading…
Add table
Reference in a new issue