Using Template-based Scheduling

In the previous tutorial, we have learned how to define a new operator with rule-based scheduling. Rule-based scheduling is a convenient way to define a new operator, but it is not efficient enough for operators with large amount of reduction. In this tutorial, we will learn how to define a new operator with template-based scheduling. Template-based scheduling allows us to define a tensor program template, and the template will be instantiated for different input shapes and tunable hyper-parameters.

Override implement_cuda() method

The Task class have two methods implement_cpu() and implement_cuda() that can be override when we define a new task.

import hidet
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule


class BatchMatmulFp16Task(Task):
    def __init__(self, a: TensorNode, b: TensorNode):
        batch_size, m_size, k_size = a.const_shape()
        batch_size, k_size, n_size = b.const_shape()
        c = compute(
            name='c',
            shape=[batch_size, m_size, n_size],
            fcompute=lambda p, i, j: reduce(
                shape=[k_size],
                fcompute=lambda k: a[p, i, k] * b[p, k, j],
                reduce_type='sum',
            ),
        )
        super().__init__(
            name='batch_matmul_fp16',
            inputs=[a, b],
            outputs=[c],
            attributes={
                'batch_size': batch_size,
                'm_size': m_size,
                'n_size': n_size,
                'k_size': k_size,
            },
        )

    def allow_epilogue(self) -> bool:
        return False

    def implement_cuda(self, working_dir: str) -> IRModule:
        # override this method to use template-based scheduling
        return batch_matmul_mma_fp16_schedule(self)

In above task definition, we override the implement_cuda() method to use template-based scheduling. Inside the implement_cuda() method, we call the batch_matmul_mma_fp16_schedule() function to get a tensor program that implements the computation defined in the task.

Implement the tensor-program

We can implement the batch_matmul_mma_fp16_schedule() function in the following way. This function is complicated. To learn what it does, we should know both CUDA programming and Hidet Script. Feel free to skip it for now.

Note

This function defines the tensor program based on Hidet Script. Hidet Script is another domain-specific language in Hidet that allows developers to write tensor programs in python syntax. We will add more documentation to introduce Hidet Script in the future when it gets more stable.

