update web_config

This commit is contained in:
Qing 2024-02-04 21:38:41 +08:00
parent fcd8254205
commit 7321232b78

View File

@ -1,7 +1,12 @@
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
from datetime import datetime from datetime import datetime
from json import JSONDecodeError from json import JSONDecodeError
import gradio as gr import gradio as gr
from iopaint.download import scan_models
from loguru import logger from loguru import logger
from iopaint.const import * from iopaint.const import *
@ -59,7 +64,7 @@ def save_config(
config.input = None config.input = None
if str(config.output_dir) == ".": if str(config.output_dir) == ".":
config.output_dir = None config.output_dir = None
config.model = config.model.strip()
print(config.model_dump_json(indent=4)) print(config.model_dump_json(indent=4))
if config.input and not os.path.exists(config.input): if config.input and not os.path.exists(config.input):
return "[Error] Input file or directory does not exist" return "[Error] Input file or directory does not exist"
@ -75,11 +80,16 @@ def save_config(
return msg return msg
def change_current_model(new_model):
return new_model
def main(config_file: Path): def main(config_file: Path):
global _config_file global _config_file
_config_file = config_file _config_file = config_file
init_config = load_config(config_file) init_config = load_config(config_file)
downloaded_models = [it.name for it in scan_models()]
with gr.Blocks() as demo: with gr.Blocks() as demo:
with gr.Row(): with gr.Row():
@ -95,11 +105,22 @@ def main(config_file: Path):
host = gr.Textbox(init_config.host, label="Host") host = gr.Textbox(init_config.host, label="Host")
port = gr.Number(init_config.port, label="Port", precision=0) port = gr.Number(init_config.port, label="Port", precision=0)
model = gr.Radio( with gr.Column():
AVAILABLE_MODELS + DIFFUSION_MODELS, model = gr.Textbox(
label="Models (https://www.iopaint.com/models)", init_config.model,
value=init_config.model, label="Current Model. This is the model that will be used when the service starts. "
) "If the model has not been downloaded before, it will be automatically downloaded. "
"You can select a model from the dropdown box below or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
)
with gr.Row():
recommend_model = gr.Dropdown(
["lama", "mat", "migan"] + DIFFUSION_MODELS,
label="Recommend Models",
)
downloaded_model = gr.Dropdown(
downloaded_models, label="Downloaded Models"
)
device = gr.Radio( device = gr.Radio(
Device.values(), label="Device", value=init_config.device Device.values(), label="Device", value=init_config.device
) )
@ -196,6 +217,9 @@ def main(config_file: Path):
value=init_config.restoreformer_device, value=init_config.restoreformer_device,
) )
downloaded_model.change(change_current_model, [downloaded_model], model)
recommend_model.change(change_current_model, [recommend_model], model)
save_btn.click( save_btn.click(
save_config, save_config,
[ [