Add mov and call support

This commit is contained in:
Andrzej Janik 2024-09-13 19:40:58 +02:00
parent 0417df3015
commit 02cf83ebb9

View file

@ -225,11 +225,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
});
let name = CString::new(name).map_err(|_| error_unreachable())?;
let fn_type = self.function_type(
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl.input_arguments.iter().map(|v| &v.v_type),
);
)?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
}
for (i, param) in func_decl.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
@ -258,67 +262,6 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
Ok(())
}
fn function_type(
&self,
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
) -> LLVMTypeRef {
if return_args.len() == 0 {
let mut input_args = input_args
.map(|type_| match type_ {
ast::Type::Scalar(scalar) => match scalar {
ast::ScalarType::Pred => {
unsafe { LLVMInt1TypeInContext(self.context) }
}
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
unsafe { LLVMInt8TypeInContext(self.context) }
}
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
unsafe { LLVMInt16TypeInContext(self.context) }
}
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
unsafe { LLVMInt32TypeInContext(self.context) }
}
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
unsafe { LLVMInt64TypeInContext(self.context) }
}
ast::ScalarType::B128 => {
unsafe { LLVMInt128TypeInContext(self.context) }
}
ast::ScalarType::F16 => {
unsafe { LLVMHalfTypeInContext(self.context) }
}
ast::ScalarType::F32 => {
unsafe { LLVMFloatTypeInContext(self.context) }
}
ast::ScalarType::F64 => {
unsafe { LLVMDoubleTypeInContext(self.context) }
}
ast::ScalarType::BF16 => {
unsafe { LLVMBFloatTypeInContext(self.context) }
}
ast::ScalarType::U16x2 => todo!(),
ast::ScalarType::S16x2 => todo!(),
ast::ScalarType::F16x2 => todo!(),
ast::ScalarType::BF16x2 => todo!(),
},
ast::Type::Vector(_, _) => todo!(),
ast::Type::Array(_, _, _) => todo!(),
ast::Type::Pointer(_, _) => todo!(),
})
.collect::<Vec<_>>();
return unsafe {
LLVMFunctionType(
LLVMVoidTypeInContext(self.context),
input_args.as_mut_ptr(),
input_args.len() as u32,
0,
)
};
}
todo!()
}
}
struct MethodEmitContext<'a, 'input> {
@ -414,7 +357,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
inst: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> {
match inst {
ast::Instruction::Mov { data, arguments } => todo!(),
ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
@ -425,7 +368,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Or { data, arguments } => todo!(),
ast::Instruction::And { data, arguments } => todo!(),
ast::Instruction::Bra { arguments } => todo!(),
ast::Instruction::Call { data, arguments } => todo!(),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
ast::Instruction::Cvt { data, arguments } => todo!(),
ast::Instruction::Shr { data, arguments } => todo!(),
ast::Instruction::Shl { data, arguments } => todo!(),
@ -563,6 +506,68 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_ret(&self, _data: ptx_parser::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
fn emit_call(
&mut self,
data: ptx_parser::CallDetails,
arguments: ptx_parser::CallArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if cfg!(debug_assertions) {
for (_, space) in data.return_arguments.iter() {
if *space != ast::StateSpace::Reg {
panic!()
}
}
for (_, space) in data.input_arguments.iter() {
if *space != ast::StateSpace::Reg {
panic!()
}
}
}
let name = match (&*data.return_arguments, &*arguments.return_arguments) {
([], []) => LLVM_UNNAMED.as_ptr(),
([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
_ => todo!(),
};
let type_ = get_function_type(
self.context,
data.return_arguments.iter().map(|(type_, space)| type_),
data.input_arguments.iter().map(|(type_, space)| type_),
)?;
let mut input_arguments = arguments
.input_arguments
.iter()
.map(|arg| self.resolver.value(*arg))
.collect::<Result<Vec<_>, _>>()?;
let llvm_fn = unsafe {
LLVMBuildCall2(
self.builder,
type_,
self.resolver.value(arguments.func)?,
input_arguments.as_mut_ptr(),
input_arguments.len() as u32,
name,
)
};
match &*arguments.return_arguments {
[] => {}
[name] => {
self.resolver.register(*name, llvm_fn);
}
_ => todo!(),
}
Ok(())
}
fn emit_mov(
&mut self,
_data: ptx_parser::MovDetails,
arguments: ptx_parser::MovArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.resolver
.register(arguments.dst, self.resolver.value(arguments.src)?);
Ok(())
}
}
fn get_pointer_type<'ctx>(
@ -624,6 +629,29 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
}
}
fn get_function_type<'a>(
context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
) -> Result<LLVMTypeRef, TranslateError> {
let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args
.map(|type_| get_type(context, type_))
.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
_ => todo!(),
};
Ok(unsafe {
LLVMFunctionType(
return_type,
input_args.as_mut_ptr(),
input_args.len() as u32,
0,
)
})
}
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),