IOPaint/lama_cleaner/server.py

708 lines
22 KiB
Python
Raw Normal View History

2022-04-18 09:01:10 +02:00
#!/usr/bin/env python3
2023-12-19 06:16:30 +01:00
import json
2023-03-29 16:05:34 +02:00
import os
2023-12-24 08:32:27 +01:00
import typer
from typer import Option
from lama_cleaner.download import cli_download_model, scan_models
from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
2023-05-13 07:45:27 +02:00
import hashlib
2023-12-15 05:40:29 +01:00
import traceback
2023-12-24 08:32:27 +01:00
from dataclasses import dataclass
2023-05-09 13:07:12 +02:00
2022-04-18 09:01:10 +02:00
2023-03-26 06:37:58 +02:00
import imghdr
2022-04-18 09:01:10 +02:00
import io
import logging
import multiprocessing
2022-09-15 16:21:27 +02:00
import random
2022-04-18 09:01:10 +02:00
import time
from pathlib import Path
import cv2
import numpy as np
2023-03-26 06:37:58 +02:00
import torch
from PIL import Image
2022-04-18 09:01:10 +02:00
from loguru import logger
2023-12-24 08:32:27 +01:00
from lama_cleaner.const import *
2023-03-26 06:37:58 +02:00
from lama_cleaner.file_manager import FileManager
2023-03-22 05:57:18 +01:00
from lama_cleaner.model.utils import torch_gc
2022-04-18 09:01:10 +02:00
from lama_cleaner.model_manager import ModelManager
2023-03-26 07:39:09 +02:00
from lama_cleaner.plugins import (
InteractiveSeg,
RemoveBG,
RealESRGANUpscaler,
GFPGANPlugin,
2023-03-30 15:02:18 +02:00
RestoreFormerPlugin,
2023-05-13 07:45:27 +02:00
AnimeSeg,
2023-03-26 07:39:09 +02:00
)
2022-04-18 09:01:10 +02:00
from lama_cleaner.schema import Config
2023-12-24 08:32:27 +01:00
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
2022-04-18 09:01:10 +02:00
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
2023-02-06 15:00:47 +01:00
from flask import (
Flask,
request,
send_file,
cli,
make_response,
send_from_directory,
jsonify,
)
2023-05-07 07:25:46 +02:00
from flask_socketio import SocketIO
2022-04-18 16:54:34 +02:00
2022-04-18 15:30:49 +02:00
# Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
cli.show_server_banner = lambda *_: None
2022-04-18 09:01:10 +02:00
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
2023-02-06 15:00:47 +01:00
pil_to_bytes,
2022-04-18 09:01:10 +02:00
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
2022-09-15 16:21:27 +02:00
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
2022-04-18 09:01:10 +02:00
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
2022-04-18 09:29:29 +02:00
class NoFlaskwebgui(logging.Filter):
def filter(self, record):
2023-05-13 07:45:27 +02:00
msg = record.getMessage()
if "Running on http:" in msg:
print(msg[msg.index("Running on http:") :])
return (
"flaskwebgui-keep-server-alive" not in msg
and "socket.io" not in msg
and "This is a development server." not in msg
)
2022-04-18 09:01:10 +02:00
2022-04-18 09:29:29 +02:00
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
2022-04-18 09:01:10 +02:00
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
2023-12-24 08:32:27 +01:00
CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
2023-05-13 07:45:27 +02:00
sio_logger = logging.getLogger("sio-logger")
sio_logger.setLevel(logging.ERROR)
2023-05-07 10:58:55 +02:00
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
2022-04-18 09:01:10 +02:00
2023-12-24 08:32:27 +01:00
@dataclass
class GlobalConfig:
model_manager: ModelManager = None
file_manager: FileManager = None
output_dir: Path = None
input_image_path: Path = None
disable_model_switch: bool = False
is_desktop: bool = False
image_quality: int = 95
plugins = {}
@property
def enable_auto_saving(self) -> bool:
return self.output_dir is not None
@property
def enable_file_manager(self) -> bool:
return self.file_manager is not None
global_config = GlobalConfig()
2022-04-18 09:01:10 +02:00
2022-11-18 14:40:12 +01:00
2022-04-18 09:01:10 +02:00
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
2022-10-15 16:32:25 +02:00
def diffuser_callback(i, t, latents):
2023-05-07 10:58:55 +02:00
socketio.emit("diffusion_progress", {"step": i})
2022-09-15 16:21:27 +02:00
2023-01-05 15:07:39 +01:00
@app.route("/save_image", methods=["POST"])
def save_image():
2023-12-24 08:32:27 +01:00
if global_config.output_dir is None:
return "--output-dir is None", 500
2023-01-05 15:07:39 +01:00
input = request.files
filename = request.form["filename"]
2023-01-05 15:07:39 +01:00
origin_image_bytes = input["image"].read() # RGB
ext = get_image_ext(origin_image_bytes)
2023-06-27 04:08:36 +02:00
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
2023-12-24 08:32:27 +01:00
save_path = str(global_config.output_dir / filename)
2023-06-27 04:08:36 +02:00
if alpha_channel is not None:
if alpha_channel.shape[:2] != image.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(image.shape[1], image.shape[0])
)
image = np.concatenate((image, alpha_channel[:, :, np.newaxis]), axis=-1)
pil_image = Image.fromarray(image)
img_bytes = pil_to_bytes(
2023-06-27 04:08:36 +02:00
pil_image,
ext,
2023-12-24 08:32:27 +01:00
quality=global_config.image_quality,
exif_infos=exif_infos,
)
with open(save_path, "wb") as fw:
fw.write(img_bytes)
2023-02-06 15:00:47 +01:00
return "ok", 200
2023-01-05 15:07:39 +01:00
2023-01-07 13:51:05 +01:00
@app.route("/medias/<tab>")
def medias(tab):
2023-02-06 15:00:47 +01:00
if tab == "image":
2023-12-24 08:32:27 +01:00
response = make_response(jsonify(global_config.file_manager.media_names), 200)
2023-01-08 13:53:55 +01:00
else:
2023-12-24 08:32:27 +01:00
response = make_response(
jsonify(global_config.file_manager.output_media_names), 200
)
2023-01-08 14:59:26 +01:00
# response.last_modified = thumb.modified_time[tab]
2023-01-08 13:53:55 +01:00
# response.cache_control.no_cache = True
# response.cache_control.max_age = 0
2023-01-08 14:59:26 +01:00
# response.make_conditional(request)
2023-01-08 13:53:55 +01:00
return response
2022-12-31 14:07:08 +01:00
2023-02-06 15:00:47 +01:00
@app.route("/media/<tab>/<filename>")
2023-01-07 13:51:05 +01:00
def media_file(tab, filename):
2023-02-06 15:00:47 +01:00
if tab == "image":
2023-12-24 08:32:27 +01:00
return send_from_directory(global_config.file_manager.root_directory, filename)
return send_from_directory(global_config.file_manager.output_dir, filename)
2022-12-31 14:07:08 +01:00
2023-02-06 15:00:47 +01:00
@app.route("/media_thumbnail/<tab>/<filename>")
2023-01-07 13:51:05 +01:00
def media_thumbnail_file(tab, filename):
2022-12-31 14:07:08 +01:00
args = request.args
2023-02-06 15:00:47 +01:00
width = args.get("width")
height = args.get("height")
2022-12-31 14:07:08 +01:00
if width is None and height is None:
width = 256
if width:
width = int(float(width))
if height:
height = int(float(height))
2023-12-24 08:32:27 +01:00
directory = global_config.file_manager.root_directory
2023-02-06 15:00:47 +01:00
if tab == "output":
2023-12-24 08:32:27 +01:00
directory = global_config.file_manager.output_dir
thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
2023-02-06 15:00:47 +01:00
directory, filename, width, height
)
2022-12-31 14:07:08 +01:00
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
response = make_response(send_file(thumb_filepath))
response.headers["X-Width"] = str(width)
response.headers["X-Height"] = str(height)
return response
2022-04-18 09:01:10 +02:00
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
2023-05-07 10:58:55 +02:00
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
2022-12-10 15:06:15 +01:00
mask, _ = load_img(input["mask"].read(), gray=True)
2022-11-27 14:25:27 +01:00
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
2023-02-06 15:00:47 +01:00
return (
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
400,
)
2022-04-18 09:01:10 +02:00
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
form = request.form
2023-03-25 07:11:00 +01:00
size_limit = max(image.shape)
2022-04-18 09:01:10 +02:00
2022-12-10 15:06:15 +01:00
if "paintByExampleImage" in input:
2023-02-06 15:00:47 +01:00
paint_by_example_example_image, _ = load_img(
input["paintByExampleImage"].read()
)
2022-12-10 15:06:15 +01:00
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else:
paint_by_example_example_image = None
2022-04-18 09:01:10 +02:00
config = Config(
2022-04-18 16:54:34 +02:00
ldm_steps=form["ldmSteps"],
2022-06-12 07:14:17 +02:00
ldm_sampler=form["ldmSampler"],
2022-04-18 16:54:34 +02:00
hd_strategy=form["hdStrategy"],
2022-07-13 03:04:28 +02:00
zits_wireframe=form["zitsWireframe"],
2022-04-18 16:54:34 +02:00
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
hd_strategy_resize_limit=form["hdStrategyResizeLimit"],
2022-09-20 16:43:20 +02:00
prompt=form["prompt"],
2022-11-08 14:58:48 +01:00
negative_prompt=form["negativePrompt"],
2022-09-20 16:43:20 +02:00
use_croper=form["useCroper"],
croper_x=form["croperX"],
croper_y=form["croperY"],
croper_height=form["croperHeight"],
croper_width=form["croperWidth"],
2023-12-22 07:00:30 +01:00
use_extender=form["useExtender"],
extender_x=form["extenderX"],
extender_y=form["extenderY"],
extender_height=form["extenderHeight"],
extender_width=form["extenderWidth"],
2023-01-05 15:07:39 +01:00
sd_scale=form["sdScale"],
2022-09-22 15:50:41 +02:00
sd_mask_blur=form["sdMaskBlur"],
2022-09-15 16:21:27 +02:00
sd_strength=form["sdStrength"],
sd_steps=form["sdSteps"],
sd_guidance_scale=form["sdGuidanceScale"],
sd_sampler=form["sdSampler"],
sd_seed=form["sdSeed"],
2023-12-19 06:16:30 +01:00
sd_freeu=form["enableFreeu"],
sd_freeu_config=json.loads(form["freeuConfig"]),
sd_lcm_lora=form["enableLCMLora"],
sd_match_histograms=form["sdMatchHistograms"],
2022-10-09 15:32:13 +02:00
cv2_flag=form["cv2Flag"],
2023-02-06 15:00:47 +01:00
cv2_radius=form["cv2Radius"],
2022-12-10 15:06:15 +01:00
paint_by_example_example_image=paint_by_example_example_image,
2023-01-28 14:13:21 +01:00
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
2023-12-19 06:16:30 +01:00
controlnet_enabled=form["controlnet_enabled"],
2023-03-19 15:40:23 +01:00
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
2023-05-13 07:45:27 +02:00
controlnet_method=form["controlnet_method"],
2022-04-18 09:01:10 +02:00
)
2022-09-15 16:21:27 +02:00
if config.sd_seed == -1:
2023-12-19 06:16:30 +01:00
config.sd_seed = random.randint(1, 99999999)
2022-09-15 16:21:27 +02:00
2022-04-18 09:01:10 +02:00
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time()
try:
2023-12-24 08:32:27 +01:00
res_np_img = global_config.model_manager(image, mask, config)
except RuntimeError as e:
if "CUDA out of memory. " in str(e):
2022-11-18 14:40:12 +01:00
# NOTE: the string may change?
return "CUDA out of memory", 500
2022-11-18 14:40:12 +01:00
else:
logger.exception(e)
2023-05-18 07:07:12 +02:00
return f"{str(e)}", 500
finally:
logger.info(f"process time: {(time.time() - start) * 1000}ms")
2023-05-20 06:35:36 +02:00
torch_gc()
2022-04-18 09:01:10 +02:00
2023-02-06 15:00:47 +01:00
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
2022-04-18 09:01:10 +02:00
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
2022-09-20 16:43:20 +02:00
2023-05-07 10:58:55 +02:00
bytes_io = io.BytesIO(
pil_to_bytes(
Image.fromarray(res_np_img),
ext,
2023-12-24 08:32:27 +01:00
quality=global_config.image_quality,
2023-05-07 10:58:55 +02:00
exif_infos=exif_infos,
)
)
2023-02-06 15:00:47 +01:00
2022-09-20 16:43:20 +02:00
response = make_response(
send_file(
2023-02-06 15:00:47 +01:00
# io.BytesIO(numpy_to_bytes(res_np_img, ext)),
bytes_io,
2022-09-20 16:43:20 +02:00
mimetype=f"image/{ext}",
)
2022-04-18 09:01:10 +02:00
)
2022-09-20 16:43:20 +02:00
response.headers["X-Seed"] = str(config.sd_seed)
2023-05-07 07:25:46 +02:00
2023-05-07 10:58:55 +02:00
socketio.emit("diffusion_finish")
2022-09-20 16:43:20 +02:00
return response
2022-04-18 09:01:10 +02:00
2023-03-25 02:53:22 +01:00
@app.route("/run_plugin", methods=["POST"])
2023-03-22 05:57:18 +01:00
def run_plugin():
form = request.form
files = request.files
name = form["name"]
2023-12-24 08:32:27 +01:00
if name not in global_config.plugins:
2023-03-22 05:57:18 +01:00
return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB
2023-05-07 10:58:55 +02:00
rgb_np_img, alpha_channel, exif_infos = load_img(
origin_image_bytes, return_exif=True
)
2022-11-27 14:25:27 +01:00
start = time.time()
2023-03-30 15:02:18 +02:00
try:
2023-04-06 15:55:20 +02:00
form = dict(form)
if name == InteractiveSeg.name:
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
form["img_md5"] = img_md5
2023-12-24 08:32:27 +01:00
bgr_res = global_config.plugins[name](rgb_np_img, files, form)
2023-03-30 15:02:18 +02:00
except RuntimeError as e:
torch.cuda.empty_cache()
if "CUDA out of memory. " in str(e):
# NOTE: the string may change?
return "CUDA out of memory", 500
else:
logger.exception(e)
return "Internal Server Error", 500
2023-03-22 05:57:18 +01:00
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
torch_gc()
2023-03-26 15:02:51 +02:00
if name == InteractiveSeg.name:
return make_response(
send_file(
io.BytesIO(numpy_to_bytes(bgr_res, "png")),
mimetype="image/png",
)
)
2023-03-26 14:42:31 +02:00
2023-05-09 13:07:12 +02:00
if name in [RemoveBG.name, AnimeSeg.name]:
rgb_res = bgr_res
2023-03-26 14:42:31 +02:00
ext = "png"
2023-03-25 02:53:22 +01:00
else:
2023-03-26 14:42:31 +02:00
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
ext = get_image_ext(origin_image_bytes)
if alpha_channel is not None:
if alpha_channel.shape[:2] != rgb_res.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(rgb_res.shape[1], rgb_res.shape[0])
)
rgb_res = np.concatenate(
(rgb_res, alpha_channel[:, :, np.newaxis]), axis=-1
2023-03-25 02:53:22 +01:00
)
2023-03-26 14:42:31 +02:00
response = make_response(
send_file(
io.BytesIO(
pil_to_bytes(
2023-05-07 10:58:55 +02:00
Image.fromarray(rgb_res),
ext,
2023-12-24 08:32:27 +01:00
quality=global_config.image_quality,
2023-05-07 10:58:55 +02:00
exif_infos=exif_infos,
2023-03-26 14:42:31 +02:00
)
),
mimetype=f"image/{ext}",
2022-11-27 14:25:27 +01:00
)
2023-03-26 14:42:31 +02:00
)
2022-11-27 14:25:27 +01:00
return response
2023-03-25 02:53:22 +01:00
@app.route("/server_config", methods=["GET"])
def get_server_config():
return {
2023-12-24 08:32:27 +01:00
"plugins": list(global_config.plugins.keys()),
"enableFileManager": global_config.enable_file_manager,
"enableAutoSaving": global_config.enable_auto_saving,
"enableControlnet": global_config.model_manager.sd_controlnet,
"controlnetMethod": global_config.model_manager.sd_controlnet_method,
"disableModelSwitch": global_config.disable_model_switch,
"isDesktop": global_config.is_desktop,
2023-03-25 02:53:22 +01:00
}, 200
2023-03-22 05:57:18 +01:00
2023-12-01 03:15:35 +01:00
@app.route("/models", methods=["GET"])
def get_models():
2023-12-24 08:32:27 +01:00
return [it.model_dump() for it in global_config.model_manager.scan_models()]
2023-11-16 14:45:55 +01:00
2022-04-18 09:01:10 +02:00
@app.route("/model")
def current_model():
2023-12-24 08:32:27 +01:00
return (
global_config.model_manager.current_model,
200,
)
2022-04-18 09:01:10 +02:00
2022-11-18 14:40:12 +01:00
2022-04-18 09:01:10 +02:00
@app.route("/model", methods=["POST"])
def switch_model():
2023-12-24 08:32:27 +01:00
if global_config.disable_model_switch:
2022-11-25 01:55:10 +01:00
return "Switch model is disabled", 400
2022-04-18 09:01:10 +02:00
new_name = request.form.get("name")
2023-12-24 08:32:27 +01:00
if new_name == global_config.model_manager.name:
2022-04-18 09:01:10 +02:00
return "Same model", 200
try:
2023-12-24 08:32:27 +01:00
global_config.model_manager.switch(new_name)
2023-12-01 03:15:35 +01:00
except Exception as e:
2023-12-15 05:40:29 +01:00
traceback.print_exc()
error_message = f"{type(e).__name__} - {str(e)}"
2023-12-01 03:15:35 +01:00
logger.error(error_message)
return f"Switch model failed: {error_message}", 500
2022-04-18 09:01:10 +02:00
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
2023-04-16 07:45:34 +02:00
return send_file(os.path.join(BUILD_DIR, "index.html"))
2022-04-18 09:01:10 +02:00
@app.route("/inputimage")
2023-12-01 03:15:35 +01:00
def get_cli_input_image():
2023-12-24 08:32:27 +01:00
if global_config.input_image_path:
with open(global_config.input_image_path, "rb") as f:
2022-04-18 09:01:10 +02:00
image_in_bytes = f.read()
return send_file(
2023-12-24 08:32:27 +01:00
global_config.input_image_path,
2022-04-18 09:01:10 +02:00
as_attachment=True,
2023-12-24 08:32:27 +01:00
download_name=Path(global_config.input_image_path).name,
2022-04-18 09:01:10 +02:00
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
2023-12-24 08:32:27 +01:00
def build_plugins(
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: str,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
no_half: bool,
):
if enable_interactive_seg:
2023-03-22 05:57:18 +01:00
logger.info(f"Initialize {InteractiveSeg.name} plugin")
2023-12-24 08:32:27 +01:00
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
interactive_seg_model, interactive_seg_device
2023-04-06 15:55:20 +02:00
)
2023-05-09 13:07:12 +02:00
2023-12-24 08:32:27 +01:00
if enable_remove_bg:
2023-03-22 05:57:18 +01:00
logger.info(f"Initialize {RemoveBG.name} plugin")
2023-12-24 08:32:27 +01:00
global_config.plugins[RemoveBG.name] = RemoveBG()
2023-05-09 13:07:12 +02:00
2023-12-24 08:32:27 +01:00
if enable_anime_seg:
2023-05-09 13:07:12 +02:00
logger.info(f"Initialize {AnimeSeg.name} plugin")
2023-12-24 08:32:27 +01:00
global_config.plugins[AnimeSeg.name] = AnimeSeg()
2023-05-09 13:07:12 +02:00
2023-12-24 08:32:27 +01:00
if enable_realesrgan:
2023-03-25 03:15:44 +01:00
logger.info(
2023-12-24 08:32:27 +01:00
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
2023-03-25 03:15:44 +01:00
)
2023-12-24 08:32:27 +01:00
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
realesrgan_model,
realesrgan_device,
no_half=no_half,
2023-03-25 03:15:44 +01:00
)
2023-05-09 13:07:12 +02:00
2023-12-24 08:32:27 +01:00
if enable_gfpgan:
2023-03-26 07:39:09 +02:00
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
2023-12-24 08:32:27 +01:00
if enable_realesrgan:
2023-04-06 15:55:20 +02:00
logger.info("Use realesrgan as GFPGAN background upscaler")
else:
logger.info(
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
)
2023-12-24 08:32:27 +01:00
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
gfpgan_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
2023-03-26 14:52:06 +02:00
)
2023-05-09 13:07:12 +02:00
2023-12-24 08:32:27 +01:00
if enable_restoreformer:
2023-03-30 10:07:38 +02:00
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
2023-12-24 08:32:27 +01:00
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
restoreformer_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
2023-03-30 10:07:38 +02:00
)
2023-05-09 13:07:12 +02:00
2022-04-18 09:01:10 +02:00
2023-12-24 08:32:27 +01:00
@typer_app.command(help="Install all plugins dependencies")
def install_plugins_packages():
from lama_cleaner.installer import install_plugins_package
install_plugins_package()
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
def download(
model: str = Option(
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
),
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
cli_download_model(model, model_dir)
@typer_app.command(help="List downloaded models")
def list_model(
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
setup_model_dir(model_dir)
scanned_models = scan_models()
for it in scanned_models:
print(it.name)
@typer_app.command(help="Start lama cleaner server")
def start(
host: str = Option("127.0.0.1"),
port: int = Option(8080),
model: str = Option(
DEFAULT_MODEL,
2023-12-25 03:41:28 +01:00
help=f"Available erase models: [{', '.join(AVAILABLE_MODELS)}]. "
2023-12-24 08:32:27 +01:00
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
no_half: bool = Option(False, help=NO_HALF_HELP),
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu),
gui: bool = Option(False, help=GUI_HELP),
disable_model_switch: bool = Option(False),
input: Path = Option(None, help=INPUT_HELP),
output_dir: Path = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
),
quality: int = Option(95, help=QUALITY_HELP),
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
interactive_seg_model: InteractiveSegModel = Option(
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
),
interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu),
realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3),
enable_gfpgan: bool = Option(False),
gfpgan_device: Device = Option(Device.cpu),
enable_restoreformer: bool = Option(False),
restoreformer_device: Device = Option(Device.cpu),
):
global global_config
dump_environment_info()
if input:
if not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if input.is_dir():
logger.info(f"Initialize file manager")
file_manager = FileManager(app)
app.config["THUMBNAIL_MEDIA_ROOT"] = input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
file_manager.output_dir = output_dir
else:
global_config.input_image_path = input
2023-01-08 13:53:55 +01:00
2023-12-24 08:32:27 +01:00
device = check_device(device)
setup_model_dir(model_dir)
if local_files_only:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
2023-12-25 03:41:28 +01:00
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.error(
f"invalid model: {model} not exists. Available models: {[it.name for it in scanned_models]}"
)
exit()
2023-12-24 08:32:27 +01:00
global_config.image_quality = quality
global_config.disable_model_switch = disable_model_switch
global_config.is_desktop = gui
build_plugins(
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
no_half,
)
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will auto save to output dir: {output_dir}")
global_config.output_dir = output_dir
global_config.model_manager = ModelManager(
name=model,
device=torch.device(device),
no_half=no_half,
disable_nsfw=disable_nsfw_checker,
sd_cpu_textencoder=cpu_textencoder,
cpu_offload=cpu_offload,
2022-10-15 16:32:25 +02:00
callback=diffuser_callback,
2022-09-20 16:43:20 +02:00
)
2022-04-18 09:01:10 +02:00
2023-12-24 08:32:27 +01:00
if gui:
2022-04-18 09:01:10 +02:00
from flaskwebgui import FlaskUI
2022-04-18 16:54:34 +02:00
ui = FlaskUI(
2023-02-06 15:00:47 +01:00
app,
2023-05-07 07:25:46 +02:00
socketio=socketio,
2023-12-24 08:32:27 +01:00
width=1200,
height=800,
host=host,
port=port,
close_server_on_exit=True,
idle_interval=60,
2022-04-18 16:54:34 +02:00
)
2022-04-18 16:28:47 +02:00
ui.run()
2022-04-18 09:01:10 +02:00
else:
2023-05-13 07:45:27 +02:00
socketio.run(
app,
2023-12-24 08:32:27 +01:00
host=host,
port=port,
2023-05-13 07:45:27 +02:00
allow_unsafe_werkzeug=True,
)