Add prmt, membar, fix some of cvt

This commit is contained in:
Andrzej Janik 2024-10-15 18:05:32 +02:00
parent 6f2944d9be
commit 3105674618
3 changed files with 204 additions and 87 deletions

View file

@ -183,4 +183,13 @@ void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF)
cast<Instruction>(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF));
}
void LLVMZludaBuildFence(LLVMBuilderRef B, LLVMAtomicOrdering Ordering,
char *scope, const char *Name)
{
auto builder = llvm::unwrap(B);
LLVMContext &context = builder->getContext();
builder->CreateFence(mapFromLLVMOrdering(Ordering),
context.getOrInsertSyncScopeID(scope));
}
LLVM_C_EXTERN_C_END

View file

@ -71,4 +71,11 @@ extern "C" {
) -> LLVMValueRef;
pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags);
pub fn LLVMZludaBuildFence(
B: LLVMBuilderRef,
ordering: LLVMAtomicOrdering,
scope: *const i8,
Name: *const i8,
) -> LLVMValueRef;
}

View file

@ -385,20 +385,22 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
| ptx_parser::ScalarType::U16 => unsafe {
LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::S32
| ptx_parser::ScalarType::B32
| ptx_parser::ScalarType::U32 => unsafe {
LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0)
},
ptx_parser::ScalarType::F16 => todo!(),
ptx_parser::ScalarType::BF16 => todo!(),
ptx_parser::ScalarType::S32 => todo!(),
ptx_parser::ScalarType::U64 => todo!(),
ptx_parser::ScalarType::S64 => todo!(),
ptx_parser::ScalarType::S16x2 => todo!(),
ptx_parser::ScalarType::B32 => todo!(),
ptx_parser::ScalarType::F32 => todo!(),
ptx_parser::ScalarType::B64 => todo!(),
ptx_parser::ScalarType::F64 => todo!(),
ptx_parser::ScalarType::B128 => todo!(),
ptx_parser::ScalarType::U16x2 => todo!(),
ptx_parser::ScalarType::F16x2 => todo!(),
ptx_parser::ScalarType::U32 => todo!(),
ptx_parser::ScalarType::BF16x2 => todo!(),
})
}
@ -552,8 +554,8 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
ast::Instruction::PrmtSlow { .. } => todo!(),
ast::Instruction::Prmt { .. } => todo!(),
ast::Instruction::Membar { .. } => todo!(),
ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
ast::Instruction::Membar { data } => self.emit_membar(data),
ast::Instruction::Trap {} => todo!(),
// replaced by a function call
ast::Instruction::Bfe { .. }
@ -582,88 +584,14 @@ impl<'a> MethodEmitContext<'a> {
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
let builder = self.builder;
match conversion.kind {
ConversionKind::Default => {
match (&conversion.from_type, &conversion.to_type) {
(ast::Type::Scalar(from_type), ast::Type::Scalar(to_type)) => {
let from_layout = conversion.from_type.layout();
let to_layout = conversion.to_type.layout();
if from_layout.size() == to_layout.size() {
let dst_type = get_type(self.context, &conversion.to_type)?;
if from_type.kind() != ast::ScalarKind::Float
&& to_type.kind() != ast::ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
self.resolver
.register(conversion.dst, self.resolver.value(conversion.src)?);
} else {
let src = self.resolver.value(conversion.src)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildBitCast(builder, src, dst_type, dst)
});
}
Ok(())
} else {
let src = self.resolver.value(conversion.src)?;
// This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = unsafe {
LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
};
let same_width_bit_value = unsafe {
LLVMBuildBitCast(
builder,
src,
same_width_bit_type,
LLVM_UNNAMED.as_ptr(),
)
};
let wide_bit_type = unsafe {
LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
};
if to_type.kind() == ast::ScalarKind::Unsigned
|| to_type.kind() == ast::ScalarKind::Bit
{
let llvm_fn = if to_type.size_of() >= from_type.size_of() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
};
self.resolver.with_result(conversion.dst, |dst| unsafe {
llvm_fn(builder, same_width_bit_value, wide_bit_type, dst)
});
Ok(())
} else {
let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
&& to_type.kind() == ast::ScalarKind::Signed
{
if to_type.size_of() >= from_type.size_of() {
LLVMBuildSExtOrBitCast
} else {
LLVMBuildTrunc
}
} else {
if to_type.size_of() >= from_type.size_of() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
}
};
todo!()
}
}
}
(ast::Type::Vector(..), ast::Type::Scalar(..))
| (ast::Type::Scalar(..), ast::Type::Array(..))
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
let src = self.resolver.value(conversion.src)?;
let dst_type = get_type(self.context, &conversion.to_type)?;
self.resolver.with_result(conversion.dst, |dst| unsafe {
LLVMBuildBitCast(builder, src, dst_type, dst)
});
Ok(())
}
_ => todo!(),
}
}
ConversionKind::Default => self.emit_conversion_default(
self.resolver.value(conversion.src)?,
conversion.dst,
&conversion.from_type,
conversion.from_space,
&conversion.to_type,
conversion.to_space,
),
ConversionKind::SignExtend => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_type(self.context, &conversion.to_type)?;
@ -699,6 +627,115 @@ impl<'a> MethodEmitContext<'a> {
}
}
fn emit_conversion_default(
&mut self,
src: LLVMValueRef,
dst: SpirvWord,
from_type: &ast::Type,
from_space: ast::StateSpace,
to_type: &ast::Type,
to_space: ast::StateSpace,
) -> Result<(), TranslateError> {
match (from_type, to_type) {
(ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => {
let from_layout = from_type.layout();
let to_layout = to_type.layout();
if from_layout.size() == to_layout.size() {
let dst_type = get_type(self.context, &to_type)?;
if from_type.kind() != ast::ScalarKind::Float
&& to_type_scalar.kind() != ast::ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
self.resolver.register(dst, src);
} else {
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildBitCast(self.builder, src, dst_type, dst)
});
}
Ok(())
} else {
// This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = unsafe {
LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
};
let same_width_bit_value = unsafe {
LLVMBuildBitCast(
self.builder,
src,
same_width_bit_type,
LLVM_UNNAMED.as_ptr(),
)
};
let wide_bit_type = match to_type_scalar.layout().size() {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => return Err(error_unreachable()),
};
let wide_bit_type_llvm = unsafe {
LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
};
if to_type_scalar.kind() == ast::ScalarKind::Unsigned
|| to_type_scalar.kind() == ast::ScalarKind::Bit
{
let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
};
self.resolver.with_result(dst, |dst| unsafe {
llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst)
});
Ok(())
} else {
let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
&& to_type_scalar.kind() == ast::ScalarKind::Signed
{
if to_type_scalar.size_of() >= from_type.size_of() {
LLVMBuildSExtOrBitCast
} else {
LLVMBuildTrunc
}
} else {
if to_type_scalar.size_of() >= from_type.size_of() {
LLVMBuildZExtOrBitCast
} else {
LLVMBuildTrunc
}
};
let wide_bit_value = unsafe {
conversion_fn(
self.builder,
same_width_bit_value,
wide_bit_type_llvm,
LLVM_UNNAMED.as_ptr(),
)
};
self.emit_conversion_default(
wide_bit_value,
dst,
&wide_bit_type.into(),
from_space,
to_type,
to_space,
)
}
}
}
(ast::Type::Vector(..), ast::Type::Scalar(..))
| (ast::Type::Scalar(..), ast::Type::Array(..))
| (ast::Type::Array(..), ast::Type::Scalar(..)) => {
let dst_type = get_type(self.context, to_type)?;
self.resolver.with_result(dst, |dst| unsafe {
LLVMBuildBitCast(self.builder, src, dst_type, dst)
});
Ok(())
}
_ => todo!(),
}
}
fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> {
let type_ = get_scalar_type(self.context, constant.typ);
let value = match constant.value {
@ -1879,6 +1916,60 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> {
unsafe {
LLVMZludaBuildFence(
self.builder,
LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent,
get_scope_membar(data)?,
LLVM_UNNAMED.as_ptr(),
)
};
Ok(())
}
fn emit_prmt(
&mut self,
control: u16,
arguments: ptx_parser::PrmtArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let components = [
(control >> 0) & 0b1111,
(control >> 4) & 0b1111,
(control >> 8) & 0b1111,
(control >> 12) & 0b1111,
];
if components.iter().any(|&c| c > 7) {
return Err(TranslateError::Todo);
}
let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;
let mut components = [
unsafe { LLVMConstInt(u32_type, components[0] as _, 0) },
unsafe { LLVMConstInt(u32_type, components[1] as _, 0) },
unsafe { LLVMConstInt(u32_type, components[2] as _, 0) },
unsafe { LLVMConstInt(u32_type, components[3] as _, 0) },
];
let components_indices =
unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) };
let src1 = self.resolver.value(arguments.src1)?;
let src1_vector =
unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) };
let src2 = self.resolver.value(arguments.src2)?;
let src2_vector =
unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) };
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildShuffleVector(
self.builder,
src1_vector,
src2_vector,
components_indices,
dst,
)
});
Ok(())
}
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
@ -1964,6 +2055,16 @@ fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
.as_ptr())
}
fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
Ok(match scope {
ast::MemScope::Cta => c"workgroup",
ast::MemScope::Gpu => c"agent",
ast::MemScope::Sys => c"",
ast::MemScope::Cluster => todo!(),
}
.as_ptr())
}
fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
match semantics {
ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,