Skip to content

Commit

Permalink
Add SPV_AMDX_shader_enqueue version 2 support
Browse files Browse the repository at this point in the history
Co-authored-by: Dan Brown <[email protected]>
Co-authored-by: Maciej Jesionowski <[email protected]>
  • Loading branch information
Dan Brown and yavn committed Oct 4, 2024
1 parent 522dfea commit 6e411f5
Show file tree
Hide file tree
Showing 18 changed files with 153 additions and 13 deletions.
6 changes: 6 additions & 0 deletions source/name_mapper.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -241,6 +243,10 @@ spv_result_t FriendlyNameMapper::ParseInstruction(
SaveName(result_id,
std::string("_runtimearr_") + NameForId(inst.words[2]));
break;
case spv::Op::OpTypeNodePayloadArrayAMDX:
SaveName(result_id,
std::string("_payloadarr_") + NameForId(inst.words[2]));
break;
case spv::Op::OpTypePointer:
SaveName(result_id, std::string("_ptr_") +
NameForEnumOperand(SPV_OPERAND_TYPE_STORAGE_CLASS,
Expand Down
9 changes: 7 additions & 2 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2015-2022 The Khronos Group Inc.
// Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
// reserved.
// Modifications Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All
// rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -243,12 +243,14 @@ int32_t spvOpcodeIsConstant(const spv::Op opcode) {
case spv::Op::OpConstantSampler:
case spv::Op::OpConstantNull:
case spv::Op::OpConstantFunctionPointerINTEL:
case spv::Op::OpConstantStringAMDX:
case spv::Op::OpSpecConstantTrue:
case spv::Op::OpSpecConstantFalse:
case spv::Op::OpSpecConstant:
case spv::Op::OpSpecConstantComposite:
case spv::Op::OpSpecConstantCompositeReplicateEXT:
case spv::Op::OpSpecConstantOp:
case spv::Op::OpSpecConstantStringAMDX:
return true;
default:
return false;
Expand Down Expand Up @@ -296,6 +298,7 @@ bool spvOpcodeReturnsLogicalVariablePointer(const spv::Op opcode) {
case spv::Op::OpFunctionParameter:
case spv::Op::OpImageTexelPointer:
case spv::Op::OpCopyObject:
case spv::Op::OpAllocateNodePayloadsAMDX:
case spv::Op::OpSelect:
case spv::Op::OpPhi:
case spv::Op::OpFunctionCall:
Expand All @@ -322,6 +325,7 @@ int32_t spvOpcodeReturnsLogicalPointer(const spv::Op opcode) {
case spv::Op::OpImageTexelPointer:
case spv::Op::OpCopyObject:
case spv::Op::OpRawAccessChainNV:
case spv::Op::OpAllocateNodePayloadsAMDX:
return true;
default:
return false;
Expand Down Expand Up @@ -360,6 +364,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
case spv::Op::OpTypeRayQueryKHR:
case spv::Op::OpTypeHitObjectNV:
case spv::Op::OpTypeUntypedPointerKHR:
case spv::Op::OpTypeNodePayloadArrayAMDX:
return true;
default:
// In particular, OpTypeForwardPointer does not generate a type,
Expand Down
4 changes: 4 additions & 0 deletions source/opt/fix_storage_class.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2019 Google LLC
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -99,6 +101,7 @@ bool FixStorageClass::PropagateStorageClass(Instruction* inst,
case spv::Op::OpCopyMemorySized:
case spv::Op::OpVariable:
case spv::Op::OpBitcast:
case spv::Op::OpAllocateNodePayloadsAMDX:
// Nothing to change for these opcode. The result type is the same
// regardless of the storage class of the operand.
return false;
Expand Down Expand Up @@ -319,6 +322,7 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
switch (type_inst->opcode()) {
case spv::Op::OpTypeArray:
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeNodePayloadArrayAMDX:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeCooperativeMatrixKHR:
Expand Down
3 changes: 3 additions & 0 deletions source/opt/ir_context.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2017 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -539,6 +541,7 @@ void IRContext::AddCombinatorsForCapability(uint32_t capability) {
(uint32_t)spv::Op::OpTypeHitObjectNV,
(uint32_t)spv::Op::OpTypeArray,
(uint32_t)spv::Op::OpTypeRuntimeArray,
(uint32_t)spv::Op::OpTypeNodePayloadArrayAMDX,
(uint32_t)spv::Op::OpTypeStruct,
(uint32_t)spv::Op::OpTypeOpaque,
(uint32_t)spv::Op::OpTypePointer,
Expand Down
5 changes: 4 additions & 1 deletion source/opt/local_access_chain_convert_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2017 The Khronos Group Inc.
// Copyright (c) 2017 Valve Corporation
// Copyright (c) 2017 LunarG Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,7 +432,8 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"});
"SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch",
"SPV_AMDX_shader_enqueue"});
}

bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
Expand Down
3 changes: 3 additions & 0 deletions source/opt/local_single_block_elim_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2017 The Khronos Group Inc.
// Copyright (c) 2017 Valve Corporation
// Copyright (c) 2017 LunarG Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -238,6 +240,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_AMD_gcn_shader",
"SPV_KHR_shader_ballot",
"SPV_AMD_shader_ballot",
"SPV_AMDX_shader_enqueue",
"SPV_AMD_gpu_shader_half_float",
"SPV_KHR_shader_draw_parameters",
"SPV_KHR_subgroup_vote",
Expand Down
5 changes: 4 additions & 1 deletion source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2017 The Khronos Group Inc.
// Copyright (c) 2017 Valve Corporation
// Copyright (c) 2017 LunarG Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -144,7 +146,8 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});
"SPV_KHR_ray_tracing_position_fetch",
"SPV_AMDX_shader_enqueue"});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
std::vector<Instruction*> users;
Expand Down
5 changes: 4 additions & 1 deletion source/opt/scalar_replacement_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2017 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -671,7 +673,8 @@ bool ScalarReplacementPass::CheckTypeAnnotations(
for (auto inst :
get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
uint32_t decoration;
if (inst->opcode() == spv::Op::OpDecorate) {
if (inst->opcode() == spv::Op::OpDecorate ||
inst->opcode() == spv::Op::OpDecorateId) {
decoration = inst->GetSingleWordInOperand(1u);
} else {
assert(inst->opcode() == spv::Op::OpMemberDecorate);
Expand Down
13 changes: 12 additions & 1 deletion source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -803,6 +805,14 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
return type;
}
break;
case spv::Op::OpTypeNodePayloadArrayAMDX:
type = new NodePayloadArray(GetType(inst.GetSingleWordInOperand(0)));
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
break;
case spv::Op::OpTypeStruct: {
std::vector<const Type*> element_types;
bool incomplete_type = false;
Expand Down Expand Up @@ -940,7 +950,8 @@ void TypeManager::AttachDecoration(const Instruction& inst, Type* type) {
if (!IsAnnotationInst(opcode)) return;

switch (opcode) {
case spv::Op::OpDecorate: {
case spv::Op::OpDecorate:
case spv::Op::OpDecorateId: {
const auto count = inst.NumOperands();
std::vector<uint32_t> data;
for (uint32_t i = 1; i < count; ++i) {
Expand Down
30 changes: 30 additions & 0 deletions source/opt/types.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -90,6 +92,7 @@ bool Type::IsUniqueType() const {
case kStruct:
case kArray:
case kRuntimeArray:
case kNodePayloadArray:
return false;
default:
return true;
Expand Down Expand Up @@ -218,6 +221,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
DeclareKindCase(SampledImage);
DeclareKindCase(Array);
DeclareKindCase(RuntimeArray);
DeclareKindCase(NodePayloadArray);
DeclareKindCase(Struct);
DeclareKindCase(Opaque);
DeclareKindCase(Pointer);
Expand Down Expand Up @@ -485,6 +489,32 @@ void RuntimeArray::ReplaceElementType(const Type* type) {
element_type_ = type;
}

NodePayloadArray::NodePayloadArray(const Type* type)
: Type(kNodePayloadArray), element_type_(type) {
assert(!type->AsVoid());
}

bool NodePayloadArray::IsSameImpl(const Type* that, IsSameCache* seen) const {
const NodePayloadArray* rat = that->AsNodePayloadArray();
if (!rat) return false;
return element_type_->IsSameImpl(rat->element_type_, seen) &&
HasSameDecorations(that);
}

std::string NodePayloadArray::str() const {
std::ostringstream oss;
oss << "[" << element_type_->str() << "]";
return oss.str();
}

size_t NodePayloadArray::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
return element_type_->ComputeHashValue(hash, seen);
}

void NodePayloadArray::ReplaceElementType(const Type* type) {
element_type_ = type;
}

Struct::Struct(const std::vector<const Type*>& types)
: Type(kStruct), element_types_(types) {
for (const auto* t : types) {
Expand Down
26 changes: 26 additions & 0 deletions source/opt/types.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,6 +48,7 @@ class Sampler;
class SampledImage;
class Array;
class RuntimeArray;
class NodePayloadArray;
class Struct;
class Opaque;
class Pointer;
Expand Down Expand Up @@ -87,6 +90,7 @@ class Type {
kSampledImage,
kArray,
kRuntimeArray,
kNodePayloadArray,
kStruct,
kOpaque,
kPointer,
Expand Down Expand Up @@ -189,6 +193,7 @@ class Type {
DeclareCastMethod(SampledImage)
DeclareCastMethod(Array)
DeclareCastMethod(RuntimeArray)
DeclareCastMethod(NodePayloadArray)
DeclareCastMethod(Struct)
DeclareCastMethod(Opaque)
DeclareCastMethod(Pointer)
Expand Down Expand Up @@ -434,6 +439,27 @@ class RuntimeArray : public Type {
const Type* element_type_;
};

class NodePayloadArray : public Type {
public:
NodePayloadArray(const Type* element_type);
NodePayloadArray(const NodePayloadArray&) = default;

std::string str() const override;
const Type* element_type() const { return element_type_; }

NodePayloadArray* AsNodePayloadArray() override { return this; }
const NodePayloadArray* AsNodePayloadArray() const override { return this; }

size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;

void ReplaceElementType(const Type* element_type);

private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;

const Type* element_type_;
};

class Struct : public Type {
public:
Struct(const std::vector<const Type*>& element_types);
Expand Down
7 changes: 7 additions & 0 deletions source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2018 Google LLC.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,6 +32,11 @@ bool DecorationTakesIdParameters(spv::Decoration type) {
case spv::Decoration::AlignmentId:
case spv::Decoration::MaxByteOffsetId:
case spv::Decoration::HlslCounterBufferGOOGLE:
case spv::Decoration::NodeMaxPayloadsAMDX:
case spv::Decoration::NodeSharesPayloadLimitsWithAMDX:
case spv::Decoration::PayloadNodeArraySizeAMDX:
case spv::Decoration::PayloadNodeNameAMDX:
case spv::Decoration::PayloadNodeBaseIndexAMDX:
return true;
default:
break;
Expand Down
5 changes: 4 additions & 1 deletion source/val/validate_composites.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2017 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -104,7 +106,8 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
}
break;
}
case spv::Op::OpTypeRuntimeArray: {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeNodePayloadArrayAMDX: {
*member_type = type_inst->word(2);
// Array size is unknown.
break;
Expand Down
3 changes: 3 additions & 0 deletions source/val/validate_function.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2018 Google LLC.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -302,6 +304,7 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _,
case spv::StorageClass::Private:
case spv::StorageClass::Workgroup:
case spv::StorageClass::AtomicCounter:
case spv::StorageClass::NodePayloadAMDX:
// These are always allowed.
break;
case spv::StorageClass::StorageBuffer:
Expand Down
Loading

0 comments on commit 6e411f5

Please sign in to comment.