add download command

This commit is contained in:
Qing 2023-11-16 21:12:06 +08:00
parent 20e660aa4a
commit 1d145d1cd6
17 changed files with 233 additions and 67 deletions

View File

@ -11,6 +11,8 @@ from lama_cleaner.parse_args import parse_args
def entry_point():
args = parse_args()
if args is None:
return
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
from lama_cleaner.server import main

View File

@ -0,0 +1,28 @@
import json
from pathlib import Path
from typing import Dict, List
def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/")
def _scan_models(cache_dir, class_name: str) -> List[str]:
cache_dir = Path(cache_dir)
res = []
for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f:
data = json.load(f)
if data["_class_name"] == class_name:
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name not in res:
res.append(name)
return res
def scan_models(cache_dir) -> List[str]:
return _scan_models(cache_dir, "StableDiffusionPipeline")
def scan_inpainting_models(cache_dir) -> List[str]:
return _scan_models(cache_dir, "StableDiffusionInpaintPipeline")

24
lama_cleaner/download.py Normal file
View File

@ -0,0 +1,24 @@
import os
from loguru import logger
from pathlib import Path
def cli_download_model(model: str, model_dir: str):
if os.path.isfile(model_dir):
raise ValueError(f"invalid --model-dir: {model_dir} is a file")
if not os.path.exists(model_dir):
logger.info(f"Create model cache directory: {model_dir}")
Path(model_dir).mkdir(exist_ok=True, parents=True)
os.environ["XDG_CACHE_HOME"] = model_dir
from lama_cleaner.model_manager import models
if model in models:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info(f"Done.")
else:
logger.info(f"Downloading model from Huggingface: {model}")

View File

