hidet.ir.compute

Tip

Please refer to here for how to use these compute primitives to define a computation task.

Classes:

ReduceType(value)

An enumeration.

Functions:

scalar_input(name, dtype)

Define an input scalar node.

tensor_input(name, dtype, shape[, layout])

Define an input tensor node.

compute(name, shape, fcompute[, layout])

Define a grid compute node.

reduce(shape, fcompute, reduce_type[, ...])

Define a reduction node.

arg_reduce(extent, fcompute, reduce_type[, ...])

Define an arg reduction node.

class hidet.ir.compute.ReduceType(value)[source]

An enumeration.

hidet.ir.compute.scalar_input(name, dtype)[source]

Define an input scalar node.

Parameters:
  • name (str) – The name of the input scalar.

  • dtype (str or DataType) – The scalar type of the input scalar.

Returns:

ret – The input scalar node.

Return type:

ScalarInput

hidet.ir.compute.tensor_input(name, dtype, shape, layout=None)[source]

Define an input tensor node.

Parameters:
  • name (str) – The name of the input tensor.

  • dtype (str or DataType) – The scalar type of the tensor.

  • shape (Sequence[Expr or int]) – The shape of the tensor.

  • layout (DataLayout, optional) – The layout of the tensor.

Returns:

ret – The input tensor node.

Return type:

TensorInput

hidet.ir.compute.compute(name, shape, fcompute, layout=None)[source]

Define a grid compute node.

Parameters:
  • name (str) – The name of the compute node.

  • shape (Sequence[Union[int, Expr]]) – The shape of the compute node.

  • fcompute (Callable[[Sequence[Var]], Expr]) – The compute function. It takes a list of index variables and returns the output value corresponding to the index.

  • layout (DataLayout, optional) – The layout of the compute node.

Returns:

ret – The grid compute node.

Return type:

TensorNode

hidet.ir.compute.reduce(shape, fcompute, reduce_type, accumulate_dtype='float32', name=None)[source]

Define a reduction node.

Parameters:
  • shape (Sequence[int or Expr]) – The domain of the reduction.

  • fcompute (Callable[[Sequence[Var]], Expr]) – The compute function. It takes a list of reduction variables and returns the reduction value.

  • reduce_type (ReduceType or str) – The type of the reduction.

  • accumulate_dtype (str or DataType) – The data type of the accumulator.

  • name (Optional[str]) – The name hint for the output. If not specified, the name will be generated automatically.

Returns:

ret – The reduction node.

Return type:

ReduceCompute

hidet.ir.compute.arg_reduce(extent, fcompute, reduce_type, index_dtype='int64', name=None)[source]

Define an arg reduction node.

Parameters:
  • extent (int or Expr) – The domain of the reduction.

  • fcompute (Callable[[Var], Expr]) – The compute function. It takes a reduction variable and returns the value to compare.

  • reduce_type (str or ReduceType) – The type of the reduction.

  • index_dtype (str or DataType) – The data type of the index.

  • name (str, optional) – The name of the output. If not specified, the name will be generated automatically.

Returns:

ret – The arg reduction node.

Return type:

ScalarNode