diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 6da96ed..2f39678 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -2305,11 +2305,9 @@ impl<'a> MethodEmitContext<'a> { ) }); } else { - self.resolver.with_result(arguments.dst, |dst| unsafe { - let dst = CStr::from_ptr(dst); - LLVMSetValueName2(min, dst.as_ptr(), dst.count_bytes()); - min - }); + let name = self.resolver.get_or_add(arguments.dst); + unsafe { LLVMSetValueName2(min, name.as_ptr().cast(), name.len()) }; + self.resolver.register(arguments.dst, min); } Ok(()) } @@ -2322,22 +2320,48 @@ impl<'a> MethodEmitContext<'a> { let llvm_prefix = match data { ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax", ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax", - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { - "llvm.maximum" - } ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", }; let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); - self.emit_intrinsic( + + let a = self.resolver.value(arguments.src1)?; + let b = self.resolver.value(arguments.src2)?; + + let min = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), + None, Some(&data.type_().into()), - vec![ - (self.resolver.value(arguments.src1)?, llvm_type), - (self.resolver.value(arguments.src2)?, llvm_type), - ], + vec![(a, llvm_type), (b, llvm_type)], )?; + + if let ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { + nan: true, type_, .. + }) = data + { + let is_nan = unsafe { + LLVMBuildFCmp( + self.builder, + LLVMRealPredicate::LLVMRealUNO, + a, + b, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildSelect( + self.builder, + is_nan, + LLVMConstReal(get_scalar_type(self.context, type_), f64::NAN), + min, + dst, + ) + }); + } else { + let name = self.resolver.get_or_add(arguments.dst); + unsafe { LLVMSetValueName2(min, name.as_ptr().cast(), name.len()) }; + self.resolver.register(arguments.dst, min); + } Ok(()) }