IOPaint/lama_cleaner/server.py

416 lines
13 KiB
Python
Raw Normal View History

2022-04-18 09:01:10 +02:00
#!/usr/bin/env python3
import io
2022-11-27 14:25:27 +01:00
import json
2022-04-18 09:01:10 +02:00
import logging
import multiprocessing
import os
2022-09-15 16:21:27 +02:00
import random
2022-04-18 09:01:10 +02:00
import time
import imghdr
from pathlib import Path
from typing import Union
2022-12-10 15:06:15 +01:00
from PIL import Image
2022-04-18 09:01:10 +02:00
import cv2
import torch
import numpy as np
from loguru import logger
2023-01-08 13:53:55 +01:00
from watchdog.events import FileSystemEventHandler
2022-04-18 09:01:10 +02:00
2022-11-27 14:25:27 +01:00
from lama_cleaner.interactive_seg import InteractiveSeg, Click
2022-04-18 09:01:10 +02:00
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
2022-12-31 14:07:08 +01:00
from lama_cleaner.file_manager import FileManager
2022-04-18 09:01:10 +02:00
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
2022-12-31 14:07:08 +01:00
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
2022-04-18 16:54:34 +02:00
2022-04-18 15:30:49 +02:00
# Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
cli.show_server_banner = lambda *_: None
2022-04-18 09:01:10 +02:00
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
2022-09-15 16:21:27 +02:00
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
2022-04-18 09:01:10 +02:00
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
2022-04-18 09:29:29 +02:00
class NoFlaskwebgui(logging.Filter):
def filter(self, record):
2022-11-15 14:53:16 +01:00
return "flaskwebgui-keep-server-alive" not in record.getMessage()
2022-04-18 09:01:10 +02:00
2022-04-18 09:29:29 +02:00
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
2022-04-18 09:01:10 +02:00
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"])
model: ModelManager = None
2023-01-08 13:53:55 +01:00
thumb: FileManager = None
2022-11-27 14:25:27 +01:00
interactive_seg_model: InteractiveSeg = None
2022-04-18 09:01:10 +02:00
device = None
input_image_path: str = None
2022-11-13 06:14:37 +01:00
is_disable_model_switch: bool = False
2023-01-01 15:36:11 +01:00
is_enable_file_manager: bool = False
2022-11-16 10:59:39 +01:00
is_desktop: bool = False
2022-04-18 09:01:10 +02:00
2022-11-18 14:40:12 +01:00
2022-04-18 09:01:10 +02:00
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
2022-10-15 16:32:25 +02:00
def diffuser_callback(i, t, latents):
2022-09-15 16:21:27 +02:00
pass
# socketio.emit('diffusion_step', {'diffusion_step': step})
2023-01-05 15:07:39 +01:00
@app.route("/save_image", methods=["POST"])
def save_image():
# all image in output directory
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
thumb.save_to_output_directory(image, request.form["filename"])
return 'ok', 200
2023-01-07 13:51:05 +01:00
@app.route("/medias/<tab>")
def medias(tab):
if tab == 'image':
2023-01-08 13:53:55 +01:00
response = make_response(jsonify(thumb.media_names), 200)
else:
response = make_response(jsonify(thumb.output_media_names), 200)
2023-01-08 14:59:26 +01:00
# response.last_modified = thumb.modified_time[tab]
2023-01-08 13:53:55 +01:00
# response.cache_control.no_cache = True
# response.cache_control.max_age = 0
2023-01-08 14:59:26 +01:00
# response.make_conditional(request)
2023-01-08 13:53:55 +01:00
return response
2022-12-31 14:07:08 +01:00
2023-01-07 13:51:05 +01:00
@app.route('/media/<tab>/<filename>')
def media_file(tab, filename):
if tab == 'image':
return send_from_directory(thumb.root_directory, filename)
return send_from_directory(thumb.output_dir, filename)
2022-12-31 14:07:08 +01:00
2023-01-07 13:51:05 +01:00
@app.route('/media_thumbnail/<tab>/<filename>')
def media_thumbnail_file(tab, filename):
2022-12-31 14:07:08 +01:00
args = request.args
width = args.get('width')
height = args.get('height')
if width is None and height is None:
width = 256
if width:
width = int(float(width))
if height:
height = int(float(height))
2023-01-07 13:51:05 +01:00
directory = thumb.root_directory
if tab == 'output':
directory = thumb.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height)
2022-12-31 14:07:08 +01:00
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
response = make_response(send_file(thumb_filepath))
response.headers["X-Width"] = str(width)
response.headers["X-Height"] = str(height)
return response
2022-04-18 09:01:10 +02:00
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
2022-12-10 15:06:15 +01:00
mask, _ = load_img(input["mask"].read(), gray=True)
2022-11-27 14:25:27 +01:00
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
2022-04-18 09:01:10 +02:00
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
2022-12-10 15:06:15 +01:00
if "paintByExampleImage" in input:
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else:
paint_by_example_example_image = None
2022-04-18 09:01:10 +02:00
config = Config(
2022-04-18 16:54:34 +02:00
ldm_steps=form["ldmSteps"],
2022-06-12 07:14:17 +02:00
ldm_sampler=form["ldmSampler"],
2022-04-18 16:54:34 +02:00
hd_strategy=form["hdStrategy"],
2022-07-13 03:04:28 +02:00
zits_wireframe=form["zitsWireframe"],
2022-04-18 16:54:34 +02:00
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
hd_strategy_resize_limit=form["hdStrategyResizeLimit"],
2022-09-20 16:43:20 +02:00
prompt=form["prompt"],
2022-11-08 14:58:48 +01:00
negative_prompt=form["negativePrompt"],
2022-09-20 16:43:20 +02:00
use_croper=form["useCroper"],
croper_x=form["croperX"],
croper_y=form["croperY"],
croper_height=form["croperHeight"],
croper_width=form["croperWidth"],
2023-01-05 15:07:39 +01:00
sd_scale=form["sdScale"],
2022-09-22 15:50:41 +02:00
sd_mask_blur=form["sdMaskBlur"],
2022-09-15 16:21:27 +02:00
sd_strength=form["sdStrength"],
sd_steps=form["sdSteps"],
sd_guidance_scale=form["sdGuidanceScale"],
sd_sampler=form["sdSampler"],
sd_seed=form["sdSeed"],
sd_match_histograms=form["sdMatchHistograms"],
2022-10-09 15:32:13 +02:00
cv2_flag=form["cv2Flag"],
2022-12-10 15:06:15 +01:00
cv2_radius=form['cv2Radius'],
paint_by_example_steps=form["paintByExampleSteps"],
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
paint_by_example_seed=form["paintByExampleSeed"],
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
paint_by_example_example_image=paint_by_example_example_image,
2022-04-18 09:01:10 +02:00
)
2022-09-15 16:21:27 +02:00
if config.sd_seed == -1:
2022-10-15 16:32:25 +02:00
config.sd_seed = random.randint(1, 999999999)
2022-12-10 15:06:15 +01:00
if config.paint_by_example_seed == -1:
config.paint_by_example_seed = random.randint(1, 999999999)
2022-09-15 16:21:27 +02:00
2022-04-18 09:01:10 +02:00
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time()
try:
res_np_img = model(image, mask, config)
except RuntimeError as e:
2022-11-18 14:40:12 +01:00
torch.cuda.empty_cache()
if "CUDA out of memory. " in str(e):
2022-11-18 14:40:12 +01:00
# NOTE: the string may change?
return "CUDA out of memory", 500
2022-11-18 14:40:12 +01:00
else:
logger.exception(e)
return "Internal Server Error", 500
finally:
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
2022-04-18 09:01:10 +02:00
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
2022-09-20 16:43:20 +02:00
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
)
2022-04-18 09:01:10 +02:00
)
2022-09-20 16:43:20 +02:00
response.headers["X-Seed"] = str(config.sd_seed)
return response
2022-04-18 09:01:10 +02:00
2022-11-27 14:25:27 +01:00
@app.route("/interactive_seg", methods=["POST"])
def interactive_seg():
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
if 'mask' in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
_clicks = json.loads(request.form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
mimetype=f"image/png",
)
)
return response
2022-04-18 09:01:10 +02:00
@app.route("/model")
def current_model():
return model.name, 200
2022-11-18 14:40:12 +01:00
2022-11-13 06:14:37 +01:00
@app.route("/is_disable_model_switch")
def get_is_disable_model_switch():
res = 'true' if is_disable_model_switch else 'false'
return res, 200
2022-04-18 09:01:10 +02:00
2023-01-01 15:36:11 +01:00
@app.route("/is_enable_file_manager")
def get_is_enable_file_manager():
res = 'true' if is_enable_file_manager else 'false'
return res, 200
2022-04-18 09:01:10 +02:00
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
2022-11-18 14:40:12 +01:00
2022-11-16 10:59:39 +01:00
@app.route("/is_desktop")
def get_is_desktop():
return str(is_desktop), 200
2022-04-18 09:01:10 +02:00
2022-11-18 14:40:12 +01:00
2022-04-18 09:01:10 +02:00
@app.route("/model", methods=["POST"])
def switch_model():
2022-11-25 01:55:10 +01:00
if is_disable_model_switch:
return "Switch model is disabled", 400
2022-04-18 09:01:10 +02:00
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
2022-11-17 15:11:54 +01:00
return send_file(os.path.join(BUILD_DIR, "index.html"), cache_timeout=0)
2022-04-18 09:01:10 +02:00
@app.route("/inputimage")
def set_input_photo():
if input_image_path:
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
input_image_path,
as_attachment=True,
2022-04-18 16:54:34 +02:00
attachment_filename=Path(input_image_path).name,
2022-04-18 09:01:10 +02:00
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
2023-01-08 13:53:55 +01:00
class FSHandler(FileSystemEventHandler):
def on_modified(self, event):
print("File modified: %s" % event.src_path)
2022-04-18 09:01:10 +02:00
def main(args):
global model
2022-11-27 14:25:27 +01:00
global interactive_seg_model
2022-04-18 09:01:10 +02:00
global device
global input_image_path
2022-11-13 06:14:37 +01:00
global is_disable_model_switch
2023-01-01 15:36:11 +01:00
global is_enable_file_manager
2022-11-16 10:59:39 +01:00
global is_desktop
2023-01-08 13:53:55 +01:00
global thumb
2022-04-18 09:01:10 +02:00
device = torch.device(args.device)
2022-11-13 06:14:37 +01:00
is_disable_model_switch = args.disable_model_switch
2022-11-16 10:59:39 +01:00
is_desktop = args.gui
2022-11-13 06:14:37 +01:00
if is_disable_model_switch:
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable")
2022-04-18 09:01:10 +02:00
2023-01-08 14:26:57 +01:00
if args.input and os.path.isdir(args.input):
2023-01-08 13:53:55 +01:00
logger.info(f"Initialize file manager")
thumb = FileManager(app)
is_enable_file_manager = True
2022-12-31 14:07:08 +01:00
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
2023-01-07 13:51:05 +01:00
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails')
2023-01-05 15:07:39 +01:00
thumb.output_dir = Path(args.output_dir)
2023-01-08 14:59:26 +01:00
# thumb.start()
2023-01-08 13:53:55 +01:00
# try:
# while True:
# time.sleep(1)
# finally:
# thumb.image_dir_observer.stop()
# thumb.image_dir_observer.join()
# thumb.output_dir_observer.stop()
# thumb.output_dir_observer.join()
2022-12-31 14:07:08 +01:00
else:
input_image_path = args.input
2022-09-20 16:43:20 +02:00
model = ModelManager(
name=args.model,
device=device,
2023-01-03 14:30:33 +01:00
no_half=args.no_half,
2022-09-20 16:43:20 +02:00
hf_access_token=args.hf_access_token,
2022-09-29 03:42:19 +02:00
sd_disable_nsfw=args.sd_disable_nsfw,
2022-09-29 06:20:55 +02:00
sd_cpu_textencoder=args.sd_cpu_textencoder,
2022-09-29 07:13:09 +02:00
sd_run_local=args.sd_run_local,
2023-01-05 15:07:39 +01:00
local_files_only=args.local_files_only,
cpu_offload=args.cpu_offload,
sd_enable_xformers=args.sd_enable_xformers,
2022-10-15 16:32:25 +02:00
callback=diffuser_callback,
2022-09-20 16:43:20 +02:00
)
2022-04-18 09:01:10 +02:00
2022-11-27 14:25:27 +01:00
interactive_seg_model = InteractiveSeg()
2022-04-18 09:01:10 +02:00
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
2022-04-18 16:54:34 +02:00
ui = FlaskUI(
app, width=app_width, height=app_height, host=args.host, port=args.port
)
2022-04-18 16:28:47 +02:00
ui.run()
2022-04-18 09:01:10 +02:00
else:
app.run(host=args.host, port=args.port, debug=args.debug)