Source code for torch.nn.qat.modules.conv
import torch
import torch.nn as nn
from torch.nn.intrinsic import ConvReLU2d, ConvReLU3d
[docs]class Conv2d(nn.Conv2d):
r"""
A Conv2d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv2d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
for documentation.
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros', qconfig=None,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode,
**factory_kwargs)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
[docs] @classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
if type(mod) == ConvReLU2d:
mod = mod[0]
qconfig = mod.qconfig
qat_conv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
stride=mod.stride, padding=mod.padding, dilation=mod.dilation,
groups=mod.groups, bias=mod.bias is not None,
padding_mode=mod.padding_mode, qconfig=qconfig)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv
def to_float(self):
conv = torch.nn.Conv2d(
self.in_channels,
self.out_channels,
self.kernel_size, # type: ignore[arg-type]
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.dilation, # type: ignore[arg-type]
self.groups,
self.bias is not None,
self.padding_mode)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
return conv
[docs]class Conv3d(nn.Conv3d):
r"""
A Conv3d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv3d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
for documentation.
Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv3d
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
qconfig=None,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
**factory_kwargs
)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
[docs] @classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if type(mod) == ConvReLU3d:
mod = mod[0]
qconfig = mod.qconfig
qat_conv = cls(
mod.in_channels,
mod.out_channels,
mod.kernel_size,
stride=mod.stride,
padding=mod.padding,
dilation=mod.dilation,
groups=mod.groups,
bias=mod.bias is not None,
padding_mode=mod.padding_mode,
qconfig=qconfig,
)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv
def to_float(self):
conv = torch.nn.Conv3d(
self.in_channels,
self.out_channels,
self.kernel_size, # type: ignore[arg-type]
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.dilation, # type: ignore[arg-type]
self.groups,
self.bias is not None,
self.padding_mode)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
return conv