mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Add prmt, membar, fix some of cvt
This commit is contained in:
parent
6f2944d9be
commit
3105674618
3 changed files with 204 additions and 87 deletions
|
@ -183,4 +183,13 @@ void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF)
|
|||
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
|
||||
}
|
||||
|
||||
void LLVMZludaBuildFence(LLVMBuilderRef B, LLVMAtomicOrdering Ordering,
|
||||
char *scope, const char *Name)
|
||||
{
|
||||
auto builder = llvm::unwrap(B);
|
||||
LLVMContext &context = builder->getContext();
|
||||
builder->CreateFence(mapFromLLVMOrdering(Ordering),
|
||||
context.getOrInsertSyncScopeID(scope));
|
||||
}
|
||||
|
||||
LLVM_C_EXTERN_C_END
|
|
@ -71,4 +71,11 @@ extern "C" {
|
|||
) -> LLVMValueRef;
|
||||
|
||||
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
|
||||
|
||||
pub fn LLVMZludaBuildFence(
|
||||
B: LLVMBuilderRef,
|
||||
ordering: LLVMAtomicOrdering,
|
||||
scope: *const i8,
|
||||
Name: *const i8,
|
||||
) -> LLVMValueRef;
|
||||
}
|
||||
|
|
|
@ -385,20 +385,22 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
|
|||
| ptx_parser::ScalarType::U16 => unsafe {
|
||||
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::S32
|
||||
| ptx_parser::ScalarType::B32
|
||||
| ptx_parser::ScalarType::U32 => unsafe {
|
||||
LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0)
|
||||
},
|
||||
ptx_parser::ScalarType::F16 => todo!(),
|
||||
ptx_parser::ScalarType::BF16 => todo!(),
|
||||
ptx_parser::ScalarType::S32 => todo!(),
|
||||
ptx_parser::ScalarType::U64 => todo!(),
|
||||
ptx_parser::ScalarType::S64 => todo!(),
|
||||
ptx_parser::ScalarType::S16x2 => todo!(),
|
||||
ptx_parser::ScalarType::B32 => todo!(),
|
||||
ptx_parser::ScalarType::F32 => todo!(),
|
||||
ptx_parser::ScalarType::B64 => todo!(),
|
||||
ptx_parser::ScalarType::F64 => todo!(),
|
||||
ptx_parser::ScalarType::B128 => todo!(),
|
||||
ptx_parser::ScalarType::U16x2 => todo!(),
|
||||
ptx_parser::ScalarType::F16x2 => todo!(),
|
||||
ptx_parser::ScalarType::U32 => todo!(),
|
||||
ptx_parser::ScalarType::BF16x2 => todo!(),
|
||||
})
|
||||
}
|
||||
|
@ -552,8 +554,8 @@ impl<'a> MethodEmitContext<'a> {
|
|||
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
|
||||
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
|
||||
ast::Instruction::PrmtSlow { .. } => todo!(),
|
||||
ast::Instruction::Prmt { .. } => todo!(),
|
||||
ast::Instruction::Membar { .. } => todo!(),
|
||||
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
|
||||
ast::Instruction::Membar { data } => self.emit_membar(data),
|
||||
ast::Instruction::Trap {} => todo!(),
|
||||
// replaced by a function call
|
||||
ast::Instruction::Bfe { .. }
|
||||
|
@ -582,88 +584,14 @@ impl<'a> MethodEmitContext<'a> {
|
|||
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
|
||||
let builder = self.builder;
|
||||
match conversion.kind {
|
||||
ConversionKind::Default => {
|
||||
match (&conversion.from_type, &conversion.to_type) {
|
||||
(ast::Type::Scalar(from_type), ast::Type::Scalar(to_type)) => {
|
||||
let from_layout = conversion.from_type.layout();
|
||||
let to_layout = conversion.to_type.layout();
|
||||
if from_layout.size() == to_layout.size() {
|
||||
let dst_type = get_type(self.context, &conversion.to_type)?;
|
||||
if from_type.kind() != ast::ScalarKind::Float
|
||||
&& to_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
// It is noop, but another instruction expects result of this conversion
|
||||
self.resolver
|
||||
.register(conversion.dst, self.resolver.value(conversion.src)?);
|
||||
} else {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
LLVMBuildBitCast(builder, src, dst_type, dst)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
// This block is safe because it's illegal to implictly convert between floating point values
|
||||
let same_width_bit_type = unsafe {
|
||||
LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
|
||||
};
|
||||
let same_width_bit_value = unsafe {
|
||||
LLVMBuildBitCast(
|
||||
builder,
|
||||
src,
|
||||
same_width_bit_type,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
let wide_bit_type = unsafe {
|
||||
LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
|
||||
};
|
||||
if to_type.kind() == ast::ScalarKind::Unsigned
|
||||
|| to_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
let llvm_fn = if to_type.size_of() >= from_type.size_of() {
|
||||
LLVMBuildZExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
};
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
llvm_fn(builder, same_width_bit_value, wide_bit_type, dst)
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
|
||||
&& to_type.kind() == ast::ScalarKind::Signed
|
||||
{
|
||||
if to_type.size_of() >= from_type.size_of() {
|
||||
LLVMBuildSExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
}
|
||||
} else {
|
||||
if to_type.size_of() >= from_type.size_of() {
|
||||
LLVMBuildZExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
}
|
||||
};
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(..), ast::Type::Scalar(..))
|
||||
| (ast::Type::Scalar(..), ast::Type::Array(..))
|
||||
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let dst_type = get_type(self.context, &conversion.to_type)?;
|
||||
self.resolver.with_result(conversion.dst, |dst| unsafe {
|
||||
LLVMBuildBitCast(builder, src, dst_type, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
ConversionKind::Default => self.emit_conversion_default(
|
||||
self.resolver.value(conversion.src)?,
|
||||
conversion.dst,
|
||||
&conversion.from_type,
|
||||
conversion.from_space,
|
||||
&conversion.to_type,
|
||||
conversion.to_space,
|
||||
),
|
||||
ConversionKind::SignExtend => {
|
||||
let src = self.resolver.value(conversion.src)?;
|
||||
let type_ = get_type(self.context, &conversion.to_type)?;
|
||||
|
@ -699,6 +627,115 @@ impl<'a> MethodEmitContext<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
fn emit_conversion_default(
|
||||
&mut self,
|
||||
src: LLVMValueRef,
|
||||
dst: SpirvWord,
|
||||
from_type: &ast::Type,
|
||||
from_space: ast::StateSpace,
|
||||
to_type: &ast::Type,
|
||||
to_space: ast::StateSpace,
|
||||
) -> Result<(), TranslateError> {
|
||||
match (from_type, to_type) {
|
||||
(ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => {
|
||||
let from_layout = from_type.layout();
|
||||
let to_layout = to_type.layout();
|
||||
if from_layout.size() == to_layout.size() {
|
||||
let dst_type = get_type(self.context, &to_type)?;
|
||||
if from_type.kind() != ast::ScalarKind::Float
|
||||
&& to_type_scalar.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
// It is noop, but another instruction expects result of this conversion
|
||||
self.resolver.register(dst, src);
|
||||
} else {
|
||||
self.resolver.with_result(dst, |dst| unsafe {
|
||||
LLVMBuildBitCast(self.builder, src, dst_type, dst)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
// This block is safe because it's illegal to implictly convert between floating point values
|
||||
let same_width_bit_type = unsafe {
|
||||
LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
|
||||
};
|
||||
let same_width_bit_value = unsafe {
|
||||
LLVMBuildBitCast(
|
||||
self.builder,
|
||||
src,
|
||||
same_width_bit_type,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
let wide_bit_type = match to_type_scalar.layout().size() {
|
||||
1 => ast::ScalarType::B8,
|
||||
2 => ast::ScalarType::B16,
|
||||
4 => ast::ScalarType::B32,
|
||||
8 => ast::ScalarType::B64,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let wide_bit_type_llvm = unsafe {
|
||||
LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
|
||||
};
|
||||
if to_type_scalar.kind() == ast::ScalarKind::Unsigned
|
||||
|| to_type_scalar.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() {
|
||||
LLVMBuildZExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
};
|
||||
self.resolver.with_result(dst, |dst| unsafe {
|
||||
llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst)
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
|
||||
&& to_type_scalar.kind() == ast::ScalarKind::Signed
|
||||
{
|
||||
if to_type_scalar.size_of() >= from_type.size_of() {
|
||||
LLVMBuildSExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
}
|
||||
} else {
|
||||
if to_type_scalar.size_of() >= from_type.size_of() {
|
||||
LLVMBuildZExtOrBitCast
|
||||
} else {
|
||||
LLVMBuildTrunc
|
||||
}
|
||||
};
|
||||
let wide_bit_value = unsafe {
|
||||
conversion_fn(
|
||||
self.builder,
|
||||
same_width_bit_value,
|
||||
wide_bit_type_llvm,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
self.emit_conversion_default(
|
||||
wide_bit_value,
|
||||
dst,
|
||||
&wide_bit_type.into(),
|
||||
from_space,
|
||||
to_type,
|
||||
to_space,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(..), ast::Type::Scalar(..))
|
||||
| (ast::Type::Scalar(..), ast::Type::Array(..))
|
||||
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
|
||||
let dst_type = get_type(self.context, to_type)?;
|
||||
self.resolver.with_result(dst, |dst| unsafe {
|
||||
LLVMBuildBitCast(self.builder, src, dst_type, dst)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
|
||||
let type_ = get_scalar_type(self.context, constant.typ);
|
||||
let value = match constant.value {
|
||||
|
@ -1879,6 +1916,60 @@ impl<'a> MethodEmitContext<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> {
|
||||
unsafe {
|
||||
LLVMZludaBuildFence(
|
||||
self.builder,
|
||||
LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent,
|
||||
get_scope_membar(data)?,
|
||||
LLVM_UNNAMED.as_ptr(),
|
||||
)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_prmt(
|
||||
&mut self,
|
||||
control: u16,
|
||||
arguments: ptx_parser::PrmtArgs<SpirvWord>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let components = [
|
||||
(control >> 0) & 0b1111,
|
||||
(control >> 4) & 0b1111,
|
||||
(control >> 8) & 0b1111,
|
||||
(control >> 12) & 0b1111,
|
||||
];
|
||||
if components.iter().any(|&c| c > 7) {
|
||||
return Err(TranslateError::Todo);
|
||||
}
|
||||
let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
|
||||
let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;
|
||||
let mut components = [
|
||||
unsafe { LLVMConstInt(u32_type, components[0] as _, 0) },
|
||||
unsafe { LLVMConstInt(u32_type, components[1] as _, 0) },
|
||||
unsafe { LLVMConstInt(u32_type, components[2] as _, 0) },
|
||||
unsafe { LLVMConstInt(u32_type, components[3] as _, 0) },
|
||||
];
|
||||
let components_indices =
|
||||
unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) };
|
||||
let src1 = self.resolver.value(arguments.src1)?;
|
||||
let src1_vector =
|
||||
unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) };
|
||||
let src2 = self.resolver.value(arguments.src2)?;
|
||||
let src2_vector =
|
||||
unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) };
|
||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||
LLVMBuildShuffleVector(
|
||||
self.builder,
|
||||
src1_vector,
|
||||
src2_vector,
|
||||
components_indices,
|
||||
dst,
|
||||
)
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
|
||||
// Should be available in LLVM 19
|
||||
|
@ -1964,6 +2055,16 @@ fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
|
|||
.as_ptr())
|
||||
}
|
||||
|
||||
fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
|
||||
Ok(match scope {
|
||||
ast::MemScope::Cta => c"workgroup",
|
||||
ast::MemScope::Gpu => c"agent",
|
||||
ast::MemScope::Sys => c"",
|
||||
ast::MemScope::Cluster => todo!(),
|
||||
}
|
||||
.as_ptr())
|
||||
}
|
||||
|
||||
fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
|
||||
match semantics {
|
||||
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
|
||||
|
|
Loading…
Add table
Reference in a new issue