def batch_matmul_mma_fp16_schedule(task: BatchMatmulFp16Task) -> IRModule:
    from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf, cast
    from hidet.lang.mapping import repeat, spatial
    from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
    from hidet.lang.cuda import MmaConfig, mma_sync

    # get the workload size
    bs = task.attrs['batch_size']
    m_size = task.attrs['m_size']
    n_size = task.attrs['n_size']
    k_size = task.attrs['k_size']

    # define the template hyper-parameters
    mma_config = MmaConfig.m16n8k8_f16_f16()
    block_m, block_n, block_k = 128, 128, 8
    warp_m, warp_n, warp_k = 64, 64, 8
    warp_count_m, warp_count_n, warp_count_k = 2, 2, 1
    mma_m, mma_n, mma_k = mma_config.m, mma_config.n, mma_config.k  # 16, 8, 8
    mma_count_m, mma_count_n, mma_count = 4, 8, 1
    threads = warp_count_m * warp_count_n * warp_count_k * 32

    # define the tensor program
    with hidet.script_module() as module:

        @hidet.script
        def load_regs_a(
            smem_a: f16[block_m, block_k], regs_a: f16[4, mma_config.a_elements]
        ):
            """Load A registers from shared memory."""
            warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
            for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
                warp_id
            ):
                for mi in range(mma_count_m):
                    p = 0
                    for i, k in mma_config.a_load_map.on(lane_id):
                        regs_a[mi, p] = smem_a[
                            wi * warp_m + mi * mma_m + i, wk * warp_k + k
                        ]
                        p += 1

        @hidet.script
        def load_regs_b(
            smem_b: f16[block_k, block_n], regs_b: f16[8, mma_config.b_elements]
        ):
            """Load B registers from shared memory."""
            warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
            for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
                warp_id
            ):
                for mj in range(mma_count_n):
                    p = 0
                    for k, j in mma_config.b_load_map.on(lane_id):
                        regs_b[mj, p] = smem_b[
                            wk * warp_k + k, wj * warp_n + mj * mma_n + j
                        ]
                        p += 1

        @hidet.script
        def warp_mma(
            regs_a: f16[4, mma_config.a_elements],
            regs_b: f16[8, mma_config.b_elements],
            regs_c: f16[4, 8, mma_config.c_elements],
        ):
            """Perform warp-level matrix multiplication."""
            for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
                mma_sync(mma_config, ~regs_a[mi, 0], ~regs_b[mj, 0], ~regs_c[mi, mj, 0])

        @hidet.script
        def store_c(regs_c: f16[4, 8, mma_config.c_elements], c: f16[bs, m_size, n_size]):
            """Store C registers to global memory."""
            warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
            offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
            gmem_c = c[blockIdx.z, offset_m:, offset_n:]
            for k_round in range(warp_count_k):
                for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
                    warp_id
                ):
                    if wk == k_round:
                        for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
                            p = 0
                            for i, j in mma_config.c_store_map.on(lane_id):
                                gmem_c.write(
                                    [
                                        wi * warp_m + mi * mma_m + i,
                                        wj * warp_n + mj * mma_n + j,
                                    ],
                                    regs_c[mi, mj, p],
                                    protected=True,
                                )
                                p += 1

        @hidet.script
        def batch_matmul_kernel(
            a: f16[bs, m_size, k_size],
            b: f16[bs, k_size, n_size],
            c: f16[bs, m_size, n_size],
        ):
            """Batch matrix multiplication kernel."""
            attr.cuda_grid_dim = (
                (m_size + block_m - 1) // block_m,
                (n_size + block_n - 1) // block_n,
                bs,
            )
            attr.cuda_block_dim = threads
            offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
            smem_a = tensor('shared', 'float16', [block_m, block_k])
            smem_b = tensor('shared', 'float16', [block_k, block_n])
            regs_a = tensor('register', 'float16', [4, mma_config.a_elements])
            regs_b = tensor('register', 'float16', [8, mma_config.b_elements])
            regs_c = tensor('register', 'float16', [4, 8, mma_config.c_elements])

            for i, j, p in grid(4, 8, mma_config.c_elements):
                regs_c[i, j, p] = 0.0

            for k0 in range((k_size + block_k - 1) // block_k):
                offset_k = k0 * block_k
                gmem_a = a[blockIdx.z, offset_m:, offset_k:]
                gmem_b = b[blockIdx.z, offset_k:, offset_n:]
                for i, k in repeat(8, 1).spatial(16, 8).on(threadIdx.x):
                    smem_a[i, k] = gmem_a.read([i, k], protected=True)
                for k, j in repeat(8, 1).spatial(1, 128).on(threadIdx.x):
                    smem_b[k, j] = gmem_b.read([k, j], protected=True)
                syncthreads()
                load_regs_a(smem_a, regs_a)
                load_regs_b(smem_b, regs_b)
                warp_mma(regs_a, regs_b, regs_c)
                syncthreads()
            store_c(regs_c, c)

    ir_module = module.ir_module()
    return ir_module

Define the operator

The remaining part is the same as the rule-based scheduling method to add new operator.

from hidet.graph import Operator, Tensor
from hidet.graph.ops.definitions.utils import input_like


class BatchMatmulFp16Op(Operator):
    def __init__(self, a: Tensor, b: Tensor):
        assert a.dtype == hidet.float16 and b.dtype == hidet.float16
        super().__init__(
            inputs=[a, b],
            attributes={},
            task=BatchMatmulFp16Task(input_like(a, 'a'), input_like(b, 'b')),
        )


def batch_matmul_fp16(a: Tensor, b: Tensor) -> Tensor:
    return BatchMatmulFp16Op(a, b).get_output(0)


def demo_usage():
    a = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
    b = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
    c = batch_matmul_fp16(a, b)
    print(a)
    print(b)
    print(c)


demo_usage()
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[ 0.31  1.1 ]
  [ 0.72 -0.45]]]
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[-1.12  0.42]
  [ 1.34 -0.24]]]
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[ 1.13 -0.14]
  [-1.4   0.41]]]

