update
This commit is contained in:
parent
142aa64cc6
commit
cbe6577890
@ -14,7 +14,7 @@ from lama_cleaner.helper import (
|
|||||||
)
|
)
|
||||||
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
||||||
from lama_cleaner.model.utils import get_scheduler
|
from lama_cleaner.model.utils import get_scheduler
|
||||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo
|
||||||
|
|
||||||
|
|
||||||
class InpaintModel:
|
class InpaintModel:
|
||||||
@ -266,9 +266,8 @@ class InpaintModel:
|
|||||||
|
|
||||||
class DiffusionInpaintModel(InpaintModel):
|
class DiffusionInpaintModel(InpaintModel):
|
||||||
def __init__(self, device, **kwargs):
|
def __init__(self, device, **kwargs):
|
||||||
if kwargs.get("model_id_or_path"):
|
self.model_info: ModelInfo = kwargs["model_info"]
|
||||||
# 用于自定义 diffusers 模型
|
self.model_id_or_path = self.model_info.path
|
||||||
self.model_id_or_path = kwargs["model_id_or_path"]
|
|
||||||
super().__init__(device, **kwargs)
|
super().__init__(device, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -75,6 +75,11 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
sd_controlnet_method, torch_dtype=torch_dtype
|
sd_controlnet_method, torch_dtype=torch_dtype
|
||||||
)
|
)
|
||||||
if model_info.is_single_file_diffusers:
|
if model_info.is_single_file_diffusers:
|
||||||
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
|
model_kwargs["num_in_channels"] = 4
|
||||||
|
else:
|
||||||
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
self.model = PipeClass.from_single_file(
|
self.model = PipeClass.from_single_file(
|
||||||
model_info.path, controlnet=controlnet
|
model_info.path, controlnet=controlnet
|
||||||
).to(torch_dtype)
|
).to(torch_dtype)
|
||||||
|
@ -8,7 +8,7 @@ from lama_cleaner.schema import Config
|
|||||||
|
|
||||||
|
|
||||||
class InstructPix2Pix(DiffusionInpaintModel):
|
class InstructPix2Pix(DiffusionInpaintModel):
|
||||||
name = "instruct_pix2pix"
|
name = "timbrooks/instruct-pix2pix"
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ class Kandinsky(DiffusionInpaintModel):
|
|||||||
|
|
||||||
|
|
||||||
class Kandinsky22(Kandinsky):
|
class Kandinsky22(Kandinsky):
|
||||||
name = "kandinsky2.2"
|
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||||
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -9,7 +9,7 @@ from lama_cleaner.schema import Config
|
|||||||
|
|
||||||
|
|
||||||
class PaintByExample(DiffusionInpaintModel):
|
class PaintByExample(DiffusionInpaintModel):
|
||||||
name = "paint_by_example"
|
name = "Fantasy-Studio/Paint-by-Example"
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from loguru import logger
|
|||||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config, ModelType
|
||||||
|
|
||||||
|
|
||||||
class SD(DiffusionInpaintModel):
|
class SD(DiffusionInpaintModel):
|
||||||
@ -36,7 +36,12 @@ class SD(DiffusionInpaintModel):
|
|||||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
|
||||||
if os.path.isfile(self.model_id_or_path):
|
if self.model_info.is_single_file_diffusers:
|
||||||
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
|
model_kwargs["num_in_channels"] = 4
|
||||||
|
else:
|
||||||
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
||||||
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
|
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
|
||||||
)
|
)
|
||||||
|
@ -8,7 +8,7 @@ from diffusers import AutoencoderKL
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config, ModelType
|
||||||
|
|
||||||
|
|
||||||
class SDXL(DiffusionInpaintModel):
|
class SDXL(DiffusionInpaintModel):
|
||||||
@ -26,9 +26,16 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
|
||||||
|
if self.model_info.model_type == ModelType.DIFFUSERS_SDXL:
|
||||||
|
num_in_channels = 4
|
||||||
|
else:
|
||||||
|
num_in_channels = 9
|
||||||
|
|
||||||
if os.path.isfile(self.model_id_or_path):
|
if os.path.isfile(self.model_id_or_path):
|
||||||
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
|
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||||
self.model_id_or_path, torch_dtype=torch_dtype
|
self.model_id_or_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
num_in_channels=num_in_channels,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
|
@ -27,25 +27,26 @@ class ModelManager:
|
|||||||
if name not in self.available_models:
|
if name not in self.available_models:
|
||||||
raise NotImplementedError(f"Unsupported model: {name}")
|
raise NotImplementedError(f"Unsupported model: {name}")
|
||||||
|
|
||||||
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
|
|
||||||
model_info = self.available_models[name]
|
model_info = self.available_models[name]
|
||||||
|
kwargs = {**kwargs, "model_info": model_info}
|
||||||
|
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
|
||||||
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
||||||
return models[name](device, **kwargs)
|
return models[name](device, **kwargs)
|
||||||
|
|
||||||
if sd_controlnet_enabled:
|
if sd_controlnet_enabled:
|
||||||
return ControlNet(device, **{**kwargs, "model_info": model_info})
|
return ControlNet(device, **kwargs)
|
||||||
else:
|
else:
|
||||||
if model_info.model_type in [
|
if model_info.model_type in [
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
ModelType.DIFFUSERS_SD,
|
ModelType.DIFFUSERS_SD,
|
||||||
]:
|
]:
|
||||||
return SD(device, model_id_or_path=model_info.path, **kwargs)
|
return SD(device, **kwargs)
|
||||||
|
|
||||||
if model_info.model_type in [
|
if model_info.model_type in [
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
ModelType.DIFFUSERS_SDXL,
|
ModelType.DIFFUSERS_SDXL,
|
||||||
]:
|
]:
|
||||||
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
|
return SDXL(device, **kwargs)
|
||||||
|
|
||||||
raise NotImplementedError(f"Unsupported model: {name}")
|
raise NotImplementedError(f"Unsupported model: {name}")
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import traceback
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
@ -453,7 +454,8 @@ def switch_model():
|
|||||||
try:
|
try:
|
||||||
model.switch(new_name)
|
model.switch(new_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
traceback.print_exc()
|
||||||
|
error_message = f"{type(e).__name__} - {str(e)}"
|
||||||
logger.error(error_message)
|
logger.error(error_message)
|
||||||
return f"Switch model failed: {error_message}", 500
|
return f"Switch model failed: {error_message}", 500
|
||||||
return f"ok, switch to {new_name}", 200
|
return f"ok, switch to {new_name}", 200
|
||||||
|
Loading…
Reference in New Issue
Block a user