Add Operator Resolve Rule

This is a tutorial introduces the operator resolving mechanism and how to add a new operator resolve rule. An operator resolve rule is used to resolve an operator to other operators. Usually, we would resolve a more generic operator to more specific and efficient operators. The operator resolving rules allow us to reuse existing highly-optimized operators to implement a new operator, while organizing the operators in a more modular way.

Operator Resolving

The core idea of the operator resolving is to resolve a generic operator to more specific and high-optimized operators. When we define a new operator, we can also attach an operator resolve rule to it. The rule defines how to resolve the operator to other operators with the same semantics. After the operator is resolved, the original operator will be replaced by the resolved operators. This process is transparent to the user and is done automatically by a pass when we optimize a flow graph.

There are typical two scenarios that we need to resolve an operator to other operators:

  • Resolve a generic operator to specialized variants: We can provide a generic operator and lots of its specialized variants. When optimizing the model, we can resolve the generic operator to the most suitable specialized operator. For example, in Hidet, we provided a generic matmul() operator with the same semantics as the numpy equivalent numpy.matmul. This operator is a generic operator and is scheduled automatically by our auto-scheduler, thus it is not very efficient. But we also provided a lot of specialized variants of the operators such as highly-optimized batch_matmul() that only accepts \(A=[B, M, K]\) and \(B=[B, K, N]\). During the operator resolving, we first reshape and broadcast the input tensors to align the input shapes with the specialized operator, then use the specialized operator to compute the result, and finally reshape the output tensor to get the correct output shape.

Tip

During the operator resolving, we might introduce some extra operators to adjust the input tensors (e.g., reshape(), broadcast(), transpose(), etc.). These extra operators are usually fused into the resolved operators or surrounding operators of the original operator in the later optimization pass. Thus, the extra overhead is usually negligible.

../../_images/resolve-example-matmul.svg

The resolve rule for Matmul operator.

  • Reuse a new operator to existing operators: When we add a new operator and the new operator can be implemented by existing operators, we can use a resolve rule to resolve the new operator to the existing highly-optimized operators to reduce the development effort.

../../_images/resolve-example-conv2d.svg

This rule resolves the generic conv2d() operator to matrix multiplication using the img2col algorithm.

The operator resolving pass would repeat the resolving process until no more operators can be resolved. Thus, in the conv2d example, we will first resolve conv2d() to matmul(), and then to batch_matmul().

Add Operator Resolve Rule

To add a resolve rule to an operator, we need to

  1. define a subclass of ResolveRule and then

  2. register the rule by decorating it with register_resolve_rule().

In the following example, we resolve the pow() operator to normal multiplications if the exponent is a constant integer 3.

Before we start, let’s have a look at the original behavior when there is no such resolve rule.

import hidet

a = hidet.symbol(shape=[2, 3], device='cuda')
b = hidet.ops.pow(a, hidet.asarray(3, device='cuda'))
graph = hidet.trace_from(b, inputs=[a])
print('Original graph:')
print(graph)

print('Optimized graph without resolving Pow:')
graph_opt = hidet.graph.optimize(graph)
print(graph_opt)
Original graph:
Graph(x: float32[2, 3][cuda]){
  c = Constant(int64[][cuda])
  x_1: float32[2, 3][cuda] = Pow(x, c)
  return x_1
}
Optimized graph without resolving Pow:
Graph(x: float32[2, 3][cuda]){
  c = Constant(int64[][cuda])
  x_1: float32[2, 3][cuda] = Pow(x, c)
  return x_1
}

The original graph contains a pow() operator, and the optimized graph is the same as the original graph. Now let’s add the resolve rule and see what happens.

from typing import Optional, List
from hidet import Tensor
from hidet.graph.ops.arithmetic import PowOp
from hidet.graph.transforms import register_resolve_rule, ResolveRule


@register_resolve_rule(PowOp)
class PowResolveRule(ResolveRule):
    def resolve(self, op: PowOp) -> Optional[List[Tensor]]:
        a: Tensor = op.inputs[0]  # get the base tensor
        b: Tensor = op.inputs[1]  # get the exponent tensor
        if not b.is_symbolic() and len(b.shape) == 0 and int(b) == 3:
            # if the exponent is a constant integer 3, resolve the operator to a * a * a
            return [a * a * a]
        # otherwise, return None to indicate that the operator cannot be resolved
        # and the original operator will be kept
        return None


# optimize the original graph again
# the Pow operator will be resolved to a * a * a
# after that, the two multiplications will be fused into one operator
graph_opt_new = hidet.graph.optimize(graph)
print('Optimized graph after resolving Pow:')
print(graph_opt_new)
Optimized graph after resolving Pow:
Graph(x: float32[2, 3][cuda]){
  x_1: float32[2, 3][cuda] = FusedMultiply(x, fused_graph=FlowGraph(Multiply, Multiply), anchor=1)
  return x_1
}

See also

register_resolve_rule(), ResolveRule for the details of the resolve rule.

Summary

In this tutorial, we learned how to resolve an operator to other operators. We also learned how to add a resolve rule to an operator. The resolve rule is a powerful tool to reuse existing operators to implement new operators. We can also use it to resolve a generic operator to more specialized variants.

Total running time of the script: (0 minutes 0.081 seconds)

Gallery generated by Sphinx-Gallery