Using Template-based Scheduling
Contents
Note
Click here to download the full example code
Using Template-based Scheduling¶
In the previous tutorial, we have learned how to define a new operator with rule-based scheduling. Rule-based scheduling is a convenient way to define a new operator, but it is not efficient enough for operators with large amount of reduction. In this tutorial, we will learn how to define a new operator with template-based scheduling. Template-based scheduling allows us to define a tensor program template, and the template will be instantiated for different input shapes and tunable hyper-parameters.
Override implement_cuda()
method¶
The Task
class have two methods implement_cpu()
and implement_cuda()
that
can be override when we define a new task.
import hidet
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
class BatchMatmulFp16Task(Task):
def __init__(self, a: TensorNode, b: TensorNode):
batch_size, m_size, k_size = a.const_shape()
batch_size, k_size, n_size = b.const_shape()
c = compute(
name='c',
shape=[batch_size, m_size, n_size],
fcompute=lambda p, i, j: reduce(
shape=[k_size],
fcompute=lambda k: a[p, i, k] * b[p, k, j],
reduce_type='sum',
),
)
super().__init__(
name='batch_matmul_fp16',
inputs=[a, b],
outputs=[c],
attributes={
'batch_size': batch_size,
'm_size': m_size,
'n_size': n_size,
'k_size': k_size,
},
)
def allow_epilogue(self) -> bool:
return False
def implement_cuda(self, working_dir: str) -> IRModule:
# override this method to use template-based scheduling
return batch_matmul_mma_fp16_schedule(self)
In above task definition, we override the implement_cuda()
method to use template-based scheduling. Inside
the implement_cuda()
method, we call the batch_matmul_mma_fp16_schedule()
function to get a tensor
program that implements the computation defined in the task.
Implement the tensor-program¶
We can implement the batch_matmul_mma_fp16_schedule()
function in the following way. This function is
complicated. To learn what it does, we should know both CUDA programming and Hidet Script. Feel free to skip it for
now.
Note
This function defines the tensor program based on Hidet Script. Hidet Script is another domain-specific language in Hidet that allows developers to write tensor programs in python syntax. We will add more documentation to introduce Hidet Script in the future when it gets more stable.
def batch_matmul_mma_fp16_schedule(task: BatchMatmulFp16Task) -> IRModule:
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf, cast
from hidet.lang.mapping import repeat, spatial
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.lang.cuda import MmaConfig, mma_sync
# get the workload size
bs = task.attrs['batch_size']
m_size = task.attrs['m_size']
n_size = task.attrs['n_size']
k_size = task.attrs['k_size']
# define the template hyper-parameters
mma_config = MmaConfig.m16n8k8_f16_f16()
block_m, block_n, block_k = 128, 128, 8
warp_m, warp_n, warp_k = 64, 64, 8
warp_count_m, warp_count_n, warp_count_k = 2, 2, 1
mma_m, mma_n, mma_k = mma_config.m, mma_config.n, mma_config.k # 16, 8, 8
mma_count_m, mma_count_n, mma_count = 4, 8, 1
threads = warp_count_m * warp_count_n * warp_count_k * 32
# define the tensor program
with hidet.script_module() as module:
@hidet.script
def load_regs_a(
smem_a: f16[block_m, block_k], regs_a: f16[4, mma_config.a_elements]
):
"""Load A registers from shared memory."""
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
warp_id
):
for mi in range(mma_count_m):
p = 0
for i, k in mma_config.a_load_map.on(lane_id):
regs_a[mi, p] = smem_a[
wi * warp_m + mi * mma_m + i, wk * warp_k + k
]
p += 1
@hidet.script
def load_regs_b(
smem_b: f16[block_k, block_n], regs_b: f16[8, mma_config.b_elements]
):
"""Load B registers from shared memory."""
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
warp_id
):
for mj in range(mma_count_n):
p = 0
for k, j in mma_config.b_load_map.on(lane_id):
regs_b[mj, p] = smem_b[
wk * warp_k + k, wj * warp_n + mj * mma_n + j
]
p += 1
@hidet.script
def warp_mma(
regs_a: f16[4, mma_config.a_elements],
regs_b: f16[8, mma_config.b_elements],
regs_c: f16[4, 8, mma_config.c_elements],
):
"""Perform warp-level matrix multiplication."""
for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
mma_sync(mma_config, ~regs_a[mi, 0], ~regs_b[mj, 0], ~regs_c[mi, mj, 0])
@hidet.script
def store_c(regs_c: f16[4, 8, mma_config.c_elements], c: f16[bs, m_size, n_size]):
"""Store C registers to global memory."""
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
gmem_c = c[blockIdx.z, offset_m:, offset_n:]
for k_round in range(warp_count_k):
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
warp_id
):
if wk == k_round:
for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
p = 0
for i, j in mma_config.c_store_map.on(lane_id):
gmem_c.write(
[
wi * warp_m + mi * mma_m + i,
wj * warp_n + mj * mma_n + j,
],
regs_c[mi, mj, p],
protected=True,
)
p += 1
@hidet.script
def batch_matmul_kernel(
a: f16[bs, m_size, k_size],
b: f16[bs, k_size, n_size],
c: f16[bs, m_size, n_size],
):
"""Batch matrix multiplication kernel."""
attr.cuda_grid_dim = (
(m_size + block_m - 1) // block_m,
(n_size + block_n - 1) // block_n,
bs,
)
attr.cuda_block_dim = threads
offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
smem_a = tensor('shared', 'float16', [block_m, block_k])
smem_b = tensor('shared', 'float16', [block_k, block_n])
regs_a = tensor('register', 'float16', [4, mma_config.a_elements])
regs_b = tensor('register', 'float16', [8, mma_config.b_elements])
regs_c = tensor('register', 'float16', [4, 8, mma_config.c_elements])
for i, j, p in grid(4, 8, mma_config.c_elements):
regs_c[i, j, p] = 0.0
for k0 in range((k_size + block_k - 1) // block_k):
offset_k = k0 * block_k
gmem_a = a[blockIdx.z, offset_m:, offset_k:]
gmem_b = b[blockIdx.z, offset_k:, offset_n:]
for i, k in repeat(8, 1).spatial(16, 8).on(threadIdx.x):
smem_a[i, k] = gmem_a.read([i, k], protected=True)
for k, j in repeat(8, 1).spatial(1, 128).on(threadIdx.x):
smem_b[k, j] = gmem_b.read([k, j], protected=True)
syncthreads()
load_regs_a(smem_a, regs_a)
load_regs_b(smem_b, regs_b)
warp_mma(regs_a, regs_b, regs_c)
syncthreads()
store_c(regs_c, c)
ir_module = module.ir_module()
return ir_module
Define the operator¶
The remaining part is the same as the rule-based scheduling method to add new operator.
from hidet.graph import Operator, Tensor
from hidet.graph.ops.definitions.utils import input_like
class BatchMatmulFp16Op(Operator):
def __init__(self, a: Tensor, b: Tensor):
assert a.dtype == hidet.float16 and b.dtype == hidet.float16
super().__init__(
inputs=[a, b],
attributes={},
task=BatchMatmulFp16Task(input_like(a, 'a'), input_like(b, 'b')),
)
def batch_matmul_fp16(a: Tensor, b: Tensor) -> Tensor:
return BatchMatmulFp16Op(a, b).get_output(0)
def demo_usage():
a = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
b = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
c = batch_matmul_fp16(a, b)
print(a)
print(b)
print(c)
demo_usage()
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[ 0.31 1.1 ]
[ 0.72 -0.45]]]
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[-1.12 0.42]
[ 1.34 -0.24]]]
Tensor(shape=(1, 2, 2), dtype='float16', device='cuda:0')
[[[ 1.13 -0.14]
[-1.4 0.41]]]
Generated Source Code¶
If you are interested in the generated source code, here it is:
# we hide the code to get the source path for simplicity
print('Generated source path (relative to hidet cache root): \n{}'.format(relative_path))
print()
print('Generated source code:')
with open(source_path, 'r') as f:
print(f.read())
Generated source path (relative to hidet cache root):
docs-cache/ops/cuda_space_0/batch_matmul_fp16/6af8f0282257d7b7/source.cu
Generated source code:
#include <stdint.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hidet/runtime/cuda_context.h>
#include <hidet/runtime/cpu_context.h>
typedef float tfloat32_t;
#define __float_to_tf32(x) (x)
extern "C" {
__device__ __forceinline__ void hidet_cuda_mma_sync_aligned_m16n8k8_row_col_f16_f16_f16_f16(half * __restrict__ a, half * __restrict__ b, half * __restrict__ c) {
uint32_t *ra;
uint32_t *rb;
uint32_t *rc;
ra = ((uint32_t*)(a));
rb = ((uint32_t*)(b));
rc = ((uint32_t*)(c));
asm ("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" : "+r"(rc[0]), "+r"(rc[1]) : "r"(ra[0]), "r"(ra[1]), "r"(rb[0]));
}
__global__ void __launch_bounds__(128) hidet_batch_matmul_kernel(half * __restrict__ a, half * __restrict__ b, half * __restrict__ c) {
__shared__ half smem_a[1024];
__shared__ half smem_b[1024];
half regs_a[16];
half regs_b[16];
half regs_c[128];
for (int32_t i = 0; (i < 4); i = (i + 1)) {
for (int32_t j = 0; (j < 8); j = (j + 1)) {
for (int32_t p = 0; (p < 4); p = (p + 1)) {
regs_c[(((i * 32) + (j * 4)) + p)] = ((half)(0.0f));
}
}
}
for (int32_t i_1 = 0; (i_1 < 8); i_1 = (i_1 + 1)) {
smem_a[((((i_1 * 16) + ((int)threadIdx.x / 8)) * 8) + ((int)threadIdx.x % 8))] = (((((i_1 * 16) + ((int)threadIdx.x / 8)) < 2) && (((int)threadIdx.x % 8) < 2)) ? a[((((i_1 * 16) + ((int)threadIdx.x / 8)) * 2) + ((int)threadIdx.x % 8))] : half(0.0f));
}
for (int32_t i_2 = 0; (i_2 < 8); i_2 = (i_2 + 1)) {
smem_b[((i_2 * 128) + (int)threadIdx.x)] = (((i_2 < 2) && ((int)threadIdx.x < 2)) ? b[((i_2 * 2) + (int)threadIdx.x)] : half(0.0f));
}
__syncthreads();
half *smem_a_1 = smem_a;
half *regs_a_1 = regs_a;
int32_t lane_id = ((int)threadIdx.x % 32);
for (int32_t mi = 0; (mi < 4); mi = (mi + 1)) {
int32_t p_1 = 0;
for (int32_t i_3 = 0; (i_3 < 2); i_3 = (i_3 + 1)) {
for (int32_t i_4 = 0; (i_4 < 2); i_4 = (i_4 + 1)) {
regs_a_1[((mi * 4) + p_1)] = smem_a_1[((((((((int)threadIdx.x / 32) / 2) * 64) + (mi * 16)) + ((i_3 * 8) + (lane_id / 4))) * 8) + (((lane_id % 4) * 2) + i_4))];
p_1 = (p_1 + 1);
}
}
}
half *smem_b_1 = smem_b;
half *regs_b_1 = regs_b;
int32_t lane_id_1 = ((int)threadIdx.x % 32);
for (int32_t mj = 0; (mj < 8); mj = (mj + 1)) {
int32_t p_2 = 0;
for (int32_t i_5 = 0; (i_5 < 2); i_5 = (i_5 + 1)) {
regs_b_1[((mj * 2) + p_2)] = smem_b_1[(((((lane_id_1 % 4) * 2) + i_5) * 128) + ((((((int)threadIdx.x / 32) % 2) * 64) + (mj * 8)) + (lane_id_1 / 4)))];
p_2 = (p_2 + 1);
}
}
half *regs_a_2 = regs_a;
half *regs_b_2 = regs_b;
half *regs_c_1 = regs_c;
for (int32_t i_6 = 0; (i_6 < 4); i_6 = (i_6 + 1)) {
for (int32_t i_7 = 0; (i_7 < 8); i_7 = (i_7 + 1)) {
hidet_cuda_mma_sync_aligned_m16n8k8_row_col_f16_f16_f16_f16(®s_a_2[(i_6 * 4)], ®s_b_2[(i_7 * 2)], ®s_c_1[((i_6 * 32) + (i_7 * 4))]);
}
}
__syncthreads();
half *regs_c_2 = regs_c;
half *c_1 = c;
int32_t warp_id = ((int)threadIdx.x / 32);
int32_t lane_id_2 = ((int)threadIdx.x % 32);
for (int32_t i_8 = 0; (i_8 < 4); i_8 = (i_8 + 1)) {
for (int32_t i_9 = 0; (i_9 < 8); i_9 = (i_9 + 1)) {
int32_t p_3 = 0;
for (int32_t i_10 = 0; (i_10 < 2); i_10 = (i_10 + 1)) {
for (int32_t i_11 = 0; (i_11 < 2); i_11 = (i_11 + 1)) {
if ((((((warp_id / 2) * 64) + (i_8 * 16)) + ((i_10 * 8) + (lane_id_2 / 4))) < 2) && (((((warp_id % 2) * 64) + (i_9 * 8)) + (((lane_id_2 % 4) * 2) + i_11)) < 2)) {
c_1[((((((warp_id / 2) * 64) + (i_8 * 16)) + ((i_10 * 8) + (lane_id_2 / 4))) * 2) + ((((warp_id % 2) * 64) + (i_9 * 8)) + (((lane_id_2 % 4) * 2) + i_11)))] = regs_c_2[(((i_8 * 32) + (i_9 * 4)) + p_3)];
}
p_3 = (p_3 + 1);
}
}
}
}
}
__host__ void hidet_launch(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
assert(((void)"Expect 3 arguments", (num_args == 3)));
assert(((void)"The 0-th argument should be tensor_pointer(float16, [1, 2, 2])", (arg_types[0] == 3)));
assert(((void)"The 1-th argument should be tensor_pointer(float16, [1, 2, 2])", (arg_types[1] == 3)));
assert(((void)"The 2-th argument should be tensor_pointer(float16, [1, 2, 2])", (arg_types[2] == 3)));
hidet_batch_matmul_kernel<<<dim3(1, 1, 1), dim3(128, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(((half*)(args[0])), ((half*)(args[1])), ((half*)(args[2])));
}
}
Summary¶
In this tutorial, we have shown how to use the template-based scheduling mechanism to add new operators. Basically, what we need to do is to override the implement_cuda or implement_cpu method of the task class, and implement the task to get an IR module. In this example, we used Hidet Script to implement the task, but you can also use other ways such as IR builder.
Total running time of the script: ( 0 minutes 1.767 seconds)