From 3f6bc8fada994697c1bfd6f0eb547b2c56816c20 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 6 Feb 2023 22:00:47 +0800 Subject: [PATCH] update --- lama_cleaner/app/package.json | 2 +- .../app/src/components/Editor/Editor.tsx | 10 +- lama_cleaner/const.py | 7 ++ lama_cleaner/helper.py | 27 +++- lama_cleaner/parse_args.py | 107 ++++++++++++---- lama_cleaner/server.py | 88 ++++++++----- lama_cleaner/tests/test_save_exif.py | 36 ++++++ lama_cleaner/web_config.py | 118 +++++++++++++----- requirements.txt | 3 +- 9 files changed, 307 insertions(+), 91 deletions(-) create mode 100644 lama_cleaner/tests/test_save_exif.py diff --git a/lama_cleaner/app/package.json b/lama_cleaner/app/package.json index 321aa7c..c456638 100644 --- a/lama_cleaner/app/package.json +++ b/lama_cleaner/app/package.json @@ -2,7 +2,7 @@ "name": "lama-cleaner", "version": "0.1.0", "private": true, - "proxy": "http://localhost:8080", + "proxy": "http://127.0.0.1:8080", "dependencies": { "@babel/core": "^7.16.0", "@heroicons/react": "^2.0.0", diff --git a/lama_cleaner/app/src/components/Editor/Editor.tsx b/lama_cleaner/app/src/components/Editor/Editor.tsx index bcf2042..724d3f1 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.tsx +++ b/lama_cleaner/app/src/components/Editor/Editor.tsx @@ -260,7 +260,9 @@ export default function Editor() { if ( (maskImage === undefined || maskImage === null) && - _lineGroups.length === 0 + _lineGroups.length === 1 && + _lineGroups[0].length === 0 && + isPix2Pix ) { // For InstructPix2Pix without mask drawLines( @@ -270,7 +272,9 @@ export default function Editor() { size: 9999999999, pts: [ { x: 0, y: 0 }, - { x: 99999999, y: 99999999 }, + { x: original.naturalWidth, y: 0 }, + { x: original.naturalWidth, y: original.naturalHeight }, + { x: 0, y: original.naturalHeight }, ], }, ], @@ -278,7 +282,7 @@ export default function Editor() { ) } }, - [context, maskCanvas] + [context, maskCanvas, isPix2Pix] ) const hadDrawSomething = useCallback(() => { diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index e0444e9..32370d9 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -1,5 +1,12 @@ import os +MPS_SUPPORT_MODELS = [ + "instruct_pix2pix", + "sd1.5", + "sd2", + "paint_by_example" +] + DEFAULT_MODEL = "lama" AVAILABLE_MODELS = [ "lama", diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index a6b67b8..74a75ba 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -62,7 +62,7 @@ def load_model(model: torch.nn.Module, url_or_path, device): model_path = download_model(url_or_path) try: - state_dict = torch.load(model_path, map_location='cpu') + state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) model.to(device) logger.info(f"Load model from: {model_path}") @@ -85,26 +85,43 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: return image_bytes -def load_img(img_bytes, gray: bool = False): +def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes: + with io.BytesIO() as output: + pil_img.save(output, format=ext, exif=exif) + image_bytes = output.getvalue() + return image_bytes + + +def load_img(img_bytes, gray: bool = False, return_exif: bool = False): alpha_channel = None image = Image.open(io.BytesIO(img_bytes)) + + try: + if return_exif: + exif = image.getexif() + except: + exif = None + logger.error("Failed to extract exif from image") + try: image = ImageOps.exif_transpose(image) except: pass if gray: - image = image.convert('L') + image = image.convert("L") np_img = np.array(image) else: - if image.mode == 'RGBA': + if image.mode == "RGBA": np_img = np.array(image) alpha_channel = np_img[:, :, -1] np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) else: - image = image.convert('RGB') + image = image.convert("RGB") np_img = np.array(image) + if return_exif: + return np_img, alpha_channel, exif return np_img, alpha_channel diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index dc0bb9b..db84908 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -5,9 +5,25 @@ from pathlib import Path from loguru import logger -from lama_cleaner.const import AVAILABLE_MODELS, NO_HALF_HELP, CPU_OFFLOAD_HELP, DISABLE_NSFW_HELP, \ - SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, AVAILABLE_DEVICES, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, \ - OUTPUT_DIR_HELP, INPUT_HELP, GUI_HELP, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR +from lama_cleaner.const import ( + AVAILABLE_MODELS, + NO_HALF_HELP, + CPU_OFFLOAD_HELP, + DISABLE_NSFW_HELP, + SD_CPU_TEXTENCODER_HELP, + LOCAL_FILES_ONLY_HELP, + AVAILABLE_DEVICES, + ENABLE_XFORMERS_HELP, + MODEL_DIR_HELP, + OUTPUT_DIR_HELP, + INPUT_HELP, + GUI_HELP, + DEFAULT_DEVICE, + NO_GUI_AUTO_CLOSE_HELP, + DEFAULT_MODEL_DIR, + DEFAULT_MODEL, + MPS_SUPPORT_MODELS, +) from lama_cleaner.runtime import dump_environment_info @@ -16,22 +32,41 @@ def parse_args(): parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", default=8080, type=int) - parser.add_argument("--config-installer", action="store_true", - help="Open config web page, mainly for windows installer") - parser.add_argument("--load-installer-config", action="store_true", - help="Load all cmd args from installer config file") - parser.add_argument("--installer-config", default=None, help="Config file for windows installer") + parser.add_argument( + "--config-installer", + action="store_true", + help="Open config web page, mainly for windows installer", + ) + parser.add_argument( + "--load-installer-config", + action="store_true", + help="Load all cmd args from installer config file", + ) + parser.add_argument( + "--installer-config", default=None, help="Config file for windows installer" + ) - parser.add_argument("--model", default="lama", choices=AVAILABLE_MODELS) + parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS) + parser.add_argument("--cuda-visible-device", default="") parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) - parser.add_argument("--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP) - parser.add_argument("--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP) - parser.add_argument("--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP) - parser.add_argument("--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES) + parser.add_argument( + "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP + ) + parser.add_argument( + "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP + ) + parser.add_argument( + "--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP + ) + parser.add_argument( + "--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES + ) parser.add_argument("--gui", action="store_true", help=GUI_HELP) - parser.add_argument("--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP) + parser.add_argument( + "--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP + ) parser.add_argument( "--gui-size", default=[1600, 1000], @@ -41,8 +76,14 @@ def parse_args(): ) parser.add_argument("--input", type=str, default=None, help=INPUT_HELP) parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP) - parser.add_argument("--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP) - parser.add_argument("--disable-model-switch", action="store_true", help="Disable model switch in frontend") + parser.add_argument( + "--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP + ) + parser.add_argument( + "--disable-model-switch", + action="store_true", + help="Disable model switch in frontend", + ) parser.add_argument("--debug", action="store_true") # useless args @@ -64,7 +105,7 @@ def parse_args(): parser.add_argument( "--sd-enable-xformers", action="store_true", - help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers" + help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers", ) args = parser.parse_args() @@ -74,14 +115,18 @@ def parse_args(): if args.config_installer: if args.installer_config is None: - parser.error(f"args.config_installer==True, must set args.installer_config to store config file") + parser.error( + f"args.config_installer==True, must set args.installer_config to store config file" + ) from lama_cleaner.web_config import main + logger.info(f"Launching installer web config page") main(args.installer_config) exit() if args.load_installer_config: from lama_cleaner.web_config import load_config + if args.installer_config and not os.path.exists(args.installer_config): parser.error(f"args.installer_config={args.installer_config} not exists") @@ -93,9 +138,25 @@ def parse_args(): if args.device == "cuda": import torch + if torch.cuda.is_available() is False: parser.error( - "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") + "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" + ) + if args.cuda_visible_device: + try: + int(args.cuda_visible_device) + except: + parser.error( + f"invalid --cuda-visible-device: {args.cuda_visible_device}, must be int" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_device + + if args.device == "mps": + if args.model not in MPS_SUPPORT_MODELS: + parser.error( + f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}" + ) if args.model_dir and args.model_dir is not None: if os.path.isfile(args.model_dir): @@ -115,7 +176,9 @@ def parse_args(): parser.error(f"invalid --input: {args.input} is not a valid image file") else: if args.output_dir is None: - parser.error(f"invalid --input: {args.input} is a directory, --output-dir is required") + parser.error( + f"invalid --input: {args.input} is a directory, --output-dir is required" + ) else: output_dir = Path(args.output_dir) if not output_dir.exists(): @@ -123,6 +186,8 @@ def parse_args(): output_dir.mkdir(parents=True) else: if not output_dir.is_dir(): - parser.error(f"invalid --output-dir: {output_dir} is not a directory") + parser.error( + f"invalid --output-dir: {output_dir} is not a directory" + ) return args diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index a55e1df..6ff0806 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -32,7 +32,15 @@ try: except: pass -from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify +from flask import ( + Flask, + request, + send_file, + cli, + make_response, + send_from_directory, + jsonify, +) # Disable ability for Flask to display warning about using a development server in a production environment. # https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 @@ -43,6 +51,7 @@ from lama_cleaner.helper import ( load_img, numpy_to_bytes, resize_max_size, + pil_to_bytes, ) NUM_THREADS = str(multiprocessing.cpu_count()) @@ -103,14 +112,13 @@ def make_gif(): origin_image, _ = load_img(origin_image_bytes) clean_image, _ = load_img(clean_image_bytes) gif_bytes = make_compare_gif( - Image.fromarray(origin_image), - Image.fromarray(clean_image) + Image.fromarray(origin_image), Image.fromarray(clean_image) ) return send_file( io.BytesIO(gif_bytes), - mimetype='image/gif', + mimetype="image/gif", as_attachment=True, - attachment_filename=filename + attachment_filename=filename, ) @@ -121,12 +129,12 @@ def save_image(): origin_image_bytes = input["image"].read() # RGB image, _ = load_img(origin_image_bytes) thumb.save_to_output_directory(image, request.form["filename"]) - return 'ok', 200 + return "ok", 200 @app.route("/medias/") def medias(tab): - if tab == 'image': + if tab == "image": response = make_response(jsonify(thumb.media_names), 200) else: response = make_response(jsonify(thumb.output_media_names), 200) @@ -137,18 +145,18 @@ def medias(tab): return response -@app.route('/media//') +@app.route("/media//") def media_file(tab, filename): - if tab == 'image': + if tab == "image": return send_from_directory(thumb.root_directory, filename) return send_from_directory(thumb.output_dir, filename) -@app.route('/media_thumbnail//') +@app.route("/media_thumbnail//") def media_thumbnail_file(tab, filename): args = request.args - width = args.get('width') - height = args.get('height') + width = args.get("width") + height = args.get("height") if width is None and height is None: width = 256 if width: @@ -157,9 +165,11 @@ def media_thumbnail_file(tab, filename): height = int(float(height)) directory = thumb.root_directory - if tab == 'output': + if tab == "output": directory = thumb.output_dir - thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height) + thumb_filename, (width, height) = thumb.get_thumbnail( + directory, filename, width, height + ) thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" response = make_response(send_file(thumb_filepath)) @@ -173,13 +183,16 @@ def process(): input = request.files # RGB origin_image_bytes = input["image"].read() - image, alpha_channel = load_img(origin_image_bytes) + image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) mask, _ = load_img(input["mask"].read(), gray=True) mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] if image.shape[:2] != mask.shape[:2]: - return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400 + return ( + f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", + 400, + ) original_shape = image.shape interpolation = cv2.INTER_CUBIC @@ -192,7 +205,9 @@ def process(): size_limit = int(size_limit) if "paintByExampleImage" in input: - paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read()) + paint_by_example_example_image, _ = load_img( + input["paintByExampleImage"].read() + ) paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) else: paint_by_example_example_image = None @@ -221,7 +236,7 @@ def process(): sd_seed=form["sdSeed"], sd_match_histograms=form["sdMatchHistograms"], cv2_flag=form["cv2Flag"], - cv2_radius=form['cv2Radius'], + cv2_radius=form["cv2Radius"], paint_by_example_steps=form["paintByExampleSteps"], paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"], paint_by_example_mask_blur=form["paintByExampleMaskBlur"], @@ -259,6 +274,7 @@ def process(): logger.info(f"process time: {(time.time() - start) * 1000}ms") torch.cuda.empty_cache() + res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) if alpha_channel is not None: if alpha_channel.shape[:2] != res_np_img.shape[:2]: alpha_channel = cv2.resize( @@ -270,9 +286,15 @@ def process(): ext = get_image_ext(origin_image_bytes) + if exif is not None: + bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, exif=exif)) + else: + bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext)) + response = make_response( send_file( - io.BytesIO(numpy_to_bytes(res_np_img, ext)), + # io.BytesIO(numpy_to_bytes(res_np_img, ext)), + bytes_io, mimetype=f"image/{ext}", ) ) @@ -285,7 +307,7 @@ def interactive_seg(): input = request.files origin_image_bytes = input["image"].read() # RGB image, _ = load_img(origin_image_bytes) - if 'mask' in input: + if "mask" in input: mask, _ = load_img(input["mask"].read(), gray=True) else: mask = None @@ -293,14 +315,16 @@ def interactive_seg(): _clicks = json.loads(request.form["clicks"]) clicks = [] for i, click in enumerate(_clicks): - clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)) + clicks.append( + Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1) + ) start = time.time() new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask) logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms") response = make_response( send_file( - io.BytesIO(numpy_to_bytes(new_mask, 'png')), + io.BytesIO(numpy_to_bytes(new_mask, "png")), mimetype=f"image/png", ) ) @@ -314,13 +338,13 @@ def current_model(): @app.route("/is_disable_model_switch") def get_is_disable_model_switch(): - res = 'true' if is_disable_model_switch else 'false' + res = "true" if is_disable_model_switch else "false" return res, 200 @app.route("/is_enable_file_manager") def get_is_enable_file_manager(): - res = 'true' if is_enable_file_manager else 'false' + res = "true" if is_enable_file_manager else "false" return res, 200 @@ -389,14 +413,18 @@ def main(args): is_disable_model_switch = args.disable_model_switch is_desktop = args.gui if is_disable_model_switch: - logger.info(f"Start with --disable-model-switch, model switch on frontend is disable") + logger.info( + f"Start with --disable-model-switch, model switch on frontend is disable" + ) if args.input and os.path.isdir(args.input): logger.info(f"Initialize file manager") thumb = FileManager(app) is_enable_file_manager = True app.config["THUMBNAIL_MEDIA_ROOT"] = args.input - app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails') + app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( + args.output_dir, "lama_cleaner_thumbnails" + ) thumb.output_dir = Path(args.output_dir) # thumb.start() # try: @@ -432,8 +460,12 @@ def main(args): from flaskwebgui import FlaskUI ui = FlaskUI( - app, width=app_width, height=app_height, host=args.host, port=args.port, - close_server_on_exit=not args.no_gui_auto_close + app, + width=app_width, + height=app_height, + host=args.host, + port=args.port, + close_server_on_exit=not args.no_gui_auto_close, ) ui.run() else: diff --git a/lama_cleaner/tests/test_save_exif.py b/lama_cleaner/tests/test_save_exif.py new file mode 100644 index 0000000..8b26363 --- /dev/null +++ b/lama_cleaner/tests/test_save_exif.py @@ -0,0 +1,36 @@ +import io + +from PIL import Image + +from lama_cleaner.helper import pil_to_bytes + + +def print_exif(exif): + for k, v in exif.items(): + print(f"{k}: {v}") + + +def test_png(): + img = Image.open("image.png") + exif = img.getexif() + print_exif(exif) + + pil_bytes = pil_to_bytes(img, ext="png", exif=exif) + + res_img = Image.open(io.BytesIO(pil_bytes)) + res_exif = res_img.getexif() + + assert dict(exif) == dict(res_exif) + + +def test_jpeg(): + img = Image.open("bunny.jpeg") + exif = img.getexif() + print_exif(exif) + + pil_bytes = pil_to_bytes(img, ext="jpeg", exif=exif) + + res_img = Image.open(io.BytesIO(pil_bytes)) + res_exif = res_img.getexif() + + assert dict(exif) == dict(res_exif) diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index c4f9730..7ca8d85 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -6,9 +6,24 @@ import gradio as gr from loguru import logger from pydantic import BaseModel -from lama_cleaner.const import AVAILABLE_MODELS, AVAILABLE_DEVICES, CPU_OFFLOAD_HELP, NO_HALF_HELP, DISABLE_NSFW_HELP, \ - SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, OUTPUT_DIR_HELP, INPUT_HELP, \ - GUI_HELP, DEFAULT_MODEL, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR +from lama_cleaner.const import ( + AVAILABLE_MODELS, + AVAILABLE_DEVICES, + CPU_OFFLOAD_HELP, + NO_HALF_HELP, + DISABLE_NSFW_HELP, + SD_CPU_TEXTENCODER_HELP, + LOCAL_FILES_ONLY_HELP, + ENABLE_XFORMERS_HELP, + MODEL_DIR_HELP, + OUTPUT_DIR_HELP, + INPUT_HELP, + GUI_HELP, + DEFAULT_MODEL, + DEFAULT_DEVICE, + NO_GUI_AUTO_CLOSE_HELP, + DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS, +) _config_file = None @@ -18,6 +33,7 @@ class Config(BaseModel): port: int = 8080 model: str = DEFAULT_MODEL device: str = DEFAULT_DEVICE + cuda_visible_device: str = "" gui: bool = False no_gui_auto_close: bool = False no_half: bool = False @@ -33,16 +49,29 @@ class Config(BaseModel): def load_config(installer_config: str): if os.path.exists(installer_config): - with open(installer_config, "r", encoding='utf-8') as f: + with open(installer_config, "r", encoding="utf-8") as f: return Config(**json.load(f)) else: return Config() def save_config( - host, port, model, device, gui, no_gui_auto_close, no_half, cpu_offload, - disable_nsfw, sd_cpu_textencoder, enable_xformers, local_files_only, - model_dir, input, output_dir + host, + port, + model, + device, + cuda_visible_device, + gui, + no_gui_auto_close, + no_half, + cpu_offload, + disable_nsfw, + sd_cpu_textencoder, + enable_xformers, + local_files_only, + model_dir, + input, + output_dir, ): config = Config(**locals()) print(config) @@ -63,6 +92,7 @@ def save_config( def close_server(*args): # TODO: make close both browser and server works import os, signal + pid = os.getpid() os.kill(pid, signal.SIGUSR1) @@ -86,33 +116,57 @@ def main(config_file: str): port = gr.Number(init_config.port, label="Port", precision=0) with gr.Row(): model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model) - device = gr.Radio(AVAILABLE_DEVICES, label="Device", value=init_config.device) + device = gr.Radio( + AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device + ) + cuda_visible_device = gr.Textbox( + "", label="CUDA visible device. (0/1/2...)" + ) gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") - no_gui_auto_close = gr.Checkbox(init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}") + no_gui_auto_close = gr.Checkbox( + init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}" + ) no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}") - disable_nsfw = gr.Checkbox(init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}") - sd_cpu_textencoder = gr.Checkbox(init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}") - enable_xformers = gr.Checkbox(init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}") - local_files_only = gr.Checkbox(init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}") + disable_nsfw = gr.Checkbox( + init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}" + ) + sd_cpu_textencoder = gr.Checkbox( + init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}" + ) + enable_xformers = gr.Checkbox( + init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}" + ) + local_files_only = gr.Checkbox( + init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}" + ) model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}") - input = gr.Textbox(init_config.input, label=f"Input file or directory. {INPUT_HELP}") - output_dir = gr.Textbox(init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}") - save_btn.click(save_config, [ - host, - port, - model, - device, - gui, - no_gui_auto_close, - no_half, - cpu_offload, - disable_nsfw, - sd_cpu_textencoder, - enable_xformers, - local_files_only, - model_dir, - input, - output_dir, - ], message) + input = gr.Textbox( + init_config.input, label=f"Input file or directory. {INPUT_HELP}" + ) + output_dir = gr.Textbox( + init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}" + ) + save_btn.click( + save_config, + [ + host, + port, + model, + device, + cuda_visible_device, + gui, + no_gui_auto_close, + no_half, + cpu_offload, + disable_nsfw, + sd_cpu_textencoder, + enable_xformers, + local_files_only, + model_dir, + input, + output_dir, + ], + message, + ) demo.launch(inbrowser=True, show_api=False) diff --git a/requirements.txt b/requirements.txt index 11cf54a..83f8f7b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ scikit-image==0.19.3 diffusers[torch]==0.12.1 transformers>=4.25.1 watchdog==2.2.1 -gradio \ No newline at end of file +gradio +piexif==1.1.3 \ No newline at end of file