IOPaint/lama_cleaner/batch_processing.py
2024-01-04 21:39:34 +08:00

121 lines
3.7 KiB
Python

import json
import cv2
from pathlib import Path
from typing import Dict, Optional
from PIL import Image
from loguru import logger
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TimeElapsedColumn,
MofNCompleteColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
)
from lama_cleaner.helper import pil_to_bytes
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import InpaintRequest
def glob_images(path: Path) -> Dict[str, Path]:
# png/jpg/jpeg
if path.is_file():
return {path.stem: path}
elif path.is_dir():
res = {}
for it in path.glob("*.*"):
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
res[it.stem] = it
return res
def batch_inpaint(
model: str,
device,
image: Path,
mask: Path,
output: Path,
config: Optional[Path] = None,
concat: bool = False,
):
if image.is_dir() and output.is_file():
logger.error(
f"invalid --output: when image is a directory, output should be a directory"
)
exit(-1)
image_paths = glob_images(image)
mask_paths = glob_images(mask)
if len(image_paths) == 0:
logger.error(f"invalid --image: empty image folder")
exit(-1)
if len(mask_paths) == 0:
logger.error(f"invalid --mask: empty mask folder")
exit(-1)
if config is None:
inpaint_request = InpaintRequest()
logger.info(f"Using default config: {inpaint_request}")
else:
with open(config, "r", encoding="utf-8") as f:
inpaint_request = InpaintRequest(**json.load(f))
model_manager = ModelManager(name=model, device=device)
first_mask = list(mask_paths.values())[0]
console = Console()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
console=console,
transient=False,
) as progress:
task = progress.add_task("Batch processing...", total=len(image_paths))
for stem, image_p in image_paths.items():
if stem not in mask_paths and mask.is_dir():
progress.log(f"mask for {image_p} not found")
progress.update(task, advance=1)
continue
mask_p = mask_paths.get(stem, first_mask)
infos = Image.open(image_p).info
img = cv2.imread(str(image_p))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
if mask_img.shape[:2] != img.shape[:2]:
progress.log(
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
)
mask_img = cv2.resize(
mask_img,
(img.shape[1], img.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
mask_img[mask_img >= 127] = 255
mask_img[mask_img < 127] = 0
# bgr
inpaint_result = model_manager(img, mask_img, inpaint_request)
inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
if concat:
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
save_p = output / f"{stem}.png"
with open(save_p, "wb") as fw:
fw.write(img_bytes)
progress.update(task, advance=1)