diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index b90831a..09ece1a 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -11,6 +11,15 @@ import torch from lama_cleaner.const import MPS_SUPPORT_MODELS from loguru import logger from torch.hub import download_url_to_file, get_dir +import hashlib + + +def md5sum(filename): + md5 = hashlib.md5() + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(128 * md5.block_size), b""): + md5.update(chunk) + return md5.hexdigest() def switch_mps_device(model_name, device): @@ -33,12 +42,22 @@ def get_cache_path_by_url(url): return cached_file -def download_model(url): +def download_model(url, model_md5: str = None): cached_file = get_cache_path_by_url(url) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None download_url_to_file(url, cached_file, hash_prefix, progress=True) + if model_md5: + _md5 = md5sum(cached_file) + if model_md5 == _md5: + logger.info(f"Download model success, md5: {_md5}") + else: + logger.error( + f"Download model failed, md5: {_md5}, expected: {model_md5}. Please delete model at {cached_file} and restart lama-cleaner" + ) + exit(-1) + return cached_file @@ -48,42 +67,49 @@ def ceil_modulo(x, mod): return (x // mod + 1) * mod -def \ - load_jit_model(url_or_path, device): +def handle_error(model_path, model_md5, e): + _md5 = md5sum(model_path) + if _md5 != model_md5: + logger.error( + f"Model md5: {_md5}, expected: {model_md5}, please delete {model_path} and restart lama-cleaner." + f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" + ) + else: + logger.error( + f"Failed to load model {model_path}," + f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}" + ) + exit(-1) + + +def load_jit_model(url_or_path, device, model_md5: str): if os.path.exists(url_or_path): model_path = url_or_path else: - model_path = download_model(url_or_path) + model_path = download_model(url_or_path, model_md5) + logger.info(f"Loading model from: {model_path}") try: model = torch.jit.load(model_path, map_location="cpu").to(device) except Exception as e: - logger.error( - f"Failed to load {model_path}, please delete model and restart lama-cleaner.\n" - f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" - f"If all above operations doesn't work, please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}" - ) - exit(-1) + handle_error(model_path, model_md5, e) model.eval() return model -def load_model(model: torch.nn.Module, url_or_path, device): +def load_model(model: torch.nn.Module, url_or_path, device, model_md5): if os.path.exists(url_or_path): model_path = url_or_path else: - model_path = download_model(url_or_path) + model_path = download_model(url_or_path, model_md5) try: + logger.info(f"Loading model from: {model_path}") state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) model.to(device) - logger.info(f"Load model from: {model_path}") - except: - logger.error( - f"Failed to load {model_path}, delete model and restart lama-cleaner" - ) - exit(-1) + except Exception as e: + handle_error(model_path, model_md5, e) model.eval() return model diff --git a/lama_cleaner/interactive_seg.py b/lama_cleaner/interactive_seg.py index 4470ffe..9b07349 100644 --- a/lama_cleaner/interactive_seg.py +++ b/lama_cleaner/interactive_seg.py @@ -156,12 +156,13 @@ INTERACTIVE_SEG_MODEL_URL = os.environ.get( "INTERACTIVE_SEG_MODEL_URL", "https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt", ) +INTERACTIVE_SEG_MODEL_MD5 = os.environ.get("INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf") class InteractiveSeg: def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3): device = torch.device('cpu') - model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device).eval() + model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5).eval() self.predictor = ISPredictor(model, device, infer_size=infer_size, open_kernel_size=open_kernel_size, diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py index 0b17775..07292c6 100644 --- a/lama_cleaner/model/fcf.py +++ b/lama_cleaner/model/fcf.py @@ -8,23 +8,42 @@ import torch.fft as fft from lama_cleaner.schema import Config -from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size +from lama_cleaner.helper import ( + load_model, + get_cache_path_by_url, + norm_img, + boxes_from_mask, + resize_max_size, +) from lama_cleaner.model.base import InpaintModel from torch import conv2d, nn import torch.nn.functional as F -from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \ - MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d +from lama_cleaner.model.utils import ( + setup_filter, + _parse_scaling, + _parse_padding, + Conv2dLayer, + FullyConnectedLayer, + MinibatchStdLayer, + activation_funcs, + conv2d_resample, + bias_act, + upsample2d, + normalize_2nd_moment, + downsample2d, +) -def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): assert isinstance(x, torch.Tensor) - return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + return _upfirdn2d_ref( + x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ) def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. - """ + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" # Validate arguments. assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: @@ -42,8 +61,15 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. - x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) - x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] + x = torch.nn.functional.pad( + x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)] + ) + x = x[ + :, + :, + max(-pady0, 0) : x.shape[2] - max(-pady1, 0), + max(-padx0, 0) : x.shape[3] - max(-padx1, 0), + ] # Setup filter. f = f * (gain ** (f.ndim / 2)) @@ -65,19 +91,20 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): class EncoderEpilogue(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels. - cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. - z_dim, # Output Latent (Z) dimensionality. - resolution, # Resolution of this block. - img_channels, # Number of input color channels. - architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. - mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. - mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - ): - assert architecture in ['orig', 'skip', 'resnet'] + def __init__( + self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture="resnet", # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ["orig", "skip", "resnet"] super().__init__() self.in_channels = in_channels self.cmap_dim = cmap_dim @@ -85,13 +112,27 @@ class EncoderEpilogue(torch.nn.Module): self.img_channels = img_channels self.architecture = architecture - if architecture == 'skip': - self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation) - self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, - num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None - self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, - conv_clamp=conv_clamp) - self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation) + if architecture == "skip": + self.fromrgb = Conv2dLayer( + self.img_channels, in_channels, kernel_size=1, activation=activation + ) + self.mbstd = ( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + if mbstd_num_channels > 0 + else None + ) + self.conv = Conv2dLayer( + in_channels + mbstd_num_channels, + in_channels, + kernel_size=3, + activation=activation, + conv_clamp=conv_clamp, + ) + self.fc = FullyConnectedLayer( + in_channels * (resolution**2), z_dim, activation=activation + ) self.dropout = torch.nn.Dropout(p=0.5) def forward(self, x, cmap, force_fp32=False): @@ -118,23 +159,29 @@ class EncoderEpilogue(torch.nn.Module): class EncoderBlock(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels, 0 = first block. - tmp_channels, # Number of intermediate channels. - out_channels, # Number of output channels. - resolution, # Resolution of this block. - img_channels, # Number of input color channels. - first_layer_idx, # Index of the first layer. - architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - use_fp16=False, # Use FP16 for this block? - fp16_channels_last=False, # Use channels-last memory format with FP16? - freeze_layers=0, # Freeze-D: Number of layers to freeze. - ): + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + freeze_layers=0, # Freeze-D: Number of layers to freeze. + ): assert in_channels in [0, tmp_channels] - assert architecture in ['orig', 'skip', 'resnet'] + assert architecture in ["orig", "skip", "resnet"] super().__init__() self.in_channels = in_channels self.resolution = resolution @@ -142,42 +189,73 @@ class EncoderBlock(torch.nn.Module): self.first_layer_idx = first_layer_idx self.architecture = architecture self.use_fp16 = use_fp16 - self.channels_last = (use_fp16 and fp16_channels_last) - self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.channels_last = use_fp16 and fp16_channels_last + self.register_buffer("resample_filter", setup_filter(resample_filter)) self.num_layers = 0 def trainable_gen(): while True: layer_idx = self.first_layer_idx + self.num_layers - trainable = (layer_idx >= freeze_layers) + trainable = layer_idx >= freeze_layers self.num_layers += 1 yield trainable trainable_iter = trainable_gen() if in_channels == 0: - self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation, - trainable=next(trainable_iter), conv_clamp=conv_clamp, - channels_last=self.channels_last) + self.fromrgb = Conv2dLayer( + self.img_channels, + tmp_channels, + kernel_size=1, + activation=activation, + trainable=next(trainable_iter), + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) - self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, - trainable=next(trainable_iter), conv_clamp=conv_clamp, - channels_last=self.channels_last) + self.conv0 = Conv2dLayer( + tmp_channels, + tmp_channels, + kernel_size=3, + activation=activation, + trainable=next(trainable_iter), + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) - self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, - trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, - channels_last=self.channels_last) + self.conv1 = Conv2dLayer( + tmp_channels, + out_channels, + kernel_size=3, + activation=activation, + down=2, + trainable=next(trainable_iter), + resample_filter=resample_filter, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) - if architecture == 'resnet': - self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, - trainable=next(trainable_iter), resample_filter=resample_filter, - channels_last=self.channels_last) + if architecture == "resnet": + self.skip = Conv2dLayer( + tmp_channels, + out_channels, + kernel_size=1, + bias=False, + down=2, + trainable=next(trainable_iter), + resample_filter=resample_filter, + channels_last=self.channels_last, + ) def forward(self, x, img, force_fp32=False): # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 dtype = torch.float32 - memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + memory_format = ( + torch.channels_last + if self.channels_last and not force_fp32 + else torch.contiguous_format + ) # Input. if x is not None: @@ -188,10 +266,14 @@ class EncoderBlock(torch.nn.Module): img = img.to(dtype=dtype, memory_format=memory_format) y = self.fromrgb(img) x = x + y if x is not None else y - img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None + img = ( + downsample2d(img, self.resample_filter) + if self.architecture == "skip" + else None + ) # Main layers. - if self.architecture == 'resnet': + if self.architecture == "resnet": y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x) feat = x.clone() @@ -207,29 +289,35 @@ class EncoderBlock(torch.nn.Module): class EncoderNetwork(torch.nn.Module): - def __init__(self, - c_dim, # Conditioning label (C) dimensionality. - z_dim, # Input latent (Z) dimensionality. - img_resolution, # Input resolution. - img_channels, # Number of input color channels. - architecture='orig', # Architecture: 'orig', 'skip', 'resnet'. - channel_base=16384, # Overall multiplier for the number of channels. - channel_max=512, # Maximum number of channels in any layer. - num_fp16_res=0, # Use FP16 for the N highest resolutions. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. - block_kwargs={}, # Arguments for DiscriminatorBlock. - mapping_kwargs={}, # Arguments for MappingNetwork. - epilogue_kwargs={}, # Arguments for EncoderEpilogue. - ): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + z_dim, # Input latent (Z) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture="orig", # Architecture: 'orig', 'skip', 'resnet'. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for EncoderEpilogue. + ): super().__init__() self.c_dim = c_dim self.z_dim = z_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels - self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] - channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + self.block_resolutions = [ + 2**i for i in range(self.img_resolution_log2, 2, -1) + ] + channels_dict = { + res: min(channel_base // res, channel_max) + for res in self.block_resolutions + [4] + } fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: @@ -237,29 +325,51 @@ class EncoderNetwork(torch.nn.Module): if c_dim == 0: cmap_dim = 0 - common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + common_kwargs = dict( + img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp + ) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] - use_fp16 = (res >= fp16_resolution) + use_fp16 = res >= fp16_resolution use_fp16 = False - block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res, - first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) - setattr(self, f'b{res}', block) + block = EncoderBlock( + in_channels, + tmp_channels, + out_channels, + resolution=res, + first_layer_idx=cur_layer_idx, + use_fp16=use_fp16, + **block_kwargs, + **common_kwargs, + ) + setattr(self, f"b{res}", block) cur_layer_idx += block.num_layers if c_dim > 0: - self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, - **mapping_kwargs) - self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs, - **common_kwargs) + self.mapping = MappingNetwork( + z_dim=0, + c_dim=c_dim, + w_dim=cmap_dim, + num_ws=None, + w_avg_beta=None, + **mapping_kwargs, + ) + self.b4 = EncoderEpilogue( + channels_dict[4], + cmap_dim=cmap_dim, + z_dim=z_dim * 2, + resolution=4, + **epilogue_kwargs, + **common_kwargs, + ) def forward(self, img, c, **block_kwargs): x = None feats = {} for res in self.block_resolutions: - block = getattr(self, f'b{res}') + block = getattr(self, f"b{res}") x, img, feat = block(x, img, **block_kwargs) feats[res] = feat @@ -270,8 +380,9 @@ class EncoderNetwork(torch.nn.Module): feats[4] = const_e B, _ = x.shape - z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype, - device=x.device) ## Noise for Co-Modulation + z = torch.zeros( + (B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device + ) ## Noise for Co-Modulation return x, z, feats @@ -310,11 +421,15 @@ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c def _unbroadcast(x, shape): extra_dims = x.ndim - len(shape) assert extra_dims >= 0 - dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + dim = [ + i + for i in range(x.ndim) + if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) + ] if len(dim): x = x.sum(dim=dim, keepdim=True) if extra_dims: - x = x.reshape(-1, *x.shape[extra_dims + 1:]) + x = x.reshape(-1, *x.shape[extra_dims + 1 :]) assert x.shape == shape return x @@ -338,9 +453,12 @@ def modulated_conv2d( # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: - weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3], - keepdim=True)) # max_Ikk - styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + weight = weight * ( + 1 + / np.sqrt(in_channels * kh * kw) + / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True) + ) # max_Ikk + styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None @@ -355,10 +473,19 @@ def modulated_conv2d( # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) - x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, - padding=padding, flip_weight=flip_weight) + x = conv2d_resample.conv2d_resample( + x=x, + w=weight.to(x.dtype), + f=resample_filter, + up=up, + down=down, + padding=padding, + flip_weight=flip_weight, + ) if demodulate and noise is not None: - x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + x = fma( + x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype) + ) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: @@ -369,8 +496,16 @@ def modulated_conv2d( batch_size = int(batch_size) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) - x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, - groups=batch_size, flip_weight=flip_weight) + x = conv2d_resample( + x=x, + w=w.to(x.dtype), + f=resample_filter, + up=up, + down=down, + padding=padding, + groups=batch_size, + flip_weight=flip_weight, + ) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) @@ -378,54 +513,77 @@ def modulated_conv2d( class SynthesisLayer(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - w_dim, # Intermediate latent (W) dimensionality. - resolution, # Resolution of this layer. - kernel_size=3, # Convolution kernel size. - up=1, # Integer upsampling factor. - use_noise=True, # Enable noise input? - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - channels_last=False, # Use channels_last format for the weights? - ): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=True, # Enable noise input? + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last=False, # Use channels_last format for the weights? + ): super().__init__() self.resolution = resolution self.up = up self.use_noise = use_noise self.activation = activation self.conv_clamp = conv_clamp - self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.register_buffer("resample_filter", setup_filter(resample_filter)) self.padding = kernel_size // 2 self.act_gain = activation_funcs[activation].def_gain self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) - memory_format = torch.channels_last if channels_last else torch.contiguous_format + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) self.weight = torch.nn.Parameter( - torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + ) if use_noise: - self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.register_buffer("noise_const", torch.randn([resolution, resolution])) self.noise_strength = torch.nn.Parameter(torch.zeros([])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) - def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1): - assert noise_mode in ['random', 'const', 'none'] + def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1): + assert noise_mode in ["random", "const", "none"] in_resolution = self.resolution // self.up styles = self.affine(w) noise = None - if self.use_noise and noise_mode == 'random': - noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], - device=x.device) * self.noise_strength - if self.use_noise and noise_mode == 'const': + if self.use_noise and noise_mode == "random": + noise = ( + torch.randn( + [x.shape[0], 1, self.resolution, self.resolution], device=x.device + ) + * self.noise_strength + ) + if self.use_noise and noise_mode == "const": noise = self.noise_const * self.noise_strength - flip_weight = (self.up == 1) # slightly faster - x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, - padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, - fused_modconv=fused_modconv) + flip_weight = self.up == 1 # slightly faster + x = modulated_conv2d( + x=x, + weight=self.weight, + styles=styles, + noise=noise, + up=self.up, + padding=self.padding, + resample_filter=self.resample_filter, + flip_weight=flip_weight, + fused_modconv=fused_modconv, + ) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None @@ -438,33 +596,52 @@ class SynthesisLayer(torch.nn.Module): class ToRGBLayer(torch.nn.Module): - def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): + def __init__( + self, + in_channels, + out_channels, + w_dim, + kernel_size=1, + conv_clamp=None, + channels_last=False, + ): super().__init__() self.conv_clamp = conv_clamp self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) - memory_format = torch.channels_last if channels_last else torch.contiguous_format + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) self.weight = torch.nn.Parameter( - torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) + ) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) - self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) def forward(self, x, w, fused_modconv=True): styles = self.affine(w) * self.weight_gain - x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = modulated_conv2d( + x=x, + weight=self.weight, + styles=styles, + demodulate=False, + fused_modconv=fused_modconv, + ) x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) return x class SynthesisForeword(torch.nn.Module): - def __init__(self, - z_dim, # Output Latent (Z) dimensionality. - resolution, # Resolution of this block. - in_channels, - img_channels, # Number of input color channels. - architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - - ): + def __init__( + self, + z_dim, # Output Latent (Z) dimensionality. + resolution, # Resolution of this block. + in_channels, + img_channels, # Number of input color channels. + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + ): super().__init__() self.in_channels = in_channels self.z_dim = z_dim @@ -472,11 +649,20 @@ class SynthesisForeword(torch.nn.Module): self.img_channels = img_channels self.architecture = architecture - self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation) - self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4) + self.fc = FullyConnectedLayer( + self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation + ) + self.conv = SynthesisLayer( + self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4 + ) - if architecture == 'skip': - self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3) + if architecture == "skip": + self.torgb = ToRGBLayer( + self.in_channels, + self.img_channels, + kernel_size=1, + w_dim=(z_dim // 2) * 3, + ) def forward(self, x, ws, feats, img, force_fp32=False): _ = force_fp32 # unused @@ -505,7 +691,7 @@ class SynthesisForeword(torch.nn.Module): mod_vector.append(x_global.clone()) mod_vector = torch.cat(mod_vector, dim=1) - if self.architecture == 'skip': + if self.architecture == "skip": img = self.torgb(x, mod_vector) img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) @@ -521,7 +707,7 @@ class SELayer(nn.Module): nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=False), nn.Linear(channel // reduction, channel, bias=False), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, x): @@ -533,16 +719,32 @@ class SELayer(nn.Module): class FourierUnit(nn.Module): - - def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', - spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): + def __init__( + self, + in_channels, + out_channels, + groups=1, + spatial_scale_factor=None, + spatial_scale_mode="bilinear", + spectral_pos_encoding=False, + use_se=False, + se_kwargs=None, + ffc3d=False, + fft_norm="ortho", + ): # bn_layer not used super(FourierUnit, self).__init__() self.groups = groups - self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), - out_channels=out_channels * 2, - kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) + self.conv_layer = torch.nn.Conv2d( + in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + groups=self.groups, + bias=False, + ) self.relu = torch.nn.ReLU(inplace=False) # squeeze and excitation block @@ -563,8 +765,12 @@ class FourierUnit(nn.Module): if self.spatial_scale_factor is not None: orig_size = x.shape[-2:] - x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, - align_corners=False) + x = F.interpolate( + x, + scale_factor=self.spatial_scale_factor, + mode=self.spatial_scale_mode, + align_corners=False, + ) r_size = x.size() # (batch, c, h, w/2+1, 2) @@ -572,12 +778,26 @@ class FourierUnit(nn.Module): ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) ffted = torch.stack((ffted.real, ffted.imag), dim=-1) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) - ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + ffted = ffted.view( + ( + batch, + -1, + ) + + ffted.size()[3:] + ) if self.spectral_pos_encoding: height, width = ffted.shape[-2:] - coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) - coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) + coords_vert = ( + torch.linspace(0, 1, height)[None, None, :, None] + .expand(batch, 1, height, width) + .to(ffted) + ) + coords_hor = ( + torch.linspace(0, 1, width)[None, None, None, :] + .expand(batch, 1, height, width) + .to(ffted) + ) ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) if self.use_se: @@ -586,22 +806,46 @@ class FourierUnit(nn.Module): ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) ffted = self.relu(ffted) - ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( - 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = ( + ffted.view( + ( + batch, + -1, + 2, + ) + + ffted.size()[2:] + ) + .permute(0, 1, 3, 4, 2) + .contiguous() + ) # (batch,c, t, h, w/2+1, 2) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] - output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + output = torch.fft.irfftn( + ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm + ) if self.spatial_scale_factor is not None: - output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) + output = F.interpolate( + output, + size=orig_size, + mode=self.spatial_scale_mode, + align_corners=False, + ) return output class SpectralTransform(nn.Module): - - def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): + def __init__( + self, + in_channels, + out_channels, + stride=1, + groups=1, + enable_lfu=True, + **fu_kwargs, + ): # bn_layer not used super(SpectralTransform, self).__init__() self.enable_lfu = enable_lfu @@ -612,18 +856,18 @@ class SpectralTransform(nn.Module): self.stride = stride self.conv1 = nn.Sequential( - nn.Conv2d(in_channels, out_channels // - 2, kernel_size=1, groups=groups, bias=False), + nn.Conv2d( + in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False + ), # nn.BatchNorm2d(out_channels // 2), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) - self.fu = FourierUnit( - out_channels // 2, out_channels // 2, groups, **fu_kwargs) + self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs) if self.enable_lfu: - self.lfu = FourierUnit( - out_channels // 2, out_channels // 2, groups) + self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups) self.conv2 = torch.nn.Conv2d( - out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False + ) def forward(self, x): @@ -635,10 +879,10 @@ class SpectralTransform(nn.Module): n, c, h, w = x.shape split_no = 2 split_s = h // split_no - xs = torch.cat(torch.split( - x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() - xs = torch.cat(torch.split(xs, split_s, dim=-1), - dim=1).contiguous() + xs = torch.cat( + torch.split(x[:, : c // 4], split_s, dim=-2), dim=1 + ).contiguous() + xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous() xs = self.lfu(xs) xs = xs.repeat(1, 1, split_no, split_no).contiguous() else: @@ -650,11 +894,23 @@ class SpectralTransform(nn.Module): class FFC(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size, - ratio_gin, ratio_gout, stride=1, padding=0, - dilation=1, groups=1, bias=False, enable_lfu=True, - padding_type='reflect', gated=False, **spectral_kwargs): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + enable_lfu=True, + padding_type="reflect", + gated=False, + **spectral_kwargs, + ): super(FFC, self).__init__() assert stride == 1 or stride == 2, "Stride should be 1 or 2." @@ -672,20 +928,55 @@ class FFC(nn.Module): self.global_in_num = in_cg module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d - self.convl2l = module(in_cl, out_cl, kernel_size, - stride, padding, dilation, groups, bias, padding_mode=padding_type) + self.convl2l = module( + in_cl, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d - self.convl2g = module(in_cl, out_cg, kernel_size, - stride, padding, dilation, groups, bias, padding_mode=padding_type) + self.convl2g = module( + in_cl, + out_cg, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d - self.convg2l = module(in_cg, out_cl, kernel_size, - stride, padding, dilation, groups, bias, padding_mode=padding_type) + self.convg2l = module( + in_cg, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type, + ) module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform self.convg2g = module( - in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) + in_cg, + out_cg, + stride, + 1 if groups == 1 else groups // 2, + enable_lfu, + **spectral_kwargs, + ) self.gated = gated - module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + module = ( + nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + ) self.gate = module(in_channels, 2, 1) def forward(self, x, fname=None): @@ -714,17 +1005,40 @@ class FFC(nn.Module): class FFC_BN_ACT(nn.Module): - - def __init__(self, in_channels, out_channels, - kernel_size, ratio_gin, ratio_gout, - stride=1, padding=0, dilation=1, groups=1, bias=False, - norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity, - padding_type='reflect', - enable_lfu=True, **kwargs): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + norm_layer=nn.SyncBatchNorm, + activation_layer=nn.Identity, + padding_type="reflect", + enable_lfu=True, + **kwargs, + ): super(FFC_BN_ACT, self).__init__() - self.ffc = FFC(in_channels, out_channels, kernel_size, - ratio_gin, ratio_gout, stride, padding, dilation, - groups, bias, enable_lfu, padding_type=padding_type, **kwargs) + self.ffc = FFC( + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride, + padding, + dilation, + groups, + bias, + enable_lfu, + padding_type=padding_type, + **kwargs, + ) lnorm = nn.Identity if ratio_gout == 1 else norm_layer gnorm = nn.Identity if ratio_gout == 0 else norm_layer global_channels = int(out_channels * ratio_gout) @@ -737,31 +1051,61 @@ class FFC_BN_ACT(nn.Module): self.act_g = gact(inplace=True) def forward(self, x, fname=None): - x_l, x_g = self.ffc(x, fname=fname, ) + x_l, x_g = self.ffc( + x, + fname=fname, + ) x_l = self.act_l(x_l) x_g = self.act_g(x_g) return x_l, x_g class FFCResnetBlock(nn.Module): - def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, - spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75): + def __init__( + self, + dim, + padding_type, + norm_layer, + activation_layer=nn.ReLU, + dilation=1, + spatial_transform_kwargs=None, + inline=False, + ratio_gin=0.75, + ratio_gout=0.75, + ): super().__init__() - self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, - norm_layer=norm_layer, - activation_layer=activation_layer, - padding_type=padding_type, - ratio_gin=ratio_gin, ratio_gout=ratio_gout) - self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, - norm_layer=norm_layer, - activation_layer=activation_layer, - padding_type=padding_type, - ratio_gin=ratio_gin, ratio_gout=ratio_gout) + self.conv1 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) + self.conv2 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) self.inline = inline def forward(self, x, fname=None): if self.inline: - x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:] + x_l, x_g = ( + x[:, : -self.conv1.ffc.global_in_num], + x[:, -self.conv1.ffc.global_in_num :], + ) else: x_l, x_g = x if type(x) is tuple else (x, 0) @@ -788,35 +1132,41 @@ class ConcatTupleLayer(nn.Module): class FFCBlock(torch.nn.Module): - def __init__(self, - dim, # Number of output/input channels. - kernel_size, # Width and height of the convolution kernel. - padding, - ratio_gin=0.75, - ratio_gout=0.75, - activation='linear', # Activation function: 'relu', 'lrelu', etc. - ): + def __init__( + self, + dim, # Number of output/input channels. + kernel_size, # Width and height of the convolution kernel. + padding, + ratio_gin=0.75, + ratio_gout=0.75, + activation="linear", # Activation function: 'relu', 'lrelu', etc. + ): super().__init__() - if activation == 'linear': + if activation == "linear": self.activation = nn.Identity else: self.activation = nn.ReLU self.padding = padding self.kernel_size = kernel_size - self.ffc_block = FFCResnetBlock(dim=dim, - padding_type='reflect', - norm_layer=nn.SyncBatchNorm, - activation_layer=self.activation, - dilation=1, - ratio_gin=ratio_gin, - ratio_gout=ratio_gout) + self.ffc_block = FFCResnetBlock( + dim=dim, + padding_type="reflect", + norm_layer=nn.SyncBatchNorm, + activation_layer=self.activation, + dilation=1, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) self.concat_layer = ConcatTupleLayer() def forward(self, gen_ft, mask, fname=None): x = gen_ft.float() - x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:] + x_l, x_g = ( + x[:, : -self.ffc_block.conv1.ffc.global_in_num], + x[:, -self.ffc_block.conv1.ffc.global_in_num :], + ) id_l, id_g = x_l, x_g x_l, x_g = self.ffc_block((x_l, x_g), fname=fname) @@ -827,17 +1177,24 @@ class FFCBlock(torch.nn.Module): class FFCSkipLayer(torch.nn.Module): - def __init__(self, - dim, # Number of input/output channels. - kernel_size=3, # Convolution kernel size. - ratio_gin=0.75, - ratio_gout=0.75, - ): + def __init__( + self, + dim, # Number of input/output channels. + kernel_size=3, # Convolution kernel size. + ratio_gin=0.75, + ratio_gout=0.75, + ): super().__init__() self.padding = kernel_size // 2 - self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU, - padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout) + self.ffc_act = FFCBlock( + dim=dim, + kernel_size=kernel_size, + activation=nn.ReLU, + padding=self.padding, + ratio_gin=ratio_gin, + ratio_gout=ratio_gout, + ) def forward(self, gen_ft, mask, fname=None): x = self.ffc_act(gen_ft, mask, fname=fname) @@ -845,21 +1202,27 @@ class FFCSkipLayer(torch.nn.Module): class SynthesisBlock(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels, 0 = first block. - out_channels, # Number of output channels. - w_dim, # Intermediate latent (W) dimensionality. - resolution, # Resolution of this block. - img_channels, # Number of output color channels. - is_last, # Is this the last block? - architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. - resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - use_fp16=False, # Use FP16 for this block? - fp16_channels_last=False, # Use channels-last memory format with FP16? - **layer_kwargs, # Arguments for SynthesisLayer. - ): - assert architecture in ['orig', 'skip', 'resnet'] + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture="skip", # Architecture: 'orig', 'skip', 'resnet'. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ["orig", "skip", "resnet"] super().__init__() self.in_channels = in_channels self.w_dim = w_dim @@ -868,8 +1231,8 @@ class SynthesisBlock(torch.nn.Module): self.is_last = is_last self.architecture = architecture self.use_fp16 = use_fp16 - self.channels_last = (use_fp16 and fp16_channels_last) - self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.channels_last = use_fp16 and fp16_channels_last + self.register_buffer("resample_filter", setup_filter(resample_filter)) self.num_conv = 0 self.num_torgb = 0 self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1} @@ -880,68 +1243,134 @@ class SynthesisBlock(torch.nn.Module): self.ffc_skip.append(FFCSkipLayer(dim=out_channels)) if in_channels == 0: - self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + self.const = torch.nn.Parameter( + torch.randn([out_channels, resolution, resolution]) + ) if in_channels != 0: - self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2, - resample_filter=resample_filter, conv_clamp=conv_clamp, - channels_last=self.channels_last, **layer_kwargs) + self.conv0 = SynthesisLayer( + in_channels, + out_channels, + w_dim=w_dim * 3, + resolution=resolution, + up=2, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + **layer_kwargs, + ) self.num_conv += 1 - self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, - conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.conv1 = SynthesisLayer( + out_channels, + out_channels, + w_dim=w_dim * 3, + resolution=resolution, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + **layer_kwargs, + ) self.num_conv += 1 - if is_last or architecture == 'skip': - self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3, - conv_clamp=conv_clamp, channels_last=self.channels_last) + if is_last or architecture == "skip": + self.torgb = ToRGBLayer( + out_channels, + img_channels, + w_dim=w_dim * 3, + conv_clamp=conv_clamp, + channels_last=self.channels_last, + ) self.num_torgb += 1 - if in_channels != 0 and architecture == 'resnet': - self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, - resample_filter=resample_filter, channels_last=self.channels_last) + if in_channels != 0 and architecture == "resnet": + self.skip = Conv2dLayer( + in_channels, + out_channels, + kernel_size=1, + bias=False, + up=2, + resample_filter=resample_filter, + channels_last=self.channels_last, + ) - def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs): + def forward( + self, + x, + mask, + feats, + img, + ws, + fname=None, + force_fp32=False, + fused_modconv=None, + **layer_kwargs, + ): dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 dtype = torch.float32 - memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + memory_format = ( + torch.channels_last + if self.channels_last and not force_fp32 + else torch.contiguous_format + ) if fused_modconv is None: - fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + fused_modconv = (not self.training) and ( + dtype == torch.float32 or int(x.shape[0]) == 1 + ) x = x.to(dtype=dtype, memory_format=memory_format) - x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format) + x_skip = ( + feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format) + ) # Main layers. if self.in_channels == 0: x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs) - elif self.architecture == 'resnet': + elif self.architecture == "resnet": y = self.skip(x, gain=np.sqrt(0.5)) - x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv0( + x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) if len(self.ffc_skip) > 0: - mask = F.interpolate(mask, size=x_skip.shape[2:], ) + mask = F.interpolate( + mask, + size=x_skip.shape[2:], + ) z = x + x_skip for fres in self.ffc_skip: z = fres(z, mask) x = x + z else: x = x + x_skip - x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = self.conv1( + x, + ws[1].clone(), + fused_modconv=fused_modconv, + gain=np.sqrt(0.5), + **layer_kwargs, + ) x = y.add_(x) else: - x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv0( + x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) if len(self.ffc_skip) > 0: - mask = F.interpolate(mask, size=x_skip.shape[2:], ) + mask = F.interpolate( + mask, + size=x_skip.shape[2:], + ) z = x + x_skip for fres in self.ffc_skip: z = fres(z, mask) x = x + z else: x = x + x_skip - x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1( + x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs + ) # ToRGB. if img is not None: img = upsample2d(img, self.resample_filter) - if self.is_last or self.architecture == 'skip': + if self.is_last or self.architecture == "skip": y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y @@ -953,28 +1382,37 @@ class SynthesisBlock(torch.nn.Module): class SynthesisNetwork(torch.nn.Module): - def __init__(self, - w_dim, # Intermediate latent (W) dimensionality. - z_dim, # Output Latent (Z) dimensionality. - img_resolution, # Output image resolution. - img_channels, # Number of color channels. - channel_base=16384, # Overall multiplier for the number of channels. - channel_max=512, # Maximum number of channels in any layer. - num_fp16_res=0, # Use FP16 for the N highest resolutions. - **block_kwargs, # Arguments for SynthesisBlock. - ): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + z_dim, # Output Latent (Z) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base=16384, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=0, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 super().__init__() self.w_dim = w_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels - self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)] - channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + self.block_resolutions = [ + 2**i for i in range(3, self.img_resolution_log2 + 1) + ] + channels_dict = { + res: min(channel_base // res, channel_max) for res in self.block_resolutions + } fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) - self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max), - z_dim=z_dim * 2, resolution=4) + self.foreword = SynthesisForeword( + img_channels=img_channels, + in_channels=min(channel_base // 4, channel_max), + z_dim=z_dim * 2, + resolution=4, + ) self.num_ws = self.img_resolution_log2 * 2 - 2 for res in self.block_resolutions: @@ -983,12 +1421,20 @@ class SynthesisNetwork(torch.nn.Module): else: in_channels = min(channel_base // (res // 2), channel_max) out_channels = channels_dict[res] - use_fp16 = (res >= fp16_resolution) + use_fp16 = res >= fp16_resolution use_fp16 = False - is_last = (res == self.img_resolution) - block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, - img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) - setattr(self, f'b{res}', block) + is_last = res == self.img_resolution + block = SynthesisBlock( + in_channels, + out_channels, + w_dim=w_dim, + resolution=res, + img_channels=img_channels, + is_last=is_last, + use_fp16=use_fp16, + **block_kwargs, + ) + setattr(self, f"b{res}", block) def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): @@ -997,7 +1443,7 @@ class SynthesisNetwork(torch.nn.Module): x, img = self.foreword(x_global, ws, feats, img) for res in self.block_resolutions: - block = getattr(self, f'b{res}') + block = getattr(self, f"b{res}") mod_vector0 = [] mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5]) mod_vector0.append(x_global.clone()) @@ -1012,23 +1458,32 @@ class SynthesisNetwork(torch.nn.Module): mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3]) mod_vector_rgb.append(x_global.clone()) mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1) - x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs) + x, img = block( + x, + mask, + feats, + img, + (mod_vector0, mod_vector1, mod_vector_rgb), + fname=fname, + **block_kwargs, + ) return img class MappingNetwork(torch.nn.Module): - def __init__(self, - z_dim, # Input latent (Z) dimensionality, 0 = no latent. - c_dim, # Conditioning label (C) dimensionality, 0 = no label. - w_dim, # Intermediate latent (W) dimensionality. - num_ws, # Number of intermediate latents to output, None = do not broadcast. - num_layers=8, # Number of mapping layers. - embed_features=None, # Label embedding dimensionality, None = same as w_dim. - layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. - w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. - ): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim @@ -1043,23 +1498,32 @@ class MappingNetwork(torch.nn.Module): embed_features = 0 if layer_features is None: layer_features = w_dim - features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + features_list = ( + [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + ) if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] - layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) - setattr(self, f'fc{idx}', layer) + layer = FullyConnectedLayer( + in_features, + out_features, + activation=activation, + lr_multiplier=lr_multiplier, + ) + setattr(self, f"fc{idx}", layer) if num_ws is not None and w_avg_beta is not None: - self.register_buffer('w_avg', torch.zeros([w_dim])) + self.register_buffer("w_avg", torch.zeros([w_dim])) - def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): + def forward( + self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False + ): # Embed, normalize, and concat inputs. x = None - with torch.autograd.profiler.record_function('input'): + with torch.autograd.profiler.record_function("input"): if self.z_dim > 0: x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: @@ -1068,58 +1532,85 @@ class MappingNetwork(torch.nn.Module): # Main layers. for idx in range(self.num_layers): - layer = getattr(self, f'fc{idx}') + layer = getattr(self, f"fc{idx}") x = layer(x) # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: - with torch.autograd.profiler.record_function('update_w_avg'): - self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + with torch.autograd.profiler.record_function("update_w_avg"): + self.w_avg.copy_( + x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta) + ) # Broadcast. if self.num_ws is not None: - with torch.autograd.profiler.record_function('broadcast'): + with torch.autograd.profiler.record_function("broadcast"): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: - with torch.autograd.profiler.record_function('truncate'): + with torch.autograd.profiler.record_function("truncate"): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: - x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + x[:, :truncation_cutoff] = self.w_avg.lerp( + x[:, :truncation_cutoff], truncation_psi + ) return x class Generator(torch.nn.Module): - def __init__(self, - z_dim, # Input latent (Z) dimensionality. - c_dim, # Conditioning label (C) dimensionality. - w_dim, # Intermediate latent (W) dimensionality. - img_resolution, # Output resolution. - img_channels, # Number of output color channels. - encoder_kwargs={}, # Arguments for EncoderNetwork. - mapping_kwargs={}, # Arguments for MappingNetwork. - synthesis_kwargs={}, # Arguments for SynthesisNetwork. - ): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + encoder_kwargs={}, # Arguments for EncoderNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels - self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution, - img_channels=img_channels, **encoder_kwargs) - self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, - img_channels=img_channels, **synthesis_kwargs) + self.encoder = EncoderNetwork( + c_dim=c_dim, + z_dim=z_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **encoder_kwargs, + ) + self.synthesis = SynthesisNetwork( + z_dim=z_dim, + w_dim=w_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **synthesis_kwargs, + ) self.num_ws = self.synthesis.num_ws - self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + self.mapping = MappingNetwork( + z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs + ) - def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): + def forward( + self, + img, + c, + fname=None, + truncation_psi=1, + truncation_cutoff=None, + **synthesis_kwargs, + ): mask = img[:, -1].unsqueeze(1) x_global, z, feats = self.encoder(img, c) - ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) + ws = self.mapping( + z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff + ) img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) return img @@ -1128,6 +1619,7 @@ FCF_MODEL_URL = os.environ.get( "FCF_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth", ) +FCF_MODEL_MD5 = os.environ.get("FCF_MODEL_MD5", "3323152bc01bf1c56fd8aba74435a211") class FcF(InpaintModel): @@ -1145,10 +1637,23 @@ class FcF(InpaintModel): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256} - G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3, - synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2}) - self.model = load_model(G, FCF_MODEL_URL, device) + kwargs = { + "channel_base": 1 * 32768, + "channel_max": 512, + "num_fp16_res": 4, + "conv_clamp": 256, + } + G = Generator( + z_dim=512, + c_dim=0, + w_dim=512, + img_resolution=512, + img_channels=3, + synthesis_kwargs=kwargs, + encoder_kwargs=kwargs, + mapping_kwargs={"num_layers": 2}, + ) + self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5) self.label = torch.zeros([1, self.model.c_dim], device=device) @staticmethod @@ -1176,10 +1681,16 @@ class FcF(InpaintModel): inpaint_result = self._pad_forward(resize_image, resize_mask, config) # only paste masked area result - inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC) + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) original_pixel_indices = crop_mask < 127 - inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices] + inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][ + original_pixel_indices + ] crop_result.append((inpaint_result, crop_box)) @@ -1208,8 +1719,15 @@ class FcF(InpaintModel): erased_img = image * (1 - mask) input_image = torch.cat([0.5 - mask, erased_img], dim=1) - output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none') - output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8) + output = self.model( + input_image, self.label, truncation_psi=0.1, noise_mode="none" + ) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) output = output[0].cpu().numpy() cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return cur_res diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index 68d6010..bdcdf0d 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -16,6 +16,7 @@ LAMA_MODEL_URL = os.environ.get( "LAMA_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", ) +LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500") class LaMa(InpaintModel): @@ -23,7 +24,7 @@ class LaMa(InpaintModel): pad_mod = 8 def init_model(self, device, **kwargs): - self.model = load_jit_model(LAMA_MODEL_URL, device).eval() + self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval() @staticmethod def is_downloaded() -> bool: diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index c224069..a5b6d12 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -26,17 +26,27 @@ LDM_ENCODE_MODEL_URL = os.environ.get( "LDM_ENCODE_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt", ) +LDM_ENCODE_MODEL_MD5 = os.environ.get( + "LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296" +) LDM_DECODE_MODEL_URL = os.environ.get( "LDM_DECODE_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt", ) +LDM_DECODE_MODEL_MD5 = os.environ.get( + "LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c" +) LDM_DIFFUSION_MODEL_URL = os.environ.get( "LDM_DIFFUSION_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt", ) +LDM_DIFFUSION_MODEL_MD5 = os.environ.get( + "LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d" +) + class DDPM(nn.Module): # classic DDPM with Gaussian diffusion, in image space @@ -234,9 +244,15 @@ class LDM(InpaintModel): self.device = device def init_model(self, device, **kwargs): - self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) - self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) - self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) + self.diffusion_model = load_jit_model( + LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5 + ) + self.cond_stage_model_decode = load_jit_model( + LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5 + ) + self.cond_stage_model_encode = load_jit_model( + LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5 + ) if self.fp16 and "cuda" in str(device): self.diffusion_model = self.diffusion_model.half() self.cond_stage_model_decode = self.cond_stage_model_decode.half() diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py index e1ece6e..7726274 100644 --- a/lama_cleaner/model/manga.py +++ b/lama_cleaner/model/manga.py @@ -11,67 +11,21 @@ from lama_cleaner.helper import get_cache_path_by_url, load_jit_model from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config -# def norm(np_img): -# return np_img / 255 * 2 - 1.0 -# -# -# @torch.no_grad() -# def run(): -# name = 'manga_1080x740.jpg' -# img_p = f'/Users/qing/code/github/MangaInpainting/examples/test/imgs/{name}' -# mask_p = f'/Users/qing/code/github/MangaInpainting/examples/test/masks/mask_{name}' -# erika_model = torch.jit.load('erika.jit') -# manga_inpaintor_model = torch.jit.load('manga_inpaintor.jit') -# -# img = cv2.imread(img_p) -# gray_img = cv2.imread(img_p, cv2.IMREAD_GRAYSCALE) -# mask = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE) -# -# kernel = np.ones((9, 9), dtype=np.uint8) -# mask = cv2.dilate(mask, kernel, 2) -# # cv2.imwrite("mask.jpg", mask) -# # cv2.imshow('dilated_mask', cv2.hconcat([mask, dilated_mask])) -# # cv2.waitKey(0) -# # exit() -# -# # img = pad(img) -# gray_img = pad(gray_img).astype(np.float32) -# mask = pad(mask) -# -# # pad_mod = 16 -# import time -# start = time.time() -# y = erika_model(torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :])) -# y = torch.clamp(y, 0, 255) -# lines = y.cpu().numpy() -# print(f"erika_model time: {time.time() - start}") -# -# cv2.imwrite('lines.png', lines[0][0]) -# -# start = time.time() -# masks = torch.from_numpy(mask[np.newaxis, np.newaxis, :, :]) -# masks = torch.where(masks > 0.5, torch.tensor(1.0), torch.tensor(0.0)) -# noise = torch.randn_like(masks) -# -# images = torch.from_numpy(norm(gray_img)[np.newaxis, np.newaxis, :, :]) -# lines = torch.from_numpy(norm(lines)) -# -# outputs = manga_inpaintor_model(images, lines, masks, noise) -# print(f"manga_inpaintor_model time: {time.time() - start}") -# -# outputs_merged = (outputs * masks) + (images * (1 - masks)) -# outputs_merged = outputs_merged * 127.5 + 127.5 -# outputs_merged = outputs_merged.permute(0, 2, 3, 1)[0].detach().cpu().numpy().astype(np.uint8) -# cv2.imwrite(f'output_{name}', outputs_merged) - MANGA_INPAINTOR_MODEL_URL = os.environ.get( "MANGA_INPAINTOR_MODEL_URL", - "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit" + "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit", ) +MANGA_INPAINTOR_MODEL_MD5 = os.environ.get( + "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c" +) + MANGA_LINE_MODEL_URL = os.environ.get( "MANGA_LINE_MODEL_URL", - "https://github.com/Sanster/models/releases/download/manga/erika.jit" + "https://github.com/Sanster/models/releases/download/manga/erika.jit", +) +MANGA_LINE_MODEL_MD5 = os.environ.get( + "MANGA_LINE_MODEL_MD5", "8f157c142718f11e233d3750a65e0794" ) @@ -80,8 +34,12 @@ class Manga(InpaintModel): pad_mod = 16 def init_model(self, device, **kwargs): - self.inpaintor_model = load_jit_model(MANGA_INPAINTOR_MODEL_URL, device) - self.line_model = load_jit_model(MANGA_LINE_MODEL_URL, device) + self.inpaintor_model = load_jit_model( + MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5 + ) + self.line_model = load_jit_model( + MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5 + ) self.seed = 42 @staticmethod @@ -105,7 +63,9 @@ class Manga(InpaintModel): torch.cuda.manual_seed_all(seed) gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) - gray_img = torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)).to(self.device) + gray_img = torch.from_numpy( + gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32) + ).to(self.device) start = time.time() lines = self.line_model(gray_img) torch.cuda.empty_cache() diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py index 58b55bf..5c6de9c 100644 --- a/lama_cleaner/model/mat.py +++ b/lama_cleaner/model/mat.py @@ -10,34 +10,52 @@ import torch.utils.checkpoint as checkpoint from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img from lama_cleaner.model.base import InpaintModel -from lama_cleaner.model.utils import setup_filter, Conv2dLayer, FullyConnectedLayer, conv2d_resample, bias_act, \ - upsample2d, activation_funcs, MinibatchStdLayer, to_2tuple, normalize_2nd_moment +from lama_cleaner.model.utils import ( + setup_filter, + Conv2dLayer, + FullyConnectedLayer, + conv2d_resample, + bias_act, + upsample2d, + activation_funcs, + MinibatchStdLayer, + to_2tuple, + normalize_2nd_moment, +) from lama_cleaner.schema import Config class ModulatedConv2d(nn.Module): - def __init__(self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - kernel_size, # Width and height of the convolution kernel. - style_dim, # dimension of the style code - demodulate=True, # perfrom demodulation - up=1, # Integer upsampling factor. - down=1, # Integer downsampling factor. - resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output to +-X, None = disable clamping. - ): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + style_dim, # dimension of the style code + demodulate=True, # perfrom demodulation + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + ): super().__init__() self.demodulate = demodulate - self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size])) + self.weight = torch.nn.Parameter( + torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]) + ) self.out_channels = out_channels self.kernel_size = kernel_size - self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) self.padding = self.kernel_size // 2 self.up = up self.down = down - self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.register_buffer("resample_filter", setup_filter(resample_filter)) self.conv_clamp = conv_clamp self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1) @@ -51,44 +69,61 @@ class ModulatedConv2d(nn.Module): decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt() weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1) - weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size) + weight = weight.view( + batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size + ) x = x.view(1, batch * in_channels, height, width) - x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down, - padding=self.padding, groups=batch) + x = conv2d_resample( + x=x, + w=weight, + f=self.resample_filter, + up=self.up, + down=self.down, + padding=self.padding, + groups=batch, + ) out = x.view(batch, self.out_channels, *x.shape[2:]) return out class StyleConv(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - style_dim, # Intermediate latent (W) dimensionality. - resolution, # Resolution of this layer. - kernel_size=3, # Convolution kernel size. - up=1, # Integer upsampling factor. - use_noise=False, # Enable noise input? - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - demodulate=True, # perform demodulation - ): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + style_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=False, # Enable noise input? + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + demodulate=True, # perform demodulation + ): super().__init__() - self.conv = ModulatedConv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - style_dim=style_dim, - demodulate=demodulate, - up=up, - resample_filter=resample_filter, - conv_clamp=conv_clamp) + self.conv = ModulatedConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + up=up, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + ) self.use_noise = use_noise self.resolution = resolution if use_noise: - self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.register_buffer("noise_const", torch.randn([resolution, resolution])) self.noise_strength = torch.nn.Parameter(torch.zeros([])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) @@ -96,47 +131,55 @@ class StyleConv(torch.nn.Module): self.act_gain = activation_funcs[activation].def_gain self.conv_clamp = conv_clamp - def forward(self, x, style, noise_mode='random', gain=1): + def forward(self, x, style, noise_mode="random", gain=1): x = self.conv(x, style) - assert noise_mode in ['random', 'const', 'none'] + assert noise_mode in ["random", "const", "none"] if self.use_noise: - if noise_mode == 'random': + if noise_mode == "random": xh, xw = x.size()[-2:] - noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \ - * self.noise_strength - if noise_mode == 'const': + noise = ( + torch.randn([x.shape[0], 1, xh, xw], device=x.device) + * self.noise_strength + ) + if noise_mode == "const": noise = self.noise_const * self.noise_strength x = x + noise act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None - out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) + out = bias_act( + x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp + ) return out class ToRGB(torch.nn.Module): - def __init__(self, - in_channels, - out_channels, - style_dim, - kernel_size=1, - resample_filter=[1, 3, 3, 1], - conv_clamp=None, - demodulate=False): + def __init__( + self, + in_channels, + out_channels, + style_dim, + kernel_size=1, + resample_filter=[1, 3, 3, 1], + conv_clamp=None, + demodulate=False, + ): super().__init__() - self.conv = ModulatedConv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - style_dim=style_dim, - demodulate=demodulate, - resample_filter=resample_filter, - conv_clamp=conv_clamp) + self.conv = ModulatedConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + style_dim=style_dim, + demodulate=demodulate, + resample_filter=resample_filter, + conv_clamp=conv_clamp, + ) self.bias = torch.nn.Parameter(torch.zeros([out_channels])) - self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.register_buffer("resample_filter", setup_filter(resample_filter)) self.conv_clamp = conv_clamp def forward(self, x, style, skip=None): @@ -156,28 +199,41 @@ def get_style_code(a, b): class DecBlockFirst(nn.Module): - def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + def __init__( + self, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): super().__init__() - self.fc = FullyConnectedLayer(in_features=in_channels * 2, - out_features=in_channels * 4 ** 2, - activation=activation) - self.conv = StyleConv(in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=4, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB(in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) + self.fc = FullyConnectedLayer( + in_features=in_channels * 2, + out_features=in_channels * 4**2, + activation=activation, + ) + self.conv = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) - def forward(self, x, ws, gs, E_features, noise_mode='random'): + def forward(self, x, ws, gs, E_features, noise_mode="random"): x = self.fc(x).view(x.shape[0], -1, 4, 4) x = x + E_features[2] style = get_style_code(ws[:, 0], gs) @@ -189,30 +245,42 @@ class DecBlockFirst(nn.Module): class DecBlockFirstV2(nn.Module): - def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + def __init__( + self, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): super().__init__() - self.conv0 = Conv2dLayer(in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - ) - self.conv1 = StyleConv(in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=4, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB(in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=4, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) - def forward(self, x, ws, gs, E_features, noise_mode='random'): + def forward(self, x, ws, gs, E_features, noise_mode="random"): # x = self.fc(x).view(x.shape[0], -1, 4, 4) x = self.conv0(x) x = x + E_features[2] @@ -225,38 +293,50 @@ class DecBlockFirstV2(nn.Module): class DecBlock(nn.Module): - def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, - img_channels): # res = 2, ..., resolution_log2 + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): # res = 2, ..., resolution_log2 super().__init__() self.res = res - self.conv0 = StyleConv(in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2 ** res, - kernel_size=3, - up=2, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.conv1 = StyleConv(in_channels=out_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2 ** res, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB(in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) - def forward(self, x, img, ws, gs, E_features, noise_mode='random'): + def forward(self, x, img, ws, gs, E_features, noise_mode="random"): style = get_style_code(ws[:, self.res * 2 - 5], gs) x = self.conv0(x, style, noise_mode=noise_mode) x = x + E_features[self.res] @@ -269,18 +349,19 @@ class DecBlock(nn.Module): class MappingNet(torch.nn.Module): - def __init__(self, - z_dim, # Input latent (Z) dimensionality, 0 = no latent. - c_dim, # Conditioning label (C) dimensionality, 0 = no label. - w_dim, # Intermediate latent (W) dimensionality. - num_ws, # Number of intermediate latents to output, None = do not broadcast. - num_layers=8, # Number of mapping layers. - embed_features=None, # Label embedding dimensionality, None = same as w_dim. - layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. - w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. - ): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim @@ -295,23 +376,32 @@ class MappingNet(torch.nn.Module): embed_features = 0 if layer_features is None: layer_features = w_dim - features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + features_list = ( + [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + ) if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] - layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) - setattr(self, f'fc{idx}', layer) + layer = FullyConnectedLayer( + in_features, + out_features, + activation=activation, + lr_multiplier=lr_multiplier, + ) + setattr(self, f"fc{idx}", layer) if num_ws is not None and w_avg_beta is not None: - self.register_buffer('w_avg', torch.zeros([w_dim])) + self.register_buffer("w_avg", torch.zeros([w_dim])) - def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): + def forward( + self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False + ): # Embed, normalize, and concat inputs. x = None - with torch.autograd.profiler.record_function('input'): + with torch.autograd.profiler.record_function("input"): if self.z_dim > 0: x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: @@ -320,64 +410,76 @@ class MappingNet(torch.nn.Module): # Main layers. for idx in range(self.num_layers): - layer = getattr(self, f'fc{idx}') + layer = getattr(self, f"fc{idx}") x = layer(x) # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: - with torch.autograd.profiler.record_function('update_w_avg'): - self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + with torch.autograd.profiler.record_function("update_w_avg"): + self.w_avg.copy_( + x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta) + ) # Broadcast. if self.num_ws is not None: - with torch.autograd.profiler.record_function('broadcast'): + with torch.autograd.profiler.record_function("broadcast"): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: - with torch.autograd.profiler.record_function('truncate'): + with torch.autograd.profiler.record_function("truncate"): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: - x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + x[:, :truncation_cutoff] = self.w_avg.lerp( + x[:, :truncation_cutoff], truncation_psi + ) return x class DisFromRGB(nn.Module): - def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 super().__init__() - self.conv = Conv2dLayer(in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - activation=activation, - ) + self.conv = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) def forward(self, x): return self.conv(x) class DisBlock(nn.Module): - def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 super().__init__() - self.conv0 = Conv2dLayer(in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - ) - self.conv1 = Conv2dLayer(in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - down=2, - activation=activation, - ) - self.skip = Conv2dLayer(in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - down=2, - bias=False, - ) + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + down=2, + activation=activation, + ) + self.skip = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + down=2, + bias=False, + ) def forward(self, x): skip = self.skip(x, gain=np.sqrt(0.5)) @@ -389,29 +491,32 @@ class DisBlock(nn.Module): class Discriminator(torch.nn.Module): - def __init__(self, - c_dim, # Conditioning label (C) dimensionality. - img_resolution, # Input resolution. - img_channels, # Number of input color channels. - channel_base=32768, # Overall multiplier for the number of channels. - channel_max=512, # Maximum number of channels in any layer. - channel_decay=1, - cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. - activation='lrelu', - mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. - mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. - ): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation="lrelu", + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_channels = img_channels resolution_log2 = int(np.log2(img_resolution)) - assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 self.resolution_log2 = resolution_log2 def nf(stage): - return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max) + return np.clip( + int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max + ) if cmap_dim == None: cmap_dim = nf(2) @@ -420,18 +525,28 @@ class Discriminator(torch.nn.Module): self.cmap_dim = cmap_dim if c_dim > 0: - self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None) + self.mapping = MappingNet( + z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None + ) Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] for res in range(resolution_log2, 2, -1): Dis.append(DisBlock(nf(res), nf(res - 1), activation)) if mbstd_num_channels > 0: - Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) - Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation)) + Dis.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis.append( + Conv2dLayer( + nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation + ) + ) self.Dis = nn.Sequential(*Dis) - self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation) self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) def forward(self, images_in, masks_in, c): @@ -450,16 +565,27 @@ class Discriminator(torch.nn.Module): def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} - return NF[2 ** stage] + return NF[2**stage] class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu') - self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features) + self.fc1 = FullyConnectedLayer( + in_features=in_features, out_features=hidden_features, activation="lrelu" + ) + self.fc2 = FullyConnectedLayer( + in_features=hidden_features, out_features=out_features + ) def forward(self, x): x = self.fc1(x) @@ -477,7 +603,9 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) return windows @@ -493,30 +621,48 @@ def window_reverse(windows, window_size: int, H: int, W: int): """ B = int(windows.shape[0] / (H * W / window_size / window_size)) # B = windows.shape[0] / (H * W / window_size / window_size) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class Conv2dLayerPartial(nn.Module): - def __init__(self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - kernel_size, # Width and height of the convolution kernel. - bias=True, # Apply additive bias before the activation function? - activation='linear', # Activation function: 'relu', 'lrelu', etc. - up=1, # Integer upsampling factor. - down=1, # Integer downsampling factor. - resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output to +-X, None = disable clamping. - trainable=True, # Update the weights of this layer during training? - ): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + trainable=True, # Update the weights of this layer during training? + ): super().__init__() - self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter, - conv_clamp, trainable) + self.conv = Conv2dLayer( + in_channels, + out_channels, + kernel_size, + bias, + activation, + up, + down, + resample_filter, + conv_clamp, + trainable, + ) self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size) - self.slide_winsize = kernel_size ** 2 + self.slide_winsize = kernel_size**2 self.stride = down self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0 @@ -525,8 +671,13 @@ class Conv2dLayerPartial(nn.Module): with torch.no_grad(): if self.weight_maskUpdater.type() != x.type(): self.weight_maskUpdater = self.weight_maskUpdater.to(x) - update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, - padding=self.padding) + update_mask = F.conv2d( + mask, + self.weight_maskUpdater, + bias=None, + stride=self.stride, + padding=self.padding, + ) mask_ratio = self.slide_winsize / (update_mask + 1e-8) update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1 mask_ratio = torch.mul(mask_ratio, update_mask) @@ -539,7 +690,7 @@ class Conv2dLayerPartial(nn.Module): class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. + r"""Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. @@ -551,15 +702,24 @@ class WindowAttention(nn.Module): proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0., - proj_drop=0.): + def __init__( + self, + dim, + window_size, + num_heads, + down_ratio=1, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.q = FullyConnectedLayer(in_features=dim, out_features=dim) self.k = FullyConnectedLayer(in_features=dim, out_features=dim) @@ -576,23 +736,40 @@ class WindowAttention(nn.Module): """ B_, N, C = x.shape norm_x = F.normalize(x, p=2.0, dim=-1) - q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1) - v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = ( + self.q(norm_x) + .reshape(B_, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.k(norm_x) + .view(B_, -1, self.num_heads, C // self.num_heads) + .permute(0, 2, 3, 1) + ) + v = ( + self.v(x) + .view(B_, -1, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) attn = (q @ k) * self.scale if mask is not None: nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) if mask_windows is not None: attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1) - attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill( - attn_mask_windows == 1, float(0.0)) + attn = attn + attn_mask_windows.masked_fill( + attn_mask_windows == 0, float(-100.0) + ).masked_fill(attn_mask_windows == 1, float(0.0)) with torch.no_grad(): - mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1) + mask_windows = torch.clamp( + torch.sum(mask_windows, dim=1, keepdim=True), 0, 1 + ).repeat(1, N, 1) attn = self.softmax(attn) @@ -602,7 +779,7 @@ class WindowAttention(nn.Module): class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. + r"""Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. @@ -619,9 +796,23 @@ class SwinTransformerBlock(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, + dim, + input_resolution, + num_heads, + down_ratio=1, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -633,18 +824,34 @@ class SwinTransformerBlock(nn.Module): # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" if self.shift_size > 0: down_ratio = 1 - self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + down_ratio=down_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) - self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu') + self.fuse = FullyConnectedLayer( + in_features=dim * 2, out_features=dim, activation="lrelu" + ) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) if self.shift_size > 0: attn_mask = self.calculate_mask(self.input_resolution) @@ -657,22 +864,30 @@ class SwinTransformerBlock(nn.Module): # calculate attention mask for SW-MSA H, W = x_size img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) return attn_mask @@ -689,17 +904,25 @@ class SwinTransformerBlock(nn.Module): # cyclic shift if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) if mask is not None: - shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_mask = torch.roll( + mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) else: shifted_x = x if mask is not None: shifted_mask = mask # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C if mask is not None: mask_windows = window_partition(shifted_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1) @@ -708,11 +931,13 @@ class SwinTransformerBlock(nn.Module): # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size if self.input_resolution == x_size: - attn_windows, mask_windows = self.attn(x_windows, mask_windows, - mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows, mask_windows = self.attn( + x_windows, mask_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C else: - attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to( - x.device)) # nW*B, window_size*window_size, C + attn_windows, mask_windows = self.attn( + x_windows, mask_windows, mask=self.calculate_mask(x_size).to(x.device) + ) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) @@ -723,9 +948,13 @@ class SwinTransformerBlock(nn.Module): # reverse cyclic shift if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) if mask is not None: - mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + mask = torch.roll( + shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) else: x = shifted_x if mask is not None: @@ -744,12 +973,13 @@ class SwinTransformerBlock(nn.Module): class PatchMerging(nn.Module): def __init__(self, in_channels, out_channels, down=2): super().__init__() - self.conv = Conv2dLayerPartial(in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - activation='lrelu', - down=down, - ) + self.conv = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation="lrelu", + down=down, + ) self.down = down def forward(self, x, x_size, mask=None): @@ -769,12 +999,13 @@ class PatchMerging(nn.Module): class PatchUpsampling(nn.Module): def __init__(self, in_channels, out_channels, up=2): super().__init__() - self.conv = Conv2dLayerPartial(in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - activation='lrelu', - up=up, - ) + self.conv = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation="lrelu", + up=up, + ) self.up = up def forward(self, x, x_size, mask=None): @@ -791,7 +1022,7 @@ class PatchUpsampling(nn.Module): class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. + """A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. @@ -809,9 +1040,24 @@ class BasicLayer(nn.Module): use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1, - mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + down_ratio=1, + mlp_ratio=2.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): super().__init__() self.dim = dim @@ -827,18 +1073,32 @@ class BasicLayer(nn.Module): self.downsample = None # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, down_ratio=down_ratio, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + down_ratio=down_ratio, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) - self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu') + self.conv = Conv2dLayerPartial( + in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu" + ) def forward(self, x, x_size, mask=None): if self.downsample is not None: @@ -862,8 +1122,12 @@ class ToToken(nn.Module): def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1): super().__init__() - self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size, - activation='lrelu') + self.proj = Conv2dLayerPartial( + in_channels=in_channels, + out_channels=dim, + kernel_size=kernel_size, + activation="lrelu", + ) def forward(self, x, mask): x, mask = self.proj(x, mask) @@ -872,18 +1136,22 @@ class ToToken(nn.Module): class EncFromRGB(nn.Module): - def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2 + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log2 super().__init__() - self.conv0 = Conv2dLayer(in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - activation=activation, - ) - self.conv1 = Conv2dLayer(in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - activation=activation, - ) + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + activation=activation, + ) + self.conv1 = Conv2dLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) def forward(self, x): x = self.conv0(x) @@ -893,20 +1161,24 @@ class EncFromRGB(nn.Module): class ConvBlockDown(nn.Module): - def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log + def __init__( + self, in_channels, out_channels, activation + ): # res = 2, ..., resolution_log super().__init__() - self.conv0 = Conv2dLayer(in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - activation=activation, - down=2, - ) - self.conv1 = Conv2dLayer(in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - activation=activation, - ) + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + down=2, + ) + self.conv1 = Conv2dLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + activation=activation, + ) def forward(self, x): x = self.conv0(x) @@ -929,25 +1201,33 @@ def feature2token(x): class Encoder(nn.Module): - def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1): + def __init__( + self, + res_log2, + img_channels, + activation, + patch_size=5, + channels=16, + drop_path_rate=0.1, + ): super().__init__() self.resolution = [] for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16 - res = 2 ** i + res = 2**i self.resolution.append(res) if i == res_log2: block = EncFromRGB(img_channels * 2 + 1, nf(i), activation) else: block = ConvBlockDown(nf(i + 1), nf(i), activation) - setattr(self, 'EncConv_Block_%dx%d' % (res, res), block) + setattr(self, "EncConv_Block_%dx%d" % (res, res), block) def forward(self, x): out = {} for res in self.resolution: res_log2 = int(np.log2(res)) - x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x) + x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x) out[res_log2] = x return out @@ -957,18 +1237,33 @@ class ToStyle(nn.Module): def __init__(self, in_channels, out_channels, activation, drop_rate): super().__init__() self.conv = nn.Sequential( - Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, - down=2), - Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, - down=2), - Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation, - down=2), + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), + Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + down=2, + ), ) self.pool = nn.AdaptiveAvgPool2d(1) - self.fc = FullyConnectedLayer(in_features=in_channels, - out_features=out_channels, - activation=activation) + self.fc = FullyConnectedLayer( + in_features=in_channels, out_features=out_channels, activation=activation + ) # self.dropout = nn.Dropout(drop_rate) def forward(self, x): @@ -981,32 +1276,45 @@ class ToStyle(nn.Module): class DecBlockFirstV2(nn.Module): - def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): super().__init__() self.res = res - self.conv0 = Conv2dLayer(in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - ) - self.conv1 = StyleConv(in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2 ** res, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB(in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) + self.conv0 = Conv2dLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + activation=activation, + ) + self.conv1 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) - def forward(self, x, ws, gs, E_features, noise_mode='random'): + def forward(self, x, ws, gs, E_features, noise_mode="random"): # x = self.fc(x).view(x.shape[0], -1, 4, 4) x = self.conv0(x) x = x + E_features[self.res] @@ -1019,38 +1327,50 @@ class DecBlockFirstV2(nn.Module): class DecBlock(nn.Module): - def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, - img_channels): # res = 4, ..., resolution_log2 + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): # res = 4, ..., resolution_log2 super().__init__() self.res = res - self.conv0 = StyleConv(in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2 ** res, - kernel_size=3, - up=2, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.conv1 = StyleConv(in_channels=out_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2 ** res, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB(in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) - def forward(self, x, img, ws, gs, E_features, noise_mode='random'): + def forward(self, x, img, ws, gs, E_features, noise_mode="random"): style = get_style_code(ws[:, self.res * 2 - 9], gs) x = self.conv0(x, style, noise_mode=noise_mode) x = x + E_features[self.res] @@ -1063,55 +1383,84 @@ class DecBlock(nn.Module): class Decoder(nn.Module): - def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels): + def __init__( + self, res_log2, activation, style_dim, use_noise, demodulate, img_channels + ): super().__init__() - self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels) + self.Dec_16x16 = DecBlockFirstV2( + 4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels + ) for res in range(5, res_log2 + 1): - setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res), - DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels)) + setattr( + self, + "Dec_%dx%d" % (2**res, 2**res), + DecBlock( + res, + nf(res - 1), + nf(res), + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ), + ) self.res_log2 = res_log2 - def forward(self, x, ws, gs, E_features, noise_mode='random'): + def forward(self, x, ws, gs, E_features, noise_mode="random"): x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode) for res in range(5, self.res_log2 + 1): - block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res)) + block = getattr(self, "Dec_%dx%d" % (2**res, 2**res)) x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode) return img class DecStyleBlock(nn.Module): - def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels): + def __init__( + self, + res, + in_channels, + out_channels, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ): super().__init__() self.res = res - self.conv0 = StyleConv(in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2 ** res, - kernel_size=3, - up=2, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.conv1 = StyleConv(in_channels=out_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2 ** res, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB(in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) + self.conv0 = StyleConv( + in_channels=in_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + up=2, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.conv1 = StyleConv( + in_channels=out_channels, + out_channels=out_channels, + style_dim=style_dim, + resolution=2**res, + kernel_size=3, + use_noise=use_noise, + activation=activation, + demodulate=demodulate, + ) + self.toRGB = ToRGB( + in_channels=out_channels, + out_channels=img_channels, + style_dim=style_dim, + kernel_size=1, + demodulate=False, + ) - def forward(self, x, img, style, skip, noise_mode='random'): + def forward(self, x, img, style, skip, noise_mode="random"): x = self.conv0(x, style, noise_mode=noise_mode) x = x + skip x = self.conv1(x, style, noise_mode=noise_mode) @@ -1121,19 +1470,37 @@ class DecStyleBlock(nn.Module): class FirstStage(nn.Module): - def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True, - activation='lrelu'): + def __init__( + self, + img_channels, + img_resolution=256, + dim=180, + w_dim=512, + use_noise=False, + demodulate=True, + activation="lrelu", + ): super().__init__() res = 64 - self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3, - activation=activation) + self.conv_first = Conv2dLayerPartial( + in_channels=img_channels + 1, + out_channels=dim, + kernel_size=3, + activation=activation, + ) self.enc_conv = nn.ModuleList() down_time = int(np.log2(img_resolution // res)) # 根据图片尺寸构建 swim transformer 的层数 for i in range(down_time): # from input size to 64 self.enc_conv.append( - Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation) + Conv2dLayerPartial( + in_channels=dim, + out_channels=dim, + kernel_size=3, + down=2, + activation=activation, + ) ) # from 64 -> 16 -> 64 @@ -1154,30 +1521,59 @@ class FirstStage(nn.Module): else: merge = None self.tran.append( - BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads, - window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], - downsample=merge) + BasicLayer( + dim=dim, + input_resolution=[res, res], + depth=depth, + num_heads=num_heads, + window_size=window_sizes[i], + drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], + downsample=merge, + ) ) # global style down_conv = [] for i in range(int(np.log2(16))): down_conv.append( - Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)) + Conv2dLayer( + in_channels=dim, + out_channels=dim, + kernel_size=3, + down=2, + activation=activation, + ) + ) down_conv.append(nn.AdaptiveAvgPool2d((1, 1))) self.down_conv = nn.Sequential(*down_conv) - self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation) - self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation) - self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation) + self.to_style = FullyConnectedLayer( + in_features=dim, out_features=dim * 2, activation=activation + ) + self.ws_style = FullyConnectedLayer( + in_features=w_dim, out_features=dim, activation=activation + ) + self.to_square = FullyConnectedLayer( + in_features=dim, out_features=16 * 16, activation=activation + ) style_dim = dim * 3 self.dec_conv = nn.ModuleList() for i in range(down_time): # from 64 to input size res = res * 2 self.dec_conv.append( - DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels)) + DecStyleBlock( + res, + dim, + dim, + activation, + style_dim, + use_noise, + demodulate, + img_channels, + ) + ) - def forward(self, images_in, masks_in, ws, noise_mode='random'): + def forward(self, images_in, masks_in, ws, noise_mode="random"): x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1) skips = [] @@ -1206,16 +1602,25 @@ class FirstStage(nn.Module): mul_map = F.dropout(mul_map, training=True) ws = self.ws_style(ws[:, -1]) add_n = self.to_square(ws).unsqueeze(1) - add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze( - -1) + add_n = ( + F.interpolate( + add_n, size=x.size(1), mode="linear", align_corners=False + ) + .squeeze(1) + .unsqueeze(-1) + ) x = x * mul_map + add_n * (1 - mul_map) - gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1)) + gs = self.to_style( + self.down_conv(token2feature(x, x_size)).flatten(start_dim=1) + ) style = torch.cat([gs, ws], dim=1) x = token2feature(x, x_size).contiguous() img = None for i, block in enumerate(self.dec_conv): - x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode) + x, img = block( + x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode + ) # ensemble img = img * (1 - masks_in) + images_in * masks_in @@ -1224,38 +1629,55 @@ class FirstStage(nn.Module): class SynthesisNet(nn.Module): - def __init__(self, - w_dim, # Intermediate latent (W) dimensionality. - img_resolution, # Output image resolution. - img_channels=3, # Number of color channels. - channel_base=32768, # Overall multiplier for the number of channels. - channel_decay=1.0, - channel_max=512, # Maximum number of channels in any layer. - activation='lrelu', # Activation function: 'relu', 'lrelu', etc. - drop_rate=0.5, - use_noise=False, - demodulate=True, - ): + def __init__( + self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels=3, # Number of color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_decay=1.0, + channel_max=512, # Maximum number of channels in any layer. + activation="lrelu", # Activation function: 'relu', 'lrelu', etc. + drop_rate=0.5, + use_noise=False, + demodulate=True, + ): super().__init__() resolution_log2 = int(np.log2(img_resolution)) - assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 self.num_layers = resolution_log2 * 2 - 3 * 2 self.img_resolution = img_resolution self.resolution_log2 = resolution_log2 # first stage - self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False, - demodulate=demodulate) + self.first_stage = FirstStage( + img_channels, + img_resolution=img_resolution, + w_dim=w_dim, + use_noise=False, + demodulate=demodulate, + ) # second stage - self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16) - self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation) - self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate) + self.enc = Encoder( + resolution_log2, img_channels, activation, patch_size=5, channels=16 + ) + self.to_square = FullyConnectedLayer( + in_features=w_dim, out_features=16 * 16, activation=activation + ) + self.to_style = ToStyle( + in_channels=nf(4), + out_channels=nf(2) * 2, + activation=activation, + drop_rate=drop_rate, + ) style_dim = w_dim + nf(2) * 2 - self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels) + self.dec = Decoder( + resolution_log2, activation, style_dim, use_noise, demodulate, img_channels + ) - def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False): + def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False): out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode) # encoder @@ -1267,7 +1689,9 @@ class SynthesisNet(nn.Module): mul_map = torch.ones_like(fea_16) * 0.5 mul_map = F.dropout(mul_map, training=True) add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1) - add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False) + add_n = F.interpolate( + add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False + ) fea_16 = fea_16 * mul_map + add_n * (1 - mul_map) E_features[4] = fea_16 @@ -1287,15 +1711,16 @@ class SynthesisNet(nn.Module): class Generator(nn.Module): - def __init__(self, - z_dim, # Input latent (Z) dimensionality, 0 = no latent. - c_dim, # Conditioning label (C) dimensionality, 0 = no label. - w_dim, # Intermediate latent (W) dimensionality. - img_resolution, # resolution of generated image - img_channels, # Number of input color channels. - synthesis_kwargs={}, # Arguments for SynthesisNetwork. - mapping_kwargs={}, # Arguments for MappingNetwork. - ): + def __init__( + self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # resolution of generated image + img_channels, # Number of input color channels. + synthesis_kwargs={}, # Arguments for SynthesisNetwork. + mapping_kwargs={}, # Arguments for MappingNetwork. + ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim @@ -1303,44 +1728,64 @@ class Generator(nn.Module): self.img_resolution = img_resolution self.img_channels = img_channels - self.synthesis = SynthesisNet(w_dim=w_dim, - img_resolution=img_resolution, - img_channels=img_channels, - **synthesis_kwargs) - self.mapping = MappingNet(z_dim=z_dim, - c_dim=c_dim, - w_dim=w_dim, - num_ws=self.synthesis.num_layers, - **mapping_kwargs) + self.synthesis = SynthesisNet( + w_dim=w_dim, + img_resolution=img_resolution, + img_channels=img_channels, + **synthesis_kwargs, + ) + self.mapping = MappingNet( + z_dim=z_dim, + c_dim=c_dim, + w_dim=w_dim, + num_ws=self.synthesis.num_layers, + **mapping_kwargs, + ) - def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, - noise_mode='none', return_stg1=False): - ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, - skip_w_avg_update=skip_w_avg_update) + def forward( + self, + images_in, + masks_in, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + skip_w_avg_update=False, + noise_mode="none", + return_stg1=False, + ): + ws = self.mapping( + z, + c, + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + skip_w_avg_update=skip_w_avg_update, + ) img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode) return img class Discriminator(torch.nn.Module): - def __init__(self, - c_dim, # Conditioning label (C) dimensionality. - img_resolution, # Input resolution. - img_channels, # Number of input color channels. - channel_base=32768, # Overall multiplier for the number of channels. - channel_max=512, # Maximum number of channels in any layer. - channel_decay=1, - cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. - activation='lrelu', - mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. - mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. - ): + def __init__( + self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + channel_decay=1, + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + activation="lrelu", + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_channels = img_channels resolution_log2 = int(np.log2(img_resolution)) - assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 self.resolution_log2 = resolution_log2 if cmap_dim == None: @@ -1350,18 +1795,28 @@ class Discriminator(torch.nn.Module): self.cmap_dim = cmap_dim if c_dim > 0: - self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None) + self.mapping = MappingNet( + z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None + ) Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)] for res in range(resolution_log2, 2, -1): Dis.append(DisBlock(nf(res), nf(res - 1), activation)) if mbstd_num_channels > 0: - Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) - Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation)) + Dis.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis.append( + Conv2dLayer( + nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation + ) + ) self.Dis = nn.Sequential(*Dis) - self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation) self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) # for 64x64 @@ -1370,12 +1825,27 @@ class Discriminator(torch.nn.Module): Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation)) if mbstd_num_channels > 0: - Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels)) - Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation)) + Dis_stg1.append( + MinibatchStdLayer( + group_size=mbstd_group_size, num_channels=mbstd_num_channels + ) + ) + Dis_stg1.append( + Conv2dLayer( + nf(2) // 2 + mbstd_num_channels, + nf(2) // 2, + kernel_size=3, + activation=activation, + ) + ) self.Dis_stg1 = nn.Sequential(*Dis_stg1) - self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation) - self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim) + self.fc0_stg1 = FullyConnectedLayer( + nf(2) // 2 * 4**2, nf(2) // 2, activation=activation + ) + self.fc1_stg1 = FullyConnectedLayer( + nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim + ) def forward(self, images_in, masks_in, images_stg1, c): x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1)) @@ -1389,7 +1859,9 @@ class Discriminator(torch.nn.Module): if self.cmap_dim > 0: x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) - x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * ( + 1 / np.sqrt(self.cmap_dim) + ) return x, x_stg1 @@ -1399,6 +1871,8 @@ MAT_MODEL_URL = os.environ.get( "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth", ) +MAT_MODEL_MD5 = os.environ.get("MAT_MODEL_MD5", "8ca927835fa3f5e21d65ffcb165377ed") + class MAT(InpaintModel): name = "mat" @@ -1413,7 +1887,7 @@ class MAT(InpaintModel): torch.manual_seed(seed) G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3) - self.model = load_model(G, MAT_MODEL_URL, device) + self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5) self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512] self.label = torch.zeros([1, self.model.c_dim], device=device) @@ -1438,8 +1912,15 @@ class MAT(InpaintModel): image = torch.from_numpy(image).unsqueeze(0).to(self.device) mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) - output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none') - output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8) + output = self.model( + image, mask, self.z, self.label, truncation_psi=1, noise_mode="none" + ) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) output = output[0].cpu().numpy() cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return cur_res diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index 9ccdcf4..948c906 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -17,21 +17,33 @@ ZITS_INPAINT_MODEL_URL = os.environ.get( "ZITS_INPAINT_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt", ) +ZITS_INPAINT_MODEL_MD5 = os.environ.get( + "ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154" +) ZITS_EDGE_LINE_MODEL_URL = os.environ.get( "ZITS_EDGE_LINE_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt", ) +ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get( + "ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f" +) ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get( "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt", ) +ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get( + "ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927" +) ZITS_WIRE_FRAME_MODEL_URL = os.environ.get( "ZITS_WIRE_FRAME_MODEL_URL", "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt", ) +ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get( + "ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b" +) def resize(img, height, width, center_crop=False): @@ -219,12 +231,12 @@ class ZITS(InpaintModel): self.sample_edge_line_iterations = 1 def init_model(self, device, **kwargs): - self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device) - self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device) + self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5) + self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5) self.structure_upsample = load_jit_model( - ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device + ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 ) - self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device) + self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5) @staticmethod def is_downloaded() -> bool: diff --git a/lama_cleaner/tests/test_model_md5.py b/lama_cleaner/tests/test_model_md5.py new file mode 100644 index 0000000..7480fcd --- /dev/null +++ b/lama_cleaner/tests/test_model_md5.py @@ -0,0 +1,54 @@ +import os +import tempfile +from pathlib import Path + + +def test_load_model(): + from lama_cleaner.interactive_seg import InteractiveSeg + from lama_cleaner.model_manager import ModelManager + + interactive_seg_model = InteractiveSeg() + + models = [ + "lama", + "ldm", + "zits", + "mat", + "fcf", + "manga", + ] + for m in models: + ModelManager( + name=m, + device="cpu", + no_half=False, + hf_access_token="", + disable_nsfw=False, + sd_cpu_textencoder=True, + sd_run_local=True, + local_files_only=True, + cpu_offload=True, + enable_xformers=False, + ) + + +# def create_empty_file(tmp_dir, name): +# tmp_model_dir = os.path.join(tmp_dir, "torch", "hub", "checkpoints") +# Path(tmp_model_dir).mkdir(exist_ok=True, parents=True) +# path = os.path.join(tmp_model_dir, name) +# with open(path, "w") as f: +# f.write("1") +# +# +# def test_load_model_error(): +# MODELS = [ +# ("big-lama.pt", "e3aa4aaa15225a33ec84f9f4bc47e500"), +# ("cond_stage_model_encode.pt", "23239fc9081956a3e70de56472b3f296"), +# ("cond_stage_model_decode.pt", "fe419cd15a750d37a4733589d0d3585c"), +# ("diffusion.pt", "b0afda12bf790c03aba2a7431f11d22d"), +# ] +# with tempfile.TemporaryDirectory() as tmp_dir: +# os.environ["XDG_CACHE_HOME"] = tmp_dir +# for name, md5 in MODELS: +# create_empty_file(tmp_dir, name) +# test_load_model()