-
Notifications
You must be signed in to change notification settings - Fork 384
Description
Bug Description
Torch-TensorRT fails at runtime when a symbolic dimension is derived from an int64 tensor and later reused in reshape after an aten.index.Tensor gather-style op.
A minimized repro crashes with:
RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:249] Expected nbNames == 0 to be true but got false- TensorRT shape error:
IShuffleLayer ... reshaping failed ... Reshape dimension of -1 has no solution ... reshape dims{379350532 1 -1 2}
This matches the failure pattern seen in the larger ASR decoder FX graph (index -> _reshape_copy path with integer index tensors). And if the graph was compiled with a use_python_runtime flag, the test case runs ok.
To Reproduce
Steps to reproduce the behavior:
- Save this repro as
repro_issue.py:
import torch
import torch_tensorrt
import traceback
class TestModule(torch.nn.Module):
def forward(self, x, targets, cache_length):
# symbolic size from int64 tensor
B = targets.size(0)
# int64 tensor used as index
idx = cache_length + torch.arange(1, device=x.device)
y = x[:, idx, :]
# reshape uses symbolic dim from int64-derived size
z = y.reshape(B, 1, -1, 2)
return z
B, S, D = 16, 128, 1024
x = torch.randn(B, S, D).cuda()
targets = torch.randint(0, 10, (B, 1), dtype=torch.int64).cuda()
cache_length = torch.tensor(0, dtype=torch.int64).cuda()
torch._dynamo.mark_dynamic(targets, 0)
torch._dynamo.mark_dynamic(x, 0)
model = TestModule().eval().cuda()
try:
compiled_model = torch.compile(
model,
backend="tensorrt",
options={
"truncate_double": True,
"enabled_precisions": {torch.float, torch.half},
"min_block_size": 1,
"optimization_level": 1,
"debug": True,
},
)
compiled_model(x, targets, cache_length)
except Exception:
traceback.print_exc()-
Run:
python repro_issue.py -
Observe runtime failure in TensorRT execution with
inferShapes/reshape error andExpected nbNames == 0 ... got false.
Error excerpt:
ERROR: [Torch-TensorRT] - IExecutionContext::inferShapes: Error Code 7: Internal Error
... [SHUFFLE]-[aten_ops._reshape_copy.default]-[_reshape_copy]
... [GATHER]-[aten_ops.index.Tensor]-[index_index_gather]_output
Reshape dimension of -1 has no solution
Instruction: RESHAPEinput dims{16 1 1024} reshape dims{379350532 1 -1 2}
RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:249]
Expected nbNames == 0 to be true but got false
Expected behavior
The compiled Torch-TensorRT engine should execute successfully and return a valid tensor with shape [B, 1, 512, 2] (for B=16) without shape-inference/internal runtime errors.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
2.11.0 - PyTorch Version (e.g. 1.0):
2.11.0.dev20260121+cu130 - CPU Architecture:
x86_64 - OS (e.g., Linux):
Linux 6.14.0-29-generic (glibc 2.39) - How you installed PyTorch (
conda,pip,libtorch, source):pip(in virtualenv) - Build command you used (if compiling from source):
N/A (not built in this repro environment) - Are you using local sources or building from archives: Torch-TensorRT is from local editable source (
Editable project location: /home/wenbingl/scratch/g/t-trt) - Python version:
3.12.3 - CUDA version:
13.0(fromtorch.version.cuda) - GPU models and configuration:
1x NVIDIA RTX PRO 6000 Blackwell Server Edition - Any other relevant information:
- Backend:
torch.compile(..., backend="tensorrt") - Debug logs enabled (
options["debug"] = True) - Repro requires dynamic shape marking (
torch._dynamo.mark_dynamic)
- Backend:
Additional context
- The minimized repro is derived from a real decoder graph where
int64tensors are used both for indexing and symbolic shape propagation. - In debug output, the
_reshape_copyshape operand includes a symbolic value sourced fromtargets.size(0), and runtime appears to treat it incorrectly (garbage large integer), causing shuffle/reshape failure. - This looks like a shape-tensor handling issue across
aten.index.Tensor+_reshape_copywhen integer-derived symbolic dims are involved.