This commit is contained in:
Qing 2023-12-01 10:15:35 +08:00
parent 973987dfbb
commit 9a9eb8abfd
55 changed files with 2596 additions and 1251 deletions

View File

@ -4,16 +4,14 @@ from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
MPS_SUPPORT_MODELS = [ MPS_UNSUPPORT_MODELS = [
"instruct_pix2pix", "lama",
"sd1.5", "ldm",
"anything4", "zits",
"realisticVision1.4", "mat",
"sd2", "fcf",
"paint_by_example", "cv2",
"controlnet", "manga",
"kandinsky2.2",
"sdxl",
] ]
DEFAULT_MODEL = "lama" DEFAULT_MODEL = "lama"
@ -36,18 +34,13 @@ AVAILABLE_MODELS = [
"sdxl", "sdxl",
] ]
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
MODELS_SUPPORT_FREEU = SD15_MODELS + ["sd2", "sdxl", "instruct_pix2pix"] DIFFUSERS_MODEL_FP16_REVERSION = [
MODELS_SUPPORT_LCM_LORA = SD15_MODELS + ["sdxl"] "runwayml/stable-diffusion-inpainting",
"Sanster/anything-4.0-inpainting",
FREEU_DEFAULT_CONFIGS = { "Sanster/Realistic_Vision_V1.4-inpainting",
"sd2": dict(s1=0.9, s2=0.2, b1=1.1, b2=1.2), "stabilityai/stable-diffusion-2-inpainting",
"sdxl": dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2), "timbrooks/instruct-pix2pix",
"sd1.5": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4), ]
"anything4": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
"realisticVision1.4": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
"instruct_pix2pix": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
}
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
DEFAULT_DEVICE = "cuda" DEFAULT_DEVICE = "cuda"
@ -70,14 +63,29 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
""" """
SD_CONTROLNET_HELP = """ SD_CONTROLNET_HELP = """
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui. Run Stable Diffusion normal or inpainting model with ControlNet.
""" """
DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny" DEFAULT_SD_CONTROLNET_METHOD = "thibaud/controlnet-sd21-openpose-diffusers"
SD_CONTROLNET_CHOICES = [ SD_CONTROLNET_CHOICES = [
"control_v11p_sd15_canny", "lllyasviel/control_v11p_sd15_canny",
"control_v11p_sd15_openpose", # "lllyasviel/control_v11p_sd15_seg",
"control_v11p_sd15_inpaint", "lllyasviel/control_v11p_sd15_openpose",
"control_v11f1p_sd15_depth", "lllyasviel/control_v11p_sd15_inpaint",
"lllyasviel/control_v11f1p_sd15_depth",
]
DEFAULT_SD2_CONTROLNET_METHOD = "thibaud/controlnet-sd21-canny-diffusers"
SD2_CONTROLNET_CHOICES = [
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-openpose-diffusers",
]
DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0"
SDXL_CONTROLNET_CHOICES = [
"thibaud/controlnet-openpose-sdxl-1.0",
"diffusers/controlnet-canny-sdxl-1.0",
"diffusers/controlnet-depth-sdxl-1.0",
] ]
SD_LOCAL_MODEL_HELP = """ SD_LOCAL_MODEL_HELP = """
@ -152,7 +160,7 @@ class Config(BaseModel):
model: str = DEFAULT_MODEL model: str = DEFAULT_MODEL
sd_local_model_path: str = None sd_local_model_path: str = None
sd_controlnet: bool = False sd_controlnet: bool = False
sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD sd_controlnet_method: str = DEFAULT_SD_CONTROLNET_METHOD
device: str = DEFAULT_DEVICE device: str = DEFAULT_DEVICE
gui: bool = False gui: bool = False
no_gui_auto_close: bool = False no_gui_auto_close: bool = False

View File

@ -1,41 +0,0 @@
import json
from pathlib import Path
from typing import Dict, List
def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/")
def _scan_models(cache_dir, class_name: List[str]) -> List[str]:
cache_dir = Path(cache_dir)
res = []
for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f:
data = json.load(f)
if data["_class_name"] in class_name:
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name not in res:
res.append(name)
return res
def scan_models(cache_dir) -> Dict[str, List[str]]:
return {
"sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]),
"sd_inpaint": _scan_models(
cache_dir,
[
"StableDiffusionInpaintPipeline",
"StableDiffusionXLInpaintPipeline",
"KandinskyV22InpaintPipeline",
],
),
"other": _scan_models(
cache_dir,
[
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
],
),
}

View File

@ -1,8 +1,20 @@
import json
import os import os
from typing import List
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
from lama_cleaner.schema import (
ModelInfo,
ModelType,
DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
DIFFUSERS_SD_CLASS_NAME,
DIFFUSERS_SDXL_CLASS_NAME,
)
def cli_download_model(model: str, model_dir: str): def cli_download_model(model: str, model_dir: str):
if os.path.isfile(model_dir): if os.path.isfile(model_dir):
@ -14,7 +26,7 @@ def cli_download_model(model: str, model_dir: str):
os.environ["XDG_CACHE_HOME"] = model_dir os.environ["XDG_CACHE_HOME"] = model_dir
from lama_cleaner.model_manager import models from lama_cleaner.model import models
if model in models: if model in models:
logger.info(f"Downloading {model}...") logger.info(f"Downloading {model}...")
@ -22,3 +34,127 @@ def cli_download_model(model: str, model_dir: str):
logger.info(f"Done.") logger.info(f"Done.")
else: else:
logger.info(f"Downloading model from Huggingface: {model}") logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline
downloaded_path = DiffusionPipeline.download(
pretrained_model_name=model,
revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
resume_download=True,
)
logger.info(f"Done. Downloaded to {downloaded_path}")
def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/")
def scan_diffusers_models(
cache_dir, class_name: List[str], model_type: ModelType
) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
res = []
for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f:
data = json.load(f)
if data["_class_name"] in class_name:
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name not in res:
res.append(
ModelInfo(
name=name,
path=name,
model_type=model_type,
)
)
return res
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
res = []
for it in cache_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
if "inpaint" in str(it).lower():
if "sdxl" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
model_type = ModelType.DIFFUSERS_SD_INPAINT
else:
if "sdxl" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL
else:
model_type = ModelType.DIFFUSERS_SD
res.append(
ModelInfo(
name=it.name,
path=str(it.absolute()),
model_type=model_type,
is_single_file_diffusers=True,
)
)
return res
def scan_inpaint_models() -> List[ModelInfo]:
res = []
from lama_cleaner.model import models
for name, m in models.items():
if m.is_erase_model:
res.append(
ModelInfo(
name=name,
path=name,
model_type=ModelType.INPAINT,
)
)
return res
def scan_models() -> List[ModelInfo]:
from diffusers.utils import DIFFUSERS_CACHE
available_models = []
available_models.extend(scan_inpaint_models())
available_models.extend(
scan_single_file_diffusion_models(os.environ["XDG_CACHE_HOME"])
)
cache_dir = Path(DIFFUSERS_CACHE)
diffusers_model_names = []
for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f:
data = json.load(f)
_class_name = data["_class_name"]
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names:
continue
if _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
elif _class_name in [
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
"KandinskyV22InpaintPipeline",
]:
model_type = ModelType.DIFFUSERS_OTHER
else:
continue
diffusers_model_names.append(name)
available_models.append(
ModelInfo(
name=name,
path=name,
model_type=model_type,
)
)
return available_models

View File

