Skip to content

Commit

Permalink
[Backport to 14] Missing SPIR-V 1.4 features/changes (#2590)
Browse files Browse the repository at this point in the history
* [SPIR-V 1.4] Support OpPtrEqual, OpPtrNotEqual and OpPtrDiff to compare pointers

* Translate icmp eq and icmp ne to PtrEqual and PtrNotEqual respectively

* CounterBuffer decoration is supported from 1.4 without extension

Do nothing for now, as it's not used in translator.

Co-authored-by: Viktoria Maximova <[email protected]>
  • Loading branch information
MrSidims and vmaksimo authored Jun 18, 2024
1 parent 62f5b09 commit 8a6bccd
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 9 deletions.
22 changes: 22 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,28 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
return mapValue(BV, V);
}

case OpPtrEqual:
case OpPtrNotEqual: {
auto *BC = static_cast<SPIRVBinary *>(BV);
auto Ops = transValue(BC->getOperands(), F, BB);

IRBuilder<> Builder(BB);
Value *Op1 = Builder.CreatePtrToInt(Ops[0], Type::getInt64Ty(*Context));
Value *Op2 = Builder.CreatePtrToInt(Ops[1], Type::getInt64Ty(*Context));
CmpInst::Predicate P =
OC == OpPtrEqual ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
Value *V = Builder.CreateICmp(P, Op1, Op2);
return mapValue(BV, V);
}

case OpPtrDiff: {
auto *BC = static_cast<SPIRVBinary *>(BV);
auto Ops = transValue(BC->getOperands(), F, BB);
IRBuilder<> Builder(BB);
Value *V = Builder.CreatePtrDiff(transType(BC->getType()), Ops[0], Ops[1]);
return mapValue(BV, V);
}

case OpCompositeConstruct: {
auto CC = static_cast<SPIRVCompositeConstruct *>(BV);
auto Constituents = transValue(CC->getOperands(), F, BB);
Expand Down
9 changes: 7 additions & 2 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,9 +1328,14 @@ SPIRVInstruction *LLVMToSPIRVBase::transCmpInst(CmpInst *Cmp,
auto *Op0 = Cmp->getOperand(0);
SPIRVValue *TOp0 = transValue(Op0, BB);
SPIRVValue *TOp1 = transValue(Cmp->getOperand(1), BB);
// TODO: once the translator supports SPIR-V 1.4, update the condition below:
// if (/* */->isPointerTy() && /* it is not allowed to use SPIR-V 1.4 */)
if (Op0->getType()->isPointerTy()) {
auto P = Cmp->getPredicate();
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_4) &&
(P == ICmpInst::ICMP_EQ || P == ICmpInst::ICMP_NE) &&
Cmp->getOperand(1)->getType()->isPointerTy()) {
Op OC = P == ICmpInst::ICMP_EQ ? OpPtrEqual : OpPtrNotEqual;
return BM->addBinaryInst(OC, transType(Cmp->getType()), TOp0, TOp1, BB);
}
unsigned AS = cast<PointerType>(Op0->getType())->getAddressSpace();
SPIRVType *Ty = transType(getSizetType(AS));
TOp0 = BM->addUnaryInst(OpConvertPtrToU, Ty, TOp0, BB);
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVDecorate.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class SPIRVDecorateGeneric : public SPIRVAnnotationGeneric {
case DecorationMaxByteOffset:
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_1);
case DecorationUserSemantic:
case DecorationCounterBuffer:
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4);

default:
Expand Down
14 changes: 14 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,21 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
"Invalid type for bitwise instruction");
assert((Op1Ty->getIntegerBitWidth() == Op2Ty->getIntegerBitWidth()) &&
"Inconsistent BitWidth");
} else if (isBinaryPtrOpCode(OpCode)) {
assert((Op1Ty->isTypePointer() && Op2Ty->isTypePointer()) &&
"Invalid types for PtrEqual, PtrNotEqual, or PtrDiff instruction");
assert(static_cast<SPIRVTypePointer *>(Op1Ty)->getElementType() ==
static_cast<SPIRVTypePointer *>(Op2Ty)->getElementType() &&
"Invalid types for PtrEqual, PtrNotEqual, or PtrDiff instruction");
} else {
assert(0 && "Invalid op code!");
}
}
SPIRVWord getRequiredSPIRVVersion() const override {
if (isBinaryPtrOpCode(OpCode))
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4);
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_0);
}
};

template <Op OC>
Expand Down Expand Up @@ -716,6 +727,9 @@ _SPIRV_OP(BitwiseAnd)
_SPIRV_OP(BitwiseOr)
_SPIRV_OP(BitwiseXor)
_SPIRV_OP(Dot)
_SPIRV_OP(PtrEqual)
_SPIRV_OP(PtrNotEqual)
_SPIRV_OP(PtrDiff)
#undef _SPIRV_OP

