IOPaint/lama_cleaner/server.py
2022-11-30 21:53:56 +08:00

312 lines
9.0 KiB
Python

#!/usr/bin/env python3
import io
import json
import logging
import multiprocessing
import os
import random
import time
import imghdr
from pathlib import Path
from typing import Union
import cv2
import torch
import numpy as np
from loguru import logger
from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from flask import Flask, request, send_file, cli, make_response
# Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
cli.show_server_banner = lambda *_: None
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
class NoFlaskwebgui(logging.Filter):
def filter(self, record):
return "flaskwebgui-keep-server-alive" not in record.getMessage()
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"])
# MAX_BUFFER_SIZE = 50 * 1000 * 1000 # 50 MB
# async_mode 优先级: eventlet/gevent_uwsgi/gevent/threading
# only threading works on macOS
# socketio = SocketIO(app, max_http_buffer_size=MAX_BUFFER_SIZE, async_mode='threading')
model: ModelManager = None
interactive_seg_model: InteractiveSeg = None
device = None
input_image_path: str = None
is_disable_model_switch: bool = False
is_desktop: bool = False
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def diffuser_callback(i, t, latents):
pass
# socketio.emit('diffusion_step', {'diffusion_step': step})
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
config = Config(
ldm_steps=form["ldmSteps"],
ldm_sampler=form["ldmSampler"],
hd_strategy=form["hdStrategy"],
zits_wireframe=form["zitsWireframe"],
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
hd_strategy_resize_limit=form["hdStrategyResizeLimit"],
prompt=form["prompt"],
negative_prompt=form["negativePrompt"],
use_croper=form["useCroper"],
croper_x=form["croperX"],
croper_y=form["croperY"],
croper_height=form["croperHeight"],
croper_width=form["croperWidth"],
sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"],
sd_steps=form["sdSteps"],
sd_guidance_scale=form["sdGuidanceScale"],
sd_sampler=form["sdSampler"],
sd_seed=form["sdSeed"],
sd_match_histograms=form["sdMatchHistograms"],
cv2_flag=form["cv2Flag"],
cv2_radius=form['cv2Radius']
)
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 999999999)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time()
try:
res_np_img = model(image, mask, config)
except RuntimeError as e:
torch.cuda.empty_cache()
if "CUDA out of memory. " in str(e):
# NOTE: the string may change?
return "CUDA out of memory", 500
else:
logger.exception(e)
return "Internal Server Error", 500
finally:
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
)
)
response.headers["X-Seed"] = str(config.sd_seed)
return response
@app.route("/interactive_seg", methods=["POST"])
def interactive_seg():
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
if 'mask' in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
_clicks = json.loads(request.form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
mimetype=f"image/png",
)
)
return response
@app.route("/model")
def current_model():
return model.name, 200
@app.route("/is_disable_model_switch")
def get_is_disable_model_switch():
res = 'true' if is_disable_model_switch else 'false'
return res, 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@app.route("/is_desktop")
def get_is_desktop():
return str(is_desktop), 200
@app.route("/model", methods=["POST"])
def switch_model():
if is_disable_model_switch:
return "Switch model is disabled", 400
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"), cache_timeout=0)
@app.route("/inputimage")
def set_input_photo():
if input_image_path:
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
input_image_path,
as_attachment=True,
attachment_filename=Path(input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def main(args):
global model
global interactive_seg_model
global device
global input_image_path
global is_disable_model_switch
global is_desktop
device = torch.device(args.device)
input_image_path = args.input
is_disable_model_switch = args.disable_model_switch
is_desktop = args.gui
if is_disable_model_switch:
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable")
model = ModelManager(
name=args.model,
device=device,
hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder,
sd_run_local=args.sd_run_local,
sd_enable_xformers=args.sd_enable_xformers,
callback=diffuser_callback,
)
interactive_seg_model = InteractiveSeg()
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
ui = FlaskUI(
app, width=app_width, height=app_height, host=args.host, port=args.port
)
ui.run()
else:
# TODO: socketio
app.run(host=args.host, port=args.port, debug=args.debug)