From c90a466809db39d38c82858a8cf95e66b767363b Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 9 Feb 2026 22:44:58 +0000 Subject: [PATCH] docs: update torch.compile documentation --- docsrc/dynamo/torch_compile.rst | 59 +++++++++++++++++++++++-- docsrc/getting_started/installation.rst | 7 +-- docsrc/ts/ptq.rst | 3 ++ 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/docsrc/dynamo/torch_compile.rst b/docsrc/dynamo/torch_compile.rst index 14498fc0c3..0724f1078d 100644 --- a/docsrc/dynamo/torch_compile.rst +++ b/docsrc/dynamo/torch_compile.rst @@ -46,7 +46,20 @@ Custom Setting Usage "optimization_level": 4, "use_python_runtime": False,}) -.. note:: Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. +.. note:: Torch-TensorRT supports FP32, FP16, and INT8 precision layers. For INT8 quantization, use the TensorRT Model Optimizer (modelopt) for post-training quantization (PTQ). See :ref:`vgg16_ptq` for an example. + +Advanced Precision Control +^^^^^^^^^^^^^^^^^ + +For fine-grained control over mixed precision execution, TensorRT 10.12+ provides additional settings: + +* ``use_explicit_typing``: Enable explicit type specification (required for TensorRT 10.12+) +* ``enable_autocast``: Enable rule-based autocast for automatic precision selection +* ``autocast_low_precision_type``: Target precision for autocast (e.g., ``torch.float16``) +* ``autocast_excluded_nodes``: Specific nodes to exclude from autocast +* ``autocast_excluded_ops``: Operation types to exclude from autocast + +For detailed information and examples, see :ref:`mixed_precision`. Compilation ----------------- @@ -98,14 +111,54 @@ Compilation can also be helpful in demonstrating graph breaks and the feasibilit print(f"Graph breaks: {explanation.graph_break_count}") optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, options={"truncate_long_and_double": True}) +Engine Caching +^^^^^^^^^^^^^^^^^ +Engine caching can significantly reduce recompilation times by saving built TensorRT engines to disk and reusing them when possible. This is particularly useful for JIT workflows where graphs may be invalidated and recompiled. When enabled, engines are saved with a hash of their corresponding PyTorch subgraph and can be reloaded in subsequent compilations—even across different Python sessions. + +To enable engine caching, use the ``cache_built_engines`` and ``reuse_cached_engines`` options: + +.. code-block:: python + + import torch_tensorrt + ... + optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, + options={"cache_built_engines": True, + "reuse_cached_engines": True, + "immutable_weights": False, + "engine_cache_dir": "/tmp/torch_trt_cache", + "engine_cache_size": 1 << 30}) # 1GB + +.. note:: To use engine caching, ``immutable_weights`` must be set to ``False`` to allow engine refitting. When a cached engine is loaded, weights are refitted rather than rebuilding the entire engine, which can reduce compilation times by orders of magnitude. + +For more details and examples, see :ref:`engine_caching_example`. + Dynamic Shape Support ----------------- -The Torch-TensorRT `torch.compile` backend will currently require recompilation for each new batch size encountered, and it is preferred to use the `dynamic=False` argument when compiling with this backend. Full dynamic shape support is planned for a future release. +The Torch-TensorRT `torch.compile` backend now supports dynamic shapes, allowing models to handle varying input dimensions without recompilation. You can specify dynamic dimensions using the ``torch._dynamo.mark_dynamic`` API: + +.. code-block:: python + + import torch + import torch_tensorrt + ... + inputs = torch.randn((1, 3, 224, 224), dtype=torch.float32).cuda() + # Mark dimension 0 (batch) as dynamic with range [1, 8] + torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8) + optimized_model = torch.compile(model, backend="tensorrt") + optimized_model(inputs) # First compilation + + # No recompilation with different batch size in the dynamic range + inputs_bs4 = torch.randn((4, 3, 224, 224), dtype=torch.float32).cuda() + optimized_model(inputs_bs4) + +Without dynamic shapes, the model will recompile for each new input shape encountered. For more control over dynamic shapes, consider using the AOT compilation path with ``torch_tensorrt.compile`` as described in :ref:`dynamic_shapes`. For a complete tutorial on dynamic shape compilation, see :ref:`compile_with_dynamic_inputs`. Recompilation Conditions ----------------- -Once the model has been compiled, subsequent inference inputs with the same shape and data type, which traverse the graph in the same way will not require recompilation. Furthermore, each new recompilation will be cached for the duration of the Python session. For instance, if inputs of batch size 4 and 8 are provided to the model, causing two recompilations, no further recompilation would be necessary for future inputs with those batch sizes during inference within the same session. Support for engine cache serialization is planned for a future release. +Once the model has been compiled, subsequent inference inputs with the same shape and data type, which traverse the graph in the same way will not require recompilation. Furthermore, each new recompilation will be cached for the duration of the Python session. For instance, if inputs of batch size 4 and 8 are provided to the model, causing two recompilations, no further recompilation would be necessary for future inputs with those batch sizes during inference within the same session. + +To persist engine caches across Python sessions, use the ``cache_built_engines`` and ``reuse_cached_engines`` options as described in the Engine Caching section above. Recompilation is generally triggered by one of two events: encountering inputs of different sizes or inputs which traverse the model code differently. The latter scenario can occur when the model code includes conditional logic, complex loops, or data-dependent-shapes. `torch.compile` handles guarding in both of these scenario and determines when recompilation is necessary. diff --git a/docsrc/getting_started/installation.rst b/docsrc/getting_started/installation.rst index ceba409fec..f23c5b7362 100644 --- a/docsrc/getting_started/installation.rst +++ b/docsrc/getting_started/installation.rst @@ -15,6 +15,7 @@ You need to have CUDA, PyTorch, and TensorRT (python package is sufficient) inst * https://developer.nvidia.com/cuda * https://pytorch.org + * TensorRT 10.0 or later (TensorRT 10.12+ recommended for latest features like explicit typing) Installing Torch-TensorRT @@ -208,9 +209,9 @@ A tarball with the include files and library can then be found in ``bazel-bin`` Choosing the Right ABI ^^^^^^^^^^^^^^^^^^^^^^^^ -For the old versions, there were two ABI options to compile Torch-TensorRT which were incompatible with each other, -pre-cxx11-abi and cxx11-abi. The complexity came from the different distributions of PyTorch. Fortunately, PyTorch -has switched to cxx11-abi for all distributions. Below is a table with general pairings of PyTorch distribution +For the old versions, there were two ABI options to compile Torch-TensorRT which were incompatible with each other, +pre-cxx11-abi and cxx11-abi. The complexity came from the different distributions of PyTorch. Fortunately, PyTorch +has switched to cxx11-abi for all distributions. Below is a table with general pairings of PyTorch distribution sources and the recommended commands: +-------------------------------------------------------------+----------------------------------------------------------+--------------------------------------------------------------------+ diff --git a/docsrc/ts/ptq.rst b/docsrc/ts/ptq.rst index f855c6679c..ada0ea6e35 100644 --- a/docsrc/ts/ptq.rst +++ b/docsrc/ts/ptq.rst @@ -3,6 +3,9 @@ Post Training Quantization (PTQ) ================================= +.. warning:: + This guide describes the legacy PTQ workflow for the TorchScript frontend. **For new projects, use the TensorRT Model Optimizer (modelopt) with the Dynamo frontend instead.** See :ref:`vgg16_ptq` for the recommended approach. + Post Training Quantization (PTQ) is a technique to reduce the required computational resources for inference while still preserving the accuracy of your model by mapping the traditional FP32 activation space to a reduced INT8 space. TensorRT uses a calibration step which executes your model with sample data from the target domain