Visualize Flow Graph

Visualization is a key component of a machine learning tool to allow us have a better understanding of the model.

We customized the popular Netron viewer to visualize the flow graph of a hidet model. The customized Netron viewer can be found at here, you can also find a link on the bottom of the documentation side bar.

In this tutorial, we will show you how to visualize the flow graph of a model.

Define model

We first define a model with a self-attention layer.

import math
import hidet
from hidet import Tensor
from hidet.graph import nn, ops


class SelfAttention(nn.Module):
    def __init__(self, hidden_size=768, num_attention_heads=12):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads

        self.query_layer = nn.Linear(hidden_size, hidden_size)
        self.key_layer = nn.Linear(hidden_size, hidden_size)
        self.value_layer = nn.Linear(hidden_size, hidden_size)

    def transpose_for_scores(self, x: Tensor) -> Tensor:
        batch_size, seq_length, hidden_size = x.shape
        x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size])
        x = x.rearrange([[0, 2], [1], [3]])
        return x  # [batch_size * num_attention_heads, seq_length, attention_head_size]

    def forward(self, hidden_states: Tensor, attention_mask: Tensor):
        batch_size, seq_length, _ = hidden_states.shape
        query = self.transpose_for_scores(self.query_layer(hidden_states))
        key = self.transpose_for_scores(self.key_layer(hidden_states))
        value = self.transpose_for_scores(self.value_layer(hidden_states))
        attention_scores = ops.matmul(query, ops.transpose(key, [-1, -2])) / math.sqrt(
            self.attention_head_size
        )
        attention_scores = attention_scores + attention_mask
        attention_probs = ops.softmax(attention_scores, axis=-1)
        context = ops.matmul(attention_probs, value)
        context = context.reshape(
            [batch_size, self.num_attention_heads, seq_length, self.attention_head_size]
        )
        context = context.rearrange([[0], [2], [1, 3]])
        return context


model = SelfAttention()
print(model)
SelfAttention(
  (query_layer): Linear(in_features=768, out_features=768)
  (key_layer): Linear(in_features=768, out_features=768)
  (value_layer): Linear(in_features=768, out_features=768)
)

Generate flow graph

Then we generate the flow graph of the model.

graph = model.flow_graph_for(
    inputs=[hidet.randn([1, 128, 768]), hidet.ones([1, 128], dtype='int32')]
)
print(graph)
Graph(x: float32[1, 128, 768][cpu], x_1: int32[1, 128][cpu]){
  c = Constant(float32[768, 768][cpu])
  c_1 = Constant(float32[768][cpu])
  c_2 = Constant(float32[768, 768][cpu])
  c_3 = Constant(float32[768][cpu])
  c_4 = Constant(float32[768, 768][cpu])
  c_5 = Constant(float32[768][cpu])
  x_2: float32[1, 128, 768][cpu] = Matmul(x, c, require_prologue=False)
  x_3: float32[1, 128, 768][cpu] = Add(x_2, c_1)
  x_4: float32[1, 128, 12, 64][cpu] = Reshape(x_3, shape=[1, 128, 12, 64])
  x_5: float32[12, 128, 64][cpu] = Rearrange(x_4, plan=[[0, 2], [1], [3]])
  x_6: float32[1, 128, 768][cpu] = Matmul(x, c_2, require_prologue=False)
  x_7: float32[1, 128, 768][cpu] = Add(x_6, c_3)
  x_8: float32[1, 128, 12, 64][cpu] = Reshape(x_7, shape=[1, 128, 12, 64])
  x_9: float32[12, 128, 64][cpu] = Rearrange(x_8, plan=[[0, 2], [1], [3]])
  x_10: float32[12, 64, 128][cpu] = PermuteDims(x_9, axes=[0, 2, 1])
  x_11: float32[12, 128, 128][cpu] = Matmul(x_5, x_10, require_prologue=False)
  x_12: float32[12, 128, 128][cpu] = DivideScalar(x_11, scalar=8.0f)
  x_13: float32[12, 128, 128][cpu] = Add(x_12, x_1)
  x_14: float32[12, 128, 128][cpu] = Softmax(x_13, axis=2)
  x_15: float32[1, 128, 768][cpu] = Matmul(x, c_4, require_prologue=False)
  x_16: float32[1, 128, 768][cpu] = Add(x_15, c_5)
  x_17: float32[1, 128, 12, 64][cpu] = Reshape(x_16, shape=[1, 128, 12, 64])
  x_18: float32[12, 128, 64][cpu] = Rearrange(x_17, plan=[[0, 2], [1], [3]])
  x_19: float32[12, 128, 64][cpu] = Matmul(x_14, x_18, require_prologue=False)
  x_20: float32[1, 12, 128, 64][cpu] = Reshape(x_19, shape=[1, 12, 128, 64])
  x_21: float32[1, 128, 768][cpu] = Rearrange(x_20, plan=[[0], [2], [1, 3]])
  return x_21
}

