This commit is contained in:
Qing 2023-02-06 22:00:47 +08:00
parent 24bff09534
commit 3f6bc8fada
9 changed files with 307 additions and 91 deletions

View File

@ -2,7 +2,7 @@
"name": "lama-cleaner", "name": "lama-cleaner",
"version": "0.1.0", "version": "0.1.0",
"private": true, "private": true,
"proxy": "http://localhost:8080", "proxy": "http://127.0.0.1:8080",
"dependencies": { "dependencies": {
"@babel/core": "^7.16.0", "@babel/core": "^7.16.0",
"@heroicons/react": "^2.0.0", "@heroicons/react": "^2.0.0",

View File

@ -260,7 +260,9 @@ export default function Editor() {
if ( if (
(maskImage === undefined || maskImage === null) && (maskImage === undefined || maskImage === null) &&
_lineGroups.length === 0 _lineGroups.length === 1 &&
_lineGroups[0].length === 0 &&
isPix2Pix
) { ) {
// For InstructPix2Pix without mask // For InstructPix2Pix without mask
drawLines( drawLines(
@ -270,7 +272,9 @@ export default function Editor() {
size: 9999999999, size: 9999999999,
pts: [ pts: [
{ x: 0, y: 0 }, { x: 0, y: 0 },
{ x: 99999999, y: 99999999 }, { x: original.naturalWidth, y: 0 },
{ x: original.naturalWidth, y: original.naturalHeight },
{ x: 0, y: original.naturalHeight },
], ],
}, },
], ],
@ -278,7 +282,7 @@ export default function Editor() {
) )
} }
}, },
[context, maskCanvas] [context, maskCanvas, isPix2Pix]
) )
const hadDrawSomething = useCallback(() => { const hadDrawSomething = useCallback(() => {

View File

@ -1,5 +1,12 @@
import os import os
MPS_SUPPORT_MODELS = [
"instruct_pix2pix",
"sd1.5",
"sd2",
"paint_by_example"
]
DEFAULT_MODEL = "lama" DEFAULT_MODEL = "lama"
AVAILABLE_MODELS = [ AVAILABLE_MODELS = [
"lama", "lama",

View File

@ -62,7 +62,7 @@ def load_model(model: torch.nn.Module, url_or_path, device):
model_path = download_model(url_or_path) model_path = download_model(url_or_path)
try: try:
state_dict = torch.load(model_path, map_location='cpu') state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
model.to(device) model.to(device)
logger.info(f"Load model from: {model_path}") logger.info(f"Load model from: {model_path}")
@ -85,26 +85,43 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
return image_bytes return image_bytes
def load_img(img_bytes, gray: bool = False): def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes:
with io.BytesIO() as output:
pil_img.save(output, format=ext, exif=exif)
image_bytes = output.getvalue()
return image_bytes
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
alpha_channel = None alpha_channel = None
image = Image.open(io.BytesIO(img_bytes)) image = Image.open(io.BytesIO(img_bytes))
try:
if return_exif:
exif = image.getexif()
except:
exif = None
logger.error("Failed to extract exif from image")
try: try:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
except: except:
pass pass
if gray: if gray:
image = image.convert('L') image = image.convert("L")
np_img = np.array(image) np_img = np.array(image)
else: else:
if image.mode == 'RGBA': if image.mode == "RGBA":
np_img = np.array(image) np_img = np.array(image)
alpha_channel = np_img[:, :, -1] alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else: else:
image = image.convert('RGB') image = image.convert("RGB")
np_img = np.array(image) np_img = np.array(image)
if return_exif:
return np_img, alpha_channel, exif
return np_img, alpha_channel return np_img, alpha_channel

View File

@ -5,9 +5,25 @@ from pathlib import Path
from loguru import logger from loguru import logger
from lama_cleaner.const import AVAILABLE_MODELS, NO_HALF_HELP, CPU_OFFLOAD_HELP, DISABLE_NSFW_HELP, \ from lama_cleaner.const import (
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, AVAILABLE_DEVICES, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, \ AVAILABLE_MODELS,
OUTPUT_DIR_HELP, INPUT_HELP, GUI_HELP, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR NO_HALF_HELP,
CPU_OFFLOAD_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
AVAILABLE_DEVICES,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR,
DEFAULT_MODEL,
MPS_SUPPORT_MODELS,
)
from lama_cleaner.runtime import dump_environment_info from lama_cleaner.runtime import dump_environment_info
@ -16,22 +32,41 @@ def parse_args():
parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int) parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--config-installer", action="store_true", parser.add_argument(
help="Open config web page, mainly for windows installer") "--config-installer",
parser.add_argument("--load-installer-config", action="store_true", action="store_true",
help="Load all cmd args from installer config file") help="Open config web page, mainly for windows installer",
parser.add_argument("--installer-config", default=None, help="Config file for windows installer") )
parser.add_argument(
"--load-installer-config",
action="store_true",
help="Load all cmd args from installer config file",
)
parser.add_argument(
"--installer-config", default=None, help="Config file for windows installer"
)
parser.add_argument("--model", default="lama", choices=AVAILABLE_MODELS) parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS)
parser.add_argument("--cuda-visible-device", default="")
parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP)
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
parser.add_argument("--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP) parser.add_argument(
parser.add_argument("--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP) "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
parser.add_argument("--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP) )
parser.add_argument("--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES) parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
)
parser.add_argument(
"--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP
)
parser.add_argument(
"--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES
)
parser.add_argument("--gui", action="store_true", help=GUI_HELP) parser.add_argument("--gui", action="store_true", help=GUI_HELP)
parser.add_argument("--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP) parser.add_argument(
"--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP
)
parser.add_argument( parser.add_argument(
"--gui-size", "--gui-size",
default=[1600, 1000], default=[1600, 1000],
@ -41,8 +76,14 @@ def parse_args():
) )
parser.add_argument("--input", type=str, default=None, help=INPUT_HELP) parser.add_argument("--input", type=str, default=None, help=INPUT_HELP)
parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP) parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP)
parser.add_argument("--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP) parser.add_argument(
parser.add_argument("--disable-model-switch", action="store_true", help="Disable model switch in frontend") "--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP
)
parser.add_argument(
"--disable-model-switch",
action="store_true",
help="Disable model switch in frontend",
)
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
# useless args # useless args
@ -64,7 +105,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--sd-enable-xformers", "--sd-enable-xformers",
action="store_true", action="store_true",
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers" help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers",
) )
args = parser.parse_args() args = parser.parse_args()
@ -74,14 +115,18 @@ def parse_args():
if args.config_installer: if args.config_installer:
if args.installer_config is None: if args.installer_config is None:
parser.error(f"args.config_installer==True, must set args.installer_config to store config file") parser.error(
f"args.config_installer==True, must set args.installer_config to store config file"
)
from lama_cleaner.web_config import main from lama_cleaner.web_config import main
logger.info(f"Launching installer web config page") logger.info(f"Launching installer web config page")
main(args.installer_config) main(args.installer_config)
exit() exit()
if args.load_installer_config: if args.load_installer_config:
from lama_cleaner.web_config import load_config from lama_cleaner.web_config import load_config
if args.installer_config and not os.path.exists(args.installer_config): if args.installer_config and not os.path.exists(args.installer_config):
parser.error(f"args.installer_config={args.installer_config} not exists") parser.error(f"args.installer_config={args.installer_config} not exists")
@ -93,9 +138,25 @@ def parse_args():
if args.device == "cuda": if args.device == "cuda":
import torch import torch
if torch.cuda.is_available() is False: if torch.cuda.is_available() is False:
parser.error( parser.error(
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
)
if args.cuda_visible_device:
try:
int(args.cuda_visible_device)
except:
parser.error(
f"invalid --cuda-visible-device: {args.cuda_visible_device}, must be int"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_device
if args.device == "mps":
if args.model not in MPS_SUPPORT_MODELS:
parser.error(
f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}"
)
if args.model_dir and args.model_dir is not None: if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):
@ -115,7 +176,9 @@ def parse_args():
parser.error(f"invalid --input: {args.input} is not a valid image file") parser.error(f"invalid --input: {args.input} is not a valid image file")
else: else:
if args.output_dir is None: if args.output_dir is None:
parser.error(f"invalid --input: {args.input} is a directory, --output-dir is required") parser.error(
f"invalid --input: {args.input} is a directory, --output-dir is required"
)
else: else:
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
if not output_dir.exists(): if not output_dir.exists():
@ -123,6 +186,8 @@ def parse_args():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
else: else:
if not output_dir.is_dir(): if not output_dir.is_dir():
parser.error(f"invalid --output-dir: {output_dir} is not a directory") parser.error(
f"invalid --output-dir: {output_dir} is not a directory"
)
return args return args