@ -7,6 +7,7 @@ import time
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
# from watchdog.events import FileSystemEventHandler # from watchdog.events import FileSystemEventHandler
# from watchdog.observers import Observer # from watchdog.observers import Observer
@ -149,6 +150,7 @@ class FileManager:
def get_thumbnail( def get_thumbnail(
self, directory: Path, original_filename: str, width, height, **options self, directory: Path, original_filename: str, width, height, **options
): ):
directory = Path(directory)
storage = FilesystemStorageBackend(self.app) storage = FilesystemStorageBackend(self.app)
crop = options.get("crop", "fit") crop = options.get("crop", "fit")
background = options.get("background") background = options.get("background")
@ -167,6 +169,7 @@ class FileManager:
thumbnail_size = (width, height) thumbnail_size = (width, height)
thumbnail_filename = generate_filename( thumbnail_filename = generate_filename(
directory,
original_filename, original_filename,
aspect_to_string(thumbnail_size), aspect_to_string(thumbnail_size),
crop, crop,

View File

@ -1,19 +1,17 @@
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
import importlib import hashlib
import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
def generate_filename(original_filename, *options): def generate_filename(directory: Path, original_filename, *options) -> str:
name, ext = os.path.splitext(original_filename) text = str(directory.absolute()) + original_filename
for v in options: for v in options:
if v: text += "%s" % v
name += "_%s" % v md5_hash = hashlib.md5()
name += ext md5_hash.update(text.encode("utf-8"))
return md5_hash.hexdigest() + ".jpg"
return name
def parse_size(size): def parse_size(size):
@ -48,7 +46,7 @@ def aspect_to_string(size):
return "x".join(map(str, size)) return "x".join(map(str, size))
IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'} IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
def glob_img(p: Union[Path, str], recursive: bool = False): def glob_img(p: Union[Path, str], recursive: bool = False):

View File

@ -8,7 +8,7 @@ import cv2
from PIL import Image, ImageOps, PngImagePlugin from PIL import Image, ImageOps, PngImagePlugin
import numpy as np import numpy as np
import torch import torch
from lama_cleaner.const import MPS_SUPPORT_MODELS from lama_cleaner.const import MPS_UNSUPPORT_MODELS
from loguru import logger from loguru import logger
from torch.hub import download_url_to_file, get_dir from torch.hub import download_url_to_file, get_dir
import hashlib import hashlib
@ -23,7 +23,7 @@ def md5sum(filename):
def switch_mps_device(model_name, device): def switch_mps_device(model_name, device):
if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps": if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
logger.info(f"{model_name} not support mps, switch to cpu") logger.info(f"{model_name} not support mps, switch to cpu")
return torch.device("cpu") return torch.device("cpu")
return device return device

View File

@ -0,0 +1,33 @@
from .controlnet import ControlNet
from .fcf import FcF
from .instruct_pix2pix import InstructPix2Pix
from .kandinsky import Kandinsky22
from .lama import LaMa
from .ldm import LDM
from .manga import Manga
from .mat import MAT
from .mi_gan import MIGAN
from .opencv2 import OpenCV2
from .paint_by_example import PaintByExample
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
from .sdxl import SDXL
from .zits import ZITS
models = {
LaMa.name: LaMa,
LDM.name: LDM,
ZITS.name: ZITS,
MAT.name: MAT,
FcF.name: FcF,
OpenCV2.name: OpenCV2,
Manga.name: Manga,
MIGAN.name: MIGAN,
SD15.name: SD15,
Anything4.name: Anything4,
RealisticVision14.name: RealisticVision14,
SD2.name: SD2,
PaintByExample.name: PaintByExample,
InstructPix2Pix.name: InstructPix2Pix,
Kandinsky22.name: Kandinsky22,
SDXL.name: SDXL,
}

View File

@ -12,7 +12,7 @@ from lama_cleaner.helper import (
pad_img_to_modulo, pad_img_to_modulo,
switch_mps_device, switch_mps_device,
) )
from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb from lama_cleaner.model.helper.g_diffuser_bot import expand_image
from lama_cleaner.model.utils import get_scheduler from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, HDStrategy, SDSampler from lama_cleaner.schema import Config, HDStrategy, SDSampler
@ -22,6 +22,7 @@ class InpaintModel:
min_size: Optional[int] = None min_size: Optional[int] = None
pad_mod = 8 pad_mod = 8
pad_to_square = False pad_to_square = False
is_erase_model = False
def __init__(self, device, **kwargs): def __init__(self, device, **kwargs):
""" """
@ -264,6 +265,12 @@ class InpaintModel:
class DiffusionInpaintModel(InpaintModel): class DiffusionInpaintModel(InpaintModel):
def __init__(self, device, **kwargs):
if kwargs.get("model_id_or_path"):
# 用于自定义 diffusers 模型
self.model_id_or_path = kwargs["model_id_or_path"]
super().__init__(device, **kwargs)
@torch.no_grad() @torch.no_grad()
def __call__(self, image, mask, config: Config): def __call__(self, image, mask, config: Config):
""" """

View File

@ -1,5 +1,3 @@
import gc
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
@ -7,107 +5,26 @@ import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel
from loguru import logger from loguru import logger
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc, get_scheduler from lama_cleaner.model.helper.controlnet_preprocess import (
from lama_cleaner.schema import Config make_canny_control_image,
make_openpose_control_image,
make_depth_control_image,
make_inpaint_control_image,
)
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, ModelInfo, ModelType
# 为了兼容性
class CPUTextEncoderWrapper(torch.nn.Module): controlnet_name_map = {
def __init__(self, text_encoder, torch_dtype): "control_v11p_sd15_canny": "lllyasviel/control_v11p_sd15_canny",
super().__init__() "control_v11p_sd15_openpose": "lllyasviel/control_v11p_sd15_openpose",
self.config = text_encoder.config "control_v11p_sd15_inpaint": "lllyasviel/control_v11p_sd15_inpaint",
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) "control_v11f1p_sd15_depth": "lllyasviel/control_v11f1p_sd15_depth",
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype
del text_encoder
torch_gc()
def __call__(self, x, **kwargs):
input_device = x.device
return [
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
.to(input_device)
.to(self.torch_dtype)
]
@property
def dtype(self):
return self.torch_dtype
NAMES_MAP = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
} }
NATIVE_NAMES_MAP = {
"sd1.5": "runwayml/stable-diffusion-v1-5",
"anything4": "andite/anything-v4.0",
"realisticVision1.4": "SG161222/Realistic_Vision_V1.4",
}
def make_inpaint_condition(image, image_mask):
"""
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
"""
image = image.astype(np.float32) / 255.0
image[image_mask[:, :, -1] > 128] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
def load_from_local_model(
local_model_path, torch_dtype, controlnet, pipe_class, is_native_control_inpaint
):
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline")
try:
pipe = download_from_original_stable_diffusion_ckpt(
local_model_path,
num_in_channels=4 if is_native_control_inpaint else 9,
from_safetensors=local_model_path.endswith("safetensors"),
device="cpu",
load_safety_checker=False,
)
except Exception as e:
err_msg = str(e)
logger.exception(e)
if is_native_control_inpaint and "[320, 9, 3, 3]" in err_msg:
logger.error(
"control_v11p_sd15_inpaint method requires normal SD model, not inpainting SD model"
)
if not is_native_control_inpaint and "[320, 4, 3, 3]" in err_msg:
logger.error(
f"{controlnet.config['_name_or_path']} method requires inpainting SD model, "
f"you can convert any SD model to inpainting model in AUTO1111: \n"
f"https://www.reddit.com/r/StableDiffusion/comments/zyi24j/how_to_turn_any_model_into_an_inpainting_model/"
)
exit(-1)
inpaint_pipe = pipe_class(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
controlnet=controlnet,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
del pipe
gc.collect()
return inpaint_pipe.to(torch_dtype=torch_dtype)
class ControlNet(DiffusionInpaintModel): class ControlNet(DiffusionInpaintModel):
name = "controlnet" name = "controlnet"
@ -116,10 +33,16 @@ class ControlNet(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False) fp16 = not kwargs.get("no_half", False)
model_info: ModelInfo = kwargs["model_info"]
sd_controlnet_method = kwargs["sd_controlnet_method"]
sd_controlnet_method = controlnet_name_map.get(
sd_controlnet_method, sd_controlnet_method
)
model_kwargs = { self.model_info = model_info
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) self.sd_controlnet_method = sd_controlnet_method
}
model_kwargs = {}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update( model_kwargs.update(
@ -133,41 +56,39 @@ class ControlNet(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
sd_controlnet_method = kwargs["sd_controlnet_method"] if model_info.model_type in [
self.sd_controlnet_method = sd_controlnet_method ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SD_INPAINT,
if sd_controlnet_method == "control_v11p_sd15_inpaint": ]:
from diffusers import StableDiffusionControlNetPipeline as PipeClass from diffusers import (
StableDiffusionControlNetInpaintPipeline as PipeClass,
self.is_native_control_inpaint = True )
else: elif model_info.model_type in [
from .pipeline import StableDiffusionControlNetInpaintPipeline as PipeClass ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
self.is_native_control_inpaint = False ]:
from diffusers import (
if self.is_native_control_inpaint: StableDiffusionXLControlNetInpaintPipeline as PipeClass,
model_id = NATIVE_NAMES_MAP[kwargs["name"]] )
else:
model_id = NAMES_MAP[kwargs["name"]]
controlnet = ControlNetModel.from_pretrained( controlnet = ControlNetModel.from_pretrained(
f"lllyasviel/{sd_controlnet_method}", torch_dtype=torch_dtype sd_controlnet_method, torch_dtype=torch_dtype
) )
self.is_local_sd_model = False if model_info.is_single_file_diffusers:
if kwargs.get("sd_local_model_path", None): self.model = PipeClass.from_single_file(
self.is_local_sd_model = True model_info.path, controlnet=controlnet
self.model = load_from_local_model( ).to(torch_dtype)
kwargs["sd_local_model_path"],
torch_dtype=torch_dtype,
controlnet=controlnet,
pipe_class=PipeClass,
is_native_control_inpaint=self.is_native_control_inpaint,
)
else: else:
self.model = PipeClass.from_pretrained( self.model = PipeClass.from_pretrained(
model_id, model_info.path,
controlnet=controlnet, controlnet=controlnet,
revision="fp16" if use_gpu and fp16 else "main", revision="fp16"
if (
model_info.path in DIFFUSERS_MODEL_FP16_REVERSION
and use_gpu
and fp16
)
else "main",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
**model_kwargs, **model_kwargs,
) )
@ -191,6 +112,19 @@ class ControlNet(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
def _get_control_image(self, image, mask):
if "canny" in self.sd_controlnet_method:
control_image = make_canny_control_image(image)
elif "openpose" in self.sd_controlnet_method:
control_image = make_openpose_control_image(image)
elif "depth" in self.sd_controlnet_method:
control_image = make_depth_control_image(image)
elif "inpaint" in self.sd_controlnet_method:
control_image = make_inpaint_control_image(image, mask)
else:
raise NotImplementedError(f"{self.sd_controlnet_method} not implemented")
return control_image
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -206,84 +140,30 @@ class ControlNet(DiffusionInpaintModel):
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2] img_h, img_w = image.shape[:2]
control_image = self._get_control_image(image, mask)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
image = PIL.Image.fromarray(image)
if self.is_native_control_inpaint: output = self.model(
control_image = make_inpaint_condition(image, mask) image=image,
output = self.model( mask_image=mask_image,
prompt=config.prompt, control_image=control_image,
image=control_image, prompt=config.prompt,
height=img_h, negative_prompt=config.negative_prompt,
width=img_w, num_inference_steps=config.sd_steps,
num_inference_steps=config.sd_steps, guidance_scale=config.sd_guidance_scale,
guidance_scale=config.sd_guidance_scale, output_type="np",
controlnet_conditioning_scale=config.controlnet_conditioning_scale, callback=self.callback,
negative_prompt=config.negative_prompt, height=img_h,
generator=torch.manual_seed(config.sd_seed), width=img_w,
output_type="np", generator=torch.manual_seed(config.sd_seed),
callback=self.callback, controlnet_conditioning_scale=config.controlnet_conditioning_scale,
).images[0] ).images[0]
else:
if "canny" in self.sd_controlnet_method:
canny_image = cv2.Canny(image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate(
[canny_image, canny_image, canny_image], axis=2
)
canny_image = PIL.Image.fromarray(canny_image)
control_image = canny_image
elif "openpose" in self.sd_controlnet_method:
from controlnet_aux import OpenposeDetector
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
control_image = processor(image, hand_and_face=True)
elif "depth" in self.sd_controlnet_method:
from transformers import pipeline
depth_estimator = pipeline("depth-estimation")
depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"]
depth_image = np.array(depth_image)
depth_image = depth_image[:, :, None]
depth_image = np.concatenate(
[depth_image, depth_image, depth_image], axis=2
)
control_image = PIL.Image.fromarray(depth_image)
else:
raise NotImplementedError(
f"{self.sd_controlnet_method} not implemented"
)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
image = PIL.Image.fromarray(image)
output = self.model(
image=image,
control_image=control_image,
prompt=config.prompt,
negative_prompt=config.negative_prompt,
mask_image=mask_image,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask
@staticmethod @staticmethod
def is_downloaded() -> bool: def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings # model will be downloaded when app start, and can't switch in frontend settings

View File

@ -1626,6 +1626,7 @@ class FcF(InpaintModel):
min_size = 512 min_size = 512
pad_mod = 512 pad_mod = 512
pad_to_square = True pad_to_square = True
is_erase_model = True
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):
seed = 0 seed = 0

View File

@ -0,0 +1,46 @@
import torch
import PIL
import cv2
from PIL import Image
import numpy as np
def make_canny_control_image(image: np.ndarray) -> Image:
canny_image = cv2.Canny(image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
canny_image = PIL.Image.fromarray(canny_image)
control_image = canny_image
return control_image
def make_openpose_control_image(image: np.ndarray) -> Image:
from controlnet_aux import OpenposeDetector
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
control_image = processor(image, hand_and_face=True)
return control_image
def make_depth_control_image(image: np.ndarray) -> Image:
from transformers import pipeline
depth_estimator = pipeline("depth-estimation")
depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"]
depth_image = np.array(depth_image)
depth_image = depth_image[:, :, None]
depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2)
control_image = PIL.Image.fromarray(depth_image)
return control_image
def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor:
"""
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
"""
image = image.astype(np.float32) / 255.0
image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image

View File

@ -0,0 +1,25 @@
import torch
from lama_cleaner.model.utils import torch_gc
class CPUTextEncoderWrapper(torch.nn.Module):
def __init__(self, text_encoder, torch_dtype):
super().__init__()
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype
del text_encoder
torch_gc()
def __call__(self, x, **kwargs):
input_device = x.device
return [
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
.to(input_device)
.to(self.torch_dtype)
]
@property
def dtype(self):
return self.torch_dtype

View File

@ -17,7 +17,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
fp16 = not kwargs.get("no_half", False) fp16 = not kwargs.get("no_half", False)
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)} model_kwargs = {}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update( model_kwargs.update(
@ -77,16 +77,6 @@ class InstructPix2Pix(DiffusionInpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
#
# def forward_post_process(self, result, image, mask, config):
# if config.sd_match_histograms:
# result = self._match_histograms(result, image[:, :, ::-1], mask)
#
# if config.sd_mask_blur != 0:
# k = 2 * config.sd_mask_blur + 1
# mask = cv2.GaussianBlur(mask, (k, k), 0)
# return result, image, mask
@staticmethod @staticmethod
def is_downloaded() -> bool: def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings # model will be downloaded when app start, and can't switch in frontend settings

View File

@ -20,7 +20,6 @@ class Kandinsky(DiffusionInpaintModel):
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = { model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]),
"torch_dtype": torch_dtype, "torch_dtype": torch_dtype,
} }

View File

@ -23,6 +23,7 @@ LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e
class LaMa(InpaintModel): class LaMa(InpaintModel):
name = "lama" name = "lama"
pad_mod = 8 pad_mod = 8
is_erase_model = True
@staticmethod @staticmethod
def download(): def download():

View File

@ -237,6 +237,7 @@ class LatentDiffusion(DDPM):
class LDM(InpaintModel): class LDM(InpaintModel):
name = "ldm" name = "ldm"
pad_mod = 32 pad_mod = 32
is_erase_model = True
def __init__(self, device, fp16: bool = True, **kwargs): def __init__(self, device, fp16: bool = True, **kwargs):
self.fp16 = fp16 self.fp16 = fp16

View File

@ -32,6 +32,7 @@ MANGA_LINE_MODEL_MD5 = os.environ.get(
class Manga(InpaintModel): class Manga(InpaintModel):
name = "manga" name = "manga"
pad_mod = 16 pad_mod = 16
is_erase_model = True
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):
self.inpaintor_model = load_jit_model( self.inpaintor_model = load_jit_model(

View File

@ -1880,6 +1880,7 @@ class MAT(InpaintModel):
min_size = 512 min_size = 512
pad_mod = 512 pad_mod = 512
pad_to_square = True pad_to_square = True
is_erase_model = True
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):
seed = 240 # pick up a random number seed = 240 # pick up a random number

View File

@ -26,6 +26,7 @@ class MIGAN(InpaintModel):
min_size = 512 min_size = 512
pad_mod = 512 pad_mod = 512
pad_to_square = True pad_to_square = True
is_erase_model = True
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):
self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval() self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval()

View File

@ -8,6 +8,7 @@ flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
class OpenCV2(InpaintModel): class OpenCV2(InpaintModel):
name = "cv2" name = "cv2"
pad_mod = 1 pad_mod = 1
is_erase_model = True
@staticmethod @staticmethod
def is_downloaded() -> bool: def is_downloaded() -> bool:

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
from typing import Union, List, Optional, Callable, Dict, Any from typing import Union, List, Optional, Callable, Dict, Any
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py # Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
@ -217,6 +218,38 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
@classmethod
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
controlnet = kwargs.pop("controlnet", None)
pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path,
num_in_channels=9,
from_safetensors=pretrained_model_link_or_path.endswith("safetensors"),
device="cpu",
load_safety_checker=False,
)
inpaint_pipe = cls(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
controlnet=controlnet,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
del pipe
gc.collect()
return inpaint_pipe
def prepare_mask_latents( def prepare_mask_latents(
self, self,
mask, mask,

View File

@ -1,4 +1,4 @@
import gc import os
import PIL.Image import PIL.Image
import cv2 import cv2
@ -6,34 +6,12 @@ import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
class CPUTextEncoderWrapper(torch.nn.Module):
def __init__(self, text_encoder, torch_dtype):
super().__init__()
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype
del text_encoder
torch_gc()
def __call__(self, x, **kwargs):
input_device = x.device
return [
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
.to(input_device)
.to(self.torch_dtype)
]
@property
def dtype(self):
return self.torch_dtype
class SD(DiffusionInpaintModel): class SD(DiffusionInpaintModel):
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
@ -44,9 +22,7 @@ class SD(DiffusionInpaintModel):
fp16 = not kwargs.get("no_half", False) fp16 = not kwargs.get("no_half", False)
model_kwargs = { model_kwargs = {}
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update( model_kwargs.update(
@ -60,14 +36,20 @@ class SD(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
if kwargs.get("sd_local_model_path", None): if os.path.isfile(self.model_id_or_path):
self.model = StableDiffusionInpaintPipeline.from_single_file( self.model = StableDiffusionInpaintPipeline.from_single_file(
kwargs["sd_local_model_path"], torch_dtype=torch_dtype, **model_kwargs self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
) )
else: else:
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
revision="fp16" if use_gpu and fp16 else "main", revision="fp16"
if (
self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
and use_gpu
and fp16
)
else "main",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
**model_kwargs, **model_kwargs,

View File

@ -1,7 +1,10 @@
import os
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from diffusers import AutoencoderKL
from loguru import logger from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
@ -13,26 +16,31 @@ class SDXL(DiffusionInpaintModel):
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdxl" lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
model_id_or_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from diffusers.pipelines import AutoPipelineForInpainting from diffusers.pipelines import StableDiffusionXLInpaintPipeline
fp16 = not kwargs.get("no_half", False) fp16 = not kwargs.get("no_half", False)
model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
}
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = AutoPipelineForInpainting.from_pretrained( if os.path.isfile(self.model_id_or_path):
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", self.model = StableDiffusionXLInpaintPipeline.from_single_file(
revision="main", self.model_id_or_path, torch_dtype=torch_dtype
torch_dtype=torch_dtype, )
use_auth_token=kwargs["hf_access_token"], else:
**model_kwargs, vae = AutoencoderKL.from_pretrained(
) "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
self.model_id_or_path,
revision="main",
torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"],
vae=vae,
)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing() self.model.enable_attention_slicing()

View File

@ -226,6 +226,7 @@ class ZITS(InpaintModel):
min_size = 256 min_size = 256
pad_mod = 32 pad_mod = 32
pad_to_square = True pad_to_square = True
is_erase_model = True
def __init__(self, device, **kwargs): def __init__(self, device, **kwargs):
""" """

View File

@ -1,49 +1,14 @@
import torch
import gc import gc
from typing import List, Dict
import torch
from loguru import logger from loguru import logger
from lama_cleaner.const import ( from lama_cleaner.download import scan_models
SD15_MODELS,
MODELS_SUPPORT_FREEU,
MODELS_SUPPORT_LCM_LORA,
)
from lama_cleaner.helper import switch_mps_device from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.controlnet import ControlNet from lama_cleaner.model import models, ControlNet, SD, SDXL
from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.kandinsky import Kandinsky22
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.manga import Manga
from lama_cleaner.model.mat import MAT
from lama_cleaner.model.mi_gan import MIGAN
from lama_cleaner.model.paint_by_example import PaintByExample
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
from lama_cleaner.model.sdxl import SDXL
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model.zits import ZITS from lama_cleaner.schema import Config, ModelInfo, ModelType
from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config
models = {
"lama": LaMa,
"ldm": LDM,
"zits": ZITS,
"mat": MAT,
"fcf": FcF,
SD15.name: SD15,
Anything4.name: Anything4,
RealisticVision14.name: RealisticVision14,
"cv2": OpenCV2,
"manga": Manga,
"sd2": SD2,
"paint_by_example": PaintByExample,
"instruct_pix2pix": InstructPix2Pix,
Kandinsky22.name: Kandinsky22,
SDXL.name: SDXL,
MIGAN.name: MIGAN,
}
class ModelManager: class ModelManager:
@ -51,23 +16,39 @@ class ModelManager:
self.name = name self.name = name
self.device = device self.device = device
self.kwargs = kwargs self.kwargs = kwargs
self.available_models: Dict[str, ModelInfo] = {}
self.scan_models()
self.model = self.init_model(name, device, **kwargs) self.model = self.init_model(name, device, **kwargs)
def init_model(self, name: str, device, **kwargs): def init_model(self, name: str, device, **kwargs):
if name in SD15_MODELS and kwargs.get("sd_controlnet", False): for old_name, model_cls in models.items():
return ControlNet(device, **{**kwargs, "name": name}) if name == old_name and hasattr(model_cls, "model_id_or_path"):
name = model_cls.model_id_or_path
if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}")
if name in models: sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
model = models[name](device, **kwargs) model_info = self.available_models[name]
else: if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
raise NotImplementedError(f"Not supported model: {name}") return models[name](device, **kwargs)
return model
def is_downloaded(self, name: str) -> bool: if sd_controlnet_enabled:
if name in models: return ControlNet(device, **{**kwargs, "model_info": model_info})
return models[name].is_downloaded()
else: else:
raise NotImplementedError(f"Not supported model: {name}") if model_info.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
]:
raise NotImplementedError(
f"When using non inpaint Stable Diffusion model, you must enable controlnet"
)
if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT:
return SD(device, model_id_or_path=model_info.path, **kwargs)
if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT:
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
raise NotImplementedError(f"Unsupported model: {name}")
def __call__(self, image, mask, config: Config): def __call__(self, image, mask, config: Config):
self.switch_controlnet_method(control_method=config.controlnet_method) self.switch_controlnet_method(control_method=config.controlnet_method)
@ -75,9 +56,18 @@ class ModelManager:
self.enable_disable_lcm_lora(config) self.enable_disable_lcm_lora(config)
return self.model(image, mask, config) return self.model(image, mask, config)
def switch(self, new_name: str, **kwargs): def scan_models(self) -> List[ModelInfo]:
available_models = scan_models()
self.available_models = {it.name: it for it in available_models}
return available_models
def switch(self, new_name: str):
if new_name == self.name: if new_name == self.name:
return return
old_name = self.name
self.name = new_name
try: try:
if torch.cuda.memory_allocated() > 0: if torch.cuda.memory_allocated() > 0:
# Clear current loaded model from memory # Clear current loaded model from memory
@ -88,8 +78,8 @@ class ModelManager:
self.model = self.init_model( self.model = self.init_model(
new_name, switch_mps_device(new_name, self.device), **self.kwargs new_name, switch_mps_device(new_name, self.device), **self.kwargs
) )
self.name = new_name except Exception as e:
except NotImplementedError as e: self.name = old_name
raise e raise e
def switch_controlnet_method(self, control_method: str): def switch_controlnet_method(self, control_method: str):
@ -97,27 +87,9 @@ class ModelManager:
return return
if self.kwargs["sd_controlnet_method"] == control_method: if self.kwargs["sd_controlnet_method"] == control_method:
return return
if not hasattr(self.model, "is_local_sd_model"):
return
if self.model.is_local_sd_model: if not self.available_models[self.name].support_controlnet():
# is_native_control_inpaint 表示加载了普通 SD 模型 return
if (
self.model.is_native_control_inpaint
and control_method != "control_v11p_sd15_inpaint"
):
raise RuntimeError(
f"--sd-local-model-path load a normal SD model, "
f"to use {control_method} you should load an inpainting SD model"
)
elif (
not self.model.is_native_control_inpaint
and control_method == "control_v11p_sd15_inpaint"
):
raise RuntimeError(
f"--sd-local-model-path load an inpainting SD model, "
f"to use {control_method} you should load a norml SD model"
)
del self.model del self.model
torch_gc() torch_gc()
@ -133,7 +105,7 @@ class ModelManager:
if str(self.model.device) == "mps": if str(self.model.device) == "mps":
return return
if self.name in MODELS_SUPPORT_FREEU: if self.available_models[self.name].support_freeu():
if config.sd_freeu: if config.sd_freeu:
freeu_config = config.sd_freeu_config freeu_config = config.sd_freeu_config
self.model.model.enable_freeu( self.model.model.enable_freeu(
@ -146,7 +118,7 @@ class ModelManager:
self.model.model.disable_freeu() self.model.model.disable_freeu()
def enable_disable_lcm_lora(self, config: Config): def enable_disable_lcm_lora(self, config: Config):
if self.name in MODELS_SUPPORT_LCM_LORA: if self.available_models[self.name].support_lcm_lora():
if config.sd_lcm_lora: if config.sd_lcm_lora:
if not self.model.model.pipe.get_list_adapters(): if not self.model.model.pipe.get_list_adapters():
self.model.model.load_lora_weights(self.model.lcm_lora_id) self.model.model.load_lora_weights(self.model.lcm_lora_id)

View File

@ -6,7 +6,7 @@ from pathlib import Path
from loguru import logger from loguru import logger
from lama_cleaner.const import * from lama_cleaner.const import *
from lama_cleaner.download import cli_download_model from lama_cleaner.download import cli_download_model, scan_models
from lama_cleaner.runtime import dump_environment_info from lama_cleaner.runtime import dump_environment_info
DOWNLOAD_SUBCOMMAND = "download" DOWNLOAD_SUBCOMMAND = "download"
@ -46,7 +46,11 @@ def parse_args():
"--installer-config", default=None, help="Config file for windows installer" "--installer-config", default=None, help="Config file for windows installer"
) )
parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS) parser.add_argument(
"--model",
default=DEFAULT_MODEL,
help=f"Available models: [{', '.join(AVAILABLE_MODELS)}], or model id on huggingface",
)
parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP)
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
@ -56,10 +60,9 @@ def parse_args():
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
parser.add_argument( parser.add_argument(
"--sd-controlnet-method", "--sd-controlnet-method",
default=DEFAULT_CONTROLNET_METHOD, default=DEFAULT_SD_CONTROLNET_METHOD,
choices=SD_CONTROLNET_CHOICES, choices=SD_CONTROLNET_CHOICES,
) )
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
parser.add_argument( parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
) )
@ -170,7 +173,8 @@ def parse_args():
) )
######### #########
# useless args ### useless args ###
parser.add_argument("--sd-local-model-path", default=None, help=argparse.SUPPRESS)
parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS) parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS)
parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS) parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS)
parser.add_argument( parser.add_argument(
@ -180,6 +184,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--sd-enable-xformers", action="store_true", help=argparse.SUPPRESS "--sd-enable-xformers", action="store_true", help=argparse.SUPPRESS
) )
### end useless args ###
args = parser.parse_args() args = parser.parse_args()
# collect system info to help debug # collect system info to help debug
@ -251,6 +256,17 @@ def parse_args():
os.environ["XDG_CACHE_HOME"] = args.model_dir os.environ["XDG_CACHE_HOME"] = args.model_dir
os.environ["U2NET_HOME"] = args.model_dir os.environ["U2NET_HOME"] = args.model_dir
if args.sd_run_local or args.local_files_only:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
if args.model not in AVAILABLE_MODELS:
scanned_models = scan_models()
if args.model not in [it.name for it in scanned_models]:
parser.error(
f"invalid --model: {args.model} not exists. Available models: {AVAILABLE_MODELS} or {scanned_models}"
)
if args.input and args.input is not None: if args.input and args.input is not None:
if not os.path.exists(args.input): if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists") parser.error(f"invalid --input: {args.input} not exists")

View File

@ -4,6 +4,61 @@ from enum import Enum
from PIL.Image import Image from PIL.Image import Image
from pydantic import BaseModel from pydantic import BaseModel
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
class ModelType(str, Enum):
INPAINT = "inpaint" # LaMa, MAT...
DIFFUSERS_SD = "diffusers_sd"
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
DIFFUSERS_SDXL = "diffusers_sdxl"
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
DIFFUSERS_OTHER = "diffusers_other"
FREEU_DEFAULT_CONFIGS = {
ModelType.DIFFUSERS_SD: dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
ModelType.DIFFUSERS_SDXL: dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2),
}
class ModelInfo(BaseModel):
name: str
path: str
model_type: ModelType
is_single_file_diffusers: bool = False
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
def support_freeu(self) -> bool:
return (
self.model_type
in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
or "instruct-pix2pix" in self.name
)
class HDStrategy(str, Enum): class HDStrategy(str, Enum):
# Use original image size # Use original image size

View File

@ -2,8 +2,6 @@
import os import os
import hashlib import hashlib
from lama_cleaner.diffusers_utils import scan_models
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import imghdr import imghdr
@ -22,9 +20,9 @@ from loguru import logger
from lama_cleaner.const import ( from lama_cleaner.const import (
SD15_MODELS, SD15_MODELS,
FREEU_DEFAULT_CONFIGS, SD_CONTROLNET_CHOICES,
MODELS_SUPPORT_FREEU, SDXL_CONTROLNET_CHOICES,
MODELS_SUPPORT_LCM_LORA, SD2_CONTROLNET_CHOICES,
) )
from lama_cleaner.file_manager import FileManager from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.utils import torch_gc
@ -118,8 +116,8 @@ input_image_path: str = None
is_disable_model_switch: bool = False is_disable_model_switch: bool = False
is_controlnet: bool = False is_controlnet: bool = False
controlnet_method: str = "control_v11p_sd15_canny" controlnet_method: str = "control_v11p_sd15_canny"
is_enable_file_manager: bool = False enable_file_manager: bool = False
is_enable_auto_saving: bool = False enable_auto_saving: bool = False
is_desktop: bool = False is_desktop: bool = False
image_quality: int = 95 image_quality: int = 95
plugins = {} plugins = {}
@ -421,34 +419,35 @@ def run_plugin():
@app.route("/server_config", methods=["GET"]) @app.route("/server_config", methods=["GET"])
def get_server_config(): def get_server_config():
controlnet = {
"SD": SD_CONTROLNET_CHOICES,
"SD2": SD2_CONTROLNET_CHOICES,
"SDXL": SDXL_CONTROLNET_CHOICES,
}
return { return {
"isControlNet": is_controlnet,
"controlNetMethod": controlnet_method,
"isDisableModelSwitchState": is_disable_model_switch,
"isEnableAutoSaving": is_enable_auto_saving,
"enableFileManager": is_enable_file_manager,
"plugins": list(plugins.keys()), "plugins": list(plugins.keys()),
"freeSupportedModels": MODELS_SUPPORT_FREEU, "availableControlNet": controlnet,
"freeuDefaultConfigs": FREEU_DEFAULT_CONFIGS, "enableFileManager": enable_file_manager,
"lcmLoraSupportedModels": MODELS_SUPPORT_LCM_LORA, "enableAutoSaving": enable_auto_saving,
}, 200 }, 200
@app.route("/sd_models", methods=["GET"]) @app.route("/models", methods=["GET"])
def get_diffusers_models(): def get_models():
from diffusers.utils import DIFFUSERS_CACHE return [
{
return scan_models(DIFFUSERS_CACHE) **it.dict(),
"support_lcm_lora": it.support_lcm_lora(),
"support_controlnet": it.support_controlnet(),
"support_freeu": it.support_freeu(),
}
for it in model.scan_models()
]
@app.route("/model") @app.route("/model")
def current_model(): def current_model():
return model.name, 200 return model.available_models[model.name].dict(), 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@app.route("/is_desktop") @app.route("/is_desktop")
@ -467,8 +466,10 @@ def switch_model():
try: try:
model.switch(new_name) model.switch(new_name)
except NotImplementedError: except Exception as e:
return f"{new_name} not implemented", 403 error_message = str(e)
logger.error(error_message)
return f"Switch model failed: {error_message}", 500
return f"ok, switch to {new_name}", 200 return f"ok, switch to {new_name}", 200
@ -478,7 +479,7 @@ def index():
@app.route("/inputimage") @app.route("/inputimage")
def set_input_photo(): def get_cli_input_image():
if input_image_path: if input_image_path:
with open(input_image_path, "rb") as f: with open(input_image_path, "rb") as f:
image_in_bytes = f.read() image_in_bytes = f.read()
@ -547,11 +548,10 @@ def main(args):
global device global device
global input_image_path global input_image_path
global is_disable_model_switch global is_disable_model_switch
global is_enable_file_manager global enable_file_manager
global is_desktop global is_desktop
global thumb global thumb
global output_dir global output_dir
global is_enable_auto_saving
global is_controlnet global is_controlnet
global controlnet_method global controlnet_method
global image_quality global image_quality
@ -566,7 +566,9 @@ def main(args):
output_dir = args.output_dir output_dir = args.output_dir
if output_dir: if output_dir:
is_enable_auto_saving = True output_dir = os.path.abspath(output_dir)
logger.info(f"Output dir: {output_dir}")
enable_auto_saving = True
device = torch.device(args.device) device = torch.device(args.device)
is_disable_model_switch = args.disable_model_switch is_disable_model_switch = args.disable_model_switch
@ -579,12 +581,12 @@ def main(args):
if args.input and os.path.isdir(args.input): if args.input and os.path.isdir(args.input):
logger.info(f"Initialize file manager") logger.info(f"Initialize file manager")
thumb = FileManager(app) thumb = FileManager(app)
is_enable_file_manager = True enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
args.output_dir, "lama_cleaner_thumbnails" output_dir, "lama_cleaner_thumbnails"
) )
thumb.output_dir = Path(args.output_dir) thumb.output_dir = Path(output_dir)
# thumb.start() # thumb.start()
# try: # try:
# while True: # while True:

