diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py index 775bc46..e399b60 100644 --- a/lama_cleaner/__init__.py +++ b/lama_cleaner/__init__.py @@ -10,6 +10,6 @@ warnings.simplefilter("ignore", UserWarning) def entry_point(): # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 - from lama_cleaner.server import typer_app + from lama_cleaner.cli import typer_app typer_app() diff --git a/lama_cleaner/cli.py b/lama_cleaner/cli.py new file mode 100644 index 0000000..c9d1cb3 --- /dev/null +++ b/lama_cleaner/cli.py @@ -0,0 +1,123 @@ +from pathlib import Path + +import typer +from loguru import logger +from typer import Option + +from lama_cleaner.const import * +from lama_cleaner.download import cli_download_model, scan_models +from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device + +typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False) + + +@typer_app.command(help="Install all plugins dependencies") +def install_plugins_packages(): + from lama_cleaner.installer import install_plugins_package + + install_plugins_package() + + +@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace") +def download( + model: str = Option( + ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting" + ), + model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), +): + cli_download_model(model, model_dir) + + +@typer_app.command(name="list", help="List downloaded models") +def list_model( + model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), +): + setup_model_dir(model_dir) + scanned_models = scan_models() + for it in scanned_models: + print(it.name) + + +@typer_app.command(help="Start lama cleaner server") +def start( + host: str = Option("127.0.0.1"), + port: int = Option(8080), + model: str = Option( + DEFAULT_MODEL, + help=f"Available erase models: [{', '.join(AVAILABLE_MODELS)}]. " + f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface", + ), + model_dir: Path = Option( + DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False + ), + no_half: bool = Option(False, help=NO_HALF_HELP), + cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP), + disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP), + cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP), + local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP), + device: Device = Option(Device.cpu), + gui: bool = Option(False, help=GUI_HELP), + disable_model_switch: bool = Option(False), + input: Path = Option(None, help=INPUT_HELP), + output_dir: Path = Option( + None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False + ), + quality: int = Option(95, help=QUALITY_HELP), + enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP), + interactive_seg_model: InteractiveSegModel = Option( + InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP + ), + interactive_seg_device: Device = Option(Device.cpu), + enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), + enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP), + enable_realesrgan: bool = Option(False), + realesrgan_device: Device = Option(Device.cpu), + realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3), + enable_gfpgan: bool = Option(False), + gfpgan_device: Device = Option(Device.cpu), + enable_restoreformer: bool = Option(False), + restoreformer_device: Device = Option(Device.cpu), +): + dump_environment_info() + device = check_device(device) + model_dir = model_dir.expanduser().absolute() + setup_model_dir(model_dir) + + if local_files_only: + os.environ["TRANSFORMERS_OFFLINE"] = "1" + os.environ["HF_HUB_OFFLINE"] = "1" + + scanned_models = scan_models() + if model not in [it.name for it in scanned_models]: + logger.info(f"{model} not found in {model_dir}, try to downloading") + cli_download_model(model, model_dir) + + from lama_cleaner.server import start + + start( + host=host, + port=port, + model=model, + no_half=no_half, + cpu_offload=cpu_offload, + disable_nsfw_checker=disable_nsfw_checker, + cpu_textencoder=cpu_textencoder, + device=device, + gui=gui, + disable_model_switch=disable_model_switch, + input=input, + output_dir=output_dir, + quality=quality, + enable_interactive_seg=enable_interactive_seg, + interactive_seg_model=interactive_seg_model, + interactive_seg_device=interactive_seg_device, + enable_remove_bg=enable_remove_bg, + enable_anime_seg=enable_anime_seg, + enable_realesrgan=enable_realesrgan, + realesrgan_device=realesrgan_device, + realesrgan_model=realesrgan_model, + enable_gfpgan=enable_gfpgan, + gfpgan_device=gfpgan_device, + enable_restoreformer=enable_restoreformer, + restoreformer_device=restoreformer_device, + ) diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py index 698f187..66b14f9 100644 --- a/lama_cleaner/download.py +++ b/lama_cleaner/download.py @@ -43,6 +43,7 @@ def folder_name_to_show_name(name: str) -> str: def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: cache_dir = Path(cache_dir) + # logger.info(f"Scanning single file sd/sdxl models in {cache_dir}") res = [] for it in cache_dir.glob(f"*.*"): if it.suffix not in [".safetensors", ".ckpt"]: @@ -68,10 +69,12 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: return res -def scan_inpaint_models() -> List[ModelInfo]: +def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]: res = [] from lama_cleaner.model import models + # logger.info(f"Scanning inpaint models in {model_dir}") + for name, m in models.items(): if m.is_erase_model and m.is_downloaded(): res.append( @@ -87,10 +90,12 @@ def scan_inpaint_models() -> List[ModelInfo]: def scan_models() -> List[ModelInfo]: from diffusers.utils import DIFFUSERS_CACHE + model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR) available_models = [] - available_models.extend(scan_inpaint_models()) - available_models.extend(scan_single_file_diffusion_models(DEFAULT_MODEL_DIR)) + available_models.extend(scan_inpaint_models(model_dir)) + available_models.extend(scan_single_file_diffusion_models(model_dir)) cache_dir = Path(DIFFUSERS_CACHE) + # logger.info(f"Scanning diffusers models in {cache_dir}") diffusers_model_names = [] for it in cache_dir.glob("**/*/model_index.json"): with open(it, "r", encoding="utf-8") as f: diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py index fb63d81..247dce9 100644 --- a/lama_cleaner/plugins/__init__.py +++ b/lama_cleaner/plugins/__init__.py @@ -1,6 +1,70 @@ +from loguru import logger + from .interactive_seg import InteractiveSeg from .remove_bg import RemoveBG from .realesrgan import RealESRGANUpscaler from .gfpgan_plugin import GFPGANPlugin from .restoreformer import RestoreFormerPlugin from .anime_seg import AnimeSeg +from ..const import InteractiveSegModel, Device + + +def build_plugins( + global_config, + enable_interactive_seg: bool, + interactive_seg_model: InteractiveSegModel, + interactive_seg_device: Device, + enable_remove_bg: bool, + enable_anime_seg: bool, + enable_realesrgan: bool, + realesrgan_device: Device, + realesrgan_model: str, + enable_gfpgan: bool, + gfpgan_device: Device, + enable_restoreformer: bool, + restoreformer_device: Device, + no_half: bool, +): + if enable_interactive_seg: + logger.info(f"Initialize {InteractiveSeg.name} plugin") + global_config.plugins[InteractiveSeg.name] = InteractiveSeg( + interactive_seg_model, interactive_seg_device + ) + + if enable_remove_bg: + logger.info(f"Initialize {RemoveBG.name} plugin") + global_config.plugins[RemoveBG.name] = RemoveBG() + + if enable_anime_seg: + logger.info(f"Initialize {AnimeSeg.name} plugin") + global_config.plugins[AnimeSeg.name] = AnimeSeg() + + if enable_realesrgan: + logger.info( + f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}" + ) + global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( + realesrgan_model, + realesrgan_device, + no_half=no_half, + ) + + if enable_gfpgan: + logger.info(f"Initialize {GFPGANPlugin.name} plugin") + if enable_realesrgan: + logger.info("Use realesrgan as GFPGAN background upscaler") + else: + logger.info( + f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" + ) + global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin( + gfpgan_device, + upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), + ) + + if enable_restoreformer: + logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") + global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( + restoreformer_device, + upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), + ) diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 052cf06..adc9f2c 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,13 +1,6 @@ #!/usr/bin/env python3 -import json import os -import typer -from typer import Option - -from lama_cleaner.download import cli_download_model, scan_models -from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import hashlib import traceback @@ -35,14 +28,11 @@ from lama_cleaner.model_manager import ModelManager from lama_cleaner.plugins import ( InteractiveSeg, RemoveBG, - RealESRGANUpscaler, - GFPGANPlugin, - RestoreFormerPlugin, AnimeSeg, + build_plugins, ) from lama_cleaner.schema import Config -typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False) try: torch._C._jit_override_can_fuse_on_cpu(False) @@ -492,7 +482,20 @@ def get_cli_input_image(): return "No Input Image" -def build_plugins( +def start( + host: str, + port: int, + model: str, + no_half: bool, + cpu_offload: bool, + disable_nsfw_checker, + cpu_textencoder: bool, + device: Device, + gui: bool, + disable_model_switch: bool, + input: Path, + output_dir: Path, + quality: int, enable_interactive_seg: bool, interactive_seg_model: InteractiveSegModel, interactive_seg_device: Device, @@ -505,123 +508,7 @@ def build_plugins( gfpgan_device: Device, enable_restoreformer: bool, restoreformer_device: Device, - no_half: bool, ): - if enable_interactive_seg: - logger.info(f"Initialize {InteractiveSeg.name} plugin") - global_config.plugins[InteractiveSeg.name] = InteractiveSeg( - interactive_seg_model, interactive_seg_device - ) - - if enable_remove_bg: - logger.info(f"Initialize {RemoveBG.name} plugin") - global_config.plugins[RemoveBG.name] = RemoveBG() - - if enable_anime_seg: - logger.info(f"Initialize {AnimeSeg.name} plugin") - global_config.plugins[AnimeSeg.name] = AnimeSeg() - - if enable_realesrgan: - logger.info( - f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}" - ) - global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( - realesrgan_model, - realesrgan_device, - no_half=no_half, - ) - - if enable_gfpgan: - logger.info(f"Initialize {GFPGANPlugin.name} plugin") - if enable_realesrgan: - logger.info("Use realesrgan as GFPGAN background upscaler") - else: - logger.info( - f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" - ) - global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin( - gfpgan_device, - upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), - ) - - if enable_restoreformer: - logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") - global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( - restoreformer_device, - upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), - ) - - -@typer_app.command(help="Install all plugins dependencies") -def install_plugins_packages(): - from lama_cleaner.installer import install_plugins_package - - install_plugins_package() - - -@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace") -def download( - model: str = Option( - ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting" - ), - model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), -): - cli_download_model(model, model_dir) - - -@typer_app.command(help="List downloaded models") -def list_model( - model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), -): - setup_model_dir(model_dir) - scanned_models = scan_models() - for it in scanned_models: - print(it.name) - - -@typer_app.command(help="Start lama cleaner server") -def start( - host: str = Option("127.0.0.1"), - port: int = Option(8080), - model: str = Option( - DEFAULT_MODEL, - help=f"Available erase models: [{', '.join(AVAILABLE_MODELS)}]. " - f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface", - ), - model_dir: Path = Option( - DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False - ), - no_half: bool = Option(False, help=NO_HALF_HELP), - cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP), - disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP), - cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP), - local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP), - device: Device = Option(Device.cpu), - gui: bool = Option(False, help=GUI_HELP), - disable_model_switch: bool = Option(False), - input: Path = Option(None, help=INPUT_HELP), - output_dir: Path = Option( - None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False - ), - quality: int = Option(95, help=QUALITY_HELP), - enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP), - interactive_seg_model: InteractiveSegModel = Option( - InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP - ), - interactive_seg_device: Device = Option(Device.cpu), - enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), - enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP), - enable_realesrgan: bool = Option(False), - realesrgan_device: Device = Option(Device.cpu), - realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3), - enable_gfpgan: bool = Option(False), - gfpgan_device: Device = Option(Device.cpu), - enable_restoreformer: bool = Option(False), - restoreformer_device: Device = Option(Device.cpu), -): - global global_config - dump_environment_info() - if input: if not input.exists(): logger.error(f"invalid --input: {input} not exists") @@ -637,24 +524,11 @@ def start( else: global_config.input_image_path = input - device = check_device(device) - setup_model_dir(model_dir) - - if local_files_only: - os.environ["TRANSFORMERS_OFFLINE"] = "1" - os.environ["HF_HUB_OFFLINE"] = "1" - - scanned_models = scan_models() - if model not in [it.name for it in scanned_models]: - logger.error( - f"invalid model: {model} not exists. Available models: {[it.name for it in scanned_models]}" - ) - exit() - global_config.image_quality = quality global_config.disable_model_switch = disable_model_switch global_config.is_desktop = gui build_plugins( + global_config, enable_interactive_seg, interactive_seg_model, interactive_seg_device,