mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-02 14:19:27 +00:00
Merge commit '3da39364e0
' into compile_more
This commit is contained in:
commit
92b7316a87
18 changed files with 540 additions and 8 deletions
1
.git-blame-ignore-revs
Normal file
1
.git-blame-ignore-revs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
21ef5f60a3a5efa17855a30f6b5c7d1968cd46ba
|
Binary file not shown.
|
@ -562,4 +562,19 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
||||||
{
|
{
|
||||||
return ballot(value, true);
|
return ballot(value, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define REDUX_SYNC_TYPE_IMPL(reducer, ptx_type, amd_type, cpp_type) \
|
||||||
|
cpp_type __ockl_wfred_##reducer##_##amd_type(cpp_type) __device__; \
|
||||||
|
cpp_type FUNC(redux_sync_##reducer##_##ptx_type)(cpp_type src, uint32_t membermask __attribute__((unused))) \
|
||||||
|
{ \
|
||||||
|
return __ockl_wfred_##reducer##_##amd_type(src); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REDUX_SYNC_IMPL(reducer) \
|
||||||
|
REDUX_SYNC_TYPE_IMPL(reducer, u32, u32, uint32_t) \
|
||||||
|
REDUX_SYNC_TYPE_IMPL(reducer, s32, i32, int32_t)
|
||||||
|
|
||||||
|
REDUX_SYNC_IMPL(add);
|
||||||
|
REDUX_SYNC_IMPL(min);
|
||||||
|
REDUX_SYNC_IMPL(max);
|
||||||
}
|
}
|
||||||
|
|
|
@ -195,7 +195,8 @@ fn run_instruction<'input>(
|
||||||
| ast::Instruction::Tanh { .. }
|
| ast::Instruction::Tanh { .. }
|
||||||
| ast::Instruction::Trap {}
|
| ast::Instruction::Trap {}
|
||||||
| ast::Instruction::Xor { .. }
|
| ast::Instruction::Xor { .. }
|
||||||
| ast::Instruction::Vote { .. } => result.push(Statement::Instruction(instruction)),
|
| ast::Instruction::Vote { .. }
|
||||||
|
| ast::Instruction::ReduxSync { .. } => result.push(Statement::Instruction(instruction)),
|
||||||
ast::Instruction::Add {
|
ast::Instruction::Add {
|
||||||
data:
|
data:
|
||||||
ast::ArithDetails::Float(ast::ArithFloat {
|
ast::ArithDetails::Float(ast::ArithFloat {
|
||||||
|
|
|
@ -1853,7 +1853,8 @@ fn get_modes<T: ast::Operand>(inst: &ast::Instruction<T>) -> InstructionModes {
|
||||||
| ast::Instruction::Mul24 { .. }
|
| ast::Instruction::Mul24 { .. }
|
||||||
| ast::Instruction::Nanosleep { .. }
|
| ast::Instruction::Nanosleep { .. }
|
||||||
| ast::Instruction::AtomCas { .. }
|
| ast::Instruction::AtomCas { .. }
|
||||||
| ast::Instruction::Vote { .. } => InstructionModes::none(),
|
| ast::Instruction::Vote { .. }
|
||||||
|
| ast::Instruction::ReduxSync { .. } => InstructionModes::none(),
|
||||||
ast::Instruction::Add {
|
ast::Instruction::Add {
|
||||||
data: ast::ArithDetails::Integer(_),
|
data: ast::ArithDetails::Integer(_),
|
||||||
..
|
..
|
||||||
|
|
|
@ -526,7 +526,8 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
| ast::Instruction::Activemask { .. }
|
| ast::Instruction::Activemask { .. }
|
||||||
| ast::Instruction::ShflSync { .. }
|
| ast::Instruction::ShflSync { .. }
|
||||||
| ast::Instruction::Vote { .. }
|
| ast::Instruction::Vote { .. }
|
||||||
| ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()),
|
| ast::Instruction::Nanosleep { .. }
|
||||||
|
| ast::Instruction::ReduxSync { .. } => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1645,9 +1646,39 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let src = self.resolver.value(arguments.src)?;
|
let src = self.resolver.value(arguments.src)?;
|
||||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
if let Some(src2) = arguments.src2 {
|
||||||
llvm_fn(self.builder, src, dst_type, dst)
|
let packed_type = get_scalar_type(
|
||||||
});
|
self.context,
|
||||||
|
data.to
|
||||||
|
.packed_type()
|
||||||
|
.ok_or_else(|| error_mismatched_type())?,
|
||||||
|
);
|
||||||
|
let src2 = self.resolver.value(src2)?;
|
||||||
|
self.resolver.with_result(arguments.dst, |dst| {
|
||||||
|
let vec = unsafe {
|
||||||
|
LLVMBuildInsertElement(
|
||||||
|
self.builder,
|
||||||
|
LLVMGetPoison(dst_type),
|
||||||
|
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
|
||||||
|
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
|
||||||
|
LLVM_UNNAMED.as_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
unsafe {
|
||||||
|
LLVMBuildInsertElement(
|
||||||
|
self.builder,
|
||||||
|
vec,
|
||||||
|
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
|
||||||
|
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
|
||||||
|
dst,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||||
|
llvm_fn(self.builder, src, dst_type, dst)
|
||||||
|
})
|
||||||
|
};
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -377,6 +377,7 @@ fn run_instruction<'input>(
|
||||||
let name = match data.pred_reduction {
|
let name = match data.pred_reduction {
|
||||||
ptx_parser::Reduction::And => "bar_red_and_pred",
|
ptx_parser::Reduction::And => "bar_red_and_pred",
|
||||||
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
ptx_parser::Reduction::Or => "bar_red_or_pred",
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
};
|
};
|
||||||
to_call(
|
to_call(
|
||||||
resolver,
|
resolver,
|
||||||
|
@ -400,6 +401,25 @@ fn run_instruction<'input>(
|
||||||
ptx_parser::Instruction::Vote { data, arguments },
|
ptx_parser::Instruction::Vote { data, arguments },
|
||||||
)?
|
)?
|
||||||
}
|
}
|
||||||
|
ptx_parser::Instruction::ReduxSync { data, arguments } => {
|
||||||
|
let op = match data.reduction {
|
||||||
|
ptx_parser::Reduction::Add => "add",
|
||||||
|
ptx_parser::Reduction::Min => "min",
|
||||||
|
ptx_parser::Reduction::Max => "max",
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let name = format!(
|
||||||
|
"redux_sync_{}_{}",
|
||||||
|
op,
|
||||||
|
data.type_.to_string().replace(".", "")
|
||||||
|
);
|
||||||
|
to_call(
|
||||||
|
resolver,
|
||||||
|
fn_declarations,
|
||||||
|
name.into(),
|
||||||
|
ptx_parser::Instruction::ReduxSync { data, arguments },
|
||||||
|
)?
|
||||||
|
}
|
||||||
ptx_parser::Instruction::ShflSync {
|
ptx_parser::Instruction::ShflSync {
|
||||||
data,
|
data,
|
||||||
arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. },
|
arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. },
|
||||||
|
|
41
ptx/src/test/ll/cvt_rn_bf16x2_f32.ll
Normal file
41
ptx/src/test/ll/cvt_rn_bf16x2_f32.ll
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
define amdgpu_kernel void @cvt_rn_bf16x2_f32(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 {
|
||||||
|
%"39" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"40" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"41" = alloca float, align 4, addrspace(5)
|
||||||
|
%"42" = alloca float, align 4, addrspace(5)
|
||||||
|
%"43" = alloca i32, align 4, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"36"
|
||||||
|
|
||||||
|
"36": ; preds = %1
|
||||||
|
%"44" = load i64, ptr addrspace(4) %"37", align 8
|
||||||
|
store i64 %"44", ptr addrspace(5) %"39", align 8
|
||||||
|
%"45" = load i64, ptr addrspace(4) %"38", align 8
|
||||||
|
store i64 %"45", ptr addrspace(5) %"40", align 8
|
||||||
|
%"47" = load i64, ptr addrspace(5) %"39", align 8
|
||||||
|
%"55" = inttoptr i64 %"47" to ptr
|
||||||
|
%"46" = load float, ptr %"55", align 4
|
||||||
|
store float %"46", ptr addrspace(5) %"41", align 4
|
||||||
|
%"48" = load i64, ptr addrspace(5) %"39", align 8
|
||||||
|
%"56" = inttoptr i64 %"48" to ptr
|
||||||
|
%"35" = getelementptr inbounds i8, ptr %"56", i64 4
|
||||||
|
%"49" = load float, ptr %"35", align 4
|
||||||
|
store float %"49", ptr addrspace(5) %"42", align 4
|
||||||
|
%"51" = load float, ptr addrspace(5) %"41", align 4
|
||||||
|
%"52" = load float, ptr addrspace(5) %"42", align 4
|
||||||
|
%2 = fptrunc float %"51" to bfloat
|
||||||
|
%3 = insertelement <2 x bfloat> poison, bfloat %2, i32 1
|
||||||
|
%4 = fptrunc float %"52" to bfloat
|
||||||
|
%"57" = insertelement <2 x bfloat> %3, bfloat %4, i32 0
|
||||||
|
%"50" = bitcast <2 x bfloat> %"57" to i32
|
||||||
|
store i32 %"50", ptr addrspace(5) %"43", align 4
|
||||||
|
%"53" = load i64, ptr addrspace(5) %"40", align 8
|
||||||
|
%"54" = load i32, ptr addrspace(5) %"43", align 4
|
||||||
|
%"58" = inttoptr i64 %"53" to ptr
|
||||||
|
store i32 %"54", ptr %"58", align 4
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
58
ptx/src/test/ll/redux_sync_add_u32_partial.ll
Normal file
58
ptx/src/test/ll/redux_sync_add_u32_partial.ll
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0
|
||||||
|
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
|
||||||
|
|
||||||
|
define amdgpu_kernel void @redux_sync_add_u32_partial(ptr addrspace(4) byref(i64) %"46") #1 {
|
||||||
|
%"47" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"48" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"49" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"50" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"51" = alloca i1, align 1, addrspace(5)
|
||||||
|
%"62" = alloca i64, align 8, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"43"
|
||||||
|
|
||||||
|
"43": ; preds = %1
|
||||||
|
%"52" = load i64, ptr addrspace(4) %"46", align 8
|
||||||
|
store i64 %"52", ptr addrspace(5) %"49", align 8
|
||||||
|
%"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
|
||||||
|
br label %"44"
|
||||||
|
|
||||||
|
"44": ; preds = %"43"
|
||||||
|
store i32 %"37", ptr addrspace(5) %"47", align 4
|
||||||
|
%"55" = load i32, ptr addrspace(5) %"47", align 4
|
||||||
|
%"54" = urem i32 %"55", 2
|
||||||
|
store i32 %"54", ptr addrspace(5) %"50", align 4
|
||||||
|
%"57" = load i32, ptr addrspace(5) %"50", align 4
|
||||||
|
%2 = icmp eq i32 %"57", 0
|
||||||
|
store i1 %2, ptr addrspace(5) %"51", align 1
|
||||||
|
store i32 0, ptr addrspace(5) %"48", align 4
|
||||||
|
%"59" = load i1, ptr addrspace(5) %"51", align 1
|
||||||
|
br i1 %"59", label %"16", label %"17"
|
||||||
|
|
||||||
|
"16": ; preds = %"44"
|
||||||
|
%"61" = load i32, ptr addrspace(5) %"47", align 4
|
||||||
|
%"60" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"61", i32 1431655765)
|
||||||
|
store i32 %"60", ptr addrspace(5) %"48", align 4
|
||||||
|
br label %"17"
|
||||||
|
|
||||||
|
"17": ; preds = %"16", %"44"
|
||||||
|
%"64" = load i32, ptr addrspace(5) %"47", align 4
|
||||||
|
%3 = zext i32 %"64" to i64
|
||||||
|
%"63" = mul i64 %3, 4
|
||||||
|
store i64 %"63", ptr addrspace(5) %"62", align 8
|
||||||
|
%"66" = load i64, ptr addrspace(5) %"49", align 8
|
||||||
|
%"67" = load i64, ptr addrspace(5) %"62", align 8
|
||||||
|
%"65" = add i64 %"66", %"67"
|
||||||
|
store i64 %"65", ptr addrspace(5) %"49", align 8
|
||||||
|
%"68" = load i64, ptr addrspace(5) %"49", align 8
|
||||||
|
%"69" = load i32, ptr addrspace(5) %"48", align 4
|
||||||
|
%"70" = inttoptr i64 %"68" to ptr
|
||||||
|
store i32 %"69", ptr %"70", align 4
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||||
|
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
67
ptx/src/test/ll/redux_sync_op_s32.ll
Normal file
67
ptx/src/test/ll/redux_sync_op_s32.ll
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_redux_sync_min_s32(i32, i32) #0
|
||||||
|
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_redux_sync_max_s32(i32, i32) #0
|
||||||
|
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_redux_sync_add_s32(i32, i32) #0
|
||||||
|
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
|
||||||
|
|
||||||
|
define amdgpu_kernel void @redux_sync_op_s32(ptr addrspace(4) byref(i64) %"46") #1 {
|
||||||
|
%"47" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"48" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"49" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"50" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"51" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"52" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"53" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"70" = alloca i64, align 8, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"43"
|
||||||
|
|
||||||
|
"43": ; preds = %1
|
||||||
|
%"54" = load i64, ptr addrspace(4) %"46", align 8
|
||||||
|
store i64 %"54", ptr addrspace(5) %"53", align 8
|
||||||
|
%"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
|
||||||
|
br label %"44"
|
||||||
|
|
||||||
|
"44": ; preds = %"43"
|
||||||
|
store i32 %"37", ptr addrspace(5) %"47", align 4
|
||||||
|
%"57" = load i32, ptr addrspace(5) %"47", align 4
|
||||||
|
%"56" = sub i32 %"57", 5
|
||||||
|
store i32 %"56", ptr addrspace(5) %"48", align 4
|
||||||
|
%"59" = load i32, ptr addrspace(5) %"48", align 4
|
||||||
|
%"58" = call i32 @__zluda_ptx_impl_redux_sync_add_s32(i32 %"59", i32 -1)
|
||||||
|
store i32 %"58", ptr addrspace(5) %"49", align 4
|
||||||
|
%"61" = load i32, ptr addrspace(5) %"48", align 4
|
||||||
|
%"60" = call i32 @__zluda_ptx_impl_redux_sync_min_s32(i32 %"61", i32 -1)
|
||||||
|
store i32 %"60", ptr addrspace(5) %"50", align 4
|
||||||
|
%"63" = load i32, ptr addrspace(5) %"48", align 4
|
||||||
|
%"62" = call i32 @__zluda_ptx_impl_redux_sync_max_s32(i32 %"63", i32 -1)
|
||||||
|
store i32 %"62", ptr addrspace(5) %"51", align 4
|
||||||
|
%"65" = load i32, ptr addrspace(5) %"49", align 4
|
||||||
|
%"66" = load i32, ptr addrspace(5) %"50", align 4
|
||||||
|
%"64" = add i32 %"65", %"66"
|
||||||
|
store i32 %"64", ptr addrspace(5) %"52", align 4
|
||||||
|
%"68" = load i32, ptr addrspace(5) %"52", align 4
|
||||||
|
%"69" = load i32, ptr addrspace(5) %"51", align 4
|
||||||
|
%"67" = add i32 %"68", %"69"
|
||||||
|
store i32 %"67", ptr addrspace(5) %"52", align 4
|
||||||
|
%"72" = load i32, ptr addrspace(5) %"47", align 4
|
||||||
|
%2 = zext i32 %"72" to i64
|
||||||
|
%"71" = mul i64 %2, 4
|
||||||
|
store i64 %"71", ptr addrspace(5) %"70", align 8
|
||||||
|
%"74" = load i64, ptr addrspace(5) %"53", align 8
|
||||||
|
%"75" = load i64, ptr addrspace(5) %"70", align 8
|
||||||
|
%"73" = add i64 %"74", %"75"
|
||||||
|
store i64 %"73", ptr addrspace(5) %"53", align 8
|
||||||
|
%"76" = load i64, ptr addrspace(5) %"53", align 8
|
||||||
|
%"77" = load i32, ptr addrspace(5) %"52", align 4
|
||||||
|
%"79" = inttoptr i64 %"76" to ptr
|
||||||
|
store i32 %"77", ptr %"79", align 4
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||||
|
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
63
ptx/src/test/ll/redux_sync_op_u32.ll
Normal file
63
ptx/src/test/ll/redux_sync_op_u32.ll
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_redux_sync_max_u32(i32, i32) #0
|
||||||
|
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0
|
||||||
|
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_redux_sync_min_u32(i32, i32) #0
|
||||||
|
|
||||||
|
declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0
|
||||||
|
|
||||||
|
define amdgpu_kernel void @redux_sync_op_u32(ptr addrspace(4) byref(i64) %"44") #1 {
|
||||||
|
%"45" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"46" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"47" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"48" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"49" = alloca i32, align 4, addrspace(5)
|
||||||
|
%"50" = alloca i64, align 8, addrspace(5)
|
||||||
|
%"65" = alloca i64, align 8, addrspace(5)
|
||||||
|
br label %1
|
||||||
|
|
||||||
|
1: ; preds = %0
|
||||||
|
br label %"41"
|
||||||
|
|
||||||
|
"41": ; preds = %1
|
||||||
|
%"51" = load i64, ptr addrspace(4) %"44", align 8
|
||||||
|
store i64 %"51", ptr addrspace(5) %"50", align 8
|
||||||
|
%"36" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0)
|
||||||
|
br label %"42"
|
||||||
|
|
||||||
|
"42": ; preds = %"41"
|
||||||
|
store i32 %"36", ptr addrspace(5) %"45", align 4
|
||||||
|
%"54" = load i32, ptr addrspace(5) %"45", align 4
|
||||||
|
%"53" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"54", i32 -1)
|
||||||
|
store i32 %"53", ptr addrspace(5) %"46", align 4
|
||||||
|
%"56" = load i32, ptr addrspace(5) %"45", align 4
|
||||||
|
%"55" = call i32 @__zluda_ptx_impl_redux_sync_min_u32(i32 %"56", i32 -1)
|
||||||
|
store i32 %"55", ptr addrspace(5) %"47", align 4
|
||||||
|
%"58" = load i32, ptr addrspace(5) %"45", align 4
|
||||||
|
%"57" = call i32 @__zluda_ptx_impl_redux_sync_max_u32(i32 %"58", i32 -1)
|
||||||
|
store i32 %"57", ptr addrspace(5) %"48", align 4
|
||||||
|
%"60" = load i32, ptr addrspace(5) %"46", align 4
|
||||||
|
%"61" = load i32, ptr addrspace(5) %"47", align 4
|
||||||
|
%"59" = add i32 %"60", %"61"
|
||||||
|
store i32 %"59", ptr addrspace(5) %"49", align 4
|
||||||
|
%"63" = load i32, ptr addrspace(5) %"49", align 4
|
||||||
|
%"64" = load i32, ptr addrspace(5) %"48", align 4
|
||||||
|
%"62" = add i32 %"63", %"64"
|
||||||
|
store i32 %"62", ptr addrspace(5) %"49", align 4
|
||||||
|
%"67" = load i32, ptr addrspace(5) %"45", align 4
|
||||||
|
%2 = zext i32 %"67" to i64
|
||||||
|
%"66" = mul i64 %2, 4
|
||||||
|
store i64 %"66", ptr addrspace(5) %"65", align 8
|
||||||
|
%"69" = load i64, ptr addrspace(5) %"50", align 8
|
||||||
|
%"70" = load i64, ptr addrspace(5) %"65", align 8
|
||||||
|
%"68" = add i64 %"69", %"70"
|
||||||
|
store i64 %"68", ptr addrspace(5) %"50", align 8
|
||||||
|
%"71" = load i64, ptr addrspace(5) %"50", align 8
|
||||||
|
%"72" = load i32, ptr addrspace(5) %"49", align 4
|
||||||
|
%"73" = inttoptr i64 %"71" to ptr
|
||||||
|
store i32 %"72", ptr %"73", align 4
|
||||||
|
ret void
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
||||||
|
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
25
ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx
Normal file
25
ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
.version 7.8
|
||||||
|
.target sm_90
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry cvt_rn_bf16x2_f32(
|
||||||
|
.param .u64 input,
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .f32 in_a;
|
||||||
|
.reg .f32 in_b;
|
||||||
|
.reg .b32 result;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
ld.param.u64 out_addr, [output];
|
||||||
|
|
||||||
|
ld.f32 in_a, [in_addr];
|
||||||
|
ld.f32 in_b, [in_addr + 4];
|
||||||
|
|
||||||
|
cvt.rn.bf16x2.f32 result, in_a, in_b;
|
||||||
|
st.b32 [out_addr], result;
|
||||||
|
ret;
|
||||||
|
}
|
|
@ -200,6 +200,7 @@ test_ptx!(
|
||||||
);
|
);
|
||||||
test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]);
|
test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]);
|
||||||
test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]);
|
test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]);
|
||||||
|
test_ptx!(cvt_rn_bf16x2_f32, [0.40625, 12.9f32], [0x3ED0414Eu32]);
|
||||||
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
|
test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
|
||||||
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
|
test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
|
||||||
test_ptx!(
|
test_ptx!(
|
||||||
|
@ -452,6 +453,40 @@ test_ptx_warp!(
|
||||||
4294967292, 4294967292, 4294967292, 4294967292, 4294967292
|
4294967292, 4294967292, 4294967292, 4294967292, 4294967292
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
test_ptx_warp!(
|
||||||
|
redux_sync_op_s32,
|
||||||
|
[
|
||||||
|
357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32,
|
||||||
|
357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32,
|
||||||
|
357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 1445i32,
|
||||||
|
1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32,
|
||||||
|
1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32,
|
||||||
|
1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32,
|
||||||
|
1445i32,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
test_ptx_warp!(
|
||||||
|
redux_sync_op_u32,
|
||||||
|
[
|
||||||
|
527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32,
|
||||||
|
527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32,
|
||||||
|
527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 1615u32,
|
||||||
|
1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32,
|
||||||
|
1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32,
|
||||||
|
1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32,
|
||||||
|
1615u32,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
test_ptx_warp!(
|
||||||
|
redux_sync_add_u32_partial,
|
||||||
|
[
|
||||||
|
240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32,
|
||||||
|
0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32,
|
||||||
|
240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32,
|
||||||
|
0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32,
|
||||||
|
752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
struct DisplayError<T: Debug> {
|
struct DisplayError<T: Debug> {
|
||||||
err: T,
|
err: T,
|
||||||
|
|
31
ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx
Normal file
31
ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
.version 7.0
|
||||||
|
.target sm_80
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry redux_sync_add_u32_partial(
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u32 tid;
|
||||||
|
.reg .u32 result;
|
||||||
|
.reg .u64 out_ptr;
|
||||||
|
|
||||||
|
.reg .u32 tid_rem_2;
|
||||||
|
.reg .pred p;
|
||||||
|
|
||||||
|
ld.param.u64 out_ptr, [output];
|
||||||
|
mov.u32 tid, %tid.x;
|
||||||
|
|
||||||
|
rem.u32 tid_rem_2, tid, 2;
|
||||||
|
setp.eq.u32 p, tid_rem_2, 0;
|
||||||
|
|
||||||
|
mov.u32 result, 0;
|
||||||
|
@p redux.sync.add.u32 result, tid, 0x55555555;
|
||||||
|
|
||||||
|
.reg .u64 out_offset;
|
||||||
|
mul.wide.u32 out_offset, tid, 4;
|
||||||
|
add.u64 out_ptr, out_ptr, out_offset;
|
||||||
|
st.u32 [out_ptr], result;
|
||||||
|
|
||||||
|
ret;
|
||||||
|
}
|
34
ptx/src/test/spirv_run/redux_sync_op_s32.ptx
Normal file
34
ptx/src/test/spirv_run/redux_sync_op_s32.ptx
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
.version 7.0
|
||||||
|
.target sm_80
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry redux_sync_op_s32(
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u32 tid;
|
||||||
|
.reg .s32 in;
|
||||||
|
.reg .s32 add_out;
|
||||||
|
.reg .s32 min_out;
|
||||||
|
.reg .s32 max_out;
|
||||||
|
.reg .s32 result;
|
||||||
|
.reg .u64 out_ptr;
|
||||||
|
|
||||||
|
ld.param.u64 out_ptr, [output];
|
||||||
|
mov.u32 tid, %tid.x;
|
||||||
|
sub.s32 in, tid, 5;
|
||||||
|
|
||||||
|
redux.sync.add.s32 add_out, in, 0xFFFFFFFF;
|
||||||
|
redux.sync.min.s32 min_out, in, 0xFFFFFFFF;
|
||||||
|
redux.sync.max.s32 max_out, in, 0xFFFFFFFF;
|
||||||
|
|
||||||
|
add.s32 result, add_out, min_out;
|
||||||
|
add.s32 result, result, max_out;
|
||||||
|
|
||||||
|
.reg .u64 out_offset;
|
||||||
|
mul.wide.u32 out_offset, tid, 4;
|
||||||
|
add.u64 out_ptr, out_ptr, out_offset;
|
||||||
|
st.s32 [out_ptr], result;
|
||||||
|
|
||||||
|
ret;
|
||||||
|
}
|
32
ptx/src/test/spirv_run/redux_sync_op_u32.ptx
Normal file
32
ptx/src/test/spirv_run/redux_sync_op_u32.ptx
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
.version 7.0
|
||||||
|
.target sm_80
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry redux_sync_op_u32(
|
||||||
|
.param .u64 output
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u32 tid;
|
||||||
|
.reg .u32 add_out;
|
||||||
|
.reg .u32 min_out;
|
||||||
|
.reg .u32 max_out;
|
||||||
|
.reg .u32 result;
|
||||||
|
.reg .u64 out_ptr;
|
||||||
|
|
||||||
|
ld.param.u64 out_ptr, [output];
|
||||||
|
mov.u32 tid, %tid.x;
|
||||||
|
|
||||||
|
redux.sync.add.u32 add_out, tid, 0xFFFFFFFF;
|
||||||
|
redux.sync.min.u32 min_out, tid, 0xFFFFFFFF;
|
||||||
|
redux.sync.max.u32 max_out, tid, 0xFFFFFFFF;
|
||||||
|
|
||||||
|
add.u32 result, add_out, min_out;
|
||||||
|
add.u32 result, result, max_out;
|
||||||
|
|
||||||
|
.reg .u64 out_offset;
|
||||||
|
mul.wide.u32 out_offset, tid, 4;
|
||||||
|
add.u64 out_ptr, out_ptr, out_offset;
|
||||||
|
st.u32 [out_ptr], result;
|
||||||
|
|
||||||
|
ret;
|
||||||
|
}
|
|
@ -695,6 +695,18 @@ ptx_parser_macros::generate_instruction_type!(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
},
|
||||||
|
ReduxSync {
|
||||||
|
type: Type::Scalar(data.type_),
|
||||||
|
data: ReduxSyncData,
|
||||||
|
arguments<T>: {
|
||||||
|
dst: T,
|
||||||
|
src: T,
|
||||||
|
src_membermask: {
|
||||||
|
repr: T,
|
||||||
|
type: { Type::Scalar(ScalarType::U32) },
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -1162,6 +1174,35 @@ impl ScalarType {
|
||||||
ScalarType::Pred => ScalarKind::Pred,
|
ScalarType::Pred => ScalarKind::Pred,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn packed_type(&self) -> Option<ScalarType> {
|
||||||
|
match self {
|
||||||
|
ScalarType::E4m3x2 => Some(ScalarType::B8),
|
||||||
|
ScalarType::E5m2x2 => Some(ScalarType::B8),
|
||||||
|
ScalarType::F16x2 => Some(ScalarType::F16),
|
||||||
|
ScalarType::BF16x2 => Some(ScalarType::BF16),
|
||||||
|
ScalarType::U16x2 => Some(ScalarType::U16),
|
||||||
|
ScalarType::S16x2 => Some(ScalarType::S16),
|
||||||
|
ScalarType::S16
|
||||||
|
| ScalarType::BF16
|
||||||
|
| ScalarType::U32
|
||||||
|
| ScalarType::S8
|
||||||
|
| ScalarType::S32
|
||||||
|
| ScalarType::Pred
|
||||||
|
| ScalarType::B8
|
||||||
|
| ScalarType::U64
|
||||||
|
| ScalarType::B16
|
||||||
|
| ScalarType::S64
|
||||||
|
| ScalarType::B32
|
||||||
|
| ScalarType::U8
|
||||||
|
| ScalarType::F32
|
||||||
|
| ScalarType::B64
|
||||||
|
| ScalarType::B128
|
||||||
|
| ScalarType::U16
|
||||||
|
| ScalarType::F64
|
||||||
|
| ScalarType::F16 => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||||
|
@ -1933,8 +1974,13 @@ impl CvtDetails {
|
||||||
(RoundingMode::NearestEven, false)
|
(RoundingMode::NearestEven, false)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let dst_size = if dst.packed_type().is_some() {
|
||||||
|
dst.size_of() / 2
|
||||||
|
} else {
|
||||||
|
dst.size_of()
|
||||||
|
};
|
||||||
let mode = match (dst.kind(), src.kind()) {
|
let mode = match (dst.kind(), src.kind()) {
|
||||||
(ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
|
(ScalarKind::Float, ScalarKind::Float) => match dst_size.cmp(&src.size_of()) {
|
||||||
Ordering::Less => {
|
Ordering::Less => {
|
||||||
let (rounding, is_integer_rounding) = unwrap_rounding();
|
let (rounding, is_integer_rounding) = unwrap_rounding();
|
||||||
CvtMode::FPTruncate {
|
CvtMode::FPTruncate {
|
||||||
|
@ -2272,3 +2318,8 @@ impl VoteMode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ReduxSyncData {
|
||||||
|
pub type_: ScalarType,
|
||||||
|
pub reduction: Reduction,
|
||||||
|
}
|
||||||
|
|
|
@ -2442,7 +2442,16 @@ derive_parser!(
|
||||||
// cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a;
|
// cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a;
|
||||||
// cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b;
|
// cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b;
|
||||||
// cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a;
|
// cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a;
|
||||||
// cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
|
cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b => {
|
||||||
|
if relu || satfinite {
|
||||||
|
state.errors.push(PtxError::Todo);
|
||||||
|
}
|
||||||
|
let data = ast::CvtDetails::new(&mut state.errors, Some(frnd2), false, false, ScalarType::BF16x2, ScalarType::F32);
|
||||||
|
ast::Instruction::Cvt {
|
||||||
|
data,
|
||||||
|
arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) }
|
||||||
|
}
|
||||||
|
}
|
||||||
// cvt.rna{.satfinite}.tf32.f32 d, a;
|
// cvt.rna{.satfinite}.tf32.f32 d, a;
|
||||||
// cvt.frnd2{.relu}.tf32.f32 d, a;
|
// cvt.frnd2{.relu}.tf32.f32 d, a;
|
||||||
cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => {
|
cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => {
|
||||||
|
@ -3844,6 +3853,23 @@ derive_parser!(
|
||||||
// .mode: VoteMode = { .all, .any, .uni };
|
// .mode: VoteMode = { .all, .any, .uni };
|
||||||
.mode: VoteMode = { .all, .any };
|
.mode: VoteMode = { .all, .any };
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync
|
||||||
|
|
||||||
|
redux.sync.op.type dst, src, membermask => {
|
||||||
|
Instruction::ReduxSync {
|
||||||
|
data: ReduxSyncData { type_, reduction: op },
|
||||||
|
arguments: ReduxSyncArgs { dst, src, src_membermask: membermask }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.op: Reduction = {.add, .min, .max};
|
||||||
|
.type: ScalarType = {.u32, .s32};
|
||||||
|
|
||||||
|
// redux.sync.op.b32 dst, src, membermask;
|
||||||
|
// .op = {.and, .or, .xor}
|
||||||
|
|
||||||
|
// redux.sync.op{.abs.}{.NaN}.f32 dst, src, membermask;
|
||||||
|
// .op = { .min, .max }
|
||||||
|
|
||||||
);
|
);
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue