Source code for hidet.utils.py

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TypeVar, Iterable, Tuple, List, Union, Sequence, Optional
import cProfile
import contextlib
import io
import itertools
import os
import pstats
import time
import numpy as np
from tabulate import tabulate


def unique(seq: Sequence) -> List:
    added = set()
    return [item for item in seq if (item not in added and not added.add(item))]


def prod(seq: Iterable):
    seq = list(seq)
    if len(seq) == 0:
        return 1
    else:
        c = seq[0]
        for i in range(1, len(seq)):
            c = c * seq[i]
        return c


def median(seq: Iterable):
    seq = list(seq)
    if len(seq) == 0:
        return None
    else:
        return sorted(seq)[len(seq) // 2]


def clip(
    x: Union[int, float], low: Optional[Union[int, float]], high: Optional[Union[int, float]]
) -> Union[int, float]:
    if low is not None:
        x = max(x, low)
    if high is not None:
        x = min(x, high)
    return x


TypeA = TypeVar('TypeA')
TypeB = TypeVar('TypeB')


def strict_zip(a: Sequence[TypeA], b: Sequence[TypeB]) -> Iterable[Tuple[TypeA, TypeB]]:
    if len(a) != len(b):
        raise ValueError(
            'Expect two sequence have the same length in zip, ' 'got length {} and {}.'.format(len(a), len(b))
        )
    return zip(a, b)


class COLORS:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    MAGENTA = '\033[95m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def green(v, fmt='{}'):
    return COLORS.OKGREEN + fmt.format(v) + COLORS.ENDC


def cyan(v, fmt='{}'):
    return COLORS.OKCYAN + fmt.format(v) + COLORS.ENDC


def blue(v, fmt='{}'):
    return COLORS.OKBLUE + fmt.format(v) + COLORS.ENDC


def yellow(v, fmt='{}'):
    return COLORS.WARNING + fmt.format(v) + COLORS.ENDC


def red(v, fmt='{}'):
    return COLORS.FAIL + fmt.format(v) + COLORS.ENDC


def magenta(v, fmt='{}'):
    return COLORS.MAGENTA + fmt.format(v) + COLORS.ENDC


def bold(v, fmt='{}'):
    return COLORS.BOLD + fmt.format(v) + COLORS.ENDC


def color(v, fmt='{}', fg='default', bg='default'):
    fg_code = {
        "black": 30,
        "red": 31,
        "green": 32,
        "yellow": 33,
        "blue": 34,
        "magenta": 35,
        "cyan": 36,
        "white": 37,
        "default": 39,
    }
    bg_code = {
        "black": 40,
        "red": 41,
        "green": 42,
        "yellow": 43,
        "blue": 44,
        "magenta": 45,
        "cyan": 46,
        "white": 47,
        "default": 49,
    }
    return '\033[{};{}m{}\033[0m'.format(fg_code[fg], bg_code[bg], fmt.format(v))


def color_table():
    fg_names = ["default", "black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"]
    bg_names = ["default", "black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"]
    print('{:>10} {:>10}   {:<10}'.format('fg', 'bg', 'text'))
    for bg in bg_names:
        for fg in fg_names:
            print('{:>10} {:>10}   {}'.format(fg, bg, color('sample text', fg=fg, bg=bg)))


def color_rgb(v, fg, fmt='{}'):
    fmt = '\033[38;2;{};{};{}m'.format(fg[0], fg[1], fg[2]) + fmt + '\033[0m'
    return fmt.format(v)


def color_text(v, fmt='{}', idx: int = 0):
    if idx == 0:
        return fmt.format(v)
    colors = {1: (153, 96, 52), 2: (135, 166, 73)}
    return color_rgb(v, colors[idx], fmt=fmt)


def nocolor(s: str) -> str:
    for value in COLORS.__dict__.values():
        if isinstance(value, str) and value[0] == '\033':
            s = s.replace(value, '')
    return s


def str_indent(msg: str, indent=0) -> str:
    lines = msg.split('\n')
    lines = [' ' * indent + line for line in lines]
    return '\n'.join(lines)


class Timer:
    def __init__(self, msg=None, file=None, verbose=True, stdout=True):
        self.start_time = None
        self.end_time = None
        self.msg = msg
        self.stdout = stdout
        self.verbose = verbose
        self.file = file

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end_time = time.time()
        if self.msg is not None and self.verbose:
            msg = '{} {}'.format(self.msg, green(self.time2str(self.end_time - self.start_time)))
            if self.stdout:
                print(msg)
            if self.file:
                if isinstance(self.file, str):
                    with open(self.file, 'w') as f:
                        f.write(nocolor(msg))
                else:
                    self.file.write(msg + '\n')

    def elapsed_seconds(self) -> float:
        return self.end_time - self.start_time

    def time2str(self, seconds: float) -> str:
        if seconds < 1:
            return '{:.1f} {}'.format(seconds * 1000, 'ms')
        elif seconds < 60:
            return '{:.1f} {}'.format(seconds, 'seconds')
        elif seconds < 60 * 60:
            return '{:.1f} {}'.format(seconds / 60, 'minutes')
        else:
            return '{:.1f} {}'.format(seconds / 60 / 60, 'hours')


def repeat_until_converge(func, obj, limit=None):
    i = 0
    while True:
        i += 1
        orig_obj = obj
        obj = func(obj)
        if obj is orig_obj:
            return obj
        if limit is not None and i >= limit:
            return obj


def get_next_file_index(dirname: str) -> int:
    indices = set()
    for fname in os.listdir(dirname):
        parts = fname.split('_')
        with contextlib.suppress(ValueError):
            indices.add(int(parts[0]))
    for idx in itertools.count(0):
        if idx not in indices:
            return idx
    return -1


[docs]def factorize(n): """ example: factor(12) => [1, 2, 3, 4, 6, 12] """ i = 1 ret = [] while i * i <= n: if n % i == 0: ret.append(i) if i * i != n: ret.append(n // i) i += 1 return list(sorted(ret))
def _is_immutable(obj): from hidet.ir.expr import Constant from hidet.graph.operator import Device if isinstance(obj, (int, float, str, tuple)): return True if isinstance(obj, Constant) and obj.type.is_tensor(): return False if isinstance(obj, (Constant, Device)): return True return False def same_list(lhs, rhs, use_equal=False): if len(lhs) != len(rhs): return False for l, r in zip(lhs, rhs): if use_equal or _is_immutable(l) and _is_immutable(r): if l != r: return False else: if l is not r: return False return True def index_of(value: object, lst: Sequence, allow_missing=True) -> int: for i, v in enumerate(lst): if v is value: return i if allow_missing: return -1 else: raise ValueError('value not found in list') class HidetProfiler: def __init__(self, display_on_exit=True): self.pr = cProfile.Profile() self.display_on_exit = display_on_exit def __enter__(self): self.pr.enable() def __exit__(self, exc_type, exc_val, exc_tb): self.pr.disable() if self.display_on_exit: print(self.result()) def result(self): s = io.StringIO() ps = pstats.Stats(self.pr, stream=s).sort_stats('cumulative') ps.print_stats() return str(s.getvalue()) class TableRowContext: def __init__(self, tb): self.tb = tb self.row = [] def __iadd__(self, other): if isinstance(other, (tuple, list)): self.row.extend(other) else: self.row.append(other) def append(self, other): self.row.append(other) def extend(self, other): self.row.extend(other) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.tb.rows.append(self.row) class TableBuilder: def __init__(self, headers=tuple(), tablefmt='simple', floatfmt='.3f'): self.headers = list(headers) self.rows = [] self.tablefmt = tablefmt self.floatfmt = floatfmt def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass def __iadd__(self, row): self.rows.append(row) return self def __str__(self): return str(tabulate(self.rows, self.headers, tablefmt=self.tablefmt, floatfmt=self.floatfmt)) def new_row(self) -> TableRowContext: return TableRowContext(tb=self) def extend_header(self, column_names): self.headers.extend(column_names)
[docs]def initialize(*args, **kwargs): """ Decorate an initialization function. After decorating with this function, the initialization function will be called after the definition. Parameters ---------- args: The positional arguments of initializing. kwargs: The keyword arguments of initializing. Returns ------- ret: A decorator that will call given function with args and kwargs, and return None (to prevent this function to be called again). """ def decorator(f): f(*args, **kwargs) return decorator
[docs]def gcd(a: int, b: int) -> int: """ Get the greatest common divisor of non-negative integers a and b. Parameters ---------- a: int The lhs operand. b: int The rhs operand. Returns ------- ret: int The greatest common divisor. """ assert a >= 0 and b >= 0 return a if b == 0 else gcd(b, a % b)
[docs]def lcm(a: int, b: int) -> int: """ Get the least common multiple of non-negative integers a and b. Parameters ---------- a: int The lhs operand. b: int The rhs operand. Returns ------- ret: int The least common multiple. """ return a // gcd(a, b) * b
def is_power_of_two(n: int) -> bool: """ Check if an integer is a power of two: 1, 2, 4, 8, 16, 32, ... Parameters ---------- n: int The integer to check. Returns ------- ret: bool True if n is a power of two, False otherwise. """ return n > 0 and (n & (n - 1)) == 0 def cdiv(n: int, d: int) -> int: """ Get the ceiling of n / d. Parameters ---------- n: int The numerator. d: int The denominator. Returns ------- ret: int The ceiling of n / d. """ return (n + d - 1) // d
[docs]def error_tolerance(a: Union[np.ndarray, 'Tensor'], b: Union[np.ndarray, 'Tensor']) -> float: """ Given two tensors with the same shape and data type, this function finds the minimal e, such that abs(a - b) <= abs(b) * e + e Parameters ---------- a: Union[np.ndarray, hidet.Tensor] The first tensor. b: Union[np.ndarray, hidet.Tensor] The second tensor. Returns ------- ret: float The error tolerance between a and b. """ from hidet.graph import Tensor if isinstance(a, Tensor): a = a.numpy() if isinstance(b, Tensor): b = b.numpy() if isinstance(a.dtype, np.floating): a = a.astype(np.float32) b = b.astype(np.float32) lf = 0.0 rg = 10.0 for _ in range(20): mid = (lf + rg) / 2.0 if np.allclose(a, b, rtol=mid, atol=mid): rg = mid else: lf = mid return (lf + rg) / 2.0
def assert_close(actual, expected, rtol=1e-5, atol=1e-5): from numpy import ndarray from hidet import Tensor import numpy.testing if isinstance(actual, ndarray): pass elif isinstance(actual, Tensor): actual = actual.cpu().numpy() elif type(actual).__name__ == 'Tensor' and type(actual).__module__ == 'torch': actual = actual.cpu().numpy() else: raise TypeError(f'Unsupported type: {type(actual)}') if isinstance(expected, ndarray): pass elif isinstance(expected, Tensor): expected = expected.cpu().numpy() elif type(expected).__name__ == 'Tensor' and type(expected).__module__ == 'torch': expected = expected.cpu().numpy() else: raise TypeError(f'Unsupported type: {type(expected)}') numpy.testing.assert_allclose(actual, expected, rtol, atol) if __name__ == '__main__': # color_table() print(color_text('sample', idx=1)) print(color_text('sample', idx=2))