Resolve Operator Pass

class hidet.graph.transforms.resolve_variant.ResolveRule[source]

A resolve rule defines how to resolve an operator to other operators.

resolve(op)[source]

When define a resolve rule, the user should subclass this class and override this method.

Parameters:

op (Operator) – The operator to be resolved.

Returns:

ret – This function should return a list of tensors if the operator can be resolved, otherwise return None. In the first case, the returned tensors will be used to replace the outputs of the original operator, thus the number of tensors should be the same as the number of outputs of the original operator.

Return type:

List[Tensor], optional

hidet.graph.transforms.resolve_variant.register_resolve_rule(op_cls)[source]

Register a resolve rule for an operator class.

Parameters:

op_cls (Type[Operator]) – The operator class to be registered.

Returns:

ret – The decorator function.

Return type:

Callable[[Type[ResolveRule]], Type[ResolveRule]]

Notes

In the following example, we define a resolve rule for operator PowOp to resolve pow(x, 2.0) to square(x).

from hidet.ir import Tensor
from hidet import ops
from hidet.graph.ops import PowOp
from hidet.graph.transforms import ResolveRule, register_resolve_rule

@register_resolve_rule(PowOp)
class AddResolveRule(ResolveRule):
    def resolve(self, op: PowOp) -> Optional[List[Tensor]]:
        a: Tensor = op.inputs[0]
        b: Tensor = op.inputs[1]
        if not b.is_symbolic() and len(b.shape) == 0 and b.scalar() == 2:
            return [ops.square(a)]
        return None