diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index 670446c..7388203 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -510,7 +510,7 @@ fn emit_function_variable( llvm_type, get_llvm_address_space(&ctx.constants, variable.state_space)?, Some(variable.name), - )?; + ); match variable.initializer { None => {} Some(init) => { @@ -1066,10 +1066,14 @@ fn emit_value_copy( src: LLVMValueRef, dst: Id, ) -> Result<(), TranslateError> { - let llvm_type = get_llvm_type(ctx, type_)?; - let temp_value = emit_alloca(ctx, llvm_type, ctx.constants.private_space, None)?; - emit_alloca_store(ctx, type_, src, temp_value)?; - emit_alloca_load(ctx, type_, dst, temp_value) + let builder = ctx.builder.get(); + let type_ = get_llvm_type(ctx, type_)?; + let temp_value = emit_alloca(ctx, type_, ctx.constants.private_space, None); + unsafe { LLVMBuildStore(builder, src, temp_value) }; + ctx.names.register_result(dst, |dst| unsafe { + LLVMBuildLoad2(builder, type_, temp_value, dst) + }); + Ok(()) } // From "Performance Tips for Frontend Authors" (https://llvm.org/docs/Frontend/PerformanceTips.html): @@ -1082,21 +1086,16 @@ fn emit_alloca( type_: LLVMTypeRef, addr_space: u32, name: Option, -) -> Result { +) -> LLVMValueRef { let builder = ctx.builder.get(); let current_bb = unsafe { LLVMGetInsertBlock(builder) }; let variables_bb = unsafe { LLVMGetFirstBasicBlock(LLVMGetBasicBlockParent(current_bb)) }; unsafe { LLVMPositionBuilderAtEnd(builder, variables_bb) }; - let type_ = if type_ == get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::Pred))? { - get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))? - } else { - type_ - }; let result = ctx.names.register_result_option(name, |name| unsafe { LLVMZludaBuildAlloca(builder, type_, addr_space, name) }); unsafe { LLVMPositionBuilderAtEnd(builder, current_bb) }; - Ok(result) + result } fn emit_instruction( @@ -3476,33 +3475,11 @@ fn emit_load_var( ) }; } - emit_alloca_load(ctx, &load.typ, load.arg.dst, src) -} - -fn emit_alloca_load( - ctx: &mut EmitContext, - type_: &ast::Type, - dst_id: Id, - src: LLVMValueRef, -) -> Result<(), TranslateError> { - let builder = ctx.builder.get(); - let (dst, llvm_type) = if type_ == &ast::Type::Scalar(ast::ScalarType::Pred) { - ( - None, - get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?, - ) - } else { - (Some(dst_id), get_llvm_type(ctx, &type_)?) - }; - let ld_result = ctx.names.register_result_option(dst, |dst| unsafe { + let llvm_type = get_llvm_type(ctx, &load.typ)?; + ctx.names.register_result(load.arg.dst, |dst| unsafe { LLVMBuildLoad2(builder, llvm_type, src, dst) }); - Ok(if type_ == &ast::Type::Scalar(ast::ScalarType::Pred) { - let pred_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::Pred))?; - ctx.names.register_result(dst_id, |dst| unsafe { - LLVMBuildTrunc(builder, ld_result, pred_type, dst) - }); - }) + Ok(()) } fn emit_store_var( @@ -3528,22 +3505,6 @@ fn emit_store_var( }; }; let val = ctx.names.value(store.arg.src2)?; - emit_alloca_store(ctx, &store.type_, val, ptr)?; - Ok(()) -} - -fn emit_alloca_store( - ctx: &mut EmitContext, - type_: &ast::Type, - val: LLVMValueRef, - ptr: LLVMValueRef, -) -> Result<(), TranslateError> { - let val = if type_ == &ast::Type::Scalar(ast::ScalarType::Pred) { - let b8_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?; - unsafe { LLVMBuildZExt(ctx.builder.get(), val, b8_type, LLVM_UNNAMED) } - } else { - val - }; unsafe { LLVMBuildStore(ctx.builder.get(), val, ptr) }; Ok(()) }