From 3646461761895708e8b18780a1a2705bcb526625 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Jun 2025 15:38:56 +0000 Subject: [PATCH 1/5] Initial plan for issue From 8cc5271cfe7d91c7bc374f5cd69c9816f1faa71f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Jun 2025 15:49:09 +0000 Subject: [PATCH 2/5] Implement quantize_per_channel and dequantize_per_channel for torchlib Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- .../torch_lib/ops/quantized_decomposed.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py index 92962a9ea6..9dfd8a4da1 100644 --- a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -14,7 +14,9 @@ from onnxscript.function_libs.torch_lib.ops import common from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op +from onnxscript.onnx_opset import opset23 as op23 from onnxscript.onnx_types import TensorType +from typing import Optional @torch_op( @@ -61,3 +63,61 @@ def quantized_decomposed_dequantize_per_tensor( return dequantized assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" return op.Cast(dequantized, to=out_dtype) + + +@torch_op( + ( + "quantized_decomposed::quantize_per_channel", + "quantized_decomposed::quantize_per_channel.tensor", + "quantized_decomposed::quantize_per_channel.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_quantize_per_channel( + input: TensorType, + scales: TensorType, + zero_points: TensorType, + axis: int, + quant_min: int, + quant_max: int, + dtype: int, +) -> TensorType: + """Affine per channel quantization for the Tensor using the same quantization + parameters for each channel/axis to map from floating point to quantized values. + + Uses ONNX QuantizeLinear with per-axis quantization support. + """ + # Use opset23 for per-axis quantization support + return op23.QuantizeLinear(input, scales, zero_points, axis=axis, output_dtype=dtype) + + +@torch_op( + ( + "quantized_decomposed::dequantize_per_channel", + "quantized_decomposed::dequantize_per_channel.tensor", + "quantized_decomposed::dequantize_per_channel.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_dequantize_per_channel( + input: TensorType, + scales: TensorType, + zero_points: Optional[TensorType], + axis: int, + quant_min: int, + quant_max: int, + dtype: int, + out_dtype: int = -1, +) -> TensorType: + """Affine per channel dequantization for the Tensor using the same quantization + parameters for each channel/axis to map from quantized values to floating point values. + + Uses ONNX DequantizeLinear with per-axis quantization support. + """ + # Use opset23 for per-axis quantization support with optional output_dtype + if out_dtype in (-1, None): + # Use default output type (same as scales type) + return op23.DequantizeLinear(input, scales, zero_points, axis=axis) + else: + assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" + return op23.DequantizeLinear(input, scales, zero_points, axis=axis, output_dtype=out_dtype) From 82c8f9edfeaf57e9908cd0f98c2acac62e5aa050 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 18 Jun 2025 00:28:13 +0000 Subject: [PATCH 3/5] Format code using lintrunner Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- .../torch_lib/ops/quantized_decomposed.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py index 9dfd8a4da1..ae82f1572b 100644 --- a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -11,12 +11,13 @@ from __future__ import annotations +from typing import Optional + from onnxscript.function_libs.torch_lib.ops import common from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_opset import opset23 as op23 from onnxscript.onnx_types import TensorType -from typing import Optional @torch_op( @@ -84,7 +85,7 @@ def quantized_decomposed_quantize_per_channel( ) -> TensorType: """Affine per channel quantization for the Tensor using the same quantization parameters for each channel/axis to map from floating point to quantized values. - + Uses ONNX QuantizeLinear with per-axis quantization support. """ # Use opset23 for per-axis quantization support @@ -111,7 +112,7 @@ def quantized_decomposed_dequantize_per_channel( ) -> TensorType: """Affine per channel dequantization for the Tensor using the same quantization parameters for each channel/axis to map from quantized values to floating point values. - + Uses ONNX DequantizeLinear with per-axis quantization support. """ # Use opset23 for per-axis quantization support with optional output_dtype @@ -120,4 +121,6 @@ def quantized_decomposed_dequantize_per_channel( return op23.DequantizeLinear(input, scales, zero_points, axis=axis) else: assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" - return op23.DequantizeLinear(input, scales, zero_points, axis=axis, output_dtype=out_dtype) + return op23.DequantizeLinear( + input, scales, zero_points, axis=axis, output_dtype=out_dtype + ) From 7124f096f0dfb65381b9ac262aaa701d5ad053aa Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Fri, 19 Jun 2026 18:19:58 +0000 Subject: [PATCH 4/5] Add numerical parity tests for quantize/dequantize_per_channel Adds e2e numerical-parity tests (eager torch vs ORT) for quantize_per_channel and dequantize_per_channel covering int8 and uint8, axis 0 and non-zero axes, distinct per-channel scales/zero_points, narrow quant_min/quant_max clamping, and a quantize->dequantize round trip. Fixes the lowering to make parity hold: - Per-axis QuantizeLinear/DequantizeLinear (opset 13+ axis attribute) is used instead of opset23-only attributes (output_dtype/precision/ block_size), which were emitted into an opset-18 graph and rejected by ONNX Runtime. - zero_points (int64 from torch) is now cast to the quantized dtype so it matches the int8/uint8 tensor, and scales (float64 from torch) is cast to the input/float32 type as required by Q/DQ. - quantize_per_channel now clamps to quant_min/quant_max (via Clip) to match torch's reference semantics; dequantize defaults to float32. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../torch_lib/ops/quantized_decomposed.py | 46 +++++-- .../function_libs/torch_lib/e2e_ops_tests.py | 112 ++++++++++++++++++ 2 files changed, 145 insertions(+), 13 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py index ae82f1572b..7d77208f84 100644 --- a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -13,10 +13,10 @@ from typing import Optional +from onnxscript import ir from onnxscript.function_libs.torch_lib.ops import common from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op -from onnxscript.onnx_opset import opset23 as op23 from onnxscript.onnx_types import TensorType @@ -86,10 +86,23 @@ def quantized_decomposed_quantize_per_channel( """Affine per channel quantization for the Tensor using the same quantization parameters for each channel/axis to map from floating point to quantized values. - Uses ONNX QuantizeLinear with per-axis quantization support. + Reference: + https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py + ``res = clamp(round(input / scales) + zero_points, quant_min, quant_max)`` """ - # Use opset23 for per-axis quantization support - return op23.QuantizeLinear(input, scales, zero_points, axis=axis, output_dtype=dtype) + # ONNX QuantizeLinear requires the scale to share the input element type and the + # zero_point to share the (quantized) output element type. PyTorch passes scales as + # float64 and zero_points as int64, so cast both to the expected ONNX types. + scales = op.CastLike(scales, input) + zero_points = op.Cast(zero_points, to=dtype) + quantized = op.QuantizeLinear(input, scales, zero_points, axis=axis) + # QuantizeLinear saturates to the full range of ``dtype``. PyTorch clamps to the + # explicit ``quant_min``/``quant_max`` instead, so clamp to match its semantics. + return op.Clip( + quantized, + common.constant(quant_min, dtype=dtype), + common.constant(quant_max, dtype=dtype), + ) @torch_op( @@ -113,14 +126,21 @@ def quantized_decomposed_dequantize_per_channel( """Affine per channel dequantization for the Tensor using the same quantization parameters for each channel/axis to map from quantized values to floating point values. - Uses ONNX DequantizeLinear with per-axis quantization support. + Reference: + https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py + ``res = (input - zero_points) * scales`` cast to ``out_dtype`` (float32 by default). + ``quant_min``/``quant_max``/``dtype`` are metadata only and unused in the computation. """ - # Use opset23 for per-axis quantization support with optional output_dtype + # ONNX DequantizeLinear requires a floating point scale and a zero_point that shares + # the (quantized) input element type. PyTorch passes scales as float64 and zero_points + # as int64, so cast scales to float32 (the PyTorch default output type) and zero_points + # to the quantized input type. + scales = op.Cast(scales, to=ir.DataType.FLOAT) + if zero_points is not None: + zero_points = op.Cast(zero_points, to=dtype) + dequantized = op.DequantizeLinear(input, scales, zero_points, axis=axis) if out_dtype in (-1, None): - # Use default output type (same as scales type) - return op23.DequantizeLinear(input, scales, zero_points, axis=axis) - else: - assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" - return op23.DequantizeLinear( - input, scales, zero_points, axis=axis, output_dtype=out_dtype - ) + # PyTorch defaults to float32, which DequantizeLinear already produces. + return dequantized + assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" + return op.Cast(dequantized, to=out_dtype) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index d344723408..b6f0f4fca4 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -9,6 +9,9 @@ # TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch + +# Importing this module registers the quantized_decomposed::* operators used below. +import torch.ao.quantization.fx._decomposed # noqa: F401 from torch.onnx._internal.exporter import _testing @@ -629,6 +632,115 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_quantize_per_channel_int8(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, -128, 127, torch.int8 + ) + + x = torch.randn(3, 4) * 5 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_per_channel_uint8(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([10, 128, 250], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, 0, 255, torch.uint8 + ) + + x = torch.randn(3, 4) * 5 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_per_channel_non_zero_axis(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05, 0.3], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3, 2], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 1, -128, 127, torch.int8 + ) + + x = torch.randn(2, 4, 3) * 4 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_per_channel_clamps_to_quant_min_max(self): + # quant_min/quant_max are narrower than the int8 range, so values must be + # clamped to [-20, 20] to match the PyTorch reference semantics. + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2], dtype=torch.float64) + zero_points = torch.tensor([0, 0], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, -20, 20, torch.int8 + ) + + x = torch.randn(2, 4) * 10 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_dequantize_per_channel_int8(self): + class Model(torch.nn.Module): + def forward(self, q): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3], dtype=torch.int64) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 0, -128, 127, torch.int8 + ) + + q = torch.randint(-128, 128, (3, 4), dtype=torch.int8) + onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_dequantize_per_channel_uint8(self): + class Model(torch.nn.Module): + def forward(self, q): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([10, 128, 250], dtype=torch.int64) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 0, 0, 255, torch.uint8 + ) + + q = torch.randint(0, 256, (3, 4), dtype=torch.uint8) + onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_dequantize_per_channel_non_zero_axis(self): + class Model(torch.nn.Module): + def forward(self, q): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([1, 2, 3], dtype=torch.int64) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 2, -128, 127, torch.int8 + ) + + q = torch.randint(-128, 128, (2, 4, 3), dtype=torch.int8) + onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_dequantize_per_channel_roundtrip(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3], dtype=torch.int64) + q = torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, -128, 127, torch.int8 + ) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 0, -128, 127, torch.int8 + ) + + x = torch.randn(3, 4) * 5 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + @parameterized.parameterized.expand( [ # Multiple advanced indices, all 1D tensors. From ffc4f68f250c012ae99fc07bfc6e2f9134f87a79 Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Fri, 19 Jun 2026 21:59:26 +0000 Subject: [PATCH 5/5] Sync main and verify quantize/dequantize_per_channel Resolve merge conflict in e2e_ops_tests.py by keeping both main's torchvision import (for deform_conv2d tests) and the quantized_decomposed registration import for the new per_channel parity tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/function_libs/torch_lib/e2e_ops_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 02828494be..e15154973c 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -9,10 +9,10 @@ # TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch -import torchvision # Importing this module registers the quantized_decomposed::* operators used below. import torch.ao.quantization.fx._decomposed # noqa: F401 +import torchvision from torch.onnx._internal.exporter import _testing