From a219da27f71be2ac8eb47fc5b56a06f56a1da398 Mon Sep 17 00:00:00 2001 From: Sanster Date: Mon, 18 Apr 2022 15:01:10 +0800 Subject: [PATCH] add setup.py --- .gitignore | 3 + lama_cleaner/__init__.py | 7 ++ lama_cleaner/model/__init__.py | 0 lama_cleaner/parse_args.py | 32 +++++ lama_cleaner/server.py | 186 +++++++++++++++++++++++++++ main.py | 221 +-------------------------------- publish.sh | 10 ++ requirements-dev.txt | 2 + requirements.txt | 2 +- setup.py | 42 ++++++- 10 files changed, 284 insertions(+), 221 deletions(-) create mode 100644 lama_cleaner/__init__.py create mode 100644 lama_cleaner/model/__init__.py create mode 100644 lama_cleaner/parse_args.py create mode 100644 lama_cleaner/server.py create mode 100644 publish.sh create mode 100644 requirements-dev.txt diff --git a/.gitignore b/.gitignore index 9b4139b..bc9822f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ examples/ .idea/ .vscode/ +build/ +dist/ +lama_cleaner.egg-info/ diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py new file mode 100644 index 0000000..5e72328 --- /dev/null +++ b/lama_cleaner/__init__.py @@ -0,0 +1,7 @@ +from lama_cleaner.parse_args import parse_args +from lama_cleaner.server import main + + +def entry_point(): + args = parse_args() + main(args) diff --git a/lama_cleaner/model/__init__.py b/lama_cleaner/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py new file mode 100644 index 0000000..026577e --- /dev/null +++ b/lama_cleaner/parse_args.py @@ -0,0 +1,32 @@ +import os +import imghdr +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", default=8080, type=int) + parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) + parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) + parser.add_argument("--gui", action="store_true", help="Launch as desktop app") + parser.add_argument( + "--gui-size", + default=[1600, 1000], + nargs=2, + type=int, + help="Set window size for GUI", + ) + parser.add_argument( + "--input", type=str, help="Path to image you want to load by default" + ) + parser.add_argument("--debug", action="store_true") + + args = parser.parse_args() + if args.input is not None: + if not os.path.exists(args.input): + parser.error(f"invalid --input: {args.input} not exists") + if imghdr.what(args.input) is None: + parser.error(f"invalid --input: {args.input} is not a valid image file") + + return args diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py new file mode 100644 index 0000000..50f4150 --- /dev/null +++ b/lama_cleaner/server.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 + +import argparse +import io +import logging +import multiprocessing +import os +import time +import imghdr +from pathlib import Path +from typing import Union + +import cv2 +import torch +import numpy as np +from loguru import logger + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config + +try: + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) +except: + pass + +from flask import Flask, request, send_file +from flask_cors import CORS + +from lama_cleaner.helper import ( + load_img, + numpy_to_bytes, + resize_max_size, +) + +NUM_THREADS = str(multiprocessing.cpu_count()) + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS +if os.environ.get("CACHE_DIR"): + os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] + +BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") + + +class InterceptHandler(logging.Handler): + def emit(self, record): + logger_opt = logger.opt(depth=6, exception=record.exc_info) + logger_opt.log(record.levelno, record.getMessage()) + + +app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) +app.config["JSON_AS_ASCII"] = False +app.logger.addHandler(InterceptHandler()) +CORS(app, expose_headers=["Content-Disposition"]) + +model: ModelManager = None +device = None +input_image_path: str = None + + +def get_image_ext(img_bytes): + w = imghdr.what("", img_bytes) + if w is None: + w = "jpeg" + return w + + +@app.route("/inpaint", methods=["POST"]) +def process(): + input = request.files + # RGB + origin_image_bytes = input["image"].read() + + image, alpha_channel = load_img(origin_image_bytes) + original_shape = image.shape + interpolation = cv2.INTER_CUBIC + + form = request.form + size_limit: Union[int, str] = form.get("sizeLimit", "1080") + if size_limit == "Original": + size_limit = max(image.shape) + else: + size_limit = int(size_limit) + + config = Config( + ldm_steps=form['ldmSteps'], + hd_strategy=form['hdStrategy'], + hd_strategy_crop_margin=form['hdStrategyCropMargin'], + hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'], + hd_strategy_resize_limit=form['hdStrategyResizeLimit'], + ) + + logger.info(f"Origin image shape: {original_shape}") + image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) + logger.info(f"Resized image shape: {image.shape}") + + mask, _ = load_img(input["mask"].read(), gray=True) + mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) + + start = time.time() + res_np_img = model(image, mask, config) + logger.info(f"process time: {(time.time() - start) * 1000}ms") + + torch.cuda.empty_cache() + + if alpha_channel is not None: + if alpha_channel.shape[:2] != res_np_img.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) + ) + res_np_img = np.concatenate( + (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 + ) + + ext = get_image_ext(origin_image_bytes) + return send_file( + io.BytesIO(numpy_to_bytes(res_np_img, ext)), + mimetype=f"image/{ext}", + ) + + +@app.route("/model") +def current_model(): + return model.name, 200 + + +@app.route("/model_downloaded/") +def model_downloaded(name): + return str(model.is_downloaded(name)), 200 + + +@app.route("/model", methods=["POST"]) +def switch_model(): + new_name = request.form.get("name") + if new_name == model.name: + return "Same model", 200 + + try: + model.switch(new_name) + except NotImplementedError: + return f"{new_name} not implemented", 403 + return f"ok, switch to {new_name}", 200 + + +@app.route("/") +def index(): + return send_file(os.path.join(BUILD_DIR, "index.html")) + + +@app.route("/inputimage") +def set_input_photo(): + if input_image_path: + with open(input_image_path, "rb") as f: + image_in_bytes = f.read() + return send_file( + input_image_path, + as_attachment=True, + download_name=Path(input_image_path).name, + mimetype=f"image/{get_image_ext(image_in_bytes)}", + ) + else: + return "No Input Image" + + +def main(args): + global model + global device + global input_image_path + + device = torch.device(args.device) + input_image_path = args.input + + model = ModelManager(name=args.model, device=device) + + if args.gui: + app_width, app_height = args.gui_size + from flaskwebgui import FlaskUI + return FlaskUI(app, width=app_width, height=app_height, host=args.host, port=args.port) + else: + app.run(host=args.host, port=args.port, debug=args.debug) diff --git a/main.py b/main.py index 708238b..f57b35f 100644 --- a/main.py +++ b/main.py @@ -1,221 +1,4 @@ -#!/usr/bin/env python3 - -import argparse -import io -import logging -import multiprocessing -import os -import time -import imghdr -from pathlib import Path -from typing import Union - -import cv2 -import torch -import numpy as np -from loguru import logger - -from lama_cleaner.model_manager import ModelManager -from lama_cleaner.schema import Config - -try: - torch._C._jit_override_can_fuse_on_cpu(False) - torch._C._jit_override_can_fuse_on_gpu(False) - torch._C._jit_set_texpr_fuser_enabled(False) - torch._C._jit_set_nvfuser_enabled(False) -except: - pass - -from flask import Flask, request, send_file -from flask_cors import CORS - -from lama_cleaner.helper import ( - load_img, - numpy_to_bytes, - resize_max_size, -) - -NUM_THREADS = str(multiprocessing.cpu_count()) - -os.environ["OMP_NUM_THREADS"] = NUM_THREADS -os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS -os.environ["MKL_NUM_THREADS"] = NUM_THREADS -os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS -os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS -if os.environ.get("CACHE_DIR"): - os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] - -BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build") - - -class InterceptHandler(logging.Handler): - def emit(self, record): - logger_opt = logger.opt(depth=6, exception=record.exc_info) - logger_opt.log(record.levelno, record.getMessage()) - - -app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) -app.config["JSON_AS_ASCII"] = False -app.logger.addHandler(InterceptHandler()) -CORS(app, expose_headers=["Content-Disposition"]) - -model: ModelManager = None -device = None -input_image_path: str = None - - -def get_image_ext(img_bytes): - w = imghdr.what("", img_bytes) - if w is None: - w = "jpeg" - return w - - -@app.route("/inpaint", methods=["POST"]) -def process(): - input = request.files - # RGB - origin_image_bytes = input["image"].read() - - image, alpha_channel = load_img(origin_image_bytes) - original_shape = image.shape - interpolation = cv2.INTER_CUBIC - - form = request.form - size_limit: Union[int, str] = form.get("sizeLimit", "1080") - if size_limit == "Original": - size_limit = max(image.shape) - else: - size_limit = int(size_limit) - - config = Config( - ldm_steps=form['ldmSteps'], - hd_strategy=form['hdStrategy'], - hd_strategy_crop_margin=form['hdStrategyCropMargin'], - hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'], - hd_strategy_resize_limit=form['hdStrategyResizeLimit'], - ) - - logger.info(f"Origin image shape: {original_shape}") - image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) - logger.info(f"Resized image shape: {image.shape}") - - mask, _ = load_img(input["mask"].read(), gray=True) - mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) - - start = time.time() - res_np_img = model(image, mask, config) - logger.info(f"process time: {(time.time() - start) * 1000}ms") - - torch.cuda.empty_cache() - - if alpha_channel is not None: - if alpha_channel.shape[:2] != res_np_img.shape[:2]: - alpha_channel = cv2.resize( - alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) - ) - res_np_img = np.concatenate( - (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 - ) - - ext = get_image_ext(origin_image_bytes) - return send_file( - io.BytesIO(numpy_to_bytes(res_np_img, ext)), - mimetype=f"image/{ext}", - ) - - -@app.route("/model") -def current_model(): - return model.name, 200 - - -@app.route("/model_downloaded/") -def model_downloaded(name): - return str(model.is_downloaded(name)), 200 - - -@app.route("/model", methods=["POST"]) -def switch_model(): - new_name = request.form.get("name") - if new_name == model.name: - return "Same model", 200 - - try: - model.switch(new_name) - except NotImplementedError: - return f"{new_name} not implemented", 403 - return f"ok, switch to {new_name}", 200 - - -@app.route("/") -def index(): - return send_file(os.path.join(BUILD_DIR, "index.html")) - - -@app.route("/inputimage") -def set_input_photo(): - if input_image_path: - with open(input_image_path, "rb") as f: - image_in_bytes = f.read() - return send_file( - input_image_path, - as_attachment=True, - download_name=Path(input_image_path).name, - mimetype=f"image/{get_image_ext(image_in_bytes)}", - ) - else: - return "No Input Image" - - -def get_args_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--input", type=str, help="Path to image you want to load by default" - ) - parser.add_argument("--host", default="127.0.0.1") - parser.add_argument("--port", default=8080, type=int) - parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) - parser.add_argument("--device", default="cuda", type=str) - parser.add_argument("--gui", action="store_true", help="Launch as desktop app") - parser.add_argument( - "--gui-size", - default=[1600, 1000], - nargs=2, - type=int, - help="Set window size for GUI", - ) - parser.add_argument("--debug", action="store_true") - - args = parser.parse_args() - if args.input is not None: - if not os.path.exists(args.input): - parser.error(f"invalid --input: {args.input} not exists") - if imghdr.what(args.input) is None: - parser.error(f"invalid --input: {args.input} is not a valid image file") - - return args - - -def main(): - global model - global device - global input_image_path - - args = get_args_parser() - device = torch.device(args.device) - input_image_path = args.input - - model = ModelManager(name=args.model, device=device) - - if args.gui: - app_width, app_height = args.gui_size - from flaskwebgui import FlaskUI - ui = FlaskUI(app, width=app_width, height=app_height) - ui.run() - else: - app.run(host=args.host, port=args.port, debug=args.debug) - +from lama_cleaner import entry_point if __name__ == "__main__": - main() + entry_point() diff --git a/publish.sh b/publish.sh new file mode 100644 index 0000000..8a2e1d1 --- /dev/null +++ b/publish.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -e + +pushd ./lama_cleaner/app +yarn run build +popd + +rm -r -f dist +python3 setup.py sdist bdist_wheel +#twine upload dist/* diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..d5ba964 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,2 @@ +wheel +twine \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c0e911e..d6f0730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch>=1.8.2 opencv-python flask_cors -flask==2.1.1 +flask>=2.1.1 flaskwebgui tqdm pydantic diff --git a/setup.py b/setup.py index 702178b..7454744 100644 --- a/setup.py +++ b/setup.py @@ -1 +1,41 @@ -# TODO: make this a python package +import setuptools +from pathlib import Path + +web_files = Path("lama_cleaner/app/build/").glob("**/*") +web_files = [str(it).replace("lama_cleaner/", "") for it in web_files] + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +def load_requirements(): + requirements_file_name = "requirements.txt" + requires = [] + with open(requirements_file_name) as f: + for line in f: + if line: + requires.append(line.strip()) + return requires + + +# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files +setuptools.setup( + name="lama-cleaner", + version="0.9.0", + author="PanicByte", + author_email="cwq1913@gmail.com", + description="Image inpainting tool powered by SOTA AI Model", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/Sanster/lama-cleaner", + packages=setuptools.find_packages("./"), + package_data={"lama_cleaner": web_files}, + install_requires=load_requirements(), + python_requires=">=3.6", + entry_points={"console_scripts": ["lama-cleaner=lama_cleaner:entry_point"]}, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], +)