Source code for hidet.graph.transforms.graph_patterns.base

from __future__ import annotations
from typing import List, Optional, Dict, Union, Tuple, Type
from hidet.graph.flow_graph import Operator, Tensor
from hidet.graph.ops.arithmetic import AddOp, SubtractOp, MultiplyOp, NegativeOp

[docs]class TensorPattern: """ The tensor pattern represents a tensor in the pattern graph. """ def __init__(self, is_const=False, is_symbolic=False, trace=None): self.is_const: bool = is_const self.is_symbolic: bool = is_symbolic assert not (is_const and is_symbolic), 'Can not be const and symbolic at the same time' self.trace: Optional[Tuple[OperatorPattern, int]] = trace self.uses: List[Tuple[OperatorPattern, int]] = [] def __repr__(self): if self.trace is None: if self.is_const: return 'c' if self.is_symbolic: return 's' return 'v' else: op, idx = self.trace op_str = str(op) if len(op.outputs) == 1: return op_str else: return '{}[{}]'.format(op_str, idx) def __add__(self, other): return OperatorPattern(AddOp, inputs=[self, other]).outputs[0] def __sub__(self, other): return OperatorPattern(SubtractOp, inputs=[self, other]).outputs[0] def __mul__(self, other): return OperatorPattern(MultiplyOp, inputs=[self, other]).outputs[0] def __neg__(self): return OperatorPattern(NegativeOp, inputs=[self]).outputs[0] def op(self) -> Optional[OperatorPattern]: if self.trace is None: return None else: return self.trace[0] def add_use(self, op: OperatorPattern, idx: int): self.uses.append((op, idx)) @staticmethod def tensor(is_const=False, is_symbolic=False): return TensorPattern(is_const, is_symbolic) @staticmethod def tensors(num, is_const=False, is_symbolic=False): return [TensorPattern(is_const, is_symbolic) for _ in range(num)]
[docs]class OperatorPattern: """ The operator pattern represents an operator in the pattern graph. """ def __init__(self, op_cls, inputs, num_outputs=1): self.op_cls = op_cls self.inputs: List[TensorPattern] = inputs self.outputs = [TensorPattern(is_symbolic=True, trace=(self, idx)) for idx in range(num_outputs)] for idx, input_tensor in enumerate(self.inputs): input_tensor.add_use(self, idx) def __repr__(self): input_items = [str(v) for v in self.inputs] unary_ops = {NegativeOp: '-'} binary_ops = {AddOp: '+', SubtractOp: '-', MultiplyOp: '*'} if self.op_cls in unary_ops: return '({}{})'.format(unary_ops[self.op_cls], input_items[0]) elif self.op_cls in binary_ops: return '({} {} {})'.format(input_items[0], binary_ops[self.op_cls], input_items[1]) else: return '{}({})'.format(self.op_cls.__name__[:-2], ', '.join(input_items))
MatchDict = Dict[Union[TensorPattern, OperatorPattern], Union[Tensor, Operator]]
[docs]class SubgraphRewriteRule: """ A sub-graph rewrite rule defines a sub-graph pattern (called source) to match in the computation graph, and a target sub-graph constructor to replace the matched sub-graph. When defining a new sub-graph rewrite rule, you need to define a new class inherited from SubgraphRewriteRule and implement the source() and target() methods. The source() method returns a list of output tensors in the sub-graph pattern while the target() method returns a list of output tensors in the target sub-graph, given the match dict that maps the tensors/operators in the pattern to the matched tensors/operators in the computation graph. After defining the sub-graph rewrite rule, you need to register it to the sub-graph rewrite rule registry via :func:`register_rewrite_rule`. """ def __init__(self, name=""): = name if name else self.__class__.__name__
[docs] def source(self) -> List[TensorPattern]: """ The output tensors in the source template graph to match in the computation graph. """ raise NotImplementedError()
[docs] def target(self, matched: MatchDict) -> Optional[List[Tensor]]: """ The output tensors in the target sub-graph used to replace the matched pattern. Return None means failed to generate the target sub-graph, and we should not do the transformation. """ raise NotImplementedError()
def op_pattern( op_cls: Type[Operator], input_patterns: List[TensorPattern], num_outputs=1 ) -> Union[TensorPattern, List[TensorPattern]]: """ Create an operator pattern with the given operator class and input patterns, and return the output patterns. Parameters ---------- op_cls: Type[Operator] The operator class. This operator pattern will only be matched to an operator of the same class. input_patterns: List[TensorPattern] The input patterns of the operator. num_outputs: int The number of output tensors of the operator.Default is 1. Returns ------- ret: Union[TensorPattern, List[TensorPattern]] The output tensor pattern(s) of the operator. """ op = OperatorPattern(op_cls, input_patterns, num_outputs) if num_outputs == 1: return op.outputs[0] else: return op.outputs Usage = Dict[Tensor, List[Tuple[Optional[Operator], int]]] class NotMatchedException(Exception): pass class PatternMatcher: """ PatternMatcher matches a pattern to a subgraph in a larger graph. It starts from a tensor, or an operator, and tries to match the subgraph spanned from the start point. The spanning rules: 1. A tensor spans to its producing operator and its consuming operators (i.e., uses). 2. An operator spans to its input and output tensors. The matching rules: 1. For tensor: a) check the storage requirement (e.g., constant and symbolic) b) check the output index in the producer's output array 2. For operator: a) check the operator type. Because the operator also spans to its outputs, as long as the pattern is connected, we only need to start from a single tensor or operator. """ def __init__(self, usage: Usage): self.matched = {} self.reverse_matched = {} self.usage: Usage = usage @staticmethod def check(cond: bool, msg=""): if not cond: raise NotMatchedException(msg) def match(self, pattern, target): key = pattern if not isinstance(pattern, list) else id(pattern) if key in self.matched: self.check(target is self.matched[key], 'tried to match a pattern to two different objects') # pattern has been matched to a different target return self.matched[key] = target self.reverse_matched[target] = key if isinstance(pattern, (list, tuple)): self.match_Sequence(pattern, target) elif isinstance(pattern, TensorPattern): self.match_TensorPattern(pattern, target) elif isinstance(pattern, OperatorPattern): self.match_OperatorPattern(pattern, target) else: raise NotImplementedError() def match_Sequence(self, pattern, target): self.check(isinstance(target, (list, tuple)), 'target should be tuple or list') self.check(len(pattern) == len(target), 'sequence length does not match') for a, b in zip(pattern, target): self.match(a, b) def match_TensorPattern(self, pattern: TensorPattern, target): self.check(isinstance(target, Tensor), "expect target with type 'Tensor'") if pattern.is_const: self.check( is not None, 'requires const tensor') return if pattern.is_symbolic: self.check( is None, 'requires symbolic tensor') # spans to its inputs if pattern.trace: self.check(target.trace is not None) self.check(pattern.trace[1] == target.trace[1]) self.match(pattern.trace[0], target.trace[0]) # spans to its uses desire_uses: List[Tuple[OperatorPattern, int]] = pattern.uses actual_uses: List[Tuple[Optional[Operator], int]] = self.usage[target] for desire_use in desire_uses: desire_operator, desire_index = desire_use # pylint: disable=unused-variable if desire_operator in self.matched: # this desire operator in pattern has been spanned continue spanned = False for actual_use in actual_uses: actual_operator, actual_index = actual_use # pylint: disable=unused-variable if actual_operator in self.reverse_matched: # this actual operator has been matched continue if not issubclass( type(actual_operator), desire_operator.op_cls ): # pylint: disable=unidiomatic-typecheck continue self.match(desire_operator, actual_operator) spanned = True break self.check(spanned, "A usage of input tensor has not been spanned.") def match_OperatorPattern(self, pattern: OperatorPattern, target: Operator): self.check(isinstance(target, pattern.op_cls), "expect target with type 'Operator'") self.check(issubclass(target.__class__, pattern.op_cls)) assert len(pattern.inputs) == len(target.inputs) and len(pattern.outputs) == len(target.outputs) for a, b in zip(pattern.inputs, target.inputs): self.match(a, b) for a, b in zip(pattern.outputs, target.outputs): self.match(a, b) def graph_pattern_match(pattern: TensorPattern, target: Tensor, usage: Usage) -> Optional[MatchDict]: # peek for early stop, only for performance if pattern.trace is None: if target.trace is not None: return None if (pattern.is_const and is None) or (pattern.is_symbolic and is not None): return None return {pattern: target} if pattern.trace and target.trace and not issubclass(target.trace[0].__class__, pattern.trace[0].op_cls): return None # formal match matcher = PatternMatcher(usage) try: matcher.match(pattern, target) return matcher.matched except NotMatchedException: return None _registered_rewrite_rules: List[SubgraphRewriteRule] = [] def registered_rewrite_rules(): # pylint: disable=unused-import from . import register_all_patterns # register on demand return list(_registered_rewrite_rules) def clear_registered_rewrite_rules(): _registered_rewrite_rules.clear()
[docs]def register_rewrite_rule(rule: Union[SubgraphRewriteRule, Type[SubgraphRewriteRule]]): """ Register a sub-graph rewrite rule. Parameters ---------- rule: SubgraphRewriteRule or Type[SubgraphRewriteRule] The rule to be registered. If it is a type, it will be instantiated with default arguments. Otherwise, it should be an instance of SubgraphRewriteRule. """ if isinstance(rule, SubgraphRewriteRule): _registered_rewrite_rules.append(rule) return None elif issubclass(rule, SubgraphRewriteRule): _registered_rewrite_rules.append(rule()) return rule else: raise TypeError('rule should be a SubgraphRewriteRule or a subclass of SubgraphRewriteRule')
def deregister_rewrite_rule(rule: SubgraphRewriteRule): """ Remove a sub-graph rewrite rule from list of currently registered rules Parameters ---------- rule: SubgraphRewriteRule The rule to be deregistered. """ if isinstance(rule, SubgraphRewriteRule): _registered_rewrite_rules.remove(rule) return None else: raise TypeError('rule should be a SubgraphRewriteRule')