Sub-graph Rewrite Pass

class hidet.graph.transforms.subgraph_rewrite.TensorPattern(is_const=False, is_symbolic=False, trace=None)[source]

The tensor pattern represents a tensor in the pattern graph.

class hidet.graph.transforms.subgraph_rewrite.OperatorPattern(op_cls, inputs, num_outputs=1)[source]

The operator pattern represents an operator in the pattern graph.

class hidet.graph.transforms.subgraph_rewrite.SubgraphRewriteRule(name='')[source]

A sub-graph rewrite rule defines a sub-graph pattern (called source) to match in the computation graph, and a target sub-graph constructor to replace the matched sub-graph.

When defining a new sub-graph rewrite rule, you need to define a new class inherited from SubgraphRewriteRule and implement the source() and target() methods. The source() method returns a list of output tensors in the sub-graph pattern while the target() method returns a list of output tensors in the target sub-graph, given the match dict that maps the tensors/operators in the pattern to the matched tensors/operators in the computation graph.

After defining the sub-graph rewrite rule, you need to register it to the sub-graph rewrite rule registry via register_rewrite_rule().

source()[source]

The output tensors in the source template graph to match in the computation graph.

Return type:

List[TensorPattern]

target(matched)[source]

The output tensors in the target sub-graph used to replace the matched pattern. Return None means failed to generate the target sub-graph, and we should not do the transformation.

Parameters:

matched (Dict[TensorPattern | OperatorPattern, Tensor | Operator]) –

Return type:

List[Tensor] | None

hidet.graph.transforms.subgraph_rewrite.register_rewrite_rule(rule)[source]

Register a sub-graph rewrite rule.

Parameters:

rule (SubgraphRewriteRule or Type[SubgraphRewriteRule]) – The rule to be registered. If it is a type, it will be instantiated with default arguments. Otherwise, it should be an instance of SubgraphRewriteRule.