Resolve Operator Pass
Resolve Operator Pass¶
- class hidet.graph.transforms.resolve_variant.ResolveRule[source]¶
A resolve rule defines how to resolve an operator to other operators.
- resolve(op)[source]¶
When define a resolve rule, the user should subclass this class and override this method.
- Parameters
op (Operator) – The operator to be resolved.
- Returns
ret – This function should return a list of tensors if the operator can be resolved, otherwise return None. In the first case, the returned tensors will be used to replace the outputs of the original operator, thus the number of tensors should be the same as the number of outputs of the original operator.
- Return type
Optional[List[Tensor]]
- hidet.graph.transforms.resolve_variant.register_resolve_rule(op_cls)[source]¶
Register a resolve rule for an operator class.
- Parameters
op_cls (Type[Operator]) – The operator class to be registered.
- Returns
ret – The decorator function.
- Return type
Callable[[Type[ResolveRule]], Type[ResolveRule]]
Notes
In the following example, we define a resolve rule for operator
PowOp
to resolvepow(x, 2.0)
tosquare(x)
.from hidet.ir import Tensor from hidet import ops from hidet.graph.ops.definitions import PowOp from hidet.graph.transforms import ResolveRule, register_resolve_rule @register_resolve_rule(PowOp) class AddResolveRule(ResolveRule): def resolve(self, op: PowOp) -> Optional[List[Tensor]]: a: Tensor = op.inputs[0] b: Tensor = op.inputs[1] if not b.is_symbolic() and len(b.shape) == 0 and b.scalar() == 2: return [ops.square(a)] return None