import collections import os from itertools import repeat from typing import Any import random import cv2 import torch import torch.nn as nn import torch.nn.functional as F from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img from lama_cleaner.model.base import InpaintModel from torch.nn.functional import conv2d, conv_transpose2d import torch.utils.checkpoint as checkpoint import numpy as np from lama_cleaner.schema import Config class EasyDict(dict): """Convenience class that behaves like a dict but allows access with the attribute syntax.""" def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: del self[name] activation_funcs = { 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), } def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_2tuple = _ntuple(2) def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): """Slow reference implementation of `bias_act()` using standard TensorFlow ops. """ assert isinstance(x, torch.Tensor) assert clamp is None or clamp >= 0 spec = activation_funcs[act] alpha = float(alpha if alpha is not None else spec.def_alpha) gain = float(gain if gain is not None else spec.def_gain) clamp = float(clamp if clamp is not None else -1) # Add bias. if b is not None: assert isinstance(b, torch.Tensor) and b.ndim == 1 assert 0 <= dim < x.ndim assert b.shape[0] == x.shape[dim] x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) # Evaluate activation function. alpha = float(alpha) x = spec.func(x, alpha=alpha) # Scale by gain. gain = float(gain) if gain != 1: x = x * gain # Clamp. if clamp >= 0: x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type return x def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'): r"""Fused bias and activation function. Adds bias `b` to activation tensor `x`, evaluates activation function `act`, and scales the result by `gain`. Each of the steps is optional. In most cases, the fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports first and second order gradients, but not third order gradients. Args: x: Input activation tensor. Can be of any shape. b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The shape must be known, and it must match the dimension of `x` corresponding to `dim`. dim: The dimension in `x` corresponding to the elements of `b`. The value of `dim` is ignored if `b` is not specified. act: Name of the activation function to evaluate, or `"linear"` to disable. Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. See `activation_funcs` for a full list. `None` is not allowed. alpha: Shape parameter for the activation function, or `None` to use the default. gain: Scaling factor for the output tensor, or `None` to use default. See `activation_funcs` for the default scaling of each activation function. If unsure, consider specifying 1. clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable the clamping (default). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the same shape and datatype as `x`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] fw = f.shape[-1] fh = f.shape[0] fw = int(fw) fh = int(fh) assert fw >= 1 and fh >= 1 return fw, fh def _get_weight_shape(w): shape = [int(sz) for sz in w.shape] return shape def _parse_scaling(scaling): if isinstance(scaling, int): scaling = [scaling, scaling] assert isinstance(scaling, (list, tuple)) assert all(isinstance(x, int) for x in scaling) sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, int) for x in padding) if len(padding) == 2: padx, pady = padding padding = [padx, padx, pady, pady] padx0, padx1, pady0, pady1 = padding return padx0, padx1, pady0, pady1 def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. Args: f: Torch tensor, numpy array, or python list of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), `[]` (impulse), or `None` (identity). device: Result device (default: cpu). normalize: Normalize the filter so that it retains the magnitude for constant input signal (DC)? (default: True). flip_filter: Flip the filter? (default: False). gain: Overall scaling factor for signal magnitude (default: 1). separable: Return a separable filter? (default: select automatically). Returns: Float32 tensor of the shape `[filter_height, filter_width]` (non-separable) or `[filter_taps]` (separable). """ # Validate. if f is None: f = 1 f = torch.as_tensor(f, dtype=torch.float32) assert f.ndim in [0, 1, 2] assert f.numel() > 0 if f.ndim == 0: f = f[np.newaxis] # Separable? if separable is None: separable = (f.ndim == 1 and f.numel() >= 8) if f.ndim == 1 and not separable: f = f.ger(f) assert f.ndim == (1 if separable else 2) # Apply normalize, flip, gain, and device. if normalize: f /= f.sum() if flip_filter: f = f.flip(list(range(f.ndim))) f = f * (gain ** (f.ndim / 2)) f = f.to(device=device) return f def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Pad, upsample, filter, and downsample a batch of 2D images. Performs the following sequence of operations for each channel: 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 2. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it so that the footprint of all output pixels lies within the input image. 4. Downsample the image by keeping every Nth pixel (`down`). This sequence of operations bears close resemblance to scipy.signal.upfirdn(). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ # assert isinstance(x, torch.Tensor) # assert impl in ['ref', 'cuda'] return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. """ # Validate arguments. assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] assert f.dtype == torch.float32 and not f.requires_grad batch_size, num_channels, in_height, in_width = x.shape # upx, upy = _parse_scaling(up) # downx, downy = _parse_scaling(down) upx, upy = up, up downx, downy = down, down # padx0, padx1, pady0, pady1 = _parse_padding(padding) padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3] # Upsample by inserting zeros. x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] # Setup filter. f = f * (gain ** (f.ndim / 2)) f = f.to(x.dtype) if not flip_filter: f = f.flip(list(range(f.ndim))) # Convolve with the filter. f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) if f.ndim == 4: x = conv2d(input=x, weight=f, groups=num_channels) else: x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) # Downsample by throwing away pixels. x = x[:, :, ::downy, ::downx] return x def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Upsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a multiple of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the output. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ upx, upy = _parse_scaling(up) # upx, upy = up, up padx0, padx1, pady0, pady1 = _parse_padding(padding) # padx0, padx1, pady0, pady1 = padding, padding, padding, padding fw, fh = _get_filter_size(f) p = [ padx0 + (fw + upx - 1) // 2, padx1 + (fw - upx) // 2, pady0 + (fh + upy - 1) // 2, pady1 + (fh - upy) // 2, ] return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Downsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a fraction of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the input. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ downx, downy = _parse_scaling(down) # padx0, padx1, pady0, pady1 = _parse_padding(padding) padx0, padx1, pady0, pady1 = padding, padding, padding, padding fw, fh = _get_filter_size(f) p = [ padx0 + (fw - downx + 1) // 2, padx1 + (fw - downx) // 2, pady0 + (fh - downy + 1) // 2, pady1 + (fh - downy) // 2, ] return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. """ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) # Flip weight if requested. if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). w = w.flip([2, 3]) # Workaround performance pitfall in cuDNN 8.0.5, triggered when using # 1x1 kernel + memory_format=channels_last + less than 64 channels. if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: if out_channels <= 4 and groups == 1: in_shape = x.shape x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) else: x = x.to(memory_format=torch.contiguous_format) w = w.to(memory_format=torch.contiguous_format) x = conv2d(x, w, groups=groups) return x.to(memory_format=torch.channels_last) # Otherwise => execute using conv2d_gradfix. op = conv_transpose2d if transpose else conv2d return op(x, w, stride=stride, padding=padding, groups=groups) def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): r"""2D convolution with optional up/downsampling. Padding is performed only once at the beginning, not between the operations. Args: x: Input tensor of shape `[batch_size, in_channels, in_height, in_width]`. w: Weight tensor of shape `[out_channels, in_channels//groups, kernel_height, kernel_width]`. f: Low-pass filter for up/downsampling. Must be prepared beforehand by calling setup_filter(). None = identity (default). up: Integer upsampling factor (default: 1). down: Integer downsampling factor (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). groups: Split input channels into N groups (default: 1). flip_weight: False = convolution, True = correlation (default: True). flip_filter: False = convolution, True = correlation (default: False). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ # Validate arguments. assert isinstance(x, torch.Tensor) and (x.ndim == 4) assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) assert isinstance(up, int) and (up >= 1) assert isinstance(down, int) and (down >= 1) # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) fw, fh = _get_filter_size(f) # px0, px1, py0, py1 = _parse_padding(padding) px0, px1, py0, py1 = padding, padding, padding, padding # Adjust padding to account for up/downsampling. if up > 1: px0 += (fw + up - 1) // 2 px1 += (fw - up) // 2 py0 += (fh + up - 1) // 2 py1 += (fh - up) // 2 if down > 1: px0 += (fw - down + 1) // 2 px1 += (fw - down) // 2 py0 += (fh - down + 1) // 2 py1 += (fh - down) // 2 # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. if kw == 1 and kh == 1 and (down > 1 and up == 1): x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) return x # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. if kw == 1 and kh == 1 and (up > 1 and down == 1): x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) return x # Fast path: downsampling only => use strided convolution. if down > 1 and up == 1: x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) return x # Fast path: upsampling with optional downsampling => use transpose strided convolution. if up > 1: if groups == 1: w = w.transpose(0, 1) else: w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) w = w.transpose(1, 2) w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) px0 -= kw - 1 px1 -= kw - up py0 -= kh - 1 py1 -= kh - up pxt = max(min(-px0, -px1), 0) pyt = max(min(-py0, -py1), 0) x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, flip_filter=flip_filter) if down > 1: x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. if up == 1 and down == 1: if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) # Fallback: Generic reference implementation. x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) if down > 1: x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x # ---------------------------------------------------------------------------- def normalize_2nd_moment(x, dim=1, eps=1e-8): return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() class FullyConnectedLayer(nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias=True, # Apply additive bias before the activation function? activation='linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier=1, # Learning rate multiplier. bias_init=0, # Initial value for the additive bias. ): super().__init__() self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None self.activation = activation self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight * self.weight_gain b = self.bias if b is not None and self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: # out = torch.addmm(b.unsqueeze(0), x, w.t()) x = x.matmul(w.t()) out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) else: x = x.matmul(w.t()) out = bias_act(x, b, act=self.activation, dim=x.ndim - 1) return out class Conv2dLayer(nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. kernel_size, # Width and height of the convolution kernel. bias=True, # Apply additive bias before the activation function? activation='linear', # Activation function: 'relu', 'lrelu', etc. up=1, # Integer upsampling factor. down=1, # Integer downsampling factor. resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. conv_clamp=None, # Clamp the output to +-X, None = disable clamping. trainable=True, # Update the weights of this layer during training? ): super().__init__() self.activation = activation self.up = up self.down = down self.register_buffer('resample_filter', setup_filter(resample_filter)) self.conv_clamp = conv_clamp self.padding = kernel_size // 2 self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) self.act_gain = activation_funcs[activation].def_gain weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]) bias = torch.zeros([out_channels]) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) self.bias = torch.nn.Parameter(bias) if bias is not None else None else: self.register_buffer('weight', weight) if bias is not None: self.register_buffer('bias', bias) else: self.bias = None def forward(self, x, gain=1): w = self.weight * self.weight_gain x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down, padding=self.padding) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) return out class ModulatedConv2d(nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. kernel_size, # Width and height of the convolution kernel. style_dim, # dimension of the style code demodulate=True, # perfrom demodulation up=1, # Integer upsampling factor. down=1, # Integer downsampling factor. resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. conv_clamp=None, # Clamp the output to +-X, None = disable clamping. ): super().__init__() self.demodulate = demodulate self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size])) self.out_channels = out_channels self.kernel_size = kernel_size self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) self.padding = self.kernel_size // 2 self.up = up self.down = down self.register_buffer('resample_filter', setup_filter(resample_filter)) self.conv_clamp = conv_clamp self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1) def forward(self, x, style): batch, in_channels, height, width = x.shape style = self.affine(style).view(batch, 1, in_channels, 1, 1) weight = self.weight * self.weight_gain * style if self.demodulate: decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt() weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1) weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size) x = x.view(1, batch * in_channels, height, width) x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, groups=batch) out = x.view(batch, self.out_channels, *x.shape[2:]) return out class StyleConv(torch.nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. style_dim, # Intermediate latent (W) dimensionality. resolution, # Resolution of this layer. kernel_size=3, # Convolution kernel size. up=1, # Integer upsampling factor. use_noise=False, # Enable noise input? activation='lrelu', # Activation function: 'relu', 'lrelu', etc. resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. demodulate=True, # perform demodulation ): super().__init__() self.conv = ModulatedConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, style_dim=style_dim, demodulate=demodulate, up=up, resample_filter=resample_filter, conv_clamp=conv_clamp) self.use_noise = use_noise self.resolution = resolution if use_noise: self.register_buffer('noise_const', torch.randn([resolution, resolution])) self.noise_strength = torch.nn.Parameter(torch.zeros([])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) self.activation = activation self.act_gain = activation_funcs[activation].def_gain self.conv_clamp = conv_clamp def forward(self, x, style, noise_mode='random', gain=1): x = self.conv(x, style) assert noise_mode in ['random', 'const', 'none'] if self.use_noise: if noise_mode == 'random': xh, xw = x.size()[-2:] noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \ * self.noise_strength if noise_mode == 'const': noise = self.noise_const * self.noise_strength x = x + noise act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) return out class ToRGB(torch.nn.Module): def __init__(self, in_channels, out_channels, style_dim, kernel_size=1, resample_filter=[1, 3, 3, 1], conv_clamp=None, demodulate=False): super().__init__() self.conv = ModulatedConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, style_dim=style_dim, demodulate=demodulate, resample_filter=resample_filter, conv_clamp=conv_clamp) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) self.register_buffer('resample_filter', setup_filter(resample_filter)) self.conv_clamp = conv_clamp def forward(self, x, style, skip=None): x = self.conv(x, style) out = bias_act(x, self.bias, clamp=self.conv_clamp) if skip is not None: if skip.shape != out.shape: skip = upsample2d(skip, self.resample_filter) out = out + skip return out def get_style_code(a, b): return torch.cat([a, b], dim=1) class DecBlockFirst(nn.Module): def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): super().__init__() self.fc = FullyConnectedLayer(in_features=in_channels * 2, out_features=in_channels * 4 ** 2, activation=activation) self.conv = StyleConv(in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, resolution=4, kernel_size=3, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.toRGB = ToRGB(in_channels=out_channels, out_channels=img_channels, style_dim=style_dim, kernel_size=1, demodulate=False, ) def forward(self, x, ws, gs, E_features, noise_mode='random'): x = self.fc(x).view(x.shape[0], -1, 4, 4) x = x + E_features[2] style = get_style_code(ws[:, 0], gs) x = self.conv(x, style, noise_mode=noise_mode) style = get_style_code(ws[:, 1], gs) img = self.toRGB(x, style, skip=None) return x, img class DecBlockFirstV2(nn.Module): def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): super().__init__() self.conv0 = Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, ) self.conv1 = StyleConv(in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, resolution=4, kernel_size=3, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.toRGB = ToRGB(in_channels=out_channels, out_channels=img_channels, style_dim=style_dim, kernel_size=1, demodulate=False, ) def forward(self, x, ws, gs, E_features, noise_mode='random'): # x = self.fc(x).view(x.shape[0], -1, 4, 4) x = self.conv0(x) x = x + E_features[2] style = get_style_code(ws[:, 0], gs) x = self.conv1(x, style, noise_mode=noise_mode) style = get_style_code(ws[:, 1], gs) img = self.toRGB(x, style, skip=None) return x, img class DecBlock(nn.Module): def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): # res = 2, ..., resolution_log2 super().__init__() self.res = res self.conv0 = StyleConv(in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, resolution=2 ** res, kernel_size=3, up=2, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.conv1 = StyleConv(in_channels=out_channels, out_channels=out_channels, style_dim=style_dim, resolution=2 ** res, kernel_size=3, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.toRGB = ToRGB(in_channels=out_channels, out_channels=img_channels, style_dim=style_dim, kernel_size=1, demodulate=False, ) def forward(self, x, img, ws, gs, E_features, noise_mode='random'): style = get_style_code(ws[:, self.res * 2 - 5], gs) x = self.conv0(x, style, noise_mode=noise_mode) x = x + E_features[self.res] style = get_style_code(ws[:, self.res * 2 - 4], gs) x = self.conv1(x, style, noise_mode=noise_mode) style = get_style_code(ws[:, self.res * 2 - 3], gs) img = self.toRGB(x, style, skip=img) return x, img class MappingNet(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output, None = do not broadcast. num_layers=8, # Number of mapping layers. embed_features=None, # Label embedding dimensionality, None = same as w_dim. layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. activation='lrelu', # Activation function: 'relu', 'lrelu', etc. lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta if embed_features is None: embed_features = w_dim if c_dim == 0: embed_features = 0 if layer_features is None: layer_features = w_dim features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) if num_ws is not None and w_avg_beta is not None: self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x class DisFromRGB(nn.Module): def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 super().__init__() self.conv = Conv2dLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation, ) def forward(self, x): return self.conv(x) class DisBlock(nn.Module): def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 super().__init__() self.conv0 = Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, ) self.conv1 = Conv2dLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=3, down=2, activation=activation, ) self.skip = Conv2dLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, down=2, bias=False, ) def forward(self, x): skip = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x) x = self.conv1(x, gain=np.sqrt(0.5)) out = skip + x return out class MinibatchStdLayer(torch.nn.Module): def __init__(self, group_size, num_channels=1): super().__init__() self.group_size = group_size self.num_channels = num_channels def forward(self, x): N, C, H, W = x.shape G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N F = self.num_channels c = C // F y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. return x class Discriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. channel_base=32768, # Overall multiplier for the number of channels. channel_max=512, # Maximum number of channels in any layer. channel_decay=1, cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. activation='lrelu', mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_channels = img_channels resolution_log2 = int(np.log2(img_resolution)) assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 self.resolution_log2 = resolution_log2 def nf(stage): return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max) if cmap_dim == None: cmap_dim = nf(2) if c_dim == 0: cmap_dim = 0 self.cmap_dim = cmap_dim if c_dim > 0: self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None) Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] for res in range(resolution_log2, 2, -1): Dis.append(DisBlock(nf(res), nf(res - 1), activation)) if mbstd_num_channels > 0: Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation)) self.Dis = nn.Sequential(*Dis) self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) def forward(self, images_in, masks_in, c): x = torch.cat([masks_in - 0.5, images_in], dim=1) x = self.Dis(x) x = self.fc1(self.fc0(x.flatten(start_dim=1))) if self.c_dim > 0: cmap = self.mapping(None, c) if self.cmap_dim > 0: x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) return x def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} return NF[2 ** stage] class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu') self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features) def forward(self, x): x = self.fc1(x) x = self.fc2(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size: int, H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) # B = windows.shape[0] / (H * W / window_size / window_size) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class Conv2dLayerPartial(nn.Module): def __init__(self, in_channels, # Number of input channels. out_channels, # Number of output channels. kernel_size, # Width and height of the convolution kernel. bias=True, # Apply additive bias before the activation function? activation='linear', # Activation function: 'relu', 'lrelu', etc. up=1, # Integer upsampling factor. down=1, # Integer downsampling factor. resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. conv_clamp=None, # Clamp the output to +-X, None = disable clamping. trainable=True, # Update the weights of this layer during training? ): super().__init__() self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter, conv_clamp, trainable) self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size) self.slide_winsize = kernel_size ** 2 self.stride = down self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0 def forward(self, x, mask=None): if mask is not None: with torch.no_grad(): if self.weight_maskUpdater.type() != x.type(): self.weight_maskUpdater = self.weight_maskUpdater.to(x) update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding) mask_ratio = self.slide_winsize / (update_mask + 1e-8) update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1 mask_ratio = torch.mul(mask_ratio, update_mask) x = self.conv(x) x = torch.mul(x, mask_ratio) return x, update_mask else: x = self.conv(x) return x, None class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = FullyConnectedLayer(in_features=dim, out_features=dim) self.k = FullyConnectedLayer(in_features=dim, out_features=dim) self.v = FullyConnectedLayer(in_features=dim, out_features=dim) self.proj = FullyConnectedLayer(in_features=dim, out_features=dim) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask_windows=None, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape norm_x = F.normalize(x, p=2.0, dim=-1) q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1) v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) attn = (q @ k) * self.scale if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) if mask_windows is not None: attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1) attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill( attn_mask_windows == 1, float(0.0)) with torch.no_grad(): mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1) attn = self.softmax(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) return x, mask_windows class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" if self.shift_size > 0: down_ratio = 1 self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu') mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: attn_mask = self.calculate_mask(self.input_resolution) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def calculate_mask(self, x_size): # calculate attention mask for SW-MSA H, W = x_size img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask def forward(self, x, x_size, mask=None): # H, W = self.input_resolution H, W = x_size B, L, C = x.shape # assert L == H * W, "input feature has wrong size" shortcut = x x = x.view(B, H, W, C) if mask is not None: mask = mask.view(B, H, W, 1) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) if mask is not None: shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x if mask is not None: shifted_mask = mask # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C if mask is not None: mask_windows = window_partition(shifted_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1) else: mask_windows = None # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size if self.input_resolution == x_size: attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C else: attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to( x.device)) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C if mask is not None: mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1) shifted_mask = window_reverse(mask_windows, self.window_size, H, W) # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) if mask is not None: mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if mask is not None: mask = shifted_mask x = x.view(B, H * W, C) if mask is not None: mask = mask.view(B, H * W, 1) # FFN x = self.fuse(torch.cat([shortcut, x], dim=-1)) x = self.mlp(x) return x, mask class PatchMerging(nn.Module): def __init__(self, in_channels, out_channels, down=2): super().__init__() self.conv = Conv2dLayerPartial(in_channels=in_channels, out_channels=out_channels, kernel_size=3, activation='lrelu', down=down, ) self.down = down def forward(self, x, x_size, mask=None): x = token2feature(x, x_size) if mask is not None: mask = token2feature(mask, x_size) x, mask = self.conv(x, mask) if self.down != 1: ratio = 1 / self.down x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio)) x = feature2token(x) if mask is not None: mask = feature2token(mask) return x, x_size, mask class PatchUpsampling(nn.Module): def __init__(self, in_channels, out_channels, up=2): super().__init__() self.conv = Conv2dLayerPartial(in_channels=in_channels, out_channels=out_channels, kernel_size=3, activation='lrelu', up=up, ) self.up = up def forward(self, x, x_size, mask=None): x = token2feature(x, x_size) if mask is not None: mask = token2feature(mask, x_size) x, mask = self.conv(x, mask) if self.up != 1: x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up)) x = feature2token(x) if mask is not None: mask = feature2token(mask) return x, x_size, mask class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1, mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # patch merging layer if downsample is not None: # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) self.downsample = downsample else: self.downsample = None # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, down_ratio=down_ratio, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu') def forward(self, x, x_size, mask=None): if self.downsample is not None: x, x_size, mask = self.downsample(x, x_size, mask) identity = x for blk in self.blocks: if self.use_checkpoint: x, mask = checkpoint.checkpoint(blk, x, x_size, mask) else: x, mask = blk(x, x_size, mask) if mask is not None: mask = token2feature(mask, x_size) x, mask = self.conv(token2feature(x, x_size), mask) x = feature2token(x) + identity if mask is not None: mask = feature2token(mask) return x, x_size, mask class ToToken(nn.Module): def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1): super().__init__() self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size, activation='lrelu') def forward(self, x, mask): x, mask = self.proj(x, mask) return x, mask class EncFromRGB(nn.Module): def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 super().__init__() self.conv0 = Conv2dLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation, ) self.conv1 = Conv2dLayer(in_channels=out_channels, out_channels=out_channels, kernel_size=3, activation=activation, ) def forward(self, x): x = self.conv0(x) x = self.conv1(x) return x class ConvBlockDown(nn.Module): def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log super().__init__() self.conv0 = Conv2dLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=3, activation=activation, down=2, ) self.conv1 = Conv2dLayer(in_channels=out_channels, out_channels=out_channels, kernel_size=3, activation=activation, ) def forward(self, x): x = self.conv0(x) x = self.conv1(x) return x def token2feature(x, x_size): B, N, C = x.shape h, w = x_size x = x.permute(0, 2, 1).reshape(B, C, h, w) return x def feature2token(x): B, C, H, W = x.shape x = x.view(B, C, -1).transpose(1, 2) return x class Encoder(nn.Module): def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1): super().__init__() self.resolution = [] for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16 res = 2 ** i self.resolution.append(res) if i == res_log2: block = EncFromRGB(img_channels * 2 + 1, nf(i), activation) else: block = ConvBlockDown(nf(i + 1), nf(i), activation) setattr(self, 'EncConv_Block_%dx%d' % (res, res), block) def forward(self, x): out = {} for res in self.resolution: res_log2 = int(np.log2(res)) x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x) out[res_log2] = x return out class ToStyle(nn.Module): def __init__(self, in_channels, out_channels, activation, drop_rate): super().__init__() self.conv = nn.Sequential( Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, down=2), Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, down=2), Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, down=2), ) self.pool = nn.AdaptiveAvgPool2d(1) self.fc = FullyConnectedLayer(in_features=in_channels, out_features=out_channels, activation=activation) # self.dropout = nn.Dropout(drop_rate) def forward(self, x): x = self.conv(x) x = self.pool(x) x = self.fc(x.flatten(start_dim=1)) # x = self.dropout(x) return x class DecBlockFirstV2(nn.Module): def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): super().__init__() self.res = res self.conv0 = Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, ) self.conv1 = StyleConv(in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, resolution=2 ** res, kernel_size=3, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.toRGB = ToRGB(in_channels=out_channels, out_channels=img_channels, style_dim=style_dim, kernel_size=1, demodulate=False, ) def forward(self, x, ws, gs, E_features, noise_mode='random'): # x = self.fc(x).view(x.shape[0], -1, 4, 4) x = self.conv0(x) x = x + E_features[self.res] style = get_style_code(ws[:, 0], gs) x = self.conv1(x, style, noise_mode=noise_mode) style = get_style_code(ws[:, 1], gs) img = self.toRGB(x, style, skip=None) return x, img class DecBlock(nn.Module): def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): # res = 4, ..., resolution_log2 super().__init__() self.res = res self.conv0 = StyleConv(in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, resolution=2 ** res, kernel_size=3, up=2, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.conv1 = StyleConv(in_channels=out_channels, out_channels=out_channels, style_dim=style_dim, resolution=2 ** res, kernel_size=3, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.toRGB = ToRGB(in_channels=out_channels, out_channels=img_channels, style_dim=style_dim, kernel_size=1, demodulate=False, ) def forward(self, x, img, ws, gs, E_features, noise_mode='random'): style = get_style_code(ws[:, self.res * 2 - 9], gs) x = self.conv0(x, style, noise_mode=noise_mode) x = x + E_features[self.res] style = get_style_code(ws[:, self.res * 2 - 8], gs) x = self.conv1(x, style, noise_mode=noise_mode) style = get_style_code(ws[:, self.res * 2 - 7], gs) img = self.toRGB(x, style, skip=img) return x, img class Decoder(nn.Module): def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels): super().__init__() self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels) for res in range(5, res_log2 + 1): setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res), DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels)) self.res_log2 = res_log2 def forward(self, x, ws, gs, E_features, noise_mode='random'): x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode) for res in range(5, self.res_log2 + 1): block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res)) x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode) return img class DecStyleBlock(nn.Module): def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): super().__init__() self.res = res self.conv0 = StyleConv(in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, resolution=2 ** res, kernel_size=3, up=2, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.conv1 = StyleConv(in_channels=out_channels, out_channels=out_channels, style_dim=style_dim, resolution=2 ** res, kernel_size=3, use_noise=use_noise, activation=activation, demodulate=demodulate, ) self.toRGB = ToRGB(in_channels=out_channels, out_channels=img_channels, style_dim=style_dim, kernel_size=1, demodulate=False, ) def forward(self, x, img, style, skip, noise_mode='random'): x = self.conv0(x, style, noise_mode=noise_mode) x = x + skip x = self.conv1(x, style, noise_mode=noise_mode) img = self.toRGB(x, style, skip=img) return x, img class FirstStage(nn.Module): def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True, activation='lrelu'): super().__init__() res = 64 self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3, activation=activation) self.enc_conv = nn.ModuleList() down_time = int(np.log2(img_resolution // res)) # 根据图片尺寸构建 swim transformer 的层数 for i in range(down_time): # from input size to 64 self.enc_conv.append( Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation) ) # from 64 -> 16 -> 64 depths = [2, 3, 4, 3, 2] ratios = [1, 1 / 2, 1 / 2, 2, 2] num_heads = 6 window_sizes = [8, 16, 16, 16, 8] drop_path_rate = 0.1 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] self.tran = nn.ModuleList() for i, depth in enumerate(depths): res = int(res * ratios[i]) if ratios[i] < 1: merge = PatchMerging(dim, dim, down=int(1 / ratios[i])) elif ratios[i] > 1: merge = PatchUpsampling(dim, dim, up=ratios[i]) else: merge = None self.tran.append( BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads, window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], downsample=merge) ) # global style down_conv = [] for i in range(int(np.log2(16))): down_conv.append( Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)) down_conv.append(nn.AdaptiveAvgPool2d((1, 1))) self.down_conv = nn.Sequential(*down_conv) self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation) self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation) self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation) style_dim = dim * 3 self.dec_conv = nn.ModuleList() for i in range(down_time): # from 64 to input size res = res * 2 self.dec_conv.append( DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels)) def forward(self, images_in, masks_in, ws, noise_mode='random'): x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1) skips = [] x, mask = self.conv_first(x, masks_in) # input size skips.append(x) for i, block in enumerate(self.enc_conv): # input size to 64 x, mask = block(x, mask) if i != len(self.enc_conv) - 1: skips.append(x) x_size = x.size()[-2:] x = feature2token(x) mask = feature2token(mask) mid = len(self.tran) // 2 for i, block in enumerate(self.tran): # 64 to 16 if i < mid: x, x_size, mask = block(x, x_size, mask) skips.append(x) elif i > mid: x, x_size, mask = block(x, x_size, None) x = x + skips[mid - i] else: x, x_size, mask = block(x, x_size, None) mul_map = torch.ones_like(x) * 0.5 mul_map = F.dropout(mul_map, training=True) ws = self.ws_style(ws[:, -1]) add_n = self.to_square(ws).unsqueeze(1) add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze( -1) x = x * mul_map + add_n * (1 - mul_map) gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1)) style = torch.cat([gs, ws], dim=1) x = token2feature(x, x_size).contiguous() img = None for i, block in enumerate(self.dec_conv): x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode) # ensemble img = img * (1 - masks_in) + images_in * masks_in return img class SynthesisNet(nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output image resolution. img_channels=3, # Number of color channels. channel_base=32768, # Overall multiplier for the number of channels. channel_decay=1.0, channel_max=512, # Maximum number of channels in any layer. activation='lrelu', # Activation function: 'relu', 'lrelu', etc. drop_rate=0.5, use_noise=False, demodulate=True, ): super().__init__() resolution_log2 = int(np.log2(img_resolution)) assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 self.num_layers = resolution_log2 * 2 - 3 * 2 self.img_resolution = img_resolution self.resolution_log2 = resolution_log2 # first stage self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False, demodulate=demodulate) # second stage self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16) self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation) self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate) style_dim = w_dim + nf(2) * 2 self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels) def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False): out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode) # encoder x = images_in * masks_in + out_stg1 * (1 - masks_in) x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1) E_features = self.enc(x) fea_16 = E_features[4] mul_map = torch.ones_like(fea_16) * 0.5 mul_map = F.dropout(mul_map, training=True) add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1) add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False) fea_16 = fea_16 * mul_map + add_n * (1 - mul_map) E_features[4] = fea_16 # style gs = self.to_style(fea_16) # decoder img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode) # ensemble img = img * (1 - masks_in) + images_in * masks_in if not return_stg1: return img else: return img, out_stg1 class Generator(nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # resolution of generated image img_channels, # Number of input color channels. synthesis_kwargs={}, # Arguments for SynthesisNetwork. mapping_kwargs={}, # Arguments for MappingNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels self.synthesis = SynthesisNet(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) self.mapping = MappingNet(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.synthesis.num_layers, **mapping_kwargs) def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, noise_mode='none', return_stg1=False): ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, skip_w_avg_update=skip_w_avg_update) img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode) return img class Discriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. channel_base=32768, # Overall multiplier for the number of channels. channel_max=512, # Maximum number of channels in any layer. channel_decay=1, cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. activation='lrelu', mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_channels = img_channels resolution_log2 = int(np.log2(img_resolution)) assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 self.resolution_log2 = resolution_log2 if cmap_dim == None: cmap_dim = nf(2) if c_dim == 0: cmap_dim = 0 self.cmap_dim = cmap_dim if c_dim > 0: self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None) Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] for res in range(resolution_log2, 2, -1): Dis.append(DisBlock(nf(res), nf(res - 1), activation)) if mbstd_num_channels > 0: Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation)) self.Dis = nn.Sequential(*Dis) self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) # for 64x64 Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)] for res in range(resolution_log2, 2, -1): Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation)) if mbstd_num_channels > 0: Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation)) self.Dis_stg1 = nn.Sequential(*Dis_stg1) self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation) self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim) def forward(self, images_in, masks_in, images_stg1, c): x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1)) x = self.fc1(self.fc0(x.flatten(start_dim=1))) x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1)) x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1))) if self.c_dim > 0: cmap = self.mapping(None, c) if self.cmap_dim > 0: x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) return x, x_stg1 MAT_MODEL_URL = os.environ.get( "MAT_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth", ) class MAT(InpaintModel): min_size = 512 pad_mod = 512 pad_to_square = True def init_model(self, device): seed = 240 # pick up a random number random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3) self.model = load_model(G, MAT_MODEL_URL, device) self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512] self.label = torch.zeros([1, self.model.c_dim], device=device) @staticmethod def is_downloaded() -> bool: return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL)) def forward(self, image, mask, config: Config): """Input images and output images have same size images: [H, W, C] RGB masks: [H, W] mask area == 255 return: BGR IMAGE """ image = norm_img(image) # [0, 1] image = image * 2 - 1 # [0, 1] -> [-1, 1] mask = (mask > 127) * 255 mask = 255 - mask mask = norm_img(mask) image = torch.from_numpy(image).unsqueeze(0).to(self.device) mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none') output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8) output = output[0].cpu().numpy() cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return cur_res