Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
ConverterPriority,
dynamo_tensorrt_converter,
has_static_shapes_in_args,
)
Expand Down Expand Up @@ -568,12 +569,29 @@ def index_nonbool_validator(
return True


def index_has_bool_indices(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
"""Returns True if any index tensor is boolean."""
index = node.args[1]
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is not None and val.dtype == torch.bool:
return True
return False


# Integer indexing: output shape is deterministic (depends on index tensor
# shape, not values), so no output allocator is needed. This is the common
# case and is checked first via HIGH priority.
@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor,
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings),
and not index_has_bool_indices(node, settings),
priority=ConverterPriority.HIGH,
supports_dynamic_shapes=True,
requires_output_allocator=True,
requires_output_allocator=False,
)
@enforce_tensor_types(
{
Expand All @@ -597,6 +615,38 @@ def aten_ops_index(
)


# Boolean indexing: internally uses nonzero() which produces data-dependent
# output shapes, so an output allocator is required.
@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor,
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings)
and index_has_bool_indices(node, settings),
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_index_bool(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.index(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.tanh.default, supports_dynamic_shapes=True)
def aten_ops_tanh(
ctx: ConversionContext,
Expand Down