This commit is contained in:
Qing 2023-12-15 12:40:29 +08:00
parent 142aa64cc6
commit cbe6577890
9 changed files with 35 additions and 16 deletions

View File

@ -14,7 +14,7 @@ from lama_cleaner.helper import (
) )
from lama_cleaner.model.helper.g_diffuser_bot import expand_image from lama_cleaner.model.helper.g_diffuser_bot import expand_image
from lama_cleaner.model.utils import get_scheduler from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, HDStrategy, SDSampler from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo
class InpaintModel: class InpaintModel:
@ -266,9 +266,8 @@ class InpaintModel:
class DiffusionInpaintModel(InpaintModel): class DiffusionInpaintModel(InpaintModel):
def __init__(self, device, **kwargs): def __init__(self, device, **kwargs):
if kwargs.get("model_id_or_path"): self.model_info: ModelInfo = kwargs["model_info"]
# 用于自定义 diffusers 模型 self.model_id_or_path = self.model_info.path
self.model_id_or_path = kwargs["model_id_or_path"]
super().__init__(device, **kwargs) super().__init__(device, **kwargs)
@torch.no_grad() @torch.no_grad()

View File

@ -75,6 +75,11 @@ class ControlNet(DiffusionInpaintModel):
sd_controlnet_method, torch_dtype=torch_dtype sd_controlnet_method, torch_dtype=torch_dtype
) )
if model_info.is_single_file_diffusers: if model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
self.model = PipeClass.from_single_file( self.model = PipeClass.from_single_file(
model_info.path, controlnet=controlnet model_info.path, controlnet=controlnet
).to(torch_dtype) ).to(torch_dtype)

View File

@ -8,7 +8,7 @@ from lama_cleaner.schema import Config
class InstructPix2Pix(DiffusionInpaintModel): class InstructPix2Pix(DiffusionInpaintModel):
name = "instruct_pix2pix" name = "timbrooks/instruct-pix2pix"
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512

View File

@ -73,7 +73,7 @@ class Kandinsky(DiffusionInpaintModel):
class Kandinsky22(Kandinsky): class Kandinsky22(Kandinsky):
name = "kandinsky2.2" name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
@staticmethod @staticmethod

View File

@ -9,7 +9,7 @@ from lama_cleaner.schema import Config
class PaintByExample(DiffusionInpaintModel): class PaintByExample(DiffusionInpaintModel):
name = "paint_by_example" name = "Fantasy-Studio/Paint-by-Example"
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512

View File

@ -9,7 +9,7 @@ from loguru import logger
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.schema import Config from lama_cleaner.schema import Config, ModelType
class SD(DiffusionInpaintModel): class SD(DiffusionInpaintModel):
@ -36,7 +36,12 @@ class SD(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
if os.path.isfile(self.model_id_or_path): if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
self.model = StableDiffusionInpaintPipeline.from_single_file( self.model = StableDiffusionInpaintPipeline.from_single_file(
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
) )

View File

@ -8,7 +8,7 @@ from diffusers import AutoencoderKL
from loguru import logger from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.schema import Config from lama_cleaner.schema import Config, ModelType
class SDXL(DiffusionInpaintModel): class SDXL(DiffusionInpaintModel):
@ -26,9 +26,16 @@ class SDXL(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
if self.model_info.model_type == ModelType.DIFFUSERS_SDXL:
num_in_channels = 4
else:
num_in_channels = 9
if os.path.isfile(self.model_id_or_path): if os.path.isfile(self.model_id_or_path):
self.model = StableDiffusionXLInpaintPipeline.from_single_file( self.model = StableDiffusionXLInpaintPipeline.from_single_file(
self.model_id_or_path, torch_dtype=torch_dtype self.model_id_or_path,
torch_dtype=torch_dtype,
num_in_channels=num_in_channels,
) )
else: else:
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(

View File

@ -27,25 +27,26 @@ class ModelManager:
if name not in self.available_models: if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}") raise NotImplementedError(f"Unsupported model: {name}")
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
model_info = self.available_models[name] model_info = self.available_models[name]
kwargs = {**kwargs, "model_info": model_info}
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]: if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
return models[name](device, **kwargs) return models[name](device, **kwargs)
if sd_controlnet_enabled: if sd_controlnet_enabled:
return ControlNet(device, **{**kwargs, "model_info": model_info}) return ControlNet(device, **kwargs)
else: else:
if model_info.model_type in [ if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
]: ]:
return SD(device, model_id_or_path=model_info.path, **kwargs) return SD(device, **kwargs)
if model_info.model_type in [ if model_info.model_type in [
ModelType.DIFFUSERS_SDXL_INPAINT, ModelType.DIFFUSERS_SDXL_INPAINT,
ModelType.DIFFUSERS_SDXL, ModelType.DIFFUSERS_SDXL,
]: ]:
return SDXL(device, model_id_or_path=model_info.path, **kwargs) return SDXL(device, **kwargs)
raise NotImplementedError(f"Unsupported model: {name}") raise NotImplementedError(f"Unsupported model: {name}")

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import hashlib import hashlib
import traceback
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -453,7 +454,8 @@ def switch_model():
try: try:
model.switch(new_name) model.switch(new_name)
except Exception as e: except Exception as e:
error_message = str(e) traceback.print_exc()
error_message = f"{type(e).__name__} - {str(e)}"
logger.error(error_message) logger.error(error_message)
return f"Switch model failed: {error_message}", 500 return f"Switch model failed: {error_message}", 500
return f"ok, switch to {new_name}", 200 return f"ok, switch to {new_name}", 200