From 371db2d7711e6f6806bfcb7b25da2d41570fe676 Mon Sep 17 00:00:00 2001 From: Qing Date: Sun, 24 Dec 2023 15:32:27 +0800 Subject: [PATCH] update --- lama_cleaner/__init__.py | 7 +- lama_cleaner/benchmark.py | 1 - lama_cleaner/const.py | 81 ++-- lama_cleaner/download.py | 17 +- lama_cleaner/model/base.py | 14 +- lama_cleaner/model/controlnet.py | 26 -- lama_cleaner/model/instruct_pix2pix.py | 22 +- lama_cleaner/model/kandinsky.py | 22 +- lama_cleaner/model/paint_by_example.py | 9 - lama_cleaner/model/sd.py | 31 +- lama_cleaner/model/sdxl.py | 24 -- lama_cleaner/model_manager.py | 21 +- lama_cleaner/parse_args.py | 8 +- lama_cleaner/plugins/interactive_seg.py | 2 +- lama_cleaner/plugins/realesrgan.py | 8 +- .../plugins/segment_anything/build_sam.py | 2 +- lama_cleaner/runtime.py | 40 +- lama_cleaner/server.py | 405 +++++++++++------- lama_cleaner/tests/test_controlnet.py | 4 - lama_cleaner/tests/test_instruct_pix2pix.py | 2 - lama_cleaner/tests/test_model_md5.py | 2 - lama_cleaner/tests/test_model_switch.py | 6 - lama_cleaner/tests/test_outpainting.py | 2 - lama_cleaner/tests/test_sd_model.py | 7 - lama_cleaner/tests/test_sdxl.py | 3 - lama_cleaner/web_config.py | 10 +- web_app/src/components/Cropper.tsx | 6 +- web_app/src/components/Editor.tsx | 6 +- web_app/src/components/Extender.tsx | 15 +- web_app/src/components/SidePanel.tsx | 67 ++- web_app/src/lib/states.ts | 10 + 31 files changed, 441 insertions(+), 439 deletions(-) diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py index db72076..0a6697e 100644 --- a/lama_cleaner/__init__.py +++ b/lama_cleaner/__init__.py @@ -10,11 +10,8 @@ from lama_cleaner.parse_args import parse_args def entry_point(): - args = parse_args() - if args is None: - return # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 - from lama_cleaner.server import main + from lama_cleaner.server import typer_app - main(args) + typer_app() diff --git a/lama_cleaner/benchmark.py b/lama_cleaner/benchmark.py index f30915d..feb29e9 100644 --- a/lama_cleaner/benchmark.py +++ b/lama_cleaner/benchmark.py @@ -103,6 +103,5 @@ if __name__ == "__main__": device=device, disable_nsfw=True, sd_cpu_textencoder=True, - hf_access_token="123" ) benchmark(model, args.times, args.empty_cache) diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index c1ec963..486fd11 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -21,19 +21,17 @@ AVAILABLE_MODELS = [ "zits", "mat", "fcf", + "manga", + "cv2", "sd1.5", - "sdxl", "anything4", "realisticVision1.4", - "cv2", - "manga", "sd2", + "sdxl", "paint_by_example", "instruct_pix2pix", "kandinsky2.2", - "sdxl", ] -SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] DIFFUSERS_MODEL_FP16_REVERSION = [ "runwayml/stable-diffusion-inpainting", "Sanster/anything-4.0-inpainting", @@ -46,26 +44,22 @@ AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] DEFAULT_DEVICE = "cuda" NO_HALF_HELP = """ -Using full precision model. -If your generate result is always black or green, use this argument. (sd/paint_by_exmaple) +Using full precision(fp32) model. +If your diffusion model generate result is always black or green, use this argument. """ CPU_OFFLOAD_HELP = """ -Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example) +Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage. """ DISABLE_NSFW_HELP = """ -Disable NSFW checker. (sd/paint_by_example) +Disable NSFW checker for diffusion model. """ -SD_CPU_TEXTENCODER_HELP = """ -Run Stable Diffusion text encoder model on CPU to save GPU memory. +CPU_TEXTENCODER_HELP = """ +Run diffusion models text encoder on CPU to reduce vRAM usage. """ -SD_CONTROLNET_HELP = """ -Run Stable Diffusion normal or inpainting model with ControlNet. -""" -DEFAULT_SD_CONTROLNET_METHOD = "lllyasviel/control_v11p_sd15_canny" SD_CONTROLNET_CHOICES = [ "lllyasviel/control_v11p_sd15_canny", # "lllyasviel/control_v11p_sd15_seg", @@ -74,46 +68,36 @@ SD_CONTROLNET_CHOICES = [ "lllyasviel/control_v11f1p_sd15_depth", ] -DEFAULT_SD2_CONTROLNET_METHOD = "thibaud/controlnet-sd21-canny-diffusers" SD2_CONTROLNET_CHOICES = [ "thibaud/controlnet-sd21-canny-diffusers", "thibaud/controlnet-sd21-depth-diffusers", "thibaud/controlnet-sd21-openpose-diffusers", ] -DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0" SDXL_CONTROLNET_CHOICES = [ "thibaud/controlnet-openpose-sdxl-1.0", - "destitech/controlnet-inpaint-dreamer-sdxl" + "destitech/controlnet-inpaint-dreamer-sdxl", "diffusers/controlnet-canny-sdxl-1.0", "diffusers/controlnet-canny-sdxl-1.0-mid", - "diffusers/controlnet-canny-sdxl-1.0-small" + "diffusers/controlnet-canny-sdxl-1.0-small", "diffusers/controlnet-depth-sdxl-1.0", "diffusers/controlnet-depth-sdxl-1.0-mid", "diffusers/controlnet-depth-sdxl-1.0-small", ] -SD_LOCAL_MODEL_HELP = """ -Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path. -""" - LOCAL_FILES_ONLY_HELP = """ -Use local files only, not connect to Hugging Face server. (sd/paint_by_example) -""" - -ENABLE_XFORMERS_HELP = """ -Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example) +When loading diffusion models, using local files only, not connect to HuggingFace server. """ DEFAULT_MODEL_DIR = os.getenv( "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") ) -MODEL_DIR_HELP = """ -Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache +MODEL_DIR_HELP = f""" +Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR} """ OUTPUT_DIR_HELP = """ -Result images will be saved to output directory automatically without confirmation. +Result images will be saved to output directory automatically. """ INPUT_HELP = """ @@ -125,37 +109,45 @@ GUI_HELP = """ Launch Lama Cleaner as desktop app """ -NO_GUI_AUTO_CLOSE_HELP = """ -Prevent backend auto close after the GUI window closed. -""" - QUALITY_HELP = """ Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size. """ -class RealESRGANModelName(str, Enum): +class Choices(str, Enum): + @classmethod + def values(cls): + return [member.value for member in cls] + + +class RealESRGANModel(Choices): realesr_general_x4v3 = "realesr-general-x4v3" RealESRGAN_x4plus = "RealESRGAN_x4plus" RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" -RealESRGANModelNameList = [e.value for e in RealESRGANModelName] +class Device(Choices): + cpu = "cpu" + cuda = "cuda" + mps = "mps" + + +class InteractiveSegModel(Choices): + vit_b = "vit_b" + vit_l = "vit_l" + vit_h = "vit_h" + mobile_sam = "mobile_sam" + INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." -AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h", "vit_t"] -AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"] REMOVE_BG_HELP = "Enable remove background. Always run on CPU" ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU" REALESRGAN_HELP = "Enable realesrgan super resolution" -REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] GFPGAN_HELP = ( "Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan" ) -GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan" -RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image" @@ -164,8 +156,6 @@ class Config(BaseModel): port: int = 8080 model: str = DEFAULT_MODEL sd_local_model_path: str = None - sd_controlnet: bool = False - sd_controlnet_method: str = DEFAULT_SD_CONTROLNET_METHOD device: str = DEFAULT_DEVICE gui: bool = False no_gui_auto_close: bool = False @@ -173,7 +163,6 @@ class Config(BaseModel): cpu_offload: bool = False disable_nsfw: bool = False sd_cpu_textencoder: bool = False - enable_xformers: bool = False local_files_only: bool = False model_dir: str = DEFAULT_MODEL_DIR input: str = None @@ -186,7 +175,7 @@ class Config(BaseModel): enable_anime_seg: bool = False enable_realesrgan: bool = False realesrgan_device: str = "cpu" - realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value + realesrgan_model: str = RealESRGANModel.realesr_general_x4v3.value realesrgan_no_half: bool = False enable_gfpgan: bool = False gfpgan_device: str = "cpu" diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py index 63711af..aef26c4 100644 --- a/lama_cleaner/download.py +++ b/lama_cleaner/download.py @@ -6,6 +6,7 @@ from loguru import logger from pathlib import Path from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR +from lama_cleaner.runtime import setup_model_dir from lama_cleaner.schema import ( ModelInfo, ModelType, @@ -16,16 +17,8 @@ from lama_cleaner.schema import ( ) -def cli_download_model(model: str, model_dir: str): - if os.path.isfile(model_dir): - raise ValueError(f"invalid --model-dir: {model_dir} is a file") - - if not os.path.exists(model_dir): - logger.info(f"Create model cache directory: {model_dir}") - Path(model_dir).mkdir(exist_ok=True, parents=True) - - os.environ["XDG_CACHE_HOME"] = model_dir - +def cli_download_model(model: str, model_dir: Path): + setup_model_dir(model_dir) from lama_cleaner.model import models if model in models: @@ -38,7 +31,7 @@ def cli_download_model(model: str, model_dir: str): downloaded_path = DiffusionPipeline.download( pretrained_model_name=model, - revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main", + variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main", resume_download=True, ) logger.info(f"Done. Downloaded to {downloaded_path}") @@ -101,7 +94,7 @@ def scan_inpaint_models() -> List[ModelInfo]: from lama_cleaner.model import models for name, m in models.items(): - if m.is_erase_model: + if m.is_erase_model and m.is_downloaded(): res.append( ModelInfo( name=name, diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 72fe2ed..b4b43ba 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -41,7 +41,7 @@ class InpaintModel: @staticmethod @abc.abstractmethod def is_downloaded() -> bool: - ... + return False @abc.abstractmethod def forward(self, image, mask, config: Config): @@ -67,6 +67,8 @@ class InpaintModel: logger.info(f"final forward pad size: {pad_image.shape}") + image, mask = self.forward_pre_process(image, mask, config) + result = self.forward(pad_image, pad_mask, config) result = result[0:origin_height, 0:origin_width, :] @@ -77,6 +79,9 @@ class InpaintModel: result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) return result + def forward_pre_process(self, image, mask, config): + return image, mask + def forward_post_process(self, result, image, mask, config): return result, image, mask @@ -400,6 +405,13 @@ class DiffusionInpaintModel(InpaintModel): scheduler = get_scheduler(sd_sampler, scheduler_config) self.model.scheduler = scheduler + def forward_pre_process(self, image, mask, config): + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] + + return image, mask + def forward_post_process(self, result, image, mask, config): if config.sd_match_histograms: result = self._match_histograms(result, image[:, :, ::-1], mask) diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index e5a02ce..29591f8 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -17,14 +17,6 @@ from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.model.utils import get_scheduler from lama_cleaner.schema import Config, ModelInfo, ModelType -# 为了兼容性 -controlnet_name_map = { - "control_v11p_sd15_canny": "lllyasviel/control_v11p_sd15_canny", - "control_v11p_sd15_openpose": "lllyasviel/control_v11p_sd15_openpose", - "control_v11p_sd15_inpaint": "lllyasviel/control_v11p_sd15_inpaint", - "control_v11f1p_sd15_depth": "lllyasviel/control_v11f1p_sd15_depth", -} - class ControlNet(DiffusionInpaintModel): name = "controlnet" @@ -49,9 +41,6 @@ class ControlNet(DiffusionInpaintModel): fp16 = not kwargs.get("no_half", False) model_info: ModelInfo = kwargs["model_info"] sd_controlnet_method = kwargs["sd_controlnet_method"] - sd_controlnet_method = controlnet_name_map.get( - sd_controlnet_method, sd_controlnet_method - ) self.model_info = model_info self.sd_controlnet_method = sd_controlnet_method @@ -113,12 +102,6 @@ class ControlNet(DiffusionInpaintModel): **model_kwargs, ) - # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing - self.model.enable_attention_slicing() - # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention - if kwargs.get("enable_xformers", False): - self.model.enable_xformers_memory_efficient_attention() - if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") self.model.enable_sequential_cpu_offload(gpu_id=0) @@ -162,10 +145,6 @@ class ControlNet(DiffusionInpaintModel): scheduler = get_scheduler(config.sd_sampler, scheduler_config) self.model.scheduler = scheduler - if config.sd_mask_blur != 0: - k = 2 * config.sd_mask_blur + 1 - mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] - img_h, img_w = image.shape[:2] control_image = self._get_control_image(image, mask) mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") @@ -190,8 +169,3 @@ class ControlNet(DiffusionInpaintModel): output = (output * 255).round().astype("uint8") output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - - @staticmethod - def is_downloaded() -> bool: - # model will be downloaded when app start, and can't switch in frontend settings - return True diff --git a/lama_cleaner/model/instruct_pix2pix.py b/lama_cleaner/model/instruct_pix2pix.py index 4645fa5..18990b1 100644 --- a/lama_cleaner/model/instruct_pix2pix.py +++ b/lama_cleaner/model/instruct_pix2pix.py @@ -31,30 +31,15 @@ class InstructPix2Pix(DiffusionInpaintModel): use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( - "timbrooks/instruct-pix2pix", - revision="fp16" if use_gpu and fp16 else "main", - torch_dtype=torch_dtype, - **model_kwargs + self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs ) - self.model.enable_attention_slicing() - if kwargs.get("enable_xformers", False): - self.model.enable_xformers_memory_efficient_attention() - if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") self.model.enable_sequential_cpu_offload(gpu_id=0) else: self.model = self.model.to(device) - @staticmethod - def download(): - from diffusers import StableDiffusionInstructPix2PixPipeline - - StableDiffusionInstructPix2PixPipeline.from_pretrained( - "timbrooks/instruct-pix2pix", revision="fp16" - ) - def forward(self, image, mask, config: Config): """Input image and output image have same size image: [H, W, C] RGB @@ -76,8 +61,3 @@ class InstructPix2Pix(DiffusionInpaintModel): output = (output * 255).round().astype("uint8") output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - - @staticmethod - def is_downloaded() -> bool: - # model will be downloaded when app start, and can't switch in frontend settings - return True diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index b44f9df..d12254e 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel): } self.model = AutoPipelineForInpainting.from_pretrained( - self.model_name, **model_kwargs + self.model_id_or_path, **model_kwargs ).to(device) self.callback = kwargs.pop("callback", None) @@ -40,9 +40,6 @@ class Kandinsky(DiffusionInpaintModel): self.model.scheduler = scheduler generator = torch.manual_seed(config.sd_seed) - if config.sd_mask_blur != 0: - k = 2 * config.sd_mask_blur + 1 - mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] mask = mask.astype(np.float32) / 255 img_h, img_w = image.shape[:2] @@ -66,20 +63,7 @@ class Kandinsky(DiffusionInpaintModel): output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - @staticmethod - def is_downloaded() -> bool: - # model will be downloaded when app start, and can't switch in frontend settings - return True - class Kandinsky22(Kandinsky): - name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" - model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" - - @staticmethod - def download(): - from diffusers import AutoPipelineForInpainting - - AutoPipelineForInpainting.from_pretrained( - "kandinsky-community/kandinsky-2-2-decoder-inpaint" - ) + name = "kandinsky2.2" + model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint" diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index c8417d9..3c2783b 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -31,10 +31,6 @@ class PaintByExample(DiffusionInpaintModel): "Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs ) - self.model.enable_attention_slicing() - if kwargs.get("enable_xformers", False): - self.model.enable_xformers_memory_efficient_attention() - # TODO: gpu_id if kwargs.get("cpu_offload", False) and use_gpu: self.model.image_encoder = self.model.image_encoder.to(device) @@ -68,8 +64,3 @@ class PaintByExample(DiffusionInpaintModel): output = (output * 255).round().astype("uint8") output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - - @staticmethod - def is_downloaded() -> bool: - # model will be downloaded when app start, and can't switch in frontend settings - return True diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 48c8f65..56134e2 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -1,8 +1,5 @@ -import os - import PIL.Image import cv2 -import numpy as np import torch from loguru import logger @@ -49,23 +46,12 @@ class SD(DiffusionInpaintModel): self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model_id_or_path, revision="fp16" - if ( - self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION - and use_gpu - and fp16 - ) + if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION else "main", torch_dtype=torch_dtype, - use_auth_token=kwargs["hf_access_token"], **model_kwargs, ) - # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing - self.model.enable_attention_slicing() - # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention - if kwargs.get("enable_xformers", False): - self.model.enable_xformers_memory_efficient_attention() - if kwargs.get("cpu_offload", False) and use_gpu: # TODO: gpu_id logger.info("Enable sequential cpu offload") @@ -88,10 +74,6 @@ class SD(DiffusionInpaintModel): """ self.set_scheduler(config) - if config.sd_mask_blur != 0: - k = 2 * config.sd_mask_blur + 1 - mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] - img_h, img_w = image.shape[:2] output = self.model( @@ -114,17 +96,6 @@ class SD(DiffusionInpaintModel): output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - @staticmethod - def is_downloaded() -> bool: - # model will be downloaded when app start, and can't switch in frontend settings - return True - - @classmethod - def download(cls): - from diffusers import StableDiffusionInpaintPipeline - - StableDiffusionInpaintPipeline.from_pretrained(cls.model_id_or_path) - class SD15(SD): name = "sd1.5" diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index 5260d0a..428c670 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -45,16 +45,9 @@ class SDXL(DiffusionInpaintModel): self.model_id_or_path, revision="main", torch_dtype=torch_dtype, - use_auth_token=kwargs["hf_access_token"], vae=vae, ) - # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing - self.model.enable_attention_slicing() - # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention - if kwargs.get("enable_xformers", False): - self.model.enable_xformers_memory_efficient_attention() - if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") self.model.enable_sequential_cpu_offload(gpu_id=0) @@ -65,14 +58,6 @@ class SDXL(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) - @staticmethod - def download(): - from diffusers import AutoPipelineForInpainting - - AutoPipelineForInpainting.from_pretrained( - "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" - ) - def forward(self, image, mask, config: Config): """Input image and output image have same size image: [H, W, C] RGB @@ -81,10 +66,6 @@ class SDXL(DiffusionInpaintModel): """ self.set_scheduler(config) - if config.sd_mask_blur != 0: - k = 2 * config.sd_mask_blur + 1 - mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] - img_h, img_w = image.shape[:2] output = self.model( @@ -106,8 +87,3 @@ class SDXL(DiffusionInpaintModel): output = (output * 255).round().astype("uint8") output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - - @staticmethod - def is_downloaded() -> bool: - # model will be downloaded when app start, and can't switch in frontend settings - return True diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 69bbf99..9d5ac5f 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -3,7 +3,6 @@ from typing import List, Dict import torch from loguru import logger -from lama_cleaner.const import DEFAULT_SD_CONTROLNET_METHOD from lama_cleaner.download import scan_models from lama_cleaner.helper import switch_mps_device from lama_cleaner.model import models, ControlNet, SD, SDXL @@ -19,16 +18,25 @@ class ModelManager: self.available_models: Dict[str, ModelInfo] = {} self.scan_models() - self.sd_controlnet = kwargs.get("sd_controlnet", False) - self.sd_controlnet_method = kwargs.get( - "sd_controlnet_method", DEFAULT_SD_CONTROLNET_METHOD - ) + self.sd_controlnet = False + self.sd_controlnet_method = "" self.model = self.init_model(name, device, **kwargs) - def init_model(self, name: str, device, **kwargs): + def _map_old_name(self, name: str) -> str: for old_name, model_cls in models.items(): if name == old_name and hasattr(model_cls, "model_id_or_path"): name = model_cls.model_id_or_path + break + return name + + @property + def current_model(self) -> Dict: + name = self._map_old_name(self.name) + return self.available_models[name].model_dump() + + def init_model(self, name: str, device, **kwargs): + name = self._map_old_name(name) + logger.info(f"Loading model: {name}") if name not in self.available_models: raise NotImplementedError(f"Unsupported model: {name}") @@ -86,6 +94,7 @@ class ModelManager: ): self.sd_controlnet_method = self.available_models[new_name].controlnets[0] try: + # TODO: enable/disable controlnet without reload model del self.model torch_gc() diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 7c301bf..4c5f412 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -55,7 +55,7 @@ def parse_args(): parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) parser.add_argument( - "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP + "--sd-cpu-textencoder", action="store_true", help=CPU_TEXTENCODER_HELP ) parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) parser.add_argument( @@ -66,16 +66,10 @@ def parse_args(): parser.add_argument( "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP ) - parser.add_argument( - "--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP - ) parser.add_argument( "--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES ) parser.add_argument("--gui", action="store_true", help=GUI_HELP) - parser.add_argument( - "--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP - ) parser.add_argument( "--gui-size", default=[1600, 1000], diff --git a/lama_cleaner/plugins/interactive_seg.py b/lama_cleaner/plugins/interactive_seg.py index 36cd8d2..160d59f 100644 --- a/lama_cleaner/plugins/interactive_seg.py +++ b/lama_cleaner/plugins/interactive_seg.py @@ -22,7 +22,7 @@ SEGMENT_ANYTHING_MODELS = { "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "md5": "4b8939a88964f0f4ff5f5b2642c598a6", }, - "vit_t": { + "mobile_sam": { "url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt", "md5": "f3c0d8cda613564d499310dab6c812cd", }, diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py index 6dcf761..36f522a 100644 --- a/lama_cleaner/plugins/realesrgan.py +++ b/lama_cleaner/plugins/realesrgan.py @@ -3,7 +3,7 @@ from enum import Enum import cv2 from loguru import logger -from lama_cleaner.const import RealESRGANModelName +from lama_cleaner.const import RealESRGANModel from lama_cleaner.helper import download_model from lama_cleaner.plugins.base_plugin import BasePlugin @@ -18,7 +18,7 @@ class RealESRGANUpscaler(BasePlugin): from realesrgan.archs.srvgg_arch import SRVGGNetCompact REAL_ESRGAN_MODELS = { - RealESRGANModelName.realesr_general_x4v3: { + RealESRGANModel.realesr_general_x4v3: { "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "scale": 4, "model": lambda: SRVGGNetCompact( @@ -31,7 +31,7 @@ class RealESRGANUpscaler(BasePlugin): ), "model_md5": "91a7644643c884ee00737db24e478156", }, - RealESRGANModelName.RealESRGAN_x4plus: { + RealESRGANModel.RealESRGAN_x4plus: { "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", "scale": 4, "model": lambda: RRDBNet( @@ -44,7 +44,7 @@ class RealESRGANUpscaler(BasePlugin): ), "model_md5": "99ec365d4afad750833258a1a24f44ca", }, - RealESRGANModelName.RealESRGAN_x4plus_anime_6B: { + RealESRGANModel.RealESRGAN_x4plus_anime_6B: { "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", "scale": 4, "model": lambda: RRDBNet( diff --git a/lama_cleaner/plugins/segment_anything/build_sam.py b/lama_cleaner/plugins/segment_anything/build_sam.py index 396daf3..2212c53 100644 --- a/lama_cleaner/plugins/segment_anything/build_sam.py +++ b/lama_cleaner/plugins/segment_anything/build_sam.py @@ -109,7 +109,7 @@ sam_model_registry = { "vit_h": build_sam, "vit_l": build_sam_vit_l, "vit_b": build_sam_vit_b, - "vit_t": build_sam_vit_t, + "mobile_sam": build_sam_vit_t, } diff --git a/lama_cleaner/runtime.py b/lama_cleaner/runtime.py index d8cc2e0..d373dfc 100644 --- a/lama_cleaner/runtime.py +++ b/lama_cleaner/runtime.py @@ -1,10 +1,16 @@ # https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py +import os import platform import sys +from pathlib import Path + import packaging.version +from loguru import logger from rich import print from typing import Dict, Any +from lama_cleaner.const import Device + _PY_VERSION: str = sys.version.split()[0].rstrip("+") if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"): @@ -21,7 +27,6 @@ _CANDIDATES = [ "diffusers", "transformers", "opencv-python", - "xformers", "accelerate", "lama-cleaner", "rembg", @@ -38,7 +43,7 @@ for name in _CANDIDATES: def dump_environment_info() -> Dict[str, str]: - """Dump information about the machine to help debugging issues. """ + """Dump information about the machine to help debugging issues.""" # Generic machine info info: Dict[str, Any] = { @@ -48,3 +53,34 @@ def dump_environment_info() -> Dict[str, str]: info.update(_package_versions) print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") return info + + +def check_device(device: Device) -> Device: + if device == Device.cuda: + import platform + + if platform.system() == "Darwin": + logger.warning("MacOS does not support cuda, use cpu instead") + return Device.cpu + else: + import torch + + if not torch.cuda.is_available(): + logger.warning("CUDA is not available, use cpu instead") + return Device.cpu + elif device == Device.mps: + import torch + + if not torch.backends.mps.is_available(): + logger.warning("mps is not available, use cpu instead") + return Device.cpu + return device + + +def setup_model_dir(model_dir: Path): + model_dir = model_dir.expanduser().absolute() + os.environ["U2NET_HOME"] = str(model_dir) + os.environ["XDG_CACHE_HOME"] = str(model_dir) + if not model_dir.exists(): + logger.info(f"Create model directory: {model_dir}") + model_dir.mkdir(exist_ok=True, parents=True) diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 327a991..4fb42f5 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,10 +1,18 @@ #!/usr/bin/env python3 import json import os -import hashlib -import traceback + +import typer +from typer import Option + +from lama_cleaner.download import cli_download_model, scan_models +from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +import hashlib +import traceback +from dataclasses import dataclass + import imghdr import io @@ -20,12 +28,7 @@ import torch from PIL import Image from loguru import logger -from lama_cleaner.const import ( - SD15_MODELS, - SD_CONTROLNET_CHOICES, - SDXL_CONTROLNET_CHOICES, - SD2_CONTROLNET_CHOICES, -) +from lama_cleaner.const import * from lama_cleaner.file_manager import FileManager from lama_cleaner.model.utils import torch_gc from lama_cleaner.model_manager import ModelManager @@ -39,6 +42,8 @@ from lama_cleaner.plugins import ( ) from lama_cleaner.schema import Config +typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False) + try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) @@ -103,23 +108,34 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app.config["JSON_AS_ASCII"] = False -CORS(app, expose_headers=["Content-Disposition", "X-seed"]) +CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"]) sio_logger = logging.getLogger("sio-logger") sio_logger.setLevel(logging.ERROR) socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading") -model: ModelManager = None -thumb: FileManager = None -output_dir: str = None -device = None -input_image_path: str = None -is_disable_model_switch: bool = False -enable_file_manager: bool = False -enable_auto_saving: bool = False -is_desktop: bool = False -image_quality: int = 95 -plugins = {} + +@dataclass +class GlobalConfig: + model_manager: ModelManager = None + file_manager: FileManager = None + output_dir: Path = None + input_image_path: Path = None + disable_model_switch: bool = False + is_desktop: bool = False + image_quality: int = 95 + plugins = {} + + @property + def enable_auto_saving(self) -> bool: + return self.output_dir is not None + + @property + def enable_file_manager(self) -> bool: + return self.file_manager is not None + + +global_config = GlobalConfig() def get_image_ext(img_bytes): @@ -135,7 +151,7 @@ def diffuser_callback(i, t, latents): @app.route("/save_image", methods=["POST"]) def save_image(): - if output_dir is None: + if global_config.output_dir is None: return "--output-dir is None", 500 input = request.files @@ -143,7 +159,7 @@ def save_image(): origin_image_bytes = input["image"].read() # RGB ext = get_image_ext(origin_image_bytes) image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True) - save_path = os.path.join(output_dir, filename) + save_path = str(global_config.output_dir / filename) if alpha_channel is not None: if alpha_channel.shape[:2] != image.shape[:2]: @@ -157,7 +173,7 @@ def save_image(): img_bytes = pil_to_bytes( pil_image, ext, - quality=image_quality, + quality=global_config.image_quality, exif_infos=exif_infos, ) with open(save_path, "wb") as fw: @@ -169,9 +185,11 @@ def save_image(): @app.route("/medias/") def medias(tab): if tab == "image": - response = make_response(jsonify(thumb.media_names), 200) + response = make_response(jsonify(global_config.file_manager.media_names), 200) else: - response = make_response(jsonify(thumb.output_media_names), 200) + response = make_response( + jsonify(global_config.file_manager.output_media_names), 200 + ) # response.last_modified = thumb.modified_time[tab] # response.cache_control.no_cache = True # response.cache_control.max_age = 0 @@ -182,8 +200,8 @@ def medias(tab): @app.route("/media//") def media_file(tab, filename): if tab == "image": - return send_from_directory(thumb.root_directory, filename) - return send_from_directory(thumb.output_dir, filename) + return send_from_directory(global_config.file_manager.root_directory, filename) + return send_from_directory(global_config.file_manager.output_dir, filename) @app.route("/media_thumbnail//") @@ -198,10 +216,10 @@ def media_thumbnail_file(tab, filename): if height: height = int(float(height)) - directory = thumb.root_directory + directory = global_config.file_manager.root_directory if tab == "output": - directory = thumb.output_dir - thumb_filename, (width, height) = thumb.get_thumbnail( + directory = global_config.file_manager.output_dir + thumb_filename, (width, height) = global_config.file_manager.get_thumbnail( directory, filename, width, height ) thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" @@ -257,13 +275,11 @@ def process(): croper_y=form["croperY"], croper_height=form["croperHeight"], croper_width=form["croperWidth"], - use_extender=form["useExtender"], extender_x=form["extenderX"], extender_y=form["extenderY"], extender_height=form["extenderHeight"], extender_width=form["extenderWidth"], - sd_scale=form["sdScale"], sd_mask_blur=form["sdMaskBlur"], sd_strength=form["sdStrength"], @@ -294,7 +310,7 @@ def process(): start = time.time() try: - res_np_img = model(image, mask, config) + res_np_img = global_config.model_manager(image, mask, config) except RuntimeError as e: if "CUDA out of memory. " in str(e): # NOTE: the string may change? @@ -322,7 +338,7 @@ def process(): pil_to_bytes( Image.fromarray(res_np_img), ext, - quality=image_quality, + quality=global_config.image_quality, exif_infos=exif_infos, ) ) @@ -345,7 +361,7 @@ def run_plugin(): form = request.form files = request.files name = form["name"] - if name not in plugins: + if name not in global_config.plugins: return "Plugin not found", 500 origin_image_bytes = files["image"].read() # RGB @@ -359,7 +375,7 @@ def run_plugin(): if name == InteractiveSeg.name: img_md5 = hashlib.md5(origin_image_bytes).hexdigest() form["img_md5"] = img_md5 - bgr_res = plugins[name](rgb_np_img, files, form) + bgr_res = global_config.plugins[name](rgb_np_img, files, form) except RuntimeError as e: torch.cuda.empty_cache() if "CUDA out of memory. " in str(e): @@ -401,7 +417,7 @@ def run_plugin(): pil_to_bytes( Image.fromarray(rgb_res), ext, - quality=image_quality, + quality=global_config.image_quality, exif_infos=exif_infos, ) ), @@ -414,41 +430,40 @@ def run_plugin(): @app.route("/server_config", methods=["GET"]) def get_server_config(): return { - "plugins": list(plugins.keys()), - "enableFileManager": enable_file_manager, - "enableAutoSaving": enable_auto_saving, - "enableControlnet": model.sd_controlnet, - "controlnetMethod": model.sd_controlnet_method, - "disableModelSwitch": is_disable_model_switch, + "plugins": list(global_config.plugins.keys()), + "enableFileManager": global_config.enable_file_manager, + "enableAutoSaving": global_config.enable_auto_saving, + "enableControlnet": global_config.model_manager.sd_controlnet, + "controlnetMethod": global_config.model_manager.sd_controlnet_method, + "disableModelSwitch": global_config.disable_model_switch, + "isDesktop": global_config.is_desktop, }, 200 @app.route("/models", methods=["GET"]) def get_models(): - return [it.model_dump() for it in model.scan_models()] + return [it.model_dump() for it in global_config.model_manager.scan_models()] @app.route("/model") def current_model(): - return model.available_models[model.name].model_dump(), 200 - - -@app.route("/is_desktop") -def get_is_desktop(): - return str(is_desktop), 200 + return ( + global_config.model_manager.current_model, + 200, + ) @app.route("/model", methods=["POST"]) def switch_model(): - if is_disable_model_switch: + if global_config.disable_model_switch: return "Switch model is disabled", 400 new_name = request.form.get("name") - if new_name == model.name: + if new_name == global_config.model_manager.name: return "Same model", 200 try: - model.switch(new_name) + global_config.model_manager.switch(new_name) except Exception as e: traceback.print_exc() error_message = f"{type(e).__name__} - {str(e)}" @@ -464,160 +479,230 @@ def index(): @app.route("/inputimage") def get_cli_input_image(): - if input_image_path: - with open(input_image_path, "rb") as f: + if global_config.input_image_path: + with open(global_config.input_image_path, "rb") as f: image_in_bytes = f.read() return send_file( - input_image_path, + global_config.input_image_path, as_attachment=True, - download_name=Path(input_image_path).name, + download_name=Path(global_config.input_image_path).name, mimetype=f"image/{get_image_ext(image_in_bytes)}", ) else: return "No Input Image" -def build_plugins(args): - global plugins - if args.enable_interactive_seg: +def build_plugins( + enable_interactive_seg: bool, + interactive_seg_model: InteractiveSegModel, + interactive_seg_device: Device, + enable_remove_bg: bool, + enable_anime_seg: bool, + enable_realesrgan: bool, + realesrgan_device: Device, + realesrgan_model: str, + enable_gfpgan: bool, + gfpgan_device: Device, + enable_restoreformer: bool, + restoreformer_device: Device, + no_half: bool, +): + if enable_interactive_seg: logger.info(f"Initialize {InteractiveSeg.name} plugin") - plugins[InteractiveSeg.name] = InteractiveSeg( - args.interactive_seg_model, args.interactive_seg_device + global_config.plugins[InteractiveSeg.name] = InteractiveSeg( + interactive_seg_model, interactive_seg_device ) - if args.enable_remove_bg: + if enable_remove_bg: logger.info(f"Initialize {RemoveBG.name} plugin") - plugins[RemoveBG.name] = RemoveBG() + global_config.plugins[RemoveBG.name] = RemoveBG() - if args.enable_anime_seg: + if enable_anime_seg: logger.info(f"Initialize {AnimeSeg.name} plugin") - plugins[AnimeSeg.name] = AnimeSeg() + global_config.plugins[AnimeSeg.name] = AnimeSeg() - if args.enable_realesrgan: + if enable_realesrgan: logger.info( - f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}" + f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}" ) - plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( - args.realesrgan_model, - args.realesrgan_device, - no_half=args.realesrgan_no_half, + global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( + realesrgan_model, + realesrgan_device, + no_half=no_half, ) - if args.enable_gfpgan: + if enable_gfpgan: logger.info(f"Initialize {GFPGANPlugin.name} plugin") - if args.enable_realesrgan: + if enable_realesrgan: logger.info("Use realesrgan as GFPGAN background upscaler") else: logger.info( f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" ) - plugins[GFPGANPlugin.name] = GFPGANPlugin( - args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None) + global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin( + gfpgan_device, + upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), ) - if args.enable_restoreformer: + if enable_restoreformer: logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") - plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( - args.restoreformer_device, - upscaler=plugins.get(RealESRGANUpscaler.name, None), + global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( + restoreformer_device, + upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), ) -def main(args): - global model - global device - global input_image_path - global is_disable_model_switch - global enable_file_manager - global is_desktop - global thumb - global output_dir - global image_quality - global enable_auto_saving - build_plugins(args) +@typer_app.command(help="Install all plugins dependencies") +def install_plugins_packages(): + from lama_cleaner.installer import install_plugins_package - image_quality = args.quality - output_dir = args.output_dir + install_plugins_package() + + +@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace") +def download( + model: str = Option( + ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting" + ), + model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), +): + cli_download_model(model, model_dir) + + +@typer_app.command(help="List downloaded models") +def list_model( + model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), +): + setup_model_dir(model_dir) + scanned_models = scan_models() + for it in scanned_models: + print(it.name) + + +@typer_app.command(help="Start lama cleaner server") +def start( + host: str = Option("127.0.0.1"), + port: int = Option(8080), + model: str = Option( + DEFAULT_MODEL, + help=f"Available models: [{', '.join(AVAILABLE_MODELS)}]. " + f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface", + ), + model_dir: Path = Option( + DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False + ), + no_half: bool = Option(False, help=NO_HALF_HELP), + cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP), + disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP), + cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP), + local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP), + device: Device = Option(Device.cpu), + gui: bool = Option(False, help=GUI_HELP), + disable_model_switch: bool = Option(False), + input: Path = Option(None, help=INPUT_HELP), + output_dir: Path = Option( + None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False + ), + quality: int = Option(95, help=QUALITY_HELP), + enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP), + interactive_seg_model: InteractiveSegModel = Option( + InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP + ), + interactive_seg_device: Device = Option(Device.cpu), + enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), + enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP), + enable_realesrgan: bool = Option(False), + realesrgan_device: Device = Option(Device.cpu), + realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3), + enable_gfpgan: bool = Option(False), + gfpgan_device: Device = Option(Device.cpu), + enable_restoreformer: bool = Option(False), + restoreformer_device: Device = Option(Device.cpu), +): + global global_config + dump_environment_info() + + if input: + if not input.exists(): + logger.error(f"invalid --input: {input} not exists") + exit() + if input.is_dir(): + logger.info(f"Initialize file manager") + file_manager = FileManager(app) + app.config["THUMBNAIL_MEDIA_ROOT"] = input + app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( + output_dir, "lama_cleaner_thumbnails" + ) + file_manager.output_dir = output_dir + else: + global_config.input_image_path = input + + device = check_device(device) + setup_model_dir(model_dir) + + if local_files_only: + os.environ["TRANSFORMERS_OFFLINE"] = "1" + os.environ["HF_HUB_OFFLINE"] = "1" + + if model not in AVAILABLE_MODELS: + scanned_models = scan_models() + if model not in [it.name for it in scanned_models]: + logger.error( + f"invalid --model: {model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}" + ) + exit() + + global_config.image_quality = quality + global_config.disable_model_switch = disable_model_switch + global_config.is_desktop = gui + build_plugins( + enable_interactive_seg, + interactive_seg_model, + interactive_seg_device, + enable_remove_bg, + enable_anime_seg, + enable_realesrgan, + realesrgan_device, + realesrgan_model, + enable_gfpgan, + gfpgan_device, + enable_restoreformer, + restoreformer_device, + no_half, + ) if output_dir: - output_dir = os.path.abspath(output_dir) - logger.info(f"Output dir: {output_dir}") - enable_auto_saving = True + output_dir = output_dir.expanduser().absolute() + logger.info(f"Image will auto save to output dir: {output_dir}") + global_config.output_dir = output_dir - device = torch.device(args.device) - is_disable_model_switch = args.disable_model_switch - is_desktop = args.gui - if is_disable_model_switch: - logger.info( - f"Start with --disable-model-switch, model switch on frontend is disable" - ) - - if args.input and os.path.isdir(args.input): - logger.info(f"Initialize file manager") - thumb = FileManager(app) - enable_file_manager = True - app.config["THUMBNAIL_MEDIA_ROOT"] = args.input - app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( - output_dir, "lama_cleaner_thumbnails" - ) - thumb.output_dir = Path(output_dir) - # thumb.start() - # try: - # while True: - # time.sleep(1) - # finally: - # thumb.image_dir_observer.stop() - # thumb.image_dir_observer.join() - # thumb.output_dir_observer.stop() - # thumb.output_dir_observer.join() - - else: - input_image_path = args.input - - # 为了兼容性 - model_name_map = { - "sd1.5": "runwayml/stable-diffusion-inpainting", - "anything4": "Sanster/anything-4.0-inpainting", - "realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting", - "sd2": "stabilityai/stable-diffusion-2-inpainting", - "sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", - "kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint", - "paint_by_example": "Fantasy-Studio/Paint-by-Example", - "instruct_pix2pix": "timbrooks/instruct-pix2pix", - } - - model = ModelManager( - name=model_name_map.get(args.model, args.model), - sd_controlnet=args.sd_controlnet, - sd_controlnet_method=args.sd_controlnet_method, - device=device, - no_half=args.no_half, - hf_access_token=args.hf_access_token, - disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw, - sd_cpu_textencoder=args.sd_cpu_textencoder, - cpu_offload=args.cpu_offload, - enable_xformers=args.sd_enable_xformers or args.enable_xformers, + global_config.model_manager = ModelManager( + name=model, + device=torch.device(device), + no_half=no_half, + disable_nsfw=disable_nsfw_checker, + sd_cpu_textencoder=cpu_textencoder, + cpu_offload=cpu_offload, callback=diffuser_callback, ) - if args.gui: - app_width, app_height = args.gui_size + if gui: from flaskwebgui import FlaskUI ui = FlaskUI( app, socketio=socketio, - width=app_width, - height=app_height, - host=args.host, - port=args.port, - close_server_on_exit=not args.no_gui_auto_close, + width=1200, + height=800, + host=host, + port=port, + close_server_on_exit=True, + idle_interval=60, ) ui.run() else: socketio.run( app, - host=args.host, - port=args.port, - debug=args.debug, + host=host, + port=port, allow_unsafe_werkzeug=True, ) diff --git a/lama_cleaner/tests/test_controlnet.py b/lama_cleaner/tests/test_controlnet.py index 5c4571b..b869a19 100644 --- a/lama_cleaner/tests/test_controlnet.py +++ b/lama_cleaner/tests/test_controlnet.py @@ -39,7 +39,6 @@ def test_runway_sd_1_5( name=model_name, sd_controlnet=True, device=torch.device(sd_device), - hf_access_token="", disable_nsfw=disable_nsfw, sd_cpu_textencoder=cpu_textencoder, sd_controlnet_method=sd_controlnet_method, @@ -87,7 +86,6 @@ def test_local_file_path(sd_device, sampler): name=model_name, sd_controlnet=True, device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=True, @@ -125,7 +123,6 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler): name=model_name, sd_controlnet=True, device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=True, @@ -166,7 +163,6 @@ def test_controlnet_switch(sd_device, sampler): name=model_name, sd_controlnet=True, device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=True, diff --git a/lama_cleaner/tests/test_instruct_pix2pix.py b/lama_cleaner/tests/test_instruct_pix2pix.py index 7d633e3..b53eb8f 100644 --- a/lama_cleaner/tests/test_instruct_pix2pix.py +++ b/lama_cleaner/tests/test_instruct_pix2pix.py @@ -21,7 +21,6 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload): model = ModelManager( name=model_name, device=torch.device(device), - hf_access_token="", disable_nsfw=disable_nsfw, sd_cpu_textencoder=False, cpu_offload=cpu_offload, @@ -52,7 +51,6 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload): model = ModelManager( name=model_name, device=torch.device(device), - hf_access_token="", disable_nsfw=disable_nsfw, sd_cpu_textencoder=False, cpu_offload=cpu_offload, diff --git a/lama_cleaner/tests/test_model_md5.py b/lama_cleaner/tests/test_model_md5.py index 8811307..0713cdc 100644 --- a/lama_cleaner/tests/test_model_md5.py +++ b/lama_cleaner/tests/test_model_md5.py @@ -17,11 +17,9 @@ def test_load_model(): name=m, device="cpu", no_half=False, - hf_access_token="", disable_nsfw=False, sd_cpu_textencoder=True, cpu_offload=True, - enable_xformers=False, ) diff --git a/lama_cleaner/tests/test_model_switch.py b/lama_cleaner/tests/test_model_switch.py index 6da4343..a1fbec1 100644 --- a/lama_cleaner/tests/test_model_switch.py +++ b/lama_cleaner/tests/test_model_switch.py @@ -16,11 +16,9 @@ def test_model_switch(): sd_controlnet=True, sd_controlnet_method="lllyasviel/control_v11p_sd15_canny", device=torch.device("mps"), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=True, cpu_offload=False, - enable_xformers=False, callback=None, ) @@ -34,11 +32,9 @@ def test_controlnet_switch_onoff(caplog): sd_controlnet=True, sd_controlnet_method="lllyasviel/control_v11p_sd15_canny", device=torch.device("mps"), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=True, cpu_offload=False, - enable_xformers=False, callback=None, ) @@ -61,11 +57,9 @@ def test_controlnet_switch_method(caplog): sd_controlnet=True, sd_controlnet_method=old_method, device=torch.device("mps"), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=True, cpu_offload=False, - enable_xformers=False, callback=None, ) diff --git a/lama_cleaner/tests/test_outpainting.py b/lama_cleaner/tests/test_outpainting.py index f8ca7c5..17dc323 100644 --- a/lama_cleaner/tests/test_outpainting.py +++ b/lama_cleaner/tests/test_outpainting.py @@ -41,7 +41,6 @@ def test_outpainting(name, sd_device, rect): model = ModelManager( name=name, device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, callback=callback, @@ -86,7 +85,6 @@ def test_kandinsky_outpainting(name, sd_device, rect): model = ModelManager( name=name, device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, callback=callback, diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py index 477f9da..0b477dd 100644 --- a/lama_cleaner/tests/test_sd_model.py +++ b/lama_cleaner/tests/test_sd_model.py @@ -38,7 +38,6 @@ def test_runway_sd_1_5_all_samplers( model = ModelManager( name="runwayml/stable-diffusion-inpainting", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, ) @@ -69,7 +68,6 @@ def test_runway_sd_lcm_lora(sd_device, strategy, sampler): model = ModelManager( name="runwayml/stable-diffusion-inpainting", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, ) @@ -102,7 +100,6 @@ def test_runway_sd_freeu(sd_device, strategy, sampler): model = ModelManager( name="runwayml/stable-diffusion-inpainting", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, ) @@ -136,7 +133,6 @@ def test_runway_sd_sd_strength(sd_device, strategy, sampler): model = ModelManager( name="runwayml/stable-diffusion-inpainting", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, ) @@ -165,7 +161,6 @@ def test_runway_norm_sd_model(sd_device, strategy, sampler): model = ModelManager( name="runwayml/stable-diffusion-v1-5", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, ) @@ -192,7 +187,6 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler): model = ModelManager( name="runwayml/stable-diffusion-inpainting", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=True, @@ -229,7 +223,6 @@ def test_local_file_path(sd_device, sampler, name): model = ModelManager( name=name, device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=False, diff --git a/lama_cleaner/tests/test_sdxl.py b/lama_cleaner/tests/test_sdxl.py index 8ca0010..41ac20a 100644 --- a/lama_cleaner/tests/test_sdxl.py +++ b/lama_cleaner/tests/test_sdxl.py @@ -29,7 +29,6 @@ def test_sdxl(sd_device, strategy, sampler): model = ModelManager( name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, callback=callback, @@ -70,7 +69,6 @@ def test_sdxl_lcm_lora_and_freeu(sd_device, strategy, sampler): model = ModelManager( name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, callback=callback, @@ -131,7 +129,6 @@ def test_sdxl_outpainting(sd_device, rect): model = ModelManager( name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", device=torch.device(sd_device), - hf_access_token="", disable_nsfw=True, sd_cpu_textencoder=False, ) diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index 25d53c5..2f52d62 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -24,7 +24,6 @@ def save_config( cpu_offload, disable_nsfw, sd_cpu_textencoder, - enable_xformers, local_files_only, model_dir, input, @@ -102,9 +101,6 @@ def main(config_file: str): with gr.Column(): gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") - no_gui_auto_close = gr.Checkbox( - init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}" - ) with gr.Column(): model_dir = gr.Textbox( @@ -193,14 +189,11 @@ def main(config_file: str): init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}" ) sd_cpu_textencoder = gr.Checkbox( - init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}" + init_config.sd_cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}" ) disable_nsfw = gr.Checkbox( init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}" ) - enable_xformers = gr.Checkbox( - init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}" - ) local_files_only = gr.Checkbox( init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}" ) @@ -221,7 +214,6 @@ def main(config_file: str): cpu_offload, disable_nsfw, sd_cpu_textencoder, - enable_xformers, local_files_only, model_dir, input, diff --git a/web_app/src/components/Cropper.tsx b/web_app/src/components/Cropper.tsx index 87f7b94..2d4554b 100644 --- a/web_app/src/components/Cropper.tsx +++ b/web_app/src/components/Cropper.tsx @@ -71,6 +71,8 @@ const Cropper = (props: Props) => { setY, setWidth, setHeight, + isResizing, + setIsResizing, ] = useStore((state) => [ state.imageWidth, state.imageHeight, @@ -80,9 +82,11 @@ const Cropper = (props: Props) => { state.setCropperY, state.setCropperWidth, state.setCropperHeight, + state.isCropperExtenderResizing, + state.setIsCropperExtenderResizing, ]) - const [isResizing, setIsResizing] = useState(false) + // const [isResizing, setIsResizing] = useState(false) const [isMoving, setIsMoving] = useState(false) useEffect(() => { diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index 0de861a..9ef276e 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -65,6 +65,7 @@ export default function Editor(props: EditorProps) { updateAppState, runMannually, runInpainting, + isCropperExtenderResizing, ] = useStore((state) => [ state.disableShortCuts, state.windowSize, @@ -87,6 +88,7 @@ export default function Editor(props: EditorProps) { state.updateAppState, state.runMannually(), state.runInpainting, + state.isCropperExtenderResizing, ]) const baseBrushSize = useStore((state) => state.editorState.baseBrushSize) const brushSize = useStore((state) => state.getBrushSize()) @@ -537,7 +539,7 @@ export default function Editor(props: EditorProps) { } const toggleShowBrush = (newState: boolean) => { - if (newState !== showBrush && !isPanning) { + if (newState !== showBrush && !isPanning && !isCropperExtenderResizing) { setShowBrush(newState) } } @@ -693,7 +695,7 @@ export default function Editor(props: EditorProps) { limitToBounds={false} doubleClick={{ disabled: true }} initialScale={minScale} - minScale={minScale * 0.6} + minScale={minScale * 0.3} onPanning={(ref) => { if (!panned) { setPanned(true) diff --git a/web_app/src/components/Extender.tsx b/web_app/src/components/Extender.tsx index 14f71fd..1c37090 100644 --- a/web_app/src/components/Extender.tsx +++ b/web_app/src/components/Extender.tsx @@ -54,6 +54,8 @@ const Extender = (props: Props) => { setWidth, setHeight, extenderDirection, + isResizing, + setIsResizing, ] = useStore((state) => [ state.isInpainting, state.imageHeight, @@ -64,10 +66,10 @@ const Extender = (props: Props) => { state.setExtenderWidth, state.setExtenderHeight, state.settings.extenderDirection, + state.isCropperExtenderResizing, + state.setIsCropperExtenderResizing, ]) - const [isResizing, setIsResizing] = useState(false) - const [evData, setEVData] = useState({ initX: 0, initY: 0, @@ -122,10 +124,9 @@ const Extender = (props: Props) => { const moveBottom = () => { const newHeight = evData.initHeight + offsetY let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight) - if (extenderDirection === EXTENDER_ALL) { - if (clampedY + clampedHeight < imageHeight) { - clampedHeight = imageHeight + if (clampedHeight < Math.abs(clampedY) + imageHeight) { + clampedHeight = Math.abs(clampedY) + imageHeight } } setHeight(clampedHeight) @@ -155,8 +156,8 @@ const Extender = (props: Props) => { const newWidth = evData.initWidth + offsetX let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth) if (extenderDirection === EXTENDER_ALL) { - if (clampedX + clampedWidth < imageWdith) { - clampedWidth = imageWdith + if (clampedWidth < Math.abs(clampedX) + imageWdith) { + clampedWidth = Math.abs(clampedX) + imageWdith } } setWidth(clampedWidth) diff --git a/web_app/src/components/SidePanel.tsx b/web_app/src/components/SidePanel.tsx index e563540..535f1e7 100644 --- a/web_app/src/components/SidePanel.tsx +++ b/web_app/src/components/SidePanel.tsx @@ -105,18 +105,22 @@ const LabelTitle = ({ {text} - -