Dump netron graph

To visualize the flow graph, we need to dump the graph structure to a json file using hidet.utils.netron.dump() function.

from hidet.utils import netron

with open('attention-graph.json', 'w') as f:
    netron.dump(graph, f)

Above code will generate a json file named attention-graph.json.

You can download the generated json file attention-graph.json and open it with the customized Netron viewer.

Visualize optimization intermediate graphs

Hidet also provides a way to visualize the intermediate graphs of the optimization passes.

To get the json files for the intermediate graphs, we need to add an instrument that dumps the graph in the pass context before optimize it. We can use PassContext.save_graph_instrument() method to do that.

with hidet.graph.PassContext() as ctx:
    # print the time cost of each pass
    ctx.profile_pass_instrument(print_stdout=True)

    # save the intermediate graph of each pass to './outs' directory
    ctx.save_graph_instrument(out_dir='./outs')

    # run the optimization passes
    graph_opt = hidet.graph.optimize(graph)
  ConvChannelLastPass started...
  ConvChannelLastPass 0.007 seconds
  SubgraphRewritePass started...
  SubgraphRewritePass 0.015 seconds
 AutoMixPrecisionPass started...
 AutoMixPrecisionPass 0.007 seconds
SelectiveQuantizePass started...
  SubgraphRewritePass started...
  SubgraphRewritePass 0.015 seconds
SelectiveQuantizePass 0.022 seconds
   ResolveVariantPass started...
   ResolveVariantPass 0.008 seconds
     FuseOperatorPass started...
     FuseOperatorPass 0.023 seconds
 EliminateBarrierPass started...
 EliminateBarrierPass 0.005 seconds

Above code will generate a directory named outs that contains the json files for the intermediate graphs. The optimized graph:

print(graph_opt)
Graph(x: float32[1, 128, 768][cpu], x_1: int32[1, 128][cpu]){
  c = Constant(float32[768, 768][cpu])
  c_1 = Constant(float32[768][cpu])
  c_2 = Constant(float32[768, 768][cpu])
  c_3 = Constant(float32[768][cpu])
  c_4 = Constant(float32[768, 768][cpu])
  c_5 = Constant(float32[768][cpu])
  x_2: float32[12, 128, 64][cpu] = FusedMatmul(x, c, c_1, fused_graph=FlowGraph(Matmul, Add, Reshape, Rearrange), anchor=0)
  x_3: float32[12, 64, 128][cpu] = FusedMatmul(x, c_2, c_3, fused_graph=FlowGraph(Matmul, Add, Reshape, Rearrange, PermuteDims), anchor=0)
  x_4: float32[12, 128, 128][cpu] = FusedMatmul(x_2, x_3, x_1, fused_graph=FlowGraph(Matmul, DivideScalar, Add), anchor=0)
  x_5: float32[12, 128, 128][cpu] = Softmax(x_4, axis=2)
  x_6: float32[12, 128, 64][cpu] = FusedMatmul(x, c_4, c_5, fused_graph=FlowGraph(Matmul, Add, Reshape, Rearrange), anchor=0)
  x_7: float32[1, 128, 768][cpu] = FusedMatmul(x_5, x_6, fused_graph=FlowGraph(Matmul, Reshape, Rearrange), anchor=0)
  return x_7
}

Summary

This tutorial shows how to visualize the flow graph of a model and the intermediate graphs of the optimization passes.

Total running time of the script: (0 minutes 0.123 seconds)

Gallery generated by Sphinx-Gallery