Fix vector extract/insert

This commit is contained in:
Andrzej Janik 2024-10-06 18:00:48 +02:00
parent 6490519885
commit aa6a8ed4c4
2 changed files with 68 additions and 11 deletions

View file

@ -22,7 +22,7 @@ use std::array::TryFromSliceError;
use std::convert::{TryFrom, TryInto};
use std::ffi::{CStr, NulError};
use std::ops::Deref;
use std::ptr;
use std::{i8, ptr};
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
@ -454,7 +454,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Statement::Constant(constant) => self.emit_constant(constant)?,
Statement::RetValue(_, values) => self.emit_ret_value(values)?,
Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
Statement::RepackVector(_) => todo!(),
Statement::RepackVector(repack) => self.emit_vector_repack(repack)?,
Statement::FunctionPointer(_) => todo!(),
Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
@ -610,8 +610,22 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
LLVMBuildBitCast(builder, src, type_, dst)
});
Ok(())
} else if to_layout.size() > from_layout.size() {
// TODO: not entirely correct
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildZExt(builder, src, type_, dst)
});
Ok(())
} else {
todo!()
// TODO: not entirely correct
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildTrunc(builder, src, type_, dst)
});
Ok(())
}
}
ConversionKind::SignExtend => todo!(),
@ -1020,6 +1034,38 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> {
let i8_type = get_scalar_type(self.context, ast::ScalarType::B8);
if repack.is_extract {
let src = self.resolver.value(repack.packed)?;
for (index, dst) in repack.unpacked.iter().enumerate() {
let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) };
self.resolver.with_result(*dst, |dst| unsafe {
LLVMBuildExtractElement(self.builder, src, index, dst)
});
}
} else {
let vector_type = get_type(
self.context,
&ast::Type::Vector(repack.unpacked.len() as u8, repack.typ),
)?;
let mut temp_vec = unsafe { LLVMGetUndef(vector_type) };
for (index, src_id) in repack.unpacked.iter().enumerate() {
let dst = if index == repack.unpacked.len() - 1 {
Some(repack.packed)
} else {
None
};
let scalar_src = self.resolver.value(*src_id)?;
let index = unsafe { LLVMConstInt(i8_type, index as _, 0) };
temp_vec = self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst)
});
}
}
Ok(())
}
}
fn get_pointer_type<'ctx>(
@ -1201,4 +1247,15 @@ impl ResolveIdent {
self.register(word, t);
t
}
fn with_result_option(
&mut self,
word: Option<SpirvWord>,
fn_: impl FnOnce(*const i8) -> LLVMValueRef,
) -> LLVMValueRef {
match word {
Some(word) => self.with_result(word, fn_),
None => fn_(LLVM_UNNAMED.as_ptr()),
}
}
}

View file

@ -226,23 +226,23 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
fn vec_pack(
&mut self,
vecs: Vec<SpirvWord>,
vector_elements: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
let (width, scalar_t, state_space) = match type_space {
Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
_ => return Err(error_mismatched_type()),
};
let temp_vec = self
let temporary_vector = self
.resolver
.register_unnamed(Some((scalar_t.into(), state_space)));
.register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: vecs,
packed: temporary_vector,
unpacked: vector_elements,
relaxed_type_check,
});
if is_dst {
@ -250,7 +250,7 @@ impl<'a, 'input> FlattenArguments<'a, 'input> {
} else {
self.result.push(statement);
}
Ok(temp_vec)
Ok(temporary_vector)
}
}