diff --git a/iopaint/web_config.py b/iopaint/web_config.py index 948957d..761116b 100644 --- a/iopaint/web_config.py +++ b/iopaint/web_config.py @@ -1,7 +1,12 @@ +import os + +os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" + from datetime import datetime from json import JSONDecodeError import gradio as gr +from iopaint.download import scan_models from loguru import logger from iopaint.const import * @@ -59,7 +64,7 @@ def save_config( config.input = None if str(config.output_dir) == ".": config.output_dir = None - + config.model = config.model.strip() print(config.model_dump_json(indent=4)) if config.input and not os.path.exists(config.input): return "[Error] Input file or directory does not exist" @@ -75,11 +80,16 @@ def save_config( return msg +def change_current_model(new_model): + return new_model + + def main(config_file: Path): global _config_file _config_file = config_file init_config = load_config(config_file) + downloaded_models = [it.name for it in scan_models()] with gr.Blocks() as demo: with gr.Row(): @@ -95,11 +105,22 @@ def main(config_file: Path): host = gr.Textbox(init_config.host, label="Host") port = gr.Number(init_config.port, label="Port", precision=0) - model = gr.Radio( - AVAILABLE_MODELS + DIFFUSION_MODELS, - label="Models (https://www.iopaint.com/models)", - value=init_config.model, - ) + with gr.Column(): + model = gr.Textbox( + init_config.model, + label="Current Model. This is the model that will be used when the service starts. " + "If the model has not been downloaded before, it will be automatically downloaded. " + "You can select a model from the dropdown box below or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.", + ) + with gr.Row(): + recommend_model = gr.Dropdown( + ["lama", "mat", "migan"] + DIFFUSION_MODELS, + label="Recommend Models", + ) + downloaded_model = gr.Dropdown( + downloaded_models, label="Downloaded Models" + ) + device = gr.Radio( Device.values(), label="Device", value=init_config.device ) @@ -196,6 +217,9 @@ def main(config_file: Path): value=init_config.restoreformer_device, ) + downloaded_model.change(change_current_model, [downloaded_model], model) + recommend_model.change(change_current_model, [recommend_model], model) + save_btn.click( save_config, [