{toolTip}

- {url ? ( - - ) : ( - <> - )} -
+ {toolTip ? ( + +

{toolTip}

+ {url ? ( + + ) : ( + <> + )} +
+ ) : ( + <> + )} ) } @@ -172,7 +176,11 @@ const SidePanel = () => {
- + { return (
- + { return (
-
Example Image
+ { @@ -450,8 +465,9 @@ const SidePanel = () => { return (
{ } const renderExtender = () => { + if (!settings.model.support_outpainting) { + return null + } return ( <>
- + { >
- + { {renderExtender()}
- + void setExtenderWidth: (newValue: number) => void setExtenderHeight: (newValue: number) => void + setIsCropperExtenderResizing: (newValue: boolean) => void updateExtenderDirection: (newValue: string) => void resetExtender: (width: number, height: number) => void updateExtenderByBuiltIn: (direction: string, scale: number) => void @@ -261,6 +265,7 @@ const defaultValues: AppState = { width: 512, height: 512, }, + isCropperExtenderResizing: false, fileManagerState: { sortBy: SortBy.CTIME, @@ -889,6 +894,11 @@ export const useStore = createWithEqualityFn()( state.extenderState.height = newValue }), + setIsCropperExtenderResizing: (newValue: boolean) => + set((state) => { + state.isCropperExtenderResizing = newValue + }), + updateExtenderDirection: (newValue: string) => { console.log( `updateExtenderDirection: ${JSON.stringify(get().extenderState)}`