From 87c64d3b9145f176be26dbe8c133dc2972caa576 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 20 Jan 2026 23:18:28 -0800 Subject: [PATCH] remove downcast of int64 to int32 --- .../dynamo/conversion/impl/embedding.py | 6 +- .../dynamo/conversion/impl/select.py | 9 +- .../dynamo/conversion/test_embedding_aten.py | 21 ++++ .../py/dynamo/conversion/test_gather_aten.py | 109 ++++++++++++++++++ .../conversion/test_index_select_aten.py | 109 ++++++++++++++++++ .../py/dynamo/conversion/test_scatter_aten.py | 96 +++++++++++++++ 6 files changed, 342 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index 0a723618eb..1734c1502b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -30,10 +30,8 @@ def embedding( ) -> TRTTensor: indices_tensor = input embedding_tensor = weight - if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64: - raise RuntimeError( - "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT." - ) + # Note: TensorRT's Gather layer supports both int32 and int64 indices + # https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor") embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor") # unsupported parameters diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 3c714a8fb5..7d89545239 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -494,8 +494,7 @@ def scatter( input_shape = input.shape index_shape = index.shape index_shape_list = list(index_shape) - if index.dtype == trt.int64: - index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor") + # Note: TensorRT's Scatter layer supports both int32 and int64 indices dim = get_positive_dim(dim, len(input_shape)) src_tensor = src # scatter.value @@ -530,7 +529,9 @@ def gather( ) -> TRTTensor: input_shape = input.shape dim = get_positive_dim(dim, len(input_shape)) - index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor") + # Note: TensorRT's Gather layer supports both int32 and int64 indices + # https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html + index = get_trt_tensor(ctx, index, name + "_index_tensor") gather_layer = ctx.net.add_gather(input, index, axis=dim) gather_layer.mode = trt.GatherMode.ELEMENT set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir) @@ -857,7 +858,7 @@ def index_put_converter( values_expanded, (-1,), ) - indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32") + # Note: TensorRT's Scatter layer supports both int32 and int64 indices if accumulate: zero_tensor = impl.full.full( ctx, diff --git a/tests/py/dynamo/conversion/test_embedding_aten.py b/tests/py/dynamo/conversion/test_embedding_aten.py index c04d89ff9e..d00ad05c3b 100644 --- a/tests/py/dynamo/conversion/test_embedding_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_aten.py @@ -29,6 +29,27 @@ class TestEmbeddingConverter(DispatchTestCase): weights_tensor=torch.randn((5, 10), dtype=torch.float32), sparse=True, ), + # int64 indices - TensorRT now supports int64 for gather operations + param( + test_name="1d_indices_int64", + indices_tensor=torch.tensor([3, 1, 2], dtype=torch.int64), + weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=False, + ), + param( + test_name="2d_indices_int64", + indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]], dtype=torch.int64), + weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=True, + ), + param( + test_name="3d_indices_int64", + indices_tensor=torch.tensor( + [[[0, 1], [2, 3]], [[3, 4], [4, 0]]], dtype=torch.int64 + ), + weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=True, + ), ] ) def test_embedding( diff --git a/tests/py/dynamo/conversion/test_gather_aten.py b/tests/py/dynamo/conversion/test_gather_aten.py index 7e96027ef5..5900ece8ee 100644 --- a/tests/py/dynamo/conversion/test_gather_aten.py +++ b/tests/py/dynamo/conversion/test_gather_aten.py @@ -70,3 +70,112 @@ def forward(self, input, index): input = torch.zeros(3, 5, dtype=torch.int32) inputs = [input, index] self.run_test(TestModule(), inputs) + + +class TestGatherInt64IndexConverter(DispatchTestCase): + """Test cases for gather with int64 indices. + TensorRT now supports int64 indices for gather operations. + https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html + """ + + @parameterized.expand( + [ + ( + "gather_zero_dim_indexOne_int64", + 0, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + ), + ( + "gather_zero_dim_indexTwo_int64", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64), + ), + ( + "gather_one_dim_indexOne_int64", + 1, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + ), + ( + "gather_one_dim_indexTwo_int64", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64), + ), + ] + ) + def test_gather_index_int64_constant(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.gather.default(input, dim, index) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input] + self.run_test(TestModule(), inputs) + + @parameterized.expand( + [ + ( + "gather_zero_dim_indexOne_int64_input", + 0, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + ), + ( + "gather_zero_dim_indexTwo_int64_input", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64), + ), + ( + "gather_one_dim_indexOne_int64_input", + 1, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + ), + ( + "gather_one_dim_indexTwo_int64_input", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64), + ), + ] + ) + def test_gather_index_int64_input(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.gather.default(input, dim, index) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input, index] + self.run_test(TestModule(), inputs) + + @parameterized.expand( + [ + ( + "gather_float_input_int64_index", + 0, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + ), + ( + "gather_float_input_int64_index_dim1", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64), + ), + ] + ) + def test_gather_float_input_int64_index(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.gather.default(input, dim, index) + + input = torch.randn(3, 5, dtype=torch.float32) + inputs = [input, index] + self.run_test(TestModule(), inputs) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py index b1339efdcf..fe9e0dc2c1 100644 --- a/tests/py/dynamo/conversion/test_index_select_aten.py +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -135,5 +135,114 @@ def forward(self, source_tensor, indice_tensor): ) +class TestIndexSelectInt64Converter(DispatchTestCase): + """Test cases for index_select with int64 indices. + TensorRT now supports int64 indices for gather operations. + https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_i_gather_layer.html + """ + + @parameterized.expand( + [ + ("1d_input_int64", (10,), 0, (1,)), + ("2d_input_dim_0_int64", (10, 3), 0, (0, 2)), + ("2d_input_dim_1_int64", (5, 10), 1, (1, 2, 3)), + ("3d_input_dim_0_int64", (10, 5, 10), 0, (0, 5)), + ("3d_input_dim_2_int64", (10, 5, 10), 2, (3, 3, 4)), + ("3d_input_dim_-1_int64", (10, 5, 10), -1, (3, 3, 4)), + ] + ) + def test_index_select_int64(self, _, source_shape, dim, indices_val): + class TestIndexSelect(torch.nn.Module): + def forward(self, source_tensor, indices_tensor): + return torch.ops.aten.index_select.default( + source_tensor, dim, indices_tensor + ) + + input = [ + torch.randn(*source_shape, dtype=torch.float32), + torch.tensor([*indices_val], dtype=torch.int64), + ] + + self.run_test( + TestIndexSelect(), + input, + ) + + @parameterized.expand( + [ + param( + # 1d_source_tensor_int64_index + source_tensor=torch.randn((3,), dtype=torch.float32), + source_tensor_1=torch.randn((5,), dtype=torch.float32), + dynamic_shapes={ + "source_tensor": {0: torch.export.Dim("dyn_dim", min=3, max=6)}, + "indice_tensor": {}, + }, + dim=0, + indice_tensor=torch.tensor( + [ + 1, + ], + dtype=torch.int64, + ), + ), + param( + # 2d_source_tensor_int64_index + source_tensor=torch.randn((3, 3), dtype=torch.float32), + source_tensor_1=torch.randn((4, 6), dtype=torch.float32), + dynamic_shapes={ + "source_tensor": { + 0: torch.export.Dim("dyn_dim1", min=3, max=6), + 1: torch.export.Dim("dyn_dim2", min=2, max=7), + }, + "indice_tensor": {}, + }, + dim=-1, + indice_tensor=torch.tensor([0, 2], dtype=torch.int64), + ), + ] + ) + def test_index_select_int64_dynamic_shape( + self, source_tensor, source_tensor_1, dynamic_shapes, dim, indice_tensor + ): + class IndexSelect(torch.nn.Module): + def forward(self, source_tensor, indice_tensor): + return torch.ops.aten.index_select.default( + source_tensor, + dim, + indice_tensor, + ) + + inputs = (source_tensor, indice_tensor) + mod = IndexSelect() + + fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes) + trt_mod = torch_tensorrt.dynamo.compile( + fx_mod, + inputs=inputs, + enable_precisions=torch.float32, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + # use different shape of inputs for inference: + inputs = (source_tensor_1, indice_tensor) + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + ref_outputs = mod(*cuda_inputs) + outputs = trt_mod(*cuda_inputs) + for out, ref in zip(outputs, ref_outputs): + torch.testing.assert_close( + out, + ref, + rtol=RTOL, + atol=ATOL, + equal_nan=True, + check_dtype=True, + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_scatter_aten.py b/tests/py/dynamo/conversion/test_scatter_aten.py index 85b7094a5c..b408ef917b 100644 --- a/tests/py/dynamo/conversion/test_scatter_aten.py +++ b/tests/py/dynamo/conversion/test_scatter_aten.py @@ -173,5 +173,101 @@ def forward(self, input, index): self.run_test(TestModule(), inputs, int32_reqd=True) +class TestScatterInt64IndexConverter(DispatchTestCase): + """Test cases for scatter with int64 indices. + TensorRT now supports int64 indices for scatter operations. + """ + + @parameterized.expand( + [ + ( + "scatter_zero_dim_int64_index_value", + 0, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + 1, + ), + ( + "scatter_one_dim_int64_index_value", + 1, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + 1, + ), + ( + "scatter_zero_dim_int64_indexTwo_value", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]], dtype=torch.int64), + 1, + ), + ] + ) + def test_scatter_int64_index_constant(self, _, dim, index, value): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.scatter.value(input, dim, index, value) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input] + self.run_test(TestModule(), inputs, int32_reqd=True) + + @parameterized.expand( + [ + ( + "scatter_zero_dim_int64_index_input", + 0, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + 1, + ), + ( + "scatter_one_dim_int64_index_input", + 1, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + 1, + ), + ] + ) + def test_scatter_int64_index_input(self, _, dim, index, value): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.scatter.value(input, dim, index, value) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input, index] + self.run_test(TestModule(), inputs, int32_reqd=True) + + @parameterized.expand( + [ + ( + "scatter_src_zero_dim_int64_index", + 0, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_src_one_dim_int64_index", + 1, + torch.tensor([[0, 1, 2, 0]], dtype=torch.int64), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), + ), + ] + ) + def test_scatter_src_int64_index(self, _, dim, index, src): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, index): + return torch.ops.aten.scatter.src(input, dim, index, src) + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input, index] + self.run_test(TestModule(), inputs, int32_reqd=True) + + if __name__ == "__main__": run_tests()