From 0ccd5dec5e4a0b161c8d868f3db77b62d08a8a7b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 6 May 2024 00:17:27 +0200 Subject: [PATCH] Eliminate i1 --- ptx/src/emit.rs | 67 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index 7388203..670446c 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -510,7 +510,7 @@ fn emit_function_variable( llvm_type, get_llvm_address_space(&ctx.constants, variable.state_space)?, Some(variable.name), - ); + )?; match variable.initializer { None => {} Some(init) => { @@ -1066,14 +1066,10 @@ fn emit_value_copy( src: LLVMValueRef, dst: Id, ) -> Result<(), TranslateError> { - let builder = ctx.builder.get(); - let type_ = get_llvm_type(ctx, type_)?; - let temp_value = emit_alloca(ctx, type_, ctx.constants.private_space, None); - unsafe { LLVMBuildStore(builder, src, temp_value) }; - ctx.names.register_result(dst, |dst| unsafe { - LLVMBuildLoad2(builder, type_, temp_value, dst) - }); - Ok(()) + let llvm_type = get_llvm_type(ctx, type_)?; + let temp_value = emit_alloca(ctx, llvm_type, ctx.constants.private_space, None)?; + emit_alloca_store(ctx, type_, src, temp_value)?; + emit_alloca_load(ctx, type_, dst, temp_value) } // From "Performance Tips for Frontend Authors" (https://llvm.org/docs/Frontend/PerformanceTips.html): @@ -1086,16 +1082,21 @@ fn emit_alloca( type_: LLVMTypeRef, addr_space: u32, name: Option, -) -> LLVMValueRef { +) -> Result { let builder = ctx.builder.get(); let current_bb = unsafe { LLVMGetInsertBlock(builder) }; let variables_bb = unsafe { LLVMGetFirstBasicBlock(LLVMGetBasicBlockParent(current_bb)) }; unsafe { LLVMPositionBuilderAtEnd(builder, variables_bb) }; + let type_ = if type_ == get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::Pred))? { + get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))? + } else { + type_ + }; let result = ctx.names.register_result_option(name, |name| unsafe { LLVMZludaBuildAlloca(builder, type_, addr_space, name) }); unsafe { LLVMPositionBuilderAtEnd(builder, current_bb) }; - result + Ok(result) } fn emit_instruction( @@ -3475,11 +3476,33 @@ fn emit_load_var( ) }; } - let llvm_type = get_llvm_type(ctx, &load.typ)?; - ctx.names.register_result(load.arg.dst, |dst| unsafe { + emit_alloca_load(ctx, &load.typ, load.arg.dst, src) +} + +fn emit_alloca_load( + ctx: &mut EmitContext, + type_: &ast::Type, + dst_id: Id, + src: LLVMValueRef, +) -> Result<(), TranslateError> { + let builder = ctx.builder.get(); + let (dst, llvm_type) = if type_ == &ast::Type::Scalar(ast::ScalarType::Pred) { + ( + None, + get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?, + ) + } else { + (Some(dst_id), get_llvm_type(ctx, &type_)?) + }; + let ld_result = ctx.names.register_result_option(dst, |dst| unsafe { LLVMBuildLoad2(builder, llvm_type, src, dst) }); - Ok(()) + Ok(if type_ == &ast::Type::Scalar(ast::ScalarType::Pred) { + let pred_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::Pred))?; + ctx.names.register_result(dst_id, |dst| unsafe { + LLVMBuildTrunc(builder, ld_result, pred_type, dst) + }); + }) } fn emit_store_var( @@ -3505,6 +3528,22 @@ fn emit_store_var( }; }; let val = ctx.names.value(store.arg.src2)?; + emit_alloca_store(ctx, &store.type_, val, ptr)?; + Ok(()) +} + +fn emit_alloca_store( + ctx: &mut EmitContext, + type_: &ast::Type, + val: LLVMValueRef, + ptr: LLVMValueRef, +) -> Result<(), TranslateError> { + let val = if type_ == &ast::Type::Scalar(ast::ScalarType::Pred) { + let b8_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?; + unsafe { LLVMBuildZExt(ctx.builder.get(), val, b8_type, LLVM_UNNAMED) } + } else { + val + }; unsafe { LLVMBuildStore(ctx.builder.get(), val, ptr) }; Ok(()) }