hidet.graph.frontend.torch

hidet.graph.frontend.from_torch(module, concrete_args=None)[source]

Convert a torch.nn.Module or torch.fx.GraphModule to a hidet.nn.Module.

Parameters:
  • module (torch.nn.Module or torch.fx.GraphModule) – The torch module to convert.

  • concrete_args (Dict[str, Any] or None) – The concrete arguments to the module. If provided, will be used to make some arguments concrete during symbolic tracing.

Returns:

ret – The converted hidet module, which is a subclass of hidet.nn.Module.

Return type:

Interpreter

class hidet.graph.frontend.torch.DynamoConfig[source]
reset()[source]

Reset the configuration to the default values

search_space(level=2)[source]

The schedule search space for the operator kernel tuning Candidates are: 0, 1, 2

  • 0:

    Use the default schedule, without tuning.

  • 1:

    Tune the schedule in a small search space. Usually takes less than one minute to tune a kernel.

  • 2:

    Tune the schedule in a large search space. Usually achieves the best performance, but takes longer time.

Parameters:

level (int) – The search space level.

use_tensor_core(flag=True)[source]

Whether to use tensor core

parallel_k(strategy='default')[source]

Parallelization on k dimension of the matrix multiplication Candidates are: default, disabled, search

  • default:

    Default parallelization strategy. A heuristic strategy is used to decide whether to parallelize on k dimension and the size of split factor

  • disabled:

    Disable parallelization on k dimension

  • search:

    Search for the best parallelization strategy. Takes more time but usually achieves the best performance.

Parameters:

strategy (str) – The parallelization strategy.

use_fp16(flag=True)[source]

Whether to use float16 data type

use_fp16_reduction(flag=True)[source]

Whether to use float16 data type for reduction

use_attention(flag=False)[source]

Whether to use fused attention schedule

use_cuda_graph(flag=True)[source]

Whether to use cuda graph

print_input_graph(flag=True)[source]

Whether to print the input graph

dump_graph_ir(output_dir)[source]

Whether to dump the graph ir

Parameters:

output_dir (str) – The output directory to dump the graph ir.

correctness_report(flag=True)[source]

Whether to check correctness and print report error

steal_weights(flag=True)[source]

Whether to clear pytorch weights in certain layers after converting them to Hidet tensors. This will save some GPU memory usage.