Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

from __future__ import annotations

from typing import Optional

from onnxscript import ir
from onnxscript.function_libs.torch_lib.ops import common
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
Expand Down Expand Up @@ -61,3 +64,83 @@ def quantized_decomposed_dequantize_per_tensor(
return dequantized
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}"
return op.Cast(dequantized, to=out_dtype)


@torch_op(
(
"quantized_decomposed::quantize_per_channel",
"quantized_decomposed::quantize_per_channel.tensor",
"quantized_decomposed::quantize_per_channel.tensor2",
),
trace_only=True,
)
def quantized_decomposed_quantize_per_channel(
input: TensorType,
scales: TensorType,
zero_points: TensorType,
axis: int,
quant_min: int,
quant_max: int,
Comment thread
justinchuby marked this conversation as resolved.
dtype: int,
) -> TensorType:
"""Affine per channel quantization for the Tensor using the same quantization
parameters for each channel/axis to map from floating point to quantized values.

Reference:
https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
``res = clamp(round(input / scales) + zero_points, quant_min, quant_max)``
"""
# ONNX QuantizeLinear requires the scale to share the input element type and the
# zero_point to share the (quantized) output element type. PyTorch passes scales as
# float64 and zero_points as int64, so cast both to the expected ONNX types.
scales = op.CastLike(scales, input)
zero_points = op.Cast(zero_points, to=dtype)
quantized = op.QuantizeLinear(input, scales, zero_points, axis=axis)
# QuantizeLinear saturates to the full range of ``dtype``. PyTorch clamps to the
# explicit ``quant_min``/``quant_max`` instead, so clamp to match its semantics.
return op.Clip(
quantized,
common.constant(quant_min, dtype=dtype),
common.constant(quant_max, dtype=dtype),
)


@torch_op(
(
"quantized_decomposed::dequantize_per_channel",
"quantized_decomposed::dequantize_per_channel.tensor",
"quantized_decomposed::dequantize_per_channel.tensor2",
),
trace_only=True,
)
def quantized_decomposed_dequantize_per_channel(
input: TensorType,
scales: TensorType,
zero_points: Optional[TensorType],
axis: int,
quant_min: int,
quant_max: int,
dtype: int,
out_dtype: int = -1,
) -> TensorType:
"""Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values.

Reference:
https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
``res = (input - zero_points) * scales`` cast to ``out_dtype`` (float32 by default).
``quant_min``/``quant_max``/``dtype`` are metadata only and unused in the computation.
"""
# ONNX DequantizeLinear requires a floating point scale and a zero_point that shares
# the (quantized) input element type. PyTorch passes scales as float64 and zero_points
# as int64, so cast scales to float32 (the PyTorch default output type) and zero_points
# to the quantized input type.
scales = op.Cast(scales, to=ir.DataType.FLOAT)
if zero_points is not None:
zero_points = op.Cast(zero_points, to=dtype)
dequantized = op.DequantizeLinear(input, scales, zero_points, axis=axis)
if out_dtype in (-1, None):
# PyTorch defaults to float32, which DequantizeLinear already produces.
return dequantized
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}"
return op.Cast(dequantized, to=out_dtype)
112 changes: 112 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo
import torch

# Importing this module registers the quantized_decomposed::* operators used below.
import torch.ao.quantization.fx._decomposed # noqa: F401
import torchvision
from torch.onnx._internal.exporter import _testing

Expand Down Expand Up @@ -705,6 +708,115 @@ def forward(self, x):
)
_testing.assert_onnx_program(onnx_program)

def test_quantize_per_channel_int8(self):
class Model(torch.nn.Module):
def forward(self, x):
scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64)
zero_points = torch.tensor([0, 5, -3], dtype=torch.int64)
return torch.ops.quantized_decomposed.quantize_per_channel(
x, scales, zero_points, 0, -128, 127, torch.int8
)

x = torch.randn(3, 4) * 5
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_quantize_per_channel_uint8(self):
class Model(torch.nn.Module):
def forward(self, x):
scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64)
zero_points = torch.tensor([10, 128, 250], dtype=torch.int64)
return torch.ops.quantized_decomposed.quantize_per_channel(
x, scales, zero_points, 0, 0, 255, torch.uint8
)

x = torch.randn(3, 4) * 5
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_quantize_per_channel_non_zero_axis(self):
class Model(torch.nn.Module):
def forward(self, x):
scales = torch.tensor([0.1, 0.2, 0.05, 0.3], dtype=torch.float64)
zero_points = torch.tensor([0, 5, -3, 2], dtype=torch.int64)
return torch.ops.quantized_decomposed.quantize_per_channel(
x, scales, zero_points, 1, -128, 127, torch.int8
)

x = torch.randn(2, 4, 3) * 4
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_quantize_per_channel_clamps_to_quant_min_max(self):
# quant_min/quant_max are narrower than the int8 range, so values must be
# clamped to [-20, 20] to match the PyTorch reference semantics.
class Model(torch.nn.Module):
def forward(self, x):
scales = torch.tensor([0.1, 0.2], dtype=torch.float64)
zero_points = torch.tensor([0, 0], dtype=torch.int64)
return torch.ops.quantized_decomposed.quantize_per_channel(
x, scales, zero_points, 0, -20, 20, torch.int8
)

x = torch.randn(2, 4) * 10
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_dequantize_per_channel_int8(self):
class Model(torch.nn.Module):
def forward(self, q):
scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64)
zero_points = torch.tensor([0, 5, -3], dtype=torch.int64)
return torch.ops.quantized_decomposed.dequantize_per_channel(
q, scales, zero_points, 0, -128, 127, torch.int8
)

q = torch.randint(-128, 128, (3, 4), dtype=torch.int8)
onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_dequantize_per_channel_uint8(self):
class Model(torch.nn.Module):
def forward(self, q):
scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64)
zero_points = torch.tensor([10, 128, 250], dtype=torch.int64)
return torch.ops.quantized_decomposed.dequantize_per_channel(
q, scales, zero_points, 0, 0, 255, torch.uint8
)

q = torch.randint(0, 256, (3, 4), dtype=torch.uint8)
onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_dequantize_per_channel_non_zero_axis(self):
class Model(torch.nn.Module):
def forward(self, q):
scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64)
zero_points = torch.tensor([1, 2, 3], dtype=torch.int64)
return torch.ops.quantized_decomposed.dequantize_per_channel(
q, scales, zero_points, 2, -128, 127, torch.int8
)

q = torch.randint(-128, 128, (2, 4, 3), dtype=torch.int8)
onnx_program = torch.onnx.export(Model(), (q,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_quantize_dequantize_per_channel_roundtrip(self):
class Model(torch.nn.Module):
def forward(self, x):
scales = torch.tensor([0.1, 0.2, 0.05], dtype=torch.float64)
zero_points = torch.tensor([0, 5, -3], dtype=torch.int64)
q = torch.ops.quantized_decomposed.quantize_per_channel(
x, scales, zero_points, 0, -128, 127, torch.int8
)
return torch.ops.quantized_decomposed.dequantize_per_channel(
q, scales, zero_points, 0, -128, 127, torch.int8
)

x = torch.randn(3, 4) * 5
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

@parameterized.parameterized.expand(
[
# Multiple advanced indices, all 1D tensors.
Expand Down
Loading