add setup.py
This commit is contained in:
parent
f7e1e073dc
commit
a219da27f7
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,3 +3,6 @@
|
||||
examples/
|
||||
.idea/
|
||||
.vscode/
|
||||
build/
|
||||
dist/
|
||||
lama_cleaner.egg-info/
|
||||
|
7
lama_cleaner/__init__.py
Normal file
7
lama_cleaner/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from lama_cleaner.parse_args import parse_args
|
||||
from lama_cleaner.server import main
|
||||
|
||||
|
||||
def entry_point():
|
||||
args = parse_args()
|
||||
main(args)
|
0
lama_cleaner/model/__init__.py
Normal file
0
lama_cleaner/model/__init__.py
Normal file
32
lama_cleaner/parse_args.py
Normal file
32
lama_cleaner/parse_args.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
import imghdr
|
||||
import argparse
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
parser.add_argument(
|
||||
"--gui-size",
|
||||
default=[1600, 1000],
|
||||
nargs=2,
|
||||
type=int,
|
||||
help="Set window size for GUI",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input", type=str, help="Path to image you want to load by default"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.input is not None:
|
||||
if not os.path.exists(args.input):
|
||||
parser.error(f"invalid --input: {args.input} not exists")
|
||||
if imghdr.what(args.input) is None:
|
||||
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
||||
|
||||
return args
|
186
lama_cleaner/server.py
Normal file
186
lama_cleaner/server.py
Normal file
@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import imghdr
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
from flask import Flask, request, send_file
|
||||
from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
)
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||
if os.environ.get("CACHE_DIR"):
|
||||
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
||||
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
logger_opt = logger.opt(depth=6, exception=record.exc_info)
|
||||
logger_opt.log(record.levelno, record.getMessage())
|
||||
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
app.logger.addHandler(InterceptHandler())
|
||||
CORS(app, expose_headers=["Content-Disposition"])
|
||||
|
||||
model: ModelManager = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
w = imghdr.what("", img_bytes)
|
||||
if w is None:
|
||||
w = "jpeg"
|
||||
return w
|
||||
|
||||
|
||||
@app.route("/inpaint", methods=["POST"])
|
||||
def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
origin_image_bytes = input["image"].read()
|
||||
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
|
||||
form = request.form
|
||||
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
|
||||
if size_limit == "Original":
|
||||
size_limit = max(image.shape)
|
||||
else:
|
||||
size_limit = int(size_limit)
|
||||
|
||||
config = Config(
|
||||
ldm_steps=form['ldmSteps'],
|
||||
hd_strategy=form['hdStrategy'],
|
||||
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
|
||||
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
|
||||
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
|
||||
)
|
||||
|
||||
logger.info(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
logger.info(f"Resized image shape: {image.shape}")
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
|
||||
start = time.time()
|
||||
res_np_img = model(image, mask, config)
|
||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
||||
alpha_channel = cv2.resize(
|
||||
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
|
||||
)
|
||||
res_np_img = np.concatenate(
|
||||
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||
)
|
||||
|
||||
ext = get_image_ext(origin_image_bytes)
|
||||
return send_file(
|
||||
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||
mimetype=f"image/{ext}",
|
||||
)
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
|
||||
|
||||
@app.route("/model_downloaded/<name>")
|
||||
def model_downloaded(name):
|
||||
return str(model.is_downloaded(name)), 200
|
||||
|
||||
|
||||
@app.route("/model", methods=["POST"])
|
||||
def switch_model():
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
return "Same model", 200
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
except NotImplementedError:
|
||||
return f"{new_name} not implemented", 403
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
|
||||
|
||||
@app.route("/inputimage")
|
||||
def set_input_photo():
|
||||
if input_image_path:
|
||||
with open(input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
input_image_path,
|
||||
as_attachment=True,
|
||||
download_name=Path(input_image_path).name,
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
def main(args):
|
||||
global model
|
||||
global device
|
||||
global input_image_path
|
||||
|
||||
device = torch.device(args.device)
|
||||
input_image_path = args.input
|
||||
|
||||
model = ModelManager(name=args.model, device=device)
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
return FlaskUI(app, width=app_width, height=app_height, host=args.host, port=args.port)
|
||||
else:
|
||||
app.run(host=args.host, port=args.port, debug=args.debug)
|
221
main.py
221
main.py
@ -1,221 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import imghdr
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
from flask import Flask, request, send_file
|
||||
from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
)
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||
if os.environ.get("CACHE_DIR"):
|
||||
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
||||
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
logger_opt = logger.opt(depth=6, exception=record.exc_info)
|
||||
logger_opt.log(record.levelno, record.getMessage())
|
||||
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
app.logger.addHandler(InterceptHandler())
|
||||
CORS(app, expose_headers=["Content-Disposition"])
|
||||
|
||||
model: ModelManager = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
w = imghdr.what("", img_bytes)
|
||||
if w is None:
|
||||
w = "jpeg"
|
||||
return w
|
||||
|
||||
|
||||
@app.route("/inpaint", methods=["POST"])
|
||||
def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
origin_image_bytes = input["image"].read()
|
||||
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
|
||||
form = request.form
|
||||
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
|
||||
if size_limit == "Original":
|
||||
size_limit = max(image.shape)
|
||||
else:
|
||||
size_limit = int(size_limit)
|
||||
|
||||
config = Config(
|
||||
ldm_steps=form['ldmSteps'],
|
||||
hd_strategy=form['hdStrategy'],
|
||||
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
|
||||
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
|
||||
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
|
||||
)
|
||||
|
||||
logger.info(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
logger.info(f"Resized image shape: {image.shape}")
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
|
||||
start = time.time()
|
||||
res_np_img = model(image, mask, config)
|
||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
||||
alpha_channel = cv2.resize(
|
||||
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
|
||||
)
|
||||
res_np_img = np.concatenate(
|
||||
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||
)
|
||||
|
||||
ext = get_image_ext(origin_image_bytes)
|
||||
return send_file(
|
||||
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||
mimetype=f"image/{ext}",
|
||||
)
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
|
||||
|
||||
@app.route("/model_downloaded/<name>")
|
||||
def model_downloaded(name):
|
||||
return str(model.is_downloaded(name)), 200
|
||||
|
||||
|
||||
@app.route("/model", methods=["POST"])
|
||||
def switch_model():
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
return "Same model", 200
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
except NotImplementedError:
|
||||
return f"{new_name} not implemented", 403
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
|
||||
|
||||
@app.route("/inputimage")
|
||||
def set_input_photo():
|
||||
if input_image_path:
|
||||
with open(input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
input_image_path,
|
||||
as_attachment=True,
|
||||
download_name=Path(input_image_path).name,
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input", type=str, help="Path to image you want to load by default"
|
||||
)
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
parser.add_argument(
|
||||
"--gui-size",
|
||||
default=[1600, 1000],
|
||||
nargs=2,
|
||||
type=int,
|
||||
help="Set window size for GUI",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.input is not None:
|
||||
if not os.path.exists(args.input):
|
||||
parser.error(f"invalid --input: {args.input} not exists")
|
||||
if imghdr.what(args.input) is None:
|
||||
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
global model
|
||||
global device
|
||||
global input_image_path
|
||||
|
||||
args = get_args_parser()
|
||||
device = torch.device(args.device)
|
||||
input_image_path = args.input
|
||||
|
||||
model = ModelManager(name=args.model, device=device)
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
ui = FlaskUI(app, width=app_width, height=app_height)
|
||||
ui.run()
|
||||
else:
|
||||
app.run(host=args.host, port=args.port, debug=args.debug)
|
||||
|
||||
from lama_cleaner import entry_point
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
entry_point()
|
||||
|
10
publish.sh
Normal file
10
publish.sh
Normal file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
pushd ./lama_cleaner/app
|
||||
yarn run build
|
||||
popd
|
||||
|
||||
rm -r -f dist
|
||||
python3 setup.py sdist bdist_wheel
|
||||
#twine upload dist/*
|
2
requirements-dev.txt
Normal file
2
requirements-dev.txt
Normal file
@ -0,0 +1,2 @@
|
||||
wheel
|
||||
twine
|
@ -1,7 +1,7 @@
|
||||
torch>=1.8.2
|
||||
opencv-python
|
||||
flask_cors
|
||||
flask==2.1.1
|
||||
flask>=2.1.1
|
||||
flaskwebgui
|
||||
tqdm
|
||||
pydantic
|
||||
|
42
setup.py
42
setup.py
@ -1 +1,41 @@
|
||||
# TODO: make this a python package
|
||||
import setuptools
|
||||
from pathlib import Path
|
||||
|
||||
web_files = Path("lama_cleaner/app/build/").glob("**/*")
|
||||
web_files = [str(it).replace("lama_cleaner/", "") for it in web_files]
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
def load_requirements():
|
||||
requirements_file_name = "requirements.txt"
|
||||
requires = []
|
||||
with open(requirements_file_name) as f:
|
||||
for line in f:
|
||||
if line:
|
||||
requires.append(line.strip())
|
||||
return requires
|
||||
|
||||
|
||||
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
||||
setuptools.setup(
|
||||
name="lama-cleaner",
|
||||
version="0.9.0",
|
||||
author="PanicByte",
|
||||
author_email="cwq1913@gmail.com",
|
||||
description="Image inpainting tool powered by SOTA AI Model",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/Sanster/lama-cleaner",
|
||||
packages=setuptools.find_packages("./"),
|
||||
package_data={"lama_cleaner": web_files},
|
||||
install_requires=load_requirements(),
|
||||
python_requires=">=3.6",
|
||||
entry_points={"console_scripts": ["lama-cleaner=lama_cleaner:entry_point"]},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user