mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add mov and call support
This commit is contained in:
parent
0417df3015
commit
02cf83ebb9
1 changed files with 93 additions and 65 deletions
|
@ -225,11 +225,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
|
||||
});
|
||||
let name = CString::new(name).map_err(|_| error_unreachable())?;
|
||||
let fn_type = self.function_type(
|
||||
let fn_type = get_function_type(
|
||||
self.context,
|
||||
func_decl.return_arguments.iter().map(|v| &v.v_type),
|
||||
func_decl.input_arguments.iter().map(|v| &v.v_type),
|
||||
);
|
||||
)?;
|
||||
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
|
||||
if let ast::MethodName::Func(name) = func_decl.name {
|
||||
self.resolver.register(name, fn_);
|
||||
}
|
||||
for (i, param) in func_decl.input_arguments.iter().enumerate() {
|
||||
let value = unsafe { LLVMGetParam(fn_, i as u32) };
|
||||
let name = self.resolver.get_or_add(param.name);
|
||||
|
@ -258,67 +262,6 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn function_type(
|
||||
&self,
|
||||
return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
) -> LLVMTypeRef {
|
||||
if return_args.len() == 0 {
|
||||
let mut input_args = input_args
|
||||
.map(|type_| match type_ {
|
||||
ast::Type::Scalar(scalar) => match scalar {
|
||||
ast::ScalarType::Pred => {
|
||||
unsafe { LLVMInt1TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
||||
unsafe { LLVMInt8TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
|
||||
unsafe { LLVMInt16TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
||||
unsafe { LLVMInt32TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
|
||||
unsafe { LLVMInt64TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::B128 => {
|
||||
unsafe { LLVMInt128TypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F16 => {
|
||||
unsafe { LLVMHalfTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F32 => {
|
||||
unsafe { LLVMFloatTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::F64 => {
|
||||
unsafe { LLVMDoubleTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::BF16 => {
|
||||
unsafe { LLVMBFloatTypeInContext(self.context) }
|
||||
}
|
||||
ast::ScalarType::U16x2 => todo!(),
|
||||
ast::ScalarType::S16x2 => todo!(),
|
||||
ast::ScalarType::F16x2 => todo!(),
|
||||
ast::ScalarType::BF16x2 => todo!(),
|
||||
},
|
||||
ast::Type::Vector(_, _) => todo!(),
|
||||
ast::Type::Array(_, _, _) => todo!(),
|
||||
ast::Type::Pointer(_, _) => todo!(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
return unsafe {
|
||||
LLVMFunctionType(
|
||||
LLVMVoidTypeInContext(self.context),
|
||||
input_args.as_mut_ptr(),
|
||||
input_args.len() as u32,
|
||||
0,
|
||||
)
|
||||
};
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
struct MethodEmitContext<'a, 'input> {
|
||||
|
@ -414,7 +357,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
inst: ast::Instruction<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match inst {
|
||||
ast::Instruction::Mov { data, arguments } => todo!(),
|
||||
ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
|
||||
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
|
||||
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
|
||||
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
|
||||
|
@ -425,7 +368,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
ast::Instruction::Or { data, arguments } => todo!(),
|
||||
ast::Instruction::And { data, arguments } => todo!(),
|
||||
ast::Instruction::Bra { arguments } => todo!(),
|
||||
ast::Instruction::Call { data, arguments } => todo!(),
|
||||
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
|
||||
ast::Instruction::Cvt { data, arguments } => todo!(),
|
||||
ast::Instruction::Shr { data, arguments } => todo!(),
|
||||
ast::Instruction::Shl { data, arguments } => todo!(),
|
||||
|
@ -563,6 +506,68 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
|
|||
fn emit_ret(&self, _data: ptx_parser::RetData) {
|
||||
unsafe { LLVMBuildRetVoid(self.builder) };
|
||||
}
|
||||
|
||||
fn emit_call(
|
||||
&mut self,
|
||||
data: ptx_parser::CallDetails,
|
||||
arguments: ptx_parser::CallArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if cfg!(debug_assertions) {
|
||||
for (_, space) in data.return_arguments.iter() {
|
||||
if *space != ast::StateSpace::Reg {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
for (_, space) in data.input_arguments.iter() {
|
||||
if *space != ast::StateSpace::Reg {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
}
|
||||
let name = match (&*data.return_arguments, &*arguments.return_arguments) {
|
||||
([], []) => LLVM_UNNAMED.as_ptr(),
|
||||
([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
|
||||
_ => todo!(),
|
||||
};
|
||||
let type_ = get_function_type(
|
||||
self.context,
|
||||
data.return_arguments.iter().map(|(type_, space)| type_),
|
||||
data.input_arguments.iter().map(|(type_, space)| type_),
|
||||
)?;
|
||||
let mut input_arguments = arguments
|
||||
.input_arguments
|
||||
.iter()
|
||||
.map(|arg| self.resolver.value(*arg))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let llvm_fn = unsafe {
|
||||
LLVMBuildCall2(
|
||||
self.builder,
|
||||
type_,
|
||||
self.resolver.value(arguments.func)?,
|
||||
input_arguments.as_mut_ptr(),
|
||||
input_arguments.len() as u32,
|
||||
name,
|
||||
)
|
||||
};
|
||||
match &*arguments.return_arguments {
|
||||
[] => {}
|
||||
[name] => {
|
||||
self.resolver.register(*name, llvm_fn);
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mov(
|
||||
&mut self,
|
||||
_data: ptx_parser::MovDetails,
|
||||
arguments: ptx_parser::MovArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
self.resolver
|
||||
.register(arguments.dst, self.resolver.value(arguments.src)?);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pointer_type<'ctx>(
|
||||
|
@ -624,6 +629,29 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
|
|||
}
|
||||
}
|
||||
|
||||
fn get_function_type<'a>(
|
||||
context: LLVMContextRef,
|
||||
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args
|
||||
.map(|type_| get_type(context, type_))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let return_type = match return_args.len() {
|
||||
0 => unsafe { LLVMVoidTypeInContext(context) },
|
||||
1 => get_type(context, return_args.next().unwrap())?,
|
||||
_ => todo!(),
|
||||
};
|
||||
Ok(unsafe {
|
||||
LLVMFunctionType(
|
||||
return_type,
|
||||
input_args.as_mut_ptr(),
|
||||
input_args.len() as u32,
|
||||
0,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
|
||||
match space {
|
||||
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|
||||
|
|
Loading…
Add table
Reference in a new issue