diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index acadc4a..5d1b0cd 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -43,6 +43,10 @@ pub(super) fn run<'input>( Directive::Method(method) => emit_ctx.emit_method(method)?, } } + if let Err(err) = emit_ctx.module.verify() { + emit_ctx.module.print_to_stderr(); + panic!("{}", err); + } Ok(emit_ctx.module.write_bitcode_to_memory()) } @@ -107,15 +111,16 @@ impl<'ctx, 'input> ModuleEmitContext<'ctx, 'input> { Self::func_call_convention() }); if let Some(statements) = method.body { - let entry_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED); + let variables_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED); let variables_builder = self.context.create_builder(); - variables_builder.position_at_end(entry_bb); + variables_builder.position_at_end(variables_bb); let real_bb = self.context.append_basic_block(fn_, LLVM_UNNAMED); self.builder.position_at_end(real_bb); let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); for statement in statements { method_emitter.emit_statement(statement)?; } + method_emitter.variables_builder.build_unconditional_branch(real_bb); } Ok(()) } @@ -234,7 +239,7 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { fn emit_variable(&mut self, var: ast::Variable) -> Result<(), TranslateError> { let alloca = unsafe { PointerValue::new(LLVMZludaBuildAlloca( - self.builder.as_mut_ptr(), + self.variables_builder.as_mut_ptr(), get_type::(&self.context, &var.v_type)?.as_type_ref(), get_state_space(var.state_space)? as u32, self.resolver.get_or_add_raw(var.name), @@ -257,6 +262,15 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { let block = self .context .append_basic_block(self.method, self.resolver.get_or_add(label)); + if self + .builder + .get_insert_block() + .unwrap() + .get_terminator() + .is_none() + { + self.builder.build_unconditional_branch(block); + } self.builder.position_at_end(block); } @@ -277,7 +291,7 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { ast::Instruction::Mov { data, arguments } => todo!(), ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), - ast::Instruction::St { data, arguments } => todo!(), + ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), ast::Instruction::Mul { data, arguments } => todo!(), ast::Instruction::Setp { data, arguments } => todo!(), ast::Instruction::SetpBool { data, arguments } => todo!(), @@ -289,7 +303,7 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { ast::Instruction::Cvt { data, arguments } => todo!(), ast::Instruction::Shr { data, arguments } => todo!(), ast::Instruction::Shl { data, arguments } => todo!(), - ast::Instruction::Ret { data } => todo!(), + ast::Instruction::Ret { data } => self.emit_ret(data), ast::Instruction::Cvta { data, arguments } => todo!(), ast::Instruction::Abs { data, arguments } => todo!(), ast::Instruction::Mad { data, arguments } => todo!(), @@ -398,15 +412,39 @@ impl<'a, 'ctx, 'input> MethodEmitContext<'a, 'ctx, 'input> { arguments: ast::AddArgs, ) -> Result<(), TranslateError> { let builder = self.builder; - let src1 = self.resolver.value(arguments.src1)?; - let src2 = self.resolver.value(arguments.src2)?; + let src1 = self.resolver.value(arguments.src1)?.as_int()?; + let src2 = self.resolver.value(arguments.src2)?.as_int()?; let fn_ = match data { - ast::ArithDetails::Integer(integer) => Builder::build_int_add::>, + ast::ArithDetails::Integer(integer) => Builder::build_int_add, ast::ArithDetails::Float(float) => todo!(), }; self.resolver .with_result(arguments.dst, |dst| fn_(builder, src1, src2, dst)) } + + fn emit_st( + &self, + data: ptx_parser::StData, + arguments: ptx_parser::StArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?.as_pointer()?; + let src2 = self.resolver.value(arguments.src2)?.as_basic()?; + if data.qualifier != ast::LdStQualifier::Weak { + todo!() + } + self.builder + .build_store(src1, src2) + .map_err(|_| error_unreachable())?; + Ok(()) + } + + fn emit_ret(&self, _data: ptx_parser::RetData) -> Result<(), TranslateError> { + self.builder + .build_return(None) + .map_err(|_| error_unreachable())?; + Ok(()) + } } fn get_pointer_type<'ctx>( @@ -583,37 +621,41 @@ impl<'ctx> ResolveIdent<'ctx> { Ok(()) } - fn build_int_math>( + fn build_int_math( &mut self, builder: &Builder<'ctx>, dst: SpirvWord, src1: SpirvWord, src2: SpirvWord, - fn_: impl FnOnce(&Builder<'ctx>, T, T, &str) -> Result, + fn_: impl IntMathOp<'ctx>, ) -> Result<(), TranslateError> { let src1 = self.value(src1)?; let src2 = self.value(src2)?; self.with_result(dst, |dst| { Ok(match (src1, src2) { (AnyValueEnum::IntValue(src1), AnyValueEnum::IntValue(src2)) => { - AnyValueEnum::from(fn_(builder, src1, src2, dst)) + AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?) } (AnyValueEnum::PointerValue(src1), AnyValueEnum::PointerValue(src2)) => { - AnyValueEnum::from(fn_(builder, src1, src2, dst)) + AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?) } (AnyValueEnum::VectorValue(src1), AnyValueEnum::VectorValue(src2)) => { - AnyValueEnum::from(fn_(builder, src1, src2, dst)) + AnyValueEnum::from(fn_.call(builder, src1, src2, dst)?) } _ => return todo!(), }) }) } +} - fn build_float_math>( - &mut self, - fn_: impl FnOnce(&Builder<'ctx>, T, T, &str) -> Result, - ) -> Result<(), TranslateError> { - } +trait IntMathOp<'ctx> { + fn call>( + self, + builder: &Builder<'ctx>, + src1: T, + src2: T, + dst: &str, + ) -> Result; } trait AnyValueEnumExt<'ctx> {