View File

@ -9,7 +9,9 @@
"version": "0.0.0", "version": "0.0.0",
"dependencies": { "dependencies": {
"@heroicons/react": "^2.0.18", "@heroicons/react": "^2.0.18",
"@hookform/resolvers": "^3.3.2",
"@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "^1.0.5",
"@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-icons": "^1.3.0",
@ -17,6 +19,7 @@
"@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0", "@radix-ui/react-select": "^2.0.0",
"@radix-ui/react-separator": "^1.0.3",
"@radix-ui/react-slider": "^1.1.2", "@radix-ui/react-slider": "^1.1.2",
"@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-slot": "^1.0.2",
"@radix-ui/react-switch": "^1.0.3", "@radix-ui/react-switch": "^1.0.3",
@ -24,7 +27,9 @@
"@radix-ui/react-toast": "^1.1.5", "@radix-ui/react-toast": "^1.1.5",
"@radix-ui/react-toggle": "^1.0.3", "@radix-ui/react-toggle": "^1.0.3",
"@radix-ui/react-tooltip": "^1.0.7", "@radix-ui/react-tooltip": "^1.0.7",
"@tanstack/react-query": "^5.8.7",
"@uidotdev/usehooks": "^2.4.1", "@uidotdev/usehooks": "^2.4.1",
"axios": "^1.6.2",
"class-variance-authority": "^0.7.0", "class-variance-authority": "^0.7.0",
"clsx": "^2.0.0", "clsx": "^2.0.0",
"flexsearch": "^0.7.21", "flexsearch": "^0.7.21",
@ -35,6 +40,7 @@
"next-themes": "^0.2.1", "next-themes": "^0.2.1",
"react": "^18.2.0", "react": "^18.2.0",
"react-dom": "^18.2.0", "react-dom": "^18.2.0",
"react-hook-form": "^7.48.2",
"react-hotkeys-hook": "^4.4.1", "react-hotkeys-hook": "^4.4.1",
"react-photo-album": "^2.3.0", "react-photo-album": "^2.3.0",
"react-use": "^17.4.0", "react-use": "^17.4.0",
@ -42,9 +48,12 @@
"recoil": "^0.7.7", "recoil": "^0.7.7",
"tailwind-merge": "^2.0.0", "tailwind-merge": "^2.0.0",
"tailwindcss-animate": "^1.0.7", "tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4",
"zustand": "^4.4.6" "zustand": "^4.4.6"
}, },
"devDependencies": { "devDependencies": {
"@tanstack/eslint-plugin-query": "^5.8.4",
"@types/axios": "^0.14.0",
"@types/flexsearch": "^0.7.3", "@types/flexsearch": "^0.7.3",
"@types/lodash": "^4.14.201", "@types/lodash": "^4.14.201",
"@types/node": "^20.9.2", "@types/node": "^20.9.2",
@ -1069,6 +1078,14 @@
"react": ">= 16" "react": ">= 16"
} }
}, },
"node_modules/@hookform/resolvers": {
"version": "3.3.2",
"resolved": "https://registry.npmjs.org/@hookform/resolvers/-/resolvers-3.3.2.tgz",
"integrity": "sha512-Tw+GGPnBp+5DOsSg4ek3LCPgkBOuOgS5DsDV7qsWNH9LZc433kgsWICjlsh2J9p04H2K66hsXPPb9qn9ILdUtA==",
"peerDependencies": {
"react-hook-form": "^7.0.0"
}
},
"node_modules/@humanwhocodes/config-array": { "node_modules/@humanwhocodes/config-array": {
"version": "0.11.13", "version": "0.11.13",
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz",
@ -1374,6 +1391,34 @@
} }
} }
}, },
"node_modules/@radix-ui/react-alert-dialog": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-alert-dialog/-/react-alert-dialog-1.0.5.tgz",
"integrity": "sha512-OrVIOcZL0tl6xibeuGt5/+UxoT2N27KCFOPjFyfXMnchxSHZ/OW7cCX2nGlIYJrbHK/fczPcFzAwvNBB6XBNMA==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/primitive": "1.0.1",
"@radix-ui/react-compose-refs": "1.0.1",
"@radix-ui/react-context": "1.0.1",
"@radix-ui/react-dialog": "1.0.5",
"@radix-ui/react-primitive": "1.0.3",
"@radix-ui/react-slot": "1.0.2"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-arrow": { "node_modules/@radix-ui/react-arrow": {
"version": "1.0.3", "version": "1.0.3",
"resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.0.3.tgz", "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.0.3.tgz",
@ -1971,6 +2016,29 @@
} }
} }
}, },
"node_modules/@radix-ui/react-separator": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.0.3.tgz",
"integrity": "sha512-itYmTy/kokS21aiV5+Z56MZB54KrhPgn6eHDKkFeOLR34HMN2s8PaN47qZZAGnvupcjxHaFZnW4pQEh0BvvVuw==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/react-primitive": "1.0.3"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slider": { "node_modules/@radix-ui/react-slider": {
"version": "1.1.2", "version": "1.1.2",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.1.2.tgz", "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.1.2.tgz",
@ -2703,6 +2771,188 @@
"integrity": "sha512-myfUej5naTBWnqOCc/MdVOLVjXUXtIA+NpDrDBKJtLLg2shUjBu3cZmB/85RyitKc55+lUUyl7oRfLOvkr2hsw==", "integrity": "sha512-myfUej5naTBWnqOCc/MdVOLVjXUXtIA+NpDrDBKJtLLg2shUjBu3cZmB/85RyitKc55+lUUyl7oRfLOvkr2hsw==",
"dev": true "dev": true
}, },
"node_modules/@tanstack/eslint-plugin-query": {
"version": "5.8.4",
"resolved": "https://registry.npmjs.org/@tanstack/eslint-plugin-query/-/eslint-plugin-query-5.8.4.tgz",
"integrity": "sha512-KVgcMc+Bn1qbwkxYVWQoiVSNEIN4IAiLj3cUH/SAHT8m8E59Y97o8ON1syp0Rcw094ItG8pEVZFyQuOaH6PDgQ==",
"dev": true,
"dependencies": {
"@typescript-eslint/utils": "^5.54.0"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/tannerlinsley"
},
"peerDependencies": {
"eslint": "^8.0.0"
}
},
"node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/scope-manager": {
"version": "5.62.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-5.62.0.tgz",
"integrity": "sha512-VXuvVvZeQCQb5Zgf4HAxc04q5j+WrNAtNh9OwCsCgpKqESMTu3tF/jhZ3xG6T4NZwWl65Bg8KuS2uEvhSfLl0w==",
"dev": true,
"dependencies": {
"@typescript-eslint/types": "5.62.0",
"@typescript-eslint/visitor-keys": "5.62.0"
},
"engines": {
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/typescript-eslint"
}
},
"node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/types": {
"version": "5.62.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-5.62.0.tgz",
"integrity": "sha512-87NVngcbVXUahrRTqIK27gD2t5Cu1yuCXxbLcFtCzZGlfyVWWh8mLHkoxzjsB6DDNnvdL+fW8MiwPEJyGJQDgQ==",
"dev": true,
"engines": {
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/typescript-eslint"
}
},
"node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/typescript-estree": {
"version": "5.62.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-5.62.0.tgz",
"integrity": "sha512-CmcQ6uY7b9y694lKdRB8FEel7JbU/40iSAPomu++SjLMntB+2Leay2LO6i8VnJk58MtE9/nQSFIH6jpyRWyYzA==",
"dev": true,
"dependencies": {
"@typescript-eslint/types": "5.62.0",
"@typescript-eslint/visitor-keys": "5.62.0",
"debug": "^4.3.4",
"globby": "^11.1.0",
"is-glob": "^4.0.3",
"semver": "^7.3.7",
"tsutils": "^3.21.0"
},
"engines": {
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/typescript-eslint"
},
"peerDependenciesMeta": {
"typescript": {
"optional": true
}
}
},
"node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/utils": {
"version": "5.62.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-5.62.0.tgz",
"integrity": "sha512-n8oxjeb5aIbPFEtmQxQYOLI0i9n5ySBEY/ZEHHZqKQSFnxio1rv6dthascc9dLuwrL0RC5mPCxB7vnAVGAYWAQ==",
"dev": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.2.0",
"@types/json-schema": "^7.0.9",
"@types/semver": "^7.3.12",
"@typescript-eslint/scope-manager": "5.62.0",
"@typescript-eslint/types": "5.62.0",
"@typescript-eslint/typescript-estree": "5.62.0",
"eslint-scope": "^5.1.1",
"semver": "^7.3.7"
},
"engines": {
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/typescript-eslint"
},
"peerDependencies": {
"eslint": "^6.0.0 || ^7.0.0 || ^8.0.0"
}
},
"node_modules/@tanstack/eslint-plugin-query/node_modules/@typescript-eslint/visitor-keys": {
"version": "5.62.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-5.62.0.tgz",
"integrity": "sha512-07ny+LHRzQXepkGg6w0mFY41fVUNBrL2Roj/++7V1txKugfjm/Ci/qSND03r2RhlJhJYMcTn9AhhSSqQp0Ysyw==",
"dev": true,
"dependencies": {
"@typescript-eslint/types": "5.62.0",
"eslint-visitor-keys": "^3.3.0"
},
"engines": {
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/typescript-eslint"
}
},
"node_modules/@tanstack/eslint-plugin-query/node_modules/eslint-scope": {
"version": "5.1.1",
"resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz",
"integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==",
"dev": true,
"dependencies": {
"esrecurse": "^4.3.0",
"estraverse": "^4.1.1"
},
"engines": {
"node": ">=8.0.0"
}
},
"node_modules/@tanstack/eslint-plugin-query/node_modules/estraverse": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz",
"integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==",
"dev": true,
"engines": {
"node": ">=4.0"
}
},
"node_modules/@tanstack/query-core": {
"version": "5.8.7",
"resolved": "https://registry.npmjs.org/@tanstack/query-core/-/query-core-5.8.7.tgz",
"integrity": "sha512-58xOSkxxZK4SGQ/uzX8MDZHLGZCkxlgkPxnfhxUOL2uchnNHyay2UVcR3mQNMgaMwH1e2l+0n+zfS7+UJ/MAJw==",
"funding": {
"type": "github",
"url": "https://github.com/sponsors/tannerlinsley"
}
},
"node_modules/@tanstack/react-query": {
"version": "5.8.7",
"resolved": "https://registry.npmjs.org/@tanstack/react-query/-/react-query-5.8.7.tgz",
"integrity": "sha512-RYSSMmkhbJ7tPkf8w+MSRIXQLoUCm7DRnTLDcdf+uampupnriEsob3fVWTt9oaEj+AJWEKeCErDBdZeNcAzURQ==",
"dependencies": {
"@tanstack/query-core": "5.8.7"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/tannerlinsley"
},
"peerDependencies": {
"react": "^18.0.0",
"react-dom": "^18.0.0",
"react-native": "*"
},
"peerDependenciesMeta": {
"react-dom": {
"optional": true
},
"react-native": {
"optional": true
}
}
},
"node_modules/@types/axios": {
"version": "0.14.0",
"resolved": "https://registry.npmjs.org/@types/axios/-/axios-0.14.0.tgz",
"integrity": "sha512-KqQnQbdYE54D7oa/UmYVMZKq7CO4l8DEENzOKc4aBRwxCXSlJXGz83flFx5L7AWrOQnmuN3kVsRdt+GZPPjiVQ==",
"deprecated": "This is a stub types definition for axios (https://github.com/mzabriskie/axios). axios provides its own type definitions, so you don't need @types/axios installed!",
"dev": true,
"dependencies": {
"axios": "*"
}
},
"node_modules/@types/babel__core": { "node_modules/@types/babel__core": {
"version": "7.20.4", "version": "7.20.4",
"resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.4.tgz", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.4.tgz",
@ -3166,6 +3416,11 @@
"node": ">=8" "node": ">=8"
} }
}, },
"node_modules/asynckit": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q=="
},
"node_modules/autoprefixer": { "node_modules/autoprefixer": {
"version": "10.4.16", "version": "10.4.16",
"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz", "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz",
@ -3203,6 +3458,16 @@
"postcss": "^8.1.0" "postcss": "^8.1.0"
} }
}, },
"node_modules/axios": {
"version": "1.6.2",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.2.tgz",
"integrity": "sha512-7i24Ri4pmDRfJTR7LDBhsOTtcm+9kjX5WiY1X3wIisx6G9So3pfMkEiU7emUBe46oceVImccTEM3k6C5dbVW8A==",
"dependencies": {
"follow-redirects": "^1.15.0",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}
},
"node_modules/balanced-match": { "node_modules/balanced-match": {
"version": "1.0.2", "version": "1.0.2",
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
@ -3412,6 +3677,17 @@
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==",
"dev": true "dev": true
}, },
"node_modules/combined-stream": {
"version": "1.0.8",
"resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
"integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
"dependencies": {
"delayed-stream": "~1.0.0"
},
"engines": {
"node": ">= 0.8"
}
},
"node_modules/commander": { "node_modules/commander": {
"version": "4.1.1", "version": "4.1.1",
"resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz", "resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz",
@ -3512,6 +3788,14 @@
"integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==",
"dev": true "dev": true
}, },
"node_modules/delayed-stream": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
"integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/detect-node-es": { "node_modules/detect-node-es": {
"version": "1.1.0", "version": "1.1.0",
"resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz", "resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz",
@ -3916,6 +4200,38 @@
"resolved": "https://registry.npmjs.org/flexsearch/-/flexsearch-0.7.31.tgz", "resolved": "https://registry.npmjs.org/flexsearch/-/flexsearch-0.7.31.tgz",
"integrity": "sha512-XGozTsMPYkm+6b5QL3Z9wQcJjNYxp0CYn3U1gO7dwD6PAqU1SVWZxI9CCg3z+ml3YfqdPnrBehaBrnH2AGKbNA==" "integrity": "sha512-XGozTsMPYkm+6b5QL3Z9wQcJjNYxp0CYn3U1gO7dwD6PAqU1SVWZxI9CCg3z+ml3YfqdPnrBehaBrnH2AGKbNA=="
}, },
"node_modules/follow-redirects": {
"version": "1.15.3",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.3.tgz",
"integrity": "sha512-1VzOtuEM8pC9SFU1E+8KfTjZyMztRsgEfwQl44z8A25uy13jSzTj6dyK2Df52iV0vgHCfBwLhDWevLn95w5v6Q==",
"funding": [
{
"type": "individual",
"url": "https://github.com/sponsors/RubenVerborgh"
}
],
"engines": {
"node": ">=4.0"
},
"peerDependenciesMeta": {
"debug": {
"optional": true
}
}
},
"node_modules/form-data": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz",
"integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==",
"dependencies": {
"asynckit": "^0.4.0",
"combined-stream": "^1.0.8",
"mime-types": "^2.1.12"
},
"engines": {
"node": ">= 6"
}
},
"node_modules/fraction.js": { "node_modules/fraction.js": {
"version": "4.3.7", "version": "4.3.7",
"resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz", "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz",
@ -4413,6 +4729,25 @@
"node": ">=8.6" "node": ">=8.6"
} }
}, },
"node_modules/mime-db": {
"version": "1.52.0",
"resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz",
"integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==",
"engines": {
"node": ">= 0.6"
}
},
"node_modules/mime-types": {
"version": "2.1.35",
"resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz",
"integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==",
"dependencies": {
"mime-db": "1.52.0"
},
"engines": {
"node": ">= 0.6"
}
},
"node_modules/minimatch": { "node_modules/minimatch": {
"version": "3.1.2", "version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
@ -4868,6 +5203,11 @@
"node": ">= 0.8.0" "node": ">= 0.8.0"
} }
}, },
"node_modules/proxy-from-env": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz",
"integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg=="
},
"node_modules/punycode": { "node_modules/punycode": {
"version": "2.3.1", "version": "2.3.1",
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz",
@ -4919,6 +5259,21 @@
"react": "^18.2.0" "react": "^18.2.0"
} }
}, },
"node_modules/react-hook-form": {
"version": "7.48.2",
"resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.48.2.tgz",
"integrity": "sha512-H0T2InFQb1hX7qKtDIZmvpU1Xfn/bdahWBN1fH19gSe4bBEqTfmlr7H3XWTaVtiK4/tpPaI1F3355GPMZYge+A==",
"engines": {
"node": ">=12.22.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/react-hook-form"
},
"peerDependencies": {
"react": "^16.8.0 || ^17 || ^18"
}
},
"node_modules/react-hotkeys-hook": { "node_modules/react-hotkeys-hook": {
"version": "4.4.1", "version": "4.4.1",
"resolved": "https://registry.npmjs.org/react-hotkeys-hook/-/react-hotkeys-hook-4.4.1.tgz", "resolved": "https://registry.npmjs.org/react-hotkeys-hook/-/react-hotkeys-hook-4.4.1.tgz",
@ -5610,6 +5965,27 @@
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
}, },
"node_modules/tsutils": {
"version": "3.21.0",
"resolved": "https://registry.npmjs.org/tsutils/-/tsutils-3.21.0.tgz",
"integrity": "sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==",
"dev": true,
"dependencies": {
"tslib": "^1.8.1"
},
"engines": {
"node": ">= 6"
},
"peerDependencies": {
"typescript": ">=2.8.0 || >= 3.2.0-dev || >= 3.3.0-dev || >= 3.4.0-dev || >= 3.5.0-dev || >= 3.6.0-dev || >= 3.6.0-beta || >= 3.7.0-dev || >= 3.7.0-beta"
}
},
"node_modules/tsutils/node_modules/tslib": {
"version": "1.14.1",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==",
"dev": true
},
"node_modules/type-check": { "node_modules/type-check": {
"version": "0.4.0", "version": "0.4.0",
"resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz",
@ -5860,6 +6236,14 @@
"url": "https://github.com/sponsors/sindresorhus" "url": "https://github.com/sponsors/sindresorhus"
} }
}, },
"node_modules/zod": {
"version": "3.22.4",
"resolved": "https://registry.npmjs.org/zod/-/zod-3.22.4.tgz",
"integrity": "sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==",
"funding": {
"url": "https://github.com/sponsors/colinhacks"
}
},
"node_modules/zustand": { "node_modules/zustand": {
"version": "4.4.6", "version": "4.4.6",
"resolved": "https://registry.npmjs.org/zustand/-/zustand-4.4.6.tgz", "resolved": "https://registry.npmjs.org/zustand/-/zustand-4.4.6.tgz",

View File

@ -11,7 +11,9 @@
}, },
"dependencies": { "dependencies": {
"@heroicons/react": "^2.0.18", "@heroicons/react": "^2.0.18",
"@hookform/resolvers": "^3.3.2",
"@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "^1.0.5",
"@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-icons": "^1.3.0",
@ -19,6 +21,7 @@
"@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0", "@radix-ui/react-select": "^2.0.0",
"@radix-ui/react-separator": "^1.0.3",
"@radix-ui/react-slider": "^1.1.2", "@radix-ui/react-slider": "^1.1.2",
"@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-slot": "^1.0.2",
"@radix-ui/react-switch": "^1.0.3", "@radix-ui/react-switch": "^1.0.3",
@ -26,7 +29,9 @@
"@radix-ui/react-toast": "^1.1.5", "@radix-ui/react-toast": "^1.1.5",
"@radix-ui/react-toggle": "^1.0.3", "@radix-ui/react-toggle": "^1.0.3",
"@radix-ui/react-tooltip": "^1.0.7", "@radix-ui/react-tooltip": "^1.0.7",
"@tanstack/react-query": "^5.8.7",
"@uidotdev/usehooks": "^2.4.1", "@uidotdev/usehooks": "^2.4.1",
"axios": "^1.6.2",
"class-variance-authority": "^0.7.0", "class-variance-authority": "^0.7.0",
"clsx": "^2.0.0", "clsx": "^2.0.0",
"flexsearch": "^0.7.21", "flexsearch": "^0.7.21",
@ -37,6 +42,7 @@
"next-themes": "^0.2.1", "next-themes": "^0.2.1",
"react": "^18.2.0", "react": "^18.2.0",
"react-dom": "^18.2.0", "react-dom": "^18.2.0",
"react-hook-form": "^7.48.2",
"react-hotkeys-hook": "^4.4.1", "react-hotkeys-hook": "^4.4.1",
"react-photo-album": "^2.3.0", "react-photo-album": "^2.3.0",
"react-use": "^17.4.0", "react-use": "^17.4.0",
@ -44,9 +50,12 @@
"recoil": "^0.7.7", "recoil": "^0.7.7",
"tailwind-merge": "^2.0.0", "tailwind-merge": "^2.0.0",
"tailwindcss-animate": "^1.0.7", "tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4",
"zustand": "^4.4.6" "zustand": "^4.4.6"
}, },
"devDependencies": { "devDependencies": {
"@tanstack/eslint-plugin-query": "^5.8.4",
"@types/axios": "^0.14.0",
"@types/flexsearch": "^0.7.3", "@types/flexsearch": "^0.7.3",
"@types/lodash": "^4.14.201", "@types/lodash": "^4.14.201",
"@types/node": "^20.9.2", "@types/node": "^20.9.2",

View File

@ -1,7 +1,6 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { useCallback, useEffect, useMemo, useRef, useState } from "react"
import { nanoid } from "nanoid" import { nanoid } from "nanoid"
import { useSetRecoilState } from "recoil"
import { serverConfigState } from "@/lib/store"
import useInputImage from "@/hooks/useInputImage" import useInputImage from "@/hooks/useInputImage"
import { keepGUIAlive } from "@/lib/utils" import { keepGUIAlive } from "@/lib/utils"
import { getServerConfig, isDesktop } from "@/lib/api" import { getServerConfig, isDesktop } from "@/lib/api"
@ -19,10 +18,13 @@ const SUPPORTED_FILE_TYPE = [
"image/tiff", "image/tiff",
] ]
function Home() { function Home() {
const [file, setFile] = useStore((state) => [state.file, state.setFile]) const [file, setServerConfig, setFile] = useStore((state) => [
state.file,
state.setServerConfig,
state.setFile,
])
const userInputImage = useInputImage() const userInputImage = useInputImage()
const setServerConfigState = useSetRecoilState(serverConfigState)
useEffect(() => { useEffect(() => {
if (userInputImage) { if (userInputImage) {
@ -44,8 +46,7 @@ function Home() {
useEffect(() => { useEffect(() => {
const fetchServerConfig = async () => { const fetchServerConfig = async () => {
const serverConfig = await getServerConfig().then((res) => res.json()) const serverConfig = await getServerConfig().then((res) => res.json())
console.log(serverConfig) setServerConfig(serverConfig)
setServerConfigState(serverConfig)
} }
fetchServerConfig() fetchServerConfig()
}, []) }, [])

View File

@ -1,5 +1,7 @@
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import { cn } from "@/lib/utils"
import React, { useEffect, useState } from "react" import React, { useEffect, useState } from "react"
import { twMerge } from "tailwind-merge"
const DOC_MOVE_OPTS = { capture: true, passive: false } const DOC_MOVE_OPTS = { capture: true, passive: false }
@ -75,11 +77,6 @@ const Cropper = (props: Props) => {
state.setCropperWidth, state.setCropperWidth,
state.setCropperHeight, state.setCropperHeight,
]) ])
// const [x, setX] = useRecoilState(croperX)
// const [y, setY] = useRecoilState(croperY)
// const [height, setHeight] = useRecoilState(croperHeight)
// const [width, setWidth] = useRecoilState(croperWidth)
// const isInpainting = useRecoilValue(isInpaintingState)
const [isResizing, setIsResizing] = useState(false) const [isResizing, setIsResizing] = useState(false)
const [isMoving, setIsMoving] = useState(false) const [isMoving, setIsMoving] = useState(false)
@ -100,7 +97,7 @@ const Cropper = (props: Props) => {
}) })
const onDragFocus = () => { const onDragFocus = () => {
console.log("focus") // console.log("focus")
} }
const clampLeftRight = (newX: number, newWidth: number) => { const clampLeftRight = (newX: number, newWidth: number) => {
@ -254,102 +251,64 @@ const Cropper = (props: Props) => {
} }
} }
const createCropSelection = () => { const createDragHandle = (cursor: string, side1: string, side2: string) => {
const sideLength = 12
const draghandleCls = `w-[${sideLength}px] h-[${sideLength}px] z-4 absolute block border-2 border-primary borde pointer-events-auto hover:bg-primary`
let side2Cls = `${side2}-[-${sideLength / 2}px]`
if (side2 === "") {
if (side1 === "top" || side1 === "bottom") {
side2Cls = `left-[calc(50%-${sideLength / 2}px)]`
} else if (side1 === "left" || side1 === "right") {
side2Cls = `top-[calc(50%-${sideLength / 2}px)]`
}
}
return ( return (
<div <div
className="drag-elements" className={cn(
onFocus={onDragFocus} draghandleCls,
onPointerDown={onCropPointerDown} `${cursor}`,
> side1 ? `${side1}-[-${sideLength / 2}px]` : "",
side2Cls
)}
data-ord={side1 + side2}
aria-label={side1 + side2}
tabIndex={-1}
role="button"
/>
)
}
const createCropSelection = () => {
return (
<div onFocus={onDragFocus} onPointerDown={onCropPointerDown}>
<div <div
className="drag-bar ord-top" className="absolute pointer-events-auto top-0 left-0 w-full cursor-ns-resize h-[12px] mt-[-6px]"
data-ord="top" data-ord="top"
style={{ transform: `scale(${1 / scale})` }}
/> />
<div <div
className="drag-bar ord-right" className="absolute pointer-events-auto top-0 right-0 h-full cursor-ew-resize w-[12px] mr-[-6px]"
data-ord="right" data-ord="right"
style={{ transform: `scale(${1 / scale})` }}
/> />
<div <div
className="drag-bar ord-bottom" className="absolute pointer-events-auto bottom-0 left-0 w-full cursor-ns-resize h-[12px] mb-[-6px]"
data-ord="bottom" data-ord="bottom"
style={{ transform: `scale(${1 / scale})` }}
/> />
<div <div
className="drag-bar ord-left" className="absolute pointer-events-auto top-0 left-0 h-full cursor-ew-resize w-[12px] ml-[-6px]"
data-ord="left" data-ord="left"
style={{ transform: `scale(${1 / scale})` }}
/> />
<div {createDragHandle("cursor-nw-resize", "top", "left")}
className="drag-handle ord-topleft" {createDragHandle("cursor-ne-resize", "top", "right")}
data-ord="topleft" {createDragHandle("cursor-se-resize", "bottom", "left")}
aria-label="topleft" {createDragHandle("cursor-sw-resize", "bottom", "right")}
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div {createDragHandle("cursor-ns-resize", "top", "")}
className="drag-handle ord-topright" {createDragHandle("cursor-ns-resize", "bottom", "")}
data-ord="topright" {createDragHandle("cursor-ew-resize", "left", "")}
aria-label="topright" {createDragHandle("cursor-ew-resize", "right", "")}
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-bottomleft"
data-ord="bottomleft"
aria-label="bottomleft"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-bottomright"
data-ord="bottomright"
aria-label="bottomright"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-top"
data-ord="top"
aria-label="top"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-right"
data-ord="right"
aria-label="right"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-bottom"
data-ord="bottom"
aria-label="bottom"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-left"
data-ord="left"
aria-label="left"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
</div> </div>
) )
} }
@ -370,17 +329,17 @@ const Cropper = (props: Props) => {
const createInfoBar = () => { const createInfoBar = () => {
return ( return (
<div <div
className="border absolute pointer-events-auto text-[1rem] px-[0.8rem] py-[0.2rem] flex items-center justify-center gap-[12px] rounded-full hover:cursor-move" className={twMerge(
onPointerDown={onInfoBarPointerDown} "border absolute pointer-events-auto px-2 py-1 rounded-full hover:cursor-move bg-background",
"origin-top-left top-0 left-0"
)}
style={{ style={{
transform: `scale(${1 / scale})`, transform: `scale(${(1 / scale) * 0.8})`,
top: `${10 / scale}px`,
left: `${10 / scale}px`,
}} }}
onPointerDown={onInfoBarPointerDown}
> >
<div> {/* TODO: 移动的时候会显示 brush */}
{width} x {height} {width} x {height}
</div>
</div> </div>
) )
} }

View File

@ -3,13 +3,10 @@ import { CursorArrowRaysIcon } from "@heroicons/react/24/outline"
import { useToast } from "@/components/ui/use-toast" import { useToast } from "@/components/ui/use-toast"
import { import {
ReactZoomPanPinchContentRef, ReactZoomPanPinchContentRef,
ReactZoomPanPinchRef,
TransformComponent, TransformComponent,
TransformWrapper, TransformWrapper,
} from "react-zoom-pan-pinch" } from "react-zoom-pan-pinch"
import { useRecoilState, useRecoilValue, useSetRecoilState } from "recoil" import { useKeyPressEvent, useWindowSize } from "react-use"
import { useWindowSize } from "react-use"
// import { useWindowSize, useKey, useKeyPressEvent } from "@uidotdev/usehooks"
import inpaint, { downloadToOutput, runPlugin } from "@/lib/api" import inpaint, { downloadToOutput, runPlugin } from "@/lib/api"
import { IconButton } from "@/components/ui/button" import { IconButton } from "@/components/ui/button"
import { import {
@ -22,23 +19,6 @@ import {
srcToFile, srcToFile,
} from "@/lib/utils" } from "@/lib/utils"
import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react" import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react"
import {
croperState,
enableFileManagerState,
interactiveSegClicksState,
isDiffusionModelsState,
isEnableAutoSavingState,
isInteractiveSegRunningState,
isInteractiveSegState,
isPix2PixState,
isPluginRunningState,
isProcessingState,
negativePropmtState,
runManuallyState,
seedState,
settingState,
} from "@/lib/store"
// import Croper from "../Croper/Croper"
import emitter, { import emitter, {
EVENT_PROMPT, EVENT_PROMPT,
EVENT_CUSTOM_MASK, EVENT_CUSTOM_MASK,
@ -49,19 +29,15 @@ import emitter, {
} from "@/lib/event" } from "@/lib/event"
import { useImage } from "@/hooks/useImage" import { useImage } from "@/hooks/useImage"
import { Slider } from "./ui/slider" import { Slider } from "./ui/slider"
// import FileSelect from "../FileSelect/FileSelect"
// import InteractiveSeg from "../InteractiveSeg/InteractiveSeg"
// import InteractiveSegConfirmActions from "../InteractiveSeg/ConfirmActions"
// import InteractiveSegReplaceModal from "../InteractiveSeg/ReplaceModal"
import { PluginName } from "@/lib/types" import { PluginName } from "@/lib/types"
import { useHotkeys } from "react-hotkeys-hook" import { useHotkeys } from "react-hotkeys-hook"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import Cropper from "./Cropper" import Cropper from "./Cropper"
import { HotkeysEvent } from "react-hotkeys-hook/dist/types"
const TOOLBAR_HEIGHT = 200 const TOOLBAR_HEIGHT = 200
const MIN_BRUSH_SIZE = 10 const MIN_BRUSH_SIZE = 10
const MAX_BRUSH_SIZE = 200 const MAX_BRUSH_SIZE = 200
const COMPARE_SLIDER_DURATION_MS = 300
const BRUSH_COLOR = "#ffcc00bb" const BRUSH_COLOR = "#ffcc00bb"
interface Line { interface Line {
@ -110,48 +86,55 @@ export default function Editor(props: EditorProps) {
imageWidth, imageWidth,
imageHeight, imageHeight,
baseBrushSize, baseBrushSize,
brushScale, brushSizeScale,
promptVal, settings,
enableAutoSaving,
cropperRect,
enableManualInpainting,
setImageSize, setImageSize,
setBrushSize, setBrushSize,
setIsInpainting, setIsInpainting,
setSeed,
interactiveSegState,
updateInteractiveSegState,
resetInteractiveSegState,
isPluginRunning,
setIsPluginRunning,
] = useStore((state) => [ ] = useStore((state) => [
state.isInpainting, state.isInpainting,
state.imageWidth, state.imageWidth,
state.imageHeight, state.imageHeight,
state.brushSize, state.brushSize,
state.brushSizeScale, state.brushSizeScale,
state.prompt, state.settings,
state.serverConfig.enableAutoSaving,
state.cropperState,
state.settings.enableManualInpainting,
state.setImageSize, state.setImageSize,
state.setBrushSize, state.setBrushSize,
state.setIsInpainting, state.setIsInpainting,
state.setSeed,
state.interactiveSegState,
state.updateInteractiveSegState,
state.resetInteractiveSegState,
state.isPluginRunning,
state.setIsPluginRunning,
]) ])
const brushSize = baseBrushSize * brushScale const brushSize = baseBrushSize * brushSizeScale
// 纯 local state // 纯 local state
const [showOriginal, setShowOriginal] = useState(false) const [showOriginal, setShowOriginal] = useState(false)
// //
const negativePromptVal = useRecoilValue(negativePropmtState) const isProcessing = isInpainting
const settings = useRecoilValue(settingState) const isDiffusionModels = false
const [seedVal, setSeed] = useRecoilState(seedState) const isPix2Pix = false
const croperRect = useRecoilValue(croperState)
const setIsPluginRunning = useSetRecoilState(isPluginRunningState)
const isProcessing = useRecoilValue(isProcessingState)
const runMannually = useRecoilValue(runManuallyState)
const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
const isPix2Pix = useRecoilValue(isPix2PixState)
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
isInteractiveSegState
)
const setIsInteractiveSegRunning = useSetRecoilState(
isInteractiveSegRunningState
)
const [showInteractiveSegModal, setShowInteractiveSegModal] = useState(false) const [showInteractiveSegModal, setShowInteractiveSegModal] = useState(false)
const [interactiveSegMask, setInteractiveSegMask] = useState< const [interactiveSegMask, setInteractiveSegMask] = useState<
HTMLImageElement | null | undefined HTMLImageElement | null | undefined
>(null) >(null)
// only used while interactive segmentation is on // only used while interactive segmentation is on
const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = useState< const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = useState<
HTMLImageElement | null | undefined HTMLImageElement | null | undefined
@ -167,8 +150,6 @@ export default function Editor(props: EditorProps) {
const [dreamButtonHoverLineGroup, setDreamButtonHoverLineGroup] = const [dreamButtonHoverLineGroup, setDreamButtonHoverLineGroup] =
useState<LineGroup>([]) useState<LineGroup>([])
const [clicks, setClicks] = useRecoilState(interactiveSegClicksState)
const [original, isOriginalLoaded] = useImage(file) const [original, isOriginalLoaded] = useImage(file)
const [renders, setRenders] = useState<HTMLImageElement[]>([]) const [renders, setRenders] = useState<HTMLImageElement[]>([])
const [context, setContext] = useState<CanvasRenderingContext2D>() const [context, setContext] = useState<CanvasRenderingContext2D>()
@ -201,7 +182,6 @@ export default function Editor(props: EditorProps) {
const [initialCentered, setInitialCentered] = useState(false) const [initialCentered, setInitialCentered] = useState(false)
const [isDraging, setIsDraging] = useState(false) const [isDraging, setIsDraging] = useState(false)
const [isMultiStrokeKeyPressed, setIsMultiStrokeKeyPressed] = useState(false)
const [sliderPos, setSliderPos] = useState<number>(0) const [sliderPos, setSliderPos] = useState<number>(0)
@ -209,8 +189,6 @@ export default function Editor(props: EditorProps) {
const [redoRenders, setRedoRenders] = useState<HTMLImageElement[]>([]) const [redoRenders, setRedoRenders] = useState<HTMLImageElement[]>([])
const [redoCurLines, setRedoCurLines] = useState<Line[]>([]) const [redoCurLines, setRedoCurLines] = useState<Line[]>([])
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([]) const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
const enableFileManager = useRecoilValue(enableFileManagerState)
const isEnableAutoSaving = useRecoilValue(isEnableAutoSavingState)
const draw = useCallback( const draw = useCallback(
(render: HTMLImageElement, lineGroup: LineGroup) => { (render: HTMLImageElement, lineGroup: LineGroup) => {
@ -223,10 +201,10 @@ export default function Editor(props: EditorProps) {
context.clearRect(0, 0, context.canvas.width, context.canvas.height) context.clearRect(0, 0, context.canvas.width, context.canvas.height)
context.drawImage(render, 0, 0, imageWidth, imageHeight) context.drawImage(render, 0, 0, imageWidth, imageHeight)
if (isInteractiveSeg && tmpInteractiveSegMask) { if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) {
context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight) context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight)
} }
if (!isInteractiveSeg && interactiveSegMask) { if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) {
context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight) context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight)
} }
if (dreamButtonHoverSegMask) { if (dreamButtonHoverSegMask) {
@ -243,7 +221,7 @@ export default function Editor(props: EditorProps) {
}, },
[ [
context, context,
isInteractiveSeg, interactiveSegState,
tmpInteractiveSegMask, tmpInteractiveSegMask,
dreamButtonHoverSegMask, dreamButtonHoverSegMask,
interactiveSegMask, interactiveSegMask,
@ -363,34 +341,31 @@ export default function Editor(props: EditorProps) {
setCurLineGroup([]) setCurLineGroup([])
setIsDraging(false) setIsDraging(false)
setIsInpainting(true) setIsInpainting(true)
if (settings.graduallyInpainting) { drawLinesOnMask([maskLineGroup], maskImage)
drawLinesOnMask([maskLineGroup], maskImage)
} else {
drawLinesOnMask(newLineGroups)
}
let targetFile = file let targetFile = file
if (settings.graduallyInpainting === true) { console.log(
if (useLastLineGroup === true) { `randers.length ${renders.length} useLastLineGroup: ${useLastLineGroup}`
// renders.length == 1 还是用原来的 )
if (renders.length > 1) { if (useLastLineGroup === true) {
const lastRender = renders[renders.length - 2] // renders.length == 1 还是用原来的
targetFile = await srcToFile( if (renders.length > 1) {
lastRender.currentSrc, const lastRender = renders[renders.length - 2]
file.name,
file.type
)
}
} else if (renders.length > 0) {
console.info("gradually inpainting on last result")
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile( targetFile = await srcToFile(
lastRender.currentSrc, lastRender.currentSrc,
file.name, file.name,
file.type file.type
) )
} }
} else if (renders.length > 0) {
console.info("gradually inpainting on last result")
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
} }
try { try {
@ -398,10 +373,7 @@ export default function Editor(props: EditorProps) {
const res = await inpaint( const res = await inpaint(
targetFile, targetFile,
settings, settings,
croperRect, cropperRect,
promptVal,
negativePromptVal,
seedVal,
useCustomMask ? undefined : maskCanvas.toDataURL(), useCustomMask ? undefined : maskCanvas.toDataURL(),
useCustomMask ? customMask : undefined, useCustomMask ? customMask : undefined,
paintByExampleImage paintByExampleImage
@ -445,18 +417,15 @@ export default function Editor(props: EditorProps) {
setInteractiveSegMask(null) setInteractiveSegMask(null)
}, },
[ [
renders,
lineGroups, lineGroups,
curLineGroup, curLineGroup,
maskCanvas, maskCanvas,
settings.graduallyInpainting,
settings, settings,
croperRect, cropperRect,
promptVal,
negativePromptVal,
drawOnCurrentRender, drawOnCurrentRender,
hadDrawSomething, hadDrawSomething,
drawLinesOnMask, drawLinesOnMask,
seedVal,
] ]
) )
@ -487,7 +456,6 @@ export default function Editor(props: EditorProps) {
}, [ }, [
hadDrawSomething, hadDrawSomething,
runInpainting, runInpainting,
promptVal,
interactiveSegMask, interactiveSegMask,
prevInteractiveSegMask, prevInteractiveSegMask,
]) ])
@ -604,7 +572,7 @@ export default function Editor(props: EditorProps) {
useEffect(() => { useEffect(() => {
emitter.on(PluginName.InteractiveSeg, () => { emitter.on(PluginName.InteractiveSeg, () => {
setIsInteractiveSeg(true) // setIsInteractiveSeg(true)
if (interactiveSegMask !== null) { if (interactiveSegMask !== null) {
setShowInteractiveSegModal(true) setShowInteractiveSegModal(true)
} }
@ -807,8 +775,8 @@ export default function Editor(props: EditorProps) {
const offsetX = (windowSize.width - imageWidth * minScale) / 2 const offsetX = (windowSize.width - imageWidth * minScale) / 2
const offsetY = (windowSize.height - imageHeight * minScale) / 2 const offsetY = (windowSize.height - imageHeight * minScale) / 2
viewport.setTransform(offsetX, offsetY, minScale, 200, "easeOutQuad") viewport.setTransform(offsetX, offsetY, minScale, 200, "easeOutQuad")
if (viewport.state) { if (viewport.instance.transformState.scale) {
viewport.state.scale = minScale viewport.instance.transformState.scale = minScale
} }
setScale(minScale) setScale(minScale)
@ -850,24 +818,12 @@ export default function Editor(props: EditorProps) {
} }
}, []) }, [])
const onInteractiveCancel = useCallback(() => {
setIsInteractiveSeg(false)
setIsInteractiveSegRunning(false)
setClicks([])
setTmpInteractiveSegMask(null)
}, [])
const handleEscPressed = () => { const handleEscPressed = () => {
if (isProcessing) { if (isProcessing) {
return return
} }
if (isInteractiveSeg) { if (isDraging) {
onInteractiveCancel()
return
}
if (isDraging || isMultiStrokeKeyPressed) {
setIsDraging(false) setIsDraging(false)
setCurLineGroup([]) setCurLineGroup([])
drawOnCurrentRender([]) drawOnCurrentRender([])
@ -879,9 +835,6 @@ export default function Editor(props: EditorProps) {
useHotkeys("Escape", handleEscPressed, [ useHotkeys("Escape", handleEscPressed, [
isDraging, isDraging,
isInpainting, isInpainting,
isMultiStrokeKeyPressed,
isInteractiveSeg,
onInteractiveCancel,
resetZoom, resetZoom,
drawOnCurrentRender, drawOnCurrentRender,
]) ])
@ -901,7 +854,7 @@ export default function Editor(props: EditorProps) {
} }
return return
} }
if (isInteractiveSeg) { if (interactiveSegState.isInteractiveSeg) {
return return
} }
if (isPanning) { if (isPanning) {
@ -924,7 +877,7 @@ export default function Editor(props: EditorProps) {
return return
} }
setIsInteractiveSegRunning(true) // setIsInteractiveSegRunning(true)
const targetFile = await getCurrentRender() const targetFile = await getCurrentRender()
const prevMask = null const prevMask = null
try { try {
@ -950,14 +903,14 @@ export default function Editor(props: EditorProps) {
description: e.message ? e.message : e.toString(), description: e.message ? e.message : e.toString(),
}) })
} }
setIsInteractiveSegRunning(false) // setIsInteractiveSegRunning(false)
} }
const onPointerUp = (ev: SyntheticEvent) => { const onPointerUp = (ev: SyntheticEvent) => {
if (isMidClick(ev)) { if (isMidClick(ev)) {
setIsPanning(false) setIsPanning(false)
} }
if (isInteractiveSeg) { if (interactiveSegState.isInteractiveSeg) {
return return
} }
@ -978,12 +931,7 @@ export default function Editor(props: EditorProps) {
return return
} }
if (isMultiStrokeKeyPressed) { if (enableManualInpainting) {
setIsDraging(false)
return
}
if (runMannually) {
setIsDraging(false) setIsDraging(false)
} else { } else {
runInpainting() runInpainting()
@ -991,34 +939,34 @@ export default function Editor(props: EditorProps) {
} }
const isOutsideCroper = (clickPnt: { x: number; y: number }) => { const isOutsideCroper = (clickPnt: { x: number; y: number }) => {
if (clickPnt.x < croperRect.x) { if (clickPnt.x < cropperRect.x) {
return true return true
} }
if (clickPnt.y < croperRect.y) { if (clickPnt.y < cropperRect.y) {
return true return true
} }
if (clickPnt.x > croperRect.x + croperRect.width) { if (clickPnt.x > cropperRect.x + cropperRect.width) {
return true return true
} }
if (clickPnt.y > croperRect.y + croperRect.height) { if (clickPnt.y > cropperRect.y + cropperRect.height) {
return true return true
} }
return false return false
} }
const onCanvasMouseUp = (ev: SyntheticEvent) => { const onCanvasMouseUp = (ev: SyntheticEvent) => {
if (isInteractiveSeg) { if (interactiveSegState.isInteractiveSeg) {
const xy = mouseXY(ev) const xy = mouseXY(ev)
const isX = xy.x const isX = xy.x
const isY = xy.y const isY = xy.y
const newClicks: number[][] = [...clicks] const newClicks: number[][] = [...interactiveSegState.clicks]
if (isRightClick(ev)) { if (isRightClick(ev)) {
newClicks.push([isX, isY, 0, newClicks.length]) newClicks.push([isX, isY, 0, newClicks.length])
} else { } else {
newClicks.push([isX, isY, 1, newClicks.length]) newClicks.push([isX, isY, 1, newClicks.length])
} }
// runInteractiveSeg(newClicks) // runInteractiveSeg(newClicks)
setClicks(newClicks) updateInteractiveSegState({ clicks: newClicks })
} }
} }
@ -1026,7 +974,7 @@ export default function Editor(props: EditorProps) {
if (isProcessing) { if (isProcessing) {
return return
} }
if (isInteractiveSeg) { if (interactiveSegState.isInteractiveSeg) {
return return
} }
if (isChangingBrushSizeByMouse) { if (isChangingBrushSizeByMouse) {
@ -1063,7 +1011,7 @@ export default function Editor(props: EditorProps) {
setIsDraging(true) setIsDraging(true)
let lineGroup: LineGroup = [] let lineGroup: LineGroup = []
if (isMultiStrokeKeyPressed || runMannually) { if (enableManualInpainting) {
lineGroup = [...curLineGroup] lineGroup = [...curLineGroup]
} }
lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] }) lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] })
@ -1122,9 +1070,9 @@ export default function Editor(props: EditorProps) {
context, context,
]) ])
const undo = (keyboardEvent: KeyboardEvent, hotkeysEvent: HotkeysEvent) => { const undo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => {
keyboardEvent.preventDefault() keyboardEvent.preventDefault()
if (runMannually && curLineGroup.length !== 0) { if (enableManualInpainting && curLineGroup.length !== 0) {
undoStroke() undoStroke()
} else { } else {
undoRender() undoRender()
@ -1134,7 +1082,7 @@ export default function Editor(props: EditorProps) {
useHotkeys("meta+z,ctrl+z", undo, undefined, [ useHotkeys("meta+z,ctrl+z", undo, undefined, [
undoStroke, undoStroke,
undoRender, undoRender,
runMannually, enableManualInpainting,
curLineGroup, curLineGroup,
context?.canvas, context?.canvas,
renders, renders,
@ -1148,7 +1096,7 @@ export default function Editor(props: EditorProps) {
return false return false
} }
if (runMannually) { if (enableManualInpainting) {
if (curLineGroup.length === 0) { if (curLineGroup.length === 0) {
return true return true
} }
@ -1188,9 +1136,9 @@ export default function Editor(props: EditorProps) {
// draw(newRenders[newRenders.length - 1], []) // draw(newRenders[newRenders.length - 1], [])
}, [draw, renders, redoRenders, redoLineGroups, lineGroups, original]) }, [draw, renders, redoRenders, redoLineGroups, lineGroups, original])
const redo = (keyboardEvent: KeyboardEvent, hotkeysEvent: HotkeysEvent) => { const redo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => {
keyboardEvent.preventDefault() keyboardEvent.preventDefault()
if (runMannually && redoCurLines.length !== 0) { if (enableManualInpainting && redoCurLines.length !== 0) {
redoStroke() redoStroke()
} else { } else {
redoRender() redoRender()
@ -1200,7 +1148,7 @@ export default function Editor(props: EditorProps) {
useHotkeys("shift+ctrl+z,shift+meta+z", redo, undefined, [ useHotkeys("shift+ctrl+z,shift+meta+z", redo, undefined, [
redoStroke, redoStroke,
redoRender, redoRender,
runMannually, enableManualInpainting,
redoCurLines, redoCurLines,
]) ])
@ -1212,7 +1160,7 @@ export default function Editor(props: EditorProps) {
return false return false
} }
if (runMannually) { if (enableManualInpainting) {
if (redoCurLines.length === 0) { if (redoCurLines.length === 0) {
return true return true
} }
@ -1223,37 +1171,39 @@ export default function Editor(props: EditorProps) {
return false return false
} }
// useKeyPressEvent( useKeyPressEvent(
// "Tab", "Tab",
// (ev) => { (ev) => {
// ev?.preventDefault() ev?.preventDefault()
// ev?.stopPropagation() ev?.stopPropagation()
// if (hadRunInpainting()) { if (hadRunInpainting()) {
// setShowOriginal(() => { setShowOriginal(() => {
// window.setTimeout(() => { window.setTimeout(() => {
// setSliderPos(100) setSliderPos(100)
// }, 10) }, 10)
// return true return true
// }) })
// } }
// }, },
// (ev) => { (ev) => {
// ev?.preventDefault() ev?.preventDefault()
// ev?.stopPropagation() ev?.stopPropagation()
// if (hadRunInpainting()) { if (hadRunInpainting()) {
// setSliderPos(0) window.setTimeout(() => {
// window.setTimeout(() => { setSliderPos(0)
// setShowOriginal(false) }, 10)
// }, 350) window.setTimeout(() => {
// } setShowOriginal(false)
// } }, COMPARE_SLIDER_DURATION_MS)
// ) }
}
)
function download() { function download() {
if (file === undefined) { if (file === undefined) {
return return
} }
if ((enableFileManager || isEnableAutoSaving) && renders.length > 0) { if (enableAutoSaving && renders.length > 0) {
try { try {
downloadToOutput(renders[renders.length - 1], file.name, file.type) downloadToOutput(renders[renders.length - 1], file.name, file.type)
toast({ toast({
@ -1273,7 +1223,7 @@ export default function Editor(props: EditorProps) {
const name = file.name.replace(/(\.[\w\d_-]+)$/i, "_cleanup$1") const name = file.name.replace(/(\.[\w\d_-]+)$/i, "_cleanup$1")
const curRender = renders[renders.length - 1] const curRender = renders[renders.length - 1]
downloadImage(curRender.currentSrc, name) downloadImage(curRender.currentSrc, name)
if (settings.downloadMask) { if (settings.enableDownloadMask) {
let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1") let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1")
maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg") maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg")
@ -1305,104 +1255,98 @@ export default function Editor(props: EditorProps) {
return undefined return undefined
}, [showBrush, isPanning]) }, [showBrush, isPanning])
// Standard Hotkeys for Brush Size useHotkeys(
// useHotKey("[", () => { "[",
// setBrushSize((currentBrushSize: number) => { () => {
// if (currentBrushSize > 10) { let newBrushSize = baseBrushSize
// return currentBrushSize - 10 if (baseBrushSize > 10) {
// } newBrushSize = baseBrushSize - 10
// if (currentBrushSize <= 10 && currentBrushSize > 0) { }
// return currentBrushSize - 5 if (baseBrushSize <= 10 && baseBrushSize > 0) {
// } newBrushSize = baseBrushSize - 5
// return currentBrushSize }
// }) setBrushSize(newBrushSize)
// }) },
[baseBrushSize]
)
// useHotKey("]", () => { useHotkeys(
// setBrushSize((currentBrushSize: number) => { "]",
// return currentBrushSize + 10 () => {
// }) setBrushSize(baseBrushSize + 10)
// }) },
[baseBrushSize]
)
// // Manual Inpainting Hotkey // Manual Inpainting Hotkey
// useHotKey( useHotkeys(
// "shift+r", "shift+r",
// () => { () => {
// if (runMannually && hadDrawSomething()) { if (enableManualInpainting && hadDrawSomething()) {
// runInpainting() runInpainting()
// } }
// }, },
// {}, [enableManualInpainting, runInpainting, hadDrawSomething]
// [runMannually, runInpainting, hadDrawSomething] )
// )
// useHotKey( useHotkeys(
// "ctrl+c, cmd+c", "ctrl+c, cmd+c",
// async () => { async () => {
// const hasPermission = await askWritePermission() const hasPermission = await askWritePermission()
// if (hasPermission && renders.length > 0) { if (hasPermission && renders.length > 0) {
// if (context?.canvas) { if (context?.canvas) {
// await copyCanvasImage(context?.canvas) await copyCanvasImage(context?.canvas)
// setToastState({ toast({
// open: true, title: "Copy inpainting result to clipboard",
// desc: "Copy inpainting result to clipboard", })
// state: "success", }
// duration: 3000, }
// }) },
// } [renders, context]
// } )
// },
// {},
// [renders, context]
// )
// Toggle clean/zoom tool on spacebar. // Toggle clean/zoom tool on spacebar.
// useKeyPressEvent( useKeyPressEvent(
// " ", " ",
// (ev) => { (ev) => {
// if (!app.disableShortCuts) { ev?.preventDefault()
// ev?.preventDefault() ev?.stopPropagation()
// ev?.stopPropagation() setShowBrush(false)
// setShowBrush(false) setIsPanning(true)
// setIsPanning(true) },
// } (ev) => {
// }, ev?.preventDefault()
// (ev) => { ev?.stopPropagation()
// if (!app.disableShortCuts) { setShowBrush(true)
// ev?.preventDefault() setIsPanning(false)
// ev?.stopPropagation() }
// setShowBrush(true) )
// setIsPanning(false)
// }
// }
// )
// useKeyPressEvent( useKeyPressEvent(
// "Alt", "Alt",
// (ev) => { (ev) => {
// ev?.preventDefault() ev?.preventDefault()
// ev?.stopPropagation() ev?.stopPropagation()
// setIsChangingBrushSizeByMouse(true) setIsChangingBrushSizeByMouse(true)
// setChangeBrushSizeByMouseInit({ x, y, brushSize }) setChangeBrushSizeByMouseInit({ x, y, brushSize })
// }, },
// (ev) => { (ev) => {
// ev?.preventDefault() ev?.preventDefault()
// ev?.stopPropagation() ev?.stopPropagation()
// setIsChangingBrushSizeByMouse(false) setIsChangingBrushSizeByMouse(false)
// } }
// ) )
const getCurScale = (): number => { const getCurScale = (): number => {
let s = minScale let s = minScale
if (viewportRef.current?.state?.scale !== undefined) { if (viewportRef.current?.instance?.transformState.scale !== undefined) {
s = viewportRef.current?.state.scale s = viewportRef.current?.instance?.transformState.scale
console.log("!!!!!!")
} }
return s! return s!
} }
const getBrushStyle = (_x: number, _y: number) => { const getBrushStyle = (_x: number, _y: number) => {
const curScale = scale const curScale = getCurScale()
return { return {
width: `${brushSize * curScale}px`, width: `${brushSize * curScale}px`,
height: `${brushSize * curScale}px`, height: `${brushSize * curScale}px`,
@ -1435,7 +1379,7 @@ export default function Editor(props: EditorProps) {
const renderInteractiveSegCursor = () => { const renderInteractiveSegCursor = () => {
return ( return (
<div <div
className="interactive-seg-cursor" className="absolute h-[20px] w-[20px] pointer-events-none rounded-[50%] bg-[rgba(21,_215,_121,_0.936)] [box-shadow:0_0_0_0_rgba(21,_215,_121,_0.936)] animate-pulse"
style={{ style={{
left: `${x}px`, left: `${x}px`,
top: `${y}px`, top: `${y}px`,
@ -1475,7 +1419,9 @@ export default function Editor(props: EditorProps) {
}} }}
> >
<TransformComponent <TransformComponent
contentClass={isProcessing ? "pointer-events-none" : ""} contentClass={
isProcessing ? "pointer-events-none animate-pulse duration-700" : ""
}
contentStyle={{ contentStyle={{
visibility: initialCentered ? "visible" : "hidden", visibility: initialCentered ? "visible" : "hidden",
}} }}
@ -1486,7 +1432,7 @@ export default function Editor(props: EditorProps) {
style={{ style={{
cursor: getCursor(), cursor: getCursor(),
clipPath: `inset(0 ${sliderPos}% 0 0)`, clipPath: `inset(0 ${sliderPos}% 0 0)`,
transition: "clip-path 300ms cubic-bezier(0.4, 0, 0.2, 1)", transition: `clip-path ${COMPARE_SLIDER_DURATION_MS}ms`,
}} }}
onContextMenu={(e) => { onContextMenu={(e) => {
e.preventDefault() e.preventDefault()
@ -1519,9 +1465,10 @@ export default function Editor(props: EditorProps) {
{showOriginal && ( {showOriginal && (
<> <>
<div <div
className="[grid-area:original-image-content] h-full w-[6px] justify-self-end [transition:all_300ms_cubic-bezier(0.4,_0,_0.2,_1)]" className="[grid-area:original-image-content] z-10 bg-primary h-full w-[6px] justify-self-end"
style={{ style={{
marginRight: `${sliderPos}%`, marginRight: `${sliderPos}%`,
transition: `margin-right ${COMPARE_SLIDER_DURATION_MS}ms`,
}} }}
/> />
<img <img
@ -1543,12 +1490,12 @@ export default function Editor(props: EditorProps) {
maxWidth={imageWidth} maxWidth={imageWidth}
minHeight={Math.min(256, imageHeight)} minHeight={Math.min(256, imageHeight)}
minWidth={Math.min(256, imageWidth)} minWidth={Math.min(256, imageWidth)}
scale={scale} scale={getCurScale()}
// show={settings.showCroper}
show={true} show={true}
// show={isDiffusionModels && settings.showCroper}
/> />
{/* {isInteractiveSeg ? <InteractiveSeg /> : <></>} */} {/* {interactiveSegState.isInteractiveSeg ? <InteractiveSeg /> : <></>} */}
</TransformComponent> </TransformComponent>
</TransformWrapper> </TransformWrapper>
) )
@ -1558,7 +1505,7 @@ export default function Editor(props: EditorProps) {
setInteractiveSegMask(tmpInteractiveSegMask) setInteractiveSegMask(tmpInteractiveSegMask)
setTmpInteractiveSegMask(null) setTmpInteractiveSegMask(null)
if (!runMannually && tmpInteractiveSegMask) { if (!enableManualInpainting && tmpInteractiveSegMask) {
runInpainting(false, undefined, tmpInteractiveSegMask) runInpainting(false, undefined, tmpInteractiveSegMask)
} }
} }
@ -1570,16 +1517,12 @@ export default function Editor(props: EditorProps) {
onMouseMove={onMouseMove} onMouseMove={onMouseMove}
onMouseUp={onPointerUp} onMouseUp={onPointerUp}
> >
{/* <InteractiveSegConfirmActions
onAcceptClick={onInteractiveAccept}
onCancelClick={onInteractiveCancel}
/> */}
{renderCanvas()} {renderCanvas()}
{showBrush && {showBrush &&
!isInpainting && !isInpainting &&
!isPanning && !isPanning &&
(isInteractiveSeg (interactiveSegState.isInteractiveSeg
? renderInteractiveSegCursor() ? renderInteractiveSegCursor()
: renderBrush( : renderBrush(
getBrushStyle( getBrushStyle(
@ -1590,20 +1533,21 @@ export default function Editor(props: EditorProps) {
{showRefBrush && renderBrush(getBrushStyle(windowCenterX, windowCenterY))} {showRefBrush && renderBrush(getBrushStyle(windowCenterX, windowCenterY))}
<div className="fixed flex bottom-10 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md"> <div className="fixed flex bottom-5 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md bg-background/50">
<Slider <Slider
className="w-48" className="w-48"
defaultValue={[50]} defaultValue={[50]}
min={MIN_BRUSH_SIZE} min={MIN_BRUSH_SIZE}
max={MAX_BRUSH_SIZE} max={MAX_BRUSH_SIZE}
step={1} step={1}
tabIndex={-1}
value={[baseBrushSize]} value={[baseBrushSize]}
onValueChange={(vals) => handleSliderChange(vals[0])} onValueChange={(vals) => handleSliderChange(vals[0])}
onClick={() => setShowRefBrush(false)} onClick={() => setShowRefBrush(false)}
/> />
<div className="flex gap-2"> <div className="flex gap-2">
<IconButton <IconButton
tooltip="Reset Zoom & Pan" tooltip="Reset zoom & pan"
disabled={scale === minScale && panned === false} disabled={scale === minScale && panned === false}
onClick={resetZoom} onClick={resetZoom}
> >
@ -1616,23 +1560,26 @@ export default function Editor(props: EditorProps) {
<Redo /> <Redo />
</IconButton> </IconButton>
<IconButton <IconButton
tooltip="Show Original" tooltip="Show original image"
className={showOriginal ? "eyeicon-active" : ""} onPointerDown={(ev) => {
// onDown={(ev) => { ev.preventDefault()
// ev.preventDefault() setShowOriginal(() => {
// setShowOriginal(() => { window.setTimeout(() => {
// window.setTimeout(() => { setSliderPos(100)
// setSliderPos(100) }, 10)
// }, 10) return true
// return true })
// }) }}
// }} onPointerUp={() => {
// onUp={() => { window.setTimeout(() => {
// setSliderPos(0) // 防止快速点击 show original image 按钮时图片消失
// window.setTimeout(() => { setSliderPos(0)
// setShowOriginal(false) }, 10)
// }, 300)
// }} window.setTimeout(() => {
setShowOriginal(false)
}, COMPARE_SLIDER_DURATION_MS)
}}
disabled={renders.length === 0} disabled={renders.length === 0}
> >
<Eye /> <Eye />
@ -1645,36 +1592,25 @@ export default function Editor(props: EditorProps) {
<Download /> <Download />
</IconButton> </IconButton>
<IconButton {settings.enableManualInpainting ? (
tooltip="Run Inpainting" <IconButton
disabled={ tooltip="Run Inpainting"
isProcessing || disabled={
(!hadDrawSomething() && interactiveSegMask === null) isProcessing ||
} (!hadDrawSomething() && interactiveSegMask === null)
onClick={() => { }
// ensured by disabled onClick={() => {
runInpainting(false, undefined, interactiveSegMask) // ensured by disabled
}} runInpainting(false, undefined, interactiveSegMask)
> }}
<Eraser /> >
</IconButton> <Eraser />
</IconButton>
) : (
<></>
)}
</div> </div>
</div> </div>
{/* <InteractiveSegReplaceModal
show={showInteractiveSegModal}
onClose={() => {
onInteractiveCancel()
setShowInteractiveSegModal(false)
}}
onCleanClick={() => {
onInteractiveCancel()
setInteractiveSegMask(null)
}}
onReplaceClick={() => {
setShowInteractiveSegModal(false)
setIsInteractiveSeg(true)
}}
/> */}
</div> </div>
) )
} }

View File

@ -74,18 +74,9 @@ export default function FileManager(props: Props) {
const { onPhotoClick, photoWidth } = props const { onPhotoClick, photoWidth } = props
const [open, toggleOpen] = useToggle(false) const [open, toggleOpen] = useToggle(false)
const [ const [fileManagerState, updateFileManagerState] = useStore((state) => [
fileManagerState,
setFileManagerLayout,
setFileManagerSortBy,
setFileManagerSortOrder,
setFileManagerSearchText,
] = useStore((state) => [
state.fileManagerState, state.fileManagerState,
state.setFileManagerLayout, state.updateFileManagerState,
state.setFileManagerSortBy,
state.setFileManagerSortOrder,
state.setFileManagerSearchText,
]) ])
useHotkeys("f", () => { useHotkeys("f", () => {
@ -185,7 +176,7 @@ export default function FileManager(props: Props) {
<IconButton <IconButton
tooltip="Rows layout" tooltip="Rows layout"
onClick={() => { onClick={() => {
setFileManagerLayout("rows") updateFileManagerState({ layout: "rows" })
}} }}
> >
<ViewHorizontalIcon <ViewHorizontalIcon
@ -195,7 +186,7 @@ export default function FileManager(props: Props) {
<IconButton <IconButton
tooltip="Grid layout" tooltip="Grid layout"
onClick={() => { onClick={() => {
setFileManagerLayout("masonry") updateFileManagerState({ layout: "masonry" })
}} }}
> >
<ViewGridIcon <ViewGridIcon
@ -230,7 +221,7 @@ export default function FileManager(props: Props) {
evt.preventDefault() evt.preventDefault()
evt.stopPropagation() evt.stopPropagation()
const target = evt.target as HTMLInputElement const target = evt.target as HTMLInputElement
setFileManagerSearchText(target.value) updateFileManagerState({ searchText: target.value })
}} }}
placeholder="Search by file name" placeholder="Search by file name"
/> />
@ -250,13 +241,13 @@ export default function FileManager(props: Props) {
onValueChange={(val) => { onValueChange={(val) => {
switch (val) { switch (val) {
case SORT_BY_NAME: case SORT_BY_NAME:
setFileManagerSortBy(SortBy.NAME) updateFileManagerState({ sortBy: SortBy.NAME })
break break
case SORT_BY_CREATED_TIME: case SORT_BY_CREATED_TIME:
setFileManagerSortBy(SortBy.CTIME) updateFileManagerState({ sortBy: SortBy.CTIME })
break break
case SORT_BY_MODIFIED_TIME: case SORT_BY_MODIFIED_TIME:
setFileManagerSortBy(SortBy.MTIME) updateFileManagerState({ sortBy: SortBy.MTIME })
break break
default: default:
break break
@ -281,7 +272,7 @@ export default function FileManager(props: Props) {
<IconButton <IconButton
tooltip="Descending Order" tooltip="Descending Order"
onClick={() => { onClick={() => {
setFileManagerSortOrder(SortOrder.ASCENDING) updateFileManagerState({ sortOrder: SortOrder.ASCENDING })
}} }}
> >
<BarsArrowDownIcon /> <BarsArrowDownIcon />
@ -290,7 +281,7 @@ export default function FileManager(props: Props) {
<IconButton <IconButton
tooltip="Ascending Order" tooltip="Ascending Order"
onClick={() => { onClick={() => {
setFileManagerSortOrder(SortOrder.DESCENDING) updateFileManagerState({ sortOrder: SortOrder.DESCENDING })
}} }}
> >
<BarsArrowUpIcon /> <BarsArrowUpIcon />

View File

@ -1,19 +1,8 @@
import { PlayIcon } from "@radix-ui/react-icons" import { PlayIcon } from "@radix-ui/react-icons"
import React, { useCallback, useState } from "react" import { useCallback, useState } from "react"
import { useRecoilState, useRecoilValue } from "recoil"
import { useHotkeys } from "react-hotkeys-hook" import { useHotkeys } from "react-hotkeys-hook"
import {
enableFileManagerState,
isPix2PixState,
isSDState,
maskState,
runManuallyState,
} from "@/lib/store"
import { IconButton, ImageUploadButton } from "@/components/ui/button" import { IconButton, ImageUploadButton } from "@/components/ui/button"
import Shortcuts from "@/components/Shortcuts" import Shortcuts from "@/components/Shortcuts"
// import SettingIcon from "../Settings/SettingIcon"
// import PromptInput from "./PromptInput"
// import CoffeeIcon from '../CoffeeIcon/CoffeeIcon'
import emitter, { import emitter, {
DREAM_BUTTON_MOUSE_ENTER, DREAM_BUTTON_MOUSE_ENTER,
DREAM_BUTTON_MOUSE_LEAVE, DREAM_BUTTON_MOUSE_LEAVE,
@ -24,24 +13,37 @@ import { useImage } from "@/hooks/useImage"
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover" import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"
import PromptInput from "./PromptInput" import PromptInput from "./PromptInput"
import { RotateCw, Image } from "lucide-react" import { RotateCw, Image, Upload } from "lucide-react"
import FileManager from "./FileManager" import FileManager from "./FileManager"
import { getMediaFile } from "@/lib/api" import { getMediaFile } from "@/lib/api"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import SettingsDialog from "./Settings"
import { cn } from "@/lib/utils"
const Header = () => { const Header = () => {
const [file, isInpainting, setFile] = useStore((state) => [ const [
file,
customMask,
isInpainting,
enableFileManager,
enableManualInpainting,
enableUploadMask,
shouldShowPromptInput,
setFile,
setCustomFile,
] = useStore((state) => [
state.file, state.file,
state.customMask,
state.isInpainting, state.isInpainting,
state.serverConfig.enableFileManager,
state.settings.enableManualInpainting,
state.settings.enableUploadMask,
state.shouldShowPromptInput(),
state.setFile, state.setFile,
state.setCustomFile,
]) ])
const [mask, setMask] = useRecoilState(maskState) const [maskImage, maskImageLoaded] = useImage(customMask)
// const [maskImage, maskImageLoaded] = useImage(mask)
const isSD = useRecoilValue(isSDState)
const isPix2Pix = useRecoilValue(isPix2PixState)
const runManually = useRecoilValue(runManuallyState)
const [openMaskPopover, setOpenMaskPopover] = useState(false) const [openMaskPopover, setOpenMaskPopover] = useState(false)
const enableFileManager = useRecoilValue(enableFileManagerState)
const handleRerunLastMask = useCallback(() => { const handleRerunLastMask = useCallback(() => {
emitter.emit(RERUN_LAST_MASK) emitter.emit(RERUN_LAST_MASK)
@ -68,7 +70,7 @@ const Header = () => {
return ( return (
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 backdrop-filter backdrop-blur-md border-b"> <header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 backdrop-filter backdrop-blur-md border-b">
<div className="flex items-center"> <div className="flex items-center gap-1">
{enableFileManager ? ( {enableFileManager ? (
<FileManager <FileManager
photoWidth={512} photoWidth={512}
@ -92,38 +94,37 @@ const Header = () => {
</ImageUploadButton> </ImageUploadButton>
<div <div
className="flex items-center" className={cn([
style={{ "flex items-center gap-1",
visibility: file ? "visible" : "hidden", file && enableUploadMask ? "visible" : "hidden",
}} ])}
> >
<ImageUploadButton <ImageUploadButton
disabled={isInpainting} disabled={isInpainting}
tooltip="Upload custom mask" tooltip="Upload custom mask"
onFileUpload={(file) => { onFileUpload={(file) => {
setMask(file) setCustomFile(file)
console.info("Send custom mask") if (!enableManualInpainting) {
if (!runManually) {
emitter.emit(EVENT_CUSTOM_MASK, { mask: file }) emitter.emit(EVENT_CUSTOM_MASK, { mask: file })
} }
}} }}
> >
<div>M</div> <Upload />
</ImageUploadButton> </ImageUploadButton>
{mask ? ( {customMask ? (
<Popover open={openMaskPopover}> <Popover open={openMaskPopover}>
<PopoverTrigger <PopoverTrigger
className="btn-primary side-panel-trigger" className="btn-primary side-panel-trigger"
onMouseEnter={() => setOpenMaskPopover(true)} onMouseEnter={() => setOpenMaskPopover(true)}
onMouseLeave={() => setOpenMaskPopover(false)} onMouseLeave={() => setOpenMaskPopover(false)}
style={{ style={{
visibility: mask ? "visible" : "hidden", visibility: customMask ? "visible" : "hidden",
outline: "none", outline: "none",
}} }}
onClick={() => { onClick={() => {
if (mask) { if (customMask) {
emitter.emit(EVENT_CUSTOM_MASK, { mask }) emitter.emit(EVENT_CUSTOM_MASK, { mask: customMask })
} }
}} }}
> >
@ -131,36 +132,36 @@ const Header = () => {
<PlayIcon /> <PlayIcon />
</IconButton> </IconButton>
</PopoverTrigger> </PopoverTrigger>
{/* <PopoverContent> <PopoverContent>
{maskImageLoaded ? ( {maskImageLoaded ? (
<img src={maskImage.src} alt="Custom mask" /> <img src={maskImage.src} alt="Custom mask" />
) : ( ) : (
<></> <></>
)} )}
</PopoverContent> */} </PopoverContent>
</Popover> </Popover>
) : ( ) : (
<></> <></>
)} )}
<IconButton
disabled={isInpainting}
tooltip="Rerun last mask"
onClick={handleRerunLastMask}
onMouseEnter={onRerunMouseEnter}
onMouseLeave={onRerunMouseLeave}
>
<RotateCw />
</IconButton>
</div> </div>
<IconButton
disabled={isInpainting}
tooltip="Rerun last mask"
onClick={handleRerunLastMask}
onMouseEnter={onRerunMouseEnter}
onMouseLeave={onRerunMouseLeave}
>
<RotateCw className={file ? "visible" : "hidden"} />
</IconButton>
</div> </div>
{isSD ? <PromptInput /> : <></>} {shouldShowPromptInput ? <PromptInput /> : <></>}
{/* <CoffeeIcon /> */} <div className="flex gap-1">
<div> {/* <CoffeeIcon /> */}
<Shortcuts /> <Shortcuts />
{/* <SettingIcon /> */} <SettingsDialog />
</div> </div>
</header> </header>
) )

View File

@ -11,7 +11,7 @@ const ImageSize = () => {
} }
return ( return (
<div className="border rounded-lg px-2 py-[6px] z-10"> <div className="border rounded-lg px-2 py-[6px] z-10 bg-background">
{imageWidth}x{imageHeight} {imageWidth}x{imageHeight}
</div> </div>
) )

View File

@ -0,0 +1,136 @@
import { useStore } from "@/lib/states"
import { Button } from "./ui/button"
import { Dialog, DialogContent, DialogTitle } from "./ui/dialog"
import { MousePointerClick } from "lucide-react"
import { DropdownMenuItem } from "./ui/dropdown-menu"
interface InteractiveSegReplaceModal {
show: boolean
onClose: () => void
onCleanClick: () => void
onReplaceClick: () => void
}
const InteractiveSegReplaceModal = (props: InteractiveSegReplaceModal) => {
const { show, onClose, onCleanClick, onReplaceClick } = props
const onOpenChange = (open: boolean) => {
if (!open) {
onClose()
}
}
return (
<Dialog open={show} onOpenChange={onOpenChange}>
<DialogContent>
<DialogTitle>Do you want to remove it or create a new one?</DialogTitle>
<div className="flex gap-[12px] w-full justify-end items-center">
<Button
onClick={() => {
onClose()
onCleanClick()
}}
>
Remove
</Button>
<Button onClick={onReplaceClick}>Create new</Button>
</div>
</DialogContent>
</Dialog>
)
}
const InteractiveSegConfirmActions = () => {
const [interactiveSegState, resetInteractiveSegState] = useStore((state) => [
state.interactiveSegState,
state.resetInteractiveSegState,
])
if (!interactiveSegState.isInteractiveSeg) {
return null
}
const onAcceptClick = () => {
resetInteractiveSegState()
}
return (
<div className="z-10 absolute top-[68px] rounded-xl border-solid border p-[8px] left-1/2 translate-x-[-50%] flex justify-center items-center gap-[8px] bg-background">
<Button
onClick={() => {
resetInteractiveSegState()
}}
size="sm"
variant="secondary"
>
Cancel
</Button>
<Button
size="sm"
onClick={() => {
onAcceptClick()
}}
>
Accept
</Button>
</div>
)
}
interface ItemProps {
x: number
y: number
positive: boolean
}
const Item = (props: ItemProps) => {
const { x, y, positive } = props
const name = positive
? "bg-[rgba(21,_215,_121,_0.936)] outline-[6px_solid_rgba(98,_255,_179,_0.31)]"
: "bg-[rgba(237,_49,_55,_0.942)] outline-[6px_solid_rgba(255,_89,_95,_0.31)]"
return (
<div
className={`absolute h-[8px] w-[8px] rounded-[50%] ${name}`}
style={{
left: x,
top: y,
transform: "translate(-50%, -50%)",
}}
/>
)
}
const InteractiveSegPoints = () => {
const clicks = useStore((state) => state.interactiveSegState.clicks)
return (
<div className="absolute h-full w-full overflow-hidden pointer-events-none">
{clicks.map((click) => {
return (
<Item
key={click[3]}
x={click[0]}
y={click[1]}
positive={click[2] === 1}
/>
)
})}
</div>
)
}
const InteractiveSeg = () => {
const [interactiveSegState, updateInteractiveSegState] = useStore((state) => [
state.interactiveSegState,
state.updateInteractiveSegState,
])
return (
<div>
<InteractiveSegConfirmActions />
{/* <InteractiveSegReplaceModal /> */}
</div>
)
}
export { InteractiveSeg, InteractiveSegPoints }

View File

@ -10,6 +10,8 @@ import {
import { Button } from "./ui/button" import { Button } from "./ui/button"
import { Fullscreen, MousePointerClick, Slice, Smile } from "lucide-react" import { Fullscreen, MousePointerClick, Slice, Smile } from "lucide-react"
import { MixIcon } from "@radix-ui/react-icons" import { MixIcon } from "@radix-ui/react-icons"
import { useStore } from "@/lib/states"
import { InteractiveSeg } from "./InteractiveSeg"
export enum PluginName { export enum PluginName {
RemoveBG = "RemoveBG", RemoveBG = "RemoveBG",
@ -48,17 +50,10 @@ const pluginMap = {
} }
const Plugins = () => { const Plugins = () => {
// const [open, toggleOpen] = useToggle(true) const [plugins, updateInteractiveSegState] = useStore((state) => [
// const serverConfig = useRecoilValue(serverConfigState) state.serverConfig.plugins,
// const isProcessing = useRecoilValue(isProcessingState) state.updateInteractiveSegState,
const plugins = [ ])
PluginName.RemoveBG,
PluginName.AnimeSeg,
PluginName.RealESRGAN,
PluginName.GFPGAN,
PluginName.RestoreFormer,
PluginName.InteractiveSeg,
]
if (plugins.length === 0) { if (plugins.length === 0) {
return null return null
@ -68,6 +63,9 @@ const Plugins = () => {
// if (!disabled) { // if (!disabled) {
// emitter.emit(pluginName) // emitter.emit(pluginName)
// } // }
if (pluginName === PluginName.InteractiveSeg) {
updateInteractiveSegState({ isInteractiveSeg: true })
}
} }
const onRealESRGANClick = (upscale: number) => { const onRealESRGANClick = (upscale: number) => {
@ -98,8 +96,8 @@ const Plugins = () => {
} }
const renderPlugins = () => { const renderPlugins = () => {
return plugins.map((plugin: PluginName) => { return plugins.map((plugin: string) => {
const { IconClass, showName } = pluginMap[plugin] const { IconClass, showName } = pluginMap[plugin as PluginName]
if (plugin === PluginName.RealESRGAN) { if (plugin === PluginName.RealESRGAN) {
return renderRealESRGANPlugin() return renderRealESRGANPlugin()
} }
@ -116,7 +114,10 @@ const Plugins = () => {
return ( return (
<DropdownMenu modal={false}> <DropdownMenu modal={false}>
<DropdownMenuTrigger className="border rounded-lg z-10"> <DropdownMenuTrigger
className="border rounded-lg z-10 bg-background"
tabIndex={-1}
>
<Button variant="ghost" size="icon" asChild> <Button variant="ghost" size="icon" asChild>
<MixIcon className="p-2" /> <MixIcon className="p-2" />
</Button> </Button>

View File

@ -9,17 +9,17 @@ import { Input } from "./ui/input"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
const PromptInput = () => { const PromptInput = () => {
const [isInpainting, prompt, setPrompt] = useStore((state) => [ const [isInpainting, prompt, updateSettings] = useStore((state) => [
state.isInpainting, state.isInpainting,
state.prompt, state.settings.prompt,
state.setPrompt, state.updateSettings,
]) ])
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => { const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
evt.preventDefault() evt.preventDefault()
evt.stopPropagation() evt.stopPropagation()
const target = evt.target as HTMLInputElement const target = evt.target as HTMLInputElement
setPrompt(target.value) updateSettings({ prompt: target.value })
} }
const handleRepaintClick = () => { const handleRepaintClick = () => {

View File

@ -0,0 +1,435 @@
import { IconButton } from "@/components/ui/button"
import { useToggle } from "@uidotdev/usehooks"
import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog"
import { useHotkeys } from "react-hotkeys-hook"
import { Info, Settings } from "lucide-react"
import { zodResolver } from "@hookform/resolvers/zod"
import { useForm } from "react-hook-form"
import * as z from "zod"
import { Button } from "@/components/ui/button"
import { Separator } from "@/components/ui/separator"
import {
Form,
FormControl,
FormDescription,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form"
import { Input } from "@/components/ui/input"
import { Switch } from "./ui/switch"
import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
import { useState } from "react"
import { cn } from "@/lib/utils"
import { useQuery } from "@tanstack/react-query"
import { fetchModelInfos, switchModel } from "@/lib/api"
import { ModelInfo } from "@/lib/types"
import { useStore } from "@/lib/states"
import { ScrollArea } from "./ui/scroll-area"
import { useToast } from "./ui/use-toast"
import {
AlertDialog,
AlertDialogContent,
AlertDialogDescription,
AlertDialogHeader,
} from "./ui/alert-dialog"
const formSchema = z.object({
enableFileManager: z.boolean(),
inputDirectory: z.string().refine(async (id) => {
// verify that ID exists in database
return true
}),
outputDirectory: z.string().refine(async (id) => {
// verify that ID exists in database
return true
}),
enableDownloadMask: z.boolean(),
enableManualInpainting: z.boolean(),
enableUploadMask: z.boolean(),
})
const TAB_GENERAL = "General"
const TAB_MODEL = "Model"
const TAB_FILE_MANAGER = "File Manager"
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL]
export function SettingsDialog() {
const [open, toggleOpen] = useToggle(false)
const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
const [tab, setTab] = useState(TAB_GENERAL)
const [settings, updateSettings, fileManagerState, updateFileManagerState] =
useStore((state) => [
state.settings,
state.updateSettings,
state.fileManagerState,
state.updateFileManagerState,
])
const { toast } = useToast()
const [model, setModel] = useState<ModelInfo>(settings.model)
const { data: modelInfos, isSuccess } = useQuery({
queryKey: ["modelInfos"],
queryFn: fetchModelInfos,
})
// 1. Define your form.
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
enableDownloadMask: settings.enableDownloadMask,
enableManualInpainting: settings.enableManualInpainting,
enableUploadMask: settings.enableUploadMask,
enableFileManager: fileManagerState.enabled,
inputDirectory: fileManagerState.inputDirectory,
outputDirectory: fileManagerState.outputDirectory,
},
})
function onSubmit(values: z.infer<typeof formSchema>) {
// Do something with the form values. ✅ This will be type-safe and validated.
updateSettings({
enableDownloadMask: values.enableDownloadMask,
enableManualInpainting: values.enableManualInpainting,
enableUploadMask: values.enableUploadMask,
})
// TODO: validate input/output Directory
updateFileManagerState({
enabled: values.enableFileManager,
inputDirectory: values.inputDirectory,
outputDirectory: values.outputDirectory,
})
if (model.name !== settings.model.name) {
toggleOpenModelSwitching()
switchModel(model.name)
.then((res) => {
if (res.ok) {
toast({
title: `Switch to ${model.name} success`,
})
updateSettings({ model: model })
} else {
throw new Error("Server error")
}
})
.catch(() => {
toast({
variant: "destructive",
title: `Switch to ${model.name} failed`,
})
})
.finally(() => {
toggleOpenModelSwitching()
})
}
}
useHotkeys("s", () => {
toggleOpen()
form.handleSubmit(onSubmit)()
})
function onOpenChange(value: boolean) {
toggleOpen()
if (!value) {
form.handleSubmit(onSubmit)()
}
}
function onModelSelect(info: ModelInfo) {
setModel(info)
}
function renderModelList(model_types: string[]) {
if (!modelInfos) {
return <div>Please download model first</div>
}
return modelInfos
.filter((info) => model_types.includes(info.model_type))
.map((info: ModelInfo) => {
return (
<div key={info.name} onClick={() => onModelSelect(info)}>
<div
className={cn([
info.name === model.name ? "bg-muted " : "hover:bg-muted",
"rounded-md px-2 py-1 my-1",
"cursor-default",
])}
>
<div className="text-base max-w-sm">{info.name}</div>
</div>
<Separator />
</div>
)
})
}
function renderModelSettings() {
if (!isSuccess) {
return <></>
}
let defaultTab = "inpaint"
for (let info of modelInfos) {
if (model.name === info.name) {
defaultTab = info.model_type
break
}
}
return (
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-4 rounded-md">
<div>Current Model</div>
<div>{model.name}</div>
</div>
<Separator />
<div className="space-y-4 rounded-md">
<div className="flex gap-4 items-center justify-start">
<div>Available models</div>
<IconButton tooltip="How to download new model" asChild>
<Info />
</IconButton>
</div>
<Tabs defaultValue={defaultTab}>
<TabsList>
<TabsTrigger value="inpaint">Inpaint</TabsTrigger>
<TabsTrigger value="diffusers_sd">Diffusion</TabsTrigger>
<TabsTrigger value="diffusers_sd_inpaint">
Diffusion inpaint
</TabsTrigger>
<TabsTrigger value="diffusers_other">Diffusion other</TabsTrigger>
</TabsList>
<ScrollArea className="h-[240px] w-full mt-2">
<TabsContent value="inpaint">
{renderModelList(["inpaint"])}
</TabsContent>
<TabsContent value="diffusers_sd">
{renderModelList(["diffusers_sd", "diffusers_sdxl"])}
</TabsContent>
<TabsContent value="diffusers_sd_inpaint">
{renderModelList([
"diffusers_sd_inpaint",
"diffusers_sdxl_inpaint",
])}
</TabsContent>
<TabsContent value="diffusers_other">
{renderModelList(["diffusers_other"])}
</TabsContent>
</ScrollArea>
</Tabs>
</div>
</div>
)
}
function renderGeneralSettings() {
return (
<div className="space-y-4 w-[400px]">
<FormField
control={form.control}
name="enableManualInpainting"
render={({ field }) => (
<FormItem className="flex items-center justify-between">
<div className="space-y-0.5">
<FormLabel>Enable manual inpainting</FormLabel>
<FormDescription>
Click a button to trigger inpainting after draw mask.
</FormDescription>
</div>
<FormControl>
<Switch
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
</FormItem>
)}
/>
<Separator />
<FormField
control={form.control}
name="enableDownloadMask"
render={({ field }) => (
<FormItem className="flex items-center justify-between">
<div className="space-y-0.5">
<FormLabel>Enable download mask</FormLabel>
<FormDescription>
Also download the mask after save the inpainting result.
</FormDescription>
</div>
<FormControl>
<Switch
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
</FormItem>
)}
/>
<Separator />
<FormField
control={form.control}
name="enableUploadMask"
render={({ field }) => (
<FormItem className="flex tems-center justify-between">
<div className="space-y-0.5">
<FormLabel>Enable upload mask</FormLabel>
<FormDescription>
Enable upload custom mask to perform inpainting.
</FormDescription>
</div>
<FormControl>
<Switch
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
</FormItem>
)}
/>
<Separator />
</div>
)
}
function renderFileManagerSettings() {
return (
<div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
<FormField
control={form.control}
name="enableFileManager"
render={({ field }) => (
<FormItem className="flex items-center justify-between gap-4">
<div className="space-y-0.5">
<FormLabel>Enable file manger</FormLabel>
<FormDescription className="max-w-sm">
Browser images
</FormDescription>
</div>
<FormControl>
<Switch
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
</FormItem>
)}
/>
<Separator />
<FormField
control={form.control}
name="inputDirectory"
render={({ field }) => (
<FormItem>
<FormLabel>Input directory</FormLabel>
<FormControl>
<Input placeholder="" {...field} />
</FormControl>
<FormDescription>
Browser images from this directory.
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="outputDirectory"
render={({ field }) => (
<FormItem>
<FormLabel>Save directory</FormLabel>
<FormControl>
<Input placeholder="" {...field} />
</FormControl>
<FormDescription>
Result images will be saved to this directory.
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
</div>
)
}
return (
<>
<AlertDialog open={openModelSwitching}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogDescription>
TODO: 添加加载动画 Switching to {model.name}
</AlertDialogDescription>
</AlertDialogHeader>
</AlertDialogContent>
</AlertDialog>
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogTrigger asChild>
<IconButton tooltip="Settings">
<Settings />
</IconButton>
</DialogTrigger>
<DialogContent
className="max-w-3xl h-[600px]"
// onEscapeKeyDown={(event) => event.preventDefault()}
onOpenAutoFocus={(event) => event.preventDefault()}
// onPointerDownOutside={(event) => event.preventDefault()}
>
<DialogTitle>Settings</DialogTitle>
<Separator />
<div className="flex flex-row space-x-8 h-full">
<div className="flex flex-col space-y-1">
{TAB_NAMES.map((item) => (
<Button
key={item}
variant="ghost"
onClick={() => setTab(item)}
className={cn(
tab === item ? "bg-muted " : "hover:bg-muted",
"justify-start"
)}
>
{item}
</Button>
))}
</div>
<Separator orientation="vertical" />
<Form {...form}>
<div className="flex w-full justify-center">
<form onSubmit={form.handleSubmit(onSubmit)}>
{tab === TAB_MODEL ? renderModelSettings() : <></>}
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
{/* {tab === TAB_FILE_MANAGER ? (
renderFileManagerSettings()
) : (
<></>
)} */}
{/* <div className=" absolute right-">
<Button type="submit">Ok</Button>
</div> */}
</form>
</div>
</Form>
</div>
</DialogContent>
</Dialog>
</>
)
}
export default SettingsDialog

View File

@ -1,18 +1,16 @@
import { useEffect } from "react" import { useEffect } from "react"
import { useRecoilState, useRecoilValue, useSetRecoilState } from "recoil"
import Editor from "./Editor" import Editor from "./Editor"
// import SettingModal from "./Settings/SettingsModal"
import { import {
AIModel, AIModel,
isPaintByExampleState, isPaintByExampleState,
isPix2PixState, isPix2PixState,
isSDState, isSDState,
settingState,
} from "@/lib/store" } from "@/lib/store"
import { currentModel, modelDownloaded, switchModel } from "@/lib/api" import { currentModel } from "@/lib/api"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import ImageSize from "./ImageSize" import ImageSize from "./ImageSize"
import Plugins from "./Plugins" import Plugins from "./Plugins"
import { InteractiveSeg } from "./InteractiveSeg"
// import SidePanel from "./SidePanel/SidePanel" // import SidePanel from "./SidePanel/SidePanel"
// import PESidePanel from "./SidePanel/PESidePanel" // import PESidePanel from "./SidePanel/PESidePanel"
// import P2PSidePanel from "./SidePanel/P2PSidePanel" // import P2PSidePanel from "./SidePanel/P2PSidePanel"
@ -21,73 +19,18 @@ import Plugins from "./Plugins"
// import ImageSize from "./ImageSize/ImageSize" // import ImageSize from "./ImageSize/ImageSize"
const Workspace = () => { const Workspace = () => {
const file = useStore((state) => state.file) const [file, updateSettings] = useStore((state) => [
const [settings, setSettingState] = useRecoilState(settingState) state.file,
const isSD = useRecoilValue(isSDState) state.updateSettings,
const isPaintByExample = useRecoilValue(isPaintByExampleState) ])
const isPix2Pix = useRecoilValue(isPix2PixState)
const onSettingClose = async () => {
const curModel = await currentModel().then((res) => res.text())
if (curModel === settings.model) {
return
}
const downloaded = await modelDownloaded(settings.model).then((res) =>
res.text()
)
const { model } = settings
let loadingMessage = `Switching to ${model} model`
let loadingDuration = 3000
if (downloaded === "False") {
loadingMessage = `Downloading ${model} model, this may take a while`
loadingDuration = 9999999999
}
// TODO 修改成 Modal
// setToastState({
// open: true,
// desc: loadingMessage,
// state: "loading",
// duration: loadingDuration,
// })
switchModel(model)
.then((res) => {
if (res.ok) {
// setToastState({
// open: true,
// desc: `Switch to ${model} model success`,
// state: "success",
// duration: 3000,
// })
} else {
throw new Error("Server error")
}
})
.catch(() => {
// setToastState({
// open: true,
// desc: `Switch to ${model} model failed`,
// state: "error",
// duration: 3000,
// })
setSettingState((old) => {
return { ...old, model: curModel as AIModel }
})
})
}
useEffect(() => { useEffect(() => {
currentModel() currentModel()
.then((res) => res.text()) .then((res) => res.json())
.then((model) => { .then((model) => {
setSettingState((old) => { updateSettings({ model })
return { ...old, model: model as AIModel }
})
}) })
}, [setSettingState]) }, [])
return ( return (
<> <>
@ -99,6 +42,7 @@ const Workspace = () => {
<Plugins /> <Plugins />
<ImageSize /> <ImageSize />
</div> </div>
<InteractiveSeg />
{file ? <Editor file={file} /> : <></>} {file ? <Editor file={file} /> : <></>}
</> </>
) )

View File

@ -0,0 +1,139 @@
import * as React from "react"
import * as AlertDialogPrimitive from "@radix-ui/react-alert-dialog"
import { cn } from "@/lib/utils"
import { buttonVariants } from "@/components/ui/button"
const AlertDialog = AlertDialogPrimitive.Root
const AlertDialogTrigger = AlertDialogPrimitive.Trigger
const AlertDialogPortal = AlertDialogPrimitive.Portal
const AlertDialogOverlay = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Overlay
className={cn(
"fixed inset-0 z-50 bg-background/80 backdrop-blur-sm data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
className
)}
{...props}
ref={ref}
/>
))
AlertDialogOverlay.displayName = AlertDialogPrimitive.Overlay.displayName
const AlertDialogContent = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Content>
>(({ className, ...props }, ref) => (
<AlertDialogPortal>
<AlertDialogOverlay />
<AlertDialogPrimitive.Content
ref={ref}
className={cn(
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 border bg-background p-6 shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%] sm:rounded-lg",
className
)}
{...props}
/>
</AlertDialogPortal>
))
AlertDialogContent.displayName = AlertDialogPrimitive.Content.displayName
const AlertDialogHeader = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col space-y-2 text-center sm:text-left",
className
)}
{...props}
/>
)
AlertDialogHeader.displayName = "AlertDialogHeader"
const AlertDialogFooter = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2",
className
)}
{...props}
/>
)
AlertDialogFooter.displayName = "AlertDialogFooter"
const AlertDialogTitle = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Title>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Title>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Title
ref={ref}
className={cn("text-lg font-semibold", className)}
{...props}
/>
))
AlertDialogTitle.displayName = AlertDialogPrimitive.Title.displayName
const AlertDialogDescription = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Description>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Description>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Description
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
))
AlertDialogDescription.displayName =
AlertDialogPrimitive.Description.displayName
const AlertDialogAction = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Action>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Action>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Action
ref={ref}
className={cn(buttonVariants(), className)}
{...props}
/>
))
AlertDialogAction.displayName = AlertDialogPrimitive.Action.displayName
const AlertDialogCancel = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Cancel>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Cancel>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Cancel
ref={ref}
className={cn(
buttonVariants({ variant: "outline" }),
"mt-2 sm:mt-0",
className
)}
{...props}
/>
))
AlertDialogCancel.displayName = AlertDialogPrimitive.Cancel.displayName
export {
AlertDialog,
AlertDialogPortal,
AlertDialogOverlay,
AlertDialogTrigger,
AlertDialogContent,
AlertDialogHeader,
AlertDialogFooter,
AlertDialogTitle,
AlertDialogDescription,
AlertDialogAction,
AlertDialogCancel,
}

View File

@ -78,7 +78,7 @@ const IconButton = React.forwardRef<HTMLButtonElement, IconButtonProps>(
{...rest} {...rest}
ref={ref} ref={ref}
tabIndex={-1} tabIndex={-1}
className="cursor-default" className="cursor-default bg-background"
> >
<div className="icon-button-icon-wrapper">{children}</div> <div className="icon-button-icon-wrapper">{children}</div>
</Button> </Button>

View File

@ -87,7 +87,7 @@ const DialogTitle = React.forwardRef<
<DialogPrimitive.Title <DialogPrimitive.Title
ref={ref} ref={ref}
className={cn( className={cn(
"text-lg font-semibold leading-none tracking-tight", "text-2xl font-semibold leading-none tracking-tight",
className className
)} )}
{...props} {...props}

View File

@ -0,0 +1,176 @@
import * as React from "react"
import * as LabelPrimitive from "@radix-ui/react-label"
import { Slot } from "@radix-ui/react-slot"
import {
Controller,
ControllerProps,
FieldPath,
FieldValues,
FormProvider,
useFormContext,
} from "react-hook-form"
import { cn } from "@/lib/utils"
import { Label } from "@/components/ui/label"
const Form = FormProvider
type FormFieldContextValue<
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>
> = {
name: TName
}
const FormFieldContext = React.createContext<FormFieldContextValue>(
{} as FormFieldContextValue
)
const FormField = <
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>
>({
...props
}: ControllerProps<TFieldValues, TName>) => {
return (
<FormFieldContext.Provider value={{ name: props.name }}>
<Controller {...props} />
</FormFieldContext.Provider>
)
}
const useFormField = () => {
const fieldContext = React.useContext(FormFieldContext)
const itemContext = React.useContext(FormItemContext)
const { getFieldState, formState } = useFormContext()
const fieldState = getFieldState(fieldContext.name, formState)
if (!fieldContext) {
throw new Error("useFormField should be used within <FormField>")
}
const { id } = itemContext
return {
id,
name: fieldContext.name,
formItemId: `${id}-form-item`,
formDescriptionId: `${id}-form-item-description`,
formMessageId: `${id}-form-item-message`,
...fieldState,
}
}
type FormItemContextValue = {
id: string
}
const FormItemContext = React.createContext<FormItemContextValue>(
{} as FormItemContextValue
)
const FormItem = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => {
const id = React.useId()
return (
<FormItemContext.Provider value={{ id }}>
<div ref={ref} className={cn("space-y-2", className)} {...props} />
</FormItemContext.Provider>
)
})
FormItem.displayName = "FormItem"
const FormLabel = React.forwardRef<
React.ElementRef<typeof LabelPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root>
>(({ className, ...props }, ref) => {
const { error, formItemId } = useFormField()
return (
<Label
ref={ref}
className={cn(error && "text-destructive", "text-sm", className)}
htmlFor={formItemId}
{...props}
/>
)
})
FormLabel.displayName = "FormLabel"
const FormControl = React.forwardRef<
React.ElementRef<typeof Slot>,
React.ComponentPropsWithoutRef<typeof Slot>
>(({ ...props }, ref) => {
const { error, formItemId, formDescriptionId, formMessageId } = useFormField()
return (
<Slot
ref={ref}
id={formItemId}
aria-describedby={
!error
? `${formDescriptionId}`
: `${formDescriptionId} ${formMessageId}`
}
aria-invalid={!!error}
{...props}
/>
)
})
FormControl.displayName = "FormControl"
const FormDescription = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, ...props }, ref) => {
const { formDescriptionId } = useFormField()
return (
<p
ref={ref}
id={formDescriptionId}
className={cn("text-[0.8rem] text-muted-foreground", className)}
{...props}
/>
)
})
FormDescription.displayName = "FormDescription"
const FormMessage = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, children, ...props }, ref) => {
const { error, formMessageId } = useFormField()
const body = error ? String(error?.message) : children
if (!body) {
return null
}
return (
<p
ref={ref}
id={formMessageId}
className={cn("text-[0.8rem] font-medium text-destructive", className)}
{...props}
>
{body}
</p>
)
})
FormMessage.displayName = "FormMessage"
export {
useFormField,
Form,
FormItem,
FormLabel,
FormControl,
FormDescription,
FormMessage,
FormField,
}

View File

@ -0,0 +1,29 @@
import * as React from "react"
import * as SeparatorPrimitive from "@radix-ui/react-separator"
import { cn } from "@/lib/utils"
const Separator = React.forwardRef<
React.ElementRef<typeof SeparatorPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SeparatorPrimitive.Root>
>(
(
{ className, orientation = "horizontal", decorative = true, ...props },
ref
) => (
<SeparatorPrimitive.Root
ref={ref}
decorative={decorative}
orientation={orientation}
className={cn(
"shrink-0 bg-border",
orientation === "horizontal" ? "h-[1px] w-full" : "h-full w-[1px]",
className
)}
{...props}
/>
)
)
Separator.displayName = SeparatorPrimitive.Root.displayName
export { Separator }

View File

@ -18,7 +18,10 @@ const Slider = React.forwardRef<
<SliderPrimitive.Track className="relative h-1.5 w-full grow overflow-hidden rounded-full bg-primary/20"> <SliderPrimitive.Track className="relative h-1.5 w-full grow overflow-hidden rounded-full bg-primary/20">
<SliderPrimitive.Range className="absolute h-full bg-primary" /> <SliderPrimitive.Range className="absolute h-full bg-primary" />
</SliderPrimitive.Track> </SliderPrimitive.Track>
<SliderPrimitive.Thumb className="block h-4 w-4 rounded-full border border-primary/50 bg-background shadow transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50" /> <SliderPrimitive.Thumb
tabIndex={-1}
className="block h-4 w-4 rounded-full border border-primary/60 bg-background shadow transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50"
/>
</SliderPrimitive.Root> </SliderPrimitive.Root>
)) ))
Slider.displayName = SliderPrimitive.Root.displayName Slider.displayName = SliderPrimitive.Root.displayName

View File

@ -42,7 +42,7 @@
--popover: 0 0% 100%; --popover: 0 0% 100%;
--popover-foreground: 224 71.4% 4.1%; --popover-foreground: 224 71.4% 4.1%;
--primary: 220.9 39.3% 11%; --primary: 48 100.0% 50.0%;
--primary-foreground: 210 20% 98%; --primary-foreground: 210 20% 98%;
--secondary: 220 14.3% 95.9%; --secondary: 220 14.3% 95.9%;
@ -74,7 +74,7 @@
--popover: 224 71.4% 4.1%; --popover: 224 71.4% 4.1%;
--popover-foreground: 210 20% 98%; --popover-foreground: 210 20% 98%;
--primary: 210 20% 98%; --primary: 48 100.0% 50.0%;
--primary-foreground: 220.9 39.3% 11%; --primary-foreground: 220.9 39.3% 11%;
--secondary: 215 27.9% 16.9%; --secondary: 215 27.9% 16.9%;

View File

@ -1,18 +1,20 @@
import { PluginName } from "@/lib/types" import { ModelInfo, Rect } from "@/lib/types"
import { ControlNetMethodMap, Rect, Settings } from "@/lib/store" import { Settings } from "@/lib/states"
import { dataURItoBlob, loadImage, srcToFile } from "@/lib/utils" import { dataURItoBlob, srcToFile } from "@/lib/utils"
import axios from "axios"
export const API_ENDPOINT = import.meta.env.VITE_BACKEND export const API_ENDPOINT = import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND ? import.meta.env.VITE_BACKEND
: "" : ""
const api = axios.create({
baseURL: API_ENDPOINT,
})
export default async function inpaint( export default async function inpaint(
imageFile: File, imageFile: File,
settings: Settings, settings: Settings,
croperRect: Rect, croperRect: Rect,
prompt?: string,
negativePrompt?: string,
seed?: number,
maskBase64?: string, maskBase64?: string,
customMask?: File, customMask?: File,
paintByExampleImage?: File paintByExampleImage?: File
@ -26,38 +28,29 @@ export default async function inpaint(
fd.append("mask", customMask) fd.append("mask", customMask)
} }
const hdSettings = settings.hdSettings[settings.model]
fd.append("ldmSteps", settings.ldmSteps.toString()) fd.append("ldmSteps", settings.ldmSteps.toString())
fd.append("ldmSampler", settings.ldmSampler.toString()) fd.append("ldmSampler", settings.ldmSampler.toString())
fd.append("zitsWireframe", settings.zitsWireframe.toString()) fd.append("zitsWireframe", settings.zitsWireframe.toString())
fd.append("hdStrategy", hdSettings.hdStrategy) fd.append("hdStrategy", "Crop")
fd.append("hdStrategyCropMargin", hdSettings.hdStrategyCropMargin.toString()) fd.append("hdStrategyCropMargin", "128")
fd.append( fd.append("hdStrategyCropTrigerSize", "640")
"hdStrategyCropTrigerSize", fd.append("hdStrategyResizeLimit", "2048")
hdSettings.hdStrategyCropTrigerSize.toString()
)
fd.append(
"hdStrategyResizeLimit",
hdSettings.hdStrategyResizeLimit.toString()
)
fd.append("prompt", prompt === undefined ? "" : prompt) fd.append("prompt", settings.prompt)
fd.append( fd.append("negativePrompt", settings.negativePrompt)
"negativePrompt",
negativePrompt === undefined ? "" : negativePrompt
)
fd.append("croperX", croperRect.x.toString()) fd.append("croperX", croperRect.x.toString())
fd.append("croperY", croperRect.y.toString()) fd.append("croperY", croperRect.y.toString())
fd.append("croperHeight", croperRect.height.toString()) fd.append("croperHeight", croperRect.height.toString())
fd.append("croperWidth", croperRect.width.toString()) fd.append("croperWidth", croperRect.width.toString())
fd.append("useCroper", settings.showCroper ? "true" : "false") // fd.append("useCroper", settings.showCroper ? "true" : "false")
fd.append("useCroper", "false")
fd.append("sdMaskBlur", settings.sdMaskBlur.toString()) fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
fd.append("sdStrength", settings.sdStrength.toString()) fd.append("sdStrength", settings.sdStrength.toString())
fd.append("sdSteps", settings.sdSteps.toString()) fd.append("sdSteps", settings.sdSteps.toString())
fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString()) fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString())
fd.append("sdSampler", settings.sdSampler.toString()) fd.append("sdSampler", settings.sdSampler.toString())
fd.append("sdSeed", seed ? seed.toString() : "-1") fd.append("sdSeed", settings.seed.toString())
fd.append("sdMatchHistograms", settings.sdMatchHistograms ? "true" : "false") fd.append("sdMatchHistograms", settings.sdMatchHistograms ? "true" : "false")
fd.append("sdScale", (settings.sdScale / 100).toString()) fd.append("sdScale", (settings.sdScale / 100).toString())
@ -69,7 +62,7 @@ export default async function inpaint(
"paintByExampleGuidanceScale", "paintByExampleGuidanceScale",
settings.paintByExampleGuidanceScale.toString() settings.paintByExampleGuidanceScale.toString()
) )
fd.append("paintByExampleSeed", seed ? seed.toString() : "-1") fd.append("paintByExampleSeed", settings.seed.toString())
fd.append( fd.append(
"paintByExampleMaskBlur", "paintByExampleMaskBlur",
settings.paintByExampleMaskBlur.toString() settings.paintByExampleMaskBlur.toString()
@ -94,10 +87,7 @@ export default async function inpaint(
"controlnet_conditioning_scale", "controlnet_conditioning_scale",
settings.controlnetConditioningScale.toString() settings.controlnetConditioningScale.toString()
) )
fd.append( fd.append("controlnet_method", settings.controlnetMethod.toString())
"controlnet_method",
ControlNetMethodMap[settings.controlnetMethod.toString()]
)
try { try {
const res = await fetch(`${API_ENDPOINT}/inpaint`, { const res = await fetch(`${API_ENDPOINT}/inpaint`, {
@ -137,6 +127,10 @@ export function currentModel() {
}) })
} }
export function fetchModelInfos(): Promise<ModelInfo[]> {
return api.get("/models").then((response) => response.data)
}
export function isDesktop() { export function isDesktop() {
return fetch(`${API_ENDPOINT}/is_desktop`, { return fetch(`${API_ENDPOINT}/is_desktop`, {
method: "GET", method: "GET",

View File

@ -1,14 +1,17 @@
import { create, StoreApi, UseBoundStore } from "zustand" import { create } from "zustand"
import { persist } from "zustand/middleware" import { persist } from "zustand/middleware"
import { immer } from "zustand/middleware/immer" import { immer } from "zustand/middleware/immer"
import { SortBy, SortOrder } from "./types" import { CV2Flag, LDMSampler, ModelInfo, SortBy, SortOrder } from "./types"
import { DEFAULT_BRUSH_SIZE } from "./const" import { DEFAULT_BRUSH_SIZE } from "./const"
import { SDSampler } from "./store"
type FileManagerState = { type FileManagerState = {
sortBy: SortBy sortBy: SortBy
sortOrder: SortOrder sortOrder: SortOrder
layout: "rows" | "masonry" layout: "rows" | "masonry"
searchText: string searchText: string
inputDirectory: string
outputDirectory: string
} }
type CropperState = { type CropperState = {
@ -18,81 +21,256 @@ type CropperState = {
height: number height: number
} }
export type Settings = {
model: ModelInfo
enableDownloadMask: boolean
enableManualInpainting: boolean
enableUploadMask: boolean
showCroper: boolean
// For LDM
ldmSteps: number
ldmSampler: LDMSampler
// For ZITS
zitsWireframe: boolean
// For OpenCV2
cv2Radius: number
cv2Flag: CV2Flag
// For Diffusion moel
prompt: string
negativePrompt: string
seed: number
seedFixed: boolean
// For SD
sdMaskBlur: number
sdStrength: number
sdSteps: number
sdGuidanceScale: number
sdSampler: SDSampler
sdMatchHistograms: boolean
sdScale: number
// Paint by Example
paintByExampleSteps: number
paintByExampleGuidanceScale: number
paintByExampleMaskBlur: number
paintByExampleMatchHistograms: boolean
// InstructPix2Pix
p2pSteps: number
p2pImageGuidanceScale: number
p2pGuidanceScale: number
// ControlNet
controlnetConditioningScale: number
controlnetMethod: string
}
type ServerConfig = {
plugins: string[]
availableControlNet: Record<string, string[]>
enableFileManager: boolean
enableAutoSaving: boolean
}
type InteractiveSegState = {
isInteractiveSeg: boolean
isInteractiveSegRunning: boolean
clicks: number[][]
}
type AppState = { type AppState = {
file: File | null file: File | null
customMask: File | null
imageHeight: number imageHeight: number
imageWidth: number imageWidth: number
brushSize: number brushSize: number
brushSizeScale: number brushSizeScale: number
isInpainting: boolean isInpainting: boolean
isInteractiveSeg: boolean // 是否正处于 sam 状态 isPluginRunning: boolean
isInteractiveSegRunning: boolean
interactiveSegClicks: number[][]
prompt: string
interactiveSegState: InteractiveSegState
fileManagerState: FileManagerState fileManagerState: FileManagerState
cropperState: CropperState cropperState: CropperState
serverConfig: ServerConfig
settings: Settings
} }
type AppAction = { type AppAction = {
setFile: (file: File) => void setFile: (file: File) => void
setCustomFile: (file: File) => void
setIsInpainting: (newValue: boolean) => void setIsInpainting: (newValue: boolean) => void
setIsPluginRunning: (newValue: boolean) => void
setBrushSize: (newValue: number) => void setBrushSize: (newValue: number) => void
setImageSize: (width: number, height: number) => void setImageSize: (width: number, height: number) => void
setPrompt: (newValue: string) => void
setFileManagerSortBy: (newValue: SortBy) => void
setFileManagerSortOrder: (newValue: SortOrder) => void
setFileManagerLayout: (
newValue: AppState["fileManagerState"]["layout"]
) => void
setFileManagerSearchText: (newValue: string) => void
setCropperX: (newValue: number) => void setCropperX: (newValue: number) => void
setCropperY: (newValue: number) => void setCropperY: (newValue: number) => void
setCropperWidth: (newValue: number) => void setCropperWidth: (newValue: number) => void
setCropperHeight: (newValue: number) => void setCropperHeight: (newValue: number) => void
setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void
updateSettings: (newSettings: Partial<Settings>) => void
updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void
shouldShowPromptInput: () => boolean
}
const defaultValues: AppState = {
file: null,
customMask: null,
imageHeight: 0,
imageWidth: 0,
brushSize: DEFAULT_BRUSH_SIZE,
brushSizeScale: 1,
isInpainting: false,
isPluginRunning: false,
interactiveSegState: {
isInteractiveSeg: false,
isInteractiveSegRunning: false,
clicks: [],
},
cropperState: {
x: 0,
y: 0,
width: 512,
height: 512,
},
fileManagerState: {
sortBy: SortBy.CTIME,
sortOrder: SortOrder.DESCENDING,
layout: "masonry",
searchText: "",
inputDirectory: "",
outputDirectory: "",
},
serverConfig: {
plugins: [],
availableControlNet: { SD: [], SD2: [], SDXL: [] },
enableFileManager: false,
enableAutoSaving: false,
},
settings: {
model: {
name: "lama",
path: "lama",
model_type: "inpaint",
support_controlnet: false,
support_freeu: false,
support_lcm_lora: false,
is_single_file_diffusers: false,
},
showCroper: false,
enableDownloadMask: false,
enableManualInpainting: false,
enableUploadMask: false,
ldmSteps: 30,
ldmSampler: LDMSampler.ddim,
zitsWireframe: true,
cv2Radius: 5,
cv2Flag: CV2Flag.INPAINT_NS,
prompt: "",
negativePrompt: "",
seed: 42,
seedFixed: false,
sdMaskBlur: 5,
sdStrength: 1.0,
sdSteps: 50,
sdGuidanceScale: 7.5,
sdSampler: SDSampler.uni_pc,
sdMatchHistograms: false,
sdScale: 100,
paintByExampleSteps: 50,
paintByExampleGuidanceScale: 7.5,
paintByExampleMaskBlur: 5,
paintByExampleMatchHistograms: false,
p2pSteps: 50,
p2pImageGuidanceScale: 1.5,
p2pGuidanceScale: 7.5,
controlnetConditioningScale: 0.4,
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
},
} }
export const useStore = create<AppState & AppAction>()( export const useStore = create<AppState & AppAction>()(
immer( immer(
persist( persist(
(set, get) => ({ (set, get) => ({
file: null, ...defaultValues,
imageHeight: 0,
imageWidth: 0, shouldShowPromptInput: (): boolean => {
brushSize: DEFAULT_BRUSH_SIZE, const model_type = get().settings.model.model_type
brushSizeScale: 1, return ["diffusers_sd"].includes(model_type)
isInpainting: false,
isInteractiveSeg: false,
isInteractiveSegRunning: false,
interactiveSegClicks: [],
prompt: "",
cropperState: {
x: 0,
y: 0,
width: 0,
height: 0,
}, },
fileManagerState: {
sortBy: SortBy.CTIME, setServerConfig: (newValue: ServerConfig) => {
sortOrder: SortOrder.DESCENDING, set((state: AppState) => {
layout: "masonry", state.serverConfig = newValue
searchText: "", })
}, },
updateSettings: (newSettings: Partial<Settings>) => {
set((state: AppState) => {
state.settings = {
...state.settings,
...newSettings,
}
})
},
updateFileManagerState: (newState: Partial<FileManagerState>) => {
set((state: AppState) => {
state.fileManagerState = {
...state.fileManagerState,
...newState,
}
})
},
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => {
set((state: AppState) => {
state.interactiveSegState = {
...state.interactiveSegState,
...newState,
}
})
},
resetInteractiveSegState: () => {
set((state: AppState) => {
state.interactiveSegState = defaultValues.interactiveSegState
})
},
setIsInpainting: (newValue: boolean) => setIsInpainting: (newValue: boolean) =>
set((state: AppState) => { set((state: AppState) => {
state.isInpainting = newValue state.isInpainting = newValue
}), }),
setIsPluginRunning: (newValue: boolean) =>
set((state: AppState) => {
state.isPluginRunning = newValue
}),
setFile: (file: File) => setFile: (file: File) =>
set((state: AppState) => { set((state: AppState) => {
// TODO: 清空各种状态 // TODO: 清空各种状态
state.file = file state.file = file
}), }),
setCustomFile: (file: File) =>
set((state: AppState) => {
state.customMask = file
}),
setBrushSize: (newValue: number) => setBrushSize: (newValue: number) =>
set((state: AppState) => { set((state: AppState) => {
state.brushSize = newValue state.brushSize = newValue
@ -107,11 +285,6 @@ export const useStore = create<AppState & AppAction>()(
}) })
}, },
setPrompt: (newValue: string) =>
set((state: AppState) => {
state.prompt = newValue
}),
setCropperX: (newValue: number) => setCropperX: (newValue: number) =>
set((state: AppState) => { set((state: AppState) => {
state.cropperState.x = newValue state.cropperState.x = newValue
@ -132,32 +305,18 @@ export const useStore = create<AppState & AppAction>()(
state.cropperState.height = newValue state.cropperState.height = newValue
}), }),
setFileManagerSortBy: (newValue: SortBy) => setSeed: (newValue: number) =>
set((state: AppState) => { set((state: AppState) => {
state.fileManagerState.sortBy = newValue state.settings.seed = newValue
}),
setFileManagerSortOrder: (newValue: SortOrder) =>
set((state: AppState) => {
state.fileManagerState.sortOrder = newValue
}),
setFileManagerLayout: (newValue: "rows" | "masonry") =>
set((state: AppState) => {
state.fileManagerState.layout = newValue
}),
setFileManagerSearchText: (newValue: string) =>
set((state: AppState) => {
state.fileManagerState.searchText = newValue
}), }),
}), }),
{ {
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique) name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
version: 0,
partialize: (state) => partialize: (state) =>
Object.fromEntries( Object.fromEntries(
Object.entries(state).filter(([key]) => Object.entries(state).filter(([key]) =>
["fileManagerState", "prompt"].includes(key) ["fileManagerState", "prompt", "settings"].includes(key)
) )
), ),
} }

View File

@ -1,6 +1,6 @@
import { atom, selector } from "recoil" import { atom, selector } from "recoil"
import _ from "lodash" import _ from "lodash"
import { CV2Flag, HDStrategy, LDMSampler, ModelsHDSettings } from "./types" import { CV2Flag, LDMSampler } from "./types"
export enum AIModel { export enum AIModel {
LAMA = "lama", LAMA = "lama",
@ -320,7 +320,6 @@ export interface Settings {
graduallyInpainting: boolean graduallyInpainting: boolean
runInpaintingManually: boolean runInpaintingManually: boolean
model: AIModel model: AIModel
hdSettings: ModelsHDSettings
// For LDM // For LDM
ldmSteps: number ldmSteps: number
@ -363,107 +362,6 @@ export interface Settings {
controlnetMethod: string controlnetMethod: string
} }
const defaultHDSettings: ModelsHDSettings = {
[AIModel.LAMA]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 800,
hdStrategyCropMargin: 196,
enabled: true,
},
[AIModel.LDM]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1080,
hdStrategyCropTrigerSize: 1080,
hdStrategyCropMargin: 128,
enabled: true,
},
[AIModel.ZITS]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1024,
hdStrategyCropTrigerSize: 1024,
hdStrategyCropMargin: 128,
enabled: true,
},
[AIModel.MAT]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1024,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: true,
},
[AIModel.FCF]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 512,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.SD15]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.ANYTHING4]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.REALISTIC_VISION_1_4]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.SD2]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.PAINT_BY_EXAMPLE]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.PIX2PIX]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.Mange]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1280,
hdStrategyCropTrigerSize: 1024,
hdStrategyCropMargin: 196,
enabled: true,
},
[AIModel.CV2]: {
hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 1080,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: true,
},
[AIModel.KANDINSKY22]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
}
export enum SDSampler { export enum SDSampler {
ddim = "ddim", ddim = "ddim",
pndm = "pndm", pndm = "pndm",
@ -487,7 +385,6 @@ export const settingStateDefault: Settings = {
graduallyInpainting: true, graduallyInpainting: true,
runInpaintingManually: false, runInpaintingManually: false,
model: AIModel.LAMA, model: AIModel.LAMA,
hdSettings: defaultHDSettings,
ldmSteps: 25, ldmSteps: 25,
ldmSampler: LDMSampler.plms, ldmSampler: LDMSampler.plms,
@ -588,24 +485,6 @@ export const seedState = selector({
}, },
}) })
export const hdSettingsState = selector({
key: "hdSettings",
get: ({ get }) => {
const settings = get(settingState)
return settings.hdSettings[settings.model]
},
set: ({ get, set }, newValue: any) => {
const settings = get(settingState)
const hdSettings = settings.hdSettings[settings.model]
const newHDSettings = { ...hdSettings, ...newValue }
set(settingState, {
...settings,
hdSettings: { ...settings.hdSettings, [settings.model]: newHDSettings },
})
},
})
export const isSDState = selector({ export const isSDState = selector({
key: "isSD", key: "isSD",
get: ({ get }) => { get: ({ get }) => {

View File

@ -1,3 +1,19 @@
export interface ModelInfo {
name: string
path: string
model_type:
| "inpaint"
| "diffusers_sd"
| "diffusers_sdxl"
| "diffusers_sd_inpaint"
| "diffusers_sdxl_inpaint"
| "diffusers_other"
support_controlnet: boolean
support_freeu: boolean
support_lcm_lora: boolean
is_single_file_diffusers: boolean
}
export enum PluginName { export enum PluginName {
RemoveBG = "RemoveBG", RemoveBG = "RemoveBG",
AnimeSeg = "AnimeSeg", AnimeSeg = "AnimeSeg",
@ -18,12 +34,6 @@ export enum SortOrder {
ASCENDING = "asc", ASCENDING = "asc",
} }
export enum HDStrategy {
ORIGINAL = "Original",
RESIZE = "Resize",
CROP = "Crop",
}
export enum LDMSampler { export enum LDMSampler {
ddim = "ddim", ddim = "ddim",
plms = "plms", plms = "plms",
@ -34,12 +44,9 @@ export enum CV2Flag {
INPAINT_TELEA = "INPAINT_TELEA", INPAINT_TELEA = "INPAINT_TELEA",
} }
export interface HDSettings { export interface Rect {
hdStrategy: HDStrategy x: number
hdStrategyResizeLimit: number y: number
hdStrategyCropTrigerSize: number width: number
hdStrategyCropMargin: number height: number
enabled: boolean
} }
export type ModelsHDSettings = { [key in AIModel]: HDSettings }

View File

@ -1,16 +1,18 @@
import React from "react" import React from "react"
import ReactDOM from "react-dom/client" import ReactDOM from "react-dom/client"
import { RecoilRoot } from "recoil" import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import App from "./App.tsx" import App from "./App.tsx"
import "./globals.css" import "./globals.css"
import { ThemeProvider } from "./components/theme-provider.tsx" import { ThemeProvider } from "./components/theme-provider.tsx"
const queryClient = new QueryClient()
ReactDOM.createRoot(document.getElementById("root")!).render( ReactDOM.createRoot(document.getElementById("root")!).render(
<React.StrictMode> <React.StrictMode>
<ThemeProvider defaultTheme="dark" storageKey="vite-ui-theme"> <QueryClientProvider client={queryClient}>
<RecoilRoot> <ThemeProvider defaultTheme="dark" storageKey="vite-ui-theme">
<App /> <App />
</RecoilRoot> </ThemeProvider>
</ThemeProvider> </QueryClientProvider>
</React.StrictMode> </React.StrictMode>
) )

View File

@ -2,11 +2,11 @@
module.exports = { module.exports = {
darkMode: ["class"], darkMode: ["class"],
content: [ content: [
'./pages/**/*.{ts,tsx}', "./pages/**/*.{ts,tsx}",
'./components/**/*.{ts,tsx}', "./components/**/*.{ts,tsx}",
'./app/**/*.{ts,tsx}', "./app/**/*.{ts,tsx}",
'./src/**/*.{ts,tsx}', "./src/**/*.{ts,tsx}",
], ],
theme: { theme: {
container: { container: {
center: true, center: true,
@ -73,4 +73,4 @@ module.exports = {
}, },
}, },
plugins: [require("tailwindcss-animate")], plugins: [require("tailwindcss-animate")],
} }