More Efficient Matrix Multiplication

In this example, we show you how to write a more efficient matrix multiplication kernel on NVIDIA GPU that uses shared memory. For simplicity, we omitted some optimizations like software pipelining (see our paper for more details).

Feel free to skip this example if you are not familiar with CUDA programming.

import torch
import hidet
from hidet.lang import attrs
from hidet.lang import float32, int32
from hidet.lang import as_tensor_pointer, register_tensor, shared_tensor
from hidet.lang.cuda import threadIdx, blockIdx, syncthreads
from hidet.lang.mapping import spatial, auto_map
from hidet.lang.layout import row_major, local_layout

# the hyperparameters of the kernel
warps_m, warps_n = 4, 2  # we use 4x2 warps
warp_m, warp_n = 2, 2  # each warp repeats 2x2 times
warp_map_m, warp_map_n = 2, 16  # each warp has 2x16 threads
thread_m, thread_n = 4, 4  # each thread repeats 4x4 times

# block_size = (64, 256, 8)
block_m_size, block_n_size = (
    warps_m * warp_m * warp_map_m * thread_m,
    warps_n * warp_n * warp_map_n * thread_n,
)
block_k_size = 8
num_warps = warps_m * warps_n  # 8
num_threads = num_warps * 32  # 256

with hidet.lang.script_module() as script_module:

    @hidet.script
    def relu(x: float32) -> float32:
        return x if x > 0.0 else 0.0

    @hidet.script
    def matmul_relu_kernel(
        a_ptr: ~float32,
        b_ptr: ~float32,
        c_ptr: ~float32,
        m_size: int32,
        n_size: int32,
        k_size: int32,
    ):
        attrs.func_name = 'matmul_kernel'
        attrs.cuda.block_dim = num_threads
        attrs.cuda.grid_dim = (
            (m_size + block_m_size - 1) // block_m_size,
            (n_size + block_n_size - 1) // block_n_size,
        )

        a = as_tensor_pointer(a_ptr, float32, [m_size, k_size])
        b = as_tensor_pointer(b_ptr, float32, [k_size, n_size])
        c = as_tensor_pointer(c_ptr, float32, [m_size, n_size])

        # define tensors in shared memory
        smem_a = shared_tensor(float32, shape=[block_m_size, block_k_size])
        smem_b = shared_tensor(float32, shape=[block_k_size, block_n_size])

        # define the accumulation tensor in register
        regs_c = register_tensor(
            dtype=float32,
            # shape will be inferred from the layout automatically,
            # in this case, the shape is [64, 256]
            layout=(
                local_layout(warps_m, warps_n)
                * row_major(warp_m, warp_n)
                * local_layout(warp_map_m, warp_map_n)
                * row_major(thread_m, thread_n)
            ),
        )

        # initialize the registers
        mma_mapping = (
            spatial(warps_m, warps_n)
            .repeat(warp_m, warp_n)
            .spatial(warp_map_m, warp_map_n)
            .repeat(thread_m, thread_n)
        )
        for i, j in mma_mapping.on(threadIdx.x):
            regs_c[i, j] = 0.0

        # iterate over the k tiles
        num_k_tiles = (k_size + block_k_size - 1) // block_k_size
        for k_tile in range(num_k_tiles):
            # load smem_a [block_m_size, block_k_size] from global memory
            for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on(threadIdx.x):
                global_i, global_k = (i + blockIdx.x * block_m_size, k + k_tile * block_k_size)
                smem_a[i, k] = (
                    a[global_i, global_k] if global_i < m_size and global_k < k_size else 0.0
                )

            # load smem_b [block_k_size, block_n_size] from global memory
            for k, j in auto_map(block_k_size, block_n_size, workers=num_threads).on(threadIdx.x):
                global_k, global_j = (k + k_tile * block_k_size, j + blockIdx.y * block_n_size)
                smem_b[k, j] = (
                    b[global_k, global_j] if global_k < k_size and global_j < n_size else 0.0
                )

            # synchronize all threads in the block
            syncthreads()

            # simt matrix multiply accumulate (mma): regs_c = regs_c + smem_a @ smem_b
            for i, j in mma_mapping.on(threadIdx.x):
                for k in range(block_k_size):
                    regs_c[i, j] += smem_a[i, k] * smem_b[k, j]

            # synchronize all threads in the block
            syncthreads()

        # store regs_c back to global memory
        for i, j in mma_mapping.on(threadIdx.x):
            global_i = i + blockIdx.x * block_m_size
            global_j = j + blockIdx.y * block_n_size
            if global_i < m_size and global_j < n_size:
                c[global_i, global_j] = relu(regs_c[i, j])


