Fix bicubic antialias export: use cubic_coeff_a=-0.5 instead of -0.75#2849
Open
Fix bicubic antialias export: use cubic_coeff_a=-0.5 instead of -0.75#2849
Conversation
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Fix cubic_coeff_a value for bicubic antialias in ONNX export
Fix bicubic antialias export: use cubic_coeff_a=-0.5 instead of -0.75
Mar 11, 2026
justinchuby
approved these changes
Mar 11, 2026
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2849 +/- ##
=======================================
Coverage 71.86% 71.86%
=======================================
Files 239 239
Lines 29139 29139
Branches 2875 2875
=======================================
Hits 20942 20942
Misses 7219 7219
Partials 978 978 ☔ View full report in Codecov by Sentry. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When exporting
F.interpolate(mode='bicubic', antialias=True), the ONNX Resize node was emitted withcubic_coeff_a=-0.75(OpenCV-compatible), but PyTorch uses-0.5(Keys/PIL-compatible) for the antialias path. This caused ~32x higher numerical error vs. PyTorch when running the exported model in ONNX Runtime.Changes
_aten_upsample_output_size/_aten_upsample_scales: Addedcubic_coeff_a: float = -0.75parameter (default preserves existing behavior for non-antialias cases) and thread it through toop.Resize.aten__upsample_bicubic2d_aa: Passcubic_coeff_a=-0.5to match PyTorch's runtime behavior whenantialias=True.Original prompt
This section details on the original issue you should resolve
<issue_title>ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic antialias=True (should be -0.5)</issue_title>
<issue_description>### 🐛 Describe the bug
ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic antialias=True (should be -0.5)
Bug
When exporting
F.interpolate(mode='bicubic', antialias=True)to ONNX via the dynamo exporter, the Resize node is written withcubic_coeff_a=-0.75. However, PyTorch internally usescubic_coeff_a=-0.5(Keys interpolation) whenantialias=True, as documented in the source:The exported ONNX model therefore produces different results than PyTorch when run in ONNX Runtime (or any runtime that correctly respects the
cubic_coeff_aattribute).The
-0.75value was originally hardcoded in PR pytorch/pytorch#24805 for the non-antialias case and was carried forward without accounting for the antialias path. The distinction between-0.5(Keys, PIL-compatible) and-0.75(OpenCV-compatible) based on the antialias flag was introduced in the ATen kernels via pytorch/vision#3810 and pytorch#68819.The legacy TorchScript exporter does not support
antialias=Trueat all (UnsupportedOperatorError), so this only affects the dynamo exporter.To reproduce
Output:
Patching
cubic_coeff_ato-0.5reduces mean error by 32x, confirming that PyTorch uses-0.5at runtime but the exporter writes-0.75.Expected behavior
When
antialias=True, the ONNX Resize node should be exported withcubic_coeff_a=-0.5to match PyTorch's runtime behavior. Whenantialias=False,cubic_coeff_a=-0.75is correct.Versions
Collecting environment information...
PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 4.2.3
Libc version: glibc-2.31
Python version: 3.12.12 (main, Feb 3 2026, 22:51:04) [Clang 21.1.4 ] (64-bit runtime)
Python platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 565.57.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical...
🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.