lots update

This commit is contained in:
Qing 2023-12-27 22:00:07 +08:00
parent 0ba6c121e0
commit f0b852725f
33 changed files with 4085 additions and 1000 deletions

View File

@ -4,6 +4,11 @@ from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
MPS_UNSUPPORT_MODELS = [ MPS_UNSUPPORT_MODELS = [
"lama", "lama",
"ldm", "ldm",
@ -15,22 +20,8 @@ MPS_UNSUPPORT_MODELS = [
] ]
DEFAULT_MODEL = "lama" DEFAULT_MODEL = "lama"
AVAILABLE_MODELS = [ AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
"lama",
"ldm",
"zits",
"mat",
"fcf",
"manga",
"cv2",
]
DIFFUSERS_MODEL_FP16_REVERSION = [
"runwayml/stable-diffusion-inpainting",
"Sanster/anything-4.0-inpainting",
"Sanster/Realistic_Vision_V1.4-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
"timbrooks/instruct-pix2pix",
]
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
DEFAULT_DEVICE = "cuda" DEFAULT_DEVICE = "cuda"

View File

@ -5,23 +5,23 @@ from typing import List
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR from lama_cleaner.const import (
from lama_cleaner.runtime import setup_model_dir DEFAULT_MODEL_DIR,
from lama_cleaner.schema import (
ModelInfo,
ModelType,
DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
DIFFUSERS_SD_CLASS_NAME, DIFFUSERS_SD_CLASS_NAME,
DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_CLASS_NAME, DIFFUSERS_SDXL_CLASS_NAME,
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
) )
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.model_info import ModelInfo, ModelType
from lama_cleaner.runtime import setup_model_dir
def cli_download_model(model: str, model_dir: Path): def cli_download_model(model: str, model_dir: Path):
setup_model_dir(model_dir) setup_model_dir(model_dir)
from lama_cleaner.model import models from lama_cleaner.model import models
if model in models: if model in models and models[model].is_erase_model:
logger.info(f"Downloading {model}...") logger.info(f"Downloading {model}...")
models[model].download() models[model].download()
logger.info(f"Done.") logger.info(f"Done.")
@ -29,9 +29,10 @@ def cli_download_model(model: str, model_dir: Path):
logger.info(f"Downloading model from Huggingface: {model}") logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
downloaded_path = DiffusionPipeline.download( downloaded_path = handle_from_pretrained_exceptions(
DiffusionPipeline.download,
pretrained_model_name=model, pretrained_model_name=model,
variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main", variant="fp16",
resume_download=True, resume_download=True,
) )
logger.info(f"Done. Downloaded to {downloaded_path}") logger.info(f"Done. Downloaded to {downloaded_path}")
@ -43,19 +44,15 @@ def folder_name_to_show_name(name: str) -> str:
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir) cache_dir = Path(cache_dir)
stable_diffusion_dir = cache_dir / "stable_diffusion"
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}") # logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
res = [] res = []
for it in cache_dir.glob(f"*.*"): for it in stable_diffusion_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]: if it.suffix not in [".safetensors", ".ckpt"]:
continue continue
if "inpaint" in str(it).lower(): if "inpaint" in str(it).lower():
if "sdxl" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
model_type = ModelType.DIFFUSERS_SD_INPAINT model_type = ModelType.DIFFUSERS_SD_INPAINT
else:
if "sdxl" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL
else: else:
model_type = ModelType.DIFFUSERS_SD model_type = ModelType.DIFFUSERS_SD
res.append( res.append(
@ -66,6 +63,22 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
is_single_file_diffusers=True, is_single_file_diffusers=True,
) )
) )
for it in stable_diffusion_xl_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
if "inpaint" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
model_type = ModelType.DIFFUSERS_SDXL
res.append(
ModelInfo(
name=it.name,
path=str(it.absolute()),
model_type=model_type,
is_single_file_diffusers=True,
)
)
return res return res
@ -104,8 +117,9 @@ def scan_models() -> List[ModelInfo]:
name = folder_name_to_show_name(it.parent.parent.parent.name) name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names: if name in diffusers_model_names:
continue continue
if "PowerPaint" in name:
if _class_name == DIFFUSERS_SD_CLASS_NAME: model_type = ModelType.DIFFUSERS_OTHER
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT model_type = ModelType.DIFFUSERS_SD_INPAINT

View File

@ -290,3 +290,7 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
return cv2.drawContours(new_mask, contours, max_index, 255, -1) return cv2.drawContours(new_mask, contours, max_index, 255, -1)
else: else:
return mask return mask
def is_mac():
return sys.platform == "darwin"

View File

@ -9,6 +9,7 @@ from .mat import MAT
from .mi_gan import MIGAN from .mi_gan import MIGAN
from .opencv2 import OpenCV2 from .opencv2 import OpenCV2
from .paint_by_example import PaintByExample from .paint_by_example import PaintByExample
from .power_paint.power_paint import PowerPaint
from .sd import SD15, SD2, Anything4, RealisticVision14, SD from .sd import SD15, SD2, Anything4, RealisticVision14, SD
from .sdxl import SDXL from .sdxl import SDXL
from .zits import ZITS from .zits import ZITS
@ -30,4 +31,5 @@ models = {
InstructPix2Pix.name: InstructPix2Pix, InstructPix2Pix.name: InstructPix2Pix,
Kandinsky22.name: Kandinsky22, Kandinsky22.name: Kandinsky22,
SDXL.name: SDXL, SDXL.name: SDXL,
PowerPaint.name: PowerPaint,
} }

View File

@ -14,7 +14,7 @@ from lama_cleaner.helper import (
) )
from lama_cleaner.model.helper.g_diffuser_bot import expand_image from lama_cleaner.model.helper.g_diffuser_bot import expand_image
from lama_cleaner.model.utils import get_scheduler from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo from lama_cleaner.schema import Config, HDStrategy, SDSampler
class InpaintModel: class InpaintModel:
@ -271,7 +271,7 @@ class InpaintModel:
class DiffusionInpaintModel(InpaintModel): class DiffusionInpaintModel(InpaintModel):
def __init__(self, device, **kwargs): def __init__(self, device, **kwargs):
self.model_info: ModelInfo = kwargs["model_info"] self.model_info = kwargs["model_info"]
self.model_id_or_path = self.model_info.path self.model_id_or_path = self.model_info.path
super().__init__(device, **kwargs) super().__init__(device, **kwargs)

View File

@ -5,7 +5,6 @@ import torch
from diffusers import ControlNetModel, DiffusionPipeline from diffusers import ControlNetModel, DiffusionPipeline
from loguru import logger from loguru import logger
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.helper.controlnet_preprocess import ( from lama_cleaner.model.helper.controlnet_preprocess import (
make_canny_control_image, make_canny_control_image,
@ -14,8 +13,8 @@ from lama_cleaner.model.helper.controlnet_preprocess import (
make_inpaint_control_image, make_inpaint_control_image,
) )
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import get_scheduler from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions
from lama_cleaner.schema import Config, ModelInfo, ModelType from lama_cleaner.schema import Config, ModelType
class ControlNet(DiffusionInpaintModel): class ControlNet(DiffusionInpaintModel):
@ -39,11 +38,11 @@ class ControlNet(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False) fp16 = not kwargs.get("no_half", False)
model_info: ModelInfo = kwargs["model_info"] model_info = kwargs["model_info"]
sd_controlnet_method = kwargs["sd_controlnet_method"] controlnet_method = kwargs["controlnet_method"]
self.model_info = model_info self.model_info = model_info
self.sd_controlnet_method = sd_controlnet_method self.controlnet_method = controlnet_method
model_kwargs = {} model_kwargs = {}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
@ -76,7 +75,8 @@ class ControlNet(DiffusionInpaintModel):
) )
controlnet = ControlNetModel.from_pretrained( controlnet = ControlNetModel.from_pretrained(
sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True pretrained_model_name_or_path=controlnet_method,
resume_download=True,
) )
if model_info.is_single_file_diffusers: if model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD: if self.model_info.model_type == ModelType.DIFFUSERS_SD:
@ -88,17 +88,12 @@ class ControlNet(DiffusionInpaintModel):
model_info.path, controlnet=controlnet, **model_kwargs model_info.path, controlnet=controlnet, **model_kwargs
).to(torch_dtype) ).to(torch_dtype)
else: else:
self.model = PipeClass.from_pretrained( self.model = handle_from_pretrained_exceptions(
model_info.path, PipeClass.from_pretrained,
pretrained_model_name_or_path=model_info.path,
controlnet=controlnet, controlnet=controlnet,
revision="fp16" variant="fp16",
if ( dtype=torch_dtype,
model_info.path in DIFFUSERS_MODEL_FP16_REVERSION
and use_gpu
and fp16
)
else "main",
torch_dtype=torch_dtype,
**model_kwargs, **model_kwargs,
) )
@ -116,23 +111,23 @@ class ControlNet(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
def switch_controlnet_method(self, new_method: str): def switch_controlnet_method(self, new_method: str):
self.sd_controlnet_method = new_method self.controlnet_method = new_method
controlnet = ControlNetModel.from_pretrained( controlnet = ControlNetModel.from_pretrained(
new_method, torch_dtype=self.torch_dtype, resume_download=True new_method, torch_dtype=self.torch_dtype, resume_download=True
).to(self.model.device) ).to(self.model.device)
self.model.controlnet = controlnet self.model.controlnet = controlnet
def _get_control_image(self, image, mask): def _get_control_image(self, image, mask):
if "canny" in self.sd_controlnet_method: if "canny" in self.controlnet_method:
control_image = make_canny_control_image(image) control_image = make_canny_control_image(image)
elif "openpose" in self.sd_controlnet_method: elif "openpose" in self.controlnet_method:
control_image = make_openpose_control_image(image) control_image = make_openpose_control_image(image)
elif "depth" in self.sd_controlnet_method: elif "depth" in self.controlnet_method:
control_image = make_depth_control_image(image) control_image = make_depth_control_image(image)
elif "inpaint" in self.sd_controlnet_method: elif "inpaint" in self.controlnet_method:
control_image = make_inpaint_control_image(image, mask) control_image = make_inpaint_control_image(image, mask)
else: else:
raise NotImplementedError(f"{self.sd_controlnet_method} not implemented") raise NotImplementedError(f"{self.controlnet_method} not implemented")
return control_image return control_image
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):

