Resolve Operator Pass

class hidet.graph.transforms.resolve_variant.ResolveRule[source]

A resolve rule defines how to resolve an operator to other operators.

resolve(op)[source]

When define a resolve rule, the user should subclass this class and override this method.

Parameters

op (Operator) – The operator to be resolved.

Returns

ret – This function should return a list of tensors if the operator can be resolved, otherwise return None. In the first case, the returned tensors will be used to replace the outputs of the original operator, thus the number of tensors should be the same as the number of outputs of the original operator.

Return type

Optional[List[Tensor]]

hidet.graph.transforms.resolve_variant.register_resolve_rule(op_cls)[source]

Register a resolve rule for an operator class.

Parameters

op_cls (Type[Operator]) – The operator class to be registered.

Returns

ret – The decorator function.

Return type

Callable[[Type[ResolveRule]], Type[ResolveRule]]

Notes

In the following example, we define a resolve rule for operator PowOp to resolve pow(x, 2.0) to square(x).

from hidet.ir import Tensor
from hidet import ops
from hidet.graph.ops.definitions import PowOp
from hidet.graph.transforms import ResolveRule, register_resolve_rule

@register_resolve_rule(PowOp)
class AddResolveRule(ResolveRule):
    def resolve(self, op: PowOp) -> Optional[List[Tensor]]:
        a: Tensor = op.inputs[0]
        b: Tensor = op.inputs[1]
        if not b.is_symbolic() and len(b.shape) == 0 and b.scalar() == 2:
            return [ops.square(a)]
        return None