From 9a9eb8abfdc9ad86031608021322038802ea58f2 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 1 Dec 2023 10:15:35 +0800 Subject: [PATCH] wip --- lama_cleaner/const.py | 66 ++- lama_cleaner/diffusers_utils.py | 41 -- lama_cleaner/download.py | 138 ++++- lama_cleaner/file_manager/file_manager.py | 3 + lama_cleaner/file_manager/utils.py | 18 +- lama_cleaner/helper.py | 4 +- lama_cleaner/model/__init__.py | 33 ++ lama_cleaner/model/base.py | 9 +- lama_cleaner/model/controlnet.py | 286 +++------- lama_cleaner/model/fcf.py | 1 + .../model/helper/controlnet_preprocess.py | 46 ++ lama_cleaner/model/helper/cpu_text_encoder.py | 25 + .../model/{ => helper}/g_diffuser_bot.py | 0 lama_cleaner/model/instruct_pix2pix.py | 12 +- lama_cleaner/model/kandinsky.py | 1 - lama_cleaner/model/lama.py | 1 + lama_cleaner/model/ldm.py | 1 + lama_cleaner/model/manga.py | 1 + lama_cleaner/model/mat.py | 1 + lama_cleaner/model/mi_gan.py | 1 + lama_cleaner/model/opencv2.py | 1 + ...ine_stable_diffusion_controlnet_inpaint.py | 33 ++ lama_cleaner/model/sd.py | 44 +- lama_cleaner/model/sdxl.py | 32 +- lama_cleaner/model/zits.py | 1 + lama_cleaner/model_manager.py | 124 ++-- lama_cleaner/parse_args.py | 26 +- lama_cleaner/schema.py | 55 ++ lama_cleaner/server.py | 72 +-- web_app/package-lock.json | 384 +++++++++++++ web_app/package.json | 9 + web_app/src/App.tsx | 13 +- web_app/src/components/Cropper.tsx | 147 ++--- web_app/src/components/Editor.tsx | 536 ++++++++---------- web_app/src/components/FileManager.tsx | 29 +- web_app/src/components/Header.tsx | 99 ++-- web_app/src/components/ImageSize.tsx | 2 +- web_app/src/components/InteractiveSeg.tsx | 136 +++++ web_app/src/components/Plugins.tsx | 29 +- web_app/src/components/PromptInput.tsx | 8 +- web_app/src/components/Settings.tsx | 435 ++++++++++++++ web_app/src/components/Workspace.tsx | 76 +-- web_app/src/components/ui/alert-dialog.tsx | 139 +++++ web_app/src/components/ui/button.tsx | 2 +- web_app/src/components/ui/dialog.tsx | 2 +- web_app/src/components/ui/form.tsx | 176 ++++++ web_app/src/components/ui/separator.tsx | 29 + web_app/src/components/ui/slider.tsx | 5 +- web_app/src/globals.css | 4 +- web_app/src/lib/api.ts | 52 +- web_app/src/lib/states.ts | 277 +++++++-- web_app/src/lib/store.ts | 123 +--- web_app/src/lib/types.ts | 35 +- web_app/src/main.tsx | 12 +- web_app/tailwind.config.js | 12 +- 55 files changed, 2596 insertions(+), 1251 deletions(-) delete mode 100644 lama_cleaner/diffusers_utils.py create mode 100644 lama_cleaner/model/helper/controlnet_preprocess.py create mode 100644 lama_cleaner/model/helper/cpu_text_encoder.py rename lama_cleaner/model/{ => helper}/g_diffuser_bot.py (100%) create mode 100644 web_app/src/components/InteractiveSeg.tsx create mode 100644 web_app/src/components/Settings.tsx create mode 100644 web_app/src/components/ui/alert-dialog.tsx create mode 100644 web_app/src/components/ui/form.tsx create mode 100644 web_app/src/components/ui/separator.tsx diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 1b749a8..27a1ecd 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -4,16 +4,14 @@ from enum import Enum from pydantic import BaseModel -MPS_SUPPORT_MODELS = [ - "instruct_pix2pix", - "sd1.5", - "anything4", - "realisticVision1.4", - "sd2", - "paint_by_example", - "controlnet", - "kandinsky2.2", - "sdxl", +MPS_UNSUPPORT_MODELS = [ + "lama", + "ldm", + "zits", + "mat", + "fcf", + "cv2", + "manga", ] DEFAULT_MODEL = "lama" @@ -36,18 +34,13 @@ AVAILABLE_MODELS = [ "sdxl", ] SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] -MODELS_SUPPORT_FREEU = SD15_MODELS + ["sd2", "sdxl", "instruct_pix2pix"] -MODELS_SUPPORT_LCM_LORA = SD15_MODELS + ["sdxl"] - -FREEU_DEFAULT_CONFIGS = { - "sd2": dict(s1=0.9, s2=0.2, b1=1.1, b2=1.2), - "sdxl": dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2), - "sd1.5": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4), - "anything4": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4), - "realisticVision1.4": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4), - "instruct_pix2pix": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4), -} - +DIFFUSERS_MODEL_FP16_REVERSION = [ + "runwayml/stable-diffusion-inpainting", + "Sanster/anything-4.0-inpainting", + "Sanster/Realistic_Vision_V1.4-inpainting", + "stabilityai/stable-diffusion-2-inpainting", + "timbrooks/instruct-pix2pix", +] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] DEFAULT_DEVICE = "cuda" @@ -70,14 +63,29 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory. """ SD_CONTROLNET_HELP = """ -Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui. +Run Stable Diffusion normal or inpainting model with ControlNet. """ -DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny" +DEFAULT_SD_CONTROLNET_METHOD = "thibaud/controlnet-sd21-openpose-diffusers" SD_CONTROLNET_CHOICES = [ - "control_v11p_sd15_canny", - "control_v11p_sd15_openpose", - "control_v11p_sd15_inpaint", - "control_v11f1p_sd15_depth", + "lllyasviel/control_v11p_sd15_canny", + # "lllyasviel/control_v11p_sd15_seg", + "lllyasviel/control_v11p_sd15_openpose", + "lllyasviel/control_v11p_sd15_inpaint", + "lllyasviel/control_v11f1p_sd15_depth", +] + +DEFAULT_SD2_CONTROLNET_METHOD = "thibaud/controlnet-sd21-canny-diffusers" +SD2_CONTROLNET_CHOICES = [ + "thibaud/controlnet-sd21-canny-diffusers", + "thibaud/controlnet-sd21-depth-diffusers", + "thibaud/controlnet-sd21-openpose-diffusers", +] + +DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0" +SDXL_CONTROLNET_CHOICES = [ + "thibaud/controlnet-openpose-sdxl-1.0", + "diffusers/controlnet-canny-sdxl-1.0", + "diffusers/controlnet-depth-sdxl-1.0", ] SD_LOCAL_MODEL_HELP = """ @@ -152,7 +160,7 @@ class Config(BaseModel): model: str = DEFAULT_MODEL sd_local_model_path: str = None sd_controlnet: bool = False - sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD + sd_controlnet_method: str = DEFAULT_SD_CONTROLNET_METHOD device: str = DEFAULT_DEVICE gui: bool = False no_gui_auto_close: bool = False diff --git a/lama_cleaner/diffusers_utils.py b/lama_cleaner/diffusers_utils.py deleted file mode 100644 index d67c80c..0000000 --- a/lama_cleaner/diffusers_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -import json -from pathlib import Path -from typing import Dict, List - - -def folder_name_to_show_name(name: str) -> str: - return name.replace("models--", "").replace("--", "/") - - -def _scan_models(cache_dir, class_name: List[str]) -> List[str]: - cache_dir = Path(cache_dir) - res = [] - for it in cache_dir.glob("**/*/model_index.json"): - with open(it, "r", encoding="utf-8") as f: - data = json.load(f) - if data["_class_name"] in class_name: - name = folder_name_to_show_name(it.parent.parent.parent.name) - if name not in res: - res.append(name) - return res - - -def scan_models(cache_dir) -> Dict[str, List[str]]: - return { - "sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]), - "sd_inpaint": _scan_models( - cache_dir, - [ - "StableDiffusionInpaintPipeline", - "StableDiffusionXLInpaintPipeline", - "KandinskyV22InpaintPipeline", - ], - ), - "other": _scan_models( - cache_dir, - [ - "StableDiffusionInstructPix2PixPipeline", - "PaintByExamplePipeline", - ], - ), - } diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py index f0f7cb2..7bb2d42 100644 --- a/lama_cleaner/download.py +++ b/lama_cleaner/download.py @@ -1,8 +1,20 @@ +import json import os +from typing import List from loguru import logger from pathlib import Path +from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION +from lama_cleaner.schema import ( + ModelInfo, + ModelType, + DIFFUSERS_SD_INPAINT_CLASS_NAME, + DIFFUSERS_SDXL_INPAINT_CLASS_NAME, + DIFFUSERS_SD_CLASS_NAME, + DIFFUSERS_SDXL_CLASS_NAME, +) + def cli_download_model(model: str, model_dir: str): if os.path.isfile(model_dir): @@ -14,7 +26,7 @@ def cli_download_model(model: str, model_dir: str): os.environ["XDG_CACHE_HOME"] = model_dir - from lama_cleaner.model_manager import models + from lama_cleaner.model import models if model in models: logger.info(f"Downloading {model}...") @@ -22,3 +34,127 @@ def cli_download_model(model: str, model_dir: str): logger.info(f"Done.") else: logger.info(f"Downloading model from Huggingface: {model}") + from diffusers import DiffusionPipeline + + downloaded_path = DiffusionPipeline.download( + pretrained_model_name=model, + revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main", + resume_download=True, + ) + logger.info(f"Done. Downloaded to {downloaded_path}") + + +def folder_name_to_show_name(name: str) -> str: + return name.replace("models--", "").replace("--", "/") + + +def scan_diffusers_models( + cache_dir, class_name: List[str], model_type: ModelType +) -> List[ModelInfo]: + cache_dir = Path(cache_dir) + res = [] + for it in cache_dir.glob("**/*/model_index.json"): + with open(it, "r", encoding="utf-8") as f: + data = json.load(f) + if data["_class_name"] in class_name: + name = folder_name_to_show_name(it.parent.parent.parent.name) + if name not in res: + res.append( + ModelInfo( + name=name, + path=name, + model_type=model_type, + ) + ) + return res + + +def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: + cache_dir = Path(cache_dir) + res = [] + for it in cache_dir.glob(f"*.*"): + if it.suffix not in [".safetensors", ".ckpt"]: + continue + if "inpaint" in str(it).lower(): + if "sdxl" in str(it).lower(): + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + else: + model_type = ModelType.DIFFUSERS_SD_INPAINT + else: + if "sdxl" in str(it).lower(): + model_type = ModelType.DIFFUSERS_SDXL + else: + model_type = ModelType.DIFFUSERS_SD + res.append( + ModelInfo( + name=it.name, + path=str(it.absolute()), + model_type=model_type, + is_single_file_diffusers=True, + ) + ) + return res + + +def scan_inpaint_models() -> List[ModelInfo]: + res = [] + from lama_cleaner.model import models + + for name, m in models.items(): + if m.is_erase_model: + res.append( + ModelInfo( + name=name, + path=name, + model_type=ModelType.INPAINT, + ) + ) + return res + + +def scan_models() -> List[ModelInfo]: + from diffusers.utils import DIFFUSERS_CACHE + + available_models = [] + available_models.extend(scan_inpaint_models()) + available_models.extend( + scan_single_file_diffusion_models(os.environ["XDG_CACHE_HOME"]) + ) + + cache_dir = Path(DIFFUSERS_CACHE) + diffusers_model_names = [] + for it in cache_dir.glob("**/*/model_index.json"): + with open(it, "r", encoding="utf-8") as f: + data = json.load(f) + _class_name = data["_class_name"] + name = folder_name_to_show_name(it.parent.parent.parent.name) + if name in diffusers_model_names: + continue + + if _class_name == DIFFUSERS_SD_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD + elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD_INPAINT + elif _class_name == DIFFUSERS_SDXL_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL + elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + elif _class_name in [ + "StableDiffusionInstructPix2PixPipeline", + "PaintByExamplePipeline", + "KandinskyV22InpaintPipeline", + ]: + model_type = ModelType.DIFFUSERS_OTHER + else: + continue + + diffusers_model_names.append(name) + available_models.append( + ModelInfo( + name=name, + path=name, + model_type=model_type, + ) + ) + + return available_models diff --git a/lama_cleaner/file_manager/file_manager.py b/lama_cleaner/file_manager/file_manager.py index 73d3897..0b513c1 100644 --- a/lama_cleaner/file_manager/file_manager.py +++ b/lama_cleaner/file_manager/file_manager.py @@ -7,6 +7,7 @@ import time from io import BytesIO from pathlib import Path import numpy as np + # from watchdog.events import FileSystemEventHandler # from watchdog.observers import Observer @@ -149,6 +150,7 @@ class FileManager: def get_thumbnail( self, directory: Path, original_filename: str, width, height, **options ): + directory = Path(directory) storage = FilesystemStorageBackend(self.app) crop = options.get("crop", "fit") background = options.get("background") @@ -167,6 +169,7 @@ class FileManager: thumbnail_size = (width, height) thumbnail_filename = generate_filename( + directory, original_filename, aspect_to_string(thumbnail_size), crop, diff --git a/lama_cleaner/file_manager/utils.py b/lama_cleaner/file_manager/utils.py index 2a05671..f6890af 100644 --- a/lama_cleaner/file_manager/utils.py +++ b/lama_cleaner/file_manager/utils.py @@ -1,19 +1,17 @@ # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py -import importlib -import os +import hashlib from pathlib import Path from typing import Union -def generate_filename(original_filename, *options): - name, ext = os.path.splitext(original_filename) +def generate_filename(directory: Path, original_filename, *options) -> str: + text = str(directory.absolute()) + original_filename for v in options: - if v: - name += "_%s" % v - name += ext - - return name + text += "%s" % v + md5_hash = hashlib.md5() + md5_hash.update(text.encode("utf-8")) + return md5_hash.hexdigest() + ".jpg" def parse_size(size): @@ -48,7 +46,7 @@ def aspect_to_string(size): return "x".join(map(str, size)) -IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'} +IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"} def glob_img(p: Union[Path, str], recursive: bool = False): diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index babbeac..1c12128 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -8,7 +8,7 @@ import cv2 from PIL import Image, ImageOps, PngImagePlugin import numpy as np import torch -from lama_cleaner.const import MPS_SUPPORT_MODELS +from lama_cleaner.const import MPS_UNSUPPORT_MODELS from loguru import logger from torch.hub import download_url_to_file, get_dir import hashlib @@ -23,7 +23,7 @@ def md5sum(filename): def switch_mps_device(model_name, device): - if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps": + if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps": logger.info(f"{model_name} not support mps, switch to cpu") return torch.device("cpu") return device diff --git a/lama_cleaner/model/__init__.py b/lama_cleaner/model/__init__.py index e69de29..1892ab7 100644 --- a/lama_cleaner/model/__init__.py +++ b/lama_cleaner/model/__init__.py @@ -0,0 +1,33 @@ +from .controlnet import ControlNet +from .fcf import FcF +from .instruct_pix2pix import InstructPix2Pix +from .kandinsky import Kandinsky22 +from .lama import LaMa +from .ldm import LDM +from .manga import Manga +from .mat import MAT +from .mi_gan import MIGAN +from .opencv2 import OpenCV2 +from .paint_by_example import PaintByExample +from .sd import SD15, SD2, Anything4, RealisticVision14, SD +from .sdxl import SDXL +from .zits import ZITS + +models = { + LaMa.name: LaMa, + LDM.name: LDM, + ZITS.name: ZITS, + MAT.name: MAT, + FcF.name: FcF, + OpenCV2.name: OpenCV2, + Manga.name: Manga, + MIGAN.name: MIGAN, + SD15.name: SD15, + Anything4.name: Anything4, + RealisticVision14.name: RealisticVision14, + SD2.name: SD2, + PaintByExample.name: PaintByExample, + InstructPix2Pix.name: InstructPix2Pix, + Kandinsky22.name: Kandinsky22, + SDXL.name: SDXL, +} diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 81cdf0c..193746a 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -12,7 +12,7 @@ from lama_cleaner.helper import ( pad_img_to_modulo, switch_mps_device, ) -from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb +from lama_cleaner.model.helper.g_diffuser_bot import expand_image from lama_cleaner.model.utils import get_scheduler from lama_cleaner.schema import Config, HDStrategy, SDSampler @@ -22,6 +22,7 @@ class InpaintModel: min_size: Optional[int] = None pad_mod = 8 pad_to_square = False + is_erase_model = False def __init__(self, device, **kwargs): """ @@ -264,6 +265,12 @@ class InpaintModel: class DiffusionInpaintModel(InpaintModel): + def __init__(self, device, **kwargs): + if kwargs.get("model_id_or_path"): + # 用于自定义 diffusers 模型 + self.model_id_or_path = kwargs["model_id_or_path"] + super().__init__(device, **kwargs) + @torch.no_grad() def __call__(self, image, mask, config: Config): """ diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index 9becdcf..485585e 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -1,5 +1,3 @@ -import gc - import PIL.Image import cv2 import numpy as np @@ -7,107 +5,26 @@ import torch from diffusers import ControlNetModel from loguru import logger +from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.model.utils import torch_gc, get_scheduler -from lama_cleaner.schema import Config +from lama_cleaner.model.helper.controlnet_preprocess import ( + make_canny_control_image, + make_openpose_control_image, + make_depth_control_image, + make_inpaint_control_image, +) +from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper +from lama_cleaner.model.utils import get_scheduler +from lama_cleaner.schema import Config, ModelInfo, ModelType - -class CPUTextEncoderWrapper(torch.nn.Module): - def __init__(self, text_encoder, torch_dtype): - super().__init__() - self.config = text_encoder.config - self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) - self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) - self.torch_dtype = torch_dtype - del text_encoder - torch_gc() - - def __call__(self, x, **kwargs): - input_device = x.device - return [ - self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] - .to(input_device) - .to(self.torch_dtype) - ] - - @property - def dtype(self): - return self.torch_dtype - - -NAMES_MAP = { - "sd1.5": "runwayml/stable-diffusion-inpainting", - "anything4": "Sanster/anything-4.0-inpainting", - "realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting", +# 为了兼容性 +controlnet_name_map = { + "control_v11p_sd15_canny": "lllyasviel/control_v11p_sd15_canny", + "control_v11p_sd15_openpose": "lllyasviel/control_v11p_sd15_openpose", + "control_v11p_sd15_inpaint": "lllyasviel/control_v11p_sd15_inpaint", + "control_v11f1p_sd15_depth": "lllyasviel/control_v11f1p_sd15_depth", } -NATIVE_NAMES_MAP = { - "sd1.5": "runwayml/stable-diffusion-v1-5", - "anything4": "andite/anything-v4.0", - "realisticVision1.4": "SG161222/Realistic_Vision_V1.4", -} - - -def make_inpaint_condition(image, image_mask): - """ - image: [H, W, C] RGB - mask: [H, W, 1] 255 means area to repaint - """ - image = image.astype(np.float32) / 255.0 - image[image_mask[:, :, -1] > 128] = -1.0 # set as masked pixel - image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return image - - -def load_from_local_model( - local_model_path, torch_dtype, controlnet, pipe_class, is_native_control_inpaint -): - from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - download_from_original_stable_diffusion_ckpt, - ) - - logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline") - - try: - pipe = download_from_original_stable_diffusion_ckpt( - local_model_path, - num_in_channels=4 if is_native_control_inpaint else 9, - from_safetensors=local_model_path.endswith("safetensors"), - device="cpu", - load_safety_checker=False, - ) - except Exception as e: - err_msg = str(e) - logger.exception(e) - if is_native_control_inpaint and "[320, 9, 3, 3]" in err_msg: - logger.error( - "control_v11p_sd15_inpaint method requires normal SD model, not inpainting SD model" - ) - if not is_native_control_inpaint and "[320, 4, 3, 3]" in err_msg: - logger.error( - f"{controlnet.config['_name_or_path']} method requires inpainting SD model, " - f"you can convert any SD model to inpainting model in AUTO1111: \n" - f"https://www.reddit.com/r/StableDiffusion/comments/zyi24j/how_to_turn_any_model_into_an_inpainting_model/" - ) - exit(-1) - - inpaint_pipe = pipe_class( - vae=pipe.vae, - text_encoder=pipe.text_encoder, - tokenizer=pipe.tokenizer, - unet=pipe.unet, - controlnet=controlnet, - scheduler=pipe.scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - - del pipe - gc.collect() - return inpaint_pipe.to(torch_dtype=torch_dtype) - class ControlNet(DiffusionInpaintModel): name = "controlnet" @@ -116,10 +33,16 @@ class ControlNet(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): fp16 = not kwargs.get("no_half", False) + model_info: ModelInfo = kwargs["model_info"] + sd_controlnet_method = kwargs["sd_controlnet_method"] + sd_controlnet_method = controlnet_name_map.get( + sd_controlnet_method, sd_controlnet_method + ) - model_kwargs = { - "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) - } + self.model_info = model_info + self.sd_controlnet_method = sd_controlnet_method + + model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( @@ -133,41 +56,39 @@ class ControlNet(DiffusionInpaintModel): use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 - sd_controlnet_method = kwargs["sd_controlnet_method"] - self.sd_controlnet_method = sd_controlnet_method - - if sd_controlnet_method == "control_v11p_sd15_inpaint": - from diffusers import StableDiffusionControlNetPipeline as PipeClass - - self.is_native_control_inpaint = True - else: - from .pipeline import StableDiffusionControlNetInpaintPipeline as PipeClass - - self.is_native_control_inpaint = False - - if self.is_native_control_inpaint: - model_id = NATIVE_NAMES_MAP[kwargs["name"]] - else: - model_id = NAMES_MAP[kwargs["name"]] + if model_info.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SD_INPAINT, + ]: + from diffusers import ( + StableDiffusionControlNetInpaintPipeline as PipeClass, + ) + elif model_info.model_type in [ + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SDXL_INPAINT, + ]: + from diffusers import ( + StableDiffusionXLControlNetInpaintPipeline as PipeClass, + ) controlnet = ControlNetModel.from_pretrained( - f"lllyasviel/{sd_controlnet_method}", torch_dtype=torch_dtype + sd_controlnet_method, torch_dtype=torch_dtype ) - self.is_local_sd_model = False - if kwargs.get("sd_local_model_path", None): - self.is_local_sd_model = True - self.model = load_from_local_model( - kwargs["sd_local_model_path"], - torch_dtype=torch_dtype, - controlnet=controlnet, - pipe_class=PipeClass, - is_native_control_inpaint=self.is_native_control_inpaint, - ) + if model_info.is_single_file_diffusers: + self.model = PipeClass.from_single_file( + model_info.path, controlnet=controlnet + ).to(torch_dtype) else: self.model = PipeClass.from_pretrained( - model_id, + model_info.path, controlnet=controlnet, - revision="fp16" if use_gpu and fp16 else "main", + revision="fp16" + if ( + model_info.path in DIFFUSERS_MODEL_FP16_REVERSION + and use_gpu + and fp16 + ) + else "main", torch_dtype=torch_dtype, **model_kwargs, ) @@ -191,6 +112,19 @@ class ControlNet(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) + def _get_control_image(self, image, mask): + if "canny" in self.sd_controlnet_method: + control_image = make_canny_control_image(image) + elif "openpose" in self.sd_controlnet_method: + control_image = make_openpose_control_image(image) + elif "depth" in self.sd_controlnet_method: + control_image = make_depth_control_image(image) + elif "inpaint" in self.sd_controlnet_method: + control_image = make_inpaint_control_image(image, mask) + else: + raise NotImplementedError(f"{self.sd_controlnet_method} not implemented") + return control_image + def forward(self, image, mask, config: Config): """Input image and output image have same size image: [H, W, C] RGB @@ -206,84 +140,30 @@ class ControlNet(DiffusionInpaintModel): mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] img_h, img_w = image.shape[:2] + control_image = self._get_control_image(image, mask) + mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") + image = PIL.Image.fromarray(image) - if self.is_native_control_inpaint: - control_image = make_inpaint_condition(image, mask) - output = self.model( - prompt=config.prompt, - image=control_image, - height=img_h, - width=img_w, - num_inference_steps=config.sd_steps, - guidance_scale=config.sd_guidance_scale, - controlnet_conditioning_scale=config.controlnet_conditioning_scale, - negative_prompt=config.negative_prompt, - generator=torch.manual_seed(config.sd_seed), - output_type="np", - callback=self.callback, - ).images[0] - else: - if "canny" in self.sd_controlnet_method: - canny_image = cv2.Canny(image, 100, 200) - canny_image = canny_image[:, :, None] - canny_image = np.concatenate( - [canny_image, canny_image, canny_image], axis=2 - ) - canny_image = PIL.Image.fromarray(canny_image) - control_image = canny_image - elif "openpose" in self.sd_controlnet_method: - from controlnet_aux import OpenposeDetector - - processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") - control_image = processor(image, hand_and_face=True) - elif "depth" in self.sd_controlnet_method: - from transformers import pipeline - - depth_estimator = pipeline("depth-estimation") - depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"] - depth_image = np.array(depth_image) - depth_image = depth_image[:, :, None] - depth_image = np.concatenate( - [depth_image, depth_image, depth_image], axis=2 - ) - control_image = PIL.Image.fromarray(depth_image) - else: - raise NotImplementedError( - f"{self.sd_controlnet_method} not implemented" - ) - - mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") - image = PIL.Image.fromarray(image) - - output = self.model( - image=image, - control_image=control_image, - prompt=config.prompt, - negative_prompt=config.negative_prompt, - mask_image=mask_image, - num_inference_steps=config.sd_steps, - guidance_scale=config.sd_guidance_scale, - output_type="np", - callback=self.callback, - height=img_h, - width=img_w, - generator=torch.manual_seed(config.sd_seed), - controlnet_conditioning_scale=config.controlnet_conditioning_scale, - ).images[0] + output = self.model( + image=image, + mask_image=mask_image, + control_image=control_image, + prompt=config.prompt, + negative_prompt=config.negative_prompt, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + controlnet_conditioning_scale=config.controlnet_conditioning_scale, + ).images[0] output = (output * 255).round().astype("uint8") output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - def forward_post_process(self, result, image, mask, config): - if config.sd_match_histograms: - result = self._match_histograms(result, image[:, :, ::-1], mask) - - if config.sd_mask_blur != 0: - k = 2 * config.sd_mask_blur + 1 - mask = cv2.GaussianBlur(mask, (k, k), 0) - return result, image, mask - @staticmethod def is_downloaded() -> bool: # model will be downloaded when app start, and can't switch in frontend settings diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py index a64885a..9cfc2be 100644 --- a/lama_cleaner/model/fcf.py +++ b/lama_cleaner/model/fcf.py @@ -1626,6 +1626,7 @@ class FcF(InpaintModel): min_size = 512 pad_mod = 512 pad_to_square = True + is_erase_model = True def init_model(self, device, **kwargs): seed = 0 diff --git a/lama_cleaner/model/helper/controlnet_preprocess.py b/lama_cleaner/model/helper/controlnet_preprocess.py new file mode 100644 index 0000000..1ab1c80 --- /dev/null +++ b/lama_cleaner/model/helper/controlnet_preprocess.py @@ -0,0 +1,46 @@ +import torch +import PIL +import cv2 +from PIL import Image +import numpy as np + + +def make_canny_control_image(image: np.ndarray) -> Image: + canny_image = cv2.Canny(image, 100, 200) + canny_image = canny_image[:, :, None] + canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) + canny_image = PIL.Image.fromarray(canny_image) + control_image = canny_image + return control_image + + +def make_openpose_control_image(image: np.ndarray) -> Image: + from controlnet_aux import OpenposeDetector + + processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + control_image = processor(image, hand_and_face=True) + return control_image + + +def make_depth_control_image(image: np.ndarray) -> Image: + from transformers import pipeline + + depth_estimator = pipeline("depth-estimation") + depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"] + depth_image = np.array(depth_image) + depth_image = depth_image[:, :, None] + depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2) + control_image = PIL.Image.fromarray(depth_image) + return control_image + + +def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor: + """ + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + """ + image = image.astype(np.float32) / 255.0 + image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image diff --git a/lama_cleaner/model/helper/cpu_text_encoder.py b/lama_cleaner/model/helper/cpu_text_encoder.py new file mode 100644 index 0000000..66cc86c --- /dev/null +++ b/lama_cleaner/model/helper/cpu_text_encoder.py @@ -0,0 +1,25 @@ +import torch +from lama_cleaner.model.utils import torch_gc + + +class CPUTextEncoderWrapper(torch.nn.Module): + def __init__(self, text_encoder, torch_dtype): + super().__init__() + self.config = text_encoder.config + self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) + self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) + self.torch_dtype = torch_dtype + del text_encoder + torch_gc() + + def __call__(self, x, **kwargs): + input_device = x.device + return [ + self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] + .to(input_device) + .to(self.torch_dtype) + ] + + @property + def dtype(self): + return self.torch_dtype diff --git a/lama_cleaner/model/g_diffuser_bot.py b/lama_cleaner/model/helper/g_diffuser_bot.py similarity index 100% rename from lama_cleaner/model/g_diffuser_bot.py rename to lama_cleaner/model/helper/g_diffuser_bot.py diff --git a/lama_cleaner/model/instruct_pix2pix.py b/lama_cleaner/model/instruct_pix2pix.py index 27476b7..b569092 100644 --- a/lama_cleaner/model/instruct_pix2pix.py +++ b/lama_cleaner/model/instruct_pix2pix.py @@ -17,7 +17,7 @@ class InstructPix2Pix(DiffusionInpaintModel): fp16 = not kwargs.get("no_half", False) - model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)} + model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( @@ -77,16 +77,6 @@ class InstructPix2Pix(DiffusionInpaintModel): output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) return output - # - # def forward_post_process(self, result, image, mask, config): - # if config.sd_match_histograms: - # result = self._match_histograms(result, image[:, :, ::-1], mask) - # - # if config.sd_mask_blur != 0: - # k = 2 * config.sd_mask_blur + 1 - # mask = cv2.GaussianBlur(mask, (k, k), 0) - # return result, image, mask - @staticmethod def is_downloaded() -> bool: # model will be downloaded when app start, and can't switch in frontend settings diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index 12d8209..38dfc33 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -20,7 +20,6 @@ class Kandinsky(DiffusionInpaintModel): torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 model_kwargs = { - "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]), "torch_dtype": torch_dtype, } diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index ebbb6c9..f1dd239 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -23,6 +23,7 @@ LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e class LaMa(InpaintModel): name = "lama" pad_mod = 8 + is_erase_model = True @staticmethod def download(): diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index 29956ad..4066ad3 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -237,6 +237,7 @@ class LatentDiffusion(DDPM): class LDM(InpaintModel): name = "ldm" pad_mod = 32 + is_erase_model = True def __init__(self, device, fp16: bool = True, **kwargs): self.fp16 = fp16 diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py index d61f408..76e4c86 100644 --- a/lama_cleaner/model/manga.py +++ b/lama_cleaner/model/manga.py @@ -32,6 +32,7 @@ MANGA_LINE_MODEL_MD5 = os.environ.get( class Manga(InpaintModel): name = "manga" pad_mod = 16 + is_erase_model = True def init_model(self, device, **kwargs): self.inpaintor_model = load_jit_model( diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py index 2fbe11c..49ad54e 100644 --- a/lama_cleaner/model/mat.py +++ b/lama_cleaner/model/mat.py @@ -1880,6 +1880,7 @@ class MAT(InpaintModel): min_size = 512 pad_mod = 512 pad_to_square = True + is_erase_model = True def init_model(self, device, **kwargs): seed = 240 # pick up a random number diff --git a/lama_cleaner/model/mi_gan.py b/lama_cleaner/model/mi_gan.py index 1b2ba1d..3e3f200 100644 --- a/lama_cleaner/model/mi_gan.py +++ b/lama_cleaner/model/mi_gan.py @@ -26,6 +26,7 @@ class MIGAN(InpaintModel): min_size = 512 pad_mod = 512 pad_to_square = True + is_erase_model = True def init_model(self, device, **kwargs): self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval() diff --git a/lama_cleaner/model/opencv2.py b/lama_cleaner/model/opencv2.py index e0618dd..cfbde9e 100644 --- a/lama_cleaner/model/opencv2.py +++ b/lama_cleaner/model/opencv2.py @@ -8,6 +8,7 @@ flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA} class OpenCV2(InpaintModel): name = "cv2" pad_mod = 1 + is_erase_model = True @staticmethod def is_downloaded() -> bool: diff --git a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py index 01d73d0..f65e95d 100644 --- a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py +++ b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc from typing import Union, List, Optional, Callable, Dict, Any # Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py @@ -217,6 +218,38 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_from_original_stable_diffusion_ckpt, + ) + + controlnet = kwargs.pop("controlnet", None) + + pipe = download_from_original_stable_diffusion_ckpt( + pretrained_model_link_or_path, + num_in_channels=9, + from_safetensors=pretrained_model_link_or_path.endswith("safetensors"), + device="cpu", + load_safety_checker=False, + ) + + inpaint_pipe = cls( + vae=pipe.vae, + text_encoder=pipe.text_encoder, + tokenizer=pipe.tokenizer, + unet=pipe.unet, + controlnet=controlnet, + scheduler=pipe.scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + + del pipe + gc.collect() + return inpaint_pipe + def prepare_mask_latents( self, mask, diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 0c0b190..a81a849 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -1,4 +1,4 @@ -import gc +import os import PIL.Image import cv2 @@ -6,34 +6,12 @@ import numpy as np import torch from loguru import logger +from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.model.utils import torch_gc +from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.schema import Config -class CPUTextEncoderWrapper(torch.nn.Module): - def __init__(self, text_encoder, torch_dtype): - super().__init__() - self.config = text_encoder.config - self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) - self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) - self.torch_dtype = torch_dtype - del text_encoder - torch_gc() - - def __call__(self, x, **kwargs): - input_device = x.device - return [ - self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] - .to(input_device) - .to(self.torch_dtype) - ] - - @property - def dtype(self): - return self.torch_dtype - - class SD(DiffusionInpaintModel): pad_mod = 8 min_size = 512 @@ -44,9 +22,7 @@ class SD(DiffusionInpaintModel): fp16 = not kwargs.get("no_half", False) - model_kwargs = { - "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) - } + model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( @@ -60,14 +36,20 @@ class SD(DiffusionInpaintModel): use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 - if kwargs.get("sd_local_model_path", None): + if os.path.isfile(self.model_id_or_path): self.model = StableDiffusionInpaintPipeline.from_single_file( - kwargs["sd_local_model_path"], torch_dtype=torch_dtype, **model_kwargs + self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs ) else: self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model_id_or_path, - revision="fp16" if use_gpu and fp16 else "main", + revision="fp16" + if ( + self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION + and use_gpu + and fp16 + ) + else "main", torch_dtype=torch_dtype, use_auth_token=kwargs["hf_access_token"], **model_kwargs, diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index f64bfa8..af04941 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -1,7 +1,10 @@ +import os + import PIL.Image import cv2 import numpy as np import torch +from diffusers import AutoencoderKL from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel @@ -13,26 +16,31 @@ class SDXL(DiffusionInpaintModel): pad_mod = 8 min_size = 512 lcm_lora_id = "latent-consistency/lcm-lora-sdxl" + model_id_or_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1" def init_model(self, device: torch.device, **kwargs): - from diffusers.pipelines import AutoPipelineForInpainting + from diffusers.pipelines import StableDiffusionXLInpaintPipeline fp16 = not kwargs.get("no_half", False) - model_kwargs = { - "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) - } - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 - self.model = AutoPipelineForInpainting.from_pretrained( - "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", - revision="main", - torch_dtype=torch_dtype, - use_auth_token=kwargs["hf_access_token"], - **model_kwargs, - ) + if os.path.isfile(self.model_id_or_path): + self.model = StableDiffusionXLInpaintPipeline.from_single_file( + self.model_id_or_path, torch_dtype=torch_dtype + ) + else: + vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 + ) + self.model = StableDiffusionXLInpaintPipeline.from_pretrained( + self.model_id_or_path, + revision="main", + torch_dtype=torch_dtype, + use_auth_token=kwargs["hf_access_token"], + vae=vae, + ) # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing self.model.enable_attention_slicing() diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index 748623e..22fae23 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -226,6 +226,7 @@ class ZITS(InpaintModel): min_size = 256 pad_mod = 32 pad_to_square = True + is_erase_model = True def __init__(self, device, **kwargs): """ diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 8461c95..90e0f12 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -1,49 +1,14 @@ -import torch import gc +from typing import List, Dict +import torch from loguru import logger -from lama_cleaner.const import ( - SD15_MODELS, - MODELS_SUPPORT_FREEU, - MODELS_SUPPORT_LCM_LORA, -) +from lama_cleaner.download import scan_models from lama_cleaner.helper import switch_mps_device -from lama_cleaner.model.controlnet import ControlNet -from lama_cleaner.model.fcf import FcF -from lama_cleaner.model.kandinsky import Kandinsky22 -from lama_cleaner.model.lama import LaMa -from lama_cleaner.model.ldm import LDM -from lama_cleaner.model.manga import Manga -from lama_cleaner.model.mat import MAT -from lama_cleaner.model.mi_gan import MIGAN -from lama_cleaner.model.paint_by_example import PaintByExample -from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix -from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 -from lama_cleaner.model.sdxl import SDXL +from lama_cleaner.model import models, ControlNet, SD, SDXL from lama_cleaner.model.utils import torch_gc -from lama_cleaner.model.zits import ZITS -from lama_cleaner.model.opencv2 import OpenCV2 -from lama_cleaner.schema import Config - -models = { - "lama": LaMa, - "ldm": LDM, - "zits": ZITS, - "mat": MAT, - "fcf": FcF, - SD15.name: SD15, - Anything4.name: Anything4, - RealisticVision14.name: RealisticVision14, - "cv2": OpenCV2, - "manga": Manga, - "sd2": SD2, - "paint_by_example": PaintByExample, - "instruct_pix2pix": InstructPix2Pix, - Kandinsky22.name: Kandinsky22, - SDXL.name: SDXL, - MIGAN.name: MIGAN, -} +from lama_cleaner.schema import Config, ModelInfo, ModelType class ModelManager: @@ -51,23 +16,39 @@ class ModelManager: self.name = name self.device = device self.kwargs = kwargs + self.available_models: Dict[str, ModelInfo] = {} + self.scan_models() self.model = self.init_model(name, device, **kwargs) def init_model(self, name: str, device, **kwargs): - if name in SD15_MODELS and kwargs.get("sd_controlnet", False): - return ControlNet(device, **{**kwargs, "name": name}) + for old_name, model_cls in models.items(): + if name == old_name and hasattr(model_cls, "model_id_or_path"): + name = model_cls.model_id_or_path + if name not in self.available_models: + raise NotImplementedError(f"Unsupported model: {name}") - if name in models: - model = models[name](device, **kwargs) - else: - raise NotImplementedError(f"Not supported model: {name}") - return model + sd_controlnet_enabled = kwargs.get("sd_controlnet", False) + model_info = self.available_models[name] + if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]: + return models[name](device, **kwargs) - def is_downloaded(self, name: str) -> bool: - if name in models: - return models[name].is_downloaded() + if sd_controlnet_enabled: + return ControlNet(device, **{**kwargs, "model_info": model_info}) else: - raise NotImplementedError(f"Not supported model: {name}") + if model_info.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ]: + raise NotImplementedError( + f"When using non inpaint Stable Diffusion model, you must enable controlnet" + ) + if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT: + return SD(device, model_id_or_path=model_info.path, **kwargs) + + if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT: + return SDXL(device, model_id_or_path=model_info.path, **kwargs) + + raise NotImplementedError(f"Unsupported model: {name}") def __call__(self, image, mask, config: Config): self.switch_controlnet_method(control_method=config.controlnet_method) @@ -75,9 +56,18 @@ class ModelManager: self.enable_disable_lcm_lora(config) return self.model(image, mask, config) - def switch(self, new_name: str, **kwargs): + def scan_models(self) -> List[ModelInfo]: + available_models = scan_models() + self.available_models = {it.name: it for it in available_models} + return available_models + + def switch(self, new_name: str): if new_name == self.name: return + + old_name = self.name + self.name = new_name + try: if torch.cuda.memory_allocated() > 0: # Clear current loaded model from memory @@ -88,8 +78,8 @@ class ModelManager: self.model = self.init_model( new_name, switch_mps_device(new_name, self.device), **self.kwargs ) - self.name = new_name - except NotImplementedError as e: + except Exception as e: + self.name = old_name raise e def switch_controlnet_method(self, control_method: str): @@ -97,27 +87,9 @@ class ModelManager: return if self.kwargs["sd_controlnet_method"] == control_method: return - if not hasattr(self.model, "is_local_sd_model"): - return - if self.model.is_local_sd_model: - # is_native_control_inpaint 表示加载了普通 SD 模型 - if ( - self.model.is_native_control_inpaint - and control_method != "control_v11p_sd15_inpaint" - ): - raise RuntimeError( - f"--sd-local-model-path load a normal SD model, " - f"to use {control_method} you should load an inpainting SD model" - ) - elif ( - not self.model.is_native_control_inpaint - and control_method == "control_v11p_sd15_inpaint" - ): - raise RuntimeError( - f"--sd-local-model-path load an inpainting SD model, " - f"to use {control_method} you should load a norml SD model" - ) + if not self.available_models[self.name].support_controlnet(): + return del self.model torch_gc() @@ -133,7 +105,7 @@ class ModelManager: if str(self.model.device) == "mps": return - if self.name in MODELS_SUPPORT_FREEU: + if self.available_models[self.name].support_freeu(): if config.sd_freeu: freeu_config = config.sd_freeu_config self.model.model.enable_freeu( @@ -146,7 +118,7 @@ class ModelManager: self.model.model.disable_freeu() def enable_disable_lcm_lora(self, config: Config): - if self.name in MODELS_SUPPORT_LCM_LORA: + if self.available_models[self.name].support_lcm_lora(): if config.sd_lcm_lora: if not self.model.model.pipe.get_list_adapters(): self.model.model.load_lora_weights(self.model.lcm_lora_id) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index f672fb7..40f126c 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -6,7 +6,7 @@ from pathlib import Path from loguru import logger from lama_cleaner.const import * -from lama_cleaner.download import cli_download_model +from lama_cleaner.download import cli_download_model, scan_models from lama_cleaner.runtime import dump_environment_info DOWNLOAD_SUBCOMMAND = "download" @@ -46,7 +46,11 @@ def parse_args(): "--installer-config", default=None, help="Config file for windows installer" ) - parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS) + parser.add_argument( + "--model", + default=DEFAULT_MODEL, + help=f"Available models: [{', '.join(AVAILABLE_MODELS)}], or model id on huggingface", + ) parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) @@ -56,10 +60,9 @@ def parse_args(): parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) parser.add_argument( "--sd-controlnet-method", - default=DEFAULT_CONTROLNET_METHOD, + default=DEFAULT_SD_CONTROLNET_METHOD, choices=SD_CONTROLNET_CHOICES, ) - parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) parser.add_argument( "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP ) @@ -170,7 +173,8 @@ def parse_args(): ) ######### - # useless args + ### useless args ### + parser.add_argument("--sd-local-model-path", default=None, help=argparse.SUPPRESS) parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS) parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS) parser.add_argument( @@ -180,6 +184,7 @@ def parse_args(): parser.add_argument( "--sd-enable-xformers", action="store_true", help=argparse.SUPPRESS ) + ### end useless args ### args = parser.parse_args() # collect system info to help debug @@ -251,6 +256,17 @@ def parse_args(): os.environ["XDG_CACHE_HOME"] = args.model_dir os.environ["U2NET_HOME"] = args.model_dir + if args.sd_run_local or args.local_files_only: + os.environ["TRANSFORMERS_OFFLINE"] = "1" + os.environ["HF_HUB_OFFLINE"] = "1" + + if args.model not in AVAILABLE_MODELS: + scanned_models = scan_models() + if args.model not in [it.name for it in scanned_models]: + parser.error( + f"invalid --model: {args.model} not exists. Available models: {AVAILABLE_MODELS} or {scanned_models}" + ) + if args.input and args.input is not None: if not os.path.exists(args.input): parser.error(f"invalid --input: {args.input} not exists") diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 8b7e172..f7b253e 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -4,6 +4,61 @@ from enum import Enum from PIL.Image import Image from pydantic import BaseModel +DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" +DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline" +DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline" +DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline" + + +class ModelType(str, Enum): + INPAINT = "inpaint" # LaMa, MAT... + DIFFUSERS_SD = "diffusers_sd" + DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint" + DIFFUSERS_SDXL = "diffusers_sdxl" + DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint" + DIFFUSERS_OTHER = "diffusers_other" + + +FREEU_DEFAULT_CONFIGS = { + ModelType.DIFFUSERS_SD: dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4), + ModelType.DIFFUSERS_SDXL: dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2), +} + + +class ModelInfo(BaseModel): + name: str + path: str + model_type: ModelType + is_single_file_diffusers: bool = False + + def support_lcm_lora(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + def support_controlnet(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + + def support_freeu(self) -> bool: + return ( + self.model_type + in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] + or "instruct-pix2pix" in self.name + ) + class HDStrategy(str, Enum): # Use original image size diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 3088e7f..c302aeb 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -2,8 +2,6 @@ import os import hashlib -from lama_cleaner.diffusers_utils import scan_models - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import imghdr @@ -22,9 +20,9 @@ from loguru import logger from lama_cleaner.const import ( SD15_MODELS, - FREEU_DEFAULT_CONFIGS, - MODELS_SUPPORT_FREEU, - MODELS_SUPPORT_LCM_LORA, + SD_CONTROLNET_CHOICES, + SDXL_CONTROLNET_CHOICES, + SD2_CONTROLNET_CHOICES, ) from lama_cleaner.file_manager import FileManager from lama_cleaner.model.utils import torch_gc @@ -118,8 +116,8 @@ input_image_path: str = None is_disable_model_switch: bool = False is_controlnet: bool = False controlnet_method: str = "control_v11p_sd15_canny" -is_enable_file_manager: bool = False -is_enable_auto_saving: bool = False +enable_file_manager: bool = False +enable_auto_saving: bool = False is_desktop: bool = False image_quality: int = 95 plugins = {} @@ -421,34 +419,35 @@ def run_plugin(): @app.route("/server_config", methods=["GET"]) def get_server_config(): + controlnet = { + "SD": SD_CONTROLNET_CHOICES, + "SD2": SD2_CONTROLNET_CHOICES, + "SDXL": SDXL_CONTROLNET_CHOICES, + } return { - "isControlNet": is_controlnet, - "controlNetMethod": controlnet_method, - "isDisableModelSwitchState": is_disable_model_switch, - "isEnableAutoSaving": is_enable_auto_saving, - "enableFileManager": is_enable_file_manager, "plugins": list(plugins.keys()), - "freeSupportedModels": MODELS_SUPPORT_FREEU, - "freeuDefaultConfigs": FREEU_DEFAULT_CONFIGS, - "lcmLoraSupportedModels": MODELS_SUPPORT_LCM_LORA, + "availableControlNet": controlnet, + "enableFileManager": enable_file_manager, + "enableAutoSaving": enable_auto_saving, }, 200 -@app.route("/sd_models", methods=["GET"]) -def get_diffusers_models(): - from diffusers.utils import DIFFUSERS_CACHE - - return scan_models(DIFFUSERS_CACHE) +@app.route("/models", methods=["GET"]) +def get_models(): + return [ + { + **it.dict(), + "support_lcm_lora": it.support_lcm_lora(), + "support_controlnet": it.support_controlnet(), + "support_freeu": it.support_freeu(), + } + for it in model.scan_models() + ] @app.route("/model") def current_model(): - return model.name, 200 - - -@app.route("/model_downloaded/") -def model_downloaded(name): - return str(model.is_downloaded(name)), 200 + return model.available_models[model.name].dict(), 200 @app.route("/is_desktop") @@ -467,8 +466,10 @@ def switch_model(): try: model.switch(new_name) - except NotImplementedError: - return f"{new_name} not implemented", 403 + except Exception as e: + error_message = str(e) + logger.error(error_message) + return f"Switch model failed: {error_message}", 500 return f"ok, switch to {new_name}", 200 @@ -478,7 +479,7 @@ def index(): @app.route("/inputimage") -def set_input_photo(): +def get_cli_input_image(): if input_image_path: with open(input_image_path, "rb") as f: image_in_bytes = f.read() @@ -547,11 +548,10 @@ def main(args): global device global input_image_path global is_disable_model_switch - global is_enable_file_manager + global enable_file_manager global is_desktop global thumb global output_dir - global is_enable_auto_saving global is_controlnet global controlnet_method global image_quality @@ -566,7 +566,9 @@ def main(args): output_dir = args.output_dir if output_dir: - is_enable_auto_saving = True + output_dir = os.path.abspath(output_dir) + logger.info(f"Output dir: {output_dir}") + enable_auto_saving = True device = torch.device(args.device) is_disable_model_switch = args.disable_model_switch @@ -579,12 +581,12 @@ def main(args): if args.input and os.path.isdir(args.input): logger.info(f"Initialize file manager") thumb = FileManager(app) - is_enable_file_manager = True + enable_file_manager = True app.config["THUMBNAIL_MEDIA_ROOT"] = args.input app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( - args.output_dir, "lama_cleaner_thumbnails" + output_dir, "lama_cleaner_thumbnails" ) - thumb.output_dir = Path(args.output_dir) + thumb.output_dir = Path(output_dir) # thumb.start() # try: # while True: diff --git a/web_app/package-lock.json b/web_app/package-lock.json index dfc6f7f..bfa6edd 100644 --- a/web_app/package-lock.json +++ b/web_app/package-lock.json @@ -9,7 +9,9 @@ "version": "0.0.0", "dependencies": { "@heroicons/react": "^2.0.18", + "@hookform/resolvers": "^3.3.2", "@radix-ui/react-accordion": "^1.1.2", + "@radix-ui/react-alert-dialog": "^1.0.5", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-icons": "^1.3.0", @@ -17,6 +19,7 @@ "@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-select": "^2.0.0", + "@radix-ui/react-separator": "^1.0.3", "@radix-ui/react-slider": "^1.1.2", "@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-switch": "^1.0.3", @@ -24,7 +27,9 @@ "@radix-ui/react-toast": "^1.1.5", "@radix-ui/react-toggle": "^1.0.3", "@radix-ui/react-tooltip": "^1.0.7", + "@tanstack/react-query": "^5.8.7", "@uidotdev/usehooks": "^2.4.1", + "axios": "^1.6.2", "class-variance-authority": "^0.7.0", "clsx": "^2.0.0", "flexsearch": "^0.7.21", @@ -35,6 +40,7 @@ "next-themes": "^0.2.1", "react": "^18.2.0", "react-dom": "^18.2.0", + "react-hook-form": "^7.48.2", "react-hotkeys-hook": "^4.4.1", "react-photo-album": "^2.3.0", "react-use": "^17.4.0", @@ -42,9 +48,12 @@ "recoil": "^0.7.7", "tailwind-merge": "^2.0.0", "tailwindcss-animate": "^1.0.7", + "zod": "^3.22.4", "zustand": "^4.4.6" }, "devDependencies": { + "@tanstack/eslint-plugin-query": "^5.8.4", + "@types/axios": "^0.14.0", "@types/flexsearch": "^0.7.3", "@types/lodash": "^4.14.201", "@types/node": "^20.9.2", @@ -1069,6 +1078,14 @@ "react": ">= 16" } }, + "node_modules/@hookform/resolvers": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/@hookform/resolvers/-/resolvers-3.3.2.tgz", + "integrity": "sha512-Tw+GGPnBp+5DOsSg4ek3LCPgkBOuOgS5DsDV7qsWNH9LZc433kgsWICjlsh2J9p04H2K66hsXPPb9qn9ILdUtA==", + "peerDependencies": { + "react-hook-form": "^7.0.0" + } + }, "node_modules/@humanwhocodes/config-array": { "version": "0.11.13", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz", @@ -1374,6 +1391,34 @@ } } }, + "node_modules/@radix-ui/react-alert-dialog": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-alert-dialog/-/react-alert-dialog-1.0.5.tgz", + "integrity": "sha512-OrVIOcZL0tl6xibeuGt5/+UxoT2N27KCFOPjFyfXMnchxSHZ/OW7cCX2nGlIYJrbHK/fczPcFzAwvNBB6XBNMA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dialog": "1.0.5", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-arrow": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.0.3.tgz", @@ -1971,6 +2016,29 @@ } } }, + "node_modules/@radix-ui/react-separator": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.0.3.tgz", + "integrity": "sha512-itYmTy/kokS21aiV5+Z56MZB54KrhPgn6eHDKkFeOLR34HMN2s8PaN47qZZAGnvupcjxHaFZnW4pQEh0BvvVuw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-slider": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.1.2.tgz", @@ -2703,6 +2771,188 @@ "integrity": "sha512-myfUej5naTBWnqOCc/MdVOLVjXUXtIA+NpDrDBKJtLLg2shUjBu3cZmB/85RyitKc55+lUUyl7oRfLOvkr2hsw==", "dev": true }, + "node_modules/@tanstack/eslint-plugin-query": { + "version": "5.8.4", + "resolved": "https://registry.npmjs.org/@tanstack/eslint-plugin-query/-/eslint-plugin-query-5.8.4.tgz", + "integrity": "sha512-KVgcMc+Bn1qbwkxYVWQoiVSNEIN4IAiLj3cUH/SAHT8m8E59Y97o8ON1syp0Rcw094ItG8pEVZFyQuOaH6PDgQ==", + "dev": true, + "dependencies": { + "@typescript-eslint/utils": "^5.54.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "eslint": "^8.0.0" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/scope-manager": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-5.62.0.tgz", + "integrity": "sha512-VXuvVvZeQCQb5Zgf4HAxc04q5j+WrNAtNh9OwCsCgpKqESMTu3tF/jhZ3xG6T4NZwWl65Bg8KuS2uEvhSfLl0w==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/visitor-keys": "5.62.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/types": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-5.62.0.tgz", + "integrity": "sha512-87NVngcbVXUahrRTqIK27gD2t5Cu1yuCXxbLcFtCzZGlfyVWWh8mLHkoxzjsB6DDNnvdL+fW8MiwPEJyGJQDgQ==", + "dev": true, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/typescript-estree": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-5.62.0.tgz", + "integrity": "sha512-CmcQ6uY7b9y694lKdRB8FEel7JbU/40iSAPomu++SjLMntB+2Leay2LO6i8VnJk58MtE9/nQSFIH6jpyRWyYzA==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/visitor-keys": "5.62.0", + "debug": "^4.3.4", + "globby": "^11.1.0", + "is-glob": "^4.0.3", + "semver": "^7.3.7", + "tsutils": "^3.21.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/utils": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-5.62.0.tgz", + "integrity": "sha512-n8oxjeb5aIbPFEtmQxQYOLI0i9n5ySBEY/ZEHHZqKQSFnxio1rv6dthascc9dLuwrL0RC5mPCxB7vnAVGAYWAQ==", + "dev": true, + "dependencies": { + "@eslint-community/eslint-utils": "^4.2.0", + "@types/json-schema": "^7.0.9", + "@types/semver": "^7.3.12", + "@typescript-eslint/scope-manager": "5.62.0", + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/typescript-estree": "5.62.0", + "eslint-scope": "^5.1.1", + "semver": "^7.3.7" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/visitor-keys": { + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-5.62.0.tgz", + "integrity": "sha512-07ny+LHRzQXepkGg6w0mFY41fVUNBrL2Roj/++7V1txKugfjm/Ci/qSND03r2RhlJhJYMcTn9AhhSSqQp0Ysyw==", + "dev": true, + "dependencies": { + "@typescript-eslint/types": "5.62.0", + "eslint-visitor-keys": "^3.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "dev": true, + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/@tanstack/eslint-plugin-query/node_modules/estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/@tanstack/query-core": { + "version": "5.8.7", + "resolved": "https://registry.npmjs.org/@tanstack/query-core/-/query-core-5.8.7.tgz", + "integrity": "sha512-58xOSkxxZK4SGQ/uzX8MDZHLGZCkxlgkPxnfhxUOL2uchnNHyay2UVcR3mQNMgaMwH1e2l+0n+zfS7+UJ/MAJw==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + } + }, + "node_modules/@tanstack/react-query": { + "version": "5.8.7", + "resolved": "https://registry.npmjs.org/@tanstack/react-query/-/react-query-5.8.7.tgz", + "integrity": "sha512-RYSSMmkhbJ7tPkf8w+MSRIXQLoUCm7DRnTLDcdf+uampupnriEsob3fVWTt9oaEj+AJWEKeCErDBdZeNcAzURQ==", + "dependencies": { + "@tanstack/query-core": "5.8.7" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "react": "^18.0.0", + "react-dom": "^18.0.0", + "react-native": "*" + }, + "peerDependenciesMeta": { + "react-dom": { + "optional": true + }, + "react-native": { + "optional": true + } + } + }, + "node_modules/@types/axios": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/@types/axios/-/axios-0.14.0.tgz", + "integrity": "sha512-KqQnQbdYE54D7oa/UmYVMZKq7CO4l8DEENzOKc4aBRwxCXSlJXGz83flFx5L7AWrOQnmuN3kVsRdt+GZPPjiVQ==", + "deprecated": "This is a stub types definition for axios (https://github.com/mzabriskie/axios). axios provides its own type definitions, so you don't need @types/axios installed!", + "dev": true, + "dependencies": { + "axios": "*" + } + }, "node_modules/@types/babel__core": { "version": "7.20.4", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.4.tgz", @@ -3166,6 +3416,11 @@ "node": ">=8" } }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==" + }, "node_modules/autoprefixer": { "version": "10.4.16", "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz", @@ -3203,6 +3458,16 @@ "postcss": "^8.1.0" } }, + "node_modules/axios": { + "version": "1.6.2", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.2.tgz", + "integrity": "sha512-7i24Ri4pmDRfJTR7LDBhsOTtcm+9kjX5WiY1X3wIisx6G9So3pfMkEiU7emUBe46oceVImccTEM3k6C5dbVW8A==", + "dependencies": { + "follow-redirects": "^1.15.0", + "form-data": "^4.0.0", + "proxy-from-env": "^1.1.0" + } + }, "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -3412,6 +3677,17 @@ "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", "dev": true }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/commander": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz", @@ -3512,6 +3788,14 @@ "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", "dev": true }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/detect-node-es": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz", @@ -3916,6 +4200,38 @@ "resolved": "https://registry.npmjs.org/flexsearch/-/flexsearch-0.7.31.tgz", "integrity": "sha512-XGozTsMPYkm+6b5QL3Z9wQcJjNYxp0CYn3U1gO7dwD6PAqU1SVWZxI9CCg3z+ml3YfqdPnrBehaBrnH2AGKbNA==" }, + "node_modules/follow-redirects": { + "version": "1.15.3", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.3.tgz", + "integrity": "sha512-1VzOtuEM8pC9SFU1E+8KfTjZyMztRsgEfwQl44z8A25uy13jSzTj6dyK2Df52iV0vgHCfBwLhDWevLn95w5v6Q==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/form-data": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz", + "integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, "node_modules/fraction.js": { "version": "4.3.7", "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz", @@ -4413,6 +4729,25 @@ "node": ">=8.6" } }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, "node_modules/minimatch": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", @@ -4868,6 +5203,11 @@ "node": ">= 0.8.0" } }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==" + }, "node_modules/punycode": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", @@ -4919,6 +5259,21 @@ "react": "^18.2.0" } }, + "node_modules/react-hook-form": { + "version": "7.48.2", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.48.2.tgz", + "integrity": "sha512-H0T2InFQb1hX7qKtDIZmvpU1Xfn/bdahWBN1fH19gSe4bBEqTfmlr7H3XWTaVtiK4/tpPaI1F3355GPMZYge+A==", + "engines": { + "node": ">=12.22.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/react-hook-form" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17 || ^18" + } + }, "node_modules/react-hotkeys-hook": { "version": "4.4.1", "resolved": "https://registry.npmjs.org/react-hotkeys-hook/-/react-hotkeys-hook-4.4.1.tgz", @@ -5610,6 +5965,27 @@ "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" }, + "node_modules/tsutils": { + "version": "3.21.0", + "resolved": "https://registry.npmjs.org/tsutils/-/tsutils-3.21.0.tgz", + "integrity": "sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==", + "dev": true, + "dependencies": { + "tslib": "^1.8.1" + }, + "engines": { + "node": ">= 6" + }, + "peerDependencies": { + "typescript": ">=2.8.0 || >= 3.2.0-dev || >= 3.3.0-dev || >= 3.4.0-dev || >= 3.5.0-dev || >= 3.6.0-dev || >= 3.6.0-beta || >= 3.7.0-dev || >= 3.7.0-beta" + } + }, + "node_modules/tsutils/node_modules/tslib": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", + "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", + "dev": true + }, "node_modules/type-check": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", @@ -5860,6 +6236,14 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/zod": { + "version": "3.22.4", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.22.4.tgz", + "integrity": "sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, "node_modules/zustand": { "version": "4.4.6", "resolved": "https://registry.npmjs.org/zustand/-/zustand-4.4.6.tgz", diff --git a/web_app/package.json b/web_app/package.json index 78752a2..9b07e24 100644 --- a/web_app/package.json +++ b/web_app/package.json @@ -11,7 +11,9 @@ }, "dependencies": { "@heroicons/react": "^2.0.18", + "@hookform/resolvers": "^3.3.2", "@radix-ui/react-accordion": "^1.1.2", + "@radix-ui/react-alert-dialog": "^1.0.5", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-icons": "^1.3.0", @@ -19,6 +21,7 @@ "@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-select": "^2.0.0", + "@radix-ui/react-separator": "^1.0.3", "@radix-ui/react-slider": "^1.1.2", "@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-switch": "^1.0.3", @@ -26,7 +29,9 @@ "@radix-ui/react-toast": "^1.1.5", "@radix-ui/react-toggle": "^1.0.3", "@radix-ui/react-tooltip": "^1.0.7", + "@tanstack/react-query": "^5.8.7", "@uidotdev/usehooks": "^2.4.1", + "axios": "^1.6.2", "class-variance-authority": "^0.7.0", "clsx": "^2.0.0", "flexsearch": "^0.7.21", @@ -37,6 +42,7 @@ "next-themes": "^0.2.1", "react": "^18.2.0", "react-dom": "^18.2.0", + "react-hook-form": "^7.48.2", "react-hotkeys-hook": "^4.4.1", "react-photo-album": "^2.3.0", "react-use": "^17.4.0", @@ -44,9 +50,12 @@ "recoil": "^0.7.7", "tailwind-merge": "^2.0.0", "tailwindcss-animate": "^1.0.7", + "zod": "^3.22.4", "zustand": "^4.4.6" }, "devDependencies": { + "@tanstack/eslint-plugin-query": "^5.8.4", + "@types/axios": "^0.14.0", "@types/flexsearch": "^0.7.3", "@types/lodash": "^4.14.201", "@types/node": "^20.9.2", diff --git a/web_app/src/App.tsx b/web_app/src/App.tsx index 54cb8e0..567a061 100644 --- a/web_app/src/App.tsx +++ b/web_app/src/App.tsx @@ -1,7 +1,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { nanoid } from "nanoid" -import { useSetRecoilState } from "recoil" -import { serverConfigState } from "@/lib/store" + import useInputImage from "@/hooks/useInputImage" import { keepGUIAlive } from "@/lib/utils" import { getServerConfig, isDesktop } from "@/lib/api" @@ -19,10 +18,13 @@ const SUPPORTED_FILE_TYPE = [ "image/tiff", ] function Home() { - const [file, setFile] = useStore((state) => [state.file, state.setFile]) + const [file, setServerConfig, setFile] = useStore((state) => [ + state.file, + state.setServerConfig, + state.setFile, + ]) const userInputImage = useInputImage() - const setServerConfigState = useSetRecoilState(serverConfigState) useEffect(() => { if (userInputImage) { @@ -44,8 +46,7 @@ function Home() { useEffect(() => { const fetchServerConfig = async () => { const serverConfig = await getServerConfig().then((res) => res.json()) - console.log(serverConfig) - setServerConfigState(serverConfig) + setServerConfig(serverConfig) } fetchServerConfig() }, []) diff --git a/web_app/src/components/Cropper.tsx b/web_app/src/components/Cropper.tsx index f0be764..3a10420 100644 --- a/web_app/src/components/Cropper.tsx +++ b/web_app/src/components/Cropper.tsx @@ -1,5 +1,7 @@ import { useStore } from "@/lib/states" +import { cn } from "@/lib/utils" import React, { useEffect, useState } from "react" +import { twMerge } from "tailwind-merge" const DOC_MOVE_OPTS = { capture: true, passive: false } @@ -75,11 +77,6 @@ const Cropper = (props: Props) => { state.setCropperWidth, state.setCropperHeight, ]) - // const [x, setX] = useRecoilState(croperX) - // const [y, setY] = useRecoilState(croperY) - // const [height, setHeight] = useRecoilState(croperHeight) - // const [width, setWidth] = useRecoilState(croperWidth) - // const isInpainting = useRecoilValue(isInpaintingState) const [isResizing, setIsResizing] = useState(false) const [isMoving, setIsMoving] = useState(false) @@ -100,7 +97,7 @@ const Cropper = (props: Props) => { }) const onDragFocus = () => { - console.log("focus") + // console.log("focus") } const clampLeftRight = (newX: number, newWidth: number) => { @@ -254,102 +251,64 @@ const Cropper = (props: Props) => { } } - const createCropSelection = () => { + const createDragHandle = (cursor: string, side1: string, side2: string) => { + const sideLength = 12 + const draghandleCls = `w-[${sideLength}px] h-[${sideLength}px] z-4 absolute block border-2 border-primary borde pointer-events-auto hover:bg-primary` + + let side2Cls = `${side2}-[-${sideLength / 2}px]` + if (side2 === "") { + if (side1 === "top" || side1 === "bottom") { + side2Cls = `left-[calc(50%-${sideLength / 2}px)]` + } else if (side1 === "left" || side1 === "right") { + side2Cls = `top-[calc(50%-${sideLength / 2}px)]` + } + } + return (
+ className={cn( + draghandleCls, + `${cursor}`, + side1 ? `${side1}-[-${sideLength / 2}px]` : "", + side2Cls + )} + data-ord={side1 + side2} + aria-label={side1 + side2} + tabIndex={-1} + role="button" + /> + ) + } + + const createCropSelection = () => { + return ( +
-
+ {createDragHandle("cursor-nw-resize", "top", "left")} + {createDragHandle("cursor-ne-resize", "top", "right")} + {createDragHandle("cursor-se-resize", "bottom", "left")} + {createDragHandle("cursor-sw-resize", "bottom", "right")} -
- -
- -
- -
-
-
-
+ {createDragHandle("cursor-ns-resize", "top", "")} + {createDragHandle("cursor-ns-resize", "bottom", "")} + {createDragHandle("cursor-ew-resize", "left", "")} + {createDragHandle("cursor-ew-resize", "right", "")}
) } @@ -370,17 +329,17 @@ const Cropper = (props: Props) => { const createInfoBar = () => { return (
-
- {width} x {height} -
+ {/* TODO: 移动的时候会显示 brush */} + {width} x {height}
) } diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index 3bf8c1e..1a3998e 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -3,13 +3,10 @@ import { CursorArrowRaysIcon } from "@heroicons/react/24/outline" import { useToast } from "@/components/ui/use-toast" import { ReactZoomPanPinchContentRef, - ReactZoomPanPinchRef, TransformComponent, TransformWrapper, } from "react-zoom-pan-pinch" -import { useRecoilState, useRecoilValue, useSetRecoilState } from "recoil" -import { useWindowSize } from "react-use" -// import { useWindowSize, useKey, useKeyPressEvent } from "@uidotdev/usehooks" +import { useKeyPressEvent, useWindowSize } from "react-use" import inpaint, { downloadToOutput, runPlugin } from "@/lib/api" import { IconButton } from "@/components/ui/button" import { @@ -22,23 +19,6 @@ import { srcToFile, } from "@/lib/utils" import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react" -import { - croperState, - enableFileManagerState, - interactiveSegClicksState, - isDiffusionModelsState, - isEnableAutoSavingState, - isInteractiveSegRunningState, - isInteractiveSegState, - isPix2PixState, - isPluginRunningState, - isProcessingState, - negativePropmtState, - runManuallyState, - seedState, - settingState, -} from "@/lib/store" -// import Croper from "../Croper/Croper" import emitter, { EVENT_PROMPT, EVENT_CUSTOM_MASK, @@ -49,19 +29,15 @@ import emitter, { } from "@/lib/event" import { useImage } from "@/hooks/useImage" import { Slider } from "./ui/slider" -// import FileSelect from "../FileSelect/FileSelect" -// import InteractiveSeg from "../InteractiveSeg/InteractiveSeg" -// import InteractiveSegConfirmActions from "../InteractiveSeg/ConfirmActions" -// import InteractiveSegReplaceModal from "../InteractiveSeg/ReplaceModal" import { PluginName } from "@/lib/types" import { useHotkeys } from "react-hotkeys-hook" import { useStore } from "@/lib/states" import Cropper from "./Cropper" -import { HotkeysEvent } from "react-hotkeys-hook/dist/types" const TOOLBAR_HEIGHT = 200 const MIN_BRUSH_SIZE = 10 const MAX_BRUSH_SIZE = 200 +const COMPARE_SLIDER_DURATION_MS = 300 const BRUSH_COLOR = "#ffcc00bb" interface Line { @@ -110,48 +86,55 @@ export default function Editor(props: EditorProps) { imageWidth, imageHeight, baseBrushSize, - brushScale, - promptVal, + brushSizeScale, + settings, + enableAutoSaving, + cropperRect, + enableManualInpainting, setImageSize, setBrushSize, setIsInpainting, + setSeed, + interactiveSegState, + updateInteractiveSegState, + resetInteractiveSegState, + isPluginRunning, + setIsPluginRunning, ] = useStore((state) => [ state.isInpainting, state.imageWidth, state.imageHeight, state.brushSize, state.brushSizeScale, - state.prompt, + state.settings, + state.serverConfig.enableAutoSaving, + state.cropperState, + state.settings.enableManualInpainting, state.setImageSize, state.setBrushSize, state.setIsInpainting, + state.setSeed, + state.interactiveSegState, + state.updateInteractiveSegState, + state.resetInteractiveSegState, + state.isPluginRunning, + state.setIsPluginRunning, ]) - const brushSize = baseBrushSize * brushScale + const brushSize = baseBrushSize * brushSizeScale // 纯 local state const [showOriginal, setShowOriginal] = useState(false) // - const negativePromptVal = useRecoilValue(negativePropmtState) - const settings = useRecoilValue(settingState) - const [seedVal, setSeed] = useRecoilState(seedState) - const croperRect = useRecoilValue(croperState) - const setIsPluginRunning = useSetRecoilState(isPluginRunningState) - const isProcessing = useRecoilValue(isProcessingState) - const runMannually = useRecoilValue(runManuallyState) - const isDiffusionModels = useRecoilValue(isDiffusionModelsState) - const isPix2Pix = useRecoilValue(isPix2PixState) - const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState( - isInteractiveSegState - ) - const setIsInteractiveSegRunning = useSetRecoilState( - isInteractiveSegRunningState - ) + const isProcessing = isInpainting + const isDiffusionModels = false + const isPix2Pix = false const [showInteractiveSegModal, setShowInteractiveSegModal] = useState(false) const [interactiveSegMask, setInteractiveSegMask] = useState< HTMLImageElement | null | undefined >(null) + // only used while interactive segmentation is on const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = useState< HTMLImageElement | null | undefined @@ -167,8 +150,6 @@ export default function Editor(props: EditorProps) { const [dreamButtonHoverLineGroup, setDreamButtonHoverLineGroup] = useState([]) - const [clicks, setClicks] = useRecoilState(interactiveSegClicksState) - const [original, isOriginalLoaded] = useImage(file) const [renders, setRenders] = useState([]) const [context, setContext] = useState() @@ -201,7 +182,6 @@ export default function Editor(props: EditorProps) { const [initialCentered, setInitialCentered] = useState(false) const [isDraging, setIsDraging] = useState(false) - const [isMultiStrokeKeyPressed, setIsMultiStrokeKeyPressed] = useState(false) const [sliderPos, setSliderPos] = useState(0) @@ -209,8 +189,6 @@ export default function Editor(props: EditorProps) { const [redoRenders, setRedoRenders] = useState([]) const [redoCurLines, setRedoCurLines] = useState([]) const [redoLineGroups, setRedoLineGroups] = useState([]) - const enableFileManager = useRecoilValue(enableFileManagerState) - const isEnableAutoSaving = useRecoilValue(isEnableAutoSavingState) const draw = useCallback( (render: HTMLImageElement, lineGroup: LineGroup) => { @@ -223,10 +201,10 @@ export default function Editor(props: EditorProps) { context.clearRect(0, 0, context.canvas.width, context.canvas.height) context.drawImage(render, 0, 0, imageWidth, imageHeight) - if (isInteractiveSeg && tmpInteractiveSegMask) { + if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) { context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight) } - if (!isInteractiveSeg && interactiveSegMask) { + if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) { context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight) } if (dreamButtonHoverSegMask) { @@ -243,7 +221,7 @@ export default function Editor(props: EditorProps) { }, [ context, - isInteractiveSeg, + interactiveSegState, tmpInteractiveSegMask, dreamButtonHoverSegMask, interactiveSegMask, @@ -363,34 +341,31 @@ export default function Editor(props: EditorProps) { setCurLineGroup([]) setIsDraging(false) setIsInpainting(true) - if (settings.graduallyInpainting) { - drawLinesOnMask([maskLineGroup], maskImage) - } else { - drawLinesOnMask(newLineGroups) - } + drawLinesOnMask([maskLineGroup], maskImage) let targetFile = file - if (settings.graduallyInpainting === true) { - if (useLastLineGroup === true) { - // renders.length == 1 还是用原来的 - if (renders.length > 1) { - const lastRender = renders[renders.length - 2] - targetFile = await srcToFile( - lastRender.currentSrc, - file.name, - file.type - ) - } - } else if (renders.length > 0) { - console.info("gradually inpainting on last result") - - const lastRender = renders[renders.length - 1] + console.log( + `randers.length ${renders.length} useLastLineGroup: ${useLastLineGroup}` + ) + if (useLastLineGroup === true) { + // renders.length == 1 还是用原来的 + if (renders.length > 1) { + const lastRender = renders[renders.length - 2] targetFile = await srcToFile( lastRender.currentSrc, file.name, file.type ) } + } else if (renders.length > 0) { + console.info("gradually inpainting on last result") + + const lastRender = renders[renders.length - 1] + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type + ) } try { @@ -398,10 +373,7 @@ export default function Editor(props: EditorProps) { const res = await inpaint( targetFile, settings, - croperRect, - promptVal, - negativePromptVal, - seedVal, + cropperRect, useCustomMask ? undefined : maskCanvas.toDataURL(), useCustomMask ? customMask : undefined, paintByExampleImage @@ -445,18 +417,15 @@ export default function Editor(props: EditorProps) { setInteractiveSegMask(null) }, [ + renders, lineGroups, curLineGroup, maskCanvas, - settings.graduallyInpainting, settings, - croperRect, - promptVal, - negativePromptVal, + cropperRect, drawOnCurrentRender, hadDrawSomething, drawLinesOnMask, - seedVal, ] ) @@ -487,7 +456,6 @@ export default function Editor(props: EditorProps) { }, [ hadDrawSomething, runInpainting, - promptVal, interactiveSegMask, prevInteractiveSegMask, ]) @@ -604,7 +572,7 @@ export default function Editor(props: EditorProps) { useEffect(() => { emitter.on(PluginName.InteractiveSeg, () => { - setIsInteractiveSeg(true) + // setIsInteractiveSeg(true) if (interactiveSegMask !== null) { setShowInteractiveSegModal(true) } @@ -807,8 +775,8 @@ export default function Editor(props: EditorProps) { const offsetX = (windowSize.width - imageWidth * minScale) / 2 const offsetY = (windowSize.height - imageHeight * minScale) / 2 viewport.setTransform(offsetX, offsetY, minScale, 200, "easeOutQuad") - if (viewport.state) { - viewport.state.scale = minScale + if (viewport.instance.transformState.scale) { + viewport.instance.transformState.scale = minScale } setScale(minScale) @@ -850,24 +818,12 @@ export default function Editor(props: EditorProps) { } }, []) - const onInteractiveCancel = useCallback(() => { - setIsInteractiveSeg(false) - setIsInteractiveSegRunning(false) - setClicks([]) - setTmpInteractiveSegMask(null) - }, []) - const handleEscPressed = () => { if (isProcessing) { return } - if (isInteractiveSeg) { - onInteractiveCancel() - return - } - - if (isDraging || isMultiStrokeKeyPressed) { + if (isDraging) { setIsDraging(false) setCurLineGroup([]) drawOnCurrentRender([]) @@ -879,9 +835,6 @@ export default function Editor(props: EditorProps) { useHotkeys("Escape", handleEscPressed, [ isDraging, isInpainting, - isMultiStrokeKeyPressed, - isInteractiveSeg, - onInteractiveCancel, resetZoom, drawOnCurrentRender, ]) @@ -901,7 +854,7 @@ export default function Editor(props: EditorProps) { } return } - if (isInteractiveSeg) { + if (interactiveSegState.isInteractiveSeg) { return } if (isPanning) { @@ -924,7 +877,7 @@ export default function Editor(props: EditorProps) { return } - setIsInteractiveSegRunning(true) + // setIsInteractiveSegRunning(true) const targetFile = await getCurrentRender() const prevMask = null try { @@ -950,14 +903,14 @@ export default function Editor(props: EditorProps) { description: e.message ? e.message : e.toString(), }) } - setIsInteractiveSegRunning(false) + // setIsInteractiveSegRunning(false) } const onPointerUp = (ev: SyntheticEvent) => { if (isMidClick(ev)) { setIsPanning(false) } - if (isInteractiveSeg) { + if (interactiveSegState.isInteractiveSeg) { return } @@ -978,12 +931,7 @@ export default function Editor(props: EditorProps) { return } - if (isMultiStrokeKeyPressed) { - setIsDraging(false) - return - } - - if (runMannually) { + if (enableManualInpainting) { setIsDraging(false) } else { runInpainting() @@ -991,34 +939,34 @@ export default function Editor(props: EditorProps) { } const isOutsideCroper = (clickPnt: { x: number; y: number }) => { - if (clickPnt.x < croperRect.x) { + if (clickPnt.x < cropperRect.x) { return true } - if (clickPnt.y < croperRect.y) { + if (clickPnt.y < cropperRect.y) { return true } - if (clickPnt.x > croperRect.x + croperRect.width) { + if (clickPnt.x > cropperRect.x + cropperRect.width) { return true } - if (clickPnt.y > croperRect.y + croperRect.height) { + if (clickPnt.y > cropperRect.y + cropperRect.height) { return true } return false } const onCanvasMouseUp = (ev: SyntheticEvent) => { - if (isInteractiveSeg) { + if (interactiveSegState.isInteractiveSeg) { const xy = mouseXY(ev) const isX = xy.x const isY = xy.y - const newClicks: number[][] = [...clicks] + const newClicks: number[][] = [...interactiveSegState.clicks] if (isRightClick(ev)) { newClicks.push([isX, isY, 0, newClicks.length]) } else { newClicks.push([isX, isY, 1, newClicks.length]) } // runInteractiveSeg(newClicks) - setClicks(newClicks) + updateInteractiveSegState({ clicks: newClicks }) } } @@ -1026,7 +974,7 @@ export default function Editor(props: EditorProps) { if (isProcessing) { return } - if (isInteractiveSeg) { + if (interactiveSegState.isInteractiveSeg) { return } if (isChangingBrushSizeByMouse) { @@ -1063,7 +1011,7 @@ export default function Editor(props: EditorProps) { setIsDraging(true) let lineGroup: LineGroup = [] - if (isMultiStrokeKeyPressed || runMannually) { + if (enableManualInpainting) { lineGroup = [...curLineGroup] } lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] }) @@ -1122,9 +1070,9 @@ export default function Editor(props: EditorProps) { context, ]) - const undo = (keyboardEvent: KeyboardEvent, hotkeysEvent: HotkeysEvent) => { + const undo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { keyboardEvent.preventDefault() - if (runMannually && curLineGroup.length !== 0) { + if (enableManualInpainting && curLineGroup.length !== 0) { undoStroke() } else { undoRender() @@ -1134,7 +1082,7 @@ export default function Editor(props: EditorProps) { useHotkeys("meta+z,ctrl+z", undo, undefined, [ undoStroke, undoRender, - runMannually, + enableManualInpainting, curLineGroup, context?.canvas, renders, @@ -1148,7 +1096,7 @@ export default function Editor(props: EditorProps) { return false } - if (runMannually) { + if (enableManualInpainting) { if (curLineGroup.length === 0) { return true } @@ -1188,9 +1136,9 @@ export default function Editor(props: EditorProps) { // draw(newRenders[newRenders.length - 1], []) }, [draw, renders, redoRenders, redoLineGroups, lineGroups, original]) - const redo = (keyboardEvent: KeyboardEvent, hotkeysEvent: HotkeysEvent) => { + const redo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { keyboardEvent.preventDefault() - if (runMannually && redoCurLines.length !== 0) { + if (enableManualInpainting && redoCurLines.length !== 0) { redoStroke() } else { redoRender() @@ -1200,7 +1148,7 @@ export default function Editor(props: EditorProps) { useHotkeys("shift+ctrl+z,shift+meta+z", redo, undefined, [ redoStroke, redoRender, - runMannually, + enableManualInpainting, redoCurLines, ]) @@ -1212,7 +1160,7 @@ export default function Editor(props: EditorProps) { return false } - if (runMannually) { + if (enableManualInpainting) { if (redoCurLines.length === 0) { return true } @@ -1223,37 +1171,39 @@ export default function Editor(props: EditorProps) { return false } - // useKeyPressEvent( - // "Tab", - // (ev) => { - // ev?.preventDefault() - // ev?.stopPropagation() - // if (hadRunInpainting()) { - // setShowOriginal(() => { - // window.setTimeout(() => { - // setSliderPos(100) - // }, 10) - // return true - // }) - // } - // }, - // (ev) => { - // ev?.preventDefault() - // ev?.stopPropagation() - // if (hadRunInpainting()) { - // setSliderPos(0) - // window.setTimeout(() => { - // setShowOriginal(false) - // }, 350) - // } - // } - // ) + useKeyPressEvent( + "Tab", + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + if (hadRunInpainting()) { + setShowOriginal(() => { + window.setTimeout(() => { + setSliderPos(100) + }, 10) + return true + }) + } + }, + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + if (hadRunInpainting()) { + window.setTimeout(() => { + setSliderPos(0) + }, 10) + window.setTimeout(() => { + setShowOriginal(false) + }, COMPARE_SLIDER_DURATION_MS) + } + } + ) function download() { if (file === undefined) { return } - if ((enableFileManager || isEnableAutoSaving) && renders.length > 0) { + if (enableAutoSaving && renders.length > 0) { try { downloadToOutput(renders[renders.length - 1], file.name, file.type) toast({ @@ -1273,7 +1223,7 @@ export default function Editor(props: EditorProps) { const name = file.name.replace(/(\.[\w\d_-]+)$/i, "_cleanup$1") const curRender = renders[renders.length - 1] downloadImage(curRender.currentSrc, name) - if (settings.downloadMask) { + if (settings.enableDownloadMask) { let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1") maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg") @@ -1305,104 +1255,98 @@ export default function Editor(props: EditorProps) { return undefined }, [showBrush, isPanning]) - // Standard Hotkeys for Brush Size - // useHotKey("[", () => { - // setBrushSize((currentBrushSize: number) => { - // if (currentBrushSize > 10) { - // return currentBrushSize - 10 - // } - // if (currentBrushSize <= 10 && currentBrushSize > 0) { - // return currentBrushSize - 5 - // } - // return currentBrushSize - // }) - // }) + useHotkeys( + "[", + () => { + let newBrushSize = baseBrushSize + if (baseBrushSize > 10) { + newBrushSize = baseBrushSize - 10 + } + if (baseBrushSize <= 10 && baseBrushSize > 0) { + newBrushSize = baseBrushSize - 5 + } + setBrushSize(newBrushSize) + }, + [baseBrushSize] + ) - // useHotKey("]", () => { - // setBrushSize((currentBrushSize: number) => { - // return currentBrushSize + 10 - // }) - // }) + useHotkeys( + "]", + () => { + setBrushSize(baseBrushSize + 10) + }, + [baseBrushSize] + ) - // // Manual Inpainting Hotkey - // useHotKey( - // "shift+r", - // () => { - // if (runMannually && hadDrawSomething()) { - // runInpainting() - // } - // }, - // {}, - // [runMannually, runInpainting, hadDrawSomething] - // ) + // Manual Inpainting Hotkey + useHotkeys( + "shift+r", + () => { + if (enableManualInpainting && hadDrawSomething()) { + runInpainting() + } + }, + [enableManualInpainting, runInpainting, hadDrawSomething] + ) - // useHotKey( - // "ctrl+c, cmd+c", - // async () => { - // const hasPermission = await askWritePermission() - // if (hasPermission && renders.length > 0) { - // if (context?.canvas) { - // await copyCanvasImage(context?.canvas) - // setToastState({ - // open: true, - // desc: "Copy inpainting result to clipboard", - // state: "success", - // duration: 3000, - // }) - // } - // } - // }, - // {}, - // [renders, context] - // ) + useHotkeys( + "ctrl+c, cmd+c", + async () => { + const hasPermission = await askWritePermission() + if (hasPermission && renders.length > 0) { + if (context?.canvas) { + await copyCanvasImage(context?.canvas) + toast({ + title: "Copy inpainting result to clipboard", + }) + } + } + }, + [renders, context] + ) // Toggle clean/zoom tool on spacebar. - // useKeyPressEvent( - // " ", - // (ev) => { - // if (!app.disableShortCuts) { - // ev?.preventDefault() - // ev?.stopPropagation() - // setShowBrush(false) - // setIsPanning(true) - // } - // }, - // (ev) => { - // if (!app.disableShortCuts) { - // ev?.preventDefault() - // ev?.stopPropagation() - // setShowBrush(true) - // setIsPanning(false) - // } - // } - // ) + useKeyPressEvent( + " ", + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + setShowBrush(false) + setIsPanning(true) + }, + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + setShowBrush(true) + setIsPanning(false) + } + ) - // useKeyPressEvent( - // "Alt", - // (ev) => { - // ev?.preventDefault() - // ev?.stopPropagation() - // setIsChangingBrushSizeByMouse(true) - // setChangeBrushSizeByMouseInit({ x, y, brushSize }) - // }, - // (ev) => { - // ev?.preventDefault() - // ev?.stopPropagation() - // setIsChangingBrushSizeByMouse(false) - // } - // ) + useKeyPressEvent( + "Alt", + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + setIsChangingBrushSizeByMouse(true) + setChangeBrushSizeByMouseInit({ x, y, brushSize }) + }, + (ev) => { + ev?.preventDefault() + ev?.stopPropagation() + setIsChangingBrushSizeByMouse(false) + } + ) const getCurScale = (): number => { let s = minScale - if (viewportRef.current?.state?.scale !== undefined) { - s = viewportRef.current?.state.scale - console.log("!!!!!!") + if (viewportRef.current?.instance?.transformState.scale !== undefined) { + s = viewportRef.current?.instance?.transformState.scale } return s! } const getBrushStyle = (_x: number, _y: number) => { - const curScale = scale + const curScale = getCurScale() return { width: `${brushSize * curScale}px`, height: `${brushSize * curScale}px`, @@ -1435,7 +1379,7 @@ export default function Editor(props: EditorProps) { const renderInteractiveSegCursor = () => { return (
{ e.preventDefault() @@ -1519,9 +1465,10 @@ export default function Editor(props: EditorProps) { {showOriginal && ( <>
- {/* {isInteractiveSeg ? : <>} */} + {/* {interactiveSegState.isInteractiveSeg ? : <>} */} ) @@ -1558,7 +1505,7 @@ export default function Editor(props: EditorProps) { setInteractiveSegMask(tmpInteractiveSegMask) setTmpInteractiveSegMask(null) - if (!runMannually && tmpInteractiveSegMask) { + if (!enableManualInpainting && tmpInteractiveSegMask) { runInpainting(false, undefined, tmpInteractiveSegMask) } } @@ -1570,16 +1517,12 @@ export default function Editor(props: EditorProps) { onMouseMove={onMouseMove} onMouseUp={onPointerUp} > - {/* */} {renderCanvas()} {showBrush && !isInpainting && !isPanning && - (isInteractiveSeg + (interactiveSegState.isInteractiveSeg ? renderInteractiveSegCursor() : renderBrush( getBrushStyle( @@ -1590,20 +1533,21 @@ export default function Editor(props: EditorProps) { {showRefBrush && renderBrush(getBrushStyle(windowCenterX, windowCenterY))} -
+
handleSliderChange(vals[0])} onClick={() => setShowRefBrush(false)} />
@@ -1616,23 +1560,26 @@ export default function Editor(props: EditorProps) { { - // ev.preventDefault() - // setShowOriginal(() => { - // window.setTimeout(() => { - // setSliderPos(100) - // }, 10) - // return true - // }) - // }} - // onUp={() => { - // setSliderPos(0) - // window.setTimeout(() => { - // setShowOriginal(false) - // }, 300) - // }} + tooltip="Show original image" + onPointerDown={(ev) => { + ev.preventDefault() + setShowOriginal(() => { + window.setTimeout(() => { + setSliderPos(100) + }, 10) + return true + }) + }} + onPointerUp={() => { + window.setTimeout(() => { + // 防止快速点击 show original image 按钮时图片消失 + setSliderPos(0) + }, 10) + + window.setTimeout(() => { + setShowOriginal(false) + }, COMPARE_SLIDER_DURATION_MS) + }} disabled={renders.length === 0} > @@ -1645,36 +1592,25 @@ export default function Editor(props: EditorProps) { - { - // ensured by disabled - runInpainting(false, undefined, interactiveSegMask) - }} - > - - + {settings.enableManualInpainting ? ( + { + // ensured by disabled + runInpainting(false, undefined, interactiveSegMask) + }} + > + + + ) : ( + <> + )}
- {/* { - onInteractiveCancel() - setShowInteractiveSegModal(false) - }} - onCleanClick={() => { - onInteractiveCancel() - setInteractiveSegMask(null) - }} - onReplaceClick={() => { - setShowInteractiveSegModal(false) - setIsInteractiveSeg(true) - }} - /> */}
) } diff --git a/web_app/src/components/FileManager.tsx b/web_app/src/components/FileManager.tsx index 145b51b..89282cf 100644 --- a/web_app/src/components/FileManager.tsx +++ b/web_app/src/components/FileManager.tsx @@ -74,18 +74,9 @@ export default function FileManager(props: Props) { const { onPhotoClick, photoWidth } = props const [open, toggleOpen] = useToggle(false) - const [ - fileManagerState, - setFileManagerLayout, - setFileManagerSortBy, - setFileManagerSortOrder, - setFileManagerSearchText, - ] = useStore((state) => [ + const [fileManagerState, updateFileManagerState] = useStore((state) => [ state.fileManagerState, - state.setFileManagerLayout, - state.setFileManagerSortBy, - state.setFileManagerSortOrder, - state.setFileManagerSearchText, + state.updateFileManagerState, ]) useHotkeys("f", () => { @@ -185,7 +176,7 @@ export default function FileManager(props: Props) { { - setFileManagerLayout("rows") + updateFileManagerState({ layout: "rows" }) }} > { - setFileManagerLayout("masonry") + updateFileManagerState({ layout: "masonry" }) }} > @@ -250,13 +241,13 @@ export default function FileManager(props: Props) { onValueChange={(val) => { switch (val) { case SORT_BY_NAME: - setFileManagerSortBy(SortBy.NAME) + updateFileManagerState({ sortBy: SortBy.NAME }) break case SORT_BY_CREATED_TIME: - setFileManagerSortBy(SortBy.CTIME) + updateFileManagerState({ sortBy: SortBy.CTIME }) break case SORT_BY_MODIFIED_TIME: - setFileManagerSortBy(SortBy.MTIME) + updateFileManagerState({ sortBy: SortBy.MTIME }) break default: break @@ -281,7 +272,7 @@ export default function FileManager(props: Props) { { - setFileManagerSortOrder(SortOrder.ASCENDING) + updateFileManagerState({ sortOrder: SortOrder.ASCENDING }) }} > @@ -290,7 +281,7 @@ export default function FileManager(props: Props) { { - setFileManagerSortOrder(SortOrder.DESCENDING) + updateFileManagerState({ sortOrder: SortOrder.DESCENDING }) }} > diff --git a/web_app/src/components/Header.tsx b/web_app/src/components/Header.tsx index bade609..81fa016 100644 --- a/web_app/src/components/Header.tsx +++ b/web_app/src/components/Header.tsx @@ -1,19 +1,8 @@ import { PlayIcon } from "@radix-ui/react-icons" -import React, { useCallback, useState } from "react" -import { useRecoilState, useRecoilValue } from "recoil" +import { useCallback, useState } from "react" import { useHotkeys } from "react-hotkeys-hook" -import { - enableFileManagerState, - isPix2PixState, - isSDState, - maskState, - runManuallyState, -} from "@/lib/store" import { IconButton, ImageUploadButton } from "@/components/ui/button" import Shortcuts from "@/components/Shortcuts" -// import SettingIcon from "../Settings/SettingIcon" -// import PromptInput from "./PromptInput" -// import CoffeeIcon from '../CoffeeIcon/CoffeeIcon' import emitter, { DREAM_BUTTON_MOUSE_ENTER, DREAM_BUTTON_MOUSE_LEAVE, @@ -24,24 +13,37 @@ import { useImage } from "@/hooks/useImage" import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover" import PromptInput from "./PromptInput" -import { RotateCw, Image } from "lucide-react" +import { RotateCw, Image, Upload } from "lucide-react" import FileManager from "./FileManager" import { getMediaFile } from "@/lib/api" import { useStore } from "@/lib/states" +import SettingsDialog from "./Settings" +import { cn } from "@/lib/utils" const Header = () => { - const [file, isInpainting, setFile] = useStore((state) => [ + const [ + file, + customMask, + isInpainting, + enableFileManager, + enableManualInpainting, + enableUploadMask, + shouldShowPromptInput, + setFile, + setCustomFile, + ] = useStore((state) => [ state.file, + state.customMask, state.isInpainting, + state.serverConfig.enableFileManager, + state.settings.enableManualInpainting, + state.settings.enableUploadMask, + state.shouldShowPromptInput(), state.setFile, + state.setCustomFile, ]) - const [mask, setMask] = useRecoilState(maskState) - // const [maskImage, maskImageLoaded] = useImage(mask) - const isSD = useRecoilValue(isSDState) - const isPix2Pix = useRecoilValue(isPix2PixState) - const runManually = useRecoilValue(runManuallyState) + const [maskImage, maskImageLoaded] = useImage(customMask) const [openMaskPopover, setOpenMaskPopover] = useState(false) - const enableFileManager = useRecoilValue(enableFileManagerState) const handleRerunLastMask = useCallback(() => { emitter.emit(RERUN_LAST_MASK) @@ -68,7 +70,7 @@ const Header = () => { return (
-
+
{enableFileManager ? ( {
{ - setMask(file) - console.info("Send custom mask") - if (!runManually) { + setCustomFile(file) + if (!enableManualInpainting) { emitter.emit(EVENT_CUSTOM_MASK, { mask: file }) } }} > -
M
+
- {mask ? ( + {customMask ? ( setOpenMaskPopover(true)} onMouseLeave={() => setOpenMaskPopover(false)} style={{ - visibility: mask ? "visible" : "hidden", + visibility: customMask ? "visible" : "hidden", outline: "none", }} onClick={() => { - if (mask) { - emitter.emit(EVENT_CUSTOM_MASK, { mask }) + if (customMask) { + emitter.emit(EVENT_CUSTOM_MASK, { mask: customMask }) } }} > @@ -131,36 +132,36 @@ const Header = () => { - {/* + {maskImageLoaded ? ( Custom mask ) : ( <> )} - */} + ) : ( <> )} - - - -
+ + + +
- {isSD ? : <>} + {shouldShowPromptInput ? : <>} - {/* */} -
+
+ {/* */} - {/* */} +
) diff --git a/web_app/src/components/ImageSize.tsx b/web_app/src/components/ImageSize.tsx index 2129780..17314d4 100644 --- a/web_app/src/components/ImageSize.tsx +++ b/web_app/src/components/ImageSize.tsx @@ -11,7 +11,7 @@ const ImageSize = () => { } return ( -
+
{imageWidth}x{imageHeight}
) diff --git a/web_app/src/components/InteractiveSeg.tsx b/web_app/src/components/InteractiveSeg.tsx new file mode 100644 index 0000000..d57539e --- /dev/null +++ b/web_app/src/components/InteractiveSeg.tsx @@ -0,0 +1,136 @@ +import { useStore } from "@/lib/states" +import { Button } from "./ui/button" +import { Dialog, DialogContent, DialogTitle } from "./ui/dialog" +import { MousePointerClick } from "lucide-react" +import { DropdownMenuItem } from "./ui/dropdown-menu" + +interface InteractiveSegReplaceModal { + show: boolean + onClose: () => void + onCleanClick: () => void + onReplaceClick: () => void +} + +const InteractiveSegReplaceModal = (props: InteractiveSegReplaceModal) => { + const { show, onClose, onCleanClick, onReplaceClick } = props + + const onOpenChange = (open: boolean) => { + if (!open) { + onClose() + } + } + + return ( + + + Do you want to remove it or create a new one? +
+ + +
+
+
+ ) +} + +const InteractiveSegConfirmActions = () => { + const [interactiveSegState, resetInteractiveSegState] = useStore((state) => [ + state.interactiveSegState, + state.resetInteractiveSegState, + ]) + + if (!interactiveSegState.isInteractiveSeg) { + return null + } + + const onAcceptClick = () => { + resetInteractiveSegState() + } + + return ( +
+ + +
+ ) +} + +interface ItemProps { + x: number + y: number + positive: boolean +} + +const Item = (props: ItemProps) => { + const { x, y, positive } = props + const name = positive + ? "bg-[rgba(21,_215,_121,_0.936)] outline-[6px_solid_rgba(98,_255,_179,_0.31)]" + : "bg-[rgba(237,_49,_55,_0.942)] outline-[6px_solid_rgba(255,_89,_95,_0.31)]" + return ( +
+ ) +} + +const InteractiveSegPoints = () => { + const clicks = useStore((state) => state.interactiveSegState.clicks) + + return ( +
+ {clicks.map((click) => { + return ( + + ) + })} +
+ ) +} + +const InteractiveSeg = () => { + const [interactiveSegState, updateInteractiveSegState] = useStore((state) => [ + state.interactiveSegState, + state.updateInteractiveSegState, + ]) + + return ( +
+ + {/* */} +
+ ) +} + +export { InteractiveSeg, InteractiveSegPoints } diff --git a/web_app/src/components/Plugins.tsx b/web_app/src/components/Plugins.tsx index 2765044..122b69c 100644 --- a/web_app/src/components/Plugins.tsx +++ b/web_app/src/components/Plugins.tsx @@ -10,6 +10,8 @@ import { import { Button } from "./ui/button" import { Fullscreen, MousePointerClick, Slice, Smile } from "lucide-react" import { MixIcon } from "@radix-ui/react-icons" +import { useStore } from "@/lib/states" +import { InteractiveSeg } from "./InteractiveSeg" export enum PluginName { RemoveBG = "RemoveBG", @@ -48,17 +50,10 @@ const pluginMap = { } const Plugins = () => { - // const [open, toggleOpen] = useToggle(true) - // const serverConfig = useRecoilValue(serverConfigState) - // const isProcessing = useRecoilValue(isProcessingState) - const plugins = [ - PluginName.RemoveBG, - PluginName.AnimeSeg, - PluginName.RealESRGAN, - PluginName.GFPGAN, - PluginName.RestoreFormer, - PluginName.InteractiveSeg, - ] + const [plugins, updateInteractiveSegState] = useStore((state) => [ + state.serverConfig.plugins, + state.updateInteractiveSegState, + ]) if (plugins.length === 0) { return null @@ -68,6 +63,9 @@ const Plugins = () => { // if (!disabled) { // emitter.emit(pluginName) // } + if (pluginName === PluginName.InteractiveSeg) { + updateInteractiveSegState({ isInteractiveSeg: true }) + } } const onRealESRGANClick = (upscale: number) => { @@ -98,8 +96,8 @@ const Plugins = () => { } const renderPlugins = () => { - return plugins.map((plugin: PluginName) => { - const { IconClass, showName } = pluginMap[plugin] + return plugins.map((plugin: string) => { + const { IconClass, showName } = pluginMap[plugin as PluginName] if (plugin === PluginName.RealESRGAN) { return renderRealESRGANPlugin() } @@ -116,7 +114,10 @@ const Plugins = () => { return ( - + diff --git a/web_app/src/components/PromptInput.tsx b/web_app/src/components/PromptInput.tsx index b3442f6..4d162cd 100644 --- a/web_app/src/components/PromptInput.tsx +++ b/web_app/src/components/PromptInput.tsx @@ -9,17 +9,17 @@ import { Input } from "./ui/input" import { useStore } from "@/lib/states" const PromptInput = () => { - const [isInpainting, prompt, setPrompt] = useStore((state) => [ + const [isInpainting, prompt, updateSettings] = useStore((state) => [ state.isInpainting, - state.prompt, - state.setPrompt, + state.settings.prompt, + state.updateSettings, ]) const handleOnInput = (evt: FormEvent) => { evt.preventDefault() evt.stopPropagation() const target = evt.target as HTMLInputElement - setPrompt(target.value) + updateSettings({ prompt: target.value }) } const handleRepaintClick = () => { diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx new file mode 100644 index 0000000..86aac32 --- /dev/null +++ b/web_app/src/components/Settings.tsx @@ -0,0 +1,435 @@ +import { IconButton } from "@/components/ui/button" +import { useToggle } from "@uidotdev/usehooks" +import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog" +import { useHotkeys } from "react-hotkeys-hook" +import { Info, Settings } from "lucide-react" +import { zodResolver } from "@hookform/resolvers/zod" +import { useForm } from "react-hook-form" +import * as z from "zod" +import { Button } from "@/components/ui/button" +import { Separator } from "@/components/ui/separator" +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form" +import { Input } from "@/components/ui/input" +import { Switch } from "./ui/switch" +import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs" +import { useState } from "react" +import { cn } from "@/lib/utils" +import { useQuery } from "@tanstack/react-query" +import { fetchModelInfos, switchModel } from "@/lib/api" +import { ModelInfo } from "@/lib/types" +import { useStore } from "@/lib/states" +import { ScrollArea } from "./ui/scroll-area" +import { useToast } from "./ui/use-toast" +import { + AlertDialog, + AlertDialogContent, + AlertDialogDescription, + AlertDialogHeader, +} from "./ui/alert-dialog" + +const formSchema = z.object({ + enableFileManager: z.boolean(), + inputDirectory: z.string().refine(async (id) => { + // verify that ID exists in database + return true + }), + outputDirectory: z.string().refine(async (id) => { + // verify that ID exists in database + return true + }), + enableDownloadMask: z.boolean(), + enableManualInpainting: z.boolean(), + enableUploadMask: z.boolean(), +}) + +const TAB_GENERAL = "General" +const TAB_MODEL = "Model" +const TAB_FILE_MANAGER = "File Manager" + +const TAB_NAMES = [TAB_MODEL, TAB_GENERAL] + +export function SettingsDialog() { + const [open, toggleOpen] = useToggle(false) + const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false) + const [tab, setTab] = useState(TAB_GENERAL) + const [settings, updateSettings, fileManagerState, updateFileManagerState] = + useStore((state) => [ + state.settings, + state.updateSettings, + state.fileManagerState, + state.updateFileManagerState, + ]) + const { toast } = useToast() + const [model, setModel] = useState(settings.model) + + const { data: modelInfos, isSuccess } = useQuery({ + queryKey: ["modelInfos"], + queryFn: fetchModelInfos, + }) + + // 1. Define your form. + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + enableDownloadMask: settings.enableDownloadMask, + enableManualInpainting: settings.enableManualInpainting, + enableUploadMask: settings.enableUploadMask, + enableFileManager: fileManagerState.enabled, + inputDirectory: fileManagerState.inputDirectory, + outputDirectory: fileManagerState.outputDirectory, + }, + }) + + function onSubmit(values: z.infer) { + // Do something with the form values. ✅ This will be type-safe and validated. + updateSettings({ + enableDownloadMask: values.enableDownloadMask, + enableManualInpainting: values.enableManualInpainting, + enableUploadMask: values.enableUploadMask, + }) + + // TODO: validate input/output Directory + updateFileManagerState({ + enabled: values.enableFileManager, + inputDirectory: values.inputDirectory, + outputDirectory: values.outputDirectory, + }) + + if (model.name !== settings.model.name) { + toggleOpenModelSwitching() + switchModel(model.name) + .then((res) => { + if (res.ok) { + toast({ + title: `Switch to ${model.name} success`, + }) + updateSettings({ model: model }) + } else { + throw new Error("Server error") + } + }) + .catch(() => { + toast({ + variant: "destructive", + title: `Switch to ${model.name} failed`, + }) + }) + .finally(() => { + toggleOpenModelSwitching() + }) + } + } + useHotkeys("s", () => { + toggleOpen() + form.handleSubmit(onSubmit)() + }) + + function onOpenChange(value: boolean) { + toggleOpen() + if (!value) { + form.handleSubmit(onSubmit)() + } + } + + function onModelSelect(info: ModelInfo) { + setModel(info) + } + + function renderModelList(model_types: string[]) { + if (!modelInfos) { + return
Please download model first
+ } + return modelInfos + .filter((info) => model_types.includes(info.model_type)) + .map((info: ModelInfo) => { + return ( +
onModelSelect(info)}> +
+
{info.name}
+
+ +
+ ) + }) + } + + function renderModelSettings() { + if (!isSuccess) { + return <> + } + + let defaultTab = "inpaint" + for (let info of modelInfos) { + if (model.name === info.name) { + defaultTab = info.model_type + break + } + } + + return ( +
+
+
Current Model
+
{model.name}
+
+ + + +
+
+
Available models
+ + + +
+ + + Inpaint + Diffusion + + Diffusion inpaint + + Diffusion other + + + + {renderModelList(["inpaint"])} + + + {renderModelList(["diffusers_sd", "diffusers_sdxl"])} + + + {renderModelList([ + "diffusers_sd_inpaint", + "diffusers_sdxl_inpaint", + ])} + + + {renderModelList(["diffusers_other"])} + + + +
+
+ ) + } + + function renderGeneralSettings() { + return ( +
+ ( + +
+ Enable manual inpainting + + Click a button to trigger inpainting after draw mask. + +
+ + + +
+ )} + /> + + + + ( + +
+ Enable download mask + + Also download the mask after save the inpainting result. + +
+ + + +
+ )} + /> + + + + ( + +
+ Enable upload mask + + Enable upload custom mask to perform inpainting. + +
+ + + +
+ )} + /> + +
+ ) + } + + function renderFileManagerSettings() { + return ( +
+ ( + +
+ Enable file manger + + Browser images + +
+ + + +
+ )} + /> + + + + ( + + Input directory + + + + + Browser images from this directory. + + + + )} + /> + + ( + + Save directory + + + + + Result images will be saved to this directory. + + + + )} + /> +
+ ) + } + + return ( + <> + + + + + TODO: 添加加载动画 Switching to {model.name} + + + + + + + + + + + event.preventDefault()} + onOpenAutoFocus={(event) => event.preventDefault()} + // onPointerDownOutside={(event) => event.preventDefault()} + > + Settings + + +
+
+ {TAB_NAMES.map((item) => ( + + ))} +
+ +
+
+ + {tab === TAB_MODEL ? renderModelSettings() : <>} + {tab === TAB_GENERAL ? renderGeneralSettings() : <>} + {/* {tab === TAB_FILE_MANAGER ? ( + renderFileManagerSettings() + ) : ( + <> + )} */} + + {/*
+ +
*/} + +
+ +
+
+
+ + ) +} + +export default SettingsDialog diff --git a/web_app/src/components/Workspace.tsx b/web_app/src/components/Workspace.tsx index a869a7c..64e09e4 100644 --- a/web_app/src/components/Workspace.tsx +++ b/web_app/src/components/Workspace.tsx @@ -1,18 +1,16 @@ import { useEffect } from "react" -import { useRecoilState, useRecoilValue, useSetRecoilState } from "recoil" import Editor from "./Editor" -// import SettingModal from "./Settings/SettingsModal" import { AIModel, isPaintByExampleState, isPix2PixState, isSDState, - settingState, } from "@/lib/store" -import { currentModel, modelDownloaded, switchModel } from "@/lib/api" +import { currentModel } from "@/lib/api" import { useStore } from "@/lib/states" import ImageSize from "./ImageSize" import Plugins from "./Plugins" +import { InteractiveSeg } from "./InteractiveSeg" // import SidePanel from "./SidePanel/SidePanel" // import PESidePanel from "./SidePanel/PESidePanel" // import P2PSidePanel from "./SidePanel/P2PSidePanel" @@ -21,73 +19,18 @@ import Plugins from "./Plugins" // import ImageSize from "./ImageSize/ImageSize" const Workspace = () => { - const file = useStore((state) => state.file) - const [settings, setSettingState] = useRecoilState(settingState) - const isSD = useRecoilValue(isSDState) - const isPaintByExample = useRecoilValue(isPaintByExampleState) - const isPix2Pix = useRecoilValue(isPix2PixState) - - const onSettingClose = async () => { - const curModel = await currentModel().then((res) => res.text()) - if (curModel === settings.model) { - return - } - const downloaded = await modelDownloaded(settings.model).then((res) => - res.text() - ) - - const { model } = settings - - let loadingMessage = `Switching to ${model} model` - let loadingDuration = 3000 - if (downloaded === "False") { - loadingMessage = `Downloading ${model} model, this may take a while` - loadingDuration = 9999999999 - } - - // TODO 修改成 Modal - // setToastState({ - // open: true, - // desc: loadingMessage, - // state: "loading", - // duration: loadingDuration, - // }) - - switchModel(model) - .then((res) => { - if (res.ok) { - // setToastState({ - // open: true, - // desc: `Switch to ${model} model success`, - // state: "success", - // duration: 3000, - // }) - } else { - throw new Error("Server error") - } - }) - .catch(() => { - // setToastState({ - // open: true, - // desc: `Switch to ${model} model failed`, - // state: "error", - // duration: 3000, - // }) - setSettingState((old) => { - return { ...old, model: curModel as AIModel } - }) - }) - } + const [file, updateSettings] = useStore((state) => [ + state.file, + state.updateSettings, + ]) useEffect(() => { currentModel() - .then((res) => res.text()) + .then((res) => res.json()) .then((model) => { - setSettingState((old) => { - return { ...old, model: model as AIModel } - }) + updateSettings({ model }) }) - }, [setSettingState]) + }, []) return ( <> @@ -99,6 +42,7 @@ const Workspace = () => {
+ {file ? : <>} ) diff --git a/web_app/src/components/ui/alert-dialog.tsx b/web_app/src/components/ui/alert-dialog.tsx new file mode 100644 index 0000000..cc49f39 --- /dev/null +++ b/web_app/src/components/ui/alert-dialog.tsx @@ -0,0 +1,139 @@ +import * as React from "react" +import * as AlertDialogPrimitive from "@radix-ui/react-alert-dialog" + +import { cn } from "@/lib/utils" +import { buttonVariants } from "@/components/ui/button" + +const AlertDialog = AlertDialogPrimitive.Root + +const AlertDialogTrigger = AlertDialogPrimitive.Trigger + +const AlertDialogPortal = AlertDialogPrimitive.Portal + +const AlertDialogOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogOverlay.displayName = AlertDialogPrimitive.Overlay.displayName + +const AlertDialogContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + +)) +AlertDialogContent.displayName = AlertDialogPrimitive.Content.displayName + +const AlertDialogHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogHeader.displayName = "AlertDialogHeader" + +const AlertDialogFooter = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogFooter.displayName = "AlertDialogFooter" + +const AlertDialogTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogTitle.displayName = AlertDialogPrimitive.Title.displayName + +const AlertDialogDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogDescription.displayName = + AlertDialogPrimitive.Description.displayName + +const AlertDialogAction = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogAction.displayName = AlertDialogPrimitive.Action.displayName + +const AlertDialogCancel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogCancel.displayName = AlertDialogPrimitive.Cancel.displayName + +export { + AlertDialog, + AlertDialogPortal, + AlertDialogOverlay, + AlertDialogTrigger, + AlertDialogContent, + AlertDialogHeader, + AlertDialogFooter, + AlertDialogTitle, + AlertDialogDescription, + AlertDialogAction, + AlertDialogCancel, +} diff --git a/web_app/src/components/ui/button.tsx b/web_app/src/components/ui/button.tsx index 94b07a8..009ad3d 100644 --- a/web_app/src/components/ui/button.tsx +++ b/web_app/src/components/ui/button.tsx @@ -78,7 +78,7 @@ const IconButton = React.forwardRef( {...rest} ref={ref} tabIndex={-1} - className="cursor-default" + className="cursor-default bg-background" >
{children}
diff --git a/web_app/src/components/ui/dialog.tsx b/web_app/src/components/ui/dialog.tsx index 4ad69d4..2817d93 100644 --- a/web_app/src/components/ui/dialog.tsx +++ b/web_app/src/components/ui/dialog.tsx @@ -87,7 +87,7 @@ const DialogTitle = React.forwardRef< = FieldPath +> = { + name: TName +} + +const FormFieldContext = React.createContext( + {} as FormFieldContextValue +) + +const FormField = < + TFieldValues extends FieldValues = FieldValues, + TName extends FieldPath = FieldPath +>({ + ...props +}: ControllerProps) => { + return ( + + + + ) +} + +const useFormField = () => { + const fieldContext = React.useContext(FormFieldContext) + const itemContext = React.useContext(FormItemContext) + const { getFieldState, formState } = useFormContext() + + const fieldState = getFieldState(fieldContext.name, formState) + + if (!fieldContext) { + throw new Error("useFormField should be used within ") + } + + const { id } = itemContext + + return { + id, + name: fieldContext.name, + formItemId: `${id}-form-item`, + formDescriptionId: `${id}-form-item-description`, + formMessageId: `${id}-form-item-message`, + ...fieldState, + } +} + +type FormItemContextValue = { + id: string +} + +const FormItemContext = React.createContext( + {} as FormItemContextValue +) + +const FormItem = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => { + const id = React.useId() + + return ( + +
+ + ) +}) +FormItem.displayName = "FormItem" + +const FormLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => { + const { error, formItemId } = useFormField() + + return ( +