Optimize PyTorch Model

Hidet provides a backend to pytorch dynamo to optimize PyTorch models. To use this backend, you need to specify ‘hidet’ as the backend when calling torch.compile() such as

# optimize the model with hidet provided backend 'hidet'
model_hidet = torch.compile(model, backend='hidet')

Note

Currently, all the operators in hidet are generated by hidet itself and there is no dependency on kernel libraries such as cuDNN or cuBLAS. In the future, we might support to lower some operators to these libraries if they perform better.

Under the hood, hidet will convert the PyTorch model to hidet’s graph representation and optimize the computation graph (such as sub-graph rewrite and fusion, constant folding, etc.). After that, each operator will be lowered to hidet’s scheduling system to generate the final kernel.

Hidet provides some configurations to control the hidet backend of torch dynamo.

Search in a larger search space

There are some operators that are compute-intensive and their scheduling is critical to the performance. We usually need to search in a schedule space to find the best schedule for them to achieve the best performance on given input shapes. However, searching in a larger schedule space usually takes longer time to optimize the model. By default, hidet will use their default schedule to generate the kernel for all input shapes. To search in a larger schedule space to get better performance, you can configure the search space via search_space() :

# There are three search spaces:
# 0 - use default schedule, no search [Default]
# 1 - search in a small schedule space (usually 1~30 schedules)
# 2 - search in a large schedule space (usually more than 30 schedules)
hidet.torch.dynamo_config.search_space(2)

# After configure the search space, you can optimize the model
model_opt = torch.compile(model, backend='hidet')

# The actual searching happens when you first run the model to know the input shapes
outputs = model_opt(inputs)

Please note that the search space we set through set_search_space() will be read and used when we first run the model, instead of when we call torch.compile().

Check the correctness

It is important to make sure the optimized model is correct. Hidet provides a configuration to print the numerical difference between the hidet generated operator and the original pytorch operator. You can configure it via correctness_report():

# enable the correctness checking
hidet.torch.dynamo_config.correctness_report()

After enabling the correctness report, every time a new graph is received to compile, hidet will print the numerical difference using the dummy inputs (for now, torch dynamo does not expose the actual inputs to backends, thus we can not use the actual inputs). Let’s take the resnet18 model as an example:

import torch.backends.cudnn
import hidet

x = torch.randn(1, 3, 224, 224).cuda()
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False)
model = model.cuda().eval()

with torch.no_grad():
    hidet.torch.dynamo_config.correctness_report()
    model_opt = torch.compile(model, backend='hidet')
    model_opt(x)
    kind           operator                                                                          dtype    error    attention