@ -51,6 +51,10 @@ class InpaintModel:
"""
...
@staticmethod
def download():
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(

View File

@ -14,6 +14,7 @@ from lama_cleaner.helper import (
norm_img,
boxes_from_mask,
resize_max_size,
download_model,
)
from lama_cleaner.model.base import InpaintModel
from torch import conv2d, nn
@ -870,7 +871,6 @@ class SpectralTransform(nn.Module):
)
def forward(self, x):
x = self.downsample(x)
x = self.conv1(x)
output = self.fu(x)
@ -1437,7 +1437,6 @@ class SynthesisNetwork(torch.nn.Module):
setattr(self, f"b{res}", block)
def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
img = None
x, img = self.foreword(x_global, ws, feats, img)
@ -1656,6 +1655,10 @@ class FcF(InpaintModel):
self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5)
self.label = torch.zeros([1, self.model.c_dim], device=device)
@staticmethod
def download():
download_model(FCF_MODEL_URL, FCF_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))

View File

@ -4,7 +4,6 @@ import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config
@ -15,18 +14,21 @@ class InstructPix2Pix(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs):
from diffusers import StableDiffusionInstructPix2PixPipeline
fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
fp16 = not kwargs.get("no_half", False)
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
))
requires_safety_checker=False,
)
)
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix",
@ -36,15 +38,23 @@ class InstructPix2Pix(DiffusionInpaintModel):
)
self.model.enable_attention_slicing()
if kwargs.get('enable_xformers', False):
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get('cpu_offload', False) and use_gpu:
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
@staticmethod
def download():
from diffusers import StableDiffusionInstructPix2PixPipeline
StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", revision="fp16"
)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -60,7 +70,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.p2p_guidance_scale,
output_type="np",
generator=torch.manual_seed(config.sd_seed)
generator=torch.manual_seed(config.sd_seed),
).images[0]
output = (output * 255).round().astype("uint8")

View File

@ -76,3 +76,11 @@ class Kandinsky(DiffusionInpaintModel):
class Kandinsky22(Kandinsky):
name = "kandinsky2.2"
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
@staticmethod
def download():
from diffusers import AutoPipelineForInpainting
AutoPipelineForInpainting.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder-inpaint"
)

View File

@ -8,6 +8,7 @@ from lama_cleaner.helper import (
norm_img,
get_cache_path_by_url,
load_jit_model,
download_model,
)
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
@ -23,6 +24,10 @@ class LaMa(InpaintModel):
name = "lama"
pad_mod = 8
@staticmethod
def download():
download_model(LAMA_MODEL_URL, LAMA_MODEL_MD5)
def init_model(self, device, **kwargs):
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()

View File

@ -260,6 +260,12 @@ class LDM(InpaintModel):
self.model = LatentDiffusion(self.diffusion_model, device)
@staticmethod
def download():
download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5)
download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5)
download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
model_paths = [

View File

@ -7,7 +7,7 @@ import torch
import time
from loguru import logger
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
@ -42,6 +42,11 @@ class Manga(InpaintModel):
)
self.seed = 42
@staticmethod
def download():
download_model(MANGA_INPAINTOR_MODEL_URL, MANGA_INPAINTOR_MODEL_MD5)
download_model(MANGA_LINE_MODEL_URL, MANGA_LINE_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
model_paths = [

View File

@ -8,7 +8,12 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
from lama_cleaner.helper import (
load_model,
get_cache_path_by_url,
norm_img,
download_model,
)
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.model.utils import (
setup_filter,
@ -1898,6 +1903,10 @@ class MAT(InpaintModel):
self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
# fmt: on
@staticmethod
def download():
download_model(MAT_MODEL_URL, MAT_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))

View File

@ -2,11 +2,9 @@ import PIL
import PIL.Image
import cv2
import torch
from diffusers import DiffusionPipeline
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config
@ -16,35 +14,40 @@ class PaintByExample(DiffusionInpaintModel):
min_size = 512
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get('no_half', False)
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
from diffusers import DiffusionPipeline
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
fp16 = not kwargs.get("no_half", False)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Paint By Example Model NSFW checker")
model_kwargs.update(dict(
safety_checker=None,
requires_safety_checker=False
))
model_kwargs.update(
dict(safety_checker=None, requires_safety_checker=False)
)
self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch_dtype,
**model_kwargs
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
)
self.model.enable_attention_slicing()
if kwargs.get('enable_xformers', False):
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
# TODO: gpu_id
if kwargs.get('cpu_offload', False) and use_gpu:
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
@staticmethod
def download():
from diffusers import DiffusionPipeline
DiffusionPipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -56,8 +59,8 @@ class PaintByExample(DiffusionInpaintModel):
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=config.paint_by_example_example_image,
num_inference_steps=config.paint_by_example_steps,
output_type='np.array',
generator=torch.manual_seed(config.paint_by_example_seed)
output_type="np.array",
generator=torch.manual_seed(config.paint_by_example_seed),
).images[0]
output = (output * 255).round().astype("uint8")

View File

@ -132,6 +132,12 @@ class SD(DiffusionInpaintModel):
# model will be downloaded when app start, and can't switch in frontend settings
return True
@classmethod
def download(cls):
from diffusers import StableDiffusionInpaintPipeline
StableDiffusionInpaintPipeline.from_pretrained(cls.model_id_or_path)
class SD15(SD):
name = "sd1.5"

View File

@ -5,7 +5,6 @@ import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc, get_scheduler
from lama_cleaner.schema import Config
@ -51,6 +50,14 @@ class SDXL(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
@staticmethod
def download():
from diffusers import AutoPipelineForInpainting
AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -85,7 +92,6 @@ class SDXL(DiffusionInpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings

View File

@ -5,7 +5,7 @@ import cv2
import torch
import torch.nn.functional as F
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
from lama_cleaner.schema import Config
import numpy as np
@ -171,14 +171,19 @@ def load_image(img, mask, device, sigma256=3.0):
try:
import skimage
gray_256 = skimage.color.rgb2gray(img_256)
edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float)
# cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8))
# cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8))
except:
gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256)
edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2))
gray_256_blured = cv2.GaussianBlur(
gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256
)
edge_256 = cv2.Canny(
gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2)
)
# cv2.imwrite("opencv_edge.jpg", edge_256)
@ -233,12 +238,27 @@ class ZITS(InpaintModel):
self.sample_edge_line_iterations = 1
def init_model(self, device, **kwargs):
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5)
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5)
self.wireframe = load_jit_model(
ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5
)
self.edge_line = load_jit_model(
ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5
)
self.structure_upsample = load_jit_model(
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
)
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5)
self.inpaint = load_jit_model(
ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5
)
@staticmethod
def download():
download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5)
download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5)
download_model(
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
)
download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
@ -385,12 +405,20 @@ class ZITS(InpaintModel):
if score > mask_th:
try:
import skimage
rr, cc, value = skimage.draw.line_aa(
*to_int(line[0:2]), *to_int(line[2:4])
)
lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
except:
cv2.line(lmap, to_int(line[0:2][::-1]), to_int(line[2:4][::-1]), (1, 1, 1), 1, cv2.LINE_AA)
cv2.line(
lmap,
to_int(line[0:2][::-1]),
to_int(line[2:4][::-1]),
(1, 1, 1),
1,
cv2.LINE_AA,
)
lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
lines_tensor.append(to_tensor(lmap).unsqueeze(0))

View File

@ -6,13 +6,29 @@ from pathlib import Path
from loguru import logger
from lama_cleaner.const import *
from lama_cleaner.download import cli_download_model
from lama_cleaner.runtime import dump_environment_info
DOWNLOAD_SUBCOMMAND = "download"
def download_parse_args(parser):
subparsers = parser.add_subparsers(dest="subcommand")
subparser = subparsers.add_parser(DOWNLOAD_SUBCOMMAND, help="Download models")
subparser.add_argument(
"--model", help="Erase model name(lama/mat...) or model id on huggingface"
)
subparser.add_argument(
"--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP
)
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
download_parse_args(parser)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
@ -166,9 +182,12 @@ def parse_args():
)
args = parser.parse_args()
# collect system info to help debug
dump_environment_info()
if args.subcommand == DOWNLOAD_SUBCOMMAND:
cli_download_model(args.model, args.model_dir)
return
if args.install_plugins_package:
from lama_cleaner.installer import install_plugins_package

View File

@ -185,7 +185,7 @@ def main(config_file: str):
)
sd_controlnet_method = gr.Radio(
SD_CONTROLNET_CHOICES,
lable="ControlNet method",
label="ControlNet method",
value=init_config.sd_controlnet_method,
)
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")