View File

@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel):
} }
self.model = AutoPipelineForInpainting.from_pretrained( self.model = AutoPipelineForInpainting.from_pretrained(
self.model_id_or_path, **model_kwargs self.name, **model_kwargs
).to(device) ).to(device)
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
@ -66,4 +66,3 @@ class Kandinsky(DiffusionInpaintModel):
class Kandinsky22(Kandinsky): class Kandinsky22(Kandinsky):
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint"

View File

@ -16,7 +16,7 @@ from lama_cleaner.model.base import InpaintModel
MIGAN_MODEL_URL = os.environ.get( MIGAN_MODEL_URL = os.environ.get(
"MIGAN_MODEL_URL", "MIGAN_MODEL_URL",
"/Users/cwq/code/github/MI-GAN/exported_models/migan_places512/models/migan_traced.pt", "https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
) )
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c") MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")

View File

@ -28,7 +28,7 @@ class PaintByExample(DiffusionInpaintModel):
) )
self.model = DiffusionPipeline.from_pretrained( self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs self.name, torch_dtype=torch_dtype, **model_kwargs
) )
# TODO: gpu_id # TODO: gpu_id

View File

@ -1,3 +0,0 @@
from .pipeline_stable_diffusion_controlnet_inpaint import (
StableDiffusionControlNetInpaintPipeline,
)

View File

