Fix vector extract/insert

This commit is contained in:
Andrzej Janik 2024-10-06 18:00:48 +02:00
commit aa6a8ed4c4
2 changed files with 68 additions and 11 deletions

View file

@ -22,7 +22,7 @@ use std::array::TryFromSliceError;
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::ffi::{CStr, NulError}; use std::ffi::{CStr, NulError};
use std::ops::Deref; use std::ops::Deref;
use std::ptr; use std::{i8, ptr};
use super::*; use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
@ -454,7 +454,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Statement::Constant(constant) => self.emit_constant(constant)?, Statement::Constant(constant) => self.emit_constant(constant)?,
Statement::RetValue(_, values) => self.emit_ret_value(values)?, Statement::RetValue(_, values) => self.emit_ret_value(values)?,
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?, Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
Statement::RepackVector(_) => todo!(), Statement::RepackVector(repack) => self.emit_vector_repack(repack)?,
Statement::FunctionPointer(_) => todo!(), Statement::FunctionPointer(_) => todo!(),
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
@ -610,8 +610,22 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
LLVMBuildBitCast(builder, src, type_, dst) LLVMBuildBitCast(builder, src, type_, dst)
}); });
Ok(()) Ok(())
} else if to_layout.size() > from_layout.size() {
// TODO: not entirely correct
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildZExt(builder, src, type_, dst)
});
Ok(())
} else { } else {
todo!() // TODO: not entirely correct
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildTrunc(builder, src, type_, dst)
});
Ok(())
} }
} }
ConversionKind::SignExtend => todo!(), ConversionKind::SignExtend => todo!(),
@ -1020,6 +1034,38 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
}); });
Ok(()) Ok(())
} }
fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> {
let i8_type = get_scalar_type(self.context, ast::ScalarType::B8);
if repack.is_extract {
let src = self.resolver.value(repack.packed)?;
for (index, dst) in repack.unpacked.iter().enumerate() {
let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) };
self.resolver.with_result(*dst, |dst| unsafe {
LLVMBuildExtractElement(self.builder, src, index, dst)
});
}
} else {
let vector_type = get_type(
self.context,
&ast::Type::Vector(repack.unpacked.len() as u8, repack.typ),
)?;
let mut temp_vec = unsafe { LLVMGetUndef(vector_type) };
for (index, src_id) in repack.unpacked.iter().enumerate() {
let dst = if index == repack.unpacked.len() - 1 {
Some(repack.packed)
} else {
None
};
let scalar_src = self.resolver.value(*src_id)?;
let index = unsafe { LLVMConstInt(i8_type, index as _, 0) };
temp_vec = self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst)
});
}
}
Ok(())
}
} }
fn get_pointer_type<'ctx>( fn get_pointer_type<'ctx>(
@ -1201,4 +1247,15 @@ impl ResolveIdent {
self.register(word, t); self.register(word, t);
t t
} }
fn with_result_option(
&mut self,
word: Option<SpirvWord>,
fn_: impl FnOnce(*const i8) -> LLVMValueRef,
) -> LLVMValueRef {
match word {
Some(word) => self.with_result(word, fn_),
None => fn_(LLVM_UNNAMED.as_ptr()),
}
}
} }

View file

@ -226,23 +226,23 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
fn vec_pack( fn vec_pack(
&mut self, &mut self,
vecs: Vec<SpirvWord>, vector_elements: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>, type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool, is_dst: bool,
relaxed_type_check: bool, relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> { ) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) = match type_space { let (width, scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space), Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
_ => return Err(error_mismatched_type()), _ => return Err(error_mismatched_type()),
}; };
let temp_vec = self let temporary_vector = self
.resolver .resolver
.register_unnamed(Some((scalar_t.into(), state_space))); .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails { let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst, is_extract: is_dst,
typ: scalar_t, typ: scalar_t,
packed: temp_vec, packed: temporary_vector,
unpacked: vecs, unpacked: vector_elements,
relaxed_type_check, relaxed_type_check,
}); });
if is_dst { if is_dst {
@ -250,7 +250,7 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
} else { } else {
self.result.push(statement); self.result.push(statement);
} }
Ok(temp_vec) Ok(temporary_vector)
} }
} }