Skip to content

Commit

Permalink
Merge pull request #2418 from Try/msl-mesh-shader-patch1
Browse files Browse the repository at this point in the history
Fixes for Metal mesh-shader
  • Loading branch information
HansKristian-Work authored Dec 6, 2024
2 parents 9040e0d + f7783b5 commit 42c2136
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void _4(threadgroup spvUnsafeArray<uint2, 22>& gl_PrimitiveLineIndicesEXT, threa
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
Expand All @@ -177,6 +177,7 @@ void _4(threadgroup spvUnsafeArray<uint2, 22>& gl_PrimitiveLineIndicesEXT, threa
spvV.gl_PointSize = gl_MeshVerticesEXT[spvVI].gl_PointSize;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.vOut = vOut[spvVI];
spvV.outputs_a = outputs[spvVI].a;
spvV.outputs_b = outputs[spvVI].b;
spvMesh.set_vertex(spvVI, spvV);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void _4(threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT,
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
Expand All @@ -177,6 +177,7 @@ void _4(threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT,
spvV.gl_PointSize = gl_MeshVerticesEXT[spvVI].gl_PointSize;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.vOut = vOut[spvVI];
spvV.outputs_a = outputs[spvVI].a;
spvV.outputs_b = outputs[spvVI].b;
spvMesh.set_vertex(spvVI, spvV);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];

thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}

device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}

constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}

threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}

object_data T& operator [] (size_t pos) object_data
{
return elements[pos];
}
constexpr const object_data T& operator [] (size_t pos) const object_data
{
return elements[pos];
}
};

void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)
{
if (gl_LocalInvocationIndex == 0)
{
spvMeshSizes.x = vertexCount;
spvMeshSizes.y = primitiveCount;
}
}

struct gl_MeshPerVertexEXT
{
float4 gl_Position [[position]];
float gl_ClipDistance [[clip_distance]] [1];
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u);

struct spvPerVertex
{
float4 gl_Position [[position]];
float gl_ClipDistance [[clip_distance]] [1];
float gl_ClipDistance_0 [[user(clip0)]];
float fOut [[user(locn0)]];
uint uiOut [[user(locn1)]];
};

using spvMesh_t = mesh<spvPerVertex, void, 24, 22, topology::triangle>;

static inline __attribute__((always_inline))
void _4(threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT, thread uint& gl_LocalInvocationIndex, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray<float, 24>& fOut, threadgroup spvUnsafeArray<uint, 24>& uiOut, threadgroup uint2& spvMeshSizes)
{
spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0;
fOut[gl_LocalInvocationIndex] = float(gl_GlobalInvocationID.x);
uiOut[gl_LocalInvocationIndex] = gl_GlobalInvocationID.y;
}

[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], spvMesh_t spvMesh)
{
threadgroup uint2 spvMeshSizes;
threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24> gl_MeshVerticesEXT;
threadgroup spvUnsafeArray<float, 24> fOut;
threadgroup spvUnsafeArray<uint, 24> uiOut;
if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;
_4(gl_MeshVerticesEXT, gl_LocalInvocationIndex, gl_GlobalInvocationID, fOut, uiOut, spvMeshSizes);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (spvMeshSizes.y == 0)
{
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
spvPerVertex spvV = {};
spvV.gl_Position = gl_MeshVerticesEXT[spvVI].gl_Position;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.fOut = fOut[spvVI];
spvV.uiOut = uiOut[spvVI];
spvMesh.set_vertex(spvVI, spvV);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void _4(threadgroup spvUnsafeArray<uint2, 22>& gl_PrimitiveLineIndicesEXT, threa
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
Expand All @@ -186,6 +186,7 @@ void _4(threadgroup spvUnsafeArray<uint2, 22>& gl_PrimitiveLineIndicesEXT, threa
spvV.gl_PointSize = gl_MeshVerticesEXT[spvVI].gl_PointSize;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.vOut = vOut[spvVI];
spvV.outputs_a = outputs[spvVI].a;
spvV.outputs_b = outputs[spvVI].b;
spvMesh.set_vertex(spvVI, spvV);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void _4(threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT,
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
Expand All @@ -174,6 +174,7 @@ void _4(threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT,
spvV.gl_PointSize = gl_MeshVerticesEXT[spvVI].gl_PointSize;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.vOut = vOut[spvVI];
spvV.outputs_a = outputs[spvVI].a;
spvV.outputs_b = outputs[spvVI].b;
spvMesh.set_vertex(spvVI, spvV);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];

thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}

device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}

constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}

threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}

object_data T& operator [] (size_t pos) object_data
{
return elements[pos];
}
constexpr const object_data T& operator [] (size_t pos) const object_data
{
return elements[pos];
}
};

