Source code for hidet.ir.compute.reduce_operations

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from enum import Enum
from typing import Union
from hidet.ir.type import DataType, data_type
from hidet.ir.expr import Expr, Constant, LogicalAnd, LogicalOr


[docs]class ReduceType(Enum): Sum = 'sum' Product = 'prod' Max = 'max' Min = 'min' Average = 'avg' And = 'and' Or = 'or' def __str__(self): return self.value def __repr__(self): return self.value
class ReduceOperation: @staticmethod def from_name(name: Union[ReduceType, str]) -> ReduceOperation: name = ReduceType(name) name2operation = { ReduceType.Sum: SumReduce, ReduceType.Product: ProductReduce, ReduceType.Max: MaxReduce, ReduceType.Min: MinReduce, ReduceType.Average: AverageReduce, ReduceType.And: AndReduce, ReduceType.Or: OrReduce, } if name not in name2operation: raise ValueError('Can not recognize reduce type {}'.format(name)) return name2operation[name]() def __str__(self): return self.__class__.__name__.lower()[:-6] def initial_value(self, dtype: Union[DataType, str]) -> Constant: """ The initial value of the reduction. Parameters ---------- dtype: DataType The data type of elements to conduct the reduction. Returns ------- init_value: Constant The initial value of the reduction. """ raise NotImplementedError() def combine(self, lhs: Expr, rhs: Expr) -> Expr: """ Reduce two values. Parameters ---------- lhs: Expr The left hand side value. rhs: Expr The right hand side value. Returns ------- result: Expr The result of the reduction. """ raise NotImplementedError() def arg_combine(self, lhs_value: Expr, rhs_value: Expr): """ For some reductions like argmin and argmax, we need to combine the arg (index) instead of the value itself. This function returns True if the combine(lhs_value, rhs_value) == lhs_value, otherwise False. Only need to override this function if the reduction supports arg_reduce (e.g., argmin, argmax). Parameters ---------- lhs_value: Expr The left hand side value. rhs_value: Expr The right hand side value. Returns ------- result: bool True if the combine(lhs_value, rhs_value) == lhs_value, otherwise False. """ raise ValueError('{} reduction does not argument reduce.'.format(str(self))) def require_finalize(self) -> bool: """ Whether the reduction requires a finalization step. For some reduction, the finalization step is required to get the final result. For example, the average reduction requires a finalization step to divide the sum by the size of the reduction. Returns ------- result: bool True if the reduction requires a finalization step, otherwise False. """ return False def finalize(self, acc: Expr, size: Expr) -> Expr: """ Finalize the reduction result. Parameters ---------- acc: Expr The accumulated value. size: Expr The number of elements to conduct the reduction. Returns ------- result: Expr The final result of the reduction. """ return acc def has_atomic(self, dtype: Union[DataType, str]): """ Whether the reduction has atomic reduction support For some reduction, cuda natively supports atomic operations (min, max, sum) Returns ------- result: bool True if the reduction requires has an atomic operation associated """ return False def atomic_combine(self, acc: Expr, rhs_value: Expr): """ Performs atomic combine to the acc variable by including the rhs value Parameters ---------- acc: Expr a memory address in shared or global memory rhs_value: Expr new value to be atomically combined Returns ------- None, the combine is inplace to acc """ raise NotImplementedError() class MinReduce(ReduceOperation): def initial_value(self, dtype: Union[DataType, str]) -> Expr: if isinstance(dtype, str): dtype = data_type(dtype) return dtype.max_value def combine(self, lhs: Expr, rhs: Expr) -> Expr: from hidet.ir import primitives # pylint: disable=import-outside-toplevel return primitives.min(lhs, rhs) def arg_combine(self, lhs_value: Expr, rhs_value: Expr): from hidet.ir.expr import LessThan # pylint: disable=import-outside-toplevel return LessThan(lhs_value, rhs_value) def has_atomic(self, dtype: Union[DataType, str]): if isinstance(dtype, str): dtype = data_type(dtype) return dtype.is_integer() def atomic_combine(self, acc: Expr, rhs_value: Expr): from hidet.lang.cuda import atomic_min return atomic_min(acc, rhs_value) class MaxReduce(ReduceOperation): def initial_value(self, dtype: Union[DataType, str]) -> Constant: if isinstance(dtype, str): dtype = data_type(dtype) return dtype.min_value def combine(self, lhs: Expr, rhs: Expr) -> Expr: from hidet.ir import primitives # pylint: disable=import-outside-toplevel return primitives.max(lhs, rhs) def arg_combine(self, lhs_value: Expr, rhs_value: Expr): from hidet.ir.expr import LessThan # pylint: disable=import-outside-toplevel return LessThan(rhs_value, lhs_value) def has_atomic(self, dtype: Union[DataType, str]): if isinstance(dtype, str): dtype = data_type(dtype) return dtype.is_integer() def atomic_combine(self, acc: Expr, rhs_value: Expr): from hidet.lang.cuda import atomic_max return atomic_max(acc, rhs_value) class SumReduce(ReduceOperation): def initial_value(self, dtype: Union[DataType, str]) -> Constant: if isinstance(dtype, str): dtype = data_type(dtype) return dtype.zero def combine(self, lhs: Expr, rhs: Expr) -> Expr: return lhs + rhs def has_atomic(self, dtype: Union[DataType, str]): return True def atomic_combine(self, acc: Expr, rhs_value: Expr): from hidet.lang.cuda import atomic_add return atomic_add(acc, rhs_value) class AverageReduce(ReduceOperation): def initial_value(self, dtype: Union[DataType, str]) -> Constant: if isinstance(dtype, str): dtype = data_type(dtype) return dtype.zero def combine(self, lhs: Expr, rhs: Expr) -> Expr: return lhs + rhs def require_finalize(self) -> bool: return True def finalize(self, acc: Expr, size: Expr) -> Expr: return acc / size def has_atomic(self, dtype: Union[DataType, str]): return True def atomic_combine(self, acc: Expr, rhs_value: Expr): from hidet.lang.cuda import atomic_add return atomic_add(acc, rhs_value) class AndReduce(ReduceOperation): def initial_value(self, dtype: Union[DataType, str]) -> Constant: dtype = data_type(dtype) assert dtype.name == 'bool', 'AndReduce only support bool type' return dtype.one def combine(self, lhs: Expr, rhs: Expr) -> Expr: return LogicalAnd(lhs, rhs) class OrReduce(ReduceOperation): def initial_value(self, dtype: Union[DataType, str]) -> Constant: dtype = data_type(dtype) assert dtype.name == 'bool', 'OrReduce only support bool type' return dtype.zero def combine(self, lhs: Expr, rhs: Expr) -> Expr: return LogicalOr(lhs, rhs) class ProductReduce(ReduceOperation): def initial_value(self, dtype: Union[DataType, str]) -> Constant: return dtype.one def combine(self, lhs: Expr, rhs: Expr) -> Expr: return lhs * rhs