View File

@ -32,7 +32,15 @@ try:
except: except:
pass pass
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify from flask import (
Flask,
request,
send_file,
cli,
make_response,
send_from_directory,
jsonify,
)
# Disable ability for Flask to display warning about using a development server in a production environment. # Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 # https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
@ -43,6 +51,7 @@ from lama_cleaner.helper import (
load_img, load_img,
numpy_to_bytes, numpy_to_bytes,
resize_max_size, resize_max_size,
pil_to_bytes,
) )
NUM_THREADS = str(multiprocessing.cpu_count()) NUM_THREADS = str(multiprocessing.cpu_count())
@ -103,14 +112,13 @@ def make_gif():
origin_image, _ = load_img(origin_image_bytes) origin_image, _ = load_img(origin_image_bytes)
clean_image, _ = load_img(clean_image_bytes) clean_image, _ = load_img(clean_image_bytes)
gif_bytes = make_compare_gif( gif_bytes = make_compare_gif(
Image.fromarray(origin_image), Image.fromarray(origin_image), Image.fromarray(clean_image)
Image.fromarray(clean_image)
) )
return send_file( return send_file(
io.BytesIO(gif_bytes), io.BytesIO(gif_bytes),
mimetype='image/gif', mimetype="image/gif",
as_attachment=True, as_attachment=True,
attachment_filename=filename attachment_filename=filename,
) )
@ -121,12 +129,12 @@ def save_image():
origin_image_bytes = input["image"].read() # RGB origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes) image, _ = load_img(origin_image_bytes)
thumb.save_to_output_directory(image, request.form["filename"]) thumb.save_to_output_directory(image, request.form["filename"])
return 'ok', 200 return "ok", 200
@app.route("/medias/<tab>") @app.route("/medias/<tab>")
def medias(tab): def medias(tab):
if tab == 'image': if tab == "image":
response = make_response(jsonify(thumb.media_names), 200) response = make_response(jsonify(thumb.media_names), 200)
else: else:
response = make_response(jsonify(thumb.output_media_names), 200) response = make_response(jsonify(thumb.output_media_names), 200)
@ -137,18 +145,18 @@ def medias(tab):
return response return response
@app.route('/media/<tab>/<filename>') @app.route("/media/<tab>/<filename>")
def media_file(tab, filename): def media_file(tab, filename):
if tab == 'image': if tab == "image":
return send_from_directory(thumb.root_directory, filename) return send_from_directory(thumb.root_directory, filename)
return send_from_directory(thumb.output_dir, filename) return send_from_directory(thumb.output_dir, filename)
@app.route('/media_thumbnail/<tab>/<filename>') @app.route("/media_thumbnail/<tab>/<filename>")
def media_thumbnail_file(tab, filename): def media_thumbnail_file(tab, filename):
args = request.args args = request.args
width = args.get('width') width = args.get("width")
height = args.get('height') height = args.get("height")
if width is None and height is None: if width is None and height is None:
width = 256 width = 256
if width: if width:
@ -157,9 +165,11 @@ def media_thumbnail_file(tab, filename):
height = int(float(height)) height = int(float(height))
directory = thumb.root_directory directory = thumb.root_directory
if tab == 'output': if tab == "output":
directory = thumb.output_dir directory = thumb.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height) thumb_filename, (width, height) = thumb.get_thumbnail(
directory, filename, width, height
)
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
response = make_response(send_file(thumb_filepath)) response = make_response(send_file(thumb_filepath))
@ -173,13 +183,16 @@ def process():
input = request.files input = request.files
# RGB # RGB
origin_image_bytes = input["image"].read() origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes) image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True)
mask, _ = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]: if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400 return (
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
400,
)
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
@ -192,7 +205,9 @@ def process():
size_limit = int(size_limit) size_limit = int(size_limit)
if "paintByExampleImage" in input: if "paintByExampleImage" in input:
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read()) paint_by_example_example_image, _ = load_img(
input["paintByExampleImage"].read()
)
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else: else:
paint_by_example_example_image = None paint_by_example_example_image = None
@ -221,7 +236,7 @@ def process():
sd_seed=form["sdSeed"], sd_seed=form["sdSeed"],
sd_match_histograms=form["sdMatchHistograms"], sd_match_histograms=form["sdMatchHistograms"],
cv2_flag=form["cv2Flag"], cv2_flag=form["cv2Flag"],
cv2_radius=form['cv2Radius'], cv2_radius=form["cv2Radius"],
paint_by_example_steps=form["paintByExampleSteps"], paint_by_example_steps=form["paintByExampleSteps"],
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"], paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
paint_by_example_mask_blur=form["paintByExampleMaskBlur"], paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
@ -259,6 +274,7 @@ def process():
logger.info(f"process time: {(time.time() - start) * 1000}ms") logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache() torch.cuda.empty_cache()
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
if alpha_channel is not None: if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]: if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize( alpha_channel = cv2.resize(
@ -270,9 +286,15 @@ def process():
ext = get_image_ext(origin_image_bytes) ext = get_image_ext(origin_image_bytes)
if exif is not None:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, exif=exif))
else:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext))
response = make_response( response = make_response(
send_file( send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)), # io.BytesIO(numpy_to_bytes(res_np_img, ext)),
bytes_io,
mimetype=f"image/{ext}", mimetype=f"image/{ext}",
) )
) )
@ -285,7 +307,7 @@ def interactive_seg():
input = request.files input = request.files
origin_image_bytes = input["image"].read() # RGB origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes) image, _ = load_img(origin_image_bytes)
if 'mask' in input: if "mask" in input:
mask, _ = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
else: else:
mask = None mask = None
@ -293,14 +315,16 @@ def interactive_seg():
_clicks = json.loads(request.form["clicks"]) _clicks = json.loads(request.form["clicks"])
clicks = [] clicks = []
for i, click in enumerate(_clicks): for i, click in enumerate(_clicks):
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)) clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
start = time.time() start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask) new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms") logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
response = make_response( response = make_response(
send_file( send_file(
io.BytesIO(numpy_to_bytes(new_mask, 'png')), io.BytesIO(numpy_to_bytes(new_mask, "png")),
mimetype=f"image/png", mimetype=f"image/png",
) )
) )
@ -314,13 +338,13 @@ def current_model():
@app.route("/is_disable_model_switch") @app.route("/is_disable_model_switch")
def get_is_disable_model_switch(): def get_is_disable_model_switch():
res = 'true' if is_disable_model_switch else 'false' res = "true" if is_disable_model_switch else "false"
return res, 200 return res, 200
@app.route("/is_enable_file_manager") @app.route("/is_enable_file_manager")
def get_is_enable_file_manager(): def get_is_enable_file_manager():
res = 'true' if is_enable_file_manager else 'false' res = "true" if is_enable_file_manager else "false"
return res, 200 return res, 200
@ -389,14 +413,18 @@ def main(args):
is_disable_model_switch = args.disable_model_switch is_disable_model_switch = args.disable_model_switch
is_desktop = args.gui is_desktop = args.gui
if is_disable_model_switch: if is_disable_model_switch:
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable") logger.info(
f"Start with --disable-model-switch, model switch on frontend is disable"
)
if args.input and os.path.isdir(args.input): if args.input and os.path.isdir(args.input):
logger.info(f"Initialize file manager") logger.info(f"Initialize file manager")
thumb = FileManager(app) thumb = FileManager(app)
is_enable_file_manager = True is_enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails') app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
args.output_dir, "lama_cleaner_thumbnails"
)
thumb.output_dir = Path(args.output_dir) thumb.output_dir = Path(args.output_dir)
# thumb.start() # thumb.start()
# try: # try:
@ -432,8 +460,12 @@ def main(args):
from flaskwebgui import FlaskUI from flaskwebgui import FlaskUI
ui = FlaskUI( ui = FlaskUI(
app, width=app_width, height=app_height, host=args.host, port=args.port, app,
close_server_on_exit=not args.no_gui_auto_close width=app_width,
height=app_height,
host=args.host,
port=args.port,
close_server_on_exit=not args.no_gui_auto_close,
) )
ui.run() ui.run()
else: else:

