#!/usr/bin/env python3 import argparse import io import logging import multiprocessing import os import time import imghdr from pathlib import Path from typing import Union import cv2 import torch import numpy as np from loguru import logger from lama_cleaner.model_manager import ModelManager from lama_cleaner.schema import Config try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) except: pass from flask import Flask, request, send_file from flask_cors import CORS from lama_cleaner.helper import ( load_img, numpy_to_bytes, resize_max_size, ) NUM_THREADS = str(multiprocessing.cpu_count()) os.environ["OMP_NUM_THREADS"] = NUM_THREADS os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS os.environ["MKL_NUM_THREADS"] = NUM_THREADS os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS if os.environ.get("CACHE_DIR"): os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build") class InterceptHandler(logging.Handler): def emit(self, record): logger_opt = logger.opt(depth=6, exception=record.exc_info) logger_opt.log(record.levelno, record.getMessage()) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app.config["JSON_AS_ASCII"] = False app.logger.addHandler(InterceptHandler()) CORS(app, expose_headers=["Content-Disposition"]) model: ModelManager = None device = None input_image_path: str = None def get_image_ext(img_bytes): w = imghdr.what("", img_bytes) if w is None: w = "jpeg" return w @app.route("/inpaint", methods=["POST"]) def process(): input = request.files # RGB origin_image_bytes = input["image"].read() image, alpha_channel = load_img(origin_image_bytes) original_shape = image.shape interpolation = cv2.INTER_CUBIC form = request.form size_limit: Union[int, str] = form.get("sizeLimit", "1080") if size_limit == "Original": size_limit = max(image.shape) else: size_limit = int(size_limit) config = Config( ldm_steps=form['ldmSteps'], hd_strategy=form['hdStrategy'], hd_strategy_crop_margin=form['hdStrategyCropMargin'], hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'], hd_strategy_resize_limit=form['hdStrategyResizeLimit'], ) logger.info(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) logger.info(f"Resized image shape: {image.shape}") mask, _ = load_img(input["mask"].read(), gray=True) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) start = time.time() res_np_img = model(image, mask, config) logger.info(f"process time: {(time.time() - start) * 1000}ms") torch.cuda.empty_cache() if alpha_channel is not None: if alpha_channel.shape[:2] != res_np_img.shape[:2]: alpha_channel = cv2.resize( alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) ) res_np_img = np.concatenate( (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 ) ext = get_image_ext(origin_image_bytes) return send_file( io.BytesIO(numpy_to_bytes(res_np_img, ext)), mimetype=f"image/{ext}", ) @app.route("/model") def current_model(): return model.name, 200 @app.route("/model_downloaded/") def model_downloaded(name): return str(model.is_downloaded(name)), 200 @app.route("/model", methods=["POST"]) def switch_model(): new_name = request.form.get("name") if new_name == model.name: return "Same model", 200 try: model.switch(new_name) except NotImplementedError: return f"{new_name} not implemented", 403 return f"ok, switch to {new_name}", 200 @app.route("/") def index(): return send_file(os.path.join(BUILD_DIR, "index.html")) @app.route("/inputimage") def set_input_photo(): if input_image_path: with open(input_image_path, "rb") as f: image_in_bytes = f.read() return send_file( input_image_path, as_attachment=True, download_name=Path(input_image_path).name, mimetype=f"image/{get_image_ext(image_in_bytes)}", ) else: return "No Input Image" def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--input", type=str, help="Path to image you want to load by default" ) parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", default=8080, type=int) parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument( "--gui-size", default=[1600, 1000], nargs=2, type=int, help="Set window size for GUI", ) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if args.input is not None: if not os.path.exists(args.input): parser.error(f"invalid --input: {args.input} not exists") if imghdr.what(args.input) is None: parser.error(f"invalid --input: {args.input} is not a valid image file") return args def main(): global model global device global input_image_path args = get_args_parser() device = torch.device(args.device) input_image_path = args.input model = ModelManager(name=args.model, device=device) if args.gui: app_width, app_height = args.gui_size from flaskwebgui import FlaskUI ui = FlaskUI(app, width=app_width, height=app_height) ui.run() else: app.run(host=args.host, port=args.port, debug=args.debug) if __name__ == "__main__": main()