diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index cbb1570..4d4142c 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -277,6 +277,11 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); + for statement in statements.iter() { + if let Statement::Label(label) = statement { + method_emitter.emit_label_initial(*label); + } + } for statement in statements { method_emitter.emit_statement(statement)?; } @@ -370,7 +375,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ) -> Result<(), TranslateError> { Ok(match statement { Statement::Variable(var) => self.emit_variable(var)?, - Statement::Label(label) => self.emit_label(label), + Statement::Label(label) => self.emit_label_delayed(label)?, Statement::Instruction(inst) => self.emit_instruction(inst)?, Statement::Conditional(_) => todo!(), Statement::LoadVar(var) => self.emit_load_variable(var)?, @@ -404,7 +409,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(()) } - fn emit_label(&mut self, label: SpirvWord) { + fn emit_label_initial(&mut self, label: SpirvWord) { let block = unsafe { LLVMAppendBasicBlockInContext( self.context, @@ -412,11 +417,19 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { self.resolver.get_or_add_raw(label), ) }; + self.resolver + .register(label, unsafe { LLVMBasicBlockAsValue(block) }); + } + + fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> { + let block = self.resolver.value(label)?; + let block = unsafe { LLVMValueAsBasicBlock(block) }; let last_block = unsafe { LLVMGetInsertBlock(self.builder) }; if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() { unsafe { LLVMBuildBr(self.builder, block) }; } unsafe { LLVMPositionBuilderAtEnd(self.builder, block) }; + Ok(()) } fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> { @@ -441,7 +454,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Not { data, arguments } => todo!(), ast::Instruction::Or { data, arguments } => todo!(), ast::Instruction::And { data, arguments } => self.emit_and(arguments), - ast::Instruction::Bra { arguments } => todo!(), + ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), ast::Instruction::Cvt { data, arguments } => todo!(), ast::Instruction::Shr { data, arguments } => todo!(), @@ -755,6 +768,13 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } + + fn emit_bra(&self, arguments: ptx_parser::BraArgs) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let src = unsafe { LLVMValueAsBasicBlock(src) }; + unsafe { LLVMBuildBr(self.builder, src) }; + Ok(()) + } } fn get_pointer_type<'ctx>(