From cbe6577890bda6b3595551cdfd5a9783deb50edb Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 15 Dec 2023 12:40:29 +0800 Subject: [PATCH] update --- lama_cleaner/model/base.py | 7 +++---- lama_cleaner/model/controlnet.py | 5 +++++ lama_cleaner/model/instruct_pix2pix.py | 2 +- lama_cleaner/model/kandinsky.py | 2 +- lama_cleaner/model/paint_by_example.py | 2 +- lama_cleaner/model/sd.py | 9 +++++++-- lama_cleaner/model/sdxl.py | 11 +++++++++-- lama_cleaner/model_manager.py | 9 +++++---- lama_cleaner/server.py | 4 +++- 9 files changed, 35 insertions(+), 16 deletions(-) diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 193746a..56b572a 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -14,7 +14,7 @@ from lama_cleaner.helper import ( ) from lama_cleaner.model.helper.g_diffuser_bot import expand_image from lama_cleaner.model.utils import get_scheduler -from lama_cleaner.schema import Config, HDStrategy, SDSampler +from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo class InpaintModel: @@ -266,9 +266,8 @@ class InpaintModel: class DiffusionInpaintModel(InpaintModel): def __init__(self, device, **kwargs): - if kwargs.get("model_id_or_path"): - # 用于自定义 diffusers 模型 - self.model_id_or_path = kwargs["model_id_or_path"] + self.model_info: ModelInfo = kwargs["model_info"] + self.model_id_or_path = self.model_info.path super().__init__(device, **kwargs) @torch.no_grad() diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index 485585e..3442e6f 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -75,6 +75,11 @@ class ControlNet(DiffusionInpaintModel): sd_controlnet_method, torch_dtype=torch_dtype ) if model_info.is_single_file_diffusers: + if self.model_info.model_type == ModelType.DIFFUSERS_SD: + model_kwargs["num_in_channels"] = 4 + else: + model_kwargs["num_in_channels"] = 9 + self.model = PipeClass.from_single_file( model_info.path, controlnet=controlnet ).to(torch_dtype) diff --git a/lama_cleaner/model/instruct_pix2pix.py b/lama_cleaner/model/instruct_pix2pix.py index b569092..2b99abb 100644 --- a/lama_cleaner/model/instruct_pix2pix.py +++ b/lama_cleaner/model/instruct_pix2pix.py @@ -8,7 +8,7 @@ from lama_cleaner.schema import Config class InstructPix2Pix(DiffusionInpaintModel): - name = "instruct_pix2pix" + name = "timbrooks/instruct-pix2pix" pad_mod = 8 min_size = 512 diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index 38dfc33..b44f9df 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -73,7 +73,7 @@ class Kandinsky(DiffusionInpaintModel): class Kandinsky22(Kandinsky): - name = "kandinsky2.2" + name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" @staticmethod diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index a9606a4..8237cbd 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -9,7 +9,7 @@ from lama_cleaner.schema import Config class PaintByExample(DiffusionInpaintModel): - name = "paint_by_example" + name = "Fantasy-Studio/Paint-by-Example" pad_mod = 8 min_size = 512 diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index a81a849..48c8f65 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -9,7 +9,7 @@ from loguru import logger from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper -from lama_cleaner.schema import Config +from lama_cleaner.schema import Config, ModelType class SD(DiffusionInpaintModel): @@ -36,7 +36,12 @@ class SD(DiffusionInpaintModel): use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 - if os.path.isfile(self.model_id_or_path): + if self.model_info.is_single_file_diffusers: + if self.model_info.model_type == ModelType.DIFFUSERS_SD: + model_kwargs["num_in_channels"] = 4 + else: + model_kwargs["num_in_channels"] = 9 + self.model = StableDiffusionInpaintPipeline.from_single_file( self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs ) diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index af04941..d30a22b 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -8,7 +8,7 @@ from diffusers import AutoencoderKL from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.schema import Config +from lama_cleaner.schema import Config, ModelType class SDXL(DiffusionInpaintModel): @@ -26,9 +26,16 @@ class SDXL(DiffusionInpaintModel): use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + if self.model_info.model_type == ModelType.DIFFUSERS_SDXL: + num_in_channels = 4 + else: + num_in_channels = 9 + if os.path.isfile(self.model_id_or_path): self.model = StableDiffusionXLInpaintPipeline.from_single_file( - self.model_id_or_path, torch_dtype=torch_dtype + self.model_id_or_path, + torch_dtype=torch_dtype, + num_in_channels=num_in_channels, ) else: vae = AutoencoderKL.from_pretrained( diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 6149ecf..f5b75f8 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -27,25 +27,26 @@ class ModelManager: if name not in self.available_models: raise NotImplementedError(f"Unsupported model: {name}") - sd_controlnet_enabled = kwargs.get("sd_controlnet", False) model_info = self.available_models[name] + kwargs = {**kwargs, "model_info": model_info} + sd_controlnet_enabled = kwargs.get("sd_controlnet", False) if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]: return models[name](device, **kwargs) if sd_controlnet_enabled: - return ControlNet(device, **{**kwargs, "model_info": model_info}) + return ControlNet(device, **kwargs) else: if model_info.model_type in [ ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SD, ]: - return SD(device, model_id_or_path=model_info.path, **kwargs) + return SD(device, **kwargs) if model_info.model_type in [ ModelType.DIFFUSERS_SDXL_INPAINT, ModelType.DIFFUSERS_SDXL, ]: - return SDXL(device, model_id_or_path=model_info.path, **kwargs) + return SDXL(device, **kwargs) raise NotImplementedError(f"Unsupported model: {name}") diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 2922ee3..7da42f2 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os import hashlib +import traceback os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -453,7 +454,8 @@ def switch_model(): try: model.switch(new_name) except Exception as e: - error_message = str(e) + traceback.print_exc() + error_message = f"{type(e).__name__} - {str(e)}" logger.error(error_message) return f"Switch model failed: {error_message}", 500 return f"ok, switch to {new_name}", 200