Skip to content

Commit

Permalink
[Backport to 8] Rename ConvertFToTF32INTEL to RoundFToTF32INTEL (#1945)
Browse files Browse the repository at this point in the history
Extension name will be preserved for a while for binary compatibility.

Signed-off-by: Sidorov, Dmitry <[email protected]>
  • Loading branch information
MrSidims authored Apr 4, 2023
1 parent dd4638b commit 26db0ab
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
3 changes: 2 additions & 1 deletion include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ EXT(SPV_INTEL_fast_composite)
EXT(SPV_INTEL_optnone)
EXT(SPV_INTEL_masked_gather_scatter)
EXT(SPV_INTEL_bfloat16_conversion)
EXT(SPV_INTEL_tensor_float32_conversion)
EXT(SPV_INTEL_tensor_float32_conversion) // TODO: to remove old extension
EXT(SPV_INTEL_tensor_float32_rounding)
EXT(SPV_INTEL_hw_thread_queries)
EXT(SPV_EXT_relaxed_printf_string_address_space)
EXT(SPV_INTEL_split_barrier)
12 changes: 6 additions & 6 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2763,10 +2763,10 @@ _SPIRV_OP(ConvertBF16ToFINTEL)
#undef _SPIRV_OP

template <Op OC>
class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
return getVec(CapabilityTensorFloat32ConversionINTEL);
return getVec(CapabilityTensorFloat32RoundingINTEL);
}

SPIRVExtSet getRequiredExtensions() const override {
Expand All @@ -2787,8 +2787,8 @@ class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
// because it may call a method of class Module that may modify LiteralMap
// of Module field. That modification is not impacting validate method for
// these instructions, so const_cast is safe here.
using SPVTF32ConvTy = SPIRVTensorFloat32ConversionINTELInstBase<OC>;
SPIRVValue *Input = const_cast<SPVTF32ConvTy *>(this)->getOperand(0);
using SPVTF32RoundTy = SPIRVTensorFloat32RoundingINTELInstBase<OC>;
SPIRVValue *Input = const_cast<SPVTF32RoundTy *>(this)->getOperand(0);

SPIRVType *InCompTy = Input->getType();
SPIRVWord InCompCount = 1;
Expand Down Expand Up @@ -2816,8 +2816,8 @@ class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
};

#define _SPIRV_OP(x) \
typedef SPIRVTensorFloat32ConversionINTELInstBase<Op##x> SPIRV##x;
_SPIRV_OP(ConvertFToTF32INTEL)
typedef SPIRVTensorFloat32RoundingINTELInstBase<Op##x> SPIRV##x;
_SPIRV_OP(RoundFToTF32INTEL)
#undef _SPIRV_OP

class SPIRVSplitBarrierINTELBase : public SPIRVInstTemplateBase {
Expand Down
3 changes: 1 addition & 2 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(CapabilityOptNoneINTEL, "OptNoneINTEL");
add(CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
add(CapabilityBfloat16ConversionINTEL, "Bfloat16ConversionINTEL");
add(CapabilityTensorFloat32ConversionINTEL,
"TensorFloat32ConversionINTEL");
add(CapabilityTensorFloat32RoundingINTEL, "TensorFloat32RoundingINTEL");
add(CapabilityHWThreadQueryINTEL, "HWThreadQueryINTEL");
add(CapabilitySplitBarrierINTEL, "SplitBarrierINTEL");

Expand Down
2 changes: 1 addition & 1 deletion lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ _SPIRV_OP(SubgroupAvcSicGetInterRawSadsINTEL, 5816)
_SPIRV_OP(TypeBufferSurfaceINTEL, 6086)
_SPIRV_OP(ConvertFToBF16INTEL, 6116)
_SPIRV_OP(ConvertBF16ToFINTEL, 6117)
_SPIRV_OP(ConvertFToTF32INTEL, 6426)
_SPIRV_OP(RoundFToTF32INTEL, 6426)
_SPIRV_OP(MaskedGatherINTEL, 6428)
_SPIRV_OP(MaskedScatterINTEL, 6429)
_SPIRV_OP(ControlBarrierArriveINTEL, 6142)
Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ enum Capability {
CapabilityOptNoneINTEL = 6094,
CapabilityBfloat16ConversionINTEL = 6115,
CapabilityHWThreadQueryINTEL = 6134,
CapabilityTensorFloat32ConversionINTEL = 6425,
CapabilityTensorFloat32RoundingINTEL = 6425,
CapabilityMaskedGatherScatterINTEL = 6427,
CapabilitySplitBarrierINTEL = 6141,
CapabilityMax = 0x7fffffff,
Expand Down Expand Up @@ -1424,7 +1424,7 @@ enum Op {
OpTypeBufferSurfaceINTEL = 6086,
OpConvertFToBF16INTEL = 6116,
OpConvertBF16ToFINTEL = 6117,
OpConvertFToTF32INTEL = 6426,
OpRoundFToTF32INTEL = 6426,
OpMaskedGatherINTEL = 6428,
OpMaskedScatterINTEL = 6429,
OpControlBarrierArriveINTEL = 6142,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; CHECK-SPIRV: Capability TensorFloat32ConversionINTEL
; CHECK-SPIRV: Capability TensorFloat32RoundingINTEL
; CHECK-SPIRV: Extension "SPV_INTEL_tensor_float32_conversion"
; CHECK-SPIRV: TypeFloat [[FP32Ty:[0-9]+]] 32
; CHECK-SPIRV: TypeVector [[FP32v8Ty:[0-9]+]] [[FP32Ty]] 8
Expand All @@ -18,24 +18,24 @@ target triple = "spir64-unknown-unknown"
; CHECK-SPIRV: FunctionParameter [[FP32Ty]] [[FP32ValId:[0-9]+]]
; CHECK-SPIRV: FunctionParameter [[FP32v8Ty]] [[FP32v8ValId:[0-9]+]]

; CHECK-SPIRV: ConvertFToTF32INTEL [[FP32Ty]] [[IGNORE0:[0-9]+]] [[FP32ValId]]
; CHECK-SPIRV: ConvertFToTF32INTEL [[FP32v8Ty]] [[IGNORE1:[0-9]+]] [[FP32v8ValId]]
; CHECK-SPIRV: ConvertFToTF32INTEL [[FP32Ty]] [[IGNORE2:[0-9]+]] [[CONST]]
; CHECK-SPIRV: RoundFToTF32INTEL [[FP32Ty]] [[IGNORE0:[0-9]+]] [[FP32ValId]]
; CHECK-SPIRV: RoundFToTF32INTEL [[FP32v8Ty]] [[IGNORE1:[0-9]+]] [[FP32v8ValId]]
; CHECK-SPIRV: RoundFToTF32INTEL [[FP32Ty]] [[IGNORE2:[0-9]+]] [[CONST]]

; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float
; CHECK-LLVM: call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>
; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
; CHECK-LLVM: call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float
; CHECK-LLVM: call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>
; CHECK-LLVM: call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.000000e+00)

define spir_func void @_Z2opffv8(float %a, <8 x float> %in) {
%1 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float %a)
%2 = tail call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float> %in)
%3 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
%1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
%2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
%3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.000000e+00)
ret void
}

declare spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float)
declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)

declare spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>)
declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)

!opencl.spir.version = !{!0}
!spirv.Source = !{!1}
Expand Down

0 comments on commit 26db0ab

Please sign in to comment.