Generated Source Code

If you are interested in the generated source code, here it is:

# we hide the code to get the source path for simplicity
print('Generated source path (relative to hidet cache root): \n{}'.format(relative_path))
print()
print('Generated source code:')
with open(source_path, 'r') as f:
    print(f.read())
Generated source path (relative to hidet cache root):
docs-cache/ops/cuda_space_0/batch_matmul_fp16/6af8f0282257d7b7/source.cu

Generated source code:
#include <stdint.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hidet/runtime/cuda_context.h>
#include <hidet/runtime/cpu_context.h>
typedef float tfloat32_t;
#define __float_to_tf32(x) (x)
extern "C" {

__device__ __forceinline__ void hidet_cuda_mma_sync_aligned_m16n8k8_row_col_f16_f16_f16_f16(half * __restrict__ a, half * __restrict__ b, half * __restrict__ c) {
  uint32_t *ra;
  uint32_t *rb;
  uint32_t *rc;
  ra = ((uint32_t*)(a));
  rb = ((uint32_t*)(b));
  rc = ((uint32_t*)(c));
  asm ("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" : "+r"(rc[0]), "+r"(rc[1]) : "r"(ra[0]), "r"(ra[1]), "r"(rb[0]));
}

__global__ void __launch_bounds__(128) hidet_batch_matmul_kernel(half * __restrict__ a, half * __restrict__ b, half * __restrict__ c) {
  __shared__ half smem_a[1024];
  __shared__ half smem_b[1024];
  half regs_a[16];
  half regs_b[16];
  half regs_c[128];
  for (int32_t i = 0; (i < 4); i = (i + 1)) {
    for (int32_t j = 0; (j < 8); j = (j + 1)) {
      for (int32_t p = 0; (p < 4); p = (p + 1)) {
        regs_c[(((i * 32) + (j * 4)) + p)] = ((half)(0.0f));
      }
    }
  }
  for (int32_t i_1 = 0; (i_1 < 8); i_1 = (i_1 + 1)) {
    smem_a[((((i_1 * 16) + ((int)threadIdx.x / 8)) * 8) + ((int)threadIdx.x % 8))] = (((((i_1 * 16) + ((int)threadIdx.x / 8)) < 2) && (((int)threadIdx.x % 8) < 2)) ? a[((((i_1 * 16) + ((int)threadIdx.x / 8)) * 2) + ((int)threadIdx.x % 8))] : half(0.0f));
  }
  for (int32_t i_2 = 0; (i_2 < 8); i_2 = (i_2 + 1)) {
    smem_b[((i_2 * 128) + (int)threadIdx.x)] = (((i_2 < 2) && ((int)threadIdx.x < 2)) ? b[((i_2 * 2) + (int)threadIdx.x)] : half(0.0f));
  }
  __syncthreads();
  half *smem_a_1 = smem_a;
  half *regs_a_1 = regs_a;
  int32_t lane_id = ((int)threadIdx.x % 32);
  for (int32_t mi = 0; (mi < 4); mi = (mi + 1)) {
    int32_t p_1 = 0;
    for (int32_t i_3 = 0; (i_3 < 2); i_3 = (i_3 + 1)) {
      for (int32_t i_4 = 0; (i_4 < 2); i_4 = (i_4 + 1)) {
        regs_a_1[((mi * 4) + p_1)] = smem_a_1[((((((((int)threadIdx.x / 32) / 2) * 64) + (mi * 16)) + ((i_3 * 8) + (lane_id / 4))) * 8) + (((lane_id % 4) * 2) + i_4))];
        p_1 = (p_1 + 1);
      }
    }
  }
  half *smem_b_1 = smem_b;
  half *regs_b_1 = regs_b;
  int32_t lane_id_1 = ((int)threadIdx.x % 32);
  for (int32_t mj = 0; (mj < 8); mj = (mj + 1)) {
    int32_t p_2 = 0;
    for (int32_t i_5 = 0; (i_5 < 2); i_5 = (i_5 + 1)) {
      regs_b_1[((mj * 2) + p_2)] = smem_b_1[(((((lane_id_1 % 4) * 2) + i_5) * 128) + ((((((int)threadIdx.x / 32) % 2) * 64) + (mj * 8)) + (lane_id_1 / 4)))];
      p_2 = (p_2 + 1);
    }
  }
  half *regs_a_2 = regs_a;
  half *regs_b_2 = regs_b;
  half *regs_c_1 = regs_c;
  for (int32_t i_6 = 0; (i_6 < 4); i_6 = (i_6 + 1)) {
    for (int32_t i_7 = 0; (i_7 < 8); i_7 = (i_7 + 1)) {
      hidet_cuda_mma_sync_aligned_m16n8k8_row_col_f16_f16_f16_f16(&regs_a_2[(i_6 * 4)], &regs_b_2[(i_7 * 2)], &regs_c_1[((i_6 * 32) + (i_7 * 4))]);
    }
  }
  __syncthreads();
  half *regs_c_2 = regs_c;
  half *c_1 = c;
  int32_t warp_id = ((int)threadIdx.x / 32);
  int32_t lane_id_2 = ((int)threadIdx.x % 32);
  for (int32_t i_8 = 0; (i_8 < 4); i_8 = (i_8 + 1)) {
    for (int32_t i_9 = 0; (i_9 < 8); i_9 = (i_9 + 1)) {
      int32_t p_3 = 0;
      for (int32_t i_10 = 0; (i_10 < 2); i_10 = (i_10 + 1)) {
        for (int32_t i_11 = 0; (i_11 < 2); i_11 = (i_11 + 1)) {
          if ((((((warp_id / 2) * 64) + (i_8 * 16)) + ((i_10 * 8) + (lane_id_2 / 4))) < 2) && (((((warp_id % 2) * 64) + (i_9 * 8)) + (((lane_id_2 % 4) * 2) + i_11)) < 2)) {
            c_1[((((((warp_id / 2) * 64) + (i_8 * 16)) + ((i_10 * 8) + (lane_id_2 / 4))) * 2) + ((((warp_id % 2) * 64) + (i_9 * 8)) + (((lane_id_2 % 4) * 2) + i_11)))] = regs_c_2[(((i_8 * 32) + (i_9 * 4)) + p_3)];
          }
          p_3 = (p_3 + 1);
        }
      }
    }
  }
}

__host__ void hidet_launch(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
  assert(((void)"Expect 3 arguments", (num_args == 3)));
  assert(((void)"The 0-th argument should be tensor_pointer(float16, [1, 2, 2])", (arg_types[0] == 3)));
  assert(((void)"The 1-th argument should be tensor_pointer(float16, [1, 2, 2])", (arg_types[1] == 3)));
  assert(((void)"The 2-th argument should be tensor_pointer(float16, [1, 2, 2])", (arg_types[2] == 3)));
  hidet_batch_matmul_kernel<<<dim3(1, 1, 1), dim3(128, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(((half*)(args[0])), ((half*)(args[1])), ((half*)(args[2])));
}

}

Summary

In this tutorial, we have shown how to use the template-based scheduling mechanism to add new operators. Basically, what we need to do is to override the implement_cuda or implement_cpu method of the task class, and implement the task to get an IR module. In this example, we used Hidet Script to implement the task, but you can also use other ways such as IR builder.

Total running time of the script: ( 0 minutes 1.767 seconds)

Gallery generated by Sphinx-Gallery