template <Op TheOpCode> class SPIRVInstNoOperand : public SPIRVInstruction {
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ inline bool isBinaryOpCode(Op OpCode) {
OpCode == OpDot || OpCode == OpIAddCarry || OpCode == OpISubBorrow;
}

inline bool isBinaryPtrOpCode(Op OpCode) {
return (unsigned)OpCode >= OpPtrEqual && (unsigned)OpCode <= OpPtrDiff;
}

inline bool isShiftOpCode(Op OpCode) {
return (unsigned)OpCode >= OpShiftRightLogical &&
(unsigned)OpCode <= OpShiftLeftLogical;
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ _SPIRV_OP(GroupNonUniformBitwiseXor, 361)
_SPIRV_OP(GroupNonUniformLogicalAnd, 362)
_SPIRV_OP(GroupNonUniformLogicalOr, 363)
_SPIRV_OP(GroupNonUniformLogicalXor, 364)
_SPIRV_OP(PtrEqual, 401)
_SPIRV_OP(PtrNotEqual, 402)
_SPIRV_OP(PtrDiff, 403)
_SPIRV_OP(GroupNonUniformRotateKHR, 4431)
_SPIRV_OP(SDotKHR, 4450)
_SPIRV_OP(UDotKHR, 4451)
Expand Down
16 changes: 15 additions & 1 deletion test/ComparePointers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ kernel void test(int global *in, int global *in2) {
return;
}
// RUN: %clang_cc1 -triple spir64 -x cl -cl-std=CL2.0 -O0 -emit-llvm-bc %s -o %t.bc
// RUN: llvm-spirv %t.bc -spirv-text -o %t.spt
// RUN: llvm-spirv %t.bc --spirv-max-version=1.3 -o %t.spv
// RUN: spirv-val %t.spv
// RUN: llvm-spirv %t.bc -spirv-text --spirv-max-version=1.3 -o %t.spt
// RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

// RUN: llvm-spirv %t.bc -o %t.spv
// RUN: spirv-val %t.spv
// RUN: llvm-spirv %t.bc -spirv-text -o %t.spt
// RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV-14

// CHECK-SPIRV:ConvertPtrToU
// CHECK-SPIRV:ConvertPtrToU
Expand All @@ -26,3 +31,12 @@ kernel void test(int global *in, int global *in2) {
// CHECK-SPIRV:ConvertPtrToU
// CHECK-SPIRV:ConvertPtrToU
// CHECK-SPIRV:ULessThan

// CHECK-SPIRV-14: PtrNotEqual
// CHECK-SPIRV-14: PtrEqual
// CHECK-SPIRV-14:ConvertPtrToU
// CHECK-SPIRV-14:ConvertPtrToU
// CHECK-SPIRV-14:UGreaterThan
// CHECK-SPIRV-14:ConvertPtrToU
// CHECK-SPIRV-14:ConvertPtrToU
// CHECK-SPIRV-14:ULessThan
18 changes: 12 additions & 6 deletions test/complex-constexpr.ll
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
; RUN: llvm-spirv -r %t.spv -o %t.bc
; RUN: llvm-dis %t.bc -o %t.ll
; RUN: FileCheck %s --input-file %t.spt -check-prefix=CHECK-SPIRV
; RUN: FileCheck %s --input-file %t.ll -check-prefix=CHECK-LLVM
; RUN: llvm-spirv %t.bc --spirv-max-version=1.3 -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv %t.spv --spirv-max-version=1.3 -o %t.spt --to-text
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc
; RUN: FileCheck %s --input-file %t.spt -check-prefix=CHECK-SPIRV
; RUN: FileCheck %s --input-file %t.rev.ll -check-prefix=CHECK-LLVM

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64"
Expand All @@ -16,6 +16,12 @@ target triple = "spir64"
; CHECK-SPIRV: TypePointer [[Ptr_Ty:[0-9]+]] 8 [[Int_Ty]]
; CHECK-SPIRV: TypeFunction [[Func_Ty2:[0-9]+]] [[Void_Ty]] [[Ptr_Ty]] [[Ptr_Ty]]

; CHECK-SPIRV-14: TypeInt [[Int_Ty:[0-9]+]] 8 0
; CHECK-SPIRV-14: TypeVoid [[Void_Ty:[0-9]+]]
; CHECK-SPIRV-14: TypeFunction [[Func_Ty1:[0-9]+]] [[Void_Ty]]
; CHECK-SPIRV-14: TypePointer [[Ptr_Ty:[0-9]+]] 8
; CHECK-SPIRV-14: TypeFunction [[Func_Ty2:[0-9]+]] [[Void_Ty]] [[Ptr_Ty]] [[Ptr_Ty]]

@.str.1 = private unnamed_addr addrspace(1) constant [1 x i8] zeroinitializer, align 1

define linkonce_odr hidden spir_func void @foo() {
Expand Down
52 changes: 52 additions & 0 deletions test/transcoding/ptr_diff.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
; Check support of OpPtrDiff instruction that was added in SPIR-V 1.4

; RUN: llvm-as %s -o %t.bc
; RUN: not llvm-spirv --spirv-max-version=1.3 %t.bc 2>&1 | FileCheck --check-prefix=CHECK-ERROR %s

; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
; RUN: spirv-val %t.spv

; RUN: llvm-spirv -r --opaque-pointers %t.spv -o %t.rev.bc
; RUN: llvm-dis -opaque-pointers=1 %t.rev.bc
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

; CHECK-ERROR: RequiresVersion: Cannot fulfill SPIR-V version restriction:
; CHECK-ERROR-NEXT: SPIR-V version was restricted to at most 1.3 (66304) but a construct from the input requires SPIR-V version 1.4 (66560) or above

; SPIR-V 1.4
; CHECK-SPIRV: 66560
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV: TypePointer [[#TypePointer:]] [[#]] [[#TypeFloat]]

; CHECK-SPIRV: Variable [[#TypePointer]] [[#Var:]]
; CHECK-SPIRV: PtrDiff [[#TypeInt]] [[#]] [[#Var]] [[#Var]]

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64-unknown-unknown"

; Function Attrs: nounwind
define spir_kernel void @test(float %a) local_unnamed_addr #0 {
entry:
%0 = alloca float, align 4
store float %a, float* %0, align 4
; CHECK-LLVM: %[[#Arg1:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: %[[#Arg2:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: %[[#Sub:]] = sub i64 %[[#Arg1]], %[[#Arg2]]
; CHECK-LLVM: sdiv exact i64 %[[#Sub]], ptrtoint (ptr getelementptr (i32, ptr null, i32 1) to i64)
%1 = call spir_func noundef i32 @_Z15__spirv_PtrDiff(float* %0, float* %0)
ret void
}

declare spir_func noundef i32 @_Z15__spirv_PtrDiff(float*, float*)

attributes #0 = { convergent nounwind writeonly }

!llvm.module.flags = !{!0}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!1}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 2, i32 0}
60 changes: 60 additions & 0 deletions test/transcoding/ptr_not_equal.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
; Check support of OpPtrEqual and OpPtrNotEqual instructions that were added in SPIR-V 1.4

; RUN: llvm-as %s -o %t.bc
; RUN: not llvm-spirv --spirv-max-version=1.3 %t.bc 2>&1 | FileCheck --check-prefix=CHECK-ERROR %s

; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
; RUN: spirv-val %t.spv

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

; CHECK-ERROR: RequiresVersion: Cannot fulfill SPIR-V version restriction:
; CHECK-ERROR-NEXT: SPIR-V version was restricted to at most 1.3 (66304) but a construct from the input requires SPIR-V version 1.4 (66560) or above

; SPIR-V 1.4
; CHECK-SPIRV: 66560
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV: TypePointer [[#TypePointer:]] [[#]] [[#TypeFloat]]
; CHECK-SPIRV: TypeBool [[#TypeBool:]]

; CHECK-SPIRV: Variable [[#TypePointer]] [[#Var1:]]
; CHECK-SPIRV: Variable [[#TypePointer]] [[#Var2:]]
; CHECK-SPIRV: PtrEqual [[#TypeBool]] [[#]] [[#Var1]] [[#Var2]]
; CHECK-SPIRV: PtrNotEqual [[#TypeBool]] [[#]] [[#Var1]] [[#Var2]]

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64-unknown-unknown"

; Function Attrs: nounwind
define spir_kernel void @test(float %a, float %b) local_unnamed_addr #0 {
entry:
%0 = alloca float, align 4
%1 = alloca float, align 4
store float %a, float* %0, align 4
store float %b, float* %1, align 4
; CHECK-LLVM: %[[#Arg1:]] = ptrtoint float* %[[#]] to i64
; CHECK-LLVM: %[[#Arg2:]] = ptrtoint float* %[[#]] to i64
; CHECK-LLVM: icmp eq i64 %[[#Arg1]], %[[#Arg2]]
%2 = call spir_func noundef i1 @_Z16__spirv_PtrEqual(float* %0, float* %1)
; CHECK-LLVM: %[[#Arg3:]] = ptrtoint float* %[[#]] to i64
; CHECK-LLVM: %[[#Arg4:]] = ptrtoint float* %[[#]] to i64
; CHECK-LLVM: icmp ne i64 %[[#Arg3]], %[[#Arg4]]
%3 = call spir_func noundef i1 @_Z19__spirv_PtrNotEqual(float* %0, float* %1)
ret void
}

declare spir_func noundef i1 @_Z16__spirv_PtrEqual(float*, float*)
declare spir_func noundef i1 @_Z19__spirv_PtrNotEqual(float*, float*)

attributes #0 = { convergent nounwind writeonly }

!llvm.module.flags = !{!0}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!1}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 2, i32 0}

0 comments on commit 8a6bccd

Please sign in to comment.