diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index d6af00d..5613cb0 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -22,7 +22,7 @@ use std::array::TryFromSliceError; use std::convert::{TryFrom, TryInto}; use std::ffi::{CStr, NulError}; use std::ops::Deref; -use std::ptr; +use std::{i8, ptr}; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; @@ -454,7 +454,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Statement::Constant(constant) => self.emit_constant(constant)?, Statement::RetValue(_, values) => self.emit_ret_value(values)?, Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?, - Statement::RepackVector(_) => todo!(), + Statement::RepackVector(repack) => self.emit_vector_repack(repack)?, Statement::FunctionPointer(_) => todo!(), Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, @@ -610,8 +610,22 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { LLVMBuildBitCast(builder, src, type_, dst) }); Ok(()) + } else if to_layout.size() > from_layout.size() { + // TODO: not entirely correct + let src = self.resolver.value(conversion.src)?; + let type_ = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildZExt(builder, src, type_, dst) + }); + Ok(()) } else { - todo!() + // TODO: not entirely correct + let src = self.resolver.value(conversion.src)?; + let type_ = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildTrunc(builder, src, type_, dst) + }); + Ok(()) } } ConversionKind::SignExtend => todo!(), @@ -1020,6 +1034,38 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } + + fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> { + let i8_type = get_scalar_type(self.context, ast::ScalarType::B8); + if repack.is_extract { + let src = self.resolver.value(repack.packed)?; + for (index, dst) in repack.unpacked.iter().enumerate() { + let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) }; + self.resolver.with_result(*dst, |dst| unsafe { + LLVMBuildExtractElement(self.builder, src, index, dst) + }); + } + } else { + let vector_type = get_type( + self.context, + &ast::Type::Vector(repack.unpacked.len() as u8, repack.typ), + )?; + let mut temp_vec = unsafe { LLVMGetUndef(vector_type) }; + for (index, src_id) in repack.unpacked.iter().enumerate() { + let dst = if index == repack.unpacked.len() - 1 { + Some(repack.packed) + } else { + None + }; + let scalar_src = self.resolver.value(*src_id)?; + let index = unsafe { LLVMConstInt(i8_type, index as _, 0) }; + temp_vec = self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst) + }); + } + } + Ok(()) + } } fn get_pointer_type<'ctx>( @@ -1201,4 +1247,15 @@ impl ResolveIdent { self.register(word, t); t } + + fn with_result_option( + &mut self, + word: Option, + fn_: impl FnOnce(*const i8) -> LLVMValueRef, + ) -> LLVMValueRef { + match word { + Some(word) => self.with_result(word, fn_), + None => fn_(LLVM_UNNAMED.as_ptr()), + } + } } diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index e9768e0..1125d39 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -226,23 +226,23 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { fn vec_pack( &mut self, - vecs: Vec, + vector_elements: Vec, type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { - let (scalar_t, state_space) = match type_space { - Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space), + let (width, scalar_t, state_space) = match type_space { + Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space), _ => return Err(error_mismatched_type()), }; - let temp_vec = self + let temporary_vector = self .resolver - .register_unnamed(Some((scalar_t.into(), state_space))); + .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space))); let statement = Statement::RepackVector(RepackVectorDetails { is_extract: is_dst, typ: scalar_t, - packed: temp_vec, - unpacked: vecs, + packed: temporary_vector, + unpacked: vector_elements, relaxed_type_check, }); if is_dst { @@ -250,7 +250,7 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { } else { self.result.push(statement); } - Ok(temp_vec) + Ok(temporary_vector) } }