Sub-graph Rewrite Pass¶
- class hidet.graph.transforms.subgraph_rewrite.TensorPattern(is_const=False, is_symbolic=False, trace=None)¶
The tensor pattern represents a tensor in the pattern graph.
- class hidet.graph.transforms.subgraph_rewrite.OperatorPattern(op_cls, inputs, num_outputs=1)¶
The operator pattern represents an operator in the pattern graph.
- class hidet.graph.transforms.subgraph_rewrite.SubgraphRewriteRule(name='')¶
A sub-graph rewrite rule defines a sub-graph pattern (called source) to match in the computation graph, and a target sub-graph constructor to replace the matched sub-graph.
When defining a new sub-graph rewrite rule, you need to define a new class inherited from SubgraphRewriteRule and implement the source() and target() methods. The source() method returns a list of output tensors in the sub-graph pattern while the target() method returns a list of output tensors in the target sub-graph, given the match dict that maps the tensors/operators in the pattern to the matched tensors/operators in the computation graph.
After defining the sub-graph rewrite rule, you need to register it to the sub-graph rewrite rule registry via
The output tensors in the source template graph to match in the computation graph.
- Return type:
The output tensors in the target sub-graph used to replace the matched pattern. Return None means failed to generate the target sub-graph, and we should not do the transformation.
Register a sub-graph rewrite rule.