This commit is contained in:
Qing 2023-12-24 15:32:27 +08:00
parent 0e5e16ba20
commit 371db2d771
31 changed files with 441 additions and 439 deletions

View File

@ -10,11 +10,8 @@ from lama_cleaner.parse_args import parse_args
def entry_point():
args = parse_args()
if args is None:
return
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
from lama_cleaner.server import main
from lama_cleaner.server import typer_app
main(args)
typer_app()

View File

@ -103,6 +103,5 @@ if __name__ == "__main__":
device=device,
disable_nsfw=True,
sd_cpu_textencoder=True,
hf_access_token="123"
)
benchmark(model, args.times, args.empty_cache)

View File

@ -21,19 +21,17 @@ AVAILABLE_MODELS = [
"zits",
"mat",
"fcf",
"manga",
"cv2",
"sd1.5",
"sdxl",
"anything4",
"realisticVision1.4",
"cv2",
"manga",
"sd2",
"sdxl",
"paint_by_example",
"instruct_pix2pix",
"kandinsky2.2",
"sdxl",
]
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
DIFFUSERS_MODEL_FP16_REVERSION = [
"runwayml/stable-diffusion-inpainting",
"Sanster/anything-4.0-inpainting",
@ -46,26 +44,22 @@ AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
DEFAULT_DEVICE = "cuda"
NO_HALF_HELP = """
Using full precision model.
If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
Using full precision(fp32) model.
If your diffusion model generate result is always black or green, use this argument.
"""
CPU_OFFLOAD_HELP = """
Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
"""
DISABLE_NSFW_HELP = """
Disable NSFW checker. (sd/paint_by_example)
Disable NSFW checker for diffusion model.
"""
SD_CPU_TEXTENCODER_HELP = """
Run Stable Diffusion text encoder model on CPU to save GPU memory.
CPU_TEXTENCODER_HELP = """
Run diffusion models text encoder on CPU to reduce vRAM usage.
"""
SD_CONTROLNET_HELP = """
Run Stable Diffusion normal or inpainting model with ControlNet.
"""
DEFAULT_SD_CONTROLNET_METHOD = "lllyasviel/control_v11p_sd15_canny"
SD_CONTROLNET_CHOICES = [
"lllyasviel/control_v11p_sd15_canny",
# "lllyasviel/control_v11p_sd15_seg",
@ -74,46 +68,36 @@ SD_CONTROLNET_CHOICES = [
"lllyasviel/control_v11f1p_sd15_depth",
]
DEFAULT_SD2_CONTROLNET_METHOD = "thibaud/controlnet-sd21-canny-diffusers"
SD2_CONTROLNET_CHOICES = [
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-openpose-diffusers",
]
DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0"
SDXL_CONTROLNET_CHOICES = [
"thibaud/controlnet-openpose-sdxl-1.0",
"destitech/controlnet-inpaint-dreamer-sdxl"
"destitech/controlnet-inpaint-dreamer-sdxl",
"diffusers/controlnet-canny-sdxl-1.0",
"diffusers/controlnet-canny-sdxl-1.0-mid",
"diffusers/controlnet-canny-sdxl-1.0-small"
"diffusers/controlnet-canny-sdxl-1.0-small",
"diffusers/controlnet-depth-sdxl-1.0",
"diffusers/controlnet-depth-sdxl-1.0-mid",
"diffusers/controlnet-depth-sdxl-1.0-small",
]
SD_LOCAL_MODEL_HELP = """
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
"""
LOCAL_FILES_ONLY_HELP = """
Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
"""
ENABLE_XFORMERS_HELP = """
Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
When loading diffusion models, using local files only, not connect to HuggingFace server.
"""
DEFAULT_MODEL_DIR = os.getenv(
"XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
)
MODEL_DIR_HELP = """
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
MODEL_DIR_HELP = f"""
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
"""
OUTPUT_DIR_HELP = """
Result images will be saved to output directory automatically without confirmation.
Result images will be saved to output directory automatically.
"""
INPUT_HELP = """
@ -125,37 +109,45 @@ GUI_HELP = """
Launch Lama Cleaner as desktop app
"""
NO_GUI_AUTO_CLOSE_HELP = """
Prevent backend auto close after the GUI window closed.
"""
QUALITY_HELP = """
Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
"""
class RealESRGANModelName(str, Enum):
class Choices(str, Enum):
@classmethod
def values(cls):
return [member.value for member in cls]
class RealESRGANModel(Choices):
realesr_general_x4v3 = "realesr-general-x4v3"
RealESRGAN_x4plus = "RealESRGAN_x4plus"
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
class Device(Choices):
cpu = "cpu"
cuda = "cuda"
mps = "mps"
class InteractiveSegModel(Choices):
vit_b = "vit_b"
vit_l = "vit_l"
vit_h = "vit_h"
mobile_sam = "mobile_sam"
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h", "vit_t"]
AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"]
REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
REALESRGAN_HELP = "Enable realesrgan super resolution"
REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
GFPGAN_HELP = (
"Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan"
)
GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan"
RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
@ -164,8 +156,6 @@ class Config(BaseModel):
port: int = 8080
model: str = DEFAULT_MODEL
sd_local_model_path: str = None
sd_controlnet: bool = False
sd_controlnet_method: str = DEFAULT_SD_CONTROLNET_METHOD
device: str = DEFAULT_DEVICE
gui: bool = False
no_gui_auto_close: bool = False
@ -173,7 +163,6 @@ class Config(BaseModel):
cpu_offload: bool = False
disable_nsfw: bool = False
sd_cpu_textencoder: bool = False
enable_xformers: bool = False
local_files_only: bool = False
model_dir: str = DEFAULT_MODEL_DIR
input: str = None
@ -186,7 +175,7 @@ class Config(BaseModel):
enable_anime_seg: bool = False
enable_realesrgan: bool = False
realesrgan_device: str = "cpu"
realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value
realesrgan_model: str = RealESRGANModel.realesr_general_x4v3.value
realesrgan_no_half: bool = False
enable_gfpgan: bool = False
gfpgan_device: str = "cpu"

View File

@ -6,6 +6,7 @@ from loguru import logger
from pathlib import Path
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
from lama_cleaner.runtime import setup_model_dir
from lama_cleaner.schema import (
ModelInfo,
ModelType,
@ -16,16 +17,8 @@ from lama_cleaner.schema import (
)
def cli_download_model(model: str, model_dir: str):
if os.path.isfile(model_dir):
raise ValueError(f"invalid --model-dir: {model_dir} is a file")
if not os.path.exists(model_dir):
logger.info(f"Create model cache directory: {model_dir}")
Path(model_dir).mkdir(exist_ok=True, parents=True)
os.environ["XDG_CACHE_HOME"] = model_dir
def cli_download_model(model: str, model_dir: Path):
setup_model_dir(model_dir)
from lama_cleaner.model import models
if model in models:
@ -38,7 +31,7 @@ def cli_download_model(model: str, model_dir: str):
downloaded_path = DiffusionPipeline.download(
pretrained_model_name=model,
revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
resume_download=True,
)
logger.info(f"Done. Downloaded to {downloaded_path}")
@ -101,7 +94,7 @@ def scan_inpaint_models() -> List[ModelInfo]:
from lama_cleaner.model import models
for name, m in models.items():
if m.is_erase_model:
if m.is_erase_model and m.is_downloaded():
res.append(
ModelInfo(
name=name,

View File

@ -41,7 +41,7 @@ class InpaintModel:
@staticmethod
@abc.abstractmethod
def is_downloaded() -> bool:
...
return False
@abc.abstractmethod
def forward(self, image, mask, config: Config):
@ -67,6 +67,8 @@ class InpaintModel:
logger.info(f"final forward pad size: {pad_image.shape}")
image, mask = self.forward_pre_process(image, mask, config)
result = self.forward(pad_image, pad_mask, config)
result = result[0:origin_height, 0:origin_width, :]
@ -77,6 +79,9 @@ class InpaintModel:
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
return result
def forward_pre_process(self, image, mask, config):
return image, mask
def forward_post_process(self, result, image, mask, config):
return result, image, mask
@ -400,6 +405,13 @@ class DiffusionInpaintModel(InpaintModel):
scheduler = get_scheduler(sd_sampler, scheduler_config)
self.model.scheduler = scheduler
def forward_pre_process(self, image, mask, config):
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
return image, mask
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -17,14 +17,6 @@ from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, ModelInfo, ModelType
# 为了兼容性
controlnet_name_map = {
"control_v11p_sd15_canny": "lllyasviel/control_v11p_sd15_canny",
"control_v11p_sd15_openpose": "lllyasviel/control_v11p_sd15_openpose",
"control_v11p_sd15_inpaint": "lllyasviel/control_v11p_sd15_inpaint",
"control_v11f1p_sd15_depth": "lllyasviel/control_v11f1p_sd15_depth",
}
class ControlNet(DiffusionInpaintModel):
name = "controlnet"
@ -49,9 +41,6 @@ class ControlNet(DiffusionInpaintModel):
fp16 = not kwargs.get("no_half", False)
model_info: ModelInfo = kwargs["model_info"]
sd_controlnet_method = kwargs["sd_controlnet_method"]
sd_controlnet_method = controlnet_name_map.get(
sd_controlnet_method, sd_controlnet_method
)
self.model_info = model_info
self.sd_controlnet_method = sd_controlnet_method
@ -113,12 +102,6 @@ class ControlNet(DiffusionInpaintModel):
**model_kwargs,
)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
@ -162,10 +145,6 @@ class ControlNet(DiffusionInpaintModel):
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
self.model.scheduler = scheduler
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2]
control_image = self._get_control_image(image, mask)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
@ -190,8 +169,3 @@ class ControlNet(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -31,30 +31,15 @@ class InstructPix2Pix(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix",
revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype,
**model_kwargs
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
)
self.model.enable_attention_slicing()
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
@staticmethod
def download():
from diffusers import StableDiffusionInstructPix2PixPipeline
StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", revision="fp16"
)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -76,8 +61,3 @@ class InstructPix2Pix(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel):
}
self.model = AutoPipelineForInpainting.from_pretrained(
self.model_name, **model_kwargs
self.model_id_or_path, **model_kwargs
).to(device)
self.callback = kwargs.pop("callback", None)
@ -40,9 +40,6 @@ class Kandinsky(DiffusionInpaintModel):
self.model.scheduler = scheduler
generator = torch.manual_seed(config.sd_seed)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
mask = mask.astype(np.float32) / 255
img_h, img_w = image.shape[:2]
@ -66,20 +63,7 @@ class Kandinsky(DiffusionInpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True
class Kandinsky22(Kandinsky):
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
@staticmethod
def download():
from diffusers import AutoPipelineForInpainting
AutoPipelineForInpainting.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder-inpaint"
)
name = "kandinsky2.2"
model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint"

View File

@ -31,10 +31,6 @@ class PaintByExample(DiffusionInpaintModel):
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
)
self.model.enable_attention_slicing()
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
# TODO: gpu_id
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
@ -68,8 +64,3 @@ class PaintByExample(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -1,8 +1,5 @@
import os
import PIL.Image
import cv2
import numpy as np
import torch
from loguru import logger
@ -49,23 +46,12 @@ class SD(DiffusionInpaintModel):
self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path,
revision="fp16"
if (
self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
and use_gpu
and fp16
)
if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
else "main",
torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"],
**model_kwargs,
)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu:
# TODO: gpu_id
logger.info("Enable sequential cpu offload")
@ -88,10 +74,6 @@ class SD(DiffusionInpaintModel):
"""
self.set_scheduler(config)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2]
output = self.model(
@ -114,17 +96,6 @@ class SD(DiffusionInpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True
@classmethod
def download(cls):
from diffusers import StableDiffusionInpaintPipeline
StableDiffusionInpaintPipeline.from_pretrained(cls.model_id_or_path)
class SD15(SD):
name = "sd1.5"

View File

@ -45,16 +45,9 @@ class SDXL(DiffusionInpaintModel):
self.model_id_or_path,
revision="main",
torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"],
vae=vae,
)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
@ -65,14 +58,6 @@ class SDXL(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
@staticmethod
def download():
from diffusers import AutoPipelineForInpainting
AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -81,10 +66,6 @@ class SDXL(DiffusionInpaintModel):
"""
self.set_scheduler(config)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2]
output = self.model(
@ -106,8 +87,3 @@ class SDXL(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -3,7 +3,6 @@ from typing import List, Dict
import torch
from loguru import logger
from lama_cleaner.const import DEFAULT_SD_CONTROLNET_METHOD
from lama_cleaner.download import scan_models
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model import models, ControlNet, SD, SDXL
@ -19,16 +18,25 @@ class ModelManager:
self.available_models: Dict[str, ModelInfo] = {}
self.scan_models()
self.sd_controlnet = kwargs.get("sd_controlnet", False)
self.sd_controlnet_method = kwargs.get(
"sd_controlnet_method", DEFAULT_SD_CONTROLNET_METHOD
)
self.sd_controlnet = False
self.sd_controlnet_method = ""
self.model = self.init_model(name, device, **kwargs)
def init_model(self, name: str, device, **kwargs):
def _map_old_name(self, name: str) -> str:
for old_name, model_cls in models.items():
if name == old_name and hasattr(model_cls, "model_id_or_path"):
name = model_cls.model_id_or_path
break
return name
@property
def current_model(self) -> Dict:
name = self._map_old_name(self.name)
return self.available_models[name].model_dump()
def init_model(self, name: str, device, **kwargs):
name = self._map_old_name(name)
logger.info(f"Loading model: {name}")
if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}")
@ -86,6 +94,7 @@ class ModelManager:
):
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
try:
# TODO: enable/disable controlnet without reload model
del self.model
torch_gc()

View File

@ -55,7 +55,7 @@ def parse_args():
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
parser.add_argument(
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
"--sd-cpu-textencoder", action="store_true", help=CPU_TEXTENCODER_HELP
)
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
parser.add_argument(
@ -66,16 +66,10 @@ def parse_args():
parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
)
parser.add_argument(
"--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP
)
parser.add_argument(
"--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES
)
parser.add_argument("--gui", action="store_true", help=GUI_HELP)
parser.add_argument(
"--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP
)
parser.add_argument(
"--gui-size",
default=[1600, 1000],

View File

@ -22,7 +22,7 @@ SEGMENT_ANYTHING_MODELS = {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"md5": "4b8939a88964f0f4ff5f5b2642c598a6",
},
"vit_t": {
"mobile_sam": {
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
"md5": "f3c0d8cda613564d499310dab6c812cd",
},

View File

@ -3,7 +3,7 @@ from enum import Enum
import cv2
from loguru import logger
from lama_cleaner.const import RealESRGANModelName
from lama_cleaner.const import RealESRGANModel
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
@ -18,7 +18,7 @@ class RealESRGANUpscaler(BasePlugin):
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
REAL_ESRGAN_MODELS = {
RealESRGANModelName.realesr_general_x4v3: {
RealESRGANModel.realesr_general_x4v3: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
"scale": 4,
"model": lambda: SRVGGNetCompact(
@ -31,7 +31,7 @@ class RealESRGANUpscaler(BasePlugin):
),
"model_md5": "91a7644643c884ee00737db24e478156",
},
RealESRGANModelName.RealESRGAN_x4plus: {
RealESRGANModel.RealESRGAN_x4plus: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"scale": 4,
"model": lambda: RRDBNet(
@ -44,7 +44,7 @@ class RealESRGANUpscaler(BasePlugin):
),
"model_md5": "99ec365d4afad750833258a1a24f44ca",
},
RealESRGANModelName.RealESRGAN_x4plus_anime_6B: {
RealESRGANModel.RealESRGAN_x4plus_anime_6B: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"scale": 4,
"model": lambda: RRDBNet(

View File

@ -109,7 +109,7 @@ sam_model_registry = {
"vit_h": build_sam,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
"vit_t": build_sam_vit_t,
"mobile_sam": build_sam_vit_t,
}

View File

@ -1,10 +1,16 @@
# https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py
import os
import platform
import sys
from pathlib import Path
import packaging.version
from loguru import logger
from rich import print
from typing import Dict, Any
from lama_cleaner.const import Device
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"):
@ -21,7 +27,6 @@ _CANDIDATES = [
"diffusers",
"transformers",
"opencv-python",
"xformers",
"accelerate",
"lama-cleaner",
"rembg",
@ -48,3 +53,34 @@ def dump_environment_info() -> Dict[str, str]:
info.update(_package_versions)
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
return info
def check_device(device: Device) -> Device:
if device == Device.cuda:
import platform
if platform.system() == "Darwin":
logger.warning("MacOS does not support cuda, use cpu instead")
return Device.cpu
else:
import torch
if not torch.cuda.is_available():
logger.warning("CUDA is not available, use cpu instead")
return Device.cpu
elif device == Device.mps:
import torch
if not torch.backends.mps.is_available():
logger.warning("mps is not available, use cpu instead")
return Device.cpu
return device
def setup_model_dir(model_dir: Path):
model_dir = model_dir.expanduser().absolute()
os.environ["U2NET_HOME"] = str(model_dir)
os.environ["XDG_CACHE_HOME"] = str(model_dir)
if not model_dir.exists():
logger.info(f"Create model directory: {model_dir}")
model_dir.mkdir(exist_ok=True, parents=True)

View File

@ -1,10 +1,18 @@
#!/usr/bin/env python3
import json
import os
import hashlib
import traceback
import typer
from typer import Option
from lama_cleaner.download import cli_download_model, scan_models
from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import hashlib
import traceback
from dataclasses import dataclass
import imghdr
import io
@ -20,12 +28,7 @@ import torch
from PIL import Image
from loguru import logger
from lama_cleaner.const import (
SD15_MODELS,
SD_CONTROLNET_CHOICES,
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
)
from lama_cleaner.const import *
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
@ -39,6 +42,8 @@ from lama_cleaner.plugins import (
)
from lama_cleaner.schema import Config
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
@ -103,24 +108,35 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition", "X-seed"])
CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
sio_logger = logging.getLogger("sio-logger")
sio_logger.setLevel(logging.ERROR)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
model: ModelManager = None
thumb: FileManager = None
output_dir: str = None
device = None
input_image_path: str = None
is_disable_model_switch: bool = False
enable_file_manager: bool = False
enable_auto_saving: bool = False
@dataclass
class GlobalConfig:
model_manager: ModelManager = None
file_manager: FileManager = None
output_dir: Path = None
input_image_path: Path = None
disable_model_switch: bool = False
is_desktop: bool = False
image_quality: int = 95
plugins = {}
@property
def enable_auto_saving(self) -> bool:
return self.output_dir is not None
@property
def enable_file_manager(self) -> bool:
return self.file_manager is not None
global_config = GlobalConfig()
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
@ -135,7 +151,7 @@ def diffuser_callback(i, t, latents):
@app.route("/save_image", methods=["POST"])
def save_image():
if output_dir is None:
if global_config.output_dir is None:
return "--output-dir is None", 500
input = request.files
@ -143,7 +159,7 @@ def save_image():
origin_image_bytes = input["image"].read() # RGB
ext = get_image_ext(origin_image_bytes)
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
save_path = os.path.join(output_dir, filename)
save_path = str(global_config.output_dir / filename)
if alpha_channel is not None:
if alpha_channel.shape[:2] != image.shape[:2]:
@ -157,7 +173,7 @@ def save_image():
img_bytes = pil_to_bytes(
pil_image,
ext,
quality=image_quality,
quality=global_config.image_quality,
exif_infos=exif_infos,
)
with open(save_path, "wb") as fw:
@ -169,9 +185,11 @@ def save_image():
@app.route("/medias/<tab>")
def medias(tab):
if tab == "image":
response = make_response(jsonify(thumb.media_names), 200)
response = make_response(jsonify(global_config.file_manager.media_names), 200)
else:
response = make_response(jsonify(thumb.output_media_names), 200)
response = make_response(
jsonify(global_config.file_manager.output_media_names), 200
)
# response.last_modified = thumb.modified_time[tab]
# response.cache_control.no_cache = True
# response.cache_control.max_age = 0
@ -182,8 +200,8 @@ def medias(tab):
@app.route("/media/<tab>/<filename>")
def media_file(tab, filename):
if tab == "image":
return send_from_directory(thumb.root_directory, filename)
return send_from_directory(thumb.output_dir, filename)
return send_from_directory(global_config.file_manager.root_directory, filename)
return send_from_directory(global_config.file_manager.output_dir, filename)
@app.route("/media_thumbnail/<tab>/<filename>")
@ -198,10 +216,10 @@ def media_thumbnail_file(tab, filename):
if height:
height = int(float(height))
directory = thumb.root_directory
directory = global_config.file_manager.root_directory
if tab == "output":
directory = thumb.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail(
directory = global_config.file_manager.output_dir
thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
directory, filename, width, height
)
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
@ -257,13 +275,11 @@ def process():
croper_y=form["croperY"],
croper_height=form["croperHeight"],
croper_width=form["croperWidth"],
use_extender=form["useExtender"],
extender_x=form["extenderX"],
extender_y=form["extenderY"],
extender_height=form["extenderHeight"],
extender_width=form["extenderWidth"],
sd_scale=form["sdScale"],
sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"],
@ -294,7 +310,7 @@ def process():
start = time.time()
try:
res_np_img = model(image, mask, config)
res_np_img = global_config.model_manager(image, mask, config)
except RuntimeError as e:
if "CUDA out of memory. " in str(e):
# NOTE: the string may change?
@ -322,7 +338,7 @@ def process():
pil_to_bytes(
Image.fromarray(res_np_img),
ext,
quality=image_quality,
quality=global_config.image_quality,
exif_infos=exif_infos,
)
)
@ -345,7 +361,7 @@ def run_plugin():
form = request.form
files = request.files
name = form["name"]
if name not in plugins:
if name not in global_config.plugins:
return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB
@ -359,7 +375,7 @@ def run_plugin():
if name == InteractiveSeg.name:
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
form["img_md5"] = img_md5
bgr_res = plugins[name](rgb_np_img, files, form)
bgr_res = global_config.plugins[name](rgb_np_img, files, form)
except RuntimeError as e:
torch.cuda.empty_cache()
if "CUDA out of memory. " in str(e):
@ -401,7 +417,7 @@ def run_plugin():
pil_to_bytes(
Image.fromarray(rgb_res),
ext,
quality=image_quality,
quality=global_config.image_quality,
exif_infos=exif_infos,
)
),
@ -414,41 +430,40 @@ def run_plugin():
@app.route("/server_config", methods=["GET"])
def get_server_config():
return {
"plugins": list(plugins.keys()),
"enableFileManager": enable_file_manager,
"enableAutoSaving": enable_auto_saving,
"enableControlnet": model.sd_controlnet,
"controlnetMethod": model.sd_controlnet_method,
"disableModelSwitch": is_disable_model_switch,
"plugins": list(global_config.plugins.keys()),
"enableFileManager": global_config.enable_file_manager,
"enableAutoSaving": global_config.enable_auto_saving,
"enableControlnet": global_config.model_manager.sd_controlnet,
"controlnetMethod": global_config.model_manager.sd_controlnet_method,
"disableModelSwitch": global_config.disable_model_switch,
"isDesktop": global_config.is_desktop,
}, 200
@app.route("/models", methods=["GET"])
def get_models():
return [it.model_dump() for it in model.scan_models()]
return [it.model_dump() for it in global_config.model_manager.scan_models()]
@app.route("/model")
def current_model():
return model.available_models[model.name].model_dump(), 200
@app.route("/is_desktop")
def get_is_desktop():
return str(is_desktop), 200
return (
global_config.model_manager.current_model,
200,
)
@app.route("/model", methods=["POST"])
def switch_model():
if is_disable_model_switch:
if global_config.disable_model_switch:
return "Switch model is disabled", 400
new_name = request.form.get("name")
if new_name == model.name:
if new_name == global_config.model_manager.name:
return "Same model", 200
try:
model.switch(new_name)
global_config.model_manager.switch(new_name)
except Exception as e:
traceback.print_exc()
error_message = f"{type(e).__name__} - {str(e)}"
@ -464,160 +479,230 @@ def index():
@app.route("/inputimage")
def get_cli_input_image():
if input_image_path:
with open(input_image_path, "rb") as f:
if global_config.input_image_path:
with open(global_config.input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
input_image_path,
global_config.input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
download_name=Path(global_config.input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def build_plugins(args):
global plugins
if args.enable_interactive_seg:
def build_plugins(
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: str,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
no_half: bool,
):
if enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg(
args.interactive_seg_model, args.interactive_seg_device
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
interactive_seg_model, interactive_seg_device
)
if args.enable_remove_bg:
if enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
global_config.plugins[RemoveBG.name] = RemoveBG()
if args.enable_anime_seg:
if enable_anime_seg:
logger.info(f"Initialize {AnimeSeg.name} plugin")
plugins[AnimeSeg.name] = AnimeSeg()
global_config.plugins[AnimeSeg.name] = AnimeSeg()
if args.enable_realesrgan:
if enable_realesrgan:
logger.info(
f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}"
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
)
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
args.realesrgan_model,
args.realesrgan_device,
no_half=args.realesrgan_no_half,
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
realesrgan_model,
realesrgan_device,
no_half=no_half,
)
if args.enable_gfpgan:
if enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
if args.enable_realesrgan:
if enable_realesrgan:
logger.info("Use realesrgan as GFPGAN background upscaler")
else:
logger.info(
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
)
plugins[GFPGANPlugin.name] = GFPGANPlugin(
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
gfpgan_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
)
if args.enable_restoreformer:
if enable_restoreformer:
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
args.restoreformer_device,
upscaler=plugins.get(RealESRGANUpscaler.name, None),
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
restoreformer_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
)
def main(args):
global model
global device
global input_image_path
global is_disable_model_switch
global enable_file_manager
global is_desktop
global thumb
global output_dir
global image_quality
global enable_auto_saving
build_plugins(args)
@typer_app.command(help="Install all plugins dependencies")
def install_plugins_packages():
from lama_cleaner.installer import install_plugins_package
image_quality = args.quality
output_dir = args.output_dir
if output_dir:
output_dir = os.path.abspath(output_dir)
logger.info(f"Output dir: {output_dir}")
enable_auto_saving = True
install_plugins_package()
device = torch.device(args.device)
is_disable_model_switch = args.disable_model_switch
is_desktop = args.gui
if is_disable_model_switch:
logger.info(
f"Start with --disable-model-switch, model switch on frontend is disable"
)
if args.input and os.path.isdir(args.input):
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
def download(
model: str = Option(
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
),
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
cli_download_model(model, model_dir)
@typer_app.command(help="List downloaded models")
def list_model(
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
setup_model_dir(model_dir)
scanned_models = scan_models()
for it in scanned_models:
print(it.name)
@typer_app.command(help="Start lama cleaner server")
def start(
host: str = Option("127.0.0.1"),
port: int = Option(8080),
model: str = Option(
DEFAULT_MODEL,
help=f"Available models: [{', '.join(AVAILABLE_MODELS)}]. "
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
no_half: bool = Option(False, help=NO_HALF_HELP),
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu),
gui: bool = Option(False, help=GUI_HELP),
disable_model_switch: bool = Option(False),
input: Path = Option(None, help=INPUT_HELP),
output_dir: Path = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
),
quality: int = Option(95, help=QUALITY_HELP),
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
interactive_seg_model: InteractiveSegModel = Option(
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
),
interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu),
realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3),
enable_gfpgan: bool = Option(False),
gfpgan_device: Device = Option(Device.cpu),
enable_restoreformer: bool = Option(False),
restoreformer_device: Device = Option(Device.cpu),
):
global global_config
dump_environment_info()
if input:
if not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if input.is_dir():
logger.info(f"Initialize file manager")
thumb = FileManager(app)
enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
file_manager = FileManager(app)
app.config["THUMBNAIL_MEDIA_ROOT"] = input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
thumb.output_dir = Path(output_dir)
# thumb.start()
# try:
# while True:
# time.sleep(1)
# finally:
# thumb.image_dir_observer.stop()
# thumb.image_dir_observer.join()
# thumb.output_dir_observer.stop()
# thumb.output_dir_observer.join()
file_manager.output_dir = output_dir
else:
input_image_path = args.input
global_config.input_image_path = input
# 为了兼容性
model_name_map = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
"sd2": "stabilityai/stable-diffusion-2-inpainting",
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
}
device = check_device(device)
setup_model_dir(model_dir)
model = ModelManager(
name=model_name_map.get(args.model, args.model),
sd_controlnet=args.sd_controlnet,
sd_controlnet_method=args.sd_controlnet_method,
device=device,
no_half=args.no_half,
hf_access_token=args.hf_access_token,
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder,
cpu_offload=args.cpu_offload,
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
if local_files_only:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
if model not in AVAILABLE_MODELS:
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.error(
f"invalid --model: {model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}"
)
exit()
global_config.image_quality = quality
global_config.disable_model_switch = disable_model_switch
global_config.is_desktop = gui
build_plugins(
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
no_half,
)
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will auto save to output dir: {output_dir}")
global_config.output_dir = output_dir
global_config.model_manager = ModelManager(
name=model,
device=torch.device(device),
no_half=no_half,
disable_nsfw=disable_nsfw_checker,
sd_cpu_textencoder=cpu_textencoder,
cpu_offload=cpu_offload,
callback=diffuser_callback,
)
if args.gui:
app_width, app_height = args.gui_size
if gui:
from flaskwebgui import FlaskUI
ui = FlaskUI(
app,
socketio=socketio,
width=app_width,
height=app_height,
host=args.host,
port=args.port,
close_server_on_exit=not args.no_gui_auto_close,
width=1200,
height=800,
host=host,
port=port,
close_server_on_exit=True,
idle_interval=60,
)
ui.run()
else:
socketio.run(
app,
host=args.host,
port=args.port,
debug=args.debug,
host=host,
port=port,
allow_unsafe_werkzeug=True,
)

View File

@ -39,7 +39,6 @@ def test_runway_sd_1_5(
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder,
sd_controlnet_method=sd_controlnet_method,
@ -87,7 +86,6 @@ def test_local_file_path(sd_device, sampler):
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
@ -125,7 +123,6 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
@ -166,7 +163,6 @@ def test_controlnet_switch(sd_device, sampler):
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,

View File

@ -21,7 +21,6 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
model = ModelManager(
name=model_name,
device=torch.device(device),
hf_access_token="",
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False,
cpu_offload=cpu_offload,
@ -52,7 +51,6 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
model = ModelManager(
name=model_name,
device=torch.device(device),
hf_access_token="",
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False,
cpu_offload=cpu_offload,

View File

@ -17,11 +17,9 @@ def test_load_model():
name=m,
device="cpu",
no_half=False,
hf_access_token="",
disable_nsfw=False,
sd_cpu_textencoder=True,
cpu_offload=True,
enable_xformers=False,
)

View File

@ -16,11 +16,9 @@ def test_model_switch():
sd_controlnet=True,
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
device=torch.device("mps"),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
enable_xformers=False,
callback=None,
)
@ -34,11 +32,9 @@ def test_controlnet_switch_onoff(caplog):
sd_controlnet=True,
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
device=torch.device("mps"),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
enable_xformers=False,
callback=None,
)
@ -61,11 +57,9 @@ def test_controlnet_switch_method(caplog):
sd_controlnet=True,
sd_controlnet_method=old_method,
device=torch.device("mps"),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
enable_xformers=False,
callback=None,
)

View File

@ -41,7 +41,6 @@ def test_outpainting(name, sd_device, rect):
model = ModelManager(
name=name,
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
@ -86,7 +85,6 @@ def test_kandinsky_outpainting(name, sd_device, rect):
model = ModelManager(
name=name,
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,

View File

@ -38,7 +38,6 @@ def test_runway_sd_1_5_all_samplers(
model = ModelManager(
name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
)
@ -69,7 +68,6 @@ def test_runway_sd_lcm_lora(sd_device, strategy, sampler):
model = ModelManager(
name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
)
@ -102,7 +100,6 @@ def test_runway_sd_freeu(sd_device, strategy, sampler):
model = ModelManager(
name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
)
@ -136,7 +133,6 @@ def test_runway_sd_sd_strength(sd_device, strategy, sampler):
model = ModelManager(
name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
)
@ -165,7 +161,6 @@ def test_runway_norm_sd_model(sd_device, strategy, sampler):
model = ModelManager(
name="runwayml/stable-diffusion-v1-5",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
)
@ -192,7 +187,6 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
model = ModelManager(
name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
@ -229,7 +223,6 @@ def test_local_file_path(sd_device, sampler, name):
model = ModelManager(
name=name,
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=False,

View File

@ -29,7 +29,6 @@ def test_sdxl(sd_device, strategy, sampler):
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
@ -70,7 +69,6 @@ def test_sdxl_lcm_lora_and_freeu(sd_device, strategy, sampler):
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
@ -131,7 +129,6 @@ def test_sdxl_outpainting(sd_device, rect):
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True,
sd_cpu_textencoder=False,
)

View File

@ -24,7 +24,6 @@ def save_config(
cpu_offload,
disable_nsfw,
sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,
@ -102,9 +101,6 @@ def main(config_file: str):
with gr.Column():
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
no_gui_auto_close = gr.Checkbox(
init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}"
)
with gr.Column():
model_dir = gr.Textbox(
@ -193,14 +189,11 @@ def main(config_file: str):
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
)
sd_cpu_textencoder = gr.Checkbox(
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
init_config.sd_cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}"
)
disable_nsfw = gr.Checkbox(
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
)
enable_xformers = gr.Checkbox(
init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}"
)
local_files_only = gr.Checkbox(
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
)
@ -221,7 +214,6 @@ def main(config_file: str):
cpu_offload,
disable_nsfw,
sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,

View File

@ -71,6 +71,8 @@ const Cropper = (props: Props) => {
setY,
setWidth,
setHeight,
isResizing,
setIsResizing,
] = useStore((state) => [
state.imageWidth,
state.imageHeight,
@ -80,9 +82,11 @@ const Cropper = (props: Props) => {
state.setCropperY,
state.setCropperWidth,
state.setCropperHeight,
state.isCropperExtenderResizing,
state.setIsCropperExtenderResizing,
])
const [isResizing, setIsResizing] = useState(false)
// const [isResizing, setIsResizing] = useState(false)
const [isMoving, setIsMoving] = useState(false)
useEffect(() => {

View File

@ -65,6 +65,7 @@ export default function Editor(props: EditorProps) {
updateAppState,
runMannually,
runInpainting,
isCropperExtenderResizing,
] = useStore((state) => [
state.disableShortCuts,
state.windowSize,
@ -87,6 +88,7 @@ export default function Editor(props: EditorProps) {
state.updateAppState,
state.runMannually(),
state.runInpainting,
state.isCropperExtenderResizing,
])
const baseBrushSize = useStore((state) => state.editorState.baseBrushSize)
const brushSize = useStore((state) => state.getBrushSize())
@ -537,7 +539,7 @@ export default function Editor(props: EditorProps) {
}
const toggleShowBrush = (newState: boolean) => {
if (newState !== showBrush && !isPanning) {
if (newState !== showBrush && !isPanning && !isCropperExtenderResizing) {
setShowBrush(newState)
}
}
@ -693,7 +695,7 @@ export default function Editor(props: EditorProps) {
limitToBounds={false}
doubleClick={{ disabled: true }}
initialScale={minScale}
minScale={minScale * 0.6}
minScale={minScale * 0.3}
onPanning={(ref) => {
if (!panned) {
setPanned(true)

View File

@ -54,6 +54,8 @@ const Extender = (props: Props) => {
setWidth,
setHeight,
extenderDirection,
isResizing,
setIsResizing,
] = useStore((state) => [
state.isInpainting,
state.imageHeight,
@ -64,10 +66,10 @@ const Extender = (props: Props) => {
state.setExtenderWidth,
state.setExtenderHeight,
state.settings.extenderDirection,
state.isCropperExtenderResizing,
state.setIsCropperExtenderResizing,
])
const [isResizing, setIsResizing] = useState(false)
const [evData, setEVData] = useState<EVData>({
initX: 0,
initY: 0,
@ -122,10 +124,9 @@ const Extender = (props: Props) => {
const moveBottom = () => {
const newHeight = evData.initHeight + offsetY
let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight)
if (extenderDirection === EXTENDER_ALL) {
if (clampedY + clampedHeight < imageHeight) {
clampedHeight = imageHeight
if (clampedHeight < Math.abs(clampedY) + imageHeight) {
clampedHeight = Math.abs(clampedY) + imageHeight
}
}
setHeight(clampedHeight)
@ -155,8 +156,8 @@ const Extender = (props: Props) => {
const newWidth = evData.initWidth + offsetX
let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth)
if (extenderDirection === EXTENDER_ALL) {
if (clampedX + clampedWidth < imageWdith) {
clampedWidth = imageWdith
if (clampedWidth < Math.abs(clampedX) + imageWdith) {
clampedWidth = Math.abs(clampedX) + imageWdith
}
}
setWidth(clampedWidth)

View File

@ -105,6 +105,7 @@ const LabelTitle = ({
{text}
</Label>
</TooltipTrigger>
{toolTip ? (
<TooltipContent className="flex flex-col max-w-xs text-sm" side="left">
<p>{toolTip}</p>
{url ? (
@ -117,6 +118,9 @@ const LabelTitle = ({
<></>
)}
</TooltipContent>
) : (
<></>
)}
</Tooltip>
)
}
@ -172,7 +176,11 @@ const SidePanel = () => {
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-4">
<div className="flex justify-between items-center pr-2">
<LabelTitle text="Controlnet" />
<LabelTitle
text="ControlNet"
toolTip="Using an additional conditioning image to control how an image is generated"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
/>
<Switch
id="controlnet"
checked={settings.enableControlnet}
@ -271,7 +279,11 @@ const SidePanel = () => {
return (
<div className="flex flex-col gap-4">
<div className="flex justify-between items-center pr-2">
<LabelTitle text="Freeu" />
<LabelTitle
text="FreeU"
toolTip="FreeU is a technique for improving image quality. Different models may require different FreeU-specific hyperparameters, which can be viewed in the more info section."
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu"
/>
<Switch
id="freeu"
checked={settings.enableFreeu}
@ -408,7 +420,10 @@ const SidePanel = () => {
return (
<div>
<RowContainer>
<div>Example Image</div>
<LabelTitle
text="Example Image"
toolTip="An example image to guide image generation."
/>
<ImageUploadButton
tooltip="Upload example image"
onFileUpload={(file) => {
@ -450,8 +465,9 @@ const SidePanel = () => {
return (
<div className="flex flex-col gap-1">
<LabelTitle
htmlFor="image-guidance-scale"
text="Image guidance scale"
toolTip="Push the generated image towards the inital image. Higher image guidance scale encourages generated images that are closely linked to the source image, usually at the expense of lower image quality."
url="https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix"
/>
<RowContainer>
<Slider
@ -518,11 +534,17 @@ const SidePanel = () => {
}
const renderExtender = () => {
if (!settings.model.support_outpainting) {
return null
}
return (
<>
<div className="flex flex-col gap-4">
<RowContainer>
<LabelTitle text="Extender" />
<LabelTitle
text="Extender"
toolTip="Perform outpainting on images to expand it's content."
/>
<Switch
id="extender"
checked={settings.showExtender}
@ -709,7 +731,10 @@ const SidePanel = () => {
>
<div className="flex flex-col gap-4 mt-4">
<RowContainer>
<LabelTitle text="Cropper" />
<LabelTitle
text="Cropper"
toolTip="Inpainting on part of image, improve inference speed and reduce memory usage."
/>
<Switch
id="cropper"
checked={settings.showCropper}
@ -725,7 +750,11 @@ const SidePanel = () => {
{renderExtender()}
<div className="flex flex-col gap-1">
<LabelTitle htmlFor="steps" text="Steps" />
<LabelTitle
htmlFor="steps"
text="Steps"
toolTip="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
/>
<RowContainer>
<Slider
className="w-[180px]"

View File

@ -150,8 +150,11 @@ type AppState = {
interactiveSegState: InteractiveSegState
fileManagerState: FileManagerState
cropperState: CropperState
extenderState: CropperState
isCropperExtenderResizing: bool
serverConfig: ServerConfig
settings: Settings
@ -177,6 +180,7 @@ type AppAction = {
setExtenderY: (newValue: number) => void
setExtenderWidth: (newValue: number) => void
setExtenderHeight: (newValue: number) => void
setIsCropperExtenderResizing: (newValue: boolean) => void
updateExtenderDirection: (newValue: string) => void
resetExtender: (width: number, height: number) => void
updateExtenderByBuiltIn: (direction: string, scale: number) => void
@ -261,6 +265,7 @@ const defaultValues: AppState = {
width: 512,
height: 512,
},
isCropperExtenderResizing: false,
fileManagerState: {
sortBy: SortBy.CTIME,
@ -889,6 +894,11 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
state.extenderState.height = newValue
}),
setIsCropperExtenderResizing: (newValue: boolean) =>
set((state) => {
state.isCropperExtenderResizing = newValue
}),
updateExtenderDirection: (newValue: string) => {
console.log(
`updateExtenderDirection: ${JSON.stringify(get().extenderState)}`