void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)
{
if (gl_LocalInvocationIndex == 0)
{
spvMeshSizes.x = vertexCount;
spvMeshSizes.y = primitiveCount;
}
}

struct gl_MeshPerVertexEXT
{
float4 gl_Position [[position]];
float gl_ClipDistance [[clip_distance]] [1];
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u);

struct spvPerVertex
{
float4 gl_Position [[position]];
float gl_ClipDistance [[clip_distance]] [1];
float gl_ClipDistance_0 [[user(clip0)]];
float fOut [[user(locn0)]];
uint uiOut [[user(locn1)]];
};

using spvMesh_t = mesh<spvPerVertex, void, 24, 22, topology::triangle>;

static inline __attribute__((always_inline))
void _4(threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT, thread uint& gl_LocalInvocationIndex, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray<float, 24>& fOut, threadgroup spvUnsafeArray<uint, 24>& uiOut, threadgroup uint2& spvMeshSizes)
{
spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0;
fOut[gl_LocalInvocationIndex] = float(gl_GlobalInvocationID.x);
uiOut[gl_LocalInvocationIndex] = gl_GlobalInvocationID.y;
}

[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], spvMesh_t spvMesh)
{
threadgroup uint2 spvMeshSizes;
threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24> gl_MeshVerticesEXT;
threadgroup spvUnsafeArray<float, 24> fOut;
threadgroup spvUnsafeArray<uint, 24> uiOut;
if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;
_4(gl_MeshVerticesEXT, gl_LocalInvocationIndex, gl_GlobalInvocationID, fOut, uiOut, spvMeshSizes);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (spvMeshSizes.y == 0)
{
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
spvPerVertex spvV = {};
spvV.gl_Position = gl_MeshVerticesEXT[spvVI].gl_Position;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.fOut = fOut[spvVI];
spvV.uiOut = uiOut[spvVI];
spvMesh.set_vertex(spvVI, spvV);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#version 450
#extension GL_EXT_mesh_shader : require
#extension GL_EXT_fragment_shading_rate : require
layout(local_size_x = 2, local_size_y = 3, local_size_z = 4) in;
layout(triangles, max_vertices = 24, max_primitives = 22) out;

out gl_MeshPerVertexEXT
{
vec4 gl_Position;
float gl_ClipDistance[1];
} gl_MeshVerticesEXT[];

layout(location = 0) out float fOut[];
layout(location = 1) out flat uint uiOut[];

void main()
{
SetMeshOutputsEXT(24, 22);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = vec4(gl_GlobalInvocationID, 1.0);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0;
fOut[gl_LocalInvocationIndex] = float(gl_GlobalInvocationID.x);
uiOut[gl_LocalInvocationIndex] = gl_GlobalInvocationID.y;
}
4 changes: 2 additions & 2 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19215,7 +19215,7 @@ void CompilerMSL::emit_mesh_outputs()
end_scope();
statement("spvMesh.set_primitive_count(spvMeshSizes.y);");

statement("const uint spvThreadCount = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);");
statement("const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);");

if (mesh_out_per_vertex != 0)
{
Expand All @@ -19240,7 +19240,7 @@ void CompilerMSL::emit_mesh_outputs()
uint32_t orig_id = get_extended_member_decoration(type_vert.self, index, SPIRVCrossDecorationInterfaceMemberIndex);

// Clip/cull distances are special-case
if (orig_id == (~0u))
if (orig_var == 0 && orig_id == (~0u))
continue;

auto &orig = get<SPIRVariable>(orig_var);
Expand Down

0 comments on commit 42c2136

Please sign in to comment.