Skip to content

Commit

Permalink
Translate specialization constants from SPIRV
Browse files Browse the repository at this point in the history
Add support for OpSpecConstantTrue, OpSpecConstantFalse, OpSpecConstant and
OpSpecConstantComposite instructions.

Add -spec-const command line option which allows to overwrite the default
values of specialization constants while translating SPIRV to LLVM.

Add getSpecConstInfo API. This API returns specialization sonstant IDs
available in the SPIR-V module and their size. This information also can be
printed out with -spec-const-info command line option.

Signed-off-by: Alexey Sotkin <[email protected]>
  • Loading branch information
AlexeySotkin committed Jan 30, 2020
1 parent 5e51c7e commit a4b2532
Show file tree
Hide file tree
Showing 11 changed files with 466 additions and 30 deletions.
4 changes: 4 additions & 0 deletions include/LLVMSPIRVLib.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ bool writeSpirv(Module *M, const SPIRV::TranslatorOpts &Opts, std::ostream &OS,
bool readSpirv(LLVMContext &C, const SPIRV::TranslatorOpts &Opts,
std::istream &IS, Module *&M, std::string &ErrMsg);

using SpecConstInfoTy = std::pair<uint32_t, uint32_t>;
void getSpecConstInfo(std::istream &IS,
std::vector<SpecConstInfoTy> &SpecConstInfo);

/// \brief Convert a SPIRVModule into LLVM IR.
/// \returns null on failure.
std::unique_ptr<Module>
Expand Down
14 changes: 14 additions & 0 deletions include/LLVMSPIRVOpts.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cassert>
#include <cstdint>
#include <map>
#include <unordered_map>

namespace SPIRV {

Expand Down Expand Up @@ -99,12 +100,25 @@ class TranslatorOpts {

void enableGenArgNameMD() { GenKernelArgNameMD = true; }

void setSpecConst(uint32_t SpecId, uint64_t SpecValue) {
ExternalSpecialization[SpecId] = SpecValue;
}

bool getSpecializationConstant(uint32_t SpecId, uint64_t &Value) const {
auto It = ExternalSpecialization.find(SpecId);
if (It == ExternalSpecialization.end())
return false;
Value = It->second;
return true;
}

private:
// Common translation options
VersionNumber MaxVersion = VersionNumber::MaximumVersion;
ExtensionsStatusMap ExtStatusMap;
// SPIR-V to LLVM translation options
bool GenKernelArgNameMD;
std::unordered_map<uint32_t, uint64_t> ExternalSpecialization;
};

} // namespace SPIRV
Expand Down
87 changes: 79 additions & 8 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1191,15 +1191,28 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,

// Translation of non-instruction values
switch (OC) {
case OpConstant: {
case OpConstant:
case OpSpecConstant: {
SPIRVConstant *BConst = static_cast<SPIRVConstant *>(BV);
SPIRVType *BT = BV->getType();
Type *LT = transType(BT);
uint64_t ConstValue = BConst->getZExtIntValue();
SPIRVWord SpecId = 0;
if (OC == OpSpecConstant && BV->hasDecorate(DecorationSpecId, 0, &SpecId)) {
// Update the value with possibly provided external specialization.
if (BM->getSpecializationConstant(SpecId, ConstValue)) {
assert(
(BT->getBitWidth() == 64 ||
(ConstValue >> BT->getBitWidth()) == 0) &&
"Size of externally provided specialization constant value doesn't"
"fit into the specialization constant type");
}
}
switch (BT->getOpCode()) {
case OpTypeBool:
case OpTypeInt:
return mapValue(
BV, ConstantInt::get(LT, BConst->getZExtIntValue(),
BV, ConstantInt::get(LT, ConstValue,
static_cast<SPIRVTypeInt *>(BT)->isSigned()));
case OpTypeFloat: {
const llvm::fltSemantics *FS = nullptr;
Expand All @@ -1214,12 +1227,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
FS = &APFloat::IEEEdouble();
break;
default:
llvm_unreachable("invalid float type");
llvm_unreachable("invalid floating-point type");
}
return mapValue(
BV, ConstantFP::get(*Context,
APFloat(*FS, APInt(BT->getFloatBitWidth(),
BConst->getZExtIntValue()))));
APFloat FPConstValue(*FS, APInt(BT->getFloatBitWidth(), ConstValue));
return mapValue(BV, ConstantFP::get(*Context, FPConstValue));
}
default:
llvm_unreachable("Not implemented");
Expand All @@ -1233,12 +1244,27 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
case OpConstantFalse:
return mapValue(BV, ConstantInt::getFalse(*Context));

case OpSpecConstantTrue:
case OpSpecConstantFalse: {
bool IsTrue = OC == OpSpecConstantTrue;
SPIRVWord SpecId = 0;
if (BV->hasDecorate(DecorationSpecId, 0, &SpecId)) {
uint64_t ConstValue = 0;
if (BM->getSpecializationConstant(SpecId, ConstValue)) {
IsTrue = ConstValue;
}
}
return mapValue(BV, IsTrue ? ConstantInt::getTrue(*Context)
: ConstantInt::getFalse(*Context));
}

case OpConstantNull: {
auto LT = transType(BV->getType());
return mapValue(BV, Constant::getNullValue(LT));
}

case OpConstantComposite: {
case OpConstantComposite:
case OpSpecConstantComposite: {
auto BCC = static_cast<SPIRVConstantComposite *>(BV);
std::vector<Constant *> CV;
for (auto &I : BCC->getElements())
Expand Down Expand Up @@ -3196,3 +3222,48 @@ bool llvm::readSpirv(LLVMContext &C, const SPIRV::TranslatorOpts &Opts,

return true;
}

void llvm::getSpecConstInfo(std::istream &IS,
std::vector<SpecConstInfoTy> &SpecConstInfo) {
std::unique_ptr<SPIRVModule> BM(SPIRVModule::createSPIRVModule());
BM->setAutoAddExtensions(false);
SPIRVDecoder D(IS, *BM);
SPIRVWord Magic;
D >> Magic;
if (!BM->getErrorLog().checkError(Magic == MagicNumber, SPIRVEC_InvalidModule,
"invalid magic number")) {
return;
}
// Skip the rest of the header
D.ignore(4);

// According to the logical layout of SPIRV module (p2.4 of the spec),
// all constant instructions must appear before function declarations.
while (D.OpCode != OpFunction && D.getWordCountAndOpCode()) {
switch (D.OpCode) {
case OpDecorate:
// The decoration is added to the module in scope of SPIRVDecorate::decode
D.getEntry();
break;
case OpTypeBool:
case OpTypeInt:
case OpTypeFloat:
BM->addEntry(D.getEntry());
break;
case OpSpecConstant:
case OpSpecConstantTrue:
case OpSpecConstantFalse: {
auto *C = BM->addConstant(static_cast<SPIRVValue *>(D.getEntry()));
SPIRVWord SpecConstIdLiteral = 0;
if (C->hasDecorate(DecorationSpecId, 0, &SpecConstIdLiteral)) {
SPIRVType *Ty = C->getType();
uint32_t SpecConstSize = Ty->isTypeBool() ? 1 : Ty->getBitWidth() / 8;
SpecConstInfo.emplace_back(SpecConstIdLiteral, SpecConstSize);
}
break;
}
default:
D.ignoreInstruction();
}
}
}
4 changes: 0 additions & 4 deletions lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,6 @@ _SPIRV_OP(Nop)
_SPIRV_OP(SourceContinued)
_SPIRV_OP(TypeMatrix)
_SPIRV_OP(TypeRuntimeArray)
_SPIRV_OP(SpecConstantTrue)
_SPIRV_OP(SpecConstantFalse)
_SPIRV_OP(SpecConstant)
_SPIRV_OP(SpecConstantComposite)
_SPIRV_OP(Image)
_SPIRV_OP(ImageTexelPointer)
_SPIRV_OP(ImageSampleDrefImplicitLod)
Expand Down
8 changes: 7 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@

namespace SPIRV {

template <Op> class SPIRVConstantBase;
using SPIRVConstant = SPIRVConstantBase<OpConstant>;

class SPIRVBasicBlock;
class SPIRVConstant;
class SPIRVEntry;
class SPIRVFunction;
class SPIRVInstruction;
Expand Down Expand Up @@ -431,6 +433,10 @@ class SPIRVModule {
return TranslationOpts.isGenArgNameMDEnabled();
}

bool getSpecializationConstant(SPIRVWord SpecId, uint64_t &ConstValue) {
return TranslationOpts.getSpecializationConstant(SpecId, ConstValue);
}

// I/O functions
friend spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M);
friend std::istream &operator>>(std::istream &I, SPIRVModule &M);
Expand Down
22 changes: 22 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include "SPIRVNameMapEnum.h"
#include "SPIRVOpCode.h"

#include <limits> // std::numeric_limits

namespace SPIRV {

/// Write string with quote. Replace " with \".
Expand Down Expand Up @@ -256,6 +258,12 @@ SPIRVEntry *SPIRVDecoder::getEntry() {
}
}

if (!M.getErrorLog().checkError(Entry->isImplemented(),
SPIRVEC_UnimplementedOpCode,
std::to_string(Entry->getOpCode()))) {
M.setInvalid();
}

assert(!IS.bad() && !IS.fail() && "SPIRV stream fails");
return Entry;
}
Expand All @@ -266,6 +274,20 @@ void SPIRVDecoder::validate() const {
assert(!IS.bad() && "Bad iInput stream");
}

// Skip \param n words in SPIR-V binary stream.
// In case of SPIR-V text format always skip until the end of the line.
void SPIRVDecoder::ignore(size_t N) {
#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (SPIRVUseTextFormat) {
IS.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
return;
}
#endif
IS.ignore(N * sizeof(SPIRVWord));
}

void SPIRVDecoder::ignoreInstruction() { ignore(WordCount - 1); }

spv_ostream &operator<<(spv_ostream &O, const SPIRVNL &E) {
#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (SPIRVUseTextFormat)
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class SPIRVDecoder {
bool getWordCountAndOpCode();
SPIRVEntry *getEntry();
void validate() const;
void ignore(size_t N);
void ignoreInstruction();

std::istream &IS;
SPIRVModule &M;
Expand Down
1 change: 0 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ class SPIRVTypeVector : public SPIRVType {
SPIRVWord CompCount; // Component Count
};

class SPIRVConstant;
class SPIRVTypeArray : public SPIRVType {
public:
// Complete constructor
Expand Down
39 changes: 24 additions & 15 deletions lib/SPIRV/libSPIRV/SPIRVValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,34 +136,34 @@ class SPIRVValue : public SPIRVEntry {
SPIRVType *Type; // Value Type
};

class SPIRVConstant : public SPIRVValue {
template <spv::Op OC> class SPIRVConstantBase : public SPIRVValue {
public:
// Complete constructor for integer constant
SPIRVConstant(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
uint64_t TheValue)
: SPIRVValue(M, 0, OpConstant, TheType, TheId) {
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
uint64_t TheValue)
: SPIRVValue(M, 0, OC, TheType, TheId) {
Union.UInt64Val = TheValue;
recalculateWordCount();
validate();
}
// Complete constructor for float constant
SPIRVConstant(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
float TheValue)
: SPIRVValue(M, 0, OpConstant, TheType, TheId) {
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
float TheValue)
: SPIRVValue(M, 0, OC, TheType, TheId) {
Union.FloatVal = TheValue;
recalculateWordCount();
validate();
}
// Complete constructor for double constant
SPIRVConstant(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
double TheValue)
: SPIRVValue(M, 0, OpConstant, TheType, TheId) {
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
double TheValue)
: SPIRVValue(M, 0, OC, TheType, TheId) {
Union.DoubleVal = TheValue;
recalculateWordCount();
validate();
}
// Incomplete constructor
SPIRVConstant() : SPIRVValue(OpConstant), NumWords(0) {}
SPIRVConstantBase() : SPIRVValue(OC), NumWords(0) {}
uint64_t getZExtIntValue() const { return Union.UInt64Val; }
float getFloatValue() const { return Union.FloatVal; }
double getDoubleValue() const { return Union.DoubleVal; }
Expand Down Expand Up @@ -204,6 +204,9 @@ class SPIRVConstant : public SPIRVValue {
} Union;
};

using SPIRVConstant = SPIRVConstantBase<OpConstant>;
using SPIRVSpecConstant = SPIRVConstantBase<OpSpecConstant>;

template <Op OC> class SPIRVConstantEmpty : public SPIRVValue {
public:
// Complete constructor
Expand Down Expand Up @@ -236,6 +239,8 @@ template <Op OC> class SPIRVConstantBool : public SPIRVConstantEmpty<OC> {

typedef SPIRVConstantBool<OpConstantTrue> SPIRVConstantTrue;
typedef SPIRVConstantBool<OpConstantFalse> SPIRVConstantFalse;
typedef SPIRVConstantBool<OpSpecConstantTrue> SPIRVSpecConstantTrue;
typedef SPIRVConstantBool<OpSpecConstantFalse> SPIRVSpecConstantFalse;

class SPIRVConstantNull : public SPIRVConstantEmpty<OpConstantNull> {
public:
Expand Down Expand Up @@ -273,18 +278,18 @@ class SPIRVUndef : public SPIRVConstantEmpty<OpUndef> {
void validate() const override { SPIRVConstantEmpty::validate(); }
};

class SPIRVConstantComposite : public SPIRVValue {
template <spv::Op OC> class SPIRVConstantCompositeBase : public SPIRVValue {
public:
// Complete constructor for composite constant
SPIRVConstantComposite(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
const std::vector<SPIRVValue *> TheElements)
SPIRVConstantCompositeBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
const std::vector<SPIRVValue *> TheElements)
: SPIRVValue(M, TheElements.size() + 3, OpConstantComposite, TheType,
TheId) {
Elements = getIds(TheElements);
validate();
}
// Incomplete constructor
SPIRVConstantComposite() : SPIRVValue(OpConstantComposite) {}
SPIRVConstantCompositeBase() : SPIRVValue(OpConstantComposite) {}
std::vector<SPIRVValue *> getElements() const { return getValues(Elements); }
std::vector<SPIRVEntry *> getNonLiteralOperands() const override {
std::vector<SPIRVValue *> Elements = getElements();
Expand All @@ -306,6 +311,10 @@ class SPIRVConstantComposite : public SPIRVValue {
std::vector<SPIRVId> Elements;
};

using SPIRVConstantComposite = SPIRVConstantCompositeBase<OpConstantComposite>;
using SPIRVSpecConstantComposite =
SPIRVConstantCompositeBase<OpSpecConstantComposite>;

class SPIRVConstantSampler : public SPIRVValue {
public:
const static Op OC = OpConstantSampler;
Expand Down
Loading

0 comments on commit a4b2532

Please sign in to comment.