add setup.py

This commit is contained in:
Sanster 2022-04-18 15:01:10 +08:00
parent f7e1e073dc
commit a219da27f7
10 changed files with 284 additions and 221 deletions

3
.gitignore vendored
View File

@ -3,3 +3,6 @@
examples/
.idea/
.vscode/
build/
dist/
lama_cleaner.egg-info/

7
lama_cleaner/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from lama_cleaner.parse_args import parse_args
from lama_cleaner.server import main
def entry_point():
args = parse_args()
main(args)

View File

View File

@ -0,0 +1,32 @@
import os
import imghdr
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
"--gui-size",
default=[1600, 1000],
nargs=2,
type=int,
help="Set window size for GUI",
)
parser.add_argument(
"--input", type=str, help="Path to image you want to load by default"
)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
if args.input is not None:
if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists")
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
return args

186
lama_cleaner/server.py Normal file
View File

@ -0,0 +1,186 @@
#!/usr/bin/env python3
import argparse
import io
import logging
import multiprocessing
import os
import time
import imghdr
from pathlib import Path
from typing import Union
import cv2
import torch
import numpy as np
from loguru import logger
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from flask import Flask, request, send_file
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
class InterceptHandler(logging.Handler):
def emit(self, record):
logger_opt = logger.opt(depth=6, exception=record.exc_info)
logger_opt.log(record.levelno, record.getMessage())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
app.logger.addHandler(InterceptHandler())
CORS(app, expose_headers=["Content-Disposition"])
model: ModelManager = None
device = None
input_image_path: str = None
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
config = Config(
ldm_steps=form['ldmSteps'],
hd_strategy=form['hdStrategy'],
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time()
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
return send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
)
@app.route("/model")
def current_model():
return model.name, 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@app.route("/model", methods=["POST"])
def switch_model():
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@app.route("/inputimage")
def set_input_photo():
if input_image_path:
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def main(args):
global model
global device
global input_image_path
device = torch.device(args.device)
input_image_path = args.input
model = ModelManager(name=args.model, device=device)
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
return FlaskUI(app, width=app_width, height=app_height, host=args.host, port=args.port)
else:
app.run(host=args.host, port=args.port, debug=args.debug)

221
main.py
View File

@ -1,221 +1,4 @@
#!/usr/bin/env python3
import argparse
import io
import logging
import multiprocessing
import os
import time
import imghdr
from pathlib import Path
from typing import Union
import cv2
import torch
import numpy as np
from loguru import logger
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from flask import Flask, request, send_file
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
class InterceptHandler(logging.Handler):
def emit(self, record):
logger_opt = logger.opt(depth=6, exception=record.exc_info)
logger_opt.log(record.levelno, record.getMessage())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
app.logger.addHandler(InterceptHandler())
CORS(app, expose_headers=["Content-Disposition"])
model: ModelManager = None
device = None
input_image_path: str = None
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
config = Config(
ldm_steps=form['ldmSteps'],
hd_strategy=form['hdStrategy'],
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time()
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
return send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
)
@app.route("/model")
def current_model():
return model.name, 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@app.route("/model", methods=["POST"])
def switch_model():
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@app.route("/inputimage")
def set_input_photo():
if input_image_path:
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input", type=str, help="Path to image you want to load by default"
)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
"--gui-size",
default=[1600, 1000],
nargs=2,
type=int,
help="Set window size for GUI",
)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
if args.input is not None:
if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists")
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
return args
def main():
global model
global device
global input_image_path
args = get_args_parser()
device = torch.device(args.device)
input_image_path = args.input
model = ModelManager(name=args.model, device=device)
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
ui = FlaskUI(app, width=app_width, height=app_height)
ui.run()
else:
app.run(host=args.host, port=args.port, debug=args.debug)
from lama_cleaner import entry_point
if __name__ == "__main__":
main()
entry_point()

10
publish.sh Normal file
View File

@ -0,0 +1,10 @@
#!/usr/bin/env bash
set -e
pushd ./lama_cleaner/app
yarn run build
popd
rm -r -f dist
python3 setup.py sdist bdist_wheel
#twine upload dist/*

2
requirements-dev.txt Normal file
View File

@ -0,0 +1,2 @@
wheel
twine

View File

@ -1,7 +1,7 @@
torch>=1.8.2
opencv-python
flask_cors
flask==2.1.1
flask>=2.1.1
flaskwebgui
tqdm
pydantic

View File

@ -1 +1,41 @@
# TODO: make this a python package
import setuptools
from pathlib import Path
web_files = Path("lama_cleaner/app/build/").glob("**/*")
web_files = [str(it).replace("lama_cleaner/", "") for it in web_files]
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
def load_requirements():
requirements_file_name = "requirements.txt"
requires = []
with open(requirements_file_name) as f:
for line in f:
if line:
requires.append(line.strip())
return requires
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
setuptools.setup(
name="lama-cleaner",
version="0.9.0",
author="PanicByte",
author_email="cwq1913@gmail.com",
description="Image inpainting tool powered by SOTA AI Model",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/Sanster/lama-cleaner",
packages=setuptools.find_packages("./"),
package_data={"lama_cleaner": web_files},
install_requires=load_requirements(),
python_requires=">=3.6",
entry_points={"console_scripts": ["lama-cleaner=lama_cleaner:entry_point"]},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
)