diff --git a/lama_cleaner/batch_processing.py b/lama_cleaner/batch_processing.py new file mode 100644 index 0000000..4a3727c --- /dev/null +++ b/lama_cleaner/batch_processing.py @@ -0,0 +1,120 @@ +import json +import cv2 +from pathlib import Path +from typing import Dict, Optional +from PIL import Image + +from loguru import logger +from rich.console import Console +from rich.progress import ( + Progress, + SpinnerColumn, + TimeElapsedColumn, + MofNCompleteColumn, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeRemainingColumn, +) + +from lama_cleaner.helper import pil_to_bytes +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import InpaintRequest + + +def glob_images(path: Path) -> Dict[str, Path]: + # png/jpg/jpeg + if path.is_file(): + return {path.stem: path} + elif path.is_dir(): + res = {} + for it in path.glob("*.*"): + if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: + res[it.stem] = it + return res + + +def batch_inpaint( + model: str, + device, + image: Path, + mask: Path, + output: Path, + config: Optional[Path] = None, + concat: bool = False, +): + if image.is_dir() and output.is_file(): + logger.error( + f"invalid --output: when image is a directory, output should be a directory" + ) + exit(-1) + + image_paths = glob_images(image) + mask_paths = glob_images(mask) + if len(image_paths) == 0: + logger.error(f"invalid --image: empty image folder") + exit(-1) + if len(mask_paths) == 0: + logger.error(f"invalid --mask: empty mask folder") + exit(-1) + + if config is None: + inpaint_request = InpaintRequest() + logger.info(f"Using default config: {inpaint_request}") + else: + with open(config, "r", encoding="utf-8") as f: + inpaint_request = InpaintRequest(**json.load(f)) + + model_manager = ModelManager(name=model, device=device) + first_mask = list(mask_paths.values())[0] + + console = Console() + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + console=console, + transient=False, + ) as progress: + task = progress.add_task("Batch processing...", total=len(image_paths)) + for stem, image_p in image_paths.items(): + if stem not in mask_paths and mask.is_dir(): + progress.log(f"mask for {image_p} not found") + progress.update(task, advance=1) + continue + mask_p = mask_paths.get(stem, first_mask) + + infos = Image.open(image_p).info + + img = cv2.imread(str(image_p)) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) + if mask_img.shape[:2] != img.shape[:2]: + progress.log( + f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}" + ) + mask_img = cv2.resize( + mask_img, + (img.shape[1], img.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + mask_img[mask_img >= 127] = 255 + mask_img[mask_img < 127] = 0 + + # bgr + inpaint_result = model_manager(img, mask_img, inpaint_request) + inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB) + if concat: + mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB) + inpaint_result = cv2.hconcat([img, mask_img, inpaint_result]) + + img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos) + save_p = output / f"{stem}.png" + with open(save_p, "wb") as fw: + fw.write(img_bytes) + + progress.update(task, advance=1) diff --git a/lama_cleaner/cli.py b/lama_cleaner/cli.py index fbdd167..0615a1e 100644 --- a/lama_cleaner/cli.py +++ b/lama_cleaner/cli.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Dict import typer from fastapi import FastAPI @@ -39,15 +40,36 @@ def list_model( print(it.name) -@typer_app.command(help="Processing image with lama cleaner") +@typer_app.command(help="Batch processing images") def run( - input: Path = Option(..., help="Image file or folder containing images"), - output_dir: Path = Option(..., help="Output directory"), - config_path: Path = Option(..., help="Config file path"), + model: str = Option("lama"), + device: Device = Option(Device.cpu), + image: Path = Option(..., help="Image folders or file path"), + mask: Path = Option( + ..., + help="Mask folders or file path. " + "If it is a directory, the mask images in the directory should have the same name as the original image." + "If it is a file, all images will use this mask." + "Mask will automatically resize to the same size as the original image.", + ), + output: Path = Option(..., help="Output directory or file path"), + config: Path = Option( + None, help="Config file path. You can use dump command to create a base config." + ), + concat: bool = Option( + False, help="Concat original image, mask and output images into one image" + ), model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), ): setup_model_dir(model_dir) - pass + scanned_models = scan_models() + if model not in [it.name for it in scanned_models]: + logger.info(f"{model} not found in {model_dir}, try to downloading") + cli_download_model(model, model_dir) + + from lama_cleaner.batch_processing import batch_inpaint + + batch_inpaint(model, device, image, mask, output, config, concat) @typer_app.command(help="Start lama cleaner server") diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 82c60c6..683ac0e 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -65,7 +65,7 @@ class InpaintModel: mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size ) - logger.info(f"final forward pad size: {pad_image.shape}") + # logger.info(f"final forward pad size: {pad_image.shape}") image, mask = self.forward_pre_process(image, mask, config) @@ -93,7 +93,7 @@ class InpaintModel: return: BGR IMAGE """ inpaint_result = None - logger.info(f"hd_strategy: {config.hd_strategy}") + # logger.info(f"hd_strategy: {config.hd_strategy}") if config.hd_strategy == HDStrategy.CROP: if max(image.shape) > config.hd_strategy_crop_trigger_size: logger.info(f"Run crop strategy") @@ -189,7 +189,7 @@ class InpaintModel: crop_img = image[t:b, l:r, :] crop_mask = mask[t:b, l:r] - logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") + # logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") return crop_img, crop_mask, [l, t, r, b]