IOPaint/iopaint/cli.py

174 lines
6.4 KiB
Python
Raw Normal View History

2023-12-25 04:31:49 +01:00
from pathlib import Path
2024-01-04 14:39:34 +01:00
from typing import Dict
2023-12-25 04:31:49 +01:00
import typer
2023-12-30 16:36:44 +01:00
from fastapi import FastAPI
2023-12-25 04:31:49 +01:00
from loguru import logger
from typer import Option
2024-01-05 08:19:23 +01:00
from iopaint.const import *
from iopaint.download import cli_download_model, scan_models
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
2023-12-25 04:31:49 +01:00
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
@typer_app.command(help="Install all plugins dependencies")
def install_plugins_packages():
2024-01-05 08:19:23 +01:00
from iopaint.installer import install_plugins_package
2023-12-25 04:31:49 +01:00
install_plugins_package()
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
def download(
model: str = Option(
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
),
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
cli_download_model(model, model_dir)
@typer_app.command(name="list", help="List downloaded models")
def list_model(
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
setup_model_dir(model_dir)
scanned_models = scan_models()
for it in scanned_models:
print(it.name)
2024-01-04 14:39:34 +01:00
@typer_app.command(help="Batch processing images")
2023-12-30 16:36:44 +01:00
def run(
2024-01-04 14:39:34 +01:00
model: str = Option("lama"),
device: Device = Option(Device.cpu),
image: Path = Option(..., help="Image folders or file path"),
mask: Path = Option(
...,
help="Mask folders or file path. "
"If it is a directory, the mask images in the directory should have the same name as the original image."
"If it is a file, all images will use this mask."
"Mask will automatically resize to the same size as the original image.",
),
output: Path = Option(..., help="Output directory or file path"),
config: Path = Option(
None, help="Config file path. You can use dump command to create a base config."
),
concat: bool = Option(
False, help="Concat original image, mask and output images into one image"
),
2023-12-30 16:36:44 +01:00
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
setup_model_dir(model_dir)
2024-01-04 14:39:34 +01:00
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.info(f"{model} not found in {model_dir}, try to downloading")
cli_download_model(model, model_dir)
2024-01-05 08:19:23 +01:00
from iopaint.batch_processing import batch_inpaint
2024-01-04 14:39:34 +01:00
batch_inpaint(model, device, image, mask, output, config, concat)
2023-12-30 16:36:44 +01:00
2024-01-05 08:19:23 +01:00
@typer_app.command(help="Start IOPaint server")
2023-12-25 04:31:49 +01:00
def start(
host: str = Option("127.0.0.1"),
port: int = Option(8080),
model: str = Option(
DEFAULT_MODEL,
help=f"Available erase models: [{', '.join(AVAILABLE_MODELS)}]. "
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
no_half: bool = Option(False, help=NO_HALF_HELP),
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu),
gui: bool = Option(False, help=GUI_HELP),
disable_model_switch: bool = Option(False),
input: Path = Option(None, help=INPUT_HELP),
output_dir: Path = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
),
quality: int = Option(95, help=QUALITY_HELP),
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
interactive_seg_model: InteractiveSegModel = Option(
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
),
interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu),
2023-12-28 03:48:52 +01:00
realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
2023-12-25 04:31:49 +01:00
enable_gfpgan: bool = Option(False),
gfpgan_device: Device = Option(Device.cpu),
enable_restoreformer: bool = Option(False),
restoreformer_device: Device = Option(Device.cpu),
):
dump_environment_info()
device = check_device(device)
2023-12-30 16:36:44 +01:00
if input and not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will be saved to {output_dir}")
if not output_dir.exists():
logger.info(f"Create output directory {output_dir}")
output_dir.mkdir(parents=True)
2023-12-25 04:31:49 +01:00
model_dir = model_dir.expanduser().absolute()
setup_model_dir(model_dir)
if local_files_only:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.info(f"{model} not found in {model_dir}, try to downloading")
cli_download_model(model, model_dir)
2024-01-05 08:19:23 +01:00
from iopaint.api import Api
from iopaint.schema import ApiConfig
2023-12-30 16:36:44 +01:00
app = FastAPI()
api = Api(
app,
ApiConfig(
host=host,
port=port,
model=model,
no_half=no_half,
cpu_offload=cpu_offload,
disable_nsfw_checker=disable_nsfw_checker,
cpu_textencoder=cpu_textencoder,
device=device,
gui=gui,
disable_model_switch=disable_model_switch,
input=input,
output_dir=output_dir,
quality=quality,
enable_interactive_seg=enable_interactive_seg,
interactive_seg_model=interactive_seg_model,
interactive_seg_device=interactive_seg_device,
enable_remove_bg=enable_remove_bg,
enable_anime_seg=enable_anime_seg,
enable_realesrgan=enable_realesrgan,
realesrgan_device=realesrgan_device,
realesrgan_model=realesrgan_model,
enable_gfpgan=enable_gfpgan,
gfpgan_device=gfpgan_device,
enable_restoreformer=enable_restoreformer,
restoreformer_device=restoreformer_device,
),
2023-12-25 04:31:49 +01:00
)
2023-12-30 16:36:44 +01:00
api.launch()