diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py index 92962a9ea6..7d77208f84 100644 --- a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -11,6 +11,9 @@ from __future__ import annotations +from typing import Optional + +from onnxscript import ir from onnxscript.function_libs.torch_lib.ops import common from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op @@ -61,3 +64,83 @@ def quantized_decomposed_dequantize_per_tensor( return dequantized assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" return op.Cast(dequantized, to=out_dtype) + + +@torch_op( + ( + "quantized_decomposed::quantize_per_channel", + "quantized_decomposed::quantize_per_channel.tensor", + "quantized_decomposed::quantize_per_channel.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_quantize_per_channel( + input: TensorType, + scales: TensorType, + zero_points: TensorType, + axis: int, + quant_min: int, + quant_max: int, + dtype: int, +) -> TensorType: + """Affine per channel quantization for the Tensor using the same quantization + parameters for each channel/axis to map from floating point to quantized values. + + Reference: + https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py + ``res = clamp(round(input / scales) + zero_points, quant_min, quant_max)`` + """ + # ONNX QuantizeLinear requires the scale to share the input element type and the + # zero_point to share the (quantized) output element type. PyTorch passes scales as + # float64 and zero_points as int64, so cast both to the expected ONNX types. + scales = op.CastLike(scales, input) + zero_points = op.Cast(zero_points, to=dtype) + quantized = op.QuantizeLinear(input, scales, zero_points, axis=axis) + # QuantizeLinear saturates to the full range of ``dtype``. PyTorch clamps to the + # explicit ``quant_min``/``quant_max`` instead, so clamp to match its semantics. + return op.Clip( + quantized, + common.constant(quant_min, dtype=dtype), + common.constant(quant_max, dtype=dtype), + ) + + +@torch_op( + ( + "quantized_decomposed::dequantize_per_channel", + "quantized_decomposed::dequantize_per_channel.tensor", + "quantized_decomposed::dequantize_per_channel.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_dequantize_per_channel( + input: TensorType, + scales: TensorType, + zero_points: Optional[TensorType], + axis: int, + quant_min: int, + quant_max: int, + dtype: int, + out_dtype: int = -1, +) -> TensorType: + """Affine per channel dequantization for the Tensor using the same quantization + parameters for each channel/axis to map from quantized values to floating point values. + + Reference: + https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py + ``res = (input - zero_points) * scales`` cast to ``out_dtype`` (float32 by default). + ``quant_min``/``quant_max``/``dtype`` are metadata only and unused in the computation. + """ + # ONNX DequantizeLinear requires a floating point scale and a zero_point that shares + # the (quantized) input element type. PyTorch passes scales as float64 and zero_points + # as int64, so cast scales to float32 (the PyTorch default output type) and zero_points + # to the quantized input type. + scales = op.Cast(scales, to=ir.DataType.FLOAT) + if zero_points is not None: + zero_points = op.Cast(zero_points, to=dtype) + dequantized = op.DequantizeLinear(input, scales, zero_points, axis=axis) + if out_dtype in (-1, None): + # PyTorch defaults to float32, which DequantizeLinear already produces. + return dequantized + assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" + return op.Cast(dequantized, to=out_dtype) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 440e2316c1..e15154973c 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -9,6 +9,9 @@ # TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch + +# Importing this module registers the quantized_decomposed::* operators used below. +import torch.ao.quantization.fx._decomposed # noqa: F401 import torchvision from torch.onnx._internal.exporter import _testing @@ -705,6 +708,115 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_quantize_per_channel_int8(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, -128, 127, torch.int8 + ) + + x = torch.randn(3, 4) * 5 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_per_channel_uint8(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([10, 128, 250], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, 0, 255, torch.uint8 + ) + + x = torch.randn(3, 4) * 5 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_per_channel_non_zero_axis(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05, 0.3], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3, 2], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 1, -128, 127, torch.int8 + ) + + x = torch.randn(2, 4, 3) * 4 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_per_channel_clamps_to_quant_min_max(self): + # quant_min/quant_max are narrower than the int8 range, so values must be + # clamped to [-20, 20] to match the PyTorch reference semantics. + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2], dtype=torch.float64) + zero_points = torch.tensor([0, 0], dtype=torch.int64) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, -20, 20, torch.int8 + ) + + x = torch.randn(2, 4) * 10 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_dequantize_per_channel_int8(self): + class Model(torch.nn.Module): + def forward(self, q): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3], dtype=torch.int64) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 0, -128, 127, torch.int8 + ) + + q = torch.randint(-128, 128, (3, 4), dtype=torch.int8) + onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_dequantize_per_channel_uint8(self): + class Model(torch.nn.Module): + def forward(self, q): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([10, 128, 250], dtype=torch.int64) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 0, 0, 255, torch.uint8 + ) + + q = torch.randint(0, 256, (3, 4), dtype=torch.uint8) + onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_dequantize_per_channel_non_zero_axis(self): + class Model(torch.nn.Module): + def forward(self, q): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([1, 2, 3], dtype=torch.int64) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 2, -128, 127, torch.int8 + ) + + q = torch.randint(-128, 128, (2, 4, 3), dtype=torch.int8) + onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_quantize_dequantize_per_channel_roundtrip(self): + class Model(torch.nn.Module): + def forward(self, x): + scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64) + zero_points = torch.tensor([0, 5, -3], dtype=torch.int64) + q = torch.ops.quantized_decomposed.quantize_per_channel( + x, scales, zero_points, 0, -128, 127, torch.int8 + ) + return torch.ops.quantized_decomposed.dequantize_per_channel( + q, scales, zero_points, 0, -128, 127, torch.int8 + ) + + x = torch.randn(3, 4) * 5 + onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + @parameterized.parameterized.expand( [ # Multiple advanced indices, all 1D tensors.