Unity3D ComputeShader Example

So I’ve been messing around with compute shaders in Unity3D for the past day and I thought I’d share some code to maybe help others get started. Unfortunately the documentation for this capability is abysmal, so there’s a few simple things worth noting to get it up and running. This is a DirectX 11 only capability so you’ll need to be building for windows if you want to get anything out of this. In this example I’ll be demonstrating how you can set up your compute shader so that you get a square set of points that move back and forth (z-axis). I got this working from some code that a Unity dev posted, which I’ve unfortunately lost the link for. It will look something like the following

Unity ComputeShader Moving Points

Unity ComputeShader Moving Points

It is of course not nearly as interesting as a screenshot.

Anyway here’s how you can achieve this pretty simple compute shader.¬† There are three files needed to show procedural meshes with compute shaders, a .compute file, a normal shader file and a C# script. Below you can see the code for the compute shader. ¬†Hopefully the comments explain the different parts, but the idea is that we arrange points in a grid and offset their z position with sine using the current time from Unity.

#pragma kernel CSMain

//We define the size of a group in the x and y directions, z direction will just be one
 #define thread_group_size_x 4
 #define thread_group_size_y 4

 //A struct that simple holds a position
struct PositionStruct
	float3 pos;

//A struct containing an offset for use by Wave function
struct OffsetStruct
	float offset;

//A constant buffer struct that holds a time variable sent from Unity
struct CBufferStruct
	float t;

//We keep three buffers accessed by the kernel, a constant buffer that is the same for every computation,
//an offset buffer with a value to offset the wave, and an output buffer that is written to by the kernel
RWStructuredBuffer<CBufferStruct> cBuffer;
RWStructuredBuffer<OffsetStruct> offsets;
RWStructuredBuffer<PositionStruct> output;

//A simple sine modulation of the z coordinate, with an offset by a random value between 0 and 2PI
float3 Wave(float3 p, int idx)
	p.z = sin(cBuffer[0].t + offsets[idx].offset);
	return p;

//The kernel for this compute shader, each thread group contains a number of threads specified by numthreads(x,y,z)
//We lookup the the index into the flat array by using x + y * x_stride
//The position is calculated from the thread index and then the z component is shifted by the Wave function
void CSMain (uint3 id : SV_DispatchThreadID)
    int idx = id.x + id.y * thread_group_size_x * 32;
	float spacing = 1.0;

	float3 pos = float3(id.x*spacing, id.y*spacing, id.z*spacing);

	pos = Wave(pos, idx);

	output[idx].pos = pos;

If you create a new ComputeShader in Unity and stick this in there you’re ready for the next part, defining the C# file that dispatches this compute file. Again the comments should explain how this works, but the idea is that you create your buffers and then dispatch a number of thread groups (which runs the number of threads defined in the compute shader for each group) on the GPU. When the object is deactivated we free the buffers. One key point here, this script goes on a camera, which took me a bit to figure out.

using UnityEngine;

//This game object invokes PlaneComputeShader (when attached via drag'n drop in the editor) using the PlaneBufferShader (also attached in the editor)
//to display a grid of points moving back and forth along the z axis.
public class CreatePlane : MonoBehaviour
    public Shader shader;
    public ComputeShader computeShader;

    private ComputeBuffer offsetBuffer;
    private ComputeBuffer outputBuffer;
    private ComputeBuffer constantBuffer;
    private int _kernel;
    private Material material;

    public const int VertCount = 16384; //32*32*4*4 (Groups*ThreadsPerGroup)

	//We initialize the buffers and the material used to draw.
	void Start ()
	    _kernel = computeShader.FindKernel("CSMain");

    //When this GameObject is disabled we must release the buffers or else Unity complains.
    private void OnDisable()

    //After all rendering is complete we dispatch the compute shader and then set the material before drawing with DrawProcedural
    //this just draws the "mesh" as a set of points
	void OnPostRender ()

        material.SetBuffer("buf_Points", outputBuffer);
        Graphics.DrawProcedural(MeshTopology.Points, VertCount);

    //To setup a ComputeBuffer we pass in the array length, as well as the size in bytes of a single element.
    //We fill the offset buffer with random numbers between 0 and 2*PI.
    void CreateBuffers()
        offsetBuffer = new ComputeBuffer(VertCount, 4); //Contains a single float value (OffsetStruct)

        float[] values = new float[VertCount];

        for (int i = 0; i < VertCount; i++)
            values[i] = Random.value*2*Mathf.PI;


        constantBuffer = new ComputeBuffer(1, 4); //Contains a single element (time) which is a float

        outputBuffer = new ComputeBuffer(VertCount, 12); //Output buffer contains vertices (float3 = Vector3 -> 12 bytes)

    //For some reason I made this method to create a material from the attached shader.
    void CreateMaterial()
        material = new Material(shader);

    //Remember to release buffers and destroy the material when play has been stopped.
    void ReleaseBuffer()


    //The meat of this script, it sets the constant buffer (current time) and then sets all of the buffers for the compute shader.
    //We then dispatch 32x32x1 groups of threads of our CSMain kernel.
    void Dispatch()
        constantBuffer.SetData(new[] { Time.time });

        computeShader.SetBuffer(_kernel, "cBuffer", constantBuffer);
        computeShader.SetBuffer(_kernel, "offsets", offsetBuffer);
        computeShader.SetBuffer(_kernel, "output", outputBuffer);

        computeShader.Dispatch(_kernel, 32, 32, 1);

The final part is to create a shader file that will be used for rendering, this shader file contains a buffer just like the compute shader where we feed in the output from the compute shader. This gives it the set of points that need to be rendered. The shader is very simple, it just fetches the position of a vertex and transforms it into screen space, assigning it a solid color.

Shader "DX11/PlaneBufferShader"

			#pragma target 5.0

			#pragma vertex vert
			#pragma fragment frag

			#include "UnityCG.cginc"

			//The buffer containing the points we want to draw.
			StructuredBuffer<float3> buf_Points;

			//A simple input struct for our pixel shader step containing a position.
			struct ps_input {
				float4 pos : SV_POSITION;

			//Our vertex function simply fetches a point from the buffer corresponding to the vertex index
			//which we transform with the view-projection matrix before passing to the pixel program.
			ps_input vert (uint id : SV_VertexID)
				ps_input o;
				float3 worldPos = buf_Points[id];
				o.pos = mul (UNITY_MATRIX_VP, float4(worldPos,1.0f));
				return o;

			//Pixel function returns a solid color for each point.
			float4 frag (ps_input i) : COLOR
				return float4(1,0.5f,0.0f,1);



	Fallback Off

When the C# script is placed on a camera and you hook up the normal shader and the compute shader to it, you can see the fancy points move! There are probably numerous other ways to achieve the same thing, and no doubt I did something “bad” but this does work and hopefully someone out there can see how to get compute shaders to do some procedural mesh generation in Unity.

One comment to Unity3D ComputeShader Example

  • Mike Dardis  says:

    Thanks for the code, had to make sure the code files were named correctly and set Unity to Dx11 mode. Clean and well described foundation, excellent.

Leave a reply

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>