Skip to content

Commit

Permalink
Merge pull request #2429 from KhronosGroup/fix-2411
Browse files Browse the repository at this point in the history
HLSL: Fix lowering of arrayed clip/cull distance in mesh shaders.
  • Loading branch information
HansKristian-Work authored Dec 12, 2024
2 parents d2478b2 + de515fe commit ebe2aa0
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT

groupshared float shared_float[16];

void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
{
SetMeshOutputCounts(24u, 22u);
float3 _173 = float3(gl_GlobalInvocationID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT

groupshared float shared_float[16];

void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
{
SetMeshOutputCounts(24u, 22u);
float3 _29 = float3(gl_GlobalInvocationID);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
static uint gl_LocalInvocationIndex;
struct SPIRV_Cross_Input
{
uint gl_LocalInvocationIndex : SV_GroupIndex;
};

struct gl_MeshPerVertexEXT
{
float4 gl_ClipDistance : SV_ClipDistance;
};

struct gl_MeshPerPrimitiveEXT
{
};

void write_clip_distance(inout float v[4])
{
v[0] += 1.0f;
v[1] += 2.0f;
v[2] += 3.0f;
v[3] += 4.0f;
}

void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[3])
{
SetMeshOutputCounts(3u, 1u);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0f;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[1] = 4.0f;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[2] = 4.0f;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[3] = 4.0f;
float _62[4] = { gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0], gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[1], gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[2], gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[3] };
float param[4] = _62;
write_clip_distance(param);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance = float4(param[0], param[1], param[2], param[3]);
}

[outputtopology("triangle")]
[numthreads(1, 1, 1)]
void main(SPIRV_Cross_Input stage_input, out vertices gl_MeshPerVertexEXT gl_MeshVerticesEXT[3])
{
gl_LocalInvocationIndex = stage_input.gl_LocalInvocationIndex;
mesh_main(gl_MeshVerticesEXT);
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ float spvFlipVertY(float v)
return -v;
}

void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
{
SetMeshOutputCounts(24u, 22u);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = spvFlipVertY(float4(float3(gl_GlobalInvocationID), 1.0f));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT

groupshared float shared_float[16];

void main3(inout uint2 gl_PrimitiveLineIndicesEXT[22], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22])
void main3(inout uint2 gl_PrimitiveLineIndicesEXT[22], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22])
{
gl_PrimitiveLineIndicesEXT[gl_LocalInvocationIndex] = uint2(0u, 1u) + gl_LocalInvocationIndex.xx;
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = int(gl_GlobalInvocationID.x);
Expand All @@ -61,7 +61,7 @@ void main3(inout uint2 gl_PrimitiveLineIndicesEXT[22], inout gl_MeshPerPrimitive
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveShadingRateEXT = int(gl_GlobalInvocationID.x) + 3;
}

void main2(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
void main2(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
{
SetMeshOutputCounts(24u, 22u);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0f);
Expand All @@ -81,7 +81,7 @@ void main2(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPri
}
}

void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
{
main2(gl_MeshVerticesEXT, gl_MeshPrimitivesEXT, _payload, gl_PrimitiveLineIndicesEXT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT

groupshared float shared_float[16];

void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
{
SetMeshOutputCounts(24u, 22u);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0f);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(triangles, max_vertices = 3, max_primitives = 1) out;

out gl_MeshPerVertexEXT
{
float gl_ClipDistance[4];
} gl_MeshVerticesEXT[];

void write_clip_distance(inout float v[4])
{
v[0] += 1.0;
v[1] += 2.0;
v[2] += 3.0;
v[3] += 4.0;
}

void main()
{
SetMeshOutputsEXT(3, 1);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[1] = 4.0;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[2] = 4.0;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[3] = 4.0;
write_clip_distance(gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance);
}
2 changes: 2 additions & 0 deletions spirv_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,8 @@ struct AccessChainMeta
bool flattened_struct = false;
bool relaxed_precision = false;
bool access_meshlet_position_y = false;
bool chain_is_builtin = false;
spv::BuiltIn builtin = {};
};

enum ExtendedDecorations
Expand Down
23 changes: 22 additions & 1 deletion spirv_glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10210,6 +10210,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
bool pending_array_enclose = false;
bool dimension_flatten = false;
bool access_meshlet_position_y = false;
bool chain_is_builtin = false;
spv::BuiltIn chained_builtin = {};

if (auto *base_expr = maybe_get<SPIRExpression>(base))
{
Expand Down Expand Up @@ -10367,6 +10369,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
auto builtin = ir.meta[base].decoration.builtin_type;
bool mesh_shader = get_execution_model() == ExecutionModelMeshEXT;

chain_is_builtin = true;
chained_builtin = builtin;

switch (builtin)
{
case BuiltInCullDistance:
Expand Down Expand Up @@ -10502,6 +10507,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
{
access_meshlet_position_y = true;
}

chain_is_builtin = true;
chained_builtin = builtin;
}
else
{
Expand Down Expand Up @@ -10721,6 +10729,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
meta->storage_physical_type = physical_type;
meta->relaxed_precision = relaxed_precision;
meta->access_meshlet_position_y = access_meshlet_position_y;
meta->chain_is_builtin = chain_is_builtin;
meta->builtin = chained_builtin;
}

return expr;
Expand Down Expand Up @@ -12336,6 +12346,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
flattened_structs[ops[1]] = true;
if (meta.relaxed_precision && backend.requires_relaxed_precision_analysis)
set_decoration(ops[1], DecorationRelaxedPrecision);
if (meta.chain_is_builtin)
set_decoration(ops[1], DecorationBuiltIn, meta.builtin);

// If we have some expression dependencies in our access chain, this access chain is technically a forwarded
// temporary which could be subject to invalidation.
Expand Down Expand Up @@ -15679,7 +15691,16 @@ string CompilerGLSL::argument_decl(const SPIRFunction::Parameter &arg)

if (type.pointer)
{
if (arg.write_count && arg.read_count)
// If we're passing around block types to function, we really mean reference in a pointer sense,
// but DXC does not like inout for mesh blocks, so workaround that. out is technically not correct,
// but it works in practice due to legalization. It's ... not great, but you gotta do what you gotta do.
// GLSL will never hit this case since it's not valid.
if (type.storage == StorageClassOutput && get_execution_model() == ExecutionModelMeshEXT &&
has_decoration(type.self, DecorationBlock) && is_builtin_type(type) && arg.write_count)
{
direction = "out ";
}
else if (arg.write_count && arg.read_count)
direction = "inout ";
else if (arg.write_count)
direction = "out ";
Expand Down
80 changes: 74 additions & 6 deletions spirv_hlsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4775,13 +4775,13 @@ void CompilerHLSL::emit_load(const Instruction &instruction)
{
auto ops = stream(instruction);

auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t ptr = ops[2];

auto *chain = maybe_get<SPIRAccessChain>(ptr);
if (chain)
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t ptr = ops[2];

auto &type = get<SPIRType>(result_type);
bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;

Expand Down Expand Up @@ -4819,7 +4819,35 @@ void CompilerHLSL::emit_load(const Instruction &instruction)
}
}
else
CompilerGLSL::emit_instruction(instruction);
{
// Very special case where we cannot rely on IO lowering.
// Mesh shader clip/cull arrays ... Cursed.
auto &res_type = get<SPIRType>(result_type);
if (get_execution_model() == ExecutionModelMeshEXT &&
has_decoration(ptr, DecorationBuiltIn) &&
(get_decoration(ptr, DecorationBuiltIn) == BuiltInClipDistance ||
get_decoration(ptr, DecorationBuiltIn) == BuiltInCullDistance) &&
is_array(res_type) && !is_array(get<SPIRType>(res_type.parent_type)))
{
track_expression_read(ptr);
string load_expr = "{ ";
uint32_t num_elements = to_array_size_literal(res_type);
for (uint32_t i = 0; i < num_elements; i++)
{
load_expr += join(to_expression(ptr), "[", i, "]");
if (i + 1 < num_elements)
load_expr += ", ";
}
load_expr += " }";
emit_op(result_type, id, load_expr, false);
register_read(id, ptr, false);
inherit_expression_dependencies(id, ptr);
}
else
{
CompilerGLSL::emit_instruction(instruction);
}
}
}

void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
Expand Down Expand Up @@ -6903,3 +6931,43 @@ bool CompilerHLSL::is_user_type_structured(uint32_t id) const
}
return false;
}

void CompilerHLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
{
// Loading a full array of ClipDistance needs special consideration in mesh shaders
// since we cannot lower them by wrapping the variables in global statics.
// Fortunately, clip/cull is a proper vector in HLSL so we can lower with simple rvalue casts.
if (get_execution_model() != ExecutionModelMeshEXT ||
!has_decoration(target_id, DecorationBuiltIn) ||
!is_array(expr_type))
{
CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
return;
}

auto builtin = BuiltIn(get_decoration(target_id, DecorationBuiltIn));
if (builtin != BuiltInClipDistance && builtin != BuiltInCullDistance)
{
CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
return;
}

// Array of array means one thread is storing clip distance for all vertices. Nonsensical?
if (is_array(get<SPIRType>(expr_type.parent_type)))
SPIRV_CROSS_THROW("Attempting to store all mesh vertices in one go. This is not supported.");

uint32_t num_clip = to_array_size_literal(expr_type);
if (num_clip > 4)
SPIRV_CROSS_THROW("Number of clip or cull distances exceeds 4, this will not work with mesh shaders.");

auto unrolled_expr = join("float", num_clip, "(");
for (uint32_t i = 0; i < num_clip; i++)
{
unrolled_expr += join(expr, "[", i, "]");
if (i + 1 < num_clip)
unrolled_expr += ", ";
}

unrolled_expr += ")";
expr = std::move(unrolled_expr);
}
2 changes: 2 additions & 0 deletions spirv_hlsl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ class CompilerHLSL : public CompilerGLSL
std::vector<TypeID> composite_selection_workaround_types;

std::string get_inner_entry_point_name() const;

void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
};
} // namespace SPIRV_CROSS_NAMESPACE

Expand Down

0 comments on commit ebe2aa0

Please sign in to comment.