add batch_processing
This commit is contained in:
parent
cd82b21cd9
commit
4a9f2ab03c
120
lama_cleaner/batch_processing.py
Normal file
120
lama_cleaner/batch_processing.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import json
|
||||||
|
import cv2
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import (
|
||||||
|
Progress,
|
||||||
|
SpinnerColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
TextColumn,
|
||||||
|
BarColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lama_cleaner.helper import pil_to_bytes
|
||||||
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
|
|
||||||
|
def glob_images(path: Path) -> Dict[str, Path]:
|
||||||
|
# png/jpg/jpeg
|
||||||
|
if path.is_file():
|
||||||
|
return {path.stem: path}
|
||||||
|
elif path.is_dir():
|
||||||
|
res = {}
|
||||||
|
for it in path.glob("*.*"):
|
||||||
|
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
|
||||||
|
res[it.stem] = it
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def batch_inpaint(
|
||||||
|
model: str,
|
||||||
|
device,
|
||||||
|
image: Path,
|
||||||
|
mask: Path,
|
||||||
|
output: Path,
|
||||||
|
config: Optional[Path] = None,
|
||||||
|
concat: bool = False,
|
||||||
|
):
|
||||||
|
if image.is_dir() and output.is_file():
|
||||||
|
logger.error(
|
||||||
|
f"invalid --output: when image is a directory, output should be a directory"
|
||||||
|
)
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
image_paths = glob_images(image)
|
||||||
|
mask_paths = glob_images(mask)
|
||||||
|
if len(image_paths) == 0:
|
||||||
|
logger.error(f"invalid --image: empty image folder")
|
||||||
|
exit(-1)
|
||||||
|
if len(mask_paths) == 0:
|
||||||
|
logger.error(f"invalid --mask: empty mask folder")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
inpaint_request = InpaintRequest()
|
||||||
|
logger.info(f"Using default config: {inpaint_request}")
|
||||||
|
else:
|
||||||
|
with open(config, "r", encoding="utf-8") as f:
|
||||||
|
inpaint_request = InpaintRequest(**json.load(f))
|
||||||
|
|
||||||
|
model_manager = ModelManager(name=model, device=device)
|
||||||
|
first_mask = list(mask_paths.values())[0]
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
console=console,
|
||||||
|
transient=False,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("Batch processing...", total=len(image_paths))
|
||||||
|
for stem, image_p in image_paths.items():
|
||||||
|
if stem not in mask_paths and mask.is_dir():
|
||||||
|
progress.log(f"mask for {image_p} not found")
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
continue
|
||||||
|
mask_p = mask_paths.get(stem, first_mask)
|
||||||
|
|
||||||
|
infos = Image.open(image_p).info
|
||||||
|
|
||||||
|
img = cv2.imread(str(image_p))
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||||
|
mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||||
|
if mask_img.shape[:2] != img.shape[:2]:
|
||||||
|
progress.log(
|
||||||
|
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
|
||||||
|
)
|
||||||
|
mask_img = cv2.resize(
|
||||||
|
mask_img,
|
||||||
|
(img.shape[1], img.shape[0]),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
mask_img[mask_img >= 127] = 255
|
||||||
|
mask_img[mask_img < 127] = 0
|
||||||
|
|
||||||
|
# bgr
|
||||||
|
inpaint_result = model_manager(img, mask_img, inpaint_request)
|
||||||
|
inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
|
||||||
|
if concat:
|
||||||
|
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
|
||||||
|
inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
|
||||||
|
|
||||||
|
img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
|
||||||
|
save_p = output / f"{stem}.png"
|
||||||
|
with open(save_p, "wb") as fw:
|
||||||
|
fw.write(img_bytes)
|
||||||
|
|
||||||
|
progress.update(task, advance=1)
|
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@ -39,15 +40,36 @@ def list_model(
|
|||||||
print(it.name)
|
print(it.name)
|
||||||
|
|
||||||
|
|
||||||
@typer_app.command(help="Processing image with lama cleaner")
|
@typer_app.command(help="Batch processing images")
|
||||||
def run(
|
def run(
|
||||||
input: Path = Option(..., help="Image file or folder containing images"),
|
model: str = Option("lama"),
|
||||||
output_dir: Path = Option(..., help="Output directory"),
|
device: Device = Option(Device.cpu),
|
||||||
config_path: Path = Option(..., help="Config file path"),
|
image: Path = Option(..., help="Image folders or file path"),
|
||||||
|
mask: Path = Option(
|
||||||
|
...,
|
||||||
|
help="Mask folders or file path. "
|
||||||
|
"If it is a directory, the mask images in the directory should have the same name as the original image."
|
||||||
|
"If it is a file, all images will use this mask."
|
||||||
|
"Mask will automatically resize to the same size as the original image.",
|
||||||
|
),
|
||||||
|
output: Path = Option(..., help="Output directory or file path"),
|
||||||
|
config: Path = Option(
|
||||||
|
None, help="Config file path. You can use dump command to create a base config."
|
||||||
|
),
|
||||||
|
concat: bool = Option(
|
||||||
|
False, help="Concat original image, mask and output images into one image"
|
||||||
|
),
|
||||||
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||||
):
|
):
|
||||||
setup_model_dir(model_dir)
|
setup_model_dir(model_dir)
|
||||||
pass
|
scanned_models = scan_models()
|
||||||
|
if model not in [it.name for it in scanned_models]:
|
||||||
|
logger.info(f"{model} not found in {model_dir}, try to downloading")
|
||||||
|
cli_download_model(model, model_dir)
|
||||||
|
|
||||||
|
from lama_cleaner.batch_processing import batch_inpaint
|
||||||
|
|
||||||
|
batch_inpaint(model, device, image, mask, output, config, concat)
|
||||||
|
|
||||||
|
|
||||||
@typer_app.command(help="Start lama cleaner server")
|
@typer_app.command(help="Start lama cleaner server")
|
||||||
|
@ -65,7 +65,7 @@ class InpaintModel:
|
|||||||
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"final forward pad size: {pad_image.shape}")
|
# logger.info(f"final forward pad size: {pad_image.shape}")
|
||||||
|
|
||||||
image, mask = self.forward_pre_process(image, mask, config)
|
image, mask = self.forward_pre_process(image, mask, config)
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ class InpaintModel:
|
|||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
inpaint_result = None
|
inpaint_result = None
|
||||||
logger.info(f"hd_strategy: {config.hd_strategy}")
|
# logger.info(f"hd_strategy: {config.hd_strategy}")
|
||||||
if config.hd_strategy == HDStrategy.CROP:
|
if config.hd_strategy == HDStrategy.CROP:
|
||||||
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
||||||
logger.info(f"Run crop strategy")
|
logger.info(f"Run crop strategy")
|
||||||
@ -189,7 +189,7 @@ class InpaintModel:
|
|||||||
crop_img = image[t:b, l:r, :]
|
crop_img = image[t:b, l:r, :]
|
||||||
crop_mask = mask[t:b, l:r]
|
crop_mask = mask[t:b, l:r]
|
||||||
|
|
||||||
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
# logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
||||||
|
|
||||||
return crop_img, crop_mask, [l, t, r, b]
|
return crop_img, crop_mask, [l, t, r, b]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user