@ -1,638 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
from typing import Union, List, Optional, Callable, Dict, Any
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
import torch
import PIL.Image
from diffusers.pipelines.controlnet.pipeline_controlnet import *
from diffusers.utils import replace_example_docstring
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # !pip install opencv-python transformers accelerate
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch
>>> import cv2
>>> from PIL import Image
>>> # download an image
>>> image = load_image(
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
... )
>>> image = np.array(image)
>>> mask_image = load_image(
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
... )
>>> mask_image = np.array(mask_image)
>>> # get canny image
>>> canny_image = cv2.Canny(image, 100, 200)
>>> canny_image = canny_image[:, :, None]
>>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
>>> canny_image = Image.fromarray(canny_image)
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
>>> # remove following line if xformers is not installed
>>> pipe.enable_xformers_memory_efficient_attention()
>>> pipe.enable_model_cpu_offload()
>>> # generate image
>>> generator = torch.manual_seed(0)
>>> image = pipe(
... "futuristic-looking doggo",
... num_inference_steps=20,
... generator=generator,
... image=image,
... control_image=canny_image,
... mask_image=mask_image
... ).images[0]
```
"""
def prepare_mask_and_masked_image(image, mask):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(
f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
)
# Batch single image
if image.ndim == 3:
assert (
image.shape[0] == 3
), "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert (
image.ndim == 4 and mask.ndim == 4
), "Image and Mask must have 4 dimensions"
assert (
image.shape[-2:] == mask.shape[-2:]
), "Image and Mask must have the same spatial dimensions"
assert (
image.shape[0] == mask.shape[0]
), "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(
f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
)
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`ControlNetModel`]):
Provides additional conditioning to the unet during the denoising process
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@classmethod
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
controlnet = kwargs.pop("controlnet", None)
pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path,
num_in_channels=9,
from_safetensors=pretrained_model_link_or_path.endswith("safetensors"),
device="cpu",
load_safety_checker=False,
)
inpaint_pipe = cls(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
controlnet=controlnet,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
del pipe
gc.collect()
return inpaint_pipe
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
device,
generator,
do_classifier_free_guidance,
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
if isinstance(generator, list):
masked_image_latents = [
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
generator=generator[i]
)
for i in range(batch_size)
]
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
generator=generator
)
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2)
if do_classifier_free_guidance
else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
def _default_height_width(self, height, width, image):
if isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[3]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[2]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
control_image: Union[
torch.FloatTensor,
PIL.Image.Image,
List[torch.FloatTensor],
List[PIL.Image.Image],
] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
also be accepted as an image. The control image is automatically resized to fit the output image.
mask_image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, control_image)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
image=control_image,
callback_steps=callback_steps,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare image
control_image = self.prepare_image(
control_image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
if do_classifier_free_guidance:
control_image = torch.cat([control_image] * 2)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.controlnet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# EXTRA: prepare mask latents
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=control_image,
return_dict=False,
)
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# predict the noise residual
latent_model_input = torch.cat(
[latent_model_input, mask, masked_image_latents], dim=1
)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,96 @@
from PIL import Image
import PIL.Image
import cv2
import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.schema import Config
from .powerpaint_tokenizer import add_task_to_prompt
class PowerPaint(DiffusionInpaintModel):
name = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
pad_mod = 8
min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
def init_model(self, device: torch.device, **kwargs):
from .pipeline_powerpaint import StableDiffusionInpaintPipeline
from .powerpaint_tokenizer import PowerPaintTokenizer
fp16 = not kwargs.get("no_half", False)
model_kwargs = {}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = handle_from_pretrained_exceptions(
StableDiffusionInpaintPipeline.from_pretrained,
pretrained_model_name_or_path=self.name,
variant="fp16",
torch_dtype=torch_dtype,
**model_kwargs,
)
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
img_h, img_w = image.shape[:2]
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
config.prompt, config.negative_prompt, config.powerpaint_task
)
output = self.model(
image=PIL.Image.fromarray(image),
promptA=promptA,
promptB=promptB,
tradoff=config.fitting_degree,
tradoff_nag=config.fitting_degree,
negative_promptA=negative_promptA,
negative_promptB=negative_promptB,
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
num_inference_steps=config.sd_steps,
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@ -0,0 +1,540 @@
import torch
import torch.nn as nn
import copy
import random
from typing import Any, List, Optional, Union
from transformers import CLIPTokenizer
from lama_cleaner.schema import PowerPaintTask
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
if task == PowerPaintTask.object_remove:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
elif task == PowerPaintTask.shape_guided:
promptA = prompt + " P_shape"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.outpainting:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
else:
promptA = prompt + " P_obj"
promptB = prompt + " P_obj"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
return promptA, promptB, negative_promptA, negative_promptB
class PowerPaintTokenizer:
def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer
self.token_map = {}
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
num_vec_per_token = 10
for placeholder_token in placeholder_tokens:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
output.append(ith_token)
self.token_map[placeholder_token] = output
def __getattr__(self, name: str) -> Any:
if name == "wrapped":
return super().__getattr__("wrapped")
try:
return getattr(self.wrapped, name)
except AttributeError:
try:
return super().__getattr__(name)
except AttributeError:
raise AttributeError(
"'name' cannot be found in both "
f"'{self.__class__.__name__}' and "
f"'{self.__class__.__name__}.tokenizer'."
)
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
"""Attempt to add tokens to the tokenizer.
Args:
tokens (Union[str, List[str]]): The tokens to be added.
"""
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
assert num_added_tokens != 0, (
f"The tokenizer already contains the token {tokens}. Please pass "
"a different `placeholder_token` that is not already in the "
"tokenizer."
)
def get_token_info(self, token: str) -> dict:
"""Get the information of a token, including its start and end index in
the current tokenizer.
Args:
token (str): The token to be queried.
Returns:
dict: The information of the token, including its start and end
index in current tokenizer.
"""
token_ids = self.__call__(token).input_ids
start, end = token_ids[1], token_ids[-2] + 1
return {"name": token, "start": start, "end": end}
def add_placeholder_token(
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
):
"""Add placeholder tokens to the tokenizer.
Args:
placeholder_token (str): The placeholder token to be added.
num_vec_per_token (int, optional): The number of vectors of
the added placeholder token.
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
"""
output = []
if num_vec_per_token == 1:
self.try_adding_tokens(placeholder_token, *args, **kwargs)
output.append(placeholder_token)
else:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
self.try_adding_tokens(ith_token, *args, **kwargs)
output.append(ith_token)
for token in self.token_map:
if token in placeholder_token:
raise ValueError(
f"The tokenizer already has placeholder token {token} "
f"that can get confused with {placeholder_token} "
"keep placeholder tokens independent"
)
self.token_map[placeholder_token] = output
def replace_placeholder_tokens_in_text(
self,
text: Union[str, List[str]],
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
) -> Union[str, List[str]]:
"""Replace the keywords in text with placeholder tokens. This function
will be called in `self.__call__` and `self.encode`.
Args:
text (Union[str, List[str]]): The text to be processed.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(
self.replace_placeholder_tokens_in_text(
text[i], vector_shuffle=vector_shuffle
)
)
return output
for placeholder_token in self.token_map:
if placeholder_token in text:
tokens = self.token_map[placeholder_token]
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
if vector_shuffle:
tokens = copy.copy(tokens)
random.shuffle(tokens)
text = text.replace(placeholder_token, " ".join(tokens))
return text
def replace_text_with_placeholder_tokens(
self, text: Union[str, List[str]]
) -> Union[str, List[str]]:
"""Replace the placeholder tokens in text with the original keywords.
This function will be called in `self.decode`.
Args:
text (Union[str, List[str]]): The text to be processed.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_text_with_placeholder_tokens(text[i]))
return output
for placeholder_token, tokens in self.token_map.items():
merged_tokens = " ".join(tokens)
if merged_tokens in text:
text = text.replace(merged_tokens, placeholder_token)
return text
def __call__(
self,
text: Union[str, List[str]],
*args,
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
**kwargs,
):
"""The call function of the wrapper.
Args:
text (Union[str, List[str]]): The text to be tokenized.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
)
return self.wrapped.__call__(replaced_text, *args, **kwargs)
def encode(self, text: Union[str, List[str]], *args, **kwargs):
"""Encode the passed text to token index.
Args:
text (Union[str, List[str]]): The text to be encode.
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(text)
return self.wrapped(replaced_text, *args, **kwargs)
def decode(
self, token_ids, return_raw: bool = False, *args, **kwargs
) -> Union[str, List[str]]:
"""Decode the token index to text.
Args:
token_ids: The token index to be decoded.
return_raw: Whether keep the placeholder token in the text.
Defaults to False.
*args, **kwargs: The arguments for `self.wrapped.decode`.
Returns:
Union[str, List[str]]: The decoded text.
"""
text = self.wrapped.decode(token_ids, *args, **kwargs)
if return_raw:
return text
replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text
class EmbeddingLayerWithFixes(nn.Module):
"""The revised embedding layer to support external embeddings. This design
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
jack.py#L224 # noqa.
Args:
wrapped (nn.Emebdding): The embedding layer to be wrapped.
external_embeddings (Union[dict, List[dict]], optional): The external
embeddings added to this layer. Defaults to None.
"""
def __init__(
self,
wrapped: nn.Embedding,
external_embeddings: Optional[Union[dict, List[dict]]] = None,
):
super().__init__()
self.wrapped = wrapped
self.num_embeddings = wrapped.weight.shape[0]
self.external_embeddings = []
if external_embeddings:
self.add_embeddings(external_embeddings)
self.trainable_embeddings = nn.ParameterDict()
@property
def weight(self):
"""Get the weight of wrapped embedding layer."""
return self.wrapped.weight
def check_duplicate_names(self, embeddings: List[dict]):
"""Check whether duplicate names exist in list of 'external
embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
names = [emb["name"] for emb in embeddings]
assert len(names) == len(set(names)), (
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
)
def check_ids_overlap(self, embeddings):
"""Check whether overlap exist in token ids of 'external_embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
ids_range.sort() # sort by 'start'
# check if 'end' has overlapping
for idx in range(len(ids_range) - 1):
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
)
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
"""Add external embeddings to this layer.
Use case:
>>> 1. Add token to tokenizer and get the token id.
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
>>> # 'how much' in kiswahili
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
>>>
>>> 2. Add external embeddings to the model.
>>> new_embedding = {
>>> 'name': 'ngapi', # 'how much' in kiswahili
>>> 'embedding': torch.ones(1, 15) * 4,
>>> 'start': tokenizer.get_token_info('kwaheri')['start'],
>>> 'end': tokenizer.get_token_info('kwaheri')['end'],
>>> 'trainable': False # if True, will registry as a parameter
>>> }
>>> embedding_layer = nn.Embedding(10, 15)
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
>>> embedding_layer_wrapper.add_embeddings(new_embedding)
>>>
>>> 3. Forward tokenizer and embedding layer!
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
>>> input_ids = tokenizer(
>>> input_text, padding='max_length', truncation=True,
>>> return_tensors='pt')['input_ids']
>>> out_feat = embedding_layer_wrapper(input_ids)
>>>
>>> 4. Let's validate the result!
>>> assert (out_feat[0, 3: 7] == 2.3).all()
>>> assert (out_feat[2, 5: 9] == 2.3).all()
Args:
embeddings (Union[dict, list[dict]]): The external embeddings to
be added. Each dict must contain the following 4 fields: 'name'
(the name of this embedding), 'embedding' (the embedding
tensor), 'start' (the start token id of this embedding), 'end'
(the end token id of this embedding). For example:
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
"""
if isinstance(embeddings, dict):
embeddings = [embeddings]
self.external_embeddings += embeddings
self.check_duplicate_names(self.external_embeddings)
self.check_ids_overlap(self.external_embeddings)
# set for trainable
added_trainable_emb_info = []
for embedding in embeddings:
trainable = embedding.get("trainable", False)
if trainable:
name = embedding["name"]
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
self.trainable_embeddings[name] = embedding["embedding"]
added_trainable_emb_info.append(name)
added_emb_info = [emb["name"] for emb in embeddings]
added_emb_info = ", ".join(added_emb_info)
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
if added_trainable_emb_info:
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
print(
"Successfully add trainable external embeddings: "
f"{added_trainable_emb_info}",
"current",
)
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Replace external input ids to 0.
Args:
input_ids (torch.Tensor): The input ids to be replaced.
Returns:
torch.Tensor: The replaced input ids.
"""
input_ids_fwd = input_ids.clone()
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
return input_ids_fwd
def replace_embeddings(
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
) -> torch.Tensor:
"""Replace external embedding to the embedding layer. Noted that, in
this function we use `torch.cat` to avoid inplace modification.
Args:
input_ids (torch.Tensor): The original token ids. Shape like
[LENGTH, ].
embedding (torch.Tensor): The embedding of token ids after
`replace_input_ids` function.
external_embedding (dict): The external embedding to be replaced.
Returns:
torch.Tensor: The replaced embedding.
"""
new_embedding = []
name = external_embedding["name"]
start = external_embedding["start"]
end = external_embedding["end"]
target_ids_to_replace = [i for i in range(start, end)]
ext_emb = external_embedding["embedding"]
# do not need to replace
if not (input_ids == start).any():
return embedding
# start replace
s_idx, e_idx = 0, 0
while e_idx < len(input_ids):
if input_ids[e_idx] == start:
if e_idx != 0:
# add embedding do not need to replace
new_embedding.append(embedding[s_idx:e_idx])
# check if the next embedding need to replace is valid
actually_ids_to_replace = [
int(i) for i in input_ids[e_idx : e_idx + end - start]
]
assert actually_ids_to_replace == target_ids_to_replace, (
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
f"Expect '{target_ids_to_replace}' for embedding "
f"'{name}' but found '{actually_ids_to_replace}'."
)
new_embedding.append(ext_emb)
s_idx = e_idx + end - start
e_idx = s_idx + 1
else:
e_idx += 1
if e_idx == len(input_ids):
new_embedding.append(embedding[s_idx:e_idx])
return torch.cat(new_embedding, dim=0)
def forward(
self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None
):
"""The forward function.
Args:
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
[LENGTH, ].
external_embeddings (Optional[List[dict]]): The external
embeddings. If not passed, only `self.external_embeddings`
will be used. Defaults to None.
input_ids: shape like [bz, LENGTH] or [LENGTH].
"""
assert input_ids.ndim in [1, 2]
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if external_embeddings is None and not self.external_embeddings:
return self.wrapped(input_ids)
input_ids_fwd = self.replace_input_ids(input_ids)
inputs_embeds = self.wrapped(input_ids_fwd)
vecs = []
if external_embeddings is None:
external_embeddings = []
elif isinstance(external_embeddings, dict):
external_embeddings = [external_embeddings]
embeddings = self.external_embeddings + external_embeddings
for input_id, embedding in zip(input_ids, inputs_embeds):
new_embedding = embedding
for external_embedding in embeddings:
new_embedding = self.replace_embeddings(
input_id, new_embedding, external_embedding
)
vecs.append(new_embedding)
return torch.stack(vecs)
def add_tokens(
tokenizer,
text_encoder,
placeholder_tokens: list,
initialize_tokens: list = None,
num_vectors_per_token: int = 1,
):
"""Add token for training.
# TODO: support add tokens as dict, then we can load pretrained tokens.
"""
if initialize_tokens is not None:
assert len(initialize_tokens) == len(
placeholder_tokens
), "placeholder_token should be the same length as initialize_token"
for ii in range(len(placeholder_tokens)):
tokenizer.add_placeholder_token(
placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token
)
# text_encoder.set_embedding_layer()
embedding_layer = text_encoder.text_model.embeddings.token_embedding
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(
embedding_layer
)
embedding_layer = text_encoder.text_model.embeddings.token_embedding
assert embedding_layer is not None, (
"Do not support get embedding layer for current text encoder. "
"Please check your configuration."
)
initialize_embedding = []
if initialize_tokens is not None:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
initialize_embedding.append(
temp_embedding[None, ...].repeat(num_vectors_per_token, 1)
)
else:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer("a").input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
len_emb = temp_embedding.shape[0]
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
initialize_embedding.append(init_weight)
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
token_info_all = []
for ii in range(len(placeholder_tokens)):
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
token_info["embedding"] = initialize_embedding[ii]
token_info["trainable"] = True
token_info_all.append(token_info)
embedding_layer.add_embeddings(token_info_all)

View File

@ -3,9 +3,9 @@ import cv2
import torch import torch
from loguru import logger from loguru import logger
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.schema import Config, ModelType from lama_cleaner.schema import Config, ModelType
@ -40,20 +40,18 @@ class SD(DiffusionInpaintModel):
model_kwargs["num_in_channels"] = 9 model_kwargs["num_in_channels"] = 9
self.model = StableDiffusionInpaintPipeline.from_single_file( self.model = StableDiffusionInpaintPipeline.from_single_file(
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs self.model_id_or_path, dtype=torch_dtype, **model_kwargs
) )
else: else:
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = handle_from_pretrained_exceptions(
self.model_id_or_path, StableDiffusionInpaintPipeline.from_pretrained,
revision="fp16" pretrained_model_name_or_path=self.model_id_or_path,
if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION variant="fp16",
else "main", dtype=torch_dtype,
torch_dtype=torch_dtype,
**model_kwargs, **model_kwargs,
) )
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
# TODO: gpu_id
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)
else: else:
@ -98,20 +96,20 @@ class SD(DiffusionInpaintModel):
class SD15(SD): class SD15(SD):
name = "sd1.5" name = "runwayml/stable-diffusion-inpainting"
model_id_or_path = "runwayml/stable-diffusion-inpainting" model_id_or_path = "runwayml/stable-diffusion-inpainting"
class Anything4(SD): class Anything4(SD):
name = "anything4" name = "Sanster/anything-4.0-inpainting"
model_id_or_path = "Sanster/anything-4.0-inpainting" model_id_or_path = "Sanster/anything-4.0-inpainting"
class RealisticVision14(SD): class RealisticVision14(SD):
name = "realisticVision1.4" name = "Sanster/Realistic_Vision_V1.4-inpainting"
model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting" model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting"
class SD2(SD): class SD2(SD):
name = "sd2" name = "stabilityai/stable-diffusion-2-inpainting"
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"

View File

@ -8,11 +8,12 @@ from diffusers import AutoencoderKL
from loguru import logger from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.schema import Config, ModelType from lama_cleaner.schema import Config, ModelType
class SDXL(DiffusionInpaintModel): class SDXL(DiffusionInpaintModel):
name = "sdxl" name = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdxl" lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
@ -34,18 +35,19 @@ class SDXL(DiffusionInpaintModel):
if os.path.isfile(self.model_id_or_path): if os.path.isfile(self.model_id_or_path):
self.model = StableDiffusionXLInpaintPipeline.from_single_file( self.model = StableDiffusionXLInpaintPipeline.from_single_file(
self.model_id_or_path, self.model_id_or_path,
torch_dtype=torch_dtype, dtype=torch_dtype,
num_in_channels=num_in_channels, num_in_channels=num_in_channels,
) )
else: else:
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
) )
self.model = StableDiffusionXLInpaintPipeline.from_pretrained( self.model = handle_from_pretrained_exceptions(
self.model_id_or_path, StableDiffusionXLInpaintPipeline.from_pretrained,
revision="main", pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
vae=vae, vae=vae,
variant="fp16",
) )
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:

View File

@ -1,6 +1,7 @@
import gc import gc
import math import math
import random import random
import traceback
from typing import Any from typing import Any
import torch import torch
@ -16,8 +17,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
UniPCMultistepScheduler, UniPCMultistepScheduler,
LCMScheduler LCMScheduler,
) )
from huggingface_hub.utils import RevisionNotFoundError
from loguru import logger
from requests import HTTPError
from lama_cleaner.schema import SDSampler from lama_cleaner.schema import SDSampler
from torch import conv2d, conv_transpose2d from torch import conv2d, conv_transpose2d
@ -944,3 +948,20 @@ def get_scheduler(sd_sampler, scheduler_config):
return LCMScheduler.from_config(scheduler_config) return LCMScheduler.from_config(scheduler_config)
else: else:
raise ValueError(sd_sampler) raise ValueError(sd_sampler)
def handle_from_pretrained_exceptions(func, **kwargs):
try:
return func(**kwargs)
except ValueError as e:
# 处理异常的逻辑
if "You are trying to load the model files of the `variant=fp16`" in str(e):
logger.info("variant=fp16 not found, try revision=fp16")
return func(**{**kwargs, "variant": None, "revision": "fp16"})
except OSError as e:
previous_traceback = traceback.format_exc()
if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
logger.info("revision=fp16 not found, try revision=main")
return func(**{**kwargs, "variant": None, "revision": "main"})
except Exception as e:
raise e

100
lama_cleaner/model_info.py Normal file
View File

@ -0,0 +1,100 @@
from enum import Enum
from typing import List
from pydantic import computed_field, BaseModel
from lama_cleaner.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
)
from lama_cleaner.model import InstructPix2Pix, Kandinsky22, PowerPaint, SD2
from lama_cleaner.schema import ModelType
class ModelInfo(BaseModel):
name: str
path: str
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
InstructPix2Pix.name,
Kandinsky22.name,
PowerPaint.name,
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if self.name in [SD2.name]:
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
if self.name == PowerPaint.name:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_strength(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_outpainting(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [Kandinsky22.name, PowerPaint.name]
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [PowerPaint.name]
@computed_field
@property
def support_freeu(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [InstructPix2Pix.name]

View File

@ -7,7 +7,8 @@ from lama_cleaner.download import scan_models
from lama_cleaner.helper import switch_mps_device from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model import models, ControlNet, SD, SDXL from lama_cleaner.model import models, ControlNet, SD, SDXL
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.utils import torch_gc
from lama_cleaner.schema import Config, ModelInfo, ModelType from lama_cleaner.model_info import ModelInfo, ModelType
from lama_cleaner.schema import Config
class ModelManager: class ModelManager:
@ -18,13 +19,20 @@ class ModelManager:
self.available_models: Dict[str, ModelInfo] = {} self.available_models: Dict[str, ModelInfo] = {}
self.scan_models() self.scan_models()
self.sd_controlnet = False self.enable_controlnet = kwargs.get("enable_controlnet", False)
self.sd_controlnet_method = "" controlnet_method = kwargs.get("controlnet_method", None)
if (
controlnet_method is None
and name in self.available_models
and self.available_models[name].support_controlnet
):
controlnet_method = self.available_models[name].controlnets[0]
self.controlnet_method = controlnet_method
self.model = self.init_model(name, device, **kwargs) self.model = self.init_model(name, device, **kwargs)
@property @property
def current_model(self) -> Dict: def current_model(self) -> Dict:
return self.available_models[name].model_dump() return self.available_models[self.name].model_dump()
def init_model(self, name: str, device, **kwargs): def init_model(self, name: str, device, **kwargs):
logger.info(f"Loading model: {name}") logger.info(f"Loading model: {name}")
@ -35,15 +43,14 @@ class ModelManager:
kwargs = { kwargs = {
**kwargs, **kwargs,
"model_info": model_info, "model_info": model_info,
"sd_controlnet": self.sd_controlnet, "enable_controlnet": self.enable_controlnet,
"sd_controlnet_method": self.sd_controlnet_method, "controlnet_method": self.controlnet_method,
} }
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]: if model_info.support_controlnet and self.enable_controlnet:
return models[name](device, **kwargs)
if self.sd_controlnet:
return ControlNet(device, **kwargs) return ControlNet(device, **kwargs)
elif model_info.name in models:
return models[name](device, **kwargs)
else: else:
if model_info.model_type in [ if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SD_INPAINT,
@ -75,15 +82,15 @@ class ModelManager:
return return
old_name = self.name old_name = self.name
old_sd_controlnet_method = self.sd_controlnet_method old_controlnet_method = self.controlnet_method
self.name = new_name self.name = new_name
if ( if (
self.available_models[new_name].support_controlnet self.available_models[new_name].support_controlnet
and self.sd_controlnet_method and self.controlnet_method
not in self.available_models[new_name].controlnets not in self.available_models[new_name].controlnets
): ):
self.sd_controlnet_method = self.available_models[new_name].controlnets[0] self.controlnet_method = self.available_models[new_name].controlnets[0]
try: try:
# TODO: enable/disable controlnet without reload model # TODO: enable/disable controlnet without reload model
del self.model del self.model
@ -94,7 +101,7 @@ class ModelManager:
) )
except Exception as e: except Exception as e:
self.name = old_name self.name = old_name
self.sd_controlnet_method = old_sd_controlnet_method self.controlnet_method = old_controlnet_method
logger.info(f"Switch model from {old_name} to {new_name} failed, rollback") logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
self.model = self.init_model( self.model = self.init_model(
old_name, switch_mps_device(old_name, self.device), **self.kwargs old_name, switch_mps_device(old_name, self.device), **self.kwargs
@ -106,24 +113,24 @@ class ModelManager:
return return
if ( if (
self.sd_controlnet self.enable_controlnet
and config.controlnet_method and config.controlnet_method
and self.sd_controlnet_method != config.controlnet_method and self.controlnet_method != config.controlnet_method
): ):
old_sd_controlnet_method = self.sd_controlnet_method old_controlnet_method = self.controlnet_method
self.sd_controlnet_method = config.controlnet_method self.controlnet_method = config.controlnet_method
self.model.switch_controlnet_method(config.controlnet_method) self.model.switch_controlnet_method(config.controlnet_method)
logger.info( logger.info(
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}" f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}"
) )
elif self.sd_controlnet != config.controlnet_enabled: elif self.enable_controlnet != config.enable_controlnet:
self.sd_controlnet = config.controlnet_enabled self.enable_controlnet = config.enable_controlnet
self.sd_controlnet_method = config.controlnet_method self.controlnet_method = config.controlnet_method
self.model = self.init_model( self.model = self.init_model(
self.name, switch_mps_device(self.name, self.device), **self.kwargs self.name, switch_mps_device(self.name, self.device), **self.kwargs
) )
if not config.controlnet_enabled: if not config.enable_controlnet:
logger.info(f"Disable controlnet") logger.info(f"Disable controlnet")
else: else:
logger.info(f"Enable controlnet: {config.controlnet_method}") logger.info(f"Enable controlnet: {config.controlnet_method}")

View File

@ -1,19 +1,8 @@
from typing import Optional, List
from enum import Enum from enum import Enum
from typing import Optional
from PIL.Image import Image from PIL.Image import Image
from pydantic import BaseModel, computed_field from pydantic import BaseModel
from lama_cleaner.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
)
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
class ModelType(str, Enum): class ModelType(str, Enum):
@ -25,103 +14,6 @@ class ModelType(str, Enum):
DIFFUSERS_OTHER = "diffusers_other" DIFFUSERS_OTHER = "diffusers_other"
FREEU_DEFAULT_CONFIGS = {
ModelType.DIFFUSERS_SD: dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
ModelType.DIFFUSERS_SDXL: dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2),
}
class ModelInfo(BaseModel):
name: str
path: str
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
"timbrooks/instruct-pix2pix",
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if self.name in ["stabilityai/stable-diffusion-2-inpainting"]:
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_strength(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_outpainting(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
]
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_freeu(self) -> bool:
return (
self.model_type
in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
or "timbrooks/instruct-pix2pix" in self.name
)
class HDStrategy(str, Enum): class HDStrategy(str, Enum):
# Use original image size # Use original image size
ORIGINAL = "Original" ORIGINAL = "Original"
@ -157,6 +49,13 @@ class FREEUConfig(BaseModel):
b2: float = 1.4 b2: float = 1.4
class PowerPaintTask(str, Enum):
text_guided = "text-guided"
shape_guided = "shape-guided"
object_remove = "object-remove"
outpainting = "outpainting"
class Config(BaseModel): class Config(BaseModel):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -239,6 +138,11 @@ class Config(BaseModel):
p2p_image_guidance_scale: float = 1.5 p2p_image_guidance_scale: float = 1.5
# ControlNet # ControlNet
controlnet_enabled: bool = False enable_controlnet: bool = False
controlnet_conditioning_scale: float = 0.4 controlnet_conditioning_scale: float = 0.4
controlnet_method: str = "control_v11p_sd15_canny" controlnet_method: str = "lllyasviel/control_v11p_sd15_canny"
# PowerPaint
powerpaint_task: PowerPaintTask = PowerPaintTask.text_guided
# control the fitting degree of the generated objects to the mask shape.
fitting_degree: float = 1.0

View File

@ -63,6 +63,7 @@ from lama_cleaner.helper import (
numpy_to_bytes, numpy_to_bytes,
resize_max_size, resize_max_size,
pil_to_bytes, pil_to_bytes,
is_mac,
) )
NUM_THREADS = str(multiprocessing.cpu_count()) NUM_THREADS = str(multiprocessing.cpu_count())
@ -285,9 +286,10 @@ def process():
cv2_radius=form["cv2Radius"], cv2_radius=form["cv2Radius"],
paint_by_example_example_image=paint_by_example_example_image, paint_by_example_example_image=paint_by_example_example_image,
p2p_image_guidance_scale=form["p2pImageGuidanceScale"], p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
controlnet_enabled=form["controlnet_enabled"], enable_controlnet=form["enable_controlnet"],
controlnet_conditioning_scale=form["controlnet_conditioning_scale"], controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
controlnet_method=form["controlnet_method"], controlnet_method=form["controlnet_method"],
powerpaint_task=form["powerpaintTask"],
) )
if config.sd_seed == -1: if config.sd_seed == -1:
@ -305,6 +307,8 @@ def process():
if "CUDA out of memory. " in str(e): if "CUDA out of memory. " in str(e):
# NOTE: the string may change? # NOTE: the string may change?
return "CUDA out of memory", 500 return "CUDA out of memory", 500
elif "Invalid buffer size" in str(e) and is_mac():
return "Out of memory", 500
else: else:
logger.exception(e) logger.exception(e)
return f"{str(e)}", 500 return f"{str(e)}", 500
@ -423,8 +427,8 @@ def get_server_config():
"plugins": list(global_config.plugins.keys()), "plugins": list(global_config.plugins.keys()),
"enableFileManager": global_config.enable_file_manager, "enableFileManager": global_config.enable_file_manager,
"enableAutoSaving": global_config.enable_auto_saving, "enableAutoSaving": global_config.enable_auto_saving,
"enableControlnet": global_config.model_manager.sd_controlnet, "enableControlnet": global_config.model_manager.enable_controlnet,
"controlnetMethod": global_config.model_manager.sd_controlnet_method, "controlnetMethod": global_config.model_manager.controlnet_method,
"disableModelSwitch": global_config.disable_model_switch, "disableModelSwitch": global_config.disable_model_switch,
"isDesktop": global_config.is_desktop, "isDesktop": global_config.is_desktop,
}, 200 }, 200

View File

View File

@ -15,8 +15,8 @@ def save_config(
port, port,
model, model,
sd_local_model_path, sd_local_model_path,
sd_controlnet, enable_controlnet,
sd_controlnet_method, controlnet_method,
device, device,
gui, gui,
no_gui_auto_close, no_gui_auto_close,
@ -176,13 +176,13 @@ def main(config_file: str):
sd_local_model_path = gr.Textbox( sd_local_model_path = gr.Textbox(
init_config.sd_local_model_path, label=f"{SD_LOCAL_MODEL_HELP}" init_config.sd_local_model_path, label=f"{SD_LOCAL_MODEL_HELP}"
) )
sd_controlnet = gr.Checkbox( enable_controlnet = gr.Checkbox(
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}" init_config.enable_controlnet, label=f"{SD_CONTROLNET_HELP}"
) )
sd_controlnet_method = gr.Radio( controlnet_method = gr.Radio(
SD_CONTROLNET_CHOICES, SD_CONTROLNET_CHOICES,
label="ControlNet method", label="ControlNet method",
value=init_config.sd_controlnet_method, value=init_config.controlnet_method,
) )
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
cpu_offload = gr.Checkbox( cpu_offload = gr.Checkbox(
@ -205,8 +205,8 @@ def main(config_file: str):
port, port,
model, model,
sd_local_model_path, sd_local_model_path,
sd_controlnet, enable_controlnet,
sd_controlnet_method, controlnet_method,
device, device,
gui, gui,
no_gui_auto_close, no_gui_auto_close,

View File

@ -1,5 +1,5 @@
import { EXTENDER_ALL, EXTENDER_X, EXTENDER_Y } from "@/lib/const"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import { ExtenderDirection } from "@/lib/types"
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils"
import React, { useEffect, useState } from "react" import React, { useEffect, useState } from "react"
import { twMerge } from "tailwind-merge" import { twMerge } from "tailwind-merge"
@ -107,7 +107,7 @@ const Extender = (props: Props) => {
const newY = evData.initY + offsetY const newY = evData.initY + offsetY
let clampedY = newY let clampedY = newY
let clampedHeight = newHeight let clampedHeight = newHeight
if (extenderDirection === EXTENDER_ALL) { if (extenderDirection === ExtenderDirection.xy) {
if (clampedY > 0) { if (clampedY > 0) {
clampedY = 0 clampedY = 0
clampedHeight = evData.initHeight - Math.abs(evData.initY) clampedHeight = evData.initHeight - Math.abs(evData.initY)
@ -124,7 +124,7 @@ const Extender = (props: Props) => {
const moveBottom = () => { const moveBottom = () => {
const newHeight = evData.initHeight + offsetY const newHeight = evData.initHeight + offsetY
let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight) let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight)
if (extenderDirection === EXTENDER_ALL) { if (extenderDirection === ExtenderDirection.xy) {
if (clampedHeight < Math.abs(clampedY) + imageHeight) { if (clampedHeight < Math.abs(clampedY) + imageHeight) {
clampedHeight = Math.abs(clampedY) + imageHeight clampedHeight = Math.abs(clampedY) + imageHeight
} }
@ -138,7 +138,7 @@ const Extender = (props: Props) => {
const newX = evData.initX + offsetX const newX = evData.initX + offsetX
let clampedX = newX let clampedX = newX
let clampedWidth = newWidth let clampedWidth = newWidth
if (extenderDirection === EXTENDER_ALL) { if (extenderDirection === ExtenderDirection.xy) {
if (clampedX > 0) { if (clampedX > 0) {
clampedX = 0 clampedX = 0
clampedWidth = evData.initWidth - Math.abs(evData.initX) clampedWidth = evData.initWidth - Math.abs(evData.initX)
@ -155,7 +155,7 @@ const Extender = (props: Props) => {
const moveRight = () => { const moveRight = () => {
const newWidth = evData.initWidth + offsetX const newWidth = evData.initWidth + offsetX
let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth) let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth)
if (extenderDirection === EXTENDER_ALL) { if (extenderDirection === ExtenderDirection.xy) {
if (clampedWidth < Math.abs(clampedX) + imageWdith) { if (clampedWidth < Math.abs(clampedX) + imageWdith) {
clampedWidth = Math.abs(clampedX) + imageWdith clampedWidth = Math.abs(clampedX) + imageWdith
} }
@ -296,7 +296,9 @@ const Extender = (props: Props) => {
onPointerDown={onCropPointerDown} onPointerDown={onCropPointerDown}
className="absolute top-0 h-full w-full" className="absolute top-0 h-full w-full"
> >
{[EXTENDER_Y, EXTENDER_ALL].includes(extenderDirection) ? ( {[ExtenderDirection.y, ExtenderDirection.xy].includes(
extenderDirection
) ? (
<> <>
<div <div
className="absolute pointer-events-auto top-0 left-0 w-full cursor-ns-resize h-[12px] mt-[-6px]" className="absolute pointer-events-auto top-0 left-0 w-full cursor-ns-resize h-[12px] mt-[-6px]"
@ -313,7 +315,9 @@ const Extender = (props: Props) => {
<></> <></>
)} )}
{[EXTENDER_X, EXTENDER_ALL].includes(extenderDirection) ? ( {[ExtenderDirection.x, ExtenderDirection.xy].includes(
extenderDirection
) ? (
<> <>
<div <div
className="absolute pointer-events-auto top-0 right-0 h-full cursor-ew-resize w-[12px] mr-[-6px]" className="absolute pointer-events-auto top-0 right-0 h-full cursor-ew-resize w-[12px] mr-[-6px]"
@ -330,7 +334,7 @@ const Extender = (props: Props) => {
<></> <></>
)} )}
{extenderDirection === EXTENDER_ALL ? ( {extenderDirection === ExtenderDirection.xy ? (
<> <>
{createDragHandle("cursor-nw-resize", "top", "left")} {createDragHandle("cursor-nw-resize", "top", "left")}
{createDragHandle("cursor-ne-resize", "top", "right")} {createDragHandle("cursor-ne-resize", "top", "right")}

View File

@ -36,9 +36,9 @@ const PromptInput = () => {
updateSettings({ prompt: target.value }) updateSettings({ prompt: target.value })
} }
const handleRepaintClick = async () => { const handleRepaintClick = () => {
if (prompt.length !== 0 && !isProcessing) { if (!isProcessing) {
await runInpainting() runInpainting()
} }
} }
@ -69,7 +69,7 @@ const PromptInput = () => {
<Button <Button
size="sm" size="sm"
onClick={handleRepaintClick} onClick={handleRepaintClick}
disabled={prompt.length === 0 || isProcessing} disabled={isProcessing}
onMouseEnter={onMouseEnter} onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave} onMouseLeave={onMouseLeave}
> >

View File

@ -1,7 +1,7 @@
import { IconButton } from "@/components/ui/button" import { IconButton } from "@/components/ui/button"
import { useToggle } from "@uidotdev/usehooks" import { useToggle } from "@uidotdev/usehooks"
import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog" import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog"
import { HelpCircle, Settings } from "lucide-react" import { Settings } from "lucide-react"
import { zodResolver } from "@hookform/resolvers/zod" import { zodResolver } from "@hookform/resolvers/zod"
import { useForm } from "react-hook-form" import { useForm } from "react-hook-form"
import * as z from "zod" import * as z from "zod"
@ -179,12 +179,12 @@ export function SettingsDialog() {
<div key={info.name} onClick={() => onModelSelect(info)}> <div key={info.name} onClick={() => onModelSelect(info)}>
<div <div
className={cn([ className={cn([
info.name === model.name ? "bg-muted " : "hover:bg-muted", info.name === model.name ? "bg-muted" : "hover:bg-muted",
"rounded-md px-2 py-1 my-1", "rounded-md px-2 py-1 my-1",
"cursor-default", "cursor-default",
])} ])}
> >
<div className="text-base max-w-sm">{info.name}</div> <div className="text-base">{info.name}</div>
</div> </div>
<Separator /> <Separator />
</div> </div>
@ -223,13 +223,13 @@ export function SettingsDialog() {
<div className="space-y-4 rounded-md"> <div className="space-y-4 rounded-md">
<div className="flex gap-1 items-center justify-start"> <div className="flex gap-1 items-center justify-start">
<div className="font-medium">Available models</div> <div className="font-medium">Available models</div>
{/* <IconButton tooltip="How to download new model" asChild> {/* <IconButton tooltip="How to download new model">
<HelpCircle size={16} strokeWidth={1.5} className="opacity-50" /> <Info size={20} strokeWidth={2} className="opacity-50" />
</IconButton> */} </IconButton> */}
</div> </div>
<Tabs defaultValue={defaultTab}> <Tabs defaultValue={defaultTab}>
<TabsList> <TabsList>
<TabsTrigger value={MODEL_TYPE_INPAINT}>Erase</TabsTrigger> <TabsTrigger value={MODEL_TYPE_INPAINT}>Inpaint</TabsTrigger>
<TabsTrigger value={MODEL_TYPE_DIFFUSERS_SD}> <TabsTrigger value={MODEL_TYPE_DIFFUSERS_SD}>
Stable Diffusion Stable Diffusion
</TabsTrigger> </TabsTrigger>

View File

@ -11,21 +11,14 @@ import {
SelectValue, SelectValue,
} from "../ui/select" } from "../ui/select"
import { Textarea } from "../ui/textarea" import { Textarea } from "../ui/textarea"
import { SDSampler } from "@/lib/types" import { ExtenderDirection, PowerPaintTask, SDSampler } from "@/lib/types"
import { Separator } from "../ui/separator" import { Separator } from "../ui/separator"
import { Move, MoveHorizontal, MoveVertical, Upload } from "lucide-react"
import { Button, ImageUploadButton } from "../ui/button" import { Button, ImageUploadButton } from "../ui/button"
import { Slider } from "../ui/slider" import { Slider } from "../ui/slider"
import { useImage } from "@/hooks/useImage" import { useImage } from "@/hooks/useImage"
import { import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
EXTENDER_ALL,
EXTENDER_X,
EXTENDER_Y,
INSTRUCT_PIX2PIX,
PAINT_BY_EXAMPLE,
} from "@/lib/const"
import { Tabs, TabsContent, TabsList, TabsTrigger } from "../ui/tabs"
import { RowContainer, LabelTitle } from "./LabelTitle" import { RowContainer, LabelTitle } from "./LabelTitle"
import { Upload } from "lucide-react"
const ExtenderButton = ({ const ExtenderButton = ({
text, text,
@ -38,8 +31,7 @@ const ExtenderButton = ({
return ( return (
<Button <Button
variant="outline" variant="outline"
size="sm" className="p-1 h-7"
className="p-1"
disabled={!showExtender} disabled={!showExtender}
onClick={onClick} onClick={onClick}
> >
@ -129,6 +121,7 @@ const DiffusionOptions = () => {
<div className="pr-2"> <div className="pr-2">
<Select <Select
defaultValue={settings.controlnetMethod}
value={settings.controlnetMethod} value={settings.controlnetMethod}
onValueChange={(value) => { onValueChange={(value) => {
updateSettings({ controlnetMethod: value }) updateSettings({ controlnetMethod: value })
@ -467,96 +460,104 @@ const DiffusionOptions = () => {
/> />
</RowContainer> </RowContainer>
<Tabs <RowContainer>
<Select
defaultValue={settings.extenderDirection} defaultValue={settings.extenderDirection}
onValueChange={(value) => updateExtenderDirection(value)} value={settings.extenderDirection}
className="flex flex-col justify-center items-center" onValueChange={(value) => {
updateExtenderDirection(value as ExtenderDirection)
}}
> >
<TabsList className="w-[140px] mb-2"> <SelectTrigger
<TabsTrigger value={EXTENDER_X} disabled={!settings.showExtender}> className="w-[65px] h-7"
<MoveHorizontal size={20} strokeWidth={1} />
</TabsTrigger>
<TabsTrigger value={EXTENDER_Y} disabled={!settings.showExtender}>
<MoveVertical size={20} strokeWidth={1} />
</TabsTrigger>
<TabsTrigger
value={EXTENDER_ALL}
disabled={!settings.showExtender} disabled={!settings.showExtender}
> >
<Move size={20} strokeWidth={1} /> <SelectValue placeholder="Select axis" />
</TabsTrigger> </SelectTrigger>
</TabsList> <SelectContent align="end">
<SelectGroup>
{Object.values(ExtenderDirection).map((v) => (
<SelectItem key={v} value={v}>
{v}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
<TabsContent <div className="flex gap-1 justify-center mt-0">
value={EXTENDER_X}
className="flex gap-2 justify-center mt-0"
>
<ExtenderButton <ExtenderButton
text="1.25x" text="1.25x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 1.25)} onClick={() =>
updateExtenderByBuiltIn(settings.extenderDirection, 1.25)
}
/> />
<ExtenderButton <ExtenderButton
text="1.5x" text="1.5x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 1.5)} onClick={() =>
updateExtenderByBuiltIn(settings.extenderDirection, 1.5)
}
/> />
<ExtenderButton <ExtenderButton
text="1.75x" text="1.75x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 1.75)} onClick={() =>
updateExtenderByBuiltIn(settings.extenderDirection, 1.75)
}
/> />
<ExtenderButton <ExtenderButton
text="2.0x" text="2.0x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 2.0)} onClick={() =>
updateExtenderByBuiltIn(settings.extenderDirection, 2.0)
}
/> />
</TabsContent> </div>
<TabsContent </RowContainer>
value={EXTENDER_Y}
className="flex gap-2 justify-center mt-0"
>
<ExtenderButton
text="1.25x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 1.25)}
/>
<ExtenderButton
text="1.5x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 1.5)}
/>
<ExtenderButton
text="1.75x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 1.75)}
/>
<ExtenderButton
text="2.0x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 2.0)}
/>
</TabsContent>
<TabsContent
value={EXTENDER_ALL}
className="flex gap-2 justify-center mt-0"
>
<ExtenderButton
text="1.25x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.25)}
/>
<ExtenderButton
text="1.5x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.5)}
/>
<ExtenderButton
text="1.75x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.75)}
/>
<ExtenderButton
text="2.0x"
onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 2.0)}
/>
</TabsContent>
</Tabs>
</div> </div>
<Separator /> <Separator />
</> </>
) )
} }
const renderPowerPaintTaskType = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return (
<RowContainer>
<LabelTitle
text="Task"
toolTip="When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
/>
<Select
defaultValue={settings.powerpaintTask}
value={settings.powerpaintTask}
onValueChange={(value: PowerPaintTask) => {
updateSettings({ powerpaintTask: value })
}}
disabled={settings.showExtender}
>
<SelectTrigger className="w-[140px]">
<SelectValue placeholder="Select task" />
</SelectTrigger>
<SelectContent align="end">
<SelectGroup>
{[
PowerPaintTask.text_guided,
PowerPaintTask.object_remove,
PowerPaintTask.shape_guided,
].map((task) => (
<SelectItem key={task} value={task}>
{task}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</RowContainer>
)
}
return ( return (
<div className="flex flex-col gap-4 mt-4"> <div className="flex flex-col gap-4 mt-4">
<RowContainer> <RowContainer>
@ -577,6 +578,7 @@ const DiffusionOptions = () => {
</RowContainer> </RowContainer>
{renderExtender()} {renderExtender()}
{renderPowerPaintTaskType()}
<div className="flex flex-col gap-1"> <div className="flex flex-col gap-1">
<LabelTitle <LabelTitle
@ -642,20 +644,20 @@ const DiffusionOptions = () => {
<RowContainer> <RowContainer>
<LabelTitle text="Sampler" /> <LabelTitle text="Sampler" />
<Select <Select
value={settings.sdSampler as string} defaultValue={settings.sdSampler}
onValueChange={(value) => { value={settings.sdSampler}
const sampler = value as SDSampler onValueChange={(value: SDSampler) => {
updateSettings({ sdSampler: sampler }) updateSettings({ sdSampler: value })
}} }}
> >
<SelectTrigger className="w-[100px]"> <SelectTrigger className="w-[120px]">
<SelectValue placeholder="Select sampler" /> <SelectValue placeholder="Select sampler" />
</SelectTrigger> </SelectTrigger>
<SelectContent align="end"> <SelectContent align="end">
<SelectGroup> <SelectGroup>
{Object.values(SDSampler).map((sampler) => ( {Object.values(SDSampler).map((sampler) => (
<SelectItem key={sampler as string} value={sampler as string}> <SelectItem key={sampler} value={sampler}>
{sampler as string} {sampler}
</SelectItem> </SelectItem>
))} ))}
</SelectGroup> </SelectGroup>
@ -707,9 +709,9 @@ const DiffusionOptions = () => {
<RowContainer> <RowContainer>
<Slider <Slider
className="w-[180px]" className="w-[180px]"
defaultValue={[5]} defaultValue={[settings.sdMaskBlur]}
min={0} min={0}
max={35} max={96}
step={1} step={1}
value={[Math.floor(settings.sdMaskBlur)]} value={[Math.floor(settings.sdMaskBlur)]}
onValueChange={(vals) => updateSettings({ sdMaskBlur: vals[0] })} onValueChange={(vals) => updateSettings({ sdMaskBlur: vals[0] })}

View File

@ -1,4 +1,4 @@
import { ModelInfo, Rect } from "@/lib/types" import { ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
import { Settings } from "@/lib/states" import { Settings } from "@/lib/states"
import { srcToFile } from "@/lib/utils" import { srcToFile } from "@/lib/utils"
import axios from "axios" import axios from "axios"
@ -22,7 +22,6 @@ export default async function inpaint(
const fd = new FormData() const fd = new FormData()
fd.append("image", imageFile) fd.append("image", imageFile)
fd.append("mask", mask) fd.append("mask", mask)
fd.append("ldmSteps", settings.ldmSteps.toString()) fd.append("ldmSteps", settings.ldmSteps.toString())
fd.append("ldmSampler", settings.ldmSampler.toString()) fd.append("ldmSampler", settings.ldmSampler.toString())
fd.append("zitsWireframe", settings.zitsWireframe.toString()) fd.append("zitsWireframe", settings.zitsWireframe.toString())
@ -51,6 +50,7 @@ export default async function inpaint(
fd.append("sdSteps", settings.sdSteps.toString()) fd.append("sdSteps", settings.sdSteps.toString())
fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString()) fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString())
fd.append("sdSampler", settings.sdSampler.toString()) fd.append("sdSampler", settings.sdSampler.toString())
if (settings.seedFixed) { if (settings.seedFixed) {
fd.append("sdSeed", settings.seed.toString()) fd.append("sdSeed", settings.seed.toString())
} else { } else {
@ -76,13 +76,20 @@ export default async function inpaint(
fd.append("p2pImageGuidanceScale", settings.p2pImageGuidanceScale.toString()) fd.append("p2pImageGuidanceScale", settings.p2pImageGuidanceScale.toString())
// ControlNet // ControlNet
fd.append("controlnet_enabled", settings.enableControlnet.toString()) fd.append("enable_controlnet", settings.enableControlnet.toString())
fd.append( fd.append(
"controlnet_conditioning_scale", "controlnet_conditioning_scale",
settings.controlnetConditioningScale.toString() settings.controlnetConditioningScale.toString()
) )
fd.append("controlnet_method", settings.controlnetMethod.toString()) fd.append("controlnet_method", settings.controlnetMethod.toString())
// PowerPaint
if (settings.showExtender) {
fd.append("powerpaintTask", PowerPaintTask.outpainting)
} else {
fd.append("powerpaintTask", settings.powerpaintTask)
}
try { try {
const res = await fetch(`${API_ENDPOINT}/inpaint`, { const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: "POST", method: "POST",

View File

@ -8,14 +8,13 @@ export const MODEL_TYPE_DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
export const MODEL_TYPE_OTHER = "diffusers_other" export const MODEL_TYPE_OTHER = "diffusers_other"
export const BRUSH_COLOR = "#ffcc00bb" export const BRUSH_COLOR = "#ffcc00bb"
export const EXTENDER_X = "extender_x"
export const EXTENDER_Y = "extender_y"
export const EXTENDER_ALL = "extender_all"
export const LDM = "ldm" export const LDM = "ldm"
export const CV2 = "cv2" export const CV2 = "cv2"
export const PAINT_BY_EXAMPLE = "Fantasy-Studio/Paint-by-Example" export const PAINT_BY_EXAMPLE = "Fantasy-Studio/Paint-by-Example"
export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix" export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"
export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint" export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
export const POWERPAINT = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
export const DEFAULT_NEGATIVE_PROMPT = export const DEFAULT_NEGATIVE_PROMPT =
"out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature" "out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature"

View File

@ -5,6 +5,7 @@ import { castDraft } from "immer"
import { createWithEqualityFn } from "zustand/traditional" import { createWithEqualityFn } from "zustand/traditional"
import { import {
CV2Flag, CV2Flag,
ExtenderDirection,
FreeuConfig, FreeuConfig,
LDMSampler, LDMSampler,
Line, Line,
@ -12,6 +13,7 @@ import {
ModelInfo, ModelInfo,
PluginParams, PluginParams,
Point, Point,
PowerPaintTask,
SDSampler, SDSampler,
Size, Size,
SortBy, SortBy,
@ -21,9 +23,6 @@ import {
BRUSH_COLOR, BRUSH_COLOR,
DEFAULT_BRUSH_SIZE, DEFAULT_BRUSH_SIZE,
DEFAULT_NEGATIVE_PROMPT, DEFAULT_NEGATIVE_PROMPT,
EXTENDER_ALL,
EXTENDER_X,
EXTENDER_Y,
MODEL_TYPE_INPAINT, MODEL_TYPE_INPAINT,
PAINT_BY_EXAMPLE, PAINT_BY_EXAMPLE,
} from "./const" } from "./const"
@ -60,7 +59,7 @@ export type Settings = {
enableUploadMask: boolean enableUploadMask: boolean
showCropper: boolean showCropper: boolean
showExtender: boolean showExtender: boolean
extenderDirection: string extenderDirection: ExtenderDirection
// For LDM // For LDM
ldmSteps: number ldmSteps: number
@ -99,6 +98,9 @@ export type Settings = {
enableLCMLora: boolean enableLCMLora: boolean
enableFreeu: boolean enableFreeu: boolean
freeuConfig: FreeuConfig freeuConfig: FreeuConfig
// PowerPaint
powerpaintTask: PowerPaintTask
} }
type ServerConfig = { type ServerConfig = {
@ -178,9 +180,9 @@ type AppAction = {
setExtenderWidth: (newValue: number) => void setExtenderWidth: (newValue: number) => void
setExtenderHeight: (newValue: number) => void setExtenderHeight: (newValue: number) => void
setIsCropperExtenderResizing: (newValue: boolean) => void setIsCropperExtenderResizing: (newValue: boolean) => void
updateExtenderDirection: (newValue: string) => void updateExtenderDirection: (newValue: ExtenderDirection) => void
resetExtender: (width: number, height: number) => void resetExtender: (width: number, height: number) => void
updateExtenderByBuiltIn: (direction: string, scale: number) => void updateExtenderByBuiltIn: (direction: ExtenderDirection, scale: number) => void
setServerConfig: (newValue: ServerConfig) => void setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void setSeed: (newValue: number) => void
@ -296,7 +298,7 @@ const defaultValues: AppState = {
enableControlnet: false, enableControlnet: false,
showCropper: false, showCropper: false,
showExtender: false, showExtender: false,
extenderDirection: EXTENDER_ALL, extenderDirection: ExtenderDirection.xy,
enableDownloadMask: false, enableDownloadMask: false,
enableManualInpainting: false, enableManualInpainting: false,
enableUploadMask: false, enableUploadMask: false,
@ -309,7 +311,7 @@ const defaultValues: AppState = {
negativePrompt: DEFAULT_NEGATIVE_PROMPT, negativePrompt: DEFAULT_NEGATIVE_PROMPT,
seed: 42, seed: 42,
seedFixed: false, seedFixed: false,
sdMaskBlur: 5, sdMaskBlur: 35,
sdStrength: 1.0, sdStrength: 1.0,
sdSteps: 50, sdSteps: 50,
sdGuidanceScale: 7.5, sdGuidanceScale: 7.5,
@ -322,6 +324,7 @@ const defaultValues: AppState = {
enableLCMLora: false, enableLCMLora: false,
enableFreeu: false, enableFreeu: false,
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 }, freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
powerpaintTask: PowerPaintTask.text_guided,
}, },
} }
@ -894,7 +897,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
state.isCropperExtenderResizing = newValue state.isCropperExtenderResizing = newValue
}), }),
updateExtenderDirection: (newValue: string) => { updateExtenderDirection: (newValue: ExtenderDirection) => {
console.log( console.log(
`updateExtenderDirection: ${JSON.stringify(get().extenderState)}` `updateExtenderDirection: ${JSON.stringify(get().extenderState)}`
) )
@ -908,7 +911,10 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
get().updateExtenderByBuiltIn(newValue, 1.5) get().updateExtenderByBuiltIn(newValue, 1.5)
}, },
updateExtenderByBuiltIn: (direction: string, scale: number) => { updateExtenderByBuiltIn: (
direction: ExtenderDirection,
scale: number
) => {
const newExtenderState = { ...defaultValues.extenderState } const newExtenderState = { ...defaultValues.extenderState }
let { x, y, width, height } = newExtenderState let { x, y, width, height } = newExtenderState
const { imageWidth, imageHeight } = get() const { imageWidth, imageHeight } = get()
@ -916,15 +922,15 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
height = imageHeight height = imageHeight
switch (direction) { switch (direction) {
case EXTENDER_X: case ExtenderDirection.x:
x = -Math.ceil((imageWidth * (scale - 1)) / 2) x = -Math.ceil((imageWidth * (scale - 1)) / 2)
width = Math.ceil(imageWidth * scale) width = Math.ceil(imageWidth * scale)
break break
case EXTENDER_Y: case ExtenderDirection.y:
y = -Math.ceil((imageHeight * (scale - 1)) / 2) y = -Math.ceil((imageHeight * (scale - 1)) / 2)
height = Math.ceil(imageHeight * scale) height = Math.ceil(imageHeight * scale)
break break
case EXTENDER_ALL: case ExtenderDirection.xy:
x = -Math.ceil((imageWidth * (scale - 1)) / 2) x = -Math.ceil((imageWidth * (scale - 1)) / 2)
y = -Math.ceil((imageHeight * (scale - 1)) / 2) y = -Math.ceil((imageHeight * (scale - 1)) / 2)
width = Math.ceil(imageWidth * scale) width = Math.ceil(imageWidth * scale)

View File

@ -93,3 +93,16 @@ export interface Size {
width: number width: number
height: number height: number
} }
export enum ExtenderDirection {
x = "x",
y = "y",
xy = "xy",
}
export enum PowerPaintTask {
text_guided = "text-guided",
shape_guided = "shape-guided",
object_remove = "object-remove",
outpainting = "outpainting",
}