diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py index a1ce1a1..db72076 100644 --- a/lama_cleaner/__init__.py +++ b/lama_cleaner/__init__.py @@ -11,6 +11,8 @@ from lama_cleaner.parse_args import parse_args def entry_point(): args = parse_args() + if args is None: + return # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 from lama_cleaner.server import main diff --git a/lama_cleaner/diffusers_utils.py b/lama_cleaner/diffusers_utils.py new file mode 100644 index 0000000..d53fa7f --- /dev/null +++ b/lama_cleaner/diffusers_utils.py @@ -0,0 +1,28 @@ +import json +from pathlib import Path +from typing import Dict, List + + +def folder_name_to_show_name(name: str) -> str: + return name.replace("models--", "").replace("--", "/") + + +def _scan_models(cache_dir, class_name: str) -> List[str]: + cache_dir = Path(cache_dir) + res = [] + for it in cache_dir.glob("**/*/model_index.json"): + with open(it, "r", encoding="utf-8") as f: + data = json.load(f) + if data["_class_name"] == class_name: + name = folder_name_to_show_name(it.parent.parent.parent.name) + if name not in res: + res.append(name) + return res + + +def scan_models(cache_dir) -> List[str]: + return _scan_models(cache_dir, "StableDiffusionPipeline") + + +def scan_inpainting_models(cache_dir) -> List[str]: + return _scan_models(cache_dir, "StableDiffusionInpaintPipeline") diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py new file mode 100644 index 0000000..f0f7cb2 --- /dev/null +++ b/lama_cleaner/download.py @@ -0,0 +1,24 @@ +import os + +from loguru import logger +from pathlib import Path + + +def cli_download_model(model: str, model_dir: str): + if os.path.isfile(model_dir): + raise ValueError(f"invalid --model-dir: {model_dir} is a file") + + if not os.path.exists(model_dir): + logger.info(f"Create model cache directory: {model_dir}") + Path(model_dir).mkdir(exist_ok=True, parents=True) + + os.environ["XDG_CACHE_HOME"] = model_dir + + from lama_cleaner.model_manager import models + + if model in models: + logger.info(f"Downloading {model}...") + models[model].download() + logger.info(f"Done.") + else: + logger.info(f"Downloading model from Huggingface: {model}") diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 726a2be..81cdf0c 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -51,6 +51,10 @@ class InpaintModel: """ ... + @staticmethod + def download(): + ... + def _pad_forward(self, image, mask, config: Config): origin_height, origin_width = image.shape[:2] pad_image = pad_img_to_modulo( diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py index 07292c6..a64885a 100644 --- a/lama_cleaner/model/fcf.py +++ b/lama_cleaner/model/fcf.py @@ -14,6 +14,7 @@ from lama_cleaner.helper import ( norm_img, boxes_from_mask, resize_max_size, + download_model, ) from lama_cleaner.model.base import InpaintModel from torch import conv2d, nn @@ -870,7 +871,6 @@ class SpectralTransform(nn.Module): ) def forward(self, x): - x = self.downsample(x) x = self.conv1(x) output = self.fu(x) @@ -1437,7 +1437,6 @@ class SynthesisNetwork(torch.nn.Module): setattr(self, f"b{res}", block) def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): - img = None x, img = self.foreword(x_global, ws, feats, img) @@ -1656,6 +1655,10 @@ class FcF(InpaintModel): self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5) self.label = torch.zeros([1, self.model.c_dim], device=device) + @staticmethod + def download(): + download_model(FCF_MODEL_URL, FCF_MODEL_MD5) + @staticmethod def is_downloaded() -> bool: return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL)) diff --git a/lama_cleaner/model/instruct_pix2pix.py b/lama_cleaner/model/instruct_pix2pix.py index 4d176da..27476b7 100644 --- a/lama_cleaner/model/instruct_pix2pix.py +++ b/lama_cleaner/model/instruct_pix2pix.py @@ -4,7 +4,6 @@ import torch from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.model.utils import set_seed from lama_cleaner.schema import Config @@ -15,18 +14,21 @@ class InstructPix2Pix(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): from diffusers import StableDiffusionInstructPix2PixPipeline - fp16 = not kwargs.get('no_half', False) - model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} - if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): + fp16 = not kwargs.get("no_half", False) + + model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)} + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") - model_kwargs.update(dict( - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False - )) + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) - use_gpu = device == torch.device('cuda') and torch.cuda.is_available() + use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", @@ -36,15 +38,23 @@ class InstructPix2Pix(DiffusionInpaintModel): ) self.model.enable_attention_slicing() - if kwargs.get('enable_xformers', False): + if kwargs.get("enable_xformers", False): self.model.enable_xformers_memory_efficient_attention() - if kwargs.get('cpu_offload', False) and use_gpu: + if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") self.model.enable_sequential_cpu_offload(gpu_id=0) else: self.model = self.model.to(device) + @staticmethod + def download(): + from diffusers import StableDiffusionInstructPix2PixPipeline + + StableDiffusionInstructPix2PixPipeline.from_pretrained( + "timbrooks/instruct-pix2pix", revision="fp16" + ) + def forward(self, image, mask, config: Config): """Input image and output image have same size image: [H, W, C] RGB @@ -60,7 +70,7 @@ class InstructPix2Pix(DiffusionInpaintModel): image_guidance_scale=config.p2p_image_guidance_scale, guidance_scale=config.p2p_guidance_scale, output_type="np", - generator=torch.manual_seed(config.sd_seed) + generator=torch.manual_seed(config.sd_seed), ).images[0] output = (output * 255).round().astype("uint8") diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index 7a89d9a..12d8209 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -76,3 +76,11 @@ class Kandinsky(DiffusionInpaintModel): class Kandinsky22(Kandinsky): name = "kandinsky2.2" model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" + + @staticmethod + def download(): + from diffusers import AutoPipelineForInpainting + + AutoPipelineForInpainting.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder-inpaint" + ) diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index bdcdf0d..ebbb6c9 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -8,6 +8,7 @@ from lama_cleaner.helper import ( norm_img, get_cache_path_by_url, load_jit_model, + download_model, ) from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config @@ -23,6 +24,10 @@ class LaMa(InpaintModel): name = "lama" pad_mod = 8 + @staticmethod + def download(): + download_model(LAMA_MODEL_URL, LAMA_MODEL_MD5) + def init_model(self, device, **kwargs): self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval() diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index a5b6d12..29956ad 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -260,6 +260,12 @@ class LDM(InpaintModel): self.model = LatentDiffusion(self.diffusion_model, device) + @staticmethod + def download(): + download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5) + download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5) + download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5) + @staticmethod def is_downloaded() -> bool: model_paths = [ diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py index f6e27e5..d61f408 100644 --- a/lama_cleaner/model/manga.py +++ b/lama_cleaner/model/manga.py @@ -7,7 +7,7 @@ import torch import time from loguru import logger -from lama_cleaner.helper import get_cache_path_by_url, load_jit_model +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config @@ -42,6 +42,11 @@ class Manga(InpaintModel): ) self.seed = 42 + @staticmethod + def download(): + download_model(MANGA_INPAINTOR_MODEL_URL, MANGA_INPAINTOR_MODEL_MD5) + download_model(MANGA_LINE_MODEL_URL, MANGA_LINE_MODEL_MD5) + @staticmethod def is_downloaded() -> bool: model_paths = [ diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py index 3e09bf4..2fbe11c 100644 --- a/lama_cleaner/model/mat.py +++ b/lama_cleaner/model/mat.py @@ -8,7 +8,12 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img +from lama_cleaner.helper import ( + load_model, + get_cache_path_by_url, + norm_img, + download_model, +) from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.utils import ( setup_filter, @@ -52,7 +57,7 @@ class ModulatedConv2d(nn.Module): ) self.out_channels = out_channels self.kernel_size = kernel_size - self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) self.padding = self.kernel_size // 2 self.up = up self.down = down @@ -213,7 +218,7 @@ class DecBlockFirst(nn.Module): super().__init__() self.fc = FullyConnectedLayer( in_features=in_channels * 2, - out_features=in_channels * 4 ** 2, + out_features=in_channels * 4**2, activation=activation, ) self.conv = StyleConv( @@ -312,7 +317,7 @@ class DecBlock(nn.Module): in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, - resolution=2 ** res, + resolution=2**res, kernel_size=3, up=2, use_noise=use_noise, @@ -323,7 +328,7 @@ class DecBlock(nn.Module): in_channels=out_channels, out_channels=out_channels, style_dim=style_dim, - resolution=2 ** res, + resolution=2**res, kernel_size=3, use_noise=use_noise, activation=activation, @@ -507,7 +512,7 @@ class Discriminator(torch.nn.Module): self.img_channels = img_channels resolution_log2 = int(np.log2(img_resolution)) - assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 self.resolution_log2 = resolution_log2 def nf(stage): @@ -543,7 +548,7 @@ class Discriminator(torch.nn.Module): ) self.Dis = nn.Sequential(*Dis) - self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation) self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) def forward(self, images_in, masks_in, c): @@ -562,7 +567,7 @@ class Discriminator(torch.nn.Module): def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} - return NF[2 ** stage] + return NF[2**stage] class Mlp(nn.Module): @@ -659,7 +664,7 @@ class Conv2dLayerPartial(nn.Module): ) self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size) - self.slide_winsize = kernel_size ** 2 + self.slide_winsize = kernel_size**2 self.stride = down self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0 @@ -715,7 +720,7 @@ class WindowAttention(nn.Module): self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.q = FullyConnectedLayer(in_features=dim, out_features=dim) self.k = FullyConnectedLayer(in_features=dim, out_features=dim) @@ -1211,7 +1216,7 @@ class Encoder(nn.Module): self.resolution = [] for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16 - res = 2 ** i + res = 2**i self.resolution.append(res) if i == res_log2: block = EncFromRGB(img_channels * 2 + 1, nf(i), activation) @@ -1296,7 +1301,7 @@ class DecBlockFirstV2(nn.Module): in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, - resolution=2 ** res, + resolution=2**res, kernel_size=3, use_noise=use_noise, activation=activation, @@ -1341,7 +1346,7 @@ class DecBlock(nn.Module): in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, - resolution=2 ** res, + resolution=2**res, kernel_size=3, up=2, use_noise=use_noise, @@ -1352,7 +1357,7 @@ class DecBlock(nn.Module): in_channels=out_channels, out_channels=out_channels, style_dim=style_dim, - resolution=2 ** res, + resolution=2**res, kernel_size=3, use_noise=use_noise, activation=activation, @@ -1389,7 +1394,7 @@ class Decoder(nn.Module): for res in range(5, res_log2 + 1): setattr( self, - "Dec_%dx%d" % (2 ** res, 2 ** res), + "Dec_%dx%d" % (2**res, 2**res), DecBlock( res, nf(res - 1), @@ -1406,7 +1411,7 @@ class Decoder(nn.Module): def forward(self, x, ws, gs, E_features, noise_mode="random"): x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode) for res in range(5, self.res_log2 + 1): - block = getattr(self, "Dec_%dx%d" % (2 ** res, 2 ** res)) + block = getattr(self, "Dec_%dx%d" % (2**res, 2**res)) x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode) return img @@ -1431,7 +1436,7 @@ class DecStyleBlock(nn.Module): in_channels=in_channels, out_channels=out_channels, style_dim=style_dim, - resolution=2 ** res, + resolution=2**res, kernel_size=3, up=2, use_noise=use_noise, @@ -1442,7 +1447,7 @@ class DecStyleBlock(nn.Module): in_channels=out_channels, out_channels=out_channels, style_dim=style_dim, - resolution=2 ** res, + resolution=2**res, kernel_size=3, use_noise=use_noise, activation=activation, @@ -1640,7 +1645,7 @@ class SynthesisNet(nn.Module): ): super().__init__() resolution_log2 = int(np.log2(img_resolution)) - assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 self.num_layers = resolution_log2 * 2 - 3 * 2 self.img_resolution = img_resolution @@ -1781,7 +1786,7 @@ class Discriminator(torch.nn.Module): self.img_channels = img_channels resolution_log2 = int(np.log2(img_resolution)) - assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4 + assert img_resolution == 2**resolution_log2 and img_resolution >= 4 self.resolution_log2 = resolution_log2 if cmap_dim == None: @@ -1812,7 +1817,7 @@ class Discriminator(torch.nn.Module): ) self.Dis = nn.Sequential(*Dis) - self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation) + self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation) self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim) # for 64x64 @@ -1837,7 +1842,7 @@ class Discriminator(torch.nn.Module): self.Dis_stg1 = nn.Sequential(*Dis_stg1) self.fc0_stg1 = FullyConnectedLayer( - nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation + nf(2) // 2 * 4**2, nf(2) // 2, activation=activation ) self.fc1_stg1 = FullyConnectedLayer( nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim @@ -1898,6 +1903,10 @@ class MAT(InpaintModel): self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype) # fmt: on + @staticmethod + def download(): + download_model(MAT_MODEL_URL, MAT_MODEL_MD5) + @staticmethod def is_downloaded() -> bool: return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL)) diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index 4b5ae74..a9606a4 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -2,11 +2,9 @@ import PIL import PIL.Image import cv2 import torch -from diffusers import DiffusionPipeline from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.model.utils import set_seed from lama_cleaner.schema import Config @@ -16,35 +14,40 @@ class PaintByExample(DiffusionInpaintModel): min_size = 512 def init_model(self, device: torch.device, **kwargs): - fp16 = not kwargs.get('no_half', False) - use_gpu = device == torch.device('cuda') and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 - model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} + from diffusers import DiffusionPipeline - if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): + fp16 = not kwargs.get("no_half", False) + use_gpu = device == torch.device("cuda") and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)} + + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Paint By Example Model NSFW checker") - model_kwargs.update(dict( - safety_checker=None, - requires_safety_checker=False - )) + model_kwargs.update( + dict(safety_checker=None, requires_safety_checker=False) + ) self.model = DiffusionPipeline.from_pretrained( - "Fantasy-Studio/Paint-by-Example", - torch_dtype=torch_dtype, - **model_kwargs + "Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs ) self.model.enable_attention_slicing() - if kwargs.get('enable_xformers', False): + if kwargs.get("enable_xformers", False): self.model.enable_xformers_memory_efficient_attention() # TODO: gpu_id - if kwargs.get('cpu_offload', False) and use_gpu: + if kwargs.get("cpu_offload", False) and use_gpu: self.model.image_encoder = self.model.image_encoder.to(device) self.model.enable_sequential_cpu_offload(gpu_id=0) else: self.model = self.model.to(device) + @staticmethod + def download(): + from diffusers import DiffusionPipeline + + DiffusionPipeline.from_pretrained("Fantasy-Studio/Paint-by-Example") + def forward(self, image, mask, config: Config): """Input image and output image have same size image: [H, W, C] RGB @@ -56,8 +59,8 @@ class PaintByExample(DiffusionInpaintModel): mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), example_image=config.paint_by_example_example_image, num_inference_steps=config.paint_by_example_steps, - output_type='np.array', - generator=torch.manual_seed(config.paint_by_example_seed) + output_type="np.array", + generator=torch.manual_seed(config.paint_by_example_seed), ).images[0] output = (output * 255).round().astype("uint8") diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index f466625..0c0b190 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -132,6 +132,12 @@ class SD(DiffusionInpaintModel): # model will be downloaded when app start, and can't switch in frontend settings return True + @classmethod + def download(cls): + from diffusers import StableDiffusionInpaintPipeline + + StableDiffusionInpaintPipeline.from_pretrained(cls.model_id_or_path) + class SD15(SD): name = "sd1.5" diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index 7615795..f64bfa8 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -5,7 +5,6 @@ import torch from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.model.utils import torch_gc, get_scheduler from lama_cleaner.schema import Config @@ -51,6 +50,14 @@ class SDXL(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) + @staticmethod + def download(): + from diffusers import AutoPipelineForInpainting + + AutoPipelineForInpainting.from_pretrained( + "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" + ) + def forward(self, image, mask, config: Config): """Input image and output image have same size image: [H, W, C] RGB @@ -85,7 +92,6 @@ class SDXL(DiffusionInpaintModel): output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - @staticmethod def is_downloaded() -> bool: # model will be downloaded when app start, and can't switch in frontend settings diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index 664ca15..748623e 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -5,7 +5,7 @@ import cv2 import torch import torch.nn.functional as F -from lama_cleaner.helper import get_cache_path_by_url, load_jit_model +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model from lama_cleaner.schema import Config import numpy as np @@ -171,14 +171,19 @@ def load_image(img, mask, device, sigma256=3.0): try: import skimage + gray_256 = skimage.color.rgb2gray(img_256) edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float) # cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8)) # cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8)) except: gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY) - gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256) - edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2)) + gray_256_blured = cv2.GaussianBlur( + gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256 + ) + edge_256 = cv2.Canny( + gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2) + ) # cv2.imwrite("opencv_edge.jpg", edge_256) @@ -233,12 +238,27 @@ class ZITS(InpaintModel): self.sample_edge_line_iterations = 1 def init_model(self, device, **kwargs): - self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5) - self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5) + self.wireframe = load_jit_model( + ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5 + ) + self.edge_line = load_jit_model( + ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5 + ) self.structure_upsample = load_jit_model( ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 ) - self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5) + self.inpaint = load_jit_model( + ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5 + ) + + @staticmethod + def download(): + download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5) + download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5) + download_model( + ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 + ) + download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5) @staticmethod def is_downloaded() -> bool: @@ -385,12 +405,20 @@ class ZITS(InpaintModel): if score > mask_th: try: import skimage + rr, cc, value = skimage.draw.line_aa( *to_int(line[0:2]), *to_int(line[2:4]) ) lmap[rr, cc] = np.maximum(lmap[rr, cc], value) except: - cv2.line(lmap, to_int(line[0:2][::-1]), to_int(line[2:4][::-1]), (1, 1, 1), 1, cv2.LINE_AA) + cv2.line( + lmap, + to_int(line[0:2][::-1]), + to_int(line[2:4][::-1]), + (1, 1, 1), + 1, + cv2.LINE_AA, + ) lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8) lines_tensor.append(to_tensor(lmap).unsqueeze(0)) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index e27d8ea..f672fb7 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -6,13 +6,29 @@ from pathlib import Path from loguru import logger from lama_cleaner.const import * +from lama_cleaner.download import cli_download_model from lama_cleaner.runtime import dump_environment_info +DOWNLOAD_SUBCOMMAND = "download" + + +def download_parse_args(parser): + subparsers = parser.add_subparsers(dest="subcommand") + subparser = subparsers.add_parser(DOWNLOAD_SUBCOMMAND, help="Download models") + subparser.add_argument( + "--model", help="Erase model name(lama/mat...) or model id on huggingface" + ) + subparser.add_argument( + "--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP + ) + def parse_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + download_parse_args(parser) + parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", default=8080, type=int) @@ -166,9 +182,12 @@ def parse_args(): ) args = parser.parse_args() - # collect system info to help debug dump_environment_info() + if args.subcommand == DOWNLOAD_SUBCOMMAND: + cli_download_model(args.model, args.model_dir) + return + if args.install_plugins_package: from lama_cleaner.installer import install_plugins_package diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index 61efdcc..25d53c5 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -185,7 +185,7 @@ def main(config_file: str): ) sd_controlnet_method = gr.Radio( SD_CONTROLNET_CHOICES, - lable="ControlNet method", + label="ControlNet method", value=init_config.sd_controlnet_method, ) no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")