module = script_module.build()


def hidet_matmul_relu(a: torch.Tensor, b: torch.Tensor):
    m_size, n_size, k_size = a.shape[0], b.shape[1], a.shape[1]
    c = torch.empty([m_size, n_size], device='cuda')
    module(a, b, c, m_size, n_size, k_size)
    return c


def torch_matmul_relu(a: torch.Tensor, b: torch.Tensor):
    return torch.matmul(a, b).relu()

Run the program with different input sizes. This implementation archives about 30% performance of cuBLAS kernels. For more efficient implementations, please refer to the ones in hidet package.

for m, n, k in [(1024, 1024, 1024), (256, 256, 256), (32, 32, 32)]:
    a = torch.randn(m, k, dtype=torch.float32, device='cuda')
    b = torch.randn(k, n, dtype=torch.float32, device='cuda')

    c1 = hidet_matmul_relu(a, b)
    c2 = torch_matmul_relu(a, b)

    torch.testing.assert_close(c1, c2, atol=1e-4, rtol=1e-4)

    hidet_latency = hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b), repeat=50)
    print(f'{m}x{k}x{n}:')
    print(' torch: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: torch_matmul_relu(a, b))))
    print(' hidet: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b))))
1024x1024x1024:
 torch: 0.057 ms
 hidet: 0.130 ms
256x256x256:
 torch: 0.012 ms
 hidet: 0.036 ms
32x32x32:
 torch: 0.007 ms
 hidet: 0.007 ms

Get the source code:

print(module.source())
#include <stdint.h>
#include <hidet/runtime/symbols.h>
#include <hidet/runtime/memory_planner.h>
#include <hidet/runtime/cpu/context.h>
#include <hidet/runtime/cuda/complex.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>


static __device__ __forceinline__ float hidet_relu(float x) {
  return ((0.0f < x) ? x : 0.0f);
}

static __global__ void __launch_bounds__(256) hidet_matmul_kernel(float * __restrict__ a_ptr, float * __restrict__ b_ptr, float * __restrict__ c_ptr, int32_t m_size, int32_t n_size, int32_t k_size) {
  float *a = ((float*)(a_ptr));
  float *b = ((float*)(b_ptr));
  float *c = ((float*)(c_ptr));
  __shared__ float smem_a[512];
  __shared__ float smem_b[2048];
  float regs_c[64];
  for (int32_t i = 0; (i < 4); i = (i + 1)) {
    for (int32_t i_1 = 0; (i_1 < 4); i_1 = (i_1 + 1)) {
      regs_c[((i * 4) + i_1)] = 0.0f;
    }
  }
  for (int32_t i_2 = 0; (i_2 < 4); i_2 = (i_2 + 1)) {
    for (int32_t i_3 = 0; (i_3 < 4); i_3 = (i_3 + 1)) {
      regs_c[(16 + ((i_2 * 4) + i_3))] = 0.0f;
    }
  }
  for (int32_t i_4 = 0; (i_4 < 4); i_4 = (i_4 + 1)) {
    for (int32_t i_5 = 0; (i_5 < 4); i_5 = (i_5 + 1)) {
      regs_c[(32 + ((i_4 * 4) + i_5))] = 0.0f;
    }
  }
  for (int32_t i_6 = 0; (i_6 < 4); i_6 = (i_6 + 1)) {
    for (int32_t i_7 = 0; (i_7 < 4); i_7 = (i_7 + 1)) {
      regs_c[(48 + ((i_6 * 4) + i_7))] = 0.0f;
    }
  }
  for (int32_t k_tile = 0; (k_tile < ((k_size + 7) / 8)); k_tile = (k_tile + 1)) {
    int32_t global_i = (((int)threadIdx.x / 8) + ((int)blockIdx.x * 64));
    int32_t global_k = (((int)threadIdx.x % 8) + (k_tile * 8));
    smem_a[((((int)threadIdx.x / 8) * 8) + ((int)threadIdx.x % 8))] = (((global_i < m_size) && (global_k < k_size)) ? a[((global_i * k_size) + global_k)] : 0.0f);
    int32_t global_k_1 = (((int)threadIdx.x % 8) + (k_tile * 8));
    smem_a[(((((int)threadIdx.x / 8) * 8) + ((int)threadIdx.x % 8)) + 256)] = (((((((int)threadIdx.x / 8) + ((int)blockIdx.x * 64)) + 32) < m_size) && (global_k_1 < k_size)) ? a[((((((int)threadIdx.x / 8) + ((int)blockIdx.x * 64)) + 32) * k_size) + global_k_1)] : 0.0f);
    for (int32_t i_8 = 0; (i_8 < 8); i_8 = (i_8 + 1)) {
      int32_t global_k_2 = (i_8 + (k_tile * 8));
      int32_t global_j = ((int)threadIdx.x + ((int)blockIdx.y * 256));
      smem_b[((i_8 * 256) + (int)threadIdx.x)] = (((global_k_2 < k_size) && (global_j < n_size)) ? b[((global_k_2 * n_size) + global_j)] : 0.0f);
    }
    __syncthreads();
    for (int32_t i_9 = 0; (i_9 < 4); i_9 = (i_9 + 1)) {
      for (int32_t i_10 = 0; (i_10 < 4); i_10 = (i_10 + 1)) {
        for (int32_t k = 0; (k < 8); k = (k + 1)) {
          regs_c[((i_9 * 4) + i_10)] = (regs_c[((i_9 * 4) + i_10)] + (smem_a[((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_9) * 8) + k)] * smem_b[((k * 256) + (((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_10))]));
        }
      }
    }
    for (int32_t i_11 = 0; (i_11 < 4); i_11 = (i_11 + 1)) {
      for (int32_t i_12 = 0; (i_12 < 4); i_12 = (i_12 + 1)) {
        for (int32_t k_1 = 0; (k_1 < 8); k_1 = (k_1 + 1)) {
          regs_c[(16 + ((i_11 * 4) + i_12))] = (regs_c[(16 + ((i_11 * 4) + i_12))] + (smem_a[((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_11) * 8) + k_1)] * smem_b[(((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_12) + (k_1 * 256)) + 64)]));
        }
      }
    }
    for (int32_t i_13 = 0; (i_13 < 4); i_13 = (i_13 + 1)) {
      for (int32_t i_14 = 0; (i_14 < 4); i_14 = (i_14 + 1)) {
        for (int32_t k_2 = 0; (k_2 < 8); k_2 = (k_2 + 1)) {
          regs_c[(32 + ((i_13 * 4) + i_14))] = (regs_c[(32 + ((i_13 * 4) + i_14))] + (smem_a[(((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_13) * 8) + k_2) + 64)] * smem_b[((k_2 * 256) + (((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_14))]));
        }
      }
    }
    for (int32_t i_15 = 0; (i_15 < 4); i_15 = (i_15 + 1)) {
      for (int32_t i_16 = 0; (i_16 < 4); i_16 = (i_16 + 1)) {
        for (int32_t k_3 = 0; (k_3 < 8); k_3 = (k_3 + 1)) {
          regs_c[(48 + ((i_15 * 4) + i_16))] = (regs_c[(48 + ((i_15 * 4) + i_16))] + (smem_a[(((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_15) * 8) + k_3) + 64)] * smem_b[(((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_16) + (k_3 * 256)) + 64)]));
        }
      }
    }
    __syncthreads();
  }
  for (int32_t i_17 = 0; (i_17 < 4); i_17 = (i_17 + 1)) {
    for (int32_t i_18 = 0; (i_18 < 4); i_18 = (i_18 + 1)) {
      int32_t global_i_1 = (((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_17) + ((int)blockIdx.x * 64));
      int32_t global_j_1 = ((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_18) + ((int)blockIdx.y * 256));
      if ((global_i_1 < m_size) && (global_j_1 < n_size)) {
        c[((global_i_1 * n_size) + global_j_1)] = hidet_relu(regs_c[((i_17 * 4) + i_18)]);
      }
    }
  }
  for (int32_t i_19 = 0; (i_19 < 4); i_19 = (i_19 + 1)) {
    for (int32_t i_20 = 0; (i_20 < 4); i_20 = (i_20 + 1)) {
      int32_t global_i_2 = (((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_19) + ((int)blockIdx.x * 64));
      if ((global_i_2 < m_size) && ((((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_20) + ((int)blockIdx.y * 256)) + 64) < n_size)) {
        c[((global_i_2 * n_size) + (((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_20) + ((int)blockIdx.y * 256)) + 64))] = hidet_relu(regs_c[(16 + ((i_19 * 4) + i_20))]);
      }
    }
  }
  for (int32_t i_21 = 0; (i_21 < 4); i_21 = (i_21 + 1)) {
    for (int32_t i_22 = 0; (i_22 < 4); i_22 = (i_22 + 1)) {
      int32_t global_j_2 = ((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_22) + ((int)blockIdx.y * 256));
      if ((((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_21) + ((int)blockIdx.x * 64)) + 8) < m_size) && (global_j_2 < n_size)) {
        c[((((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_21) + ((int)blockIdx.x * 64)) + 8) * n_size) + global_j_2)] = hidet_relu(regs_c[(32 + ((i_21 * 4) + i_22))]);
      }
    }
  }
  for (int32_t i_23 = 0; (i_23 < 4); i_23 = (i_23 + 1)) {
    for (int32_t i_24 = 0; (i_24 < 4); i_24 = (i_24 + 1)) {
      if ((((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_23) + ((int)blockIdx.x * 64)) + 8) < m_size) && ((((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_24) + ((int)blockIdx.y * 256)) + 64) < n_size)) {
        c[((((((((((int)threadIdx.x / 64) * 4) + (((int)threadIdx.x % 32) / 16)) * 4) + i_23) + ((int)blockIdx.x * 64)) + 8) * n_size) + (((((((((int)threadIdx.x / 32) % 2) * 32) + ((int)threadIdx.x % 16)) * 4) + i_24) + ((int)blockIdx.y * 256)) + 64))] = hidet_relu(regs_c[(48 + ((i_23 * 4) + i_24))]);
      }
    }
  }
}

DLL void hidet_launch(float * __restrict__ a_ptr, float * __restrict__ b_ptr, float * __restrict__ c_ptr, int32_t m_size, int32_t n_size, int32_t k_size) {
  if ((0 < ((m_size + 63) / 64)) && (0 < ((n_size + 255) / 256))) {
    if (65535 < ((n_size + 255) / 256)) {
      printf("Launching kernel with grid_dim = (%d, %d, %d), block_dim = (%d, %d, %d)\n", ((m_size + 63) / 64), ((n_size + 255) / 256), 1, 256, 1, 1);
      assert(false);  // Invalid launch configuration
    }
    hidet_matmul_kernel<<<dim3(((m_size + 63) / 64), ((n_size + 255) / 256), 1), dim3(256, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(a_ptr, b_ptr, c_ptr, m_size, n_size, k_size);
    {cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
  }
}

Total running time of the script: (0 minutes 0.060 seconds)

Gallery generated by Sphinx-Gallery