View File

@ -0,0 +1,36 @@
import io
from PIL import Image
from lama_cleaner.helper import pil_to_bytes
def print_exif(exif):
for k, v in exif.items():
print(f"{k}: {v}")
def test_png():
img = Image.open("image.png")
exif = img.getexif()
print_exif(exif)
pil_bytes = pil_to_bytes(img, ext="png", exif=exif)
res_img = Image.open(io.BytesIO(pil_bytes))
res_exif = res_img.getexif()
assert dict(exif) == dict(res_exif)
def test_jpeg():
img = Image.open("bunny.jpeg")
exif = img.getexif()
print_exif(exif)
pil_bytes = pil_to_bytes(img, ext="jpeg", exif=exif)
res_img = Image.open(io.BytesIO(pil_bytes))
res_exif = res_img.getexif()
assert dict(exif) == dict(res_exif)

View File

@ -6,9 +6,24 @@ import gradio as gr
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from lama_cleaner.const import AVAILABLE_MODELS, AVAILABLE_DEVICES, CPU_OFFLOAD_HELP, NO_HALF_HELP, DISABLE_NSFW_HELP, \ from lama_cleaner.const import (
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, OUTPUT_DIR_HELP, INPUT_HELP, \ AVAILABLE_MODELS,
GUI_HELP, DEFAULT_MODEL, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR AVAILABLE_DEVICES,
CPU_OFFLOAD_HELP,
NO_HALF_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_MODEL,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS,
)
_config_file = None _config_file = None
@ -18,6 +33,7 @@ class Config(BaseModel):
port: int = 8080 port: int = 8080
model: str = DEFAULT_MODEL model: str = DEFAULT_MODEL
device: str = DEFAULT_DEVICE device: str = DEFAULT_DEVICE
cuda_visible_device: str = ""
gui: bool = False gui: bool = False
no_gui_auto_close: bool = False no_gui_auto_close: bool = False
no_half: bool = False no_half: bool = False
@ -33,16 +49,29 @@ class Config(BaseModel):
def load_config(installer_config: str): def load_config(installer_config: str):
if os.path.exists(installer_config): if os.path.exists(installer_config):
with open(installer_config, "r", encoding='utf-8') as f: with open(installer_config, "r", encoding="utf-8") as f:
return Config(**json.load(f)) return Config(**json.load(f))
else: else:
return Config() return Config()
def save_config( def save_config(
host, port, model, device, gui, no_gui_auto_close, no_half, cpu_offload, host,
disable_nsfw, sd_cpu_textencoder, enable_xformers, local_files_only, port,
model_dir, input, output_dir model,
device,
cuda_visible_device,
gui,
no_gui_auto_close,
no_half,
cpu_offload,
disable_nsfw,
sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,
output_dir,
): ):
config = Config(**locals()) config = Config(**locals())
print(config) print(config)
@ -63,6 +92,7 @@ def save_config(
def close_server(*args): def close_server(*args):
# TODO: make close both browser and server works # TODO: make close both browser and server works
import os, signal import os, signal
pid = os.getpid() pid = os.getpid()
os.kill(pid, signal.SIGUSR1) os.kill(pid, signal.SIGUSR1)
@ -86,23 +116,45 @@ def main(config_file: str):
port = gr.Number(init_config.port, label="Port", precision=0) port = gr.Number(init_config.port, label="Port", precision=0)
with gr.Row(): with gr.Row():
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model) model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
device = gr.Radio(AVAILABLE_DEVICES, label="Device", value=init_config.device) device = gr.Radio(
AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device
)
cuda_visible_device = gr.Textbox(
"", label="CUDA visible device. (0/1/2...)"
)
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
no_gui_auto_close = gr.Checkbox(init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}") no_gui_auto_close = gr.Checkbox(
init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}"
)
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}") cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}")
disable_nsfw = gr.Checkbox(init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}") disable_nsfw = gr.Checkbox(
sd_cpu_textencoder = gr.Checkbox(init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}") init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
enable_xformers = gr.Checkbox(init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}") )
local_files_only = gr.Checkbox(init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}") sd_cpu_textencoder = gr.Checkbox(
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
)
enable_xformers = gr.Checkbox(
init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}"
)
local_files_only = gr.Checkbox(
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
)
model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}") model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}")
input = gr.Textbox(init_config.input, label=f"Input file or directory. {INPUT_HELP}") input = gr.Textbox(
output_dir = gr.Textbox(init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}") init_config.input, label=f"Input file or directory. {INPUT_HELP}"
save_btn.click(save_config, [ )
output_dir = gr.Textbox(
init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}"
)
save_btn.click(
save_config,
[
host, host,
port, port,
model, model,
device, device,
cuda_visible_device,
gui, gui,
no_gui_auto_close, no_gui_auto_close,
no_half, no_half,
@ -114,5 +166,7 @@ def main(config_file: str):
model_dir, model_dir,
input, input,
output_dir, output_dir,
], message) ],
message,
)
demo.launch(inbrowser=True, show_api=False) demo.launch(inbrowser=True, show_api=False)

View File

@ -16,3 +16,4 @@ diffusers[torch]==0.12.1
transformers>=4.25.1 transformers>=4.25.1
watchdog==2.2.1 watchdog==2.2.1
gradio gradio
piexif==1.1.3