--  -------------  --------------------------------------------------------------------------------  -------  -------  -----------
0   placeholder                                                                                      float32  0.0e+00
1   call_module    Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)      float32  0.0e+00
2   call_module    BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)   float32  1.6e-07
3   call_module    ReLU(inplace=True)                                                                float32  1.2e-07
4   call_module    MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)        float32  1.2e-07
5   call_module    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)     float32  9.2e-07
6   call_module    BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)   float32  5.3e-07
7   call_module    ReLU(inplace=True)                                                                float32  4.8e-07
8   call_module    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)     float32  4.3e-07
9   call_module    BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)   float32  1.0e-06
10  call_function  operator.iadd                                                                     float32  1.0e-06
11  call_module    ReLU(inplace=True)                                                                float32  1.0e-06
12  call_module    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)     float32  1.2e-06
13  call_module    BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)   float32  8.1e-07
14  call_module    ReLU(inplace=True)                                                                float32  7.0e-07
15  call_module    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)     float32  4.9e-07
16  call_module    BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)   float32  1.4e-06
17  call_function  operator.iadd                                                                     float32  1.5e-06
18  call_module    ReLU(inplace=True)                                                                float32  1.5e-06
19  call_module    Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)    float32  1.6e-06
20  call_module    BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  8.1e-07
21  call_module    ReLU(inplace=True)                                                                float32  7.0e-07
22  call_module    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  6.3e-07
23  call_module    BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  1.1e-06
24  call_module    Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)                    float32  6.0e-07
25  call_module    BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  8.0e-07
26  call_function  operator.iadd                                                                     float32  1.2e-06
27  call_module    ReLU(inplace=True)                                                                float32  1.2e-06
28  call_module    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  1.1e-06
29  call_module    BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  1.1e-06
30  call_module    ReLU(inplace=True)                                                                float32  1.1e-06
31  call_module    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  5.7e-07
32  call_module    BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  1.3e-06
33  call_function  operator.iadd                                                                     float32  1.6e-06
34  call_module    ReLU(inplace=True)                                                                float32  1.4e-06
35  call_module    Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)   float32  1.1e-06
36  call_module    BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  1.1e-06
37  call_module    ReLU(inplace=True)                                                                float32  8.3e-07
38  call_module    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  6.0e-06
39  call_module    BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  5.0e-06
40  call_module    Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)                   float32  3.8e-07
41  call_module    BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  3.9e-07
42  call_function  operator.iadd                                                                     float32  4.2e-06
43  call_module    ReLU(inplace=True)                                                                float32  4.0e-06
44  call_module    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  7.0e-06
45  call_module    BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  5.0e-06
46  call_module    ReLU(inplace=True)                                                                float32  4.9e-06
47  call_module    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  2.3e-06
48  call_module    BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  4.7e-06
49  call_function  operator.iadd                                                                     float32  4.7e-06
50  call_module    ReLU(inplace=True)                                                                float32  4.7e-06
51  call_module    Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)   float32  2.0e-06
52  call_module    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  1.4e-06
53  call_module    ReLU(inplace=True)                                                                float32  1.1e-06
54  call_module    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  8.6e-07
55  call_module    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  2.6e-06
56  call_module    Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)                   float32  1.6e-06
57  call_module    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  3.2e-06
58  call_function  operator.iadd                                                                     float32  3.3e-06
59  call_module    ReLU(inplace=True)                                                                float32  2.5e-06
60  call_module    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  2.0e-06
61  call_module    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  2.2e-06
62  call_module    ReLU(inplace=True)                                                                float32  2.0e-06
63  call_module    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)   float32  8.7e-07
64  call_module    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  float32  1.1e-05
65  call_function  operator.iadd                                                                     float32  1.1e-05
66  call_module    ReLU(inplace=True)                                                                float32  1.0e-05
67  call_module    AdaptiveAvgPool2d(output_size=(1, 1))                                             float32  1.5e-06
68  call_function  torch.flatten                                                                     float32  1.5e-06
69  call_module    Linear(in_features=512, out_features=1000, bias=True)                             float32  3.0e-06
70  output                                                                                           float32  3.0e-06

Tip

Usually, we can expect:

  • for float32: \(e_h \leq 10^{-5}\), and

  • for float16: \(e_h \leq 10^{-2}\).

The correctness report will print the harmonic mean of the absolute error and relative error for each operator:

\[e_h = \frac{|actual - expected|}{|expected| + 1} \quad (\frac{1}{e_h} = \frac{1}{e_a} + \frac{1}{e_r})\]

where \(actual\), \(expected\) are the actual and expected results of the operator, respectively. The \(e_a\) and \(e_r\) are the absolute error and relative error, respectively. The harmonic mean error is printed for each operator.

Operator configurations

Use CUDA Graph to dispatch kernels

Hidet provides a configuration to use CUDA Graph to dispatch kernels. CUDA Graph is a new feature in CUDA 11.0 that allows us to record the kernel dispatches and replay them later. This feature is useful when we want to dispatch the same kernels multiple times. Hidet will enable CUDA Graph by default. You can disable it via use_cuda_graph():

# disable CUDA Graph
hidet.torch.dynamo_config.use_cuda_graph(False)

in case you want to use PyTorch’s CUDA Graph feature.

Use low-precision data type

Hidet provides a configuration to use low-precision data type. By default, hidet will use the same data type as the original PyTorch model. You can configure it via use_fp16() and use_fp16_reduction():

# automatically transform the model to use float16 data type
hidet.torch.dynamo_config.use_fp16(True)

# use float16 data type as the accumulate data type in operators with reduction
hidet.torch.dynamo_config.use_fp16_reduction(True)

You do not need to change the inputs feed to the model, as hidet will automatically cast the inputs to the configured data type automatically in the optimized model.