This commit is contained in:
Qing 2023-12-24 15:32:27 +08:00
parent 0e5e16ba20
commit 371db2d771
31 changed files with 441 additions and 439 deletions

View File

@ -10,11 +10,8 @@ from lama_cleaner.parse_args import parse_args
def entry_point(): def entry_point():
args = parse_args()
if args is None:
return
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
from lama_cleaner.server import main from lama_cleaner.server import typer_app
main(args) typer_app()

View File

@ -103,6 +103,5 @@ if __name__ == "__main__":
device=device, device=device,
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
hf_access_token="123"
) )
benchmark(model, args.times, args.empty_cache) benchmark(model, args.times, args.empty_cache)

View File

@ -21,19 +21,17 @@ AVAILABLE_MODELS = [
"zits", "zits",
"mat", "mat",
"fcf", "fcf",
"manga",
"cv2",
"sd1.5", "sd1.5",
"sdxl",
"anything4", "anything4",
"realisticVision1.4", "realisticVision1.4",
"cv2",
"manga",
"sd2", "sd2",
"sdxl",
"paint_by_example", "paint_by_example",
"instruct_pix2pix", "instruct_pix2pix",
"kandinsky2.2", "kandinsky2.2",
"sdxl",
] ]
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
DIFFUSERS_MODEL_FP16_REVERSION = [ DIFFUSERS_MODEL_FP16_REVERSION = [
"runwayml/stable-diffusion-inpainting", "runwayml/stable-diffusion-inpainting",
"Sanster/anything-4.0-inpainting", "Sanster/anything-4.0-inpainting",
@ -46,26 +44,22 @@ AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
DEFAULT_DEVICE = "cuda" DEFAULT_DEVICE = "cuda"
NO_HALF_HELP = """ NO_HALF_HELP = """
Using full precision model. Using full precision(fp32) model.
If your generate result is always black or green, use this argument. (sd/paint_by_exmaple) If your diffusion model generate result is always black or green, use this argument.
""" """
CPU_OFFLOAD_HELP = """ CPU_OFFLOAD_HELP = """
Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example) Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
""" """
DISABLE_NSFW_HELP = """ DISABLE_NSFW_HELP = """
Disable NSFW checker. (sd/paint_by_example) Disable NSFW checker for diffusion model.
""" """
SD_CPU_TEXTENCODER_HELP = """ CPU_TEXTENCODER_HELP = """
Run Stable Diffusion text encoder model on CPU to save GPU memory. Run diffusion models text encoder on CPU to reduce vRAM usage.
""" """
SD_CONTROLNET_HELP = """
Run Stable Diffusion normal or inpainting model with ControlNet.
"""
DEFAULT_SD_CONTROLNET_METHOD = "lllyasviel/control_v11p_sd15_canny"
SD_CONTROLNET_CHOICES = [ SD_CONTROLNET_CHOICES = [
"lllyasviel/control_v11p_sd15_canny", "lllyasviel/control_v11p_sd15_canny",
# "lllyasviel/control_v11p_sd15_seg", # "lllyasviel/control_v11p_sd15_seg",
@ -74,46 +68,36 @@ SD_CONTROLNET_CHOICES = [
"lllyasviel/control_v11f1p_sd15_depth", "lllyasviel/control_v11f1p_sd15_depth",
] ]
DEFAULT_SD2_CONTROLNET_METHOD = "thibaud/controlnet-sd21-canny-diffusers"
SD2_CONTROLNET_CHOICES = [ SD2_CONTROLNET_CHOICES = [
"thibaud/controlnet-sd21-canny-diffusers", "thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers", "thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-openpose-diffusers", "thibaud/controlnet-sd21-openpose-diffusers",
] ]
DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0"
SDXL_CONTROLNET_CHOICES = [ SDXL_CONTROLNET_CHOICES = [
"thibaud/controlnet-openpose-sdxl-1.0", "thibaud/controlnet-openpose-sdxl-1.0",
"destitech/controlnet-inpaint-dreamer-sdxl" "destitech/controlnet-inpaint-dreamer-sdxl",
"diffusers/controlnet-canny-sdxl-1.0", "diffusers/controlnet-canny-sdxl-1.0",
"diffusers/controlnet-canny-sdxl-1.0-mid", "diffusers/controlnet-canny-sdxl-1.0-mid",
"diffusers/controlnet-canny-sdxl-1.0-small" "diffusers/controlnet-canny-sdxl-1.0-small",
"diffusers/controlnet-depth-sdxl-1.0", "diffusers/controlnet-depth-sdxl-1.0",
"diffusers/controlnet-depth-sdxl-1.0-mid", "diffusers/controlnet-depth-sdxl-1.0-mid",
"diffusers/controlnet-depth-sdxl-1.0-small", "diffusers/controlnet-depth-sdxl-1.0-small",
] ]
SD_LOCAL_MODEL_HELP = """
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
"""
LOCAL_FILES_ONLY_HELP = """ LOCAL_FILES_ONLY_HELP = """
Use local files only, not connect to Hugging Face server. (sd/paint_by_example) When loading diffusion models, using local files only, not connect to HuggingFace server.
"""
ENABLE_XFORMERS_HELP = """
Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
""" """
DEFAULT_MODEL_DIR = os.getenv( DEFAULT_MODEL_DIR = os.getenv(
"XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
) )
MODEL_DIR_HELP = """ MODEL_DIR_HELP = f"""
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
""" """
OUTPUT_DIR_HELP = """ OUTPUT_DIR_HELP = """
Result images will be saved to output directory automatically without confirmation. Result images will be saved to output directory automatically.
""" """
INPUT_HELP = """ INPUT_HELP = """
@ -125,37 +109,45 @@ GUI_HELP = """
Launch Lama Cleaner as desktop app Launch Lama Cleaner as desktop app
""" """
NO_GUI_AUTO_CLOSE_HELP = """
Prevent backend auto close after the GUI window closed.
"""
QUALITY_HELP = """ QUALITY_HELP = """
Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size. Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
""" """
class RealESRGANModelName(str, Enum): class Choices(str, Enum):
@classmethod
def values(cls):
return [member.value for member in cls]
class RealESRGANModel(Choices):
realesr_general_x4v3 = "realesr-general-x4v3" realesr_general_x4v3 = "realesr-general-x4v3"
RealESRGAN_x4plus = "RealESRGAN_x4plus" RealESRGAN_x4plus = "RealESRGAN_x4plus"
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
RealESRGANModelNameList = [e.value for e in RealESRGANModelName] class Device(Choices):
cpu = "cpu"
cuda = "cuda"
mps = "mps"
class InteractiveSegModel(Choices):
vit_b = "vit_b"
vit_l = "vit_l"
vit_h = "vit_h"
mobile_sam = "mobile_sam"
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h", "vit_t"]
AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"]
REMOVE_BG_HELP = "Enable remove background. Always run on CPU" REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU" ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
REALESRGAN_HELP = "Enable realesrgan super resolution" REALESRGAN_HELP = "Enable realesrgan super resolution"
REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
GFPGAN_HELP = ( GFPGAN_HELP = (
"Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan" "Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan"
) )
GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan" RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan"
RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image" GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
@ -164,8 +156,6 @@ class Config(BaseModel):
port: int = 8080 port: int = 8080
model: str = DEFAULT_MODEL model: str = DEFAULT_MODEL
sd_local_model_path: str = None sd_local_model_path: str = None
sd_controlnet: bool = False
sd_controlnet_method: str = DEFAULT_SD_CONTROLNET_METHOD
device: str = DEFAULT_DEVICE device: str = DEFAULT_DEVICE
gui: bool = False gui: bool = False
no_gui_auto_close: bool = False no_gui_auto_close: bool = False
@ -173,7 +163,6 @@ class Config(BaseModel):
cpu_offload: bool = False cpu_offload: bool = False
disable_nsfw: bool = False disable_nsfw: bool = False
sd_cpu_textencoder: bool = False sd_cpu_textencoder: bool = False
enable_xformers: bool = False
local_files_only: bool = False local_files_only: bool = False
model_dir: str = DEFAULT_MODEL_DIR model_dir: str = DEFAULT_MODEL_DIR
input: str = None input: str = None
@ -186,7 +175,7 @@ class Config(BaseModel):
enable_anime_seg: bool = False enable_anime_seg: bool = False
enable_realesrgan: bool = False enable_realesrgan: bool = False
realesrgan_device: str = "cpu" realesrgan_device: str = "cpu"
realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value realesrgan_model: str = RealESRGANModel.realesr_general_x4v3.value
realesrgan_no_half: bool = False realesrgan_no_half: bool = False
enable_gfpgan: bool = False enable_gfpgan: bool = False
gfpgan_device: str = "cpu" gfpgan_device: str = "cpu"

View File

@ -6,6 +6,7 @@ from loguru import logger
from pathlib import Path from pathlib import Path
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
from lama_cleaner.runtime import setup_model_dir
from lama_cleaner.schema import ( from lama_cleaner.schema import (
ModelInfo, ModelInfo,
ModelType, ModelType,
@ -16,16 +17,8 @@ from lama_cleaner.schema import (
) )
def cli_download_model(model: str, model_dir: str): def cli_download_model(model: str, model_dir: Path):
if os.path.isfile(model_dir): setup_model_dir(model_dir)
raise ValueError(f"invalid --model-dir: {model_dir} is a file")
if not os.path.exists(model_dir):
logger.info(f"Create model cache directory: {model_dir}")
Path(model_dir).mkdir(exist_ok=True, parents=True)
os.environ["XDG_CACHE_HOME"] = model_dir
from lama_cleaner.model import models from lama_cleaner.model import models
if model in models: if model in models:
@ -38,7 +31,7 @@ def cli_download_model(model: str, model_dir: str):
downloaded_path = DiffusionPipeline.download( downloaded_path = DiffusionPipeline.download(
pretrained_model_name=model, pretrained_model_name=model,
revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main", variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
resume_download=True, resume_download=True,
) )
logger.info(f"Done. Downloaded to {downloaded_path}") logger.info(f"Done. Downloaded to {downloaded_path}")
@ -101,7 +94,7 @@ def scan_inpaint_models() -> List[ModelInfo]:
from lama_cleaner.model import models from lama_cleaner.model import models
for name, m in models.items(): for name, m in models.items():
if m.is_erase_model: if m.is_erase_model and m.is_downloaded():
res.append( res.append(
ModelInfo( ModelInfo(
name=name, name=name,

View File

@ -41,7 +41,7 @@ class InpaintModel:
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def is_downloaded() -> bool: def is_downloaded() -> bool:
... return False
@abc.abstractmethod @abc.abstractmethod
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
@ -67,6 +67,8 @@ class InpaintModel:
logger.info(f"final forward pad size: {pad_image.shape}") logger.info(f"final forward pad size: {pad_image.shape}")
image, mask = self.forward_pre_process(image, mask, config)
result = self.forward(pad_image, pad_mask, config) result = self.forward(pad_image, pad_mask, config)
result = result[0:origin_height, 0:origin_width, :] result = result[0:origin_height, 0:origin_width, :]
@ -77,6 +79,9 @@ class InpaintModel:
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
return result return result
def forward_pre_process(self, image, mask, config):
return image, mask
def forward_post_process(self, result, image, mask, config): def forward_post_process(self, result, image, mask, config):
return result, image, mask return result, image, mask
@ -400,6 +405,13 @@ class DiffusionInpaintModel(InpaintModel):
scheduler = get_scheduler(sd_sampler, scheduler_config) scheduler = get_scheduler(sd_sampler, scheduler_config)
self.model.scheduler = scheduler self.model.scheduler = scheduler
def forward_pre_process(self, image, mask, config):
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
return image, mask
def forward_post_process(self, result, image, mask, config): def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms: if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask) result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -17,14 +17,6 @@ from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import get_scheduler from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, ModelInfo, ModelType from lama_cleaner.schema import Config, ModelInfo, ModelType
# 为了兼容性
controlnet_name_map = {
"control_v11p_sd15_canny": "lllyasviel/control_v11p_sd15_canny",
"control_v11p_sd15_openpose": "lllyasviel/control_v11p_sd15_openpose",
"control_v11p_sd15_inpaint": "lllyasviel/control_v11p_sd15_inpaint",
"control_v11f1p_sd15_depth": "lllyasviel/control_v11f1p_sd15_depth",
}
class ControlNet(DiffusionInpaintModel): class ControlNet(DiffusionInpaintModel):
name = "controlnet" name = "controlnet"
@ -49,9 +41,6 @@ class ControlNet(DiffusionInpaintModel):
fp16 = not kwargs.get("no_half", False) fp16 = not kwargs.get("no_half", False)
model_info: ModelInfo = kwargs["model_info"] model_info: ModelInfo = kwargs["model_info"]
sd_controlnet_method = kwargs["sd_controlnet_method"] sd_controlnet_method = kwargs["sd_controlnet_method"]
sd_controlnet_method = controlnet_name_map.get(
sd_controlnet_method, sd_controlnet_method
)
self.model_info = model_info self.model_info = model_info
self.sd_controlnet_method = sd_controlnet_method self.sd_controlnet_method = sd_controlnet_method
@ -113,12 +102,6 @@ class ControlNet(DiffusionInpaintModel):
**model_kwargs, **model_kwargs,
) )
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)
@ -162,10 +145,6 @@ class ControlNet(DiffusionInpaintModel):
scheduler = get_scheduler(config.sd_sampler, scheduler_config) scheduler = get_scheduler(config.sd_sampler, scheduler_config)
self.model.scheduler = scheduler self.model.scheduler = scheduler
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2] img_h, img_w = image.shape[:2]
control_image = self._get_control_image(image, mask) control_image = self._get_control_image(image, mask)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
@ -190,8 +169,3 @@ class ControlNet(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -31,30 +31,15 @@ class InstructPix2Pix(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype,
**model_kwargs
) )
self.model.enable_attention_slicing()
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)
else: else:
self.model = self.model.to(device) self.model = self.model.to(device)
@staticmethod
def download():
from diffusers import StableDiffusionInstructPix2PixPipeline
StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", revision="fp16"
)
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -76,8 +61,3 @@ class InstructPix2Pix(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel):
} }
self.model = AutoPipelineForInpainting.from_pretrained( self.model = AutoPipelineForInpainting.from_pretrained(
self.model_name, **model_kwargs self.model_id_or_path, **model_kwargs
).to(device) ).to(device)
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
@ -40,9 +40,6 @@ class Kandinsky(DiffusionInpaintModel):
self.model.scheduler = scheduler self.model.scheduler = scheduler
generator = torch.manual_seed(config.sd_seed) generator = torch.manual_seed(config.sd_seed)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
mask = mask.astype(np.float32) / 255 mask = mask.astype(np.float32) / 255
img_h, img_w = image.shape[:2] img_h, img_w = image.shape[:2]
@ -66,20 +63,7 @@ class Kandinsky(DiffusionInpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True
class Kandinsky22(Kandinsky): class Kandinsky22(Kandinsky):
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" name = "kandinsky2.2"
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
@staticmethod
def download():
from diffusers import AutoPipelineForInpainting
AutoPipelineForInpainting.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder-inpaint"
)

View File

@ -31,10 +31,6 @@ class PaintByExample(DiffusionInpaintModel):
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs "Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
) )
self.model.enable_attention_slicing()
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
# TODO: gpu_id # TODO: gpu_id
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device) self.model.image_encoder = self.model.image_encoder.to(device)
@ -68,8 +64,3 @@ class PaintByExample(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -1,8 +1,5 @@
import os
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np
import torch import torch
from loguru import logger from loguru import logger
@ -49,23 +46,12 @@ class SD(DiffusionInpaintModel):
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
revision="fp16" revision="fp16"
if ( if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
and use_gpu
and fp16
)
else "main", else "main",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"],
**model_kwargs, **model_kwargs,
) )
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
# TODO: gpu_id # TODO: gpu_id
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
@ -88,10 +74,6 @@ class SD(DiffusionInpaintModel):
""" """
self.set_scheduler(config) self.set_scheduler(config)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2] img_h, img_w = image.shape[:2]
output = self.model( output = self.model(
@ -114,17 +96,6 @@ class SD(DiffusionInpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True
@classmethod
def download(cls):
from diffusers import StableDiffusionInpaintPipeline
StableDiffusionInpaintPipeline.from_pretrained(cls.model_id_or_path)
class SD15(SD): class SD15(SD):
name = "sd1.5" name = "sd1.5"

View File

@ -45,16 +45,9 @@ class SDXL(DiffusionInpaintModel):
self.model_id_or_path, self.model_id_or_path,
revision="main", revision="main",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"],
vae=vae, vae=vae,
) )
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)
@ -65,14 +58,6 @@ class SDXL(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
@staticmethod
def download():
from diffusers import AutoPipelineForInpainting
AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
)
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -81,10 +66,6 @@ class SDXL(DiffusionInpaintModel):
""" """
self.set_scheduler(config) self.set_scheduler(config)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2] img_h, img_w = image.shape[:2]
output = self.model( output = self.model(
@ -106,8 +87,3 @@ class SDXL(DiffusionInpaintModel):
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -3,7 +3,6 @@ from typing import List, Dict
import torch import torch
from loguru import logger from loguru import logger
from lama_cleaner.const import DEFAULT_SD_CONTROLNET_METHOD
from lama_cleaner.download import scan_models from lama_cleaner.download import scan_models
from lama_cleaner.helper import switch_mps_device from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model import models, ControlNet, SD, SDXL from lama_cleaner.model import models, ControlNet, SD, SDXL
@ -19,16 +18,25 @@ class ModelManager:
self.available_models: Dict[str, ModelInfo] = {} self.available_models: Dict[str, ModelInfo] = {}
self.scan_models() self.scan_models()
self.sd_controlnet = kwargs.get("sd_controlnet", False) self.sd_controlnet = False
self.sd_controlnet_method = kwargs.get( self.sd_controlnet_method = ""
"sd_controlnet_method", DEFAULT_SD_CONTROLNET_METHOD
)
self.model = self.init_model(name, device, **kwargs) self.model = self.init_model(name, device, **kwargs)
def init_model(self, name: str, device, **kwargs): def _map_old_name(self, name: str) -> str:
for old_name, model_cls in models.items(): for old_name, model_cls in models.items():
if name == old_name and hasattr(model_cls, "model_id_or_path"): if name == old_name and hasattr(model_cls, "model_id_or_path"):
name = model_cls.model_id_or_path name = model_cls.model_id_or_path
break
return name
@property
def current_model(self) -> Dict:
name = self._map_old_name(self.name)
return self.available_models[name].model_dump()
def init_model(self, name: str, device, **kwargs):
name = self._map_old_name(name)
logger.info(f"Loading model: {name}")
if name not in self.available_models: if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}") raise NotImplementedError(f"Unsupported model: {name}")
@ -86,6 +94,7 @@ class ModelManager:
): ):
self.sd_controlnet_method = self.available_models[new_name].controlnets[0] self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
try: try:
# TODO: enable/disable controlnet without reload model
del self.model del self.model
torch_gc() torch_gc()

View File

@ -55,7 +55,7 @@ def parse_args():
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
parser.add_argument( parser.add_argument(
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP "--sd-cpu-textencoder", action="store_true", help=CPU_TEXTENCODER_HELP
) )
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
parser.add_argument( parser.add_argument(
@ -66,16 +66,10 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
) )
parser.add_argument(
"--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP
)
parser.add_argument( parser.add_argument(
"--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES "--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES
) )
parser.add_argument("--gui", action="store_true", help=GUI_HELP) parser.add_argument("--gui", action="store_true", help=GUI_HELP)
parser.add_argument(
"--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP
)
parser.add_argument( parser.add_argument(
"--gui-size", "--gui-size",
default=[1600, 1000], default=[1600, 1000],

View File

@ -22,7 +22,7 @@ SEGMENT_ANYTHING_MODELS = {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"md5": "4b8939a88964f0f4ff5f5b2642c598a6", "md5": "4b8939a88964f0f4ff5f5b2642c598a6",
}, },
"vit_t": { "mobile_sam": {
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt", "url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
"md5": "f3c0d8cda613564d499310dab6c812cd", "md5": "f3c0d8cda613564d499310dab6c812cd",
}, },

View File

@ -3,7 +3,7 @@ from enum import Enum
import cv2 import cv2
from loguru import logger from loguru import logger
from lama_cleaner.const import RealESRGANModelName from lama_cleaner.const import RealESRGANModel
from lama_cleaner.helper import download_model from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin from lama_cleaner.plugins.base_plugin import BasePlugin
@ -18,7 +18,7 @@ class RealESRGANUpscaler(BasePlugin):
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
REAL_ESRGAN_MODELS = { REAL_ESRGAN_MODELS = {
RealESRGANModelName.realesr_general_x4v3: { RealESRGANModel.realesr_general_x4v3: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
"scale": 4, "scale": 4,
"model": lambda: SRVGGNetCompact( "model": lambda: SRVGGNetCompact(
@ -31,7 +31,7 @@ class RealESRGANUpscaler(BasePlugin):
), ),
"model_md5": "91a7644643c884ee00737db24e478156", "model_md5": "91a7644643c884ee00737db24e478156",
}, },
RealESRGANModelName.RealESRGAN_x4plus: { RealESRGANModel.RealESRGAN_x4plus: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"scale": 4, "scale": 4,
"model": lambda: RRDBNet( "model": lambda: RRDBNet(
@ -44,7 +44,7 @@ class RealESRGANUpscaler(BasePlugin):
), ),
"model_md5": "99ec365d4afad750833258a1a24f44ca", "model_md5": "99ec365d4afad750833258a1a24f44ca",
}, },
RealESRGANModelName.RealESRGAN_x4plus_anime_6B: { RealESRGANModel.RealESRGAN_x4plus_anime_6B: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"scale": 4, "scale": 4,
"model": lambda: RRDBNet( "model": lambda: RRDBNet(

View File

@ -109,7 +109,7 @@ sam_model_registry = {
"vit_h": build_sam, "vit_h": build_sam,
"vit_l": build_sam_vit_l, "vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b, "vit_b": build_sam_vit_b,
"vit_t": build_sam_vit_t, "mobile_sam": build_sam_vit_t,
} }

View File

@ -1,10 +1,16 @@
# https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py # https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py
import os
import platform import platform
import sys import sys
from pathlib import Path
import packaging.version import packaging.version
from loguru import logger
from rich import print from rich import print
from typing import Dict, Any from typing import Dict, Any
from lama_cleaner.const import Device
_PY_VERSION: str = sys.version.split()[0].rstrip("+") _PY_VERSION: str = sys.version.split()[0].rstrip("+")
if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"): if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"):
@ -21,7 +27,6 @@ _CANDIDATES = [
"diffusers", "diffusers",
"transformers", "transformers",
"opencv-python", "opencv-python",
"xformers",
"accelerate", "accelerate",
"lama-cleaner", "lama-cleaner",
"rembg", "rembg",
@ -38,7 +43,7 @@ for name in _CANDIDATES:
def dump_environment_info() -> Dict[str, str]: def dump_environment_info() -> Dict[str, str]:
"""Dump information about the machine to help debugging issues. """ """Dump information about the machine to help debugging issues."""
# Generic machine info # Generic machine info
info: Dict[str, Any] = { info: Dict[str, Any] = {
@ -48,3 +53,34 @@ def dump_environment_info() -> Dict[str, str]:
info.update(_package_versions) info.update(_package_versions)
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
return info return info
def check_device(device: Device) -> Device:
if device == Device.cuda:
import platform
if platform.system() == "Darwin":
logger.warning("MacOS does not support cuda, use cpu instead")
return Device.cpu
else:
import torch
if not torch.cuda.is_available():
logger.warning("CUDA is not available, use cpu instead")
return Device.cpu
elif device == Device.mps:
import torch
if not torch.backends.mps.is_available():
logger.warning("mps is not available, use cpu instead")
return Device.cpu
return device
def setup_model_dir(model_dir: Path):
model_dir = model_dir.expanduser().absolute()
os.environ["U2NET_HOME"] = str(model_dir)
os.environ["XDG_CACHE_HOME"] = str(model_dir)
if not model_dir.exists():
logger.info(f"Create model directory: {model_dir}")
model_dir.mkdir(exist_ok=True, parents=True)

View File

@ -1,10 +1,18 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import json import json
import os import os
import hashlib
import traceback import typer
from typer import Option
from lama_cleaner.download import cli_download_model, scan_models
from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import hashlib
import traceback
from dataclasses import dataclass
import imghdr import imghdr
import io import io
@ -20,12 +28,7 @@ import torch
from PIL import Image from PIL import Image
from loguru import logger from loguru import logger
from lama_cleaner.const import ( from lama_cleaner.const import *
SD15_MODELS,
SD_CONTROLNET_CHOICES,
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
)
from lama_cleaner.file_manager import FileManager from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
@ -39,6 +42,8 @@ from lama_cleaner.plugins import (
) )
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
try: try:
torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_override_can_fuse_on_gpu(False)
@ -103,23 +108,34 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition", "X-seed"]) CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
sio_logger = logging.getLogger("sio-logger") sio_logger = logging.getLogger("sio-logger")
sio_logger.setLevel(logging.ERROR) sio_logger.setLevel(logging.ERROR)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading") socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
model: ModelManager = None
thumb: FileManager = None @dataclass
output_dir: str = None class GlobalConfig:
device = None model_manager: ModelManager = None
input_image_path: str = None file_manager: FileManager = None
is_disable_model_switch: bool = False output_dir: Path = None
enable_file_manager: bool = False input_image_path: Path = None
enable_auto_saving: bool = False disable_model_switch: bool = False
is_desktop: bool = False is_desktop: bool = False
image_quality: int = 95 image_quality: int = 95
plugins = {} plugins = {}
@property
def enable_auto_saving(self) -> bool:
return self.output_dir is not None
@property
def enable_file_manager(self) -> bool:
return self.file_manager is not None
global_config = GlobalConfig()
def get_image_ext(img_bytes): def get_image_ext(img_bytes):
@ -135,7 +151,7 @@ def diffuser_callback(i, t, latents):
@app.route("/save_image", methods=["POST"]) @app.route("/save_image", methods=["POST"])
def save_image(): def save_image():
if output_dir is None: if global_config.output_dir is None:
return "--output-dir is None", 500 return "--output-dir is None", 500
input = request.files input = request.files
@ -143,7 +159,7 @@ def save_image():
origin_image_bytes = input["image"].read() # RGB origin_image_bytes = input["image"].read() # RGB
ext = get_image_ext(origin_image_bytes) ext = get_image_ext(origin_image_bytes)
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True) image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
save_path = os.path.join(output_dir, filename) save_path = str(global_config.output_dir / filename)
if alpha_channel is not None: if alpha_channel is not None:
if alpha_channel.shape[:2] != image.shape[:2]: if alpha_channel.shape[:2] != image.shape[:2]:
@ -157,7 +173,7 @@ def save_image():
img_bytes = pil_to_bytes( img_bytes = pil_to_bytes(
pil_image, pil_image,
ext, ext,
quality=image_quality, quality=global_config.image_quality,
exif_infos=exif_infos, exif_infos=exif_infos,
) )
with open(save_path, "wb") as fw: with open(save_path, "wb") as fw:
@ -169,9 +185,11 @@ def save_image():
@app.route("/medias/<tab>") @app.route("/medias/<tab>")
def medias(tab): def medias(tab):
if tab == "image": if tab == "image":
response = make_response(jsonify(thumb.media_names), 200) response = make_response(jsonify(global_config.file_manager.media_names), 200)
else: else:
response = make_response(jsonify(thumb.output_media_names), 200) response = make_response(
jsonify(global_config.file_manager.output_media_names), 200
)
# response.last_modified = thumb.modified_time[tab] # response.last_modified = thumb.modified_time[tab]
# response.cache_control.no_cache = True # response.cache_control.no_cache = True
# response.cache_control.max_age = 0 # response.cache_control.max_age = 0
@ -182,8 +200,8 @@ def medias(tab):
@app.route("/media/<tab>/<filename>") @app.route("/media/<tab>/<filename>")
def media_file(tab, filename): def media_file(tab, filename):
if tab == "image": if tab == "image":
return send_from_directory(thumb.root_directory, filename) return send_from_directory(global_config.file_manager.root_directory, filename)
return send_from_directory(thumb.output_dir, filename) return send_from_directory(global_config.file_manager.output_dir, filename)
@app.route("/media_thumbnail/<tab>/<filename>") @app.route("/media_thumbnail/<tab>/<filename>")
@ -198,10 +216,10 @@ def media_thumbnail_file(tab, filename):
if height: if height:
height = int(float(height)) height = int(float(height))
directory = thumb.root_directory directory = global_config.file_manager.root_directory
if tab == "output": if tab == "output":
directory = thumb.output_dir directory = global_config.file_manager.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail( thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
directory, filename, width, height directory, filename, width, height
) )
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
@ -257,13 +275,11 @@ def process():
croper_y=form["croperY"], croper_y=form["croperY"],
croper_height=form["croperHeight"], croper_height=form["croperHeight"],
croper_width=form["croperWidth"], croper_width=form["croperWidth"],
use_extender=form["useExtender"], use_extender=form["useExtender"],
extender_x=form["extenderX"], extender_x=form["extenderX"],
extender_y=form["extenderY"], extender_y=form["extenderY"],
extender_height=form["extenderHeight"], extender_height=form["extenderHeight"],
extender_width=form["extenderWidth"], extender_width=form["extenderWidth"],
sd_scale=form["sdScale"], sd_scale=form["sdScale"],
sd_mask_blur=form["sdMaskBlur"], sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"], sd_strength=form["sdStrength"],
@ -294,7 +310,7 @@ def process():
start = time.time() start = time.time()
try: try:
res_np_img = model(image, mask, config) res_np_img = global_config.model_manager(image, mask, config)
except RuntimeError as e: except RuntimeError as e:
if "CUDA out of memory. " in str(e): if "CUDA out of memory. " in str(e):
# NOTE: the string may change? # NOTE: the string may change?
@ -322,7 +338,7 @@ def process():
pil_to_bytes( pil_to_bytes(
Image.fromarray(res_np_img), Image.fromarray(res_np_img),
ext, ext,
quality=image_quality, quality=global_config.image_quality,
exif_infos=exif_infos, exif_infos=exif_infos,
) )
) )
@ -345,7 +361,7 @@ def run_plugin():
form = request.form form = request.form
files = request.files files = request.files
name = form["name"] name = form["name"]
if name not in plugins: if name not in global_config.plugins:
return "Plugin not found", 500 return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB origin_image_bytes = files["image"].read() # RGB
@ -359,7 +375,7 @@ def run_plugin():
if name == InteractiveSeg.name: if name == InteractiveSeg.name:
img_md5 = hashlib.md5(origin_image_bytes).hexdigest() img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
form["img_md5"] = img_md5 form["img_md5"] = img_md5
bgr_res = plugins[name](rgb_np_img, files, form) bgr_res = global_config.plugins[name](rgb_np_img, files, form)
except RuntimeError as e: except RuntimeError as e:
torch.cuda.empty_cache() torch.cuda.empty_cache()
if "CUDA out of memory. " in str(e): if "CUDA out of memory. " in str(e):
@ -401,7 +417,7 @@ def run_plugin():
pil_to_bytes( pil_to_bytes(
Image.fromarray(rgb_res), Image.fromarray(rgb_res),
ext, ext,
quality=image_quality, quality=global_config.image_quality,
exif_infos=exif_infos, exif_infos=exif_infos,
) )
), ),
@ -414,41 +430,40 @@ def run_plugin():
@app.route("/server_config", methods=["GET"]) @app.route("/server_config", methods=["GET"])
def get_server_config(): def get_server_config():
return { return {
"plugins": list(plugins.keys()), "plugins": list(global_config.plugins.keys()),
"enableFileManager": enable_file_manager, "enableFileManager": global_config.enable_file_manager,
"enableAutoSaving": enable_auto_saving, "enableAutoSaving": global_config.enable_auto_saving,
"enableControlnet": model.sd_controlnet, "enableControlnet": global_config.model_manager.sd_controlnet,
"controlnetMethod": model.sd_controlnet_method, "controlnetMethod": global_config.model_manager.sd_controlnet_method,
"disableModelSwitch": is_disable_model_switch, "disableModelSwitch": global_config.disable_model_switch,
"isDesktop": global_config.is_desktop,
}, 200 }, 200
@app.route("/models", methods=["GET"]) @app.route("/models", methods=["GET"])
def get_models(): def get_models():
return [it.model_dump() for it in model.scan_models()] return [it.model_dump() for it in global_config.model_manager.scan_models()]
@app.route("/model") @app.route("/model")
def current_model(): def current_model():
return model.available_models[model.name].model_dump(), 200 return (
global_config.model_manager.current_model,
200,
@app.route("/is_desktop") )
def get_is_desktop():
return str(is_desktop), 200
@app.route("/model", methods=["POST"]) @app.route("/model", methods=["POST"])
def switch_model(): def switch_model():
if is_disable_model_switch: if global_config.disable_model_switch:
return "Switch model is disabled", 400 return "Switch model is disabled", 400
new_name = request.form.get("name") new_name = request.form.get("name")
if new_name == model.name: if new_name == global_config.model_manager.name:
return "Same model", 200 return "Same model", 200
try: try:
model.switch(new_name) global_config.model_manager.switch(new_name)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
error_message = f"{type(e).__name__} - {str(e)}" error_message = f"{type(e).__name__} - {str(e)}"
@ -464,160 +479,230 @@ def index():
@app.route("/inputimage") @app.route("/inputimage")
def get_cli_input_image(): def get_cli_input_image():
if input_image_path: if global_config.input_image_path:
with open(input_image_path, "rb") as f: with open(global_config.input_image_path, "rb") as f:
image_in_bytes = f.read() image_in_bytes = f.read()
return send_file( return send_file(
input_image_path, global_config.input_image_path,
as_attachment=True, as_attachment=True,
download_name=Path(input_image_path).name, download_name=Path(global_config.input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}", mimetype=f"image/{get_image_ext(image_in_bytes)}",
) )
else: else:
return "No Input Image" return "No Input Image"
def build_plugins(args): def build_plugins(
global plugins enable_interactive_seg: bool,
if args.enable_interactive_seg: interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: str,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
no_half: bool,
):
if enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin") logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg( global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
args.interactive_seg_model, args.interactive_seg_device interactive_seg_model, interactive_seg_device
) )
if args.enable_remove_bg: if enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin") logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG() global_config.plugins[RemoveBG.name] = RemoveBG()
if args.enable_anime_seg: if enable_anime_seg:
logger.info(f"Initialize {AnimeSeg.name} plugin") logger.info(f"Initialize {AnimeSeg.name} plugin")
plugins[AnimeSeg.name] = AnimeSeg() global_config.plugins[AnimeSeg.name] = AnimeSeg()
if args.enable_realesrgan: if enable_realesrgan:
logger.info( logger.info(
f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}" f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
) )
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
args.realesrgan_model, realesrgan_model,
args.realesrgan_device, realesrgan_device,
no_half=args.realesrgan_no_half, no_half=no_half,
) )
if args.enable_gfpgan: if enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin") logger.info(f"Initialize {GFPGANPlugin.name} plugin")
if args.enable_realesrgan: if enable_realesrgan:
logger.info("Use realesrgan as GFPGAN background upscaler") logger.info("Use realesrgan as GFPGAN background upscaler")
else: else:
logger.info( logger.info(
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
) )
plugins[GFPGANPlugin.name] = GFPGANPlugin( global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None) gfpgan_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
) )
if args.enable_restoreformer: if enable_restoreformer:
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
args.restoreformer_device, restoreformer_device,
upscaler=plugins.get(RealESRGANUpscaler.name, None), upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
) )
def main(args):
global model
global device
global input_image_path
global is_disable_model_switch
global enable_file_manager
global is_desktop
global thumb
global output_dir
global image_quality
global enable_auto_saving
build_plugins(args) @typer_app.command(help="Install all plugins dependencies")
def install_plugins_packages():
from lama_cleaner.installer import install_plugins_package
image_quality = args.quality install_plugins_package()
output_dir = args.output_dir
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
def download(
model: str = Option(
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
),
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
cli_download_model(model, model_dir)
@typer_app.command(help="List downloaded models")
def list_model(
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
setup_model_dir(model_dir)
scanned_models = scan_models()
for it in scanned_models:
print(it.name)
@typer_app.command(help="Start lama cleaner server")
def start(
host: str = Option("127.0.0.1"),
port: int = Option(8080),
model: str = Option(
DEFAULT_MODEL,
help=f"Available models: [{', '.join(AVAILABLE_MODELS)}]. "
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
no_half: bool = Option(False, help=NO_HALF_HELP),
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu),
gui: bool = Option(False, help=GUI_HELP),
disable_model_switch: bool = Option(False),
input: Path = Option(None, help=INPUT_HELP),
output_dir: Path = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
),
quality: int = Option(95, help=QUALITY_HELP),
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
interactive_seg_model: InteractiveSegModel = Option(
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
),
interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu),
realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3),
enable_gfpgan: bool = Option(False),
gfpgan_device: Device = Option(Device.cpu),
enable_restoreformer: bool = Option(False),
restoreformer_device: Device = Option(Device.cpu),
):
global global_config
dump_environment_info()
if input:
if not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if input.is_dir():
logger.info(f"Initialize file manager")
file_manager = FileManager(app)
app.config["THUMBNAIL_MEDIA_ROOT"] = input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
file_manager.output_dir = output_dir
else:
global_config.input_image_path = input
device = check_device(device)
setup_model_dir(model_dir)
if local_files_only:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
if model not in AVAILABLE_MODELS:
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.error(
f"invalid --model: {model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}"
)
exit()
global_config.image_quality = quality
global_config.disable_model_switch = disable_model_switch
global_config.is_desktop = gui
build_plugins(
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
no_half,
)
if output_dir: if output_dir:
output_dir = os.path.abspath(output_dir) output_dir = output_dir.expanduser().absolute()
logger.info(f"Output dir: {output_dir}") logger.info(f"Image will auto save to output dir: {output_dir}")
enable_auto_saving = True global_config.output_dir = output_dir
device = torch.device(args.device) global_config.model_manager = ModelManager(
is_disable_model_switch = args.disable_model_switch name=model,
is_desktop = args.gui device=torch.device(device),
if is_disable_model_switch: no_half=no_half,
logger.info( disable_nsfw=disable_nsfw_checker,
f"Start with --disable-model-switch, model switch on frontend is disable" sd_cpu_textencoder=cpu_textencoder,
) cpu_offload=cpu_offload,
if args.input and os.path.isdir(args.input):
logger.info(f"Initialize file manager")
thumb = FileManager(app)
enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
thumb.output_dir = Path(output_dir)
# thumb.start()
# try:
# while True:
# time.sleep(1)
# finally:
# thumb.image_dir_observer.stop()
# thumb.image_dir_observer.join()
# thumb.output_dir_observer.stop()
# thumb.output_dir_observer.join()
else:
input_image_path = args.input
# 为了兼容性
model_name_map = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
"sd2": "stabilityai/stable-diffusion-2-inpainting",
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
}
model = ModelManager(
name=model_name_map.get(args.model, args.model),
sd_controlnet=args.sd_controlnet,
sd_controlnet_method=args.sd_controlnet_method,
device=device,
no_half=args.no_half,
hf_access_token=args.hf_access_token,
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder,
cpu_offload=args.cpu_offload,
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
callback=diffuser_callback, callback=diffuser_callback,
) )
if args.gui: if gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI from flaskwebgui import FlaskUI
ui = FlaskUI( ui = FlaskUI(
app, app,
socketio=socketio, socketio=socketio,
width=app_width, width=1200,
height=app_height, height=800,
host=args.host, host=host,
port=args.port, port=port,
close_server_on_exit=not args.no_gui_auto_close, close_server_on_exit=True,
idle_interval=60,
) )
ui.run() ui.run()
else: else:
socketio.run( socketio.run(
app, app,
host=args.host, host=host,
port=args.port, port=port,
debug=args.debug,
allow_unsafe_werkzeug=True, allow_unsafe_werkzeug=True,
) )

View File

@ -39,7 +39,6 @@ def test_runway_sd_1_5(
name=model_name, name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder, sd_cpu_textencoder=cpu_textencoder,
sd_controlnet_method=sd_controlnet_method, sd_controlnet_method=sd_controlnet_method,
@ -87,7 +86,6 @@ def test_local_file_path(sd_device, sampler):
name=model_name, name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=True, cpu_offload=True,
@ -125,7 +123,6 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
name=model_name, name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=True, cpu_offload=True,
@ -166,7 +163,6 @@ def test_controlnet_switch(sd_device, sampler):
name=model_name, name=model_name,
sd_controlnet=True, sd_controlnet=True,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=True, cpu_offload=True,

View File

@ -21,7 +21,6 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
model = ModelManager( model = ModelManager(
name=model_name, name=model_name,
device=torch.device(device), device=torch.device(device),
hf_access_token="",
disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
@ -52,7 +51,6 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
model = ModelManager( model = ModelManager(
name=model_name, name=model_name,
device=torch.device(device), device=torch.device(device),
hf_access_token="",
disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,

View File

@ -17,11 +17,9 @@ def test_load_model():
name=m, name=m,
device="cpu", device="cpu",
no_half=False, no_half=False,
hf_access_token="",
disable_nsfw=False, disable_nsfw=False,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
cpu_offload=True, cpu_offload=True,
enable_xformers=False,
) )

View File

@ -16,11 +16,9 @@ def test_model_switch():
sd_controlnet=True, sd_controlnet=True,
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny", sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
device=torch.device("mps"), device=torch.device("mps"),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
cpu_offload=False, cpu_offload=False,
enable_xformers=False,
callback=None, callback=None,
) )
@ -34,11 +32,9 @@ def test_controlnet_switch_onoff(caplog):
sd_controlnet=True, sd_controlnet=True,
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny", sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
device=torch.device("mps"), device=torch.device("mps"),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
cpu_offload=False, cpu_offload=False,
enable_xformers=False,
callback=None, callback=None,
) )
@ -61,11 +57,9 @@ def test_controlnet_switch_method(caplog):
sd_controlnet=True, sd_controlnet=True,
sd_controlnet_method=old_method, sd_controlnet_method=old_method,
device=torch.device("mps"), device=torch.device("mps"),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
cpu_offload=False, cpu_offload=False,
enable_xformers=False,
callback=None, callback=None,
) )

View File

@ -41,7 +41,6 @@ def test_outpainting(name, sd_device, rect):
model = ModelManager( model = ModelManager(
name=name, name=name,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback, callback=callback,
@ -86,7 +85,6 @@ def test_kandinsky_outpainting(name, sd_device, rect):
model = ModelManager( model = ModelManager(
name=name, name=name,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback, callback=callback,

View File

@ -38,7 +38,6 @@ def test_runway_sd_1_5_all_samplers(
model = ModelManager( model = ModelManager(
name="runwayml/stable-diffusion-inpainting", name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
) )
@ -69,7 +68,6 @@ def test_runway_sd_lcm_lora(sd_device, strategy, sampler):
model = ModelManager( model = ModelManager(
name="runwayml/stable-diffusion-inpainting", name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
) )
@ -102,7 +100,6 @@ def test_runway_sd_freeu(sd_device, strategy, sampler):
model = ModelManager( model = ModelManager(
name="runwayml/stable-diffusion-inpainting", name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
) )
@ -136,7 +133,6 @@ def test_runway_sd_sd_strength(sd_device, strategy, sampler):
model = ModelManager( model = ModelManager(
name="runwayml/stable-diffusion-inpainting", name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
) )
@ -165,7 +161,6 @@ def test_runway_norm_sd_model(sd_device, strategy, sampler):
model = ModelManager( model = ModelManager(
name="runwayml/stable-diffusion-v1-5", name="runwayml/stable-diffusion-v1-5",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
) )
@ -192,7 +187,6 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
model = ModelManager( model = ModelManager(
name="runwayml/stable-diffusion-inpainting", name="runwayml/stable-diffusion-inpainting",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=True, cpu_offload=True,
@ -229,7 +223,6 @@ def test_local_file_path(sd_device, sampler, name):
model = ModelManager( model = ModelManager(
name=name, name=name,
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=False, cpu_offload=False,

View File

@ -29,7 +29,6 @@ def test_sdxl(sd_device, strategy, sampler):
model = ModelManager( model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback, callback=callback,
@ -70,7 +69,6 @@ def test_sdxl_lcm_lora_and_freeu(sd_device, strategy, sampler):
model = ModelManager( model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback, callback=callback,
@ -131,7 +129,6 @@ def test_sdxl_outpainting(sd_device, rect):
model = ModelManager( model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="",
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
) )

View File

@ -24,7 +24,6 @@ def save_config(
cpu_offload, cpu_offload,
disable_nsfw, disable_nsfw,
sd_cpu_textencoder, sd_cpu_textencoder,
enable_xformers,
local_files_only, local_files_only,
model_dir, model_dir,
input, input,
@ -102,9 +101,6 @@ def main(config_file: str):
with gr.Column(): with gr.Column():
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
no_gui_auto_close = gr.Checkbox(
init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}"
)
with gr.Column(): with gr.Column():
model_dir = gr.Textbox( model_dir = gr.Textbox(
@ -193,14 +189,11 @@ def main(config_file: str):
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}" init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
) )
sd_cpu_textencoder = gr.Checkbox( sd_cpu_textencoder = gr.Checkbox(
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}" init_config.sd_cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}"
) )
disable_nsfw = gr.Checkbox( disable_nsfw = gr.Checkbox(
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}" init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
) )
enable_xformers = gr.Checkbox(
init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}"
)
local_files_only = gr.Checkbox( local_files_only = gr.Checkbox(
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}" init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
) )
@ -221,7 +214,6 @@ def main(config_file: str):
cpu_offload, cpu_offload,
disable_nsfw, disable_nsfw,
sd_cpu_textencoder, sd_cpu_textencoder,
enable_xformers,
local_files_only, local_files_only,
model_dir, model_dir,
input, input,

View File

@ -71,6 +71,8 @@ const Cropper = (props: Props) => {
setY, setY,
setWidth, setWidth,
setHeight, setHeight,
isResizing,
setIsResizing,
] = useStore((state) => [ ] = useStore((state) => [
state.imageWidth, state.imageWidth,
state.imageHeight, state.imageHeight,
@ -80,9 +82,11 @@ const Cropper = (props: Props) => {
state.setCropperY, state.setCropperY,
state.setCropperWidth, state.setCropperWidth,
state.setCropperHeight, state.setCropperHeight,
state.isCropperExtenderResizing,
state.setIsCropperExtenderResizing,
]) ])
const [isResizing, setIsResizing] = useState(false) // const [isResizing, setIsResizing] = useState(false)
const [isMoving, setIsMoving] = useState(false) const [isMoving, setIsMoving] = useState(false)
useEffect(() => { useEffect(() => {

View File

@ -65,6 +65,7 @@ export default function Editor(props: EditorProps) {
updateAppState, updateAppState,
runMannually, runMannually,
runInpainting, runInpainting,
isCropperExtenderResizing,
] = useStore((state) => [ ] = useStore((state) => [
state.disableShortCuts, state.disableShortCuts,
state.windowSize, state.windowSize,
@ -87,6 +88,7 @@ export default function Editor(props: EditorProps) {
state.updateAppState, state.updateAppState,
state.runMannually(), state.runMannually(),
state.runInpainting, state.runInpainting,
state.isCropperExtenderResizing,
]) ])
const baseBrushSize = useStore((state) => state.editorState.baseBrushSize) const baseBrushSize = useStore((state) => state.editorState.baseBrushSize)
const brushSize = useStore((state) => state.getBrushSize()) const brushSize = useStore((state) => state.getBrushSize())
@ -537,7 +539,7 @@ export default function Editor(props: EditorProps) {
} }
const toggleShowBrush = (newState: boolean) => { const toggleShowBrush = (newState: boolean) => {
if (newState !== showBrush && !isPanning) { if (newState !== showBrush && !isPanning && !isCropperExtenderResizing) {
setShowBrush(newState) setShowBrush(newState)
} }
} }
@ -693,7 +695,7 @@ export default function Editor(props: EditorProps) {
limitToBounds={false} limitToBounds={false}
doubleClick={{ disabled: true }} doubleClick={{ disabled: true }}
initialScale={minScale} initialScale={minScale}
minScale={minScale * 0.6} minScale={minScale * 0.3}
onPanning={(ref) => { onPanning={(ref) => {
if (!panned) { if (!panned) {
setPanned(true) setPanned(true)

View File

@ -54,6 +54,8 @@ const Extender = (props: Props) => {
setWidth, setWidth,
setHeight, setHeight,
extenderDirection, extenderDirection,
isResizing,
setIsResizing,
] = useStore((state) => [ ] = useStore((state) => [
state.isInpainting, state.isInpainting,
state.imageHeight, state.imageHeight,
@ -64,10 +66,10 @@ const Extender = (props: Props) => {
state.setExtenderWidth, state.setExtenderWidth,
state.setExtenderHeight, state.setExtenderHeight,
state.settings.extenderDirection, state.settings.extenderDirection,
state.isCropperExtenderResizing,
state.setIsCropperExtenderResizing,
]) ])
const [isResizing, setIsResizing] = useState(false)
const [evData, setEVData] = useState<EVData>({ const [evData, setEVData] = useState<EVData>({
initX: 0, initX: 0,
initY: 0, initY: 0,
@ -122,10 +124,9 @@ const Extender = (props: Props) => {
const moveBottom = () => { const moveBottom = () => {
const newHeight = evData.initHeight + offsetY const newHeight = evData.initHeight + offsetY
let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight) let [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight)
if (extenderDirection === EXTENDER_ALL) { if (extenderDirection === EXTENDER_ALL) {
if (clampedY + clampedHeight < imageHeight) { if (clampedHeight < Math.abs(clampedY) + imageHeight) {
clampedHeight = imageHeight clampedHeight = Math.abs(clampedY) + imageHeight
} }
} }
setHeight(clampedHeight) setHeight(clampedHeight)
@ -155,8 +156,8 @@ const Extender = (props: Props) => {
const newWidth = evData.initWidth + offsetX const newWidth = evData.initWidth + offsetX
let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth) let [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth)
if (extenderDirection === EXTENDER_ALL) { if (extenderDirection === EXTENDER_ALL) {
if (clampedX + clampedWidth < imageWdith) { if (clampedWidth < Math.abs(clampedX) + imageWdith) {
clampedWidth = imageWdith clampedWidth = Math.abs(clampedX) + imageWdith
} }
} }
setWidth(clampedWidth) setWidth(clampedWidth)

View File

@ -105,18 +105,22 @@ const LabelTitle = ({
{text} {text}
</Label> </Label>
</TooltipTrigger> </TooltipTrigger>
<TooltipContent className="flex flex-col max-w-xs text-sm" side="left"> {toolTip ? (
<p>{toolTip}</p> <TooltipContent className="flex flex-col max-w-xs text-sm" side="left">
{url ? ( <p>{toolTip}</p>
<Button variant="link" className="justify-end"> {url ? (
<a href={url} target="_blank"> <Button variant="link" className="justify-end">
More info <a href={url} target="_blank">
</a> More info
</Button> </a>
) : ( </Button>
<></> ) : (
)} <></>
</TooltipContent> )}
</TooltipContent>
) : (
<></>
)}
</Tooltip> </Tooltip>
) )
} }
@ -172,7 +176,11 @@ const SidePanel = () => {
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
<div className="flex justify-between items-center pr-2"> <div className="flex justify-between items-center pr-2">
<LabelTitle text="Controlnet" /> <LabelTitle
text="ControlNet"
toolTip="Using an additional conditioning image to control how an image is generated"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
/>
<Switch <Switch
id="controlnet" id="controlnet"
checked={settings.enableControlnet} checked={settings.enableControlnet}
@ -271,7 +279,11 @@ const SidePanel = () => {
return ( return (
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
<div className="flex justify-between items-center pr-2"> <div className="flex justify-between items-center pr-2">
<LabelTitle text="Freeu" /> <LabelTitle
text="FreeU"
toolTip="FreeU is a technique for improving image quality. Different models may require different FreeU-specific hyperparameters, which can be viewed in the more info section."
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu"
/>
<Switch <Switch
id="freeu" id="freeu"
checked={settings.enableFreeu} checked={settings.enableFreeu}
@ -408,7 +420,10 @@ const SidePanel = () => {
return ( return (
<div> <div>
<RowContainer> <RowContainer>
<div>Example Image</div> <LabelTitle
text="Example Image"
toolTip="An example image to guide image generation."
/>
<ImageUploadButton <ImageUploadButton
tooltip="Upload example image" tooltip="Upload example image"
onFileUpload={(file) => { onFileUpload={(file) => {
@ -450,8 +465,9 @@ const SidePanel = () => {
return ( return (
<div className="flex flex-col gap-1"> <div className="flex flex-col gap-1">
<LabelTitle <LabelTitle
htmlFor="image-guidance-scale"
text="Image guidance scale" text="Image guidance scale"
toolTip="Push the generated image towards the inital image. Higher image guidance scale encourages generated images that are closely linked to the source image, usually at the expense of lower image quality."
url="https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix"
/> />
<RowContainer> <RowContainer>
<Slider <Slider
@ -518,11 +534,17 @@ const SidePanel = () => {
} }
const renderExtender = () => { const renderExtender = () => {
if (!settings.model.support_outpainting) {
return null
}
return ( return (
<> <>
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
<RowContainer> <RowContainer>
<LabelTitle text="Extender" /> <LabelTitle
text="Extender"
toolTip="Perform outpainting on images to expand it's content."
/>
<Switch <Switch
id="extender" id="extender"
checked={settings.showExtender} checked={settings.showExtender}
@ -709,7 +731,10 @@ const SidePanel = () => {
> >
<div className="flex flex-col gap-4 mt-4"> <div className="flex flex-col gap-4 mt-4">
<RowContainer> <RowContainer>
<LabelTitle text="Cropper" /> <LabelTitle
text="Cropper"
toolTip="Inpainting on part of image, improve inference speed and reduce memory usage."
/>
<Switch <Switch
id="cropper" id="cropper"
checked={settings.showCropper} checked={settings.showCropper}
@ -725,7 +750,11 @@ const SidePanel = () => {
{renderExtender()} {renderExtender()}
<div className="flex flex-col gap-1"> <div className="flex flex-col gap-1">
<LabelTitle htmlFor="steps" text="Steps" /> <LabelTitle
htmlFor="steps"
text="Steps"
toolTip="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
/>
<RowContainer> <RowContainer>
<Slider <Slider
className="w-[180px]" className="w-[180px]"

View File

@ -150,8 +150,11 @@ type AppState = {
interactiveSegState: InteractiveSegState interactiveSegState: InteractiveSegState
fileManagerState: FileManagerState fileManagerState: FileManagerState
cropperState: CropperState cropperState: CropperState
extenderState: CropperState extenderState: CropperState
isCropperExtenderResizing: bool
serverConfig: ServerConfig serverConfig: ServerConfig
settings: Settings settings: Settings
@ -177,6 +180,7 @@ type AppAction = {
setExtenderY: (newValue: number) => void setExtenderY: (newValue: number) => void
setExtenderWidth: (newValue: number) => void setExtenderWidth: (newValue: number) => void
setExtenderHeight: (newValue: number) => void setExtenderHeight: (newValue: number) => void
setIsCropperExtenderResizing: (newValue: boolean) => void
updateExtenderDirection: (newValue: string) => void updateExtenderDirection: (newValue: string) => void
resetExtender: (width: number, height: number) => void resetExtender: (width: number, height: number) => void
updateExtenderByBuiltIn: (direction: string, scale: number) => void updateExtenderByBuiltIn: (direction: string, scale: number) => void
@ -261,6 +265,7 @@ const defaultValues: AppState = {
width: 512, width: 512,
height: 512, height: 512,
}, },
isCropperExtenderResizing: false,
fileManagerState: { fileManagerState: {
sortBy: SortBy.CTIME, sortBy: SortBy.CTIME,
@ -889,6 +894,11 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
state.extenderState.height = newValue state.extenderState.height = newValue
}), }),
setIsCropperExtenderResizing: (newValue: boolean) =>
set((state) => {
state.isCropperExtenderResizing = newValue
}),
updateExtenderDirection: (newValue: string) => { updateExtenderDirection: (newValue: string) => {
console.log( console.log(
`updateExtenderDirection: ${JSON.stringify(get().extenderState)}` `updateExtenderDirection: ${JSON.stringify(get().extenderState)}`