Naive Matrix Multiplication

In this example, we will show you how to write a program that performs matrix multiplication on GPU that supports arbitrary input size.

import torch
import hidet
from hidet.lang import attrs
from hidet.lang.types import f32, i32, tensor_pointer
from hidet.lang.cuda import threadIdx, blockIdx

hidet.option.cache_dir('./outs/cache')

with hidet.script_module() as script_module:

    @hidet.script
    def matmul_kernel(a_ptr: ~f32, b_ptr: ~f32, c_ptr: ~f32, m_size: i32, n_size: i32, k_size: i32):
        attrs.func_kind = 'cuda_kernel'
        attrs.cuda.block_dim = 16, 16
        attrs.cuda.grid_dim = (m_size + 15) // 16, (n_size + 15) // 16

        # define three tensor pointers that hold the shape and dtype information
        a = tensor_pointer(dtype=f32, shape=[m_size, k_size], init=a_ptr)
        b = tensor_pointer(dtype=f32, shape=[k_size, n_size], init=b_ptr)
        c = tensor_pointer(dtype=f32, shape=[m_size, n_size], init=c_ptr)

        i = blockIdx.x * 16 + threadIdx.x
        j = blockIdx.y * 16 + threadIdx.y

        if i < m_size and j < n_size:
            c[i, j] = 0.0
            for k in range(k_size):
                c[i, j] += a[i, k] * b[k, j]


module = script_module.build()

Hidet compiled module can be called directly with pytorch tensors.

def matmul(a: torch.Tensor, b: torch.Tensor):
    m_size, n_size, k_size = a.shape[0], b.shape[1], a.shape[1]
    c = torch.empty([m_size, n_size], device='cuda')
    module(a, b, c, m_size, n_size, k_size)
    return c

Run the compiled kernels with different input sizes and check the correctness of the result.

for m_size, n_size, k_size in [(234, 345, 567), (123, 456, 789)]:
    a = torch.randn(m_size, k_size, device='cuda')
    b = torch.randn(k_size, n_size, device='cuda')

    c1 = matmul(a, b)
    c2 = torch.matmul(a, b)

    # check the correctness of the result
    torch.testing.assert_close(c1, c2, atol=1e-4, rtol=1e-4)
print(module.source())
#include <stdint.h>
#include <hidet/runtime/symbols.h>
#include <hidet/runtime/memory_planner.h>
#include <hidet/runtime/cpu/context.h>
#include <hidet/runtime/cuda/complex.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>


static __global__ void __launch_bounds__(256) hidet_matmul_kernel(float * __restrict__ a_ptr, float * __restrict__ b_ptr, float * __restrict__ c_ptr, int32_t m_size, int32_t n_size, int32_t k_size) {
  int32_t i = (((int)blockIdx.x * 16) + (int)threadIdx.x);
  int32_t j = (((int)blockIdx.y * 16) + (int)threadIdx.y);
  if ((i < m_size) && (j < n_size)) {
    c_ptr[((i * n_size) + j)] = 0.0f;
    for (int32_t k = 0; (k < k_size); k = (k + 1)) {
      c_ptr[((i * n_size) + j)] = (c_ptr[((i * n_size) + j)] + (a_ptr[((i * k_size) + k)] * b_ptr[((k * n_size) + j)]));
    }
  }
}

DLL void hidet_launch(float * __restrict__ a_ptr, float * __restrict__ b_ptr, float * __restrict__ c_ptr, int32_t m_size, int32_t n_size, int32_t k_size) {
  if ((0 < ((m_size + 15) / 16)) && (0 < ((n_size + 15) / 16))) {
    if (65535 < ((n_size + 15) / 16)) {
      printf("Launching kernel with grid_dim = (%d, %d, %d), block_dim = (%d, %d, %d)\n", ((m_size + 15) / 16), ((n_size + 15) / 16), 1, 16, 16, 1);
      assert(false);  // Invalid launch configuration
    }
    hidet_matmul_kernel<<<dim3(((m_size + 15) / 16), ((n_size + 15) / 16), 1), dim3(16, 16, 1), 0, (cudaStream_t)get_cuda_stream()>>>(a_ptr, b_ptr, c_ptr, m_size, n_size, k_size);
    {cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
  }
}

Total running time of the script: (0 minutes 0.005 seconds)

Gallery generated by Sphinx-Gallery