lots update
This commit is contained in:
parent
0ba6c121e0
commit
f0b852725f
@ -4,6 +4,11 @@ from enum import Enum
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||||
|
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||||
|
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||||
|
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
|
||||||
|
|
||||||
MPS_UNSUPPORT_MODELS = [
|
MPS_UNSUPPORT_MODELS = [
|
||||||
"lama",
|
"lama",
|
||||||
"ldm",
|
"ldm",
|
||||||
@ -15,22 +20,8 @@ MPS_UNSUPPORT_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_MODEL = "lama"
|
DEFAULT_MODEL = "lama"
|
||||||
AVAILABLE_MODELS = [
|
AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
|
||||||
"lama",
|
|
||||||
"ldm",
|
|
||||||
"zits",
|
|
||||||
"mat",
|
|
||||||
"fcf",
|
|
||||||
"manga",
|
|
||||||
"cv2",
|
|
||||||
]
|
|
||||||
DIFFUSERS_MODEL_FP16_REVERSION = [
|
|
||||||
"runwayml/stable-diffusion-inpainting",
|
|
||||||
"Sanster/anything-4.0-inpainting",
|
|
||||||
"Sanster/Realistic_Vision_V1.4-inpainting",
|
|
||||||
"stabilityai/stable-diffusion-2-inpainting",
|
|
||||||
"timbrooks/instruct-pix2pix",
|
|
||||||
]
|
|
||||||
|
|
||||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||||
DEFAULT_DEVICE = "cuda"
|
DEFAULT_DEVICE = "cuda"
|
||||||
|
@ -5,23 +5,23 @@ from typing import List
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
|
from lama_cleaner.const import (
|
||||||
from lama_cleaner.runtime import setup_model_dir
|
DEFAULT_MODEL_DIR,
|
||||||
from lama_cleaner.schema import (
|
|
||||||
ModelInfo,
|
|
||||||
ModelType,
|
|
||||||
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
|
||||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
|
||||||
DIFFUSERS_SD_CLASS_NAME,
|
DIFFUSERS_SD_CLASS_NAME,
|
||||||
|
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
||||||
DIFFUSERS_SDXL_CLASS_NAME,
|
DIFFUSERS_SDXL_CLASS_NAME,
|
||||||
|
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
||||||
)
|
)
|
||||||
|
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||||
|
from lama_cleaner.model_info import ModelInfo, ModelType
|
||||||
|
from lama_cleaner.runtime import setup_model_dir
|
||||||
|
|
||||||
|
|
||||||
def cli_download_model(model: str, model_dir: Path):
|
def cli_download_model(model: str, model_dir: Path):
|
||||||
setup_model_dir(model_dir)
|
setup_model_dir(model_dir)
|
||||||
from lama_cleaner.model import models
|
from lama_cleaner.model import models
|
||||||
|
|
||||||
if model in models:
|
if model in models and models[model].is_erase_model:
|
||||||
logger.info(f"Downloading {model}...")
|
logger.info(f"Downloading {model}...")
|
||||||
models[model].download()
|
models[model].download()
|
||||||
logger.info(f"Done.")
|
logger.info(f"Done.")
|
||||||
@ -29,9 +29,10 @@ def cli_download_model(model: str, model_dir: Path):
|
|||||||
logger.info(f"Downloading model from Huggingface: {model}")
|
logger.info(f"Downloading model from Huggingface: {model}")
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
downloaded_path = DiffusionPipeline.download(
|
downloaded_path = handle_from_pretrained_exceptions(
|
||||||
|
DiffusionPipeline.download,
|
||||||
pretrained_model_name=model,
|
pretrained_model_name=model,
|
||||||
variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
|
variant="fp16",
|
||||||
resume_download=True,
|
resume_download=True,
|
||||||
)
|
)
|
||||||
logger.info(f"Done. Downloaded to {downloaded_path}")
|
logger.info(f"Done. Downloaded to {downloaded_path}")
|
||||||
@ -43,21 +44,33 @@ def folder_name_to_show_name(name: str) -> str:
|
|||||||
|
|
||||||
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||||
cache_dir = Path(cache_dir)
|
cache_dir = Path(cache_dir)
|
||||||
|
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
||||||
|
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
||||||
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
|
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
|
||||||
res = []
|
res = []
|
||||||
for it in cache_dir.glob(f"*.*"):
|
for it in stable_diffusion_dir.glob(f"*.*"):
|
||||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||||
continue
|
continue
|
||||||
if "inpaint" in str(it).lower():
|
if "inpaint" in str(it).lower():
|
||||||
if "sdxl" in str(it).lower():
|
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
|
||||||
else:
|
|
||||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
|
||||||
else:
|
else:
|
||||||
if "sdxl" in str(it).lower():
|
model_type = ModelType.DIFFUSERS_SD
|
||||||
model_type = ModelType.DIFFUSERS_SDXL
|
res.append(
|
||||||
else:
|
ModelInfo(
|
||||||
model_type = ModelType.DIFFUSERS_SD
|
name=it.name,
|
||||||
|
path=str(it.absolute()),
|
||||||
|
model_type=model_type,
|
||||||
|
is_single_file_diffusers=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for it in stable_diffusion_xl_dir.glob(f"*.*"):
|
||||||
|
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||||
|
continue
|
||||||
|
if "inpaint" in str(it).lower():
|
||||||
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||||
|
else:
|
||||||
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
res.append(
|
res.append(
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
name=it.name,
|
name=it.name,
|
||||||
@ -104,8 +117,9 @@ def scan_models() -> List[ModelInfo]:
|
|||||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||||
if name in diffusers_model_names:
|
if name in diffusers_model_names:
|
||||||
continue
|
continue
|
||||||
|
if "PowerPaint" in name:
|
||||||
if _class_name == DIFFUSERS_SD_CLASS_NAME:
|
model_type = ModelType.DIFFUSERS_OTHER
|
||||||
|
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
|
||||||
model_type = ModelType.DIFFUSERS_SD
|
model_type = ModelType.DIFFUSERS_SD
|
||||||
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
||||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||||
|
@ -290,3 +290,7 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
|||||||
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
||||||
else:
|
else:
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def is_mac():
|
||||||
|
return sys.platform == "darwin"
|
||||||
|
@ -9,6 +9,7 @@ from .mat import MAT
|
|||||||
from .mi_gan import MIGAN
|
from .mi_gan import MIGAN
|
||||||
from .opencv2 import OpenCV2
|
from .opencv2 import OpenCV2
|
||||||
from .paint_by_example import PaintByExample
|
from .paint_by_example import PaintByExample
|
||||||
|
from .power_paint.power_paint import PowerPaint
|
||||||
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
|
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
|
||||||
from .sdxl import SDXL
|
from .sdxl import SDXL
|
||||||
from .zits import ZITS
|
from .zits import ZITS
|
||||||
@ -30,4 +31,5 @@ models = {
|
|||||||
InstructPix2Pix.name: InstructPix2Pix,
|
InstructPix2Pix.name: InstructPix2Pix,
|
||||||
Kandinsky22.name: Kandinsky22,
|
Kandinsky22.name: Kandinsky22,
|
||||||
SDXL.name: SDXL,
|
SDXL.name: SDXL,
|
||||||
|
PowerPaint.name: PowerPaint,
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ from lama_cleaner.helper import (
|
|||||||
)
|
)
|
||||||
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
||||||
from lama_cleaner.model.utils import get_scheduler
|
from lama_cleaner.model.utils import get_scheduler
|
||||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo
|
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||||
|
|
||||||
|
|
||||||
class InpaintModel:
|
class InpaintModel:
|
||||||
@ -271,7 +271,7 @@ class InpaintModel:
|
|||||||
|
|
||||||
class DiffusionInpaintModel(InpaintModel):
|
class DiffusionInpaintModel(InpaintModel):
|
||||||
def __init__(self, device, **kwargs):
|
def __init__(self, device, **kwargs):
|
||||||
self.model_info: ModelInfo = kwargs["model_info"]
|
self.model_info = kwargs["model_info"]
|
||||||
self.model_id_or_path = self.model_info.path
|
self.model_id_or_path = self.model_info.path
|
||||||
super().__init__(device, **kwargs)
|
super().__init__(device, **kwargs)
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ import torch
|
|||||||
from diffusers import ControlNetModel, DiffusionPipeline
|
from diffusers import ControlNetModel, DiffusionPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.helper.controlnet_preprocess import (
|
from lama_cleaner.model.helper.controlnet_preprocess import (
|
||||||
make_canny_control_image,
|
make_canny_control_image,
|
||||||
@ -14,8 +13,8 @@ from lama_cleaner.model.helper.controlnet_preprocess import (
|
|||||||
make_inpaint_control_image,
|
make_inpaint_control_image,
|
||||||
)
|
)
|
||||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from lama_cleaner.model.utils import get_scheduler
|
from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions
|
||||||
from lama_cleaner.schema import Config, ModelInfo, ModelType
|
from lama_cleaner.schema import Config, ModelType
|
||||||
|
|
||||||
|
|
||||||
class ControlNet(DiffusionInpaintModel):
|
class ControlNet(DiffusionInpaintModel):
|
||||||
@ -39,11 +38,11 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
fp16 = not kwargs.get("no_half", False)
|
fp16 = not kwargs.get("no_half", False)
|
||||||
model_info: ModelInfo = kwargs["model_info"]
|
model_info = kwargs["model_info"]
|
||||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
controlnet_method = kwargs["controlnet_method"]
|
||||||
|
|
||||||
self.model_info = model_info
|
self.model_info = model_info
|
||||||
self.sd_controlnet_method = sd_controlnet_method
|
self.controlnet_method = controlnet_method
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
@ -76,7 +75,8 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
controlnet = ControlNetModel.from_pretrained(
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True
|
pretrained_model_name_or_path=controlnet_method,
|
||||||
|
resume_download=True,
|
||||||
)
|
)
|
||||||
if model_info.is_single_file_diffusers:
|
if model_info.is_single_file_diffusers:
|
||||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
@ -88,17 +88,12 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
model_info.path, controlnet=controlnet, **model_kwargs
|
model_info.path, controlnet=controlnet, **model_kwargs
|
||||||
).to(torch_dtype)
|
).to(torch_dtype)
|
||||||
else:
|
else:
|
||||||
self.model = PipeClass.from_pretrained(
|
self.model = handle_from_pretrained_exceptions(
|
||||||
model_info.path,
|
PipeClass.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=model_info.path,
|
||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
revision="fp16"
|
variant="fp16",
|
||||||
if (
|
dtype=torch_dtype,
|
||||||
model_info.path in DIFFUSERS_MODEL_FP16_REVERSION
|
|
||||||
and use_gpu
|
|
||||||
and fp16
|
|
||||||
)
|
|
||||||
else "main",
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -116,23 +111,23 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
def switch_controlnet_method(self, new_method: str):
|
def switch_controlnet_method(self, new_method: str):
|
||||||
self.sd_controlnet_method = new_method
|
self.controlnet_method = new_method
|
||||||
controlnet = ControlNetModel.from_pretrained(
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
new_method, torch_dtype=self.torch_dtype, resume_download=True
|
new_method, torch_dtype=self.torch_dtype, resume_download=True
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
self.model.controlnet = controlnet
|
self.model.controlnet = controlnet
|
||||||
|
|
||||||
def _get_control_image(self, image, mask):
|
def _get_control_image(self, image, mask):
|
||||||
if "canny" in self.sd_controlnet_method:
|
if "canny" in self.controlnet_method:
|
||||||
control_image = make_canny_control_image(image)
|
control_image = make_canny_control_image(image)
|
||||||
elif "openpose" in self.sd_controlnet_method:
|
elif "openpose" in self.controlnet_method:
|
||||||
control_image = make_openpose_control_image(image)
|
control_image = make_openpose_control_image(image)
|
||||||
elif "depth" in self.sd_controlnet_method:
|
elif "depth" in self.controlnet_method:
|
||||||
control_image = make_depth_control_image(image)
|
control_image = make_depth_control_image(image)
|
||||||
elif "inpaint" in self.sd_controlnet_method:
|
elif "inpaint" in self.controlnet_method:
|
||||||
control_image = make_inpaint_control_image(image, mask)
|
control_image = make_inpaint_control_image(image, mask)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{self.sd_controlnet_method} not implemented")
|
raise NotImplementedError(f"{self.controlnet_method} not implemented")
|
||||||
return control_image
|
return control_image
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
|
@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||||
self.model_id_or_path, **model_kwargs
|
self.name, **model_kwargs
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
@ -66,4 +66,3 @@ class Kandinsky(DiffusionInpaintModel):
|
|||||||
|
|
||||||
class Kandinsky22(Kandinsky):
|
class Kandinsky22(Kandinsky):
|
||||||
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||||
model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
|
||||||
|
@ -16,7 +16,7 @@ from lama_cleaner.model.base import InpaintModel
|
|||||||
|
|
||||||
MIGAN_MODEL_URL = os.environ.get(
|
MIGAN_MODEL_URL = os.environ.get(
|
||||||
"MIGAN_MODEL_URL",
|
"MIGAN_MODEL_URL",
|
||||||
"/Users/cwq/code/github/MI-GAN/exported_models/migan_places512/models/migan_traced.pt",
|
"https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
|
||||||
)
|
)
|
||||||
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
|
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class PaintByExample(DiffusionInpaintModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.model = DiffusionPipeline.from_pretrained(
|
self.model = DiffusionPipeline.from_pretrained(
|
||||||
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
|
self.name, torch_dtype=torch_dtype, **model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: gpu_id
|
# TODO: gpu_id
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
from .pipeline_stable_diffusion_controlnet_inpaint import (
|
|
||||||
StableDiffusionControlNetInpaintPipeline,
|
|
||||||
)
|
|
@ -1,638 +0,0 @@
|
|||||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import gc
|
|
||||||
from typing import Union, List, Optional, Callable, Dict, Any
|
|
||||||
|
|
||||||
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import PIL.Image
|
|
||||||
|
|
||||||
from diffusers.pipelines.controlnet.pipeline_controlnet import *
|
|
||||||
from diffusers.utils import replace_example_docstring
|
|
||||||
|
|
||||||
EXAMPLE_DOC_STRING = """
|
|
||||||
Examples:
|
|
||||||
```py
|
|
||||||
>>> # !pip install opencv-python transformers accelerate
|
|
||||||
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
|
|
||||||
>>> from diffusers.utils import load_image
|
|
||||||
>>> import numpy as np
|
|
||||||
>>> import torch
|
|
||||||
|
|
||||||
>>> import cv2
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> # download an image
|
|
||||||
>>> image = load_image(
|
|
||||||
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
|
||||||
... )
|
|
||||||
>>> image = np.array(image)
|
|
||||||
>>> mask_image = load_image(
|
|
||||||
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
|
||||||
... )
|
|
||||||
>>> mask_image = np.array(mask_image)
|
|
||||||
>>> # get canny image
|
|
||||||
>>> canny_image = cv2.Canny(image, 100, 200)
|
|
||||||
>>> canny_image = canny_image[:, :, None]
|
|
||||||
>>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
|
||||||
>>> canny_image = Image.fromarray(canny_image)
|
|
||||||
|
|
||||||
>>> # load control net and stable diffusion v1-5
|
|
||||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
|
||||||
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
|
||||||
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
|
|
||||||
... )
|
|
||||||
|
|
||||||
>>> # speed up diffusion process with faster scheduler and memory optimization
|
|
||||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
|
||||||
>>> # remove following line if xformers is not installed
|
|
||||||
>>> pipe.enable_xformers_memory_efficient_attention()
|
|
||||||
|
|
||||||
>>> pipe.enable_model_cpu_offload()
|
|
||||||
|
|
||||||
>>> # generate image
|
|
||||||
>>> generator = torch.manual_seed(0)
|
|
||||||
>>> image = pipe(
|
|
||||||
... "futuristic-looking doggo",
|
|
||||||
... num_inference_steps=20,
|
|
||||||
... generator=generator,
|
|
||||||
... image=image,
|
|
||||||
... control_image=canny_image,
|
|
||||||
... mask_image=mask_image
|
|
||||||
... ).images[0]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_mask_and_masked_image(image, mask):
|
|
||||||
"""
|
|
||||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
|
||||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
|
||||||
``image`` and ``1`` for the ``mask``.
|
|
||||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
|
||||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
|
||||||
Args:
|
|
||||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
|
||||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
|
||||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
|
||||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
|
||||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
|
||||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
|
||||||
Raises:
|
|
||||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
|
||||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
|
||||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
|
||||||
(ot the other way around).
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
|
||||||
dimensions: ``batch x channels x height x width``.
|
|
||||||
"""
|
|
||||||
if isinstance(image, torch.Tensor):
|
|
||||||
if not isinstance(mask, torch.Tensor):
|
|
||||||
raise TypeError(
|
|
||||||
f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Batch single image
|
|
||||||
if image.ndim == 3:
|
|
||||||
assert (
|
|
||||||
image.shape[0] == 3
|
|
||||||
), "Image outside a batch should be of shape (3, H, W)"
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
|
|
||||||
# Batch and add channel dim for single mask
|
|
||||||
if mask.ndim == 2:
|
|
||||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
# Batch single mask or add channel dim
|
|
||||||
if mask.ndim == 3:
|
|
||||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
|
||||||
if mask.shape[0] == 1:
|
|
||||||
mask = mask.unsqueeze(0)
|
|
||||||
|
|
||||||
# Batched masks no channel dim
|
|
||||||
else:
|
|
||||||
mask = mask.unsqueeze(1)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
image.ndim == 4 and mask.ndim == 4
|
|
||||||
), "Image and Mask must have 4 dimensions"
|
|
||||||
assert (
|
|
||||||
image.shape[-2:] == mask.shape[-2:]
|
|
||||||
), "Image and Mask must have the same spatial dimensions"
|
|
||||||
assert (
|
|
||||||
image.shape[0] == mask.shape[0]
|
|
||||||
), "Image and Mask must have the same batch size"
|
|
||||||
|
|
||||||
# Check image is in [-1, 1]
|
|
||||||
if image.min() < -1 or image.max() > 1:
|
|
||||||
raise ValueError("Image should be in [-1, 1] range")
|
|
||||||
|
|
||||||
# Check mask is in [0, 1]
|
|
||||||
if mask.min() < 0 or mask.max() > 1:
|
|
||||||
raise ValueError("Mask should be in [0, 1] range")
|
|
||||||
|
|
||||||
# Binarize mask
|
|
||||||
mask[mask < 0.5] = 0
|
|
||||||
mask[mask >= 0.5] = 1
|
|
||||||
|
|
||||||
# Image as float32
|
|
||||||
image = image.to(dtype=torch.float32)
|
|
||||||
elif isinstance(mask, torch.Tensor):
|
|
||||||
raise TypeError(
|
|
||||||
f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# preprocess image
|
|
||||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
|
||||||
image = [image]
|
|
||||||
|
|
||||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
|
||||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
|
||||||
image = np.concatenate(image, axis=0)
|
|
||||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
|
||||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
|
||||||
|
|
||||||
image = image.transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
|
||||||
|
|
||||||
# preprocess mask
|
|
||||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
|
||||||
mask = [mask]
|
|
||||||
|
|
||||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
|
||||||
mask = np.concatenate(
|
|
||||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
|
||||||
)
|
|
||||||
mask = mask.astype(np.float32) / 255.0
|
|
||||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
|
||||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
|
||||||
|
|
||||||
mask[mask < 0.5] = 0
|
|
||||||
mask[mask >= 0.5] = 1
|
|
||||||
mask = torch.from_numpy(mask)
|
|
||||||
|
|
||||||
masked_image = image * (mask < 0.5)
|
|
||||||
|
|
||||||
return mask, masked_image
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
|
|
||||||
r"""
|
|
||||||
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
|
|
||||||
|
|
||||||
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
|
|
||||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vae ([`AutoencoderKL`]):
|
|
||||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
|
||||||
text_encoder ([`CLIPTextModel`]):
|
|
||||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
|
||||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
|
||||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
|
||||||
tokenizer (`CLIPTokenizer`):
|
|
||||||
Tokenizer of class
|
|
||||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
||||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
|
||||||
controlnet ([`ControlNetModel`]):
|
|
||||||
Provides additional conditioning to the unet during the denoising process
|
|
||||||
scheduler ([`SchedulerMixin`]):
|
|
||||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
|
||||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
|
||||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
|
||||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
|
||||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
|
||||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
|
||||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
|
||||||
download_from_original_stable_diffusion_ckpt,
|
|
||||||
)
|
|
||||||
|
|
||||||
controlnet = kwargs.pop("controlnet", None)
|
|
||||||
|
|
||||||
pipe = download_from_original_stable_diffusion_ckpt(
|
|
||||||
pretrained_model_link_or_path,
|
|
||||||
num_in_channels=9,
|
|
||||||
from_safetensors=pretrained_model_link_or_path.endswith("safetensors"),
|
|
||||||
device="cpu",
|
|
||||||
load_safety_checker=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
inpaint_pipe = cls(
|
|
||||||
vae=pipe.vae,
|
|
||||||
text_encoder=pipe.text_encoder,
|
|
||||||
tokenizer=pipe.tokenizer,
|
|
||||||
unet=pipe.unet,
|
|
||||||
controlnet=controlnet,
|
|
||||||
scheduler=pipe.scheduler,
|
|
||||||
safety_checker=None,
|
|
||||||
feature_extractor=None,
|
|
||||||
requires_safety_checker=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
del pipe
|
|
||||||
gc.collect()
|
|
||||||
return inpaint_pipe
|
|
||||||
|
|
||||||
def prepare_mask_latents(
|
|
||||||
self,
|
|
||||||
mask,
|
|
||||||
masked_image,
|
|
||||||
batch_size,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
dtype,
|
|
||||||
device,
|
|
||||||
generator,
|
|
||||||
do_classifier_free_guidance,
|
|
||||||
):
|
|
||||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
|
||||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
|
||||||
# and half precision
|
|
||||||
mask = torch.nn.functional.interpolate(
|
|
||||||
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
||||||
)
|
|
||||||
mask = mask.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
# encode the mask image into latents space so we can concatenate it to the latents
|
|
||||||
if isinstance(generator, list):
|
|
||||||
masked_image_latents = [
|
|
||||||
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
|
|
||||||
generator=generator[i]
|
|
||||||
)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
|
||||||
else:
|
|
||||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
|
|
||||||
generator=generator
|
|
||||||
)
|
|
||||||
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
|
||||||
|
|
||||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
|
||||||
if mask.shape[0] < batch_size:
|
|
||||||
if not batch_size % mask.shape[0] == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
|
||||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
|
||||||
" of masks that you pass is divisible by the total requested batch size."
|
|
||||||
)
|
|
||||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
|
||||||
if masked_image_latents.shape[0] < batch_size:
|
|
||||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
|
||||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
|
||||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
|
||||||
)
|
|
||||||
masked_image_latents = masked_image_latents.repeat(
|
|
||||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
|
||||||
masked_image_latents = (
|
|
||||||
torch.cat([masked_image_latents] * 2)
|
|
||||||
if do_classifier_free_guidance
|
|
||||||
else masked_image_latents
|
|
||||||
)
|
|
||||||
|
|
||||||
# aligning device to prevent device errors when concating it with the latent model input
|
|
||||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
|
||||||
return mask, masked_image_latents
|
|
||||||
|
|
||||||
def _default_height_width(self, height, width, image):
|
|
||||||
if isinstance(image, list):
|
|
||||||
image = image[0]
|
|
||||||
|
|
||||||
if height is None:
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
|
||||||
height = image.height
|
|
||||||
elif isinstance(image, torch.Tensor):
|
|
||||||
height = image.shape[3]
|
|
||||||
|
|
||||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
|
||||||
|
|
||||||
if width is None:
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
|
||||||
width = image.width
|
|
||||||
elif isinstance(image, torch.Tensor):
|
|
||||||
width = image.shape[2]
|
|
||||||
|
|
||||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
|
||||||
|
|
||||||
return height, width
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str]] = None,
|
|
||||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
|
||||||
control_image: Union[
|
|
||||||
torch.FloatTensor,
|
|
||||||
PIL.Image.Image,
|
|
||||||
List[torch.FloatTensor],
|
|
||||||
List[PIL.Image.Image],
|
|
||||||
] = None,
|
|
||||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
|
||||||
height: Optional[int] = None,
|
|
||||||
width: Optional[int] = None,
|
|
||||||
num_inference_steps: int = 50,
|
|
||||||
guidance_scale: float = 7.5,
|
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
||||||
num_images_per_prompt: Optional[int] = 1,
|
|
||||||
eta: float = 0.0,
|
|
||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
output_type: Optional[str] = "pil",
|
|
||||||
return_dict: bool = True,
|
|
||||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
||||||
callback_steps: int = 1,
|
|
||||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
controlnet_conditioning_scale: float = 1.0,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Function invoked when calling the pipeline for generation.
|
|
||||||
Args:
|
|
||||||
prompt (`str` or `List[str]`, *optional*):
|
|
||||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
|
||||||
instead.
|
|
||||||
image (`PIL.Image.Image`):
|
|
||||||
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
|
|
||||||
be masked out with `mask_image` and repainted according to `prompt`.
|
|
||||||
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
|
|
||||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
|
||||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
|
|
||||||
also be accepted as an image. The control image is automatically resized to fit the output image.
|
|
||||||
mask_image (`PIL.Image.Image`):
|
|
||||||
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
|
||||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
|
||||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
|
||||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
|
||||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
||||||
The height in pixels of the generated image.
|
|
||||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
||||||
The width in pixels of the generated image.
|
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
||||||
expense of slower inference.
|
|
||||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
||||||
usually at the expense of lower image quality.
|
|
||||||
negative_prompt (`str` or `List[str]`, *optional*):
|
|
||||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
||||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
|
||||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
||||||
The number of images to generate per prompt.
|
|
||||||
eta (`float`, *optional*, defaults to 0.0):
|
|
||||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
|
||||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
|
||||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
|
||||||
to make generation deterministic.
|
|
||||||
latents (`torch.FloatTensor`, *optional*):
|
|
||||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
|
||||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
||||||
tensor will ge generated by sampling using the supplied random `generator`.
|
|
||||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
||||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
||||||
provided, text embeddings will be generated from `prompt` input argument.
|
|
||||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
||||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
||||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
||||||
argument.
|
|
||||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
||||||
The output format of the generate image. Choose between
|
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
|
||||||
plain tuple.
|
|
||||||
callback (`Callable`, *optional*):
|
|
||||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
|
||||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
|
||||||
callback_steps (`int`, *optional*, defaults to 1):
|
|
||||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
|
||||||
called at every step.
|
|
||||||
cross_attention_kwargs (`dict`, *optional*):
|
|
||||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
|
||||||
`self.processor` in
|
|
||||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
|
||||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
|
||||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
|
||||||
to the residual in the original unet.
|
|
||||||
Examples:
|
|
||||||
Returns:
|
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
|
||||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
|
||||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
|
||||||
(nsfw) content, according to the `safety_checker`.
|
|
||||||
"""
|
|
||||||
# 0. Default height and width to unet
|
|
||||||
height, width = self._default_height_width(height, width, control_image)
|
|
||||||
|
|
||||||
# 1. Check inputs. Raise error if not correct
|
|
||||||
self.check_inputs(
|
|
||||||
prompt=prompt,
|
|
||||||
image=control_image,
|
|
||||||
callback_steps=callback_steps,
|
|
||||||
prompt_embeds=prompt_embeds,
|
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Define call parameters
|
|
||||||
if prompt is not None and isinstance(prompt, str):
|
|
||||||
batch_size = 1
|
|
||||||
elif prompt is not None and isinstance(prompt, list):
|
|
||||||
batch_size = len(prompt)
|
|
||||||
else:
|
|
||||||
batch_size = prompt_embeds.shape[0]
|
|
||||||
|
|
||||||
device = self._execution_device
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
||||||
# corresponds to doing no classifier free guidance.
|
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
|
||||||
|
|
||||||
# 3. Encode input prompt
|
|
||||||
prompt_embeds = self._encode_prompt(
|
|
||||||
prompt,
|
|
||||||
device,
|
|
||||||
num_images_per_prompt,
|
|
||||||
do_classifier_free_guidance,
|
|
||||||
negative_prompt,
|
|
||||||
prompt_embeds=prompt_embeds,
|
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Prepare image
|
|
||||||
control_image = self.prepare_image(
|
|
||||||
control_image,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
batch_size * num_images_per_prompt,
|
|
||||||
num_images_per_prompt,
|
|
||||||
device,
|
|
||||||
self.controlnet.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
control_image = torch.cat([control_image] * 2)
|
|
||||||
|
|
||||||
# 5. Prepare timesteps
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
||||||
timesteps = self.scheduler.timesteps
|
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
|
||||||
num_channels_latents = self.controlnet.config.in_channels
|
|
||||||
latents = self.prepare_latents(
|
|
||||||
batch_size * num_images_per_prompt,
|
|
||||||
num_channels_latents,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
prompt_embeds.dtype,
|
|
||||||
device,
|
|
||||||
generator,
|
|
||||||
latents,
|
|
||||||
)
|
|
||||||
|
|
||||||
# EXTRA: prepare mask latents
|
|
||||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
|
||||||
mask, masked_image_latents = self.prepare_mask_latents(
|
|
||||||
mask,
|
|
||||||
masked_image,
|
|
||||||
batch_size * num_images_per_prompt,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
prompt_embeds.dtype,
|
|
||||||
device,
|
|
||||||
generator,
|
|
||||||
do_classifier_free_guidance,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
||||||
|
|
||||||
# 8. Denoising loop
|
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
||||||
for i, t in enumerate(timesteps):
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = (
|
|
||||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
||||||
)
|
|
||||||
latent_model_input = self.scheduler.scale_model_input(
|
|
||||||
latent_model_input, t
|
|
||||||
)
|
|
||||||
|
|
||||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
|
||||||
latent_model_input,
|
|
||||||
t,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
controlnet_cond=control_image,
|
|
||||||
return_dict=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
down_block_res_samples = [
|
|
||||||
down_block_res_sample * controlnet_conditioning_scale
|
|
||||||
for down_block_res_sample in down_block_res_samples
|
|
||||||
]
|
|
||||||
mid_block_res_sample *= controlnet_conditioning_scale
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
latent_model_input = torch.cat(
|
|
||||||
[latent_model_input, mask, masked_image_latents], dim=1
|
|
||||||
)
|
|
||||||
noise_pred = self.unet(
|
|
||||||
latent_model_input,
|
|
||||||
t,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
|
||||||
down_block_additional_residuals=down_block_res_samples,
|
|
||||||
mid_block_additional_residual=mid_block_res_sample,
|
|
||||||
).sample
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
||||||
noise_pred_text - noise_pred_uncond
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents = self.scheduler.step(
|
|
||||||
noise_pred, t, latents, **extra_step_kwargs
|
|
||||||
).prev_sample
|
|
||||||
|
|
||||||
# call the callback, if provided
|
|
||||||
if i == len(timesteps) - 1 or (
|
|
||||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
|
||||||
):
|
|
||||||
progress_bar.update()
|
|
||||||
if callback is not None and i % callback_steps == 0:
|
|
||||||
callback(i, t, latents)
|
|
||||||
|
|
||||||
# If we do sequential model offloading, let's offload unet and controlnet
|
|
||||||
# manually for max memory savings
|
|
||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
||||||
self.unet.to("cpu")
|
|
||||||
self.controlnet.to("cpu")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if output_type == "latent":
|
|
||||||
image = latents
|
|
||||||
has_nsfw_concept = None
|
|
||||||
elif output_type == "pil":
|
|
||||||
# 8. Post-processing
|
|
||||||
image = self.decode_latents(latents)
|
|
||||||
|
|
||||||
# 9. Run safety checker
|
|
||||||
image, has_nsfw_concept = self.run_safety_checker(
|
|
||||||
image, device, prompt_embeds.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# 10. Convert to PIL
|
|
||||||
image = self.numpy_to_pil(image)
|
|
||||||
else:
|
|
||||||
# 8. Post-processing
|
|
||||||
image = self.decode_latents(latents)
|
|
||||||
|
|
||||||
# 9. Run safety checker
|
|
||||||
image, has_nsfw_concept = self.run_safety_checker(
|
|
||||||
image, device, prompt_embeds.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Offload last model to CPU
|
|
||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
||||||
self.final_offload_hook.offload()
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (image, has_nsfw_concept)
|
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(
|
|
||||||
images=image, nsfw_content_detected=has_nsfw_concept
|
|
||||||
)
|
|
0
lama_cleaner/model/power_paint/__init__.py
Normal file
0
lama_cleaner/model/power_paint/__init__.py
Normal file
1243
lama_cleaner/model/power_paint/pipeline_powerpaint.py
Normal file
1243
lama_cleaner/model/power_paint/pipeline_powerpaint.py
Normal file
File diff suppressed because it is too large
Load Diff
1775
lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py
Normal file
1775
lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py
Normal file
File diff suppressed because it is too large
Load Diff
96
lama_cleaner/model/power_paint/power_paint.py
Normal file
96
lama_cleaner/model/power_paint/power_paint.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
|
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
|
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
from .powerpaint_tokenizer import add_task_to_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class PowerPaint(DiffusionInpaintModel):
|
||||||
|
name = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from .pipeline_powerpaint import StableDiffusionInpaintPipeline
|
||||||
|
from .powerpaint_tokenizer import PowerPaintTokenizer
|
||||||
|
|
||||||
|
fp16 = not kwargs.get("no_half", False)
|
||||||
|
model_kwargs = {}
|
||||||
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(
|
||||||
|
dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
|
||||||
|
self.model = handle_from_pretrained_exceptions(
|
||||||
|
StableDiffusionInpaintPipeline.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=self.name,
|
||||||
|
variant="fp16",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
|
||||||
|
|
||||||
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
self.set_scheduler(config)
|
||||||
|
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
|
||||||
|
config.prompt, config.negative_prompt, config.powerpaint_task
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.model(
|
||||||
|
image=PIL.Image.fromarray(image),
|
||||||
|
promptA=promptA,
|
||||||
|
promptB=promptB,
|
||||||
|
tradoff=config.fitting_degree,
|
||||||
|
tradoff_nag=config.fitting_degree,
|
||||||
|
negative_promptA=negative_promptA,
|
||||||
|
negative_promptB=negative_promptB,
|
||||||
|
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||||
|
num_inference_steps=config.sd_steps,
|
||||||
|
strength=config.sd_strength,
|
||||||
|
guidance_scale=config.sd_guidance_scale,
|
||||||
|
output_type="np",
|
||||||
|
callback=self.callback,
|
||||||
|
height=img_h,
|
||||||
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
|
callback_steps=1,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
540
lama_cleaner/model/power_paint/powerpaint_tokenizer.py
Normal file
540
lama_cleaner/model/power_paint/powerpaint_tokenizer.py
Normal file
@ -0,0 +1,540 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
|
from lama_cleaner.schema import PowerPaintTask
|
||||||
|
|
||||||
|
|
||||||
|
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
|
||||||
|
if task == PowerPaintTask.object_remove:
|
||||||
|
promptA = prompt + " P_ctxt"
|
||||||
|
promptB = prompt + " P_ctxt"
|
||||||
|
negative_promptA = negative_prompt + " P_obj"
|
||||||
|
negative_promptB = negative_prompt + " P_obj"
|
||||||
|
elif task == PowerPaintTask.shape_guided:
|
||||||
|
promptA = prompt + " P_shape"
|
||||||
|
promptB = prompt + " P_ctxt"
|
||||||
|
negative_promptA = negative_prompt
|
||||||
|
negative_promptB = negative_prompt
|
||||||
|
elif task == PowerPaintTask.outpainting:
|
||||||
|
promptA = prompt + " P_ctxt"
|
||||||
|
promptB = prompt + " P_ctxt"
|
||||||
|
negative_promptA = negative_prompt + " P_obj"
|
||||||
|
negative_promptB = negative_prompt + " P_obj"
|
||||||
|
else:
|
||||||
|
promptA = prompt + " P_obj"
|
||||||
|
promptB = prompt + " P_obj"
|
||||||
|
negative_promptA = negative_prompt
|
||||||
|
negative_promptB = negative_prompt
|
||||||
|
|
||||||
|
return promptA, promptB, negative_promptA, negative_promptB
|
||||||
|
|
||||||
|
|
||||||
|
class PowerPaintTokenizer:
|
||||||
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
|
self.wrapped = tokenizer
|
||||||
|
self.token_map = {}
|
||||||
|
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
|
||||||
|
num_vec_per_token = 10
|
||||||
|
for placeholder_token in placeholder_tokens:
|
||||||
|
output = []
|
||||||
|
for i in range(num_vec_per_token):
|
||||||
|
ith_token = placeholder_token + f"_{i}"
|
||||||
|
output.append(ith_token)
|
||||||
|
self.token_map[placeholder_token] = output
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
if name == "wrapped":
|
||||||
|
return super().__getattr__("wrapped")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return getattr(self.wrapped, name)
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
return super().__getattr__(name)
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError(
|
||||||
|
"'name' cannot be found in both "
|
||||||
|
f"'{self.__class__.__name__}' and "
|
||||||
|
f"'{self.__class__.__name__}.tokenizer'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
||||||
|
"""Attempt to add tokens to the tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (Union[str, List[str]]): The tokens to be added.
|
||||||
|
"""
|
||||||
|
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
||||||
|
assert num_added_tokens != 0, (
|
||||||
|
f"The tokenizer already contains the token {tokens}. Please pass "
|
||||||
|
"a different `placeholder_token` that is not already in the "
|
||||||
|
"tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_token_info(self, token: str) -> dict:
|
||||||
|
"""Get the information of a token, including its start and end index in
|
||||||
|
the current tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The token to be queried.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The information of the token, including its start and end
|
||||||
|
index in current tokenizer.
|
||||||
|
"""
|
||||||
|
token_ids = self.__call__(token).input_ids
|
||||||
|
start, end = token_ids[1], token_ids[-2] + 1
|
||||||
|
return {"name": token, "start": start, "end": end}
|
||||||
|
|
||||||
|
def add_placeholder_token(
|
||||||
|
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
|
||||||
|
):
|
||||||
|
"""Add placeholder tokens to the tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
placeholder_token (str): The placeholder token to be added.
|
||||||
|
num_vec_per_token (int, optional): The number of vectors of
|
||||||
|
the added placeholder token.
|
||||||
|
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
if num_vec_per_token == 1:
|
||||||
|
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
||||||
|
output.append(placeholder_token)
|
||||||
|
else:
|
||||||
|
output = []
|
||||||
|
for i in range(num_vec_per_token):
|
||||||
|
ith_token = placeholder_token + f"_{i}"
|
||||||
|
self.try_adding_tokens(ith_token, *args, **kwargs)
|
||||||
|
output.append(ith_token)
|
||||||
|
|
||||||
|
for token in self.token_map:
|
||||||
|
if token in placeholder_token:
|
||||||
|
raise ValueError(
|
||||||
|
f"The tokenizer already has placeholder token {token} "
|
||||||
|
f"that can get confused with {placeholder_token} "
|
||||||
|
"keep placeholder tokens independent"
|
||||||
|
)
|
||||||
|
self.token_map[placeholder_token] = output
|
||||||
|
|
||||||
|
def replace_placeholder_tokens_in_text(
|
||||||
|
self,
|
||||||
|
text: Union[str, List[str]],
|
||||||
|
vector_shuffle: bool = False,
|
||||||
|
prop_tokens_to_load: float = 1.0,
|
||||||
|
) -> Union[str, List[str]]:
|
||||||
|
"""Replace the keywords in text with placeholder tokens. This function
|
||||||
|
will be called in `self.__call__` and `self.encode`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Union[str, List[str]]): The text to be processed.
|
||||||
|
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||||
|
Defaults to False.
|
||||||
|
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||||
|
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[str, List[str]]: The processed text.
|
||||||
|
"""
|
||||||
|
if isinstance(text, list):
|
||||||
|
output = []
|
||||||
|
for i in range(len(text)):
|
||||||
|
output.append(
|
||||||
|
self.replace_placeholder_tokens_in_text(
|
||||||
|
text[i], vector_shuffle=vector_shuffle
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
for placeholder_token in self.token_map:
|
||||||
|
if placeholder_token in text:
|
||||||
|
tokens = self.token_map[placeholder_token]
|
||||||
|
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
||||||
|
if vector_shuffle:
|
||||||
|
tokens = copy.copy(tokens)
|
||||||
|
random.shuffle(tokens)
|
||||||
|
text = text.replace(placeholder_token, " ".join(tokens))
|
||||||
|
return text
|
||||||
|
|
||||||
|
def replace_text_with_placeholder_tokens(
|
||||||
|
self, text: Union[str, List[str]]
|
||||||
|
) -> Union[str, List[str]]:
|
||||||
|
"""Replace the placeholder tokens in text with the original keywords.
|
||||||
|
This function will be called in `self.decode`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Union[str, List[str]]): The text to be processed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[str, List[str]]: The processed text.
|
||||||
|
"""
|
||||||
|
if isinstance(text, list):
|
||||||
|
output = []
|
||||||
|
for i in range(len(text)):
|
||||||
|
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
||||||
|
return output
|
||||||
|
|
||||||
|
for placeholder_token, tokens in self.token_map.items():
|
||||||
|
merged_tokens = " ".join(tokens)
|
||||||
|
if merged_tokens in text:
|
||||||
|
text = text.replace(merged_tokens, placeholder_token)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Union[str, List[str]],
|
||||||
|
*args,
|
||||||
|
vector_shuffle: bool = False,
|
||||||
|
prop_tokens_to_load: float = 1.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""The call function of the wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Union[str, List[str]]): The text to be tokenized.
|
||||||
|
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||||
|
Defaults to False.
|
||||||
|
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||||
|
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
||||||
|
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||||
|
"""
|
||||||
|
replaced_text = self.replace_placeholder_tokens_in_text(
|
||||||
|
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
||||||
|
|
||||||
|
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
||||||
|
"""Encode the passed text to token index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (Union[str, List[str]]): The text to be encode.
|
||||||
|
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||||
|
"""
|
||||||
|
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
||||||
|
return self.wrapped(replaced_text, *args, **kwargs)
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self, token_ids, return_raw: bool = False, *args, **kwargs
|
||||||
|
) -> Union[str, List[str]]:
|
||||||
|
"""Decode the token index to text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The token index to be decoded.
|
||||||
|
return_raw: Whether keep the placeholder token in the text.
|
||||||
|
Defaults to False.
|
||||||
|
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[str, List[str]]: The decoded text.
|
||||||
|
"""
|
||||||
|
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
||||||
|
if return_raw:
|
||||||
|
return text
|
||||||
|
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
||||||
|
return replaced_text
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingLayerWithFixes(nn.Module):
|
||||||
|
"""The revised embedding layer to support external embeddings. This design
|
||||||
|
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
|
||||||
|
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
|
||||||
|
jack.py#L224 # noqa.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wrapped (nn.Emebdding): The embedding layer to be wrapped.
|
||||||
|
external_embeddings (Union[dict, List[dict]], optional): The external
|
||||||
|
embeddings added to this layer. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
wrapped: nn.Embedding,
|
||||||
|
external_embeddings: Optional[Union[dict, List[dict]]] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.wrapped = wrapped
|
||||||
|
self.num_embeddings = wrapped.weight.shape[0]
|
||||||
|
|
||||||
|
self.external_embeddings = []
|
||||||
|
if external_embeddings:
|
||||||
|
self.add_embeddings(external_embeddings)
|
||||||
|
|
||||||
|
self.trainable_embeddings = nn.ParameterDict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
"""Get the weight of wrapped embedding layer."""
|
||||||
|
return self.wrapped.weight
|
||||||
|
|
||||||
|
def check_duplicate_names(self, embeddings: List[dict]):
|
||||||
|
"""Check whether duplicate names exist in list of 'external
|
||||||
|
embeddings'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings (List[dict]): A list of embedding to be check.
|
||||||
|
"""
|
||||||
|
names = [emb["name"] for emb in embeddings]
|
||||||
|
assert len(names) == len(set(names)), (
|
||||||
|
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_ids_overlap(self, embeddings):
|
||||||
|
"""Check whether overlap exist in token ids of 'external_embeddings'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings (List[dict]): A list of embedding to be check.
|
||||||
|
"""
|
||||||
|
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
|
||||||
|
ids_range.sort() # sort by 'start'
|
||||||
|
# check if 'end' has overlapping
|
||||||
|
for idx in range(len(ids_range) - 1):
|
||||||
|
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
|
||||||
|
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
|
||||||
|
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
|
||||||
|
"""Add external embeddings to this layer.
|
||||||
|
|
||||||
|
Use case:
|
||||||
|
|
||||||
|
>>> 1. Add token to tokenizer and get the token id.
|
||||||
|
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
|
||||||
|
>>> # 'how much' in kiswahili
|
||||||
|
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
|
||||||
|
>>>
|
||||||
|
>>> 2. Add external embeddings to the model.
|
||||||
|
>>> new_embedding = {
|
||||||
|
>>> 'name': 'ngapi', # 'how much' in kiswahili
|
||||||
|
>>> 'embedding': torch.ones(1, 15) * 4,
|
||||||
|
>>> 'start': tokenizer.get_token_info('kwaheri')['start'],
|
||||||
|
>>> 'end': tokenizer.get_token_info('kwaheri')['end'],
|
||||||
|
>>> 'trainable': False # if True, will registry as a parameter
|
||||||
|
>>> }
|
||||||
|
>>> embedding_layer = nn.Embedding(10, 15)
|
||||||
|
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
|
||||||
|
>>> embedding_layer_wrapper.add_embeddings(new_embedding)
|
||||||
|
>>>
|
||||||
|
>>> 3. Forward tokenizer and embedding layer!
|
||||||
|
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
|
||||||
|
>>> input_ids = tokenizer(
|
||||||
|
>>> input_text, padding='max_length', truncation=True,
|
||||||
|
>>> return_tensors='pt')['input_ids']
|
||||||
|
>>> out_feat = embedding_layer_wrapper(input_ids)
|
||||||
|
>>>
|
||||||
|
>>> 4. Let's validate the result!
|
||||||
|
>>> assert (out_feat[0, 3: 7] == 2.3).all()
|
||||||
|
>>> assert (out_feat[2, 5: 9] == 2.3).all()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings (Union[dict, list[dict]]): The external embeddings to
|
||||||
|
be added. Each dict must contain the following 4 fields: 'name'
|
||||||
|
(the name of this embedding), 'embedding' (the embedding
|
||||||
|
tensor), 'start' (the start token id of this embedding), 'end'
|
||||||
|
(the end token id of this embedding). For example:
|
||||||
|
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
|
||||||
|
"""
|
||||||
|
if isinstance(embeddings, dict):
|
||||||
|
embeddings = [embeddings]
|
||||||
|
|
||||||
|
self.external_embeddings += embeddings
|
||||||
|
self.check_duplicate_names(self.external_embeddings)
|
||||||
|
self.check_ids_overlap(self.external_embeddings)
|
||||||
|
|
||||||
|
# set for trainable
|
||||||
|
added_trainable_emb_info = []
|
||||||
|
for embedding in embeddings:
|
||||||
|
trainable = embedding.get("trainable", False)
|
||||||
|
if trainable:
|
||||||
|
name = embedding["name"]
|
||||||
|
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
|
||||||
|
self.trainable_embeddings[name] = embedding["embedding"]
|
||||||
|
added_trainable_emb_info.append(name)
|
||||||
|
|
||||||
|
added_emb_info = [emb["name"] for emb in embeddings]
|
||||||
|
added_emb_info = ", ".join(added_emb_info)
|
||||||
|
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
|
||||||
|
|
||||||
|
if added_trainable_emb_info:
|
||||||
|
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
|
||||||
|
print(
|
||||||
|
"Successfully add trainable external embeddings: "
|
||||||
|
f"{added_trainable_emb_info}",
|
||||||
|
"current",
|
||||||
|
)
|
||||||
|
|
||||||
|
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Replace external input ids to 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): The input ids to be replaced.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The replaced input ids.
|
||||||
|
"""
|
||||||
|
input_ids_fwd = input_ids.clone()
|
||||||
|
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
|
||||||
|
return input_ids_fwd
|
||||||
|
|
||||||
|
def replace_embeddings(
|
||||||
|
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Replace external embedding to the embedding layer. Noted that, in
|
||||||
|
this function we use `torch.cat` to avoid inplace modification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): The original token ids. Shape like
|
||||||
|
[LENGTH, ].
|
||||||
|
embedding (torch.Tensor): The embedding of token ids after
|
||||||
|
`replace_input_ids` function.
|
||||||
|
external_embedding (dict): The external embedding to be replaced.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The replaced embedding.
|
||||||
|
"""
|
||||||
|
new_embedding = []
|
||||||
|
|
||||||
|
name = external_embedding["name"]
|
||||||
|
start = external_embedding["start"]
|
||||||
|
end = external_embedding["end"]
|
||||||
|
target_ids_to_replace = [i for i in range(start, end)]
|
||||||
|
ext_emb = external_embedding["embedding"]
|
||||||
|
|
||||||
|
# do not need to replace
|
||||||
|
if not (input_ids == start).any():
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
# start replace
|
||||||
|
s_idx, e_idx = 0, 0
|
||||||
|
while e_idx < len(input_ids):
|
||||||
|
if input_ids[e_idx] == start:
|
||||||
|
if e_idx != 0:
|
||||||
|
# add embedding do not need to replace
|
||||||
|
new_embedding.append(embedding[s_idx:e_idx])
|
||||||
|
|
||||||
|
# check if the next embedding need to replace is valid
|
||||||
|
actually_ids_to_replace = [
|
||||||
|
int(i) for i in input_ids[e_idx : e_idx + end - start]
|
||||||
|
]
|
||||||
|
assert actually_ids_to_replace == target_ids_to_replace, (
|
||||||
|
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
|
||||||
|
f"Expect '{target_ids_to_replace}' for embedding "
|
||||||
|
f"'{name}' but found '{actually_ids_to_replace}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
new_embedding.append(ext_emb)
|
||||||
|
|
||||||
|
s_idx = e_idx + end - start
|
||||||
|
e_idx = s_idx + 1
|
||||||
|
else:
|
||||||
|
e_idx += 1
|
||||||
|
|
||||||
|
if e_idx == len(input_ids):
|
||||||
|
new_embedding.append(embedding[s_idx:e_idx])
|
||||||
|
|
||||||
|
return torch.cat(new_embedding, dim=0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None
|
||||||
|
):
|
||||||
|
"""The forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
|
||||||
|
[LENGTH, ].
|
||||||
|
external_embeddings (Optional[List[dict]]): The external
|
||||||
|
embeddings. If not passed, only `self.external_embeddings`
|
||||||
|
will be used. Defaults to None.
|
||||||
|
|
||||||
|
input_ids: shape like [bz, LENGTH] or [LENGTH].
|
||||||
|
"""
|
||||||
|
assert input_ids.ndim in [1, 2]
|
||||||
|
if input_ids.ndim == 1:
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if external_embeddings is None and not self.external_embeddings:
|
||||||
|
return self.wrapped(input_ids)
|
||||||
|
|
||||||
|
input_ids_fwd = self.replace_input_ids(input_ids)
|
||||||
|
inputs_embeds = self.wrapped(input_ids_fwd)
|
||||||
|
|
||||||
|
vecs = []
|
||||||
|
|
||||||
|
if external_embeddings is None:
|
||||||
|
external_embeddings = []
|
||||||
|
elif isinstance(external_embeddings, dict):
|
||||||
|
external_embeddings = [external_embeddings]
|
||||||
|
embeddings = self.external_embeddings + external_embeddings
|
||||||
|
|
||||||
|
for input_id, embedding in zip(input_ids, inputs_embeds):
|
||||||
|
new_embedding = embedding
|
||||||
|
for external_embedding in embeddings:
|
||||||
|
new_embedding = self.replace_embeddings(
|
||||||
|
input_id, new_embedding, external_embedding
|
||||||
|
)
|
||||||
|
vecs.append(new_embedding)
|
||||||
|
|
||||||
|
return torch.stack(vecs)
|
||||||
|
|
||||||
|
|
||||||
|
def add_tokens(
|
||||||
|
tokenizer,
|
||||||
|
text_encoder,
|
||||||
|
placeholder_tokens: list,
|
||||||
|
initialize_tokens: list = None,
|
||||||
|
num_vectors_per_token: int = 1,
|
||||||
|
):
|
||||||
|
"""Add token for training.
|
||||||
|
|
||||||
|
# TODO: support add tokens as dict, then we can load pretrained tokens.
|
||||||
|
"""
|
||||||
|
if initialize_tokens is not None:
|
||||||
|
assert len(initialize_tokens) == len(
|
||||||
|
placeholder_tokens
|
||||||
|
), "placeholder_token should be the same length as initialize_token"
|
||||||
|
for ii in range(len(placeholder_tokens)):
|
||||||
|
tokenizer.add_placeholder_token(
|
||||||
|
placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# text_encoder.set_embedding_layer()
|
||||||
|
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||||
|
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(
|
||||||
|
embedding_layer
|
||||||
|
)
|
||||||
|
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||||
|
|
||||||
|
assert embedding_layer is not None, (
|
||||||
|
"Do not support get embedding layer for current text encoder. "
|
||||||
|
"Please check your configuration."
|
||||||
|
)
|
||||||
|
initialize_embedding = []
|
||||||
|
if initialize_tokens is not None:
|
||||||
|
for ii in range(len(placeholder_tokens)):
|
||||||
|
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
|
||||||
|
temp_embedding = embedding_layer.weight[init_id]
|
||||||
|
initialize_embedding.append(
|
||||||
|
temp_embedding[None, ...].repeat(num_vectors_per_token, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for ii in range(len(placeholder_tokens)):
|
||||||
|
init_id = tokenizer("a").input_ids[1]
|
||||||
|
temp_embedding = embedding_layer.weight[init_id]
|
||||||
|
len_emb = temp_embedding.shape[0]
|
||||||
|
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
|
||||||
|
initialize_embedding.append(init_weight)
|
||||||
|
|
||||||
|
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
|
||||||
|
|
||||||
|
token_info_all = []
|
||||||
|
for ii in range(len(placeholder_tokens)):
|
||||||
|
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
|
||||||
|
token_info["embedding"] = initialize_embedding[ii]
|
||||||
|
token_info["trainable"] = True
|
||||||
|
token_info_all.append(token_info)
|
||||||
|
embedding_layer.add_embeddings(token_info_all)
|
@ -3,9 +3,9 @@ import cv2
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
|
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||||
from lama_cleaner.schema import Config, ModelType
|
from lama_cleaner.schema import Config, ModelType
|
||||||
|
|
||||||
|
|
||||||
@ -40,20 +40,18 @@ class SD(DiffusionInpaintModel):
|
|||||||
model_kwargs["num_in_channels"] = 9
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
||||||
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
|
self.model_id_or_path, dtype=torch_dtype, **model_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
self.model = handle_from_pretrained_exceptions(
|
||||||
self.model_id_or_path,
|
StableDiffusionInpaintPipeline.from_pretrained,
|
||||||
revision="fp16"
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
|
variant="fp16",
|
||||||
else "main",
|
dtype=torch_dtype,
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
# TODO: gpu_id
|
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
else:
|
else:
|
||||||
@ -98,20 +96,20 @@ class SD(DiffusionInpaintModel):
|
|||||||
|
|
||||||
|
|
||||||
class SD15(SD):
|
class SD15(SD):
|
||||||
name = "sd1.5"
|
name = "runwayml/stable-diffusion-inpainting"
|
||||||
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
||||||
|
|
||||||
|
|
||||||
class Anything4(SD):
|
class Anything4(SD):
|
||||||
name = "anything4"
|
name = "Sanster/anything-4.0-inpainting"
|
||||||
model_id_or_path = "Sanster/anything-4.0-inpainting"
|
model_id_or_path = "Sanster/anything-4.0-inpainting"
|
||||||
|
|
||||||
|
|
||||||
class RealisticVision14(SD):
|
class RealisticVision14(SD):
|
||||||
name = "realisticVision1.4"
|
name = "Sanster/Realistic_Vision_V1.4-inpainting"
|
||||||
model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting"
|
model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting"
|
||||||
|
|
||||||
|
|
||||||
class SD2(SD):
|
class SD2(SD):
|
||||||
name = "sd2"
|
name = "stabilityai/stable-diffusion-2-inpainting"
|
||||||
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"
|
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"
|
||||||
|
@ -8,11 +8,12 @@ from diffusers import AutoencoderKL
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
|
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||||
from lama_cleaner.schema import Config, ModelType
|
from lama_cleaner.schema import Config, ModelType
|
||||||
|
|
||||||
|
|
||||||
class SDXL(DiffusionInpaintModel):
|
class SDXL(DiffusionInpaintModel):
|
||||||
name = "sdxl"
|
name = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||||
@ -34,18 +35,19 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
if os.path.isfile(self.model_id_or_path):
|
if os.path.isfile(self.model_id_or_path):
|
||||||
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
|
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
torch_dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
num_in_channels=num_in_channels,
|
num_in_channels=num_in_channels,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
||||||
)
|
)
|
||||||
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
|
self.model = handle_from_pretrained_exceptions(
|
||||||
self.model_id_or_path,
|
StableDiffusionXLInpaintPipeline.from_pretrained,
|
||||||
revision="main",
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
|
variant="fp16",
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
import traceback
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -16,8 +17,11 @@ from diffusers import (
|
|||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
LCMScheduler
|
LCMScheduler,
|
||||||
)
|
)
|
||||||
|
from huggingface_hub.utils import RevisionNotFoundError
|
||||||
|
from loguru import logger
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from lama_cleaner.schema import SDSampler
|
from lama_cleaner.schema import SDSampler
|
||||||
from torch import conv2d, conv_transpose2d
|
from torch import conv2d, conv_transpose2d
|
||||||
@ -944,3 +948,20 @@ def get_scheduler(sd_sampler, scheduler_config):
|
|||||||
return LCMScheduler.from_config(scheduler_config)
|
return LCMScheduler.from_config(scheduler_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(sd_sampler)
|
raise ValueError(sd_sampler)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_from_pretrained_exceptions(func, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(**kwargs)
|
||||||
|
except ValueError as e:
|
||||||
|
# 处理异常的逻辑
|
||||||
|
if "You are trying to load the model files of the `variant=fp16`" in str(e):
|
||||||
|
logger.info("variant=fp16 not found, try revision=fp16")
|
||||||
|
return func(**{**kwargs, "variant": None, "revision": "fp16"})
|
||||||
|
except OSError as e:
|
||||||
|
previous_traceback = traceback.format_exc()
|
||||||
|
if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
|
||||||
|
logger.info("revision=fp16 not found, try revision=main")
|
||||||
|
return func(**{**kwargs, "variant": None, "revision": "main"})
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
100
lama_cleaner/model_info.py
Normal file
100
lama_cleaner/model_info.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import computed_field, BaseModel
|
||||||
|
|
||||||
|
from lama_cleaner.const import (
|
||||||
|
SDXL_CONTROLNET_CHOICES,
|
||||||
|
SD2_CONTROLNET_CHOICES,
|
||||||
|
SD_CONTROLNET_CHOICES,
|
||||||
|
)
|
||||||
|
from lama_cleaner.model import InstructPix2Pix, Kandinsky22, PowerPaint, SD2
|
||||||
|
from lama_cleaner.schema import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
model_type: ModelType
|
||||||
|
is_single_file_diffusers: bool = False
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def need_prompt(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [
|
||||||
|
InstructPix2Pix.name,
|
||||||
|
Kandinsky22.name,
|
||||||
|
PowerPaint.name,
|
||||||
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def controlnets(self) -> List[str]:
|
||||||
|
if self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]:
|
||||||
|
return SDXL_CONTROLNET_CHOICES
|
||||||
|
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
||||||
|
if self.name in [SD2.name]:
|
||||||
|
return SD2_CONTROLNET_CHOICES
|
||||||
|
else:
|
||||||
|
return SD_CONTROLNET_CHOICES
|
||||||
|
if self.name == PowerPaint.name:
|
||||||
|
return SD_CONTROLNET_CHOICES
|
||||||
|
return []
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_strength(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_outpainting(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [Kandinsky22.name, PowerPaint.name]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_lcm_lora(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_controlnet(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [PowerPaint.name]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_freeu(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [InstructPix2Pix.name]
|
@ -7,7 +7,8 @@ from lama_cleaner.download import scan_models
|
|||||||
from lama_cleaner.helper import switch_mps_device
|
from lama_cleaner.helper import switch_mps_device
|
||||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||||
from lama_cleaner.model.utils import torch_gc
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.schema import Config, ModelInfo, ModelType
|
from lama_cleaner.model_info import ModelInfo, ModelType
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
@ -18,13 +19,20 @@ class ModelManager:
|
|||||||
self.available_models: Dict[str, ModelInfo] = {}
|
self.available_models: Dict[str, ModelInfo] = {}
|
||||||
self.scan_models()
|
self.scan_models()
|
||||||
|
|
||||||
self.sd_controlnet = False
|
self.enable_controlnet = kwargs.get("enable_controlnet", False)
|
||||||
self.sd_controlnet_method = ""
|
controlnet_method = kwargs.get("controlnet_method", None)
|
||||||
|
if (
|
||||||
|
controlnet_method is None
|
||||||
|
and name in self.available_models
|
||||||
|
and self.available_models[name].support_controlnet
|
||||||
|
):
|
||||||
|
controlnet_method = self.available_models[name].controlnets[0]
|
||||||
|
self.controlnet_method = controlnet_method
|
||||||
self.model = self.init_model(name, device, **kwargs)
|
self.model = self.init_model(name, device, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_model(self) -> Dict:
|
def current_model(self) -> Dict:
|
||||||
return self.available_models[name].model_dump()
|
return self.available_models[self.name].model_dump()
|
||||||
|
|
||||||
def init_model(self, name: str, device, **kwargs):
|
def init_model(self, name: str, device, **kwargs):
|
||||||
logger.info(f"Loading model: {name}")
|
logger.info(f"Loading model: {name}")
|
||||||
@ -35,15 +43,14 @@ class ModelManager:
|
|||||||
kwargs = {
|
kwargs = {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
"model_info": model_info,
|
"model_info": model_info,
|
||||||
"sd_controlnet": self.sd_controlnet,
|
"enable_controlnet": self.enable_controlnet,
|
||||||
"sd_controlnet_method": self.sd_controlnet_method,
|
"controlnet_method": self.controlnet_method,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
if model_info.support_controlnet and self.enable_controlnet:
|
||||||
return models[name](device, **kwargs)
|
|
||||||
|
|
||||||
if self.sd_controlnet:
|
|
||||||
return ControlNet(device, **kwargs)
|
return ControlNet(device, **kwargs)
|
||||||
|
elif model_info.name in models:
|
||||||
|
return models[name](device, **kwargs)
|
||||||
else:
|
else:
|
||||||
if model_info.model_type in [
|
if model_info.model_type in [
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
@ -75,15 +82,15 @@ class ModelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
old_name = self.name
|
old_name = self.name
|
||||||
old_sd_controlnet_method = self.sd_controlnet_method
|
old_controlnet_method = self.controlnet_method
|
||||||
self.name = new_name
|
self.name = new_name
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.available_models[new_name].support_controlnet
|
self.available_models[new_name].support_controlnet
|
||||||
and self.sd_controlnet_method
|
and self.controlnet_method
|
||||||
not in self.available_models[new_name].controlnets
|
not in self.available_models[new_name].controlnets
|
||||||
):
|
):
|
||||||
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
|
self.controlnet_method = self.available_models[new_name].controlnets[0]
|
||||||
try:
|
try:
|
||||||
# TODO: enable/disable controlnet without reload model
|
# TODO: enable/disable controlnet without reload model
|
||||||
del self.model
|
del self.model
|
||||||
@ -94,7 +101,7 @@ class ModelManager:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.name = old_name
|
self.name = old_name
|
||||||
self.sd_controlnet_method = old_sd_controlnet_method
|
self.controlnet_method = old_controlnet_method
|
||||||
logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
|
logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
|
||||||
self.model = self.init_model(
|
self.model = self.init_model(
|
||||||
old_name, switch_mps_device(old_name, self.device), **self.kwargs
|
old_name, switch_mps_device(old_name, self.device), **self.kwargs
|
||||||
@ -106,24 +113,24 @@ class ModelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.sd_controlnet
|
self.enable_controlnet
|
||||||
and config.controlnet_method
|
and config.controlnet_method
|
||||||
and self.sd_controlnet_method != config.controlnet_method
|
and self.controlnet_method != config.controlnet_method
|
||||||
):
|
):
|
||||||
old_sd_controlnet_method = self.sd_controlnet_method
|
old_controlnet_method = self.controlnet_method
|
||||||
self.sd_controlnet_method = config.controlnet_method
|
self.controlnet_method = config.controlnet_method
|
||||||
self.model.switch_controlnet_method(config.controlnet_method)
|
self.model.switch_controlnet_method(config.controlnet_method)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
|
f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}"
|
||||||
)
|
)
|
||||||
elif self.sd_controlnet != config.controlnet_enabled:
|
elif self.enable_controlnet != config.enable_controlnet:
|
||||||
self.sd_controlnet = config.controlnet_enabled
|
self.enable_controlnet = config.enable_controlnet
|
||||||
self.sd_controlnet_method = config.controlnet_method
|
self.controlnet_method = config.controlnet_method
|
||||||
|
|
||||||
self.model = self.init_model(
|
self.model = self.init_model(
|
||||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||||
)
|
)
|
||||||
if not config.controlnet_enabled:
|
if not config.enable_controlnet:
|
||||||
logger.info(f"Disable controlnet")
|
logger.info(f"Disable controlnet")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||||
|
@ -1,19 +1,8 @@
|
|||||||
from typing import Optional, List
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from pydantic import BaseModel, computed_field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from lama_cleaner.const import (
|
|
||||||
SDXL_CONTROLNET_CHOICES,
|
|
||||||
SD2_CONTROLNET_CHOICES,
|
|
||||||
SD_CONTROLNET_CHOICES,
|
|
||||||
)
|
|
||||||
|
|
||||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
|
||||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
|
||||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
|
||||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
@ -25,103 +14,6 @@ class ModelType(str, Enum):
|
|||||||
DIFFUSERS_OTHER = "diffusers_other"
|
DIFFUSERS_OTHER = "diffusers_other"
|
||||||
|
|
||||||
|
|
||||||
FREEU_DEFAULT_CONFIGS = {
|
|
||||||
ModelType.DIFFUSERS_SD: dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
|
|
||||||
ModelType.DIFFUSERS_SDXL: dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
model_type: ModelType
|
|
||||||
is_single_file_diffusers: bool = False
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def need_prompt(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
] or self.name in [
|
|
||||||
"timbrooks/instruct-pix2pix",
|
|
||||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def controlnets(self) -> List[str]:
|
|
||||||
if self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]:
|
|
||||||
return SDXL_CONTROLNET_CHOICES
|
|
||||||
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
|
||||||
if self.name in ["stabilityai/stable-diffusion-2-inpainting"]:
|
|
||||||
return SD2_CONTROLNET_CHOICES
|
|
||||||
else:
|
|
||||||
return SD_CONTROLNET_CHOICES
|
|
||||||
return []
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_strength(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_outpainting(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
] or self.name in [
|
|
||||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_lcm_lora(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_controlnet(self) -> bool:
|
|
||||||
return self.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def support_freeu(self) -> bool:
|
|
||||||
return (
|
|
||||||
self.model_type
|
|
||||||
in [
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
|
||||||
]
|
|
||||||
or "timbrooks/instruct-pix2pix" in self.name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HDStrategy(str, Enum):
|
class HDStrategy(str, Enum):
|
||||||
# Use original image size
|
# Use original image size
|
||||||
ORIGINAL = "Original"
|
ORIGINAL = "Original"
|
||||||
@ -157,6 +49,13 @@ class FREEUConfig(BaseModel):
|
|||||||
b2: float = 1.4
|
b2: float = 1.4
|
||||||
|
|
||||||
|
|
||||||
|
class PowerPaintTask(str, Enum):
|
||||||
|
text_guided = "text-guided"
|
||||||
|
shape_guided = "shape-guided"
|
||||||
|
object_remove = "object-remove"
|
||||||
|
outpainting = "outpainting"
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -239,6 +138,11 @@ class Config(BaseModel):
|
|||||||
p2p_image_guidance_scale: float = 1.5
|
p2p_image_guidance_scale: float = 1.5
|
||||||
|
|
||||||
# ControlNet
|
# ControlNet
|
||||||
controlnet_enabled: bool = False
|
enable_controlnet: bool = False
|
||||||
controlnet_conditioning_scale: float = 0.4
|
controlnet_conditioning_scale: float = 0.4
|
||||||
controlnet_method: str = "control_v11p_sd15_canny"
|
controlnet_method: str = "lllyasviel/control_v11p_sd15_canny"
|
||||||
|
|
||||||
|
# PowerPaint
|
||||||
|
powerpaint_task: PowerPaintTask = PowerPaintTask.text_guided
|
||||||
|
# control the fitting degree of the generated objects to the mask shape.
|
||||||
|
fitting_degree: float = 1.0
|
||||||
|
@ -63,6 +63,7 @@ from lama_cleaner.helper import (
|
|||||||
numpy_to_bytes,
|
numpy_to_bytes,
|
||||||
resize_max_size,
|
resize_max_size,
|
||||||
pil_to_bytes,
|
pil_to_bytes,
|
||||||
|
is_mac,
|
||||||
)
|
)
|
||||||
|
|
||||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||||
@ -285,9 +286,10 @@ def process():
|
|||||||
cv2_radius=form["cv2Radius"],
|
cv2_radius=form["cv2Radius"],
|
||||||
paint_by_example_example_image=paint_by_example_example_image,
|
paint_by_example_example_image=paint_by_example_example_image,
|
||||||
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
||||||
controlnet_enabled=form["controlnet_enabled"],
|
enable_controlnet=form["enable_controlnet"],
|
||||||
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
||||||
controlnet_method=form["controlnet_method"],
|
controlnet_method=form["controlnet_method"],
|
||||||
|
powerpaint_task=form["powerpaintTask"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.sd_seed == -1:
|
if config.sd_seed == -1:
|
||||||
@ -305,6 +307,8 @@ def process():
|
|||||||
if "CUDA out of memory. " in str(e):
|
if "CUDA out of memory. " in str(e):
|
||||||
# NOTE: the string may change?
|
# NOTE: the string may change?
|
||||||
return "CUDA out of memory", 500
|
return "CUDA out of memory", 500
|
||||||
|
elif "Invalid buffer size" in str(e) and is_mac():
|
||||||
|
return "Out of memory", 500
|
||||||
else:
|
else:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return f"{str(e)}", 500
|
return f"{str(e)}", 500
|
||||||
@ -423,8 +427,8 @@ def get_server_config():
|
|||||||
"plugins": list(global_config.plugins.keys()),
|
"plugins": list(global_config.plugins.keys()),
|
||||||
"enableFileManager": global_config.enable_file_manager,
|
"enableFileManager": global_config.enable_file_manager,
|
||||||
"enableAutoSaving": global_config.enable_auto_saving,
|
"enableAutoSaving": global_config.enable_auto_saving,
|
||||||
"enableControlnet": global_config.model_manager.sd_controlnet,
|
"enableControlnet": global_config.model_manager.enable_controlnet,
|
||||||
"controlnetMethod": global_config.model_manager.sd_controlnet_method,
|
"controlnetMethod": global_config.model_manager.controlnet_method,
|
||||||
"disableModelSwitch": global_config.disable_model_switch,
|
"disableModelSwitch": global_config.disable_model_switch,
|
||||||
"isDesktop": global_config.is_desktop,
|
"isDesktop": global_config.is_desktop,
|
||||||
}, 200
|
}, 200
|
||||||
|
0
lama_cleaner/tests/utils.py
Normal file
0
lama_cleaner/tests/utils.py
Normal file
@ -15,8 +15,8 @@ def save_config(
|
|||||||
port,
|
port,
|
||||||
model,
|
model,
|
||||||
sd_local_model_path,
|
sd_local_model_path,
|
||||||
sd_controlnet,
|
enable_controlnet,
|
||||||
sd_controlnet_method,
|
controlnet_method,
|
||||||
device,
|
device,
|
||||||
gui,
|
gui,
|
||||||
no_gui_auto_close,
|
no_gui_auto_close,
|
||||||
@ -176,13 +176,13 @@ def main(config_file: str):
|
|||||||
sd_local_model_path = gr.Textbox(
|
sd_local_model_path = gr.Textbox(
|
||||||
init_config.sd_local_model_path, label=f"{SD_LOCAL_MODEL_HELP}"
|
init_config.sd_local_model_path, label=f"{SD_LOCAL_MODEL_HELP}"
|
||||||
)
|
)
|
||||||
sd_controlnet = gr.Checkbox(
|
enable_controlnet = gr.Checkbox(
|
||||||
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
|
init_config.enable_controlnet, label=f"{SD_CONTROLNET_HELP}"
|
||||||
)
|
)
|
||||||
sd_controlnet_method = gr.Radio(
|
controlnet_method = gr.Radio(
|
||||||
SD_CONTROLNET_CHOICES,
|
SD_CONTROLNET_CHOICES,
|
||||||
label="ControlNet method",
|
label="ControlNet method",
|
||||||
value=init_config.sd_controlnet_method,
|
value=init_config.controlnet_method,
|
||||||
)
|
)
|
||||||
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
||||||
cpu_offload = gr.Checkbox(
|
cpu_offload = gr.Checkbox(
|
||||||
@ -205,8 +205,8 @@ def main(config_file: str):
|
|||||||
port,
|
port,
|
||||||
model,
|
model,
|
||||||
sd_local_model_path,
|
sd_local_model_path,
|
||||||
sd_controlnet,
|
enable_controlnet,
|
||||||
sd_controlnet_method,
|
controlnet_method,
|
||||||
device,
|
device,
|
||||||
gui,
|
gui,
|
||||||
no_gui_auto_close,
|
no_gui_auto_close,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { EXTENDER_ALL, EXTENDER_X, EXTENDER_Y } from "@/lib/const"
|
|
||||||
import { useStore } from "@/lib/states"
|
import { useStore } from "@/lib/states"
|
||||||
|
import { ExtenderDirection } from "@/lib/types"
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils"
|
||||||
import React, { useEffect, useState } from "react"
|
import React, { useEffect, useState } from "react"
|
||||||
import { twMerge } from "tailwind-merge"
|
import { twMerge } from "tailwind-merge"
|
||||||
@ -107,7 +107,7 @@ const Extender = (props: Props) => {
|
|||||||
const newY = evData.initY + offsetY
|
const newY = evData.initY + offsetY
|
||||||
let clampedY = newY
|
let clampedY = newY
|
||||||
let clampedHeight = newHeight
|
let clampedHeight = newHeight
|
||||||
if (extenderDirection === EXTENDER_ALL) {
|
if (extenderDirection === ExtenderDirection.xy) {
|
||||||
if (clampedY > 0) {
|
if (clampedY > 0) {
|
||||||
clampedY = 0
|
clampedY = 0
|
||||||
clampedHeight = evData.initHeight - Math.abs(evData.initY)
|
clampedHeight = evData.initHeight - Math.abs(evData.initY)
|
||||||
@ -124,7 +124,7 @@ const Extender = (props: Props) => {
|
|||||||
const moveBottom = () => {
|
const moveBottom = () => {
|
||||||
const newHeight = evData.initHeight + offsetY
|
const newHeight = evData.initHeight + offsetY
|
||||||
let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight)
|
let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight)
|
||||||
if (extenderDirection === EXTENDER_ALL) {
|
if (extenderDirection === ExtenderDirection.xy) {
|
||||||
if (clampedHeight < Math.abs(clampedY) + imageHeight) {
|
if (clampedHeight < Math.abs(clampedY) + imageHeight) {
|
||||||
clampedHeight = Math.abs(clampedY) + imageHeight
|
clampedHeight = Math.abs(clampedY) + imageHeight
|
||||||
}
|
}
|
||||||
@ -138,7 +138,7 @@ const Extender = (props: Props) => {
|
|||||||
const newX = evData.initX + offsetX
|
const newX = evData.initX + offsetX
|
||||||
let clampedX = newX
|
let clampedX = newX
|
||||||
let clampedWidth = newWidth
|
let clampedWidth = newWidth
|
||||||
if (extenderDirection === EXTENDER_ALL) {
|
if (extenderDirection === ExtenderDirection.xy) {
|
||||||
if (clampedX > 0) {
|
if (clampedX > 0) {
|
||||||
clampedX = 0
|
clampedX = 0
|
||||||
clampedWidth = evData.initWidth - Math.abs(evData.initX)
|
clampedWidth = evData.initWidth - Math.abs(evData.initX)
|
||||||
@ -155,7 +155,7 @@ const Extender = (props: Props) => {
|
|||||||
const moveRight = () => {
|
const moveRight = () => {
|
||||||
const newWidth = evData.initWidth + offsetX
|
const newWidth = evData.initWidth + offsetX
|
||||||
let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth)
|
let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth)
|
||||||
if (extenderDirection === EXTENDER_ALL) {
|
if (extenderDirection === ExtenderDirection.xy) {
|
||||||
if (clampedWidth < Math.abs(clampedX) + imageWdith) {
|
if (clampedWidth < Math.abs(clampedX) + imageWdith) {
|
||||||
clampedWidth = Math.abs(clampedX) + imageWdith
|
clampedWidth = Math.abs(clampedX) + imageWdith
|
||||||
}
|
}
|
||||||
@ -296,7 +296,9 @@ const Extender = (props: Props) => {
|
|||||||
onPointerDown={onCropPointerDown}
|
onPointerDown={onCropPointerDown}
|
||||||
className="absolute top-0 h-full w-full"
|
className="absolute top-0 h-full w-full"
|
||||||
>
|
>
|
||||||
{[EXTENDER_Y, EXTENDER_ALL].includes(extenderDirection) ? (
|
{[ExtenderDirection.y, ExtenderDirection.xy].includes(
|
||||||
|
extenderDirection
|
||||||
|
) ? (
|
||||||
<>
|
<>
|
||||||
<div
|
<div
|
||||||
className="absolute pointer-events-auto top-0 left-0 w-full cursor-ns-resize h-[12px] mt-[-6px]"
|
className="absolute pointer-events-auto top-0 left-0 w-full cursor-ns-resize h-[12px] mt-[-6px]"
|
||||||
@ -313,7 +315,9 @@ const Extender = (props: Props) => {
|
|||||||
<></>
|
<></>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{[EXTENDER_X, EXTENDER_ALL].includes(extenderDirection) ? (
|
{[ExtenderDirection.x, ExtenderDirection.xy].includes(
|
||||||
|
extenderDirection
|
||||||
|
) ? (
|
||||||
<>
|
<>
|
||||||
<div
|
<div
|
||||||
className="absolute pointer-events-auto top-0 right-0 h-full cursor-ew-resize w-[12px] mr-[-6px]"
|
className="absolute pointer-events-auto top-0 right-0 h-full cursor-ew-resize w-[12px] mr-[-6px]"
|
||||||
@ -330,7 +334,7 @@ const Extender = (props: Props) => {
|
|||||||
<></>
|
<></>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{extenderDirection === EXTENDER_ALL ? (
|
{extenderDirection === ExtenderDirection.xy ? (
|
||||||
<>
|
<>
|
||||||
{createDragHandle("cursor-nw-resize", "top", "left")}
|
{createDragHandle("cursor-nw-resize", "top", "left")}
|
||||||
{createDragHandle("cursor-ne-resize", "top", "right")}
|
{createDragHandle("cursor-ne-resize", "top", "right")}
|
||||||
|
@ -36,9 +36,9 @@ const PromptInput = () => {
|
|||||||
updateSettings({ prompt: target.value })
|
updateSettings({ prompt: target.value })
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleRepaintClick = async () => {
|
const handleRepaintClick = () => {
|
||||||
if (prompt.length !== 0 && !isProcessing) {
|
if (!isProcessing) {
|
||||||
await runInpainting()
|
runInpainting()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ const PromptInput = () => {
|
|||||||
<Button
|
<Button
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={handleRepaintClick}
|
onClick={handleRepaintClick}
|
||||||
disabled={prompt.length === 0 || isProcessing}
|
disabled={isProcessing}
|
||||||
onMouseEnter={onMouseEnter}
|
onMouseEnter={onMouseEnter}
|
||||||
onMouseLeave={onMouseLeave}
|
onMouseLeave={onMouseLeave}
|
||||||
>
|
>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { IconButton } from "@/components/ui/button"
|
import { IconButton } from "@/components/ui/button"
|
||||||
import { useToggle } from "@uidotdev/usehooks"
|
import { useToggle } from "@uidotdev/usehooks"
|
||||||
import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog"
|
import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog"
|
||||||
import { HelpCircle, Settings } from "lucide-react"
|
import { Settings } from "lucide-react"
|
||||||
import { zodResolver } from "@hookform/resolvers/zod"
|
import { zodResolver } from "@hookform/resolvers/zod"
|
||||||
import { useForm } from "react-hook-form"
|
import { useForm } from "react-hook-form"
|
||||||
import * as z from "zod"
|
import * as z from "zod"
|
||||||
@ -179,12 +179,12 @@ export function SettingsDialog() {
|
|||||||
<div key={info.name} onClick={() => onModelSelect(info)}>
|
<div key={info.name} onClick={() => onModelSelect(info)}>
|
||||||
<div
|
<div
|
||||||
className={cn([
|
className={cn([
|
||||||
info.name === model.name ? "bg-muted " : "hover:bg-muted",
|
info.name === model.name ? "bg-muted" : "hover:bg-muted",
|
||||||
"rounded-md px-2 py-1 my-1",
|
"rounded-md px-2 py-1 my-1",
|
||||||
"cursor-default",
|
"cursor-default",
|
||||||
])}
|
])}
|
||||||
>
|
>
|
||||||
<div className="text-base max-w-sm">{info.name}</div>
|
<div className="text-base">{info.name}</div>
|
||||||
</div>
|
</div>
|
||||||
<Separator />
|
<Separator />
|
||||||
</div>
|
</div>
|
||||||
@ -223,13 +223,13 @@ export function SettingsDialog() {
|
|||||||
<div className="space-y-4 rounded-md">
|
<div className="space-y-4 rounded-md">
|
||||||
<div className="flex gap-1 items-center justify-start">
|
<div className="flex gap-1 items-center justify-start">
|
||||||
<div className="font-medium">Available models</div>
|
<div className="font-medium">Available models</div>
|
||||||
{/* <IconButton tooltip="How to download new model" asChild>
|
{/* <IconButton tooltip="How to download new model">
|
||||||
<HelpCircle size={16} strokeWidth={1.5} className="opacity-50" />
|
<Info size={20} strokeWidth={2} className="opacity-50" />
|
||||||
</IconButton> */}
|
</IconButton> */}
|
||||||
</div>
|
</div>
|
||||||
<Tabs defaultValue={defaultTab}>
|
<Tabs defaultValue={defaultTab}>
|
||||||
<TabsList>
|
<TabsList>
|
||||||
<TabsTrigger value={MODEL_TYPE_INPAINT}>Erase</TabsTrigger>
|
<TabsTrigger value={MODEL_TYPE_INPAINT}>Inpaint</TabsTrigger>
|
||||||
<TabsTrigger value={MODEL_TYPE_DIFFUSERS_SD}>
|
<TabsTrigger value={MODEL_TYPE_DIFFUSERS_SD}>
|
||||||
Stable Diffusion
|
Stable Diffusion
|
||||||
</TabsTrigger>
|
</TabsTrigger>
|
||||||
|
@ -11,21 +11,14 @@ import {
|
|||||||
SelectValue,
|
SelectValue,
|
||||||
} from "../ui/select"
|
} from "../ui/select"
|
||||||
import { Textarea } from "../ui/textarea"
|
import { Textarea } from "../ui/textarea"
|
||||||
import { SDSampler } from "@/lib/types"
|
import { ExtenderDirection, PowerPaintTask, SDSampler } from "@/lib/types"
|
||||||
import { Separator } from "../ui/separator"
|
import { Separator } from "../ui/separator"
|
||||||
import { Move, MoveHorizontal, MoveVertical, Upload } from "lucide-react"
|
|
||||||
import { Button, ImageUploadButton } from "../ui/button"
|
import { Button, ImageUploadButton } from "../ui/button"
|
||||||
import { Slider } from "../ui/slider"
|
import { Slider } from "../ui/slider"
|
||||||
import { useImage } from "@/hooks/useImage"
|
import { useImage } from "@/hooks/useImage"
|
||||||
import {
|
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
|
||||||
EXTENDER_ALL,
|
|
||||||
EXTENDER_X,
|
|
||||||
EXTENDER_Y,
|
|
||||||
INSTRUCT_PIX2PIX,
|
|
||||||
PAINT_BY_EXAMPLE,
|
|
||||||
} from "@/lib/const"
|
|
||||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "../ui/tabs"
|
|
||||||
import { RowContainer, LabelTitle } from "./LabelTitle"
|
import { RowContainer, LabelTitle } from "./LabelTitle"
|
||||||
|
import { Upload } from "lucide-react"
|
||||||
|
|
||||||
const ExtenderButton = ({
|
const ExtenderButton = ({
|
||||||
text,
|
text,
|
||||||
@ -38,8 +31,7 @@ const ExtenderButton = ({
|
|||||||
return (
|
return (
|
||||||
<Button
|
<Button
|
||||||
variant="outline"
|
variant="outline"
|
||||||
size="sm"
|
className="p-1 h-7"
|
||||||
className="p-1"
|
|
||||||
disabled={!showExtender}
|
disabled={!showExtender}
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
>
|
>
|
||||||
@ -129,6 +121,7 @@ const DiffusionOptions = () => {
|
|||||||
|
|
||||||
<div className="pr-2">
|
<div className="pr-2">
|
||||||
<Select
|
<Select
|
||||||
|
defaultValue={settings.controlnetMethod}
|
||||||
value={settings.controlnetMethod}
|
value={settings.controlnetMethod}
|
||||||
onValueChange={(value) => {
|
onValueChange={(value) => {
|
||||||
updateSettings({ controlnetMethod: value })
|
updateSettings({ controlnetMethod: value })
|
||||||
@ -467,96 +460,104 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
|
|
||||||
<Tabs
|
<RowContainer>
|
||||||
defaultValue={settings.extenderDirection}
|
<Select
|
||||||
onValueChange={(value) => updateExtenderDirection(value)}
|
defaultValue={settings.extenderDirection}
|
||||||
className="flex flex-col justify-center items-center"
|
value={settings.extenderDirection}
|
||||||
>
|
onValueChange={(value) => {
|
||||||
<TabsList className="w-[140px] mb-2">
|
updateExtenderDirection(value as ExtenderDirection)
|
||||||
<TabsTrigger value={EXTENDER_X} disabled={!settings.showExtender}>
|
}}
|
||||||
<MoveHorizontal size={20} strokeWidth={1} />
|
>
|
||||||
</TabsTrigger>
|
<SelectTrigger
|
||||||
<TabsTrigger value={EXTENDER_Y} disabled={!settings.showExtender}>
|
className="w-[65px] h-7"
|
||||||
<MoveVertical size={20} strokeWidth={1} />
|
|
||||||
</TabsTrigger>
|
|
||||||
<TabsTrigger
|
|
||||||
value={EXTENDER_ALL}
|
|
||||||
disabled={!settings.showExtender}
|
disabled={!settings.showExtender}
|
||||||
>
|
>
|
||||||
<Move size={20} strokeWidth={1} />
|
<SelectValue placeholder="Select axis" />
|
||||||
</TabsTrigger>
|
</SelectTrigger>
|
||||||
</TabsList>
|
<SelectContent align="end">
|
||||||
|
<SelectGroup>
|
||||||
|
{Object.values(ExtenderDirection).map((v) => (
|
||||||
|
<SelectItem key={v} value={v}>
|
||||||
|
{v}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
|
||||||
<TabsContent
|
<div className="flex gap-1 justify-center mt-0">
|
||||||
value={EXTENDER_X}
|
|
||||||
className="flex gap-2 justify-center mt-0"
|
|
||||||
>
|
|
||||||
<ExtenderButton
|
<ExtenderButton
|
||||||
text="1.25x"
|
text="1.25x"
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 1.25)}
|
onClick={() =>
|
||||||
|
updateExtenderByBuiltIn(settings.extenderDirection, 1.25)
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
<ExtenderButton
|
<ExtenderButton
|
||||||
text="1.5x"
|
text="1.5x"
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 1.5)}
|
onClick={() =>
|
||||||
|
updateExtenderByBuiltIn(settings.extenderDirection, 1.5)
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
<ExtenderButton
|
<ExtenderButton
|
||||||
text="1.75x"
|
text="1.75x"
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 1.75)}
|
onClick={() =>
|
||||||
|
updateExtenderByBuiltIn(settings.extenderDirection, 1.75)
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
<ExtenderButton
|
<ExtenderButton
|
||||||
text="2.0x"
|
text="2.0x"
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 2.0)}
|
onClick={() =>
|
||||||
|
updateExtenderByBuiltIn(settings.extenderDirection, 2.0)
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
</TabsContent>
|
</div>
|
||||||
<TabsContent
|
</RowContainer>
|
||||||
value={EXTENDER_Y}
|
|
||||||
className="flex gap-2 justify-center mt-0"
|
|
||||||
>
|
|
||||||
<ExtenderButton
|
|
||||||
text="1.25x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 1.25)}
|
|
||||||
/>
|
|
||||||
<ExtenderButton
|
|
||||||
text="1.5x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 1.5)}
|
|
||||||
/>
|
|
||||||
<ExtenderButton
|
|
||||||
text="1.75x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 1.75)}
|
|
||||||
/>
|
|
||||||
<ExtenderButton
|
|
||||||
text="2.0x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 2.0)}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
|
||||||
<TabsContent
|
|
||||||
value={EXTENDER_ALL}
|
|
||||||
className="flex gap-2 justify-center mt-0"
|
|
||||||
>
|
|
||||||
<ExtenderButton
|
|
||||||
text="1.25x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.25)}
|
|
||||||
/>
|
|
||||||
<ExtenderButton
|
|
||||||
text="1.5x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.5)}
|
|
||||||
/>
|
|
||||||
<ExtenderButton
|
|
||||||
text="1.75x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.75)}
|
|
||||||
/>
|
|
||||||
<ExtenderButton
|
|
||||||
text="2.0x"
|
|
||||||
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 2.0)}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
|
||||||
</Tabs>
|
|
||||||
</div>
|
</div>
|
||||||
<Separator />
|
<Separator />
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderPowerPaintTaskType = () => {
|
||||||
|
if (settings.model.name !== POWERPAINT) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<RowContainer>
|
||||||
|
<LabelTitle
|
||||||
|
text="Task"
|
||||||
|
toolTip="When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
|
||||||
|
/>
|
||||||
|
<Select
|
||||||
|
defaultValue={settings.powerpaintTask}
|
||||||
|
value={settings.powerpaintTask}
|
||||||
|
onValueChange={(value: PowerPaintTask) => {
|
||||||
|
updateSettings({ powerpaintTask: value })
|
||||||
|
}}
|
||||||
|
disabled={settings.showExtender}
|
||||||
|
>
|
||||||
|
<SelectTrigger className="w-[140px]">
|
||||||
|
<SelectValue placeholder="Select task" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent align="end">
|
||||||
|
<SelectGroup>
|
||||||
|
{[
|
||||||
|
PowerPaintTask.text_guided,
|
||||||
|
PowerPaintTask.object_remove,
|
||||||
|
PowerPaintTask.shape_guided,
|
||||||
|
].map((task) => (
|
||||||
|
<SelectItem key={task} value={task}>
|
||||||
|
{task}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</RowContainer>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4 mt-4">
|
<div className="flex flex-col gap-4 mt-4">
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
@ -577,6 +578,7 @@ const DiffusionOptions = () => {
|
|||||||
</RowContainer>
|
</RowContainer>
|
||||||
|
|
||||||
{renderExtender()}
|
{renderExtender()}
|
||||||
|
{renderPowerPaintTaskType()}
|
||||||
|
|
||||||
<div className="flex flex-col gap-1">
|
<div className="flex flex-col gap-1">
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
@ -642,20 +644,20 @@ const DiffusionOptions = () => {
|
|||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle text="Sampler" />
|
<LabelTitle text="Sampler" />
|
||||||
<Select
|
<Select
|
||||||
value={settings.sdSampler as string}
|
defaultValue={settings.sdSampler}
|
||||||
onValueChange={(value) => {
|
value={settings.sdSampler}
|
||||||
const sampler = value as SDSampler
|
onValueChange={(value: SDSampler) => {
|
||||||
updateSettings({ sdSampler: sampler })
|
updateSettings({ sdSampler: value })
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-[100px]">
|
<SelectTrigger className="w-[120px]">
|
||||||
<SelectValue placeholder="Select sampler" />
|
<SelectValue placeholder="Select sampler" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent align="end">
|
<SelectContent align="end">
|
||||||
<SelectGroup>
|
<SelectGroup>
|
||||||
{Object.values(SDSampler).map((sampler) => (
|
{Object.values(SDSampler).map((sampler) => (
|
||||||
<SelectItem key={sampler as string} value={sampler as string}>
|
<SelectItem key={sampler} value={sampler}>
|
||||||
{sampler as string}
|
{sampler}
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
))}
|
))}
|
||||||
</SelectGroup>
|
</SelectGroup>
|
||||||
@ -707,9 +709,9 @@ const DiffusionOptions = () => {
|
|||||||
<RowContainer>
|
<RowContainer>
|
||||||
<Slider
|
<Slider
|
||||||
className="w-[180px]"
|
className="w-[180px]"
|
||||||
defaultValue={[5]}
|
defaultValue={[settings.sdMaskBlur]}
|
||||||
min={0}
|
min={0}
|
||||||
max={35}
|
max={96}
|
||||||
step={1}
|
step={1}
|
||||||
value={[Math.floor(settings.sdMaskBlur)]}
|
value={[Math.floor(settings.sdMaskBlur)]}
|
||||||
onValueChange={(vals) => updateSettings({ sdMaskBlur: vals[0] })}
|
onValueChange={(vals) => updateSettings({ sdMaskBlur: vals[0] })}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { ModelInfo, Rect } from "@/lib/types"
|
import { ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
|
||||||
import { Settings } from "@/lib/states"
|
import { Settings } from "@/lib/states"
|
||||||
import { srcToFile } from "@/lib/utils"
|
import { srcToFile } from "@/lib/utils"
|
||||||
import axios from "axios"
|
import axios from "axios"
|
||||||
@ -22,7 +22,6 @@ export default async function inpaint(
|
|||||||
const fd = new FormData()
|
const fd = new FormData()
|
||||||
fd.append("image", imageFile)
|
fd.append("image", imageFile)
|
||||||
fd.append("mask", mask)
|
fd.append("mask", mask)
|
||||||
|
|
||||||
fd.append("ldmSteps", settings.ldmSteps.toString())
|
fd.append("ldmSteps", settings.ldmSteps.toString())
|
||||||
fd.append("ldmSampler", settings.ldmSampler.toString())
|
fd.append("ldmSampler", settings.ldmSampler.toString())
|
||||||
fd.append("zitsWireframe", settings.zitsWireframe.toString())
|
fd.append("zitsWireframe", settings.zitsWireframe.toString())
|
||||||
@ -51,6 +50,7 @@ export default async function inpaint(
|
|||||||
fd.append("sdSteps", settings.sdSteps.toString())
|
fd.append("sdSteps", settings.sdSteps.toString())
|
||||||
fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString())
|
fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString())
|
||||||
fd.append("sdSampler", settings.sdSampler.toString())
|
fd.append("sdSampler", settings.sdSampler.toString())
|
||||||
|
|
||||||
if (settings.seedFixed) {
|
if (settings.seedFixed) {
|
||||||
fd.append("sdSeed", settings.seed.toString())
|
fd.append("sdSeed", settings.seed.toString())
|
||||||
} else {
|
} else {
|
||||||
@ -76,13 +76,20 @@ export default async function inpaint(
|
|||||||
fd.append("p2pImageGuidanceScale", settings.p2pImageGuidanceScale.toString())
|
fd.append("p2pImageGuidanceScale", settings.p2pImageGuidanceScale.toString())
|
||||||
|
|
||||||
// ControlNet
|
// ControlNet
|
||||||
fd.append("controlnet_enabled", settings.enableControlnet.toString())
|
fd.append("enable_controlnet", settings.enableControlnet.toString())
|
||||||
fd.append(
|
fd.append(
|
||||||
"controlnet_conditioning_scale",
|
"controlnet_conditioning_scale",
|
||||||
settings.controlnetConditioningScale.toString()
|
settings.controlnetConditioningScale.toString()
|
||||||
)
|
)
|
||||||
fd.append("controlnet_method", settings.controlnetMethod.toString())
|
fd.append("controlnet_method", settings.controlnetMethod.toString())
|
||||||
|
|
||||||
|
// PowerPaint
|
||||||
|
if (settings.showExtender) {
|
||||||
|
fd.append("powerpaintTask", PowerPaintTask.outpainting)
|
||||||
|
} else {
|
||||||
|
fd.append("powerpaintTask", settings.powerpaintTask)
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
|
@ -8,14 +8,13 @@ export const MODEL_TYPE_DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
|
|||||||
export const MODEL_TYPE_OTHER = "diffusers_other"
|
export const MODEL_TYPE_OTHER = "diffusers_other"
|
||||||
export const BRUSH_COLOR = "#ffcc00bb"
|
export const BRUSH_COLOR = "#ffcc00bb"
|
||||||
|
|
||||||
export const EXTENDER_X = "extender_x"
|
|
||||||
export const EXTENDER_Y = "extender_y"
|
|
||||||
export const EXTENDER_ALL = "extender_all"
|
|
||||||
|
|
||||||
export const LDM = "ldm"
|
export const LDM = "ldm"
|
||||||
export const CV2 = "cv2"
|
export const CV2 = "cv2"
|
||||||
|
|
||||||
export const PAINT_BY_EXAMPLE = "Fantasy-Studio/Paint-by-Example"
|
export const PAINT_BY_EXAMPLE = "Fantasy-Studio/Paint-by-Example"
|
||||||
export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"
|
export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"
|
||||||
export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||||
|
export const POWERPAINT = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||||
|
|
||||||
export const DEFAULT_NEGATIVE_PROMPT =
|
export const DEFAULT_NEGATIVE_PROMPT =
|
||||||
"out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature"
|
"out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature"
|
||||||
|
@ -5,6 +5,7 @@ import { castDraft } from "immer"
|
|||||||
import { createWithEqualityFn } from "zustand/traditional"
|
import { createWithEqualityFn } from "zustand/traditional"
|
||||||
import {
|
import {
|
||||||
CV2Flag,
|
CV2Flag,
|
||||||
|
ExtenderDirection,
|
||||||
FreeuConfig,
|
FreeuConfig,
|
||||||
LDMSampler,
|
LDMSampler,
|
||||||
Line,
|
Line,
|
||||||
@ -12,6 +13,7 @@ import {
|
|||||||
ModelInfo,
|
ModelInfo,
|
||||||
PluginParams,
|
PluginParams,
|
||||||
Point,
|
Point,
|
||||||
|
PowerPaintTask,
|
||||||
SDSampler,
|
SDSampler,
|
||||||
Size,
|
Size,
|
||||||
SortBy,
|
SortBy,
|
||||||
@ -21,9 +23,6 @@ import {
|
|||||||
BRUSH_COLOR,
|
BRUSH_COLOR,
|
||||||
DEFAULT_BRUSH_SIZE,
|
DEFAULT_BRUSH_SIZE,
|
||||||
DEFAULT_NEGATIVE_PROMPT,
|
DEFAULT_NEGATIVE_PROMPT,
|
||||||
EXTENDER_ALL,
|
|
||||||
EXTENDER_X,
|
|
||||||
EXTENDER_Y,
|
|
||||||
MODEL_TYPE_INPAINT,
|
MODEL_TYPE_INPAINT,
|
||||||
PAINT_BY_EXAMPLE,
|
PAINT_BY_EXAMPLE,
|
||||||
} from "./const"
|
} from "./const"
|
||||||
@ -60,7 +59,7 @@ export type Settings = {
|
|||||||
enableUploadMask: boolean
|
enableUploadMask: boolean
|
||||||
showCropper: boolean
|
showCropper: boolean
|
||||||
showExtender: boolean
|
showExtender: boolean
|
||||||
extenderDirection: string
|
extenderDirection: ExtenderDirection
|
||||||
|
|
||||||
// For LDM
|
// For LDM
|
||||||
ldmSteps: number
|
ldmSteps: number
|
||||||
@ -99,6 +98,9 @@ export type Settings = {
|
|||||||
enableLCMLora: boolean
|
enableLCMLora: boolean
|
||||||
enableFreeu: boolean
|
enableFreeu: boolean
|
||||||
freeuConfig: FreeuConfig
|
freeuConfig: FreeuConfig
|
||||||
|
|
||||||
|
// PowerPaint
|
||||||
|
powerpaintTask: PowerPaintTask
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig = {
|
type ServerConfig = {
|
||||||
@ -178,9 +180,9 @@ type AppAction = {
|
|||||||
setExtenderWidth: (newValue: number) => void
|
setExtenderWidth: (newValue: number) => void
|
||||||
setExtenderHeight: (newValue: number) => void
|
setExtenderHeight: (newValue: number) => void
|
||||||
setIsCropperExtenderResizing: (newValue: boolean) => void
|
setIsCropperExtenderResizing: (newValue: boolean) => void
|
||||||
updateExtenderDirection: (newValue: string) => void
|
updateExtenderDirection: (newValue: ExtenderDirection) => void
|
||||||
resetExtender: (width: number, height: number) => void
|
resetExtender: (width: number, height: number) => void
|
||||||
updateExtenderByBuiltIn: (direction: string, scale: number) => void
|
updateExtenderByBuiltIn: (direction: ExtenderDirection, scale: number) => void
|
||||||
|
|
||||||
setServerConfig: (newValue: ServerConfig) => void
|
setServerConfig: (newValue: ServerConfig) => void
|
||||||
setSeed: (newValue: number) => void
|
setSeed: (newValue: number) => void
|
||||||
@ -296,7 +298,7 @@ const defaultValues: AppState = {
|
|||||||
enableControlnet: false,
|
enableControlnet: false,
|
||||||
showCropper: false,
|
showCropper: false,
|
||||||
showExtender: false,
|
showExtender: false,
|
||||||
extenderDirection: EXTENDER_ALL,
|
extenderDirection: ExtenderDirection.xy,
|
||||||
enableDownloadMask: false,
|
enableDownloadMask: false,
|
||||||
enableManualInpainting: false,
|
enableManualInpainting: false,
|
||||||
enableUploadMask: false,
|
enableUploadMask: false,
|
||||||
@ -309,7 +311,7 @@ const defaultValues: AppState = {
|
|||||||
negativePrompt: DEFAULT_NEGATIVE_PROMPT,
|
negativePrompt: DEFAULT_NEGATIVE_PROMPT,
|
||||||
seed: 42,
|
seed: 42,
|
||||||
seedFixed: false,
|
seedFixed: false,
|
||||||
sdMaskBlur: 5,
|
sdMaskBlur: 35,
|
||||||
sdStrength: 1.0,
|
sdStrength: 1.0,
|
||||||
sdSteps: 50,
|
sdSteps: 50,
|
||||||
sdGuidanceScale: 7.5,
|
sdGuidanceScale: 7.5,
|
||||||
@ -322,6 +324,7 @@ const defaultValues: AppState = {
|
|||||||
enableLCMLora: false,
|
enableLCMLora: false,
|
||||||
enableFreeu: false,
|
enableFreeu: false,
|
||||||
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
||||||
|
powerpaintTask: PowerPaintTask.text_guided,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -894,7 +897,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
state.isCropperExtenderResizing = newValue
|
state.isCropperExtenderResizing = newValue
|
||||||
}),
|
}),
|
||||||
|
|
||||||
updateExtenderDirection: (newValue: string) => {
|
updateExtenderDirection: (newValue: ExtenderDirection) => {
|
||||||
console.log(
|
console.log(
|
||||||
`updateExtenderDirection: ${JSON.stringify(get().extenderState)}`
|
`updateExtenderDirection: ${JSON.stringify(get().extenderState)}`
|
||||||
)
|
)
|
||||||
@ -908,7 +911,10 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
get().updateExtenderByBuiltIn(newValue, 1.5)
|
get().updateExtenderByBuiltIn(newValue, 1.5)
|
||||||
},
|
},
|
||||||
|
|
||||||
updateExtenderByBuiltIn: (direction: string, scale: number) => {
|
updateExtenderByBuiltIn: (
|
||||||
|
direction: ExtenderDirection,
|
||||||
|
scale: number
|
||||||
|
) => {
|
||||||
const newExtenderState = { ...defaultValues.extenderState }
|
const newExtenderState = { ...defaultValues.extenderState }
|
||||||
let { x, y, width, height } = newExtenderState
|
let { x, y, width, height } = newExtenderState
|
||||||
const { imageWidth, imageHeight } = get()
|
const { imageWidth, imageHeight } = get()
|
||||||
@ -916,15 +922,15 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
height = imageHeight
|
height = imageHeight
|
||||||
|
|
||||||
switch (direction) {
|
switch (direction) {
|
||||||
case EXTENDER_X:
|
case ExtenderDirection.x:
|
||||||
x = -Math.ceil((imageWidth * (scale - 1)) / 2)
|
x = -Math.ceil((imageWidth * (scale - 1)) / 2)
|
||||||
width = Math.ceil(imageWidth * scale)
|
width = Math.ceil(imageWidth * scale)
|
||||||
break
|
break
|
||||||
case EXTENDER_Y:
|
case ExtenderDirection.y:
|
||||||
y = -Math.ceil((imageHeight * (scale - 1)) / 2)
|
y = -Math.ceil((imageHeight * (scale - 1)) / 2)
|
||||||
height = Math.ceil(imageHeight * scale)
|
height = Math.ceil(imageHeight * scale)
|
||||||
break
|
break
|
||||||
case EXTENDER_ALL:
|
case ExtenderDirection.xy:
|
||||||
x = -Math.ceil((imageWidth * (scale - 1)) / 2)
|
x = -Math.ceil((imageWidth * (scale - 1)) / 2)
|
||||||
y = -Math.ceil((imageHeight * (scale - 1)) / 2)
|
y = -Math.ceil((imageHeight * (scale - 1)) / 2)
|
||||||
width = Math.ceil(imageWidth * scale)
|
width = Math.ceil(imageWidth * scale)
|
||||||
|
@ -93,3 +93,16 @@ export interface Size {
|
|||||||
width: number
|
width: number
|
||||||
height: number
|
height: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum ExtenderDirection {
|
||||||
|
x = "x",
|
||||||
|
y = "y",
|
||||||
|
xy = "xy",
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum PowerPaintTask {
|
||||||
|
text_guided = "text-guided",
|
||||||
|
shape_guided = "shape-guided",
|
||||||
|
object_remove = "object-remove",
|
||||||
|
outpainting = "outpainting",
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user