.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "gallery/developer-guides/add-subgraph-rewrite-rule.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
:ref:`Go to the end `
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_gallery_developer-guides_add-subgraph-rewrite-rule.py:
Add Sub-Graph Rewrite Rule
==========================
This tutorial shows how to add a sub-graph rewrite rule in the graph optimization pipeline. Sub-graph rewriting is an
important technique in graph optimization. It is used to replace a sub-graph with another sub-graph, which is usually
more efficient than the original one. For example, we can replace a sub-graph with two matrix multiplications sharing
the same input and one addition with a concatenation and a single matrix multiplication:
.. figure:: /_static/img/subgraph-rewrite-example.svg
:align: center
:scale: 70%
The sub-graph rewrite rule that fuses two matrix multiplications.
.. seealso::
:class: margin
`TASO `_ systematically studies the sub-graph rewrite optimization
for deep learning workloads.
After the rewrite, the graph becomes more efficient as we only need to run a single kernel and the `fused` matrix
multiplication usually exposes more parallelism to utilize the underlying hardware. We can also fuse multiple
convolutions into a single one, or do other sub-graph rewrites.
Sub-graph rewrite in Hidet
--------------------------
In Hidet, we use a *sub-graph rewrite rule* to describe the rewrite. A sub-graph rewrite rule contains two parts:
- **Sub-graph pattern**: a sub-graph pattern that we use to match the sub-graph in the graph. The pattern is a directed
acyclic graph (DAG) where each node is an operator pattern and each edge is a tensor pattern. We only specify the
operator type for each node and whether the (input) tensors are symbolic or concrete.
- **Target sub-graph constructor**: when we find a sub-graph that matches the pattern, we use the constructor to
construct the target sub-graph that replaces the matched sub-graph. When constructing the target sub-graph, we can
use the matched tensors and nodes to further determine whether the rewrite is applicable. If applicable, the
constructor returns the target sub-graph, otherwise it returns ``None``.
In the above example, the sub-graph pattern contains three input tensors, where x1 is a symbolic tensor and w1, w2 are
two concrete tensors (i.e., we know the concrete values of w1 and w2). There are three operators in the pattern, where
the first two are matrix multiplications and the last one is an addition. The output tensor of the addition is the
output tensor of the pattern. When we find a sub-graph that matches the pattern, we use the constructor to construct
the target sub-graph and replace the matched sub-graph with the target sub-graph.
.. note::
**Difference between sub-graph rewrite and operator resolving**. Although
:ref:`operator resolving ` can be conceptually considered as a special case of
sub-graph rewrite, we use a different mechanism to implement them and their execution order is different. The operator
resolving can be performed efficiently thus we can add arbitrary number of operator resolve rules. But the sub-graph
rewrite is usually more expensive. Second, we run the sub-graph rewrite pass before the operator resolving pass, so
that we can use the generic operators in the sub-graph patterns without worrying about the operator resolving.
Add a sub-graph rewrite rule
----------------------------
Let's implement the sub-graph rewrite rule shown in the above example. Before we start, we first create a new model
that contains the sub-graph we want to rewrite:
.. GENERATED FROM PYTHON SOURCE LINES 63-85
.. code-block:: default
from typing import Optional, List
import hidet
from hidet import Tensor, FlowGraph, Operator
from hidet import ops
from hidet.graph.transforms.graph_patterns import MatchDict
def example_model(x: Tensor, w0: Tensor, w1: Tensor, w2: Tensor):
x = ops.matmul(x, w0)
y1 = ops.matmul(x, w1)
y2 = ops.matmul(x, w2)
y = ops.relu(ops.concat([y1, y2], axis=1))
return y
x = hidet.symbol([3, 3])
w0, w1, w2 = hidet.randn([3, 3]), hidet.randn([3, 3]), hidet.randn([3, 3])
y = example_model(x, w0, w1, w2)
graph: FlowGraph = hidet.trace_from(y, inputs=[x])
print(graph)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Graph(x: float32[3, 3][cpu]){
c = Constant(float32[3, 3][cpu])
c_1 = Constant(float32[3, 3][cpu])
c_2 = Constant(float32[3, 3][cpu])
x_1: float32[3, 3][cpu] = Matmul(x, c, require_prologue=False)
x_2: float32[3, 3][cpu] = Matmul(x_1, c_1, require_prologue=False)
x_3: float32[3, 3][cpu] = Matmul(x_1, c_2, require_prologue=False)
x_4: float32[3, 6][cpu] = Concat(x_2, x_3, axis=1)
x_5: float32[3, 6][cpu] = Relu(x_4)
return x_5
}
.. GENERATED FROM PYTHON SOURCE LINES 86-88
Then, we define and register the sub-graph rewrite rule.
.. GENERATED FROM PYTHON SOURCE LINES 88-135
.. code-block:: default
from hidet.graph.ops.matmul import MatmulOp
from hidet.graph.ops.transform import ConcatOp
from hidet.graph.transforms import TensorPattern, SubgraphRewriteRule
from hidet.graph.transforms import op_pattern, register_rewrite_rule
from hidet.utils import same_list
# register the rewrite rule, only registered rewrite rules will be applied
@register_rewrite_rule
class FuseTwoMatmulRewriteRule(SubgraphRewriteRule):
def __init__(self):
super().__init__(name="new: [matmul(x, c1), matmul(x,c2)] => matmul(x, [c1, c2])")
self.x = TensorPattern() # x can match either a symbolic or concrete tensor
self.c1 = TensorPattern(is_const=True) # c1 can only match a concrete tensor
self.c2 = TensorPattern(is_const=True) # c2 can only match a concrete tensor
self.y1 = op_pattern(MatmulOp, [self.x, self.c1]) # pattern: y1 = matmul(x, c1)
self.y2 = op_pattern(MatmulOp, [self.x, self.c2]) # pattern: y2 = matmul(x, c2)
self.y = op_pattern(ConcatOp, [self.y1, self.y2]) # pattern: y = concat(y1, y2)
def source(self) -> List[TensorPattern]:
# Return the output tensors of the source sub-graph pattern.
return [self.y]
def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
# The target sub-graph constructor
# The matched dictionary has type Dict[TensorPattern, Tensor]
# that maps the patterns to the matched tensors.
x, c1, c2, y = [matched[t] for t in [self.x, self.c1, self.c2, self.y]]
concat: Operator = y.op
# We can use the matched tensors to determine whether the rewrite is applicable.
# For example, in this case, we check whether the two weight matrices share the
# same shape except the last dimension.
if (
2 <= len(c1.shape) == len(c2.shape)
and same_list(c1.shape[:-1], c2.shape[:-1])
and concat.attrs["axis"] == len(y.shape) - 1
):
# If applicable, we construct the target sub-graph and return the output tensors.
c = ops.concat([c1, c2], axis=-1)
y = ops.matmul(x, c)
return [y]
else:
# If not, we return None to indicate that the rewrite is not applicable.
return None
.. GENERATED FROM PYTHON SOURCE LINES 136-137
We can check that the rewrite rule has been registered:
.. GENERATED FROM PYTHON SOURCE LINES 137-144
.. code-block:: default
from hidet.graph.transforms import registered_rewrite_rules, clear_registered_rewrite_rules
print('Registered rewrite rules:')
for rule in registered_rewrite_rules():
assert isinstance(rule, SubgraphRewriteRule)
print(rule.name)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Registered rewrite rules:
a + x => x + a
x - a => x + (-a)
(x + a) + b => x + (a + b)
(x + a) * b => x * b + a * b
(x + a) + (y + b) => (x + y) + (a + b)
reshape(x) * scale
reshape(x) + bias
squeeze(x) * c => squeeze(x * c)
y1 = cast(x), y2 = cast(x) => y1 = y2 = z = cast(x)
y1,2,3 = cast(x) => y1 = y2 = y3 = z = cast(x)
cast(cast(x)) => x
binaryOp(unaryOp_left(x), unaryOp_right(x)) => compositeOp(x)
binaryOp(unaryOp(x), x) => compositeOp(x)
binaryOp(x, unaryOp(x)) => compositeOp(x)
conv2d(x, w) * scale => conv2d(x, w * scale)
conv2d(x, w1)|conv2d(x, w2)|conv2d(x, w3) => conv2d(x, w1 + w2 + w3)
conv2d(x, w1)|conv2d(x, w2) => conv2d(x, w1 + w2)
3 branches of matmul(x, branch c) + branch b ==> matmul(x, c) + b followed by split
matmul(x, c1)|matmul(x, c2)|matmul(x, c3) => matmul(x, concat(c1, c2, c3)) followed by split
matmul(x, c1)|matmul(x, c2) ==> matmul(x, concat(c1, c2)) followed by split
new: [matmul(x, c1), matmul(x,c2)] => matmul(x, [c1, c2])
.. GENERATED FROM PYTHON SOURCE LINES 145-150
Apply the rewrite rule
----------------------
Besides the predefined rewrite rules, we can see that the rewrite rule we just registered is also included at the
last line. In this tutorial, to prevent the default rewrite rules from being applied, we first clear the registered
rewrite rules and then register the rewrite rule we just defined:
.. GENERATED FROM PYTHON SOURCE LINES 150-153
.. code-block:: default
clear_registered_rewrite_rules()
register_rewrite_rule(FuseTwoMatmulRewriteRule()) # a second way to register the rewrite rule
.. GENERATED FROM PYTHON SOURCE LINES 154-155
The rewrite process is done in a graph optimization pass called `subgraph_rewrite_pass`.
.. GENERATED FROM PYTHON SOURCE LINES 155-161
.. code-block:: default
from hidet.graph.transforms import subgraph_rewrite_pass
rewrite_pass = subgraph_rewrite_pass()
rewritten_graph: FlowGraph = rewrite_pass(graph)
print(rewritten_graph)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Graph(x: float32[3, 3][cpu]){
c = Constant(float32[3, 3][cpu])
c_1 = Constant(float32[3, 6][cpu])
x_1: float32[3, 3][cpu] = Matmul(x, c, require_prologue=False)
x_2: float32[3, 6][cpu] = Matmul(x_1, c_1, require_prologue=False)
x_3: float32[3, 6][cpu] = Relu(x_2)
return x_3
}
.. GENERATED FROM PYTHON SOURCE LINES 162-167
We can see that the rewritten graph contains 2 matmul operators instead of 3. There is no concat operator in the
rewritten graph, because the inputs of concat operator are constant tensors and thus have been folded.
We do not need to call the rewrite pass explicitly. It will be called automatically when we call
:func:`hidet.graph.optimize`, together with other graph optimization passes.
.. GENERATED FROM PYTHON SOURCE LINES 167-170
.. code-block:: default
graph_opt: FlowGraph = hidet.graph.optimize(graph)
print(graph_opt)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Graph(x: float32[3, 3][cpu]){
c = Constant(float32[3, 3][cpu])
c_1 = Constant(float32[3, 6][cpu])
x_1: float32[3, 3][cpu] = Matmul(x, c, require_prologue=False)
x_2: float32[3, 6][cpu] = FusedMatmul(x_1, c_1, fused_graph=FlowGraph(Matmul, Relu), anchor=0)
return x_2
}
.. GENERATED FROM PYTHON SOURCE LINES 171-175
Summary
-------
In this tutorial, we have learned how to define and register a sub-graph rewrite rule. It is an important
component of the graph optimization framework. Hidet uses it to implement some horizontal fusion rules.
.. rst-class:: sphx-glr-timing
**Total running time of the script:** (0 minutes 0.254 seconds)
.. _sphx_glr_download_gallery_developer-guides_add-subgraph-rewrite-rule.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: add-subgraph-rewrite-rule.py `
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: add-subgraph-rewrite-rule.ipynb `
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery `_