How mesh shaders are implemented in an AMD driver
In the previous post I gave a brief introduction on what mesh and task shaders are from the perspective of application developers. Now it’s time to dive deeper and talk about how mesh shaders are implemented in a Vulkan driver on AMD HW. Note that everything I discuss here is based on my experience and understanding as I was working on mesh shader support in RADV and is already available as public information in open source driver code. The goal of this blog post is to elaborate on how mesh shaders are implemented on the NGG hardware in AMD RDNA2 GPUs, and to show how these details affect shader performance. Hopefully, this helps the reader better understand how the concepts in the API are translated to the HW and what pitfalls to avoid to get good perf.
Short intro to NGG
NGG (Next Generation Geometry) is the technology that is responsible for any vertex and geometry processing in RDNA GPUs (with some caveats). Also known as “primitive shader”, the main innovations of NGG are:
- Shaders are aware of not only vertices, but also primitives (this is why they are called primitive shader).
- The output topology is entirely up to the shader, meaning that it can create output vertices and primitives with an arbitrary topology regarless of its input.
- On RDNA2 and newer, per-primitive output attributes are also supported.
This flexibility allows the driver to implement every vertex/geometry processing stage using NGG. Vertex, tess eval and geometry shaders can all be compiled to NGG “primitive shaders”. The only major limiting factor is that each thread (SIMD lane) can only output up to 1 vertex and 1 primitive (with caveats).
The driver is also capable of extending the application shaders with sweet stuff such as per-triangle culling, but this is not the main focus of this blog post. I also won’t cover the caveats here, but I may write more about NGG in the future.
Mapping the mesh shader API to NGG
The draw commands as executed on the GPU only understand a number of input vertices but the mesh shader API draw calls specify a number of workgroups instead. To make it work, we configure the shader such that the number of input vertices per workgroup is 1, and the output is set to what you passed into the API. This way, the FW can figure out how many workgroups it really needs to launch.
The driver has to accomodate the HW limitation above, so we must ensure that in the compiled shader, each thread only outputs up to 1 vertex and 1 primitive. Reminder: the API programming model allows any shader invocation to write any vertex and/or primitive. So, there is a fundamental mismatch between the programming model and what the HW can do.
This raises a few interesting questions.
How do we allow any thread to write any vertex/primitive? The driver allocates some LDS (shared memory) space, and writes all mesh shader outputs there. At the very end of the shader, each thread reads the attributes of the vertex and primitive that matches the thread ID and outputs that single vertex and primitive. This roundtrip to the LDS can be omitted if an output is only written by the thread with matching thread ID.
What if the MS workgroup size is less than the max number of output vertices or primitives?
Each HW thread can create up to 1 vertex and 1 primitive.
The driver has to set the real workgroup size accordingly:
hw workgroup size = max(api workgroup size, max vertex count, max primitive count)
The result is that the HW will get a workgroup that has some threads (invocations) that
execute the code you wrote (the “API shader”), and then some that won’t do anything
but wait until the very end to output their up to 1 vertex and 1 primitive.
It can result in poor occupancy (low HW utilization = bad performance).
What if the shader also has barriers in it? This is now turning into a headache. The driver has to ensure that the threads that “do nothing” also execute an equal amount of barriers as those that run your API shader. If the HW workgroup has the same number of waves as the API shader, this is trivial. Otherwise, we have to emit some extra code that keeps the extra waves running in a loop executing barriers. This is the worst.
What if the API shader also uses shared memory, or not all outputs fit the LDS? The D3D12 spec requires the driver to have at least 28K shared memory (LDS) available to the shader. However, graphics shaders can only access up to 32K LDS. How do we make this work, considering the above fact that the driver has to write mesh shader outputs to LDS? This is getting really ugly now, but in that case, the driver is forced to write MS outputs to VRAM instead of LDS.
How do you deal with the compute-like stuff, eg. workgroup ID, subgroup ID, etc.? Fortunately, most of these were already available to the shader, just not exposed in the traditional VS, TES, GS programming model. The only pain point is the workgroup ID which needs trickery. I already mentioned above that the HW is tricked into thinking that each MS workgroup has 1 input vertex. So we can just use the same register that contains the vertex ID for getting the workgroup ID.
Conclusion, performance considerations
The above implementation details can be turned into performance recommendations.
Specify a MS workgroup size that matches the maximum amount of vertices and primitives. Also, distribute the work among the full workgroup rather than leaving some threads doing nothing. If you do this, you ensure that the hardware is optimally utilized. This is the most important recommendation here today.
Try to only write to the mesh output array indices from the corresponding thread. If you do this, you hit an optimal code path in the driver, so it won’t have to write those outputs to LDS and read them back at the end.
Use shared memory, but not excessively. Implementing any nice algorithm in your mesh shader will likely need you to share data between threads. Don’t be afraid to use shared memory, but prefer to use subgroup functionality instead when possible.
What if you don’t want do any of the above?
That is perfectly fine. Don’t use mesh shaders then.
The main takeaway about mesh shading is that it’s a very low level tool. The driver can implement the full programming model, but it can’t hold your hands as well as it could for traditional vertex processing. You may have to implement things (eg. vertex inputs, culling, etc.) that previously the driver would do for you. Essentially, if you write a mesh shader you are trying to beat the driver at its own game.
Wait, aren’t we forgetting something?
I think this post is already dense enough with technical detail. Brace yourself for the next post, where I’m going to blow your mind even more and talk about how task shaders are implemented.
Comments
The blog doesn't have comments, but
feel free to reach out to me on IRC (Venemo
on OFTC) or Discord (sunrise_sky
) to discuss.