IOPaint/scripts/user_scripts/tasks.py

131 lines
3.5 KiB
Python
Raw Normal View History

2022-10-24 12:32:35 +02:00
import os
import json
from enum import Enum
import socket
import logging
from contextlib import closing
from invoke import task
from rich import print
from rich.prompt import IntPrompt, Prompt, Confirm
from rich.logging import RichHandler
FORMAT = "%(message)s"
logging.basicConfig(
level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
)
log = logging.getLogger("lama-cleaner")
def find_free_port() -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
2022-10-27 16:38:00 +02:00
s.bind(("", 0))
2022-10-24 12:32:35 +02:00
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
CONFIG_PATH = "config.json"
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
class MODEL(str, Enum):
SD15 = "sd1.5"
2022-10-27 16:38:00 +02:00
LAMA = "lama"
2022-10-24 12:32:35 +02:00
class DEVICE(str, Enum):
CUDA = "cuda"
CPU = "cpu"
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
@task
def info(c):
print("Environment information".center(60, "-"))
try:
c.run("git --version")
c.run("conda --version")
c.run("which python")
c.run("python --version")
c.run("which pip")
c.run("pip --version")
2022-10-27 16:38:00 +02:00
c.run('pip list | grep "torch\|lama\|diffusers\|opencv\|cuda"')
2022-10-24 12:32:35 +02:00
except:
pass
2022-10-27 16:38:00 +02:00
print("-" * 60)
2022-10-24 12:32:35 +02:00
@task(pre=[info])
def config(c, disable_device_choice=False):
# TODO: 提示选择模型选择设备端口host
# 如果是 sd 模型,提示接受条款和输入 huggingface token
2022-10-27 16:38:00 +02:00
model = Prompt.ask(
"Choice model", choices=[MODEL.SD15, MODEL.LAMA], default=MODEL.SD15
)
2022-10-24 12:32:35 +02:00
hf_access_token = ""
if model == MODEL.SD15:
while True:
2022-10-27 16:38:00 +02:00
hf_access_token = Prompt.ask(
"Huggingface access token (https://huggingface.co/docs/hub/security-tokens)"
)
2022-10-24 12:32:35 +02:00
if hf_access_token == "":
log.warning("Access token is required to download model")
else:
2022-10-27 16:38:00 +02:00
break
2022-10-24 12:32:35 +02:00
if disable_device_choice:
device = DEVICE.CPU
else:
2022-10-27 16:38:00 +02:00
device = Prompt.ask(
"Choice device", choices=[DEVICE.CUDA, DEVICE.CPU], default=DEVICE.CUDA
)
2022-10-24 12:32:35 +02:00
if device == DEVICE.CUDA:
import torch
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
if not torch.cuda.is_available():
2022-10-27 16:38:00 +02:00
log.warning(
"Did not find CUDA device on your computer, fallback to cpu"
)
2022-10-24 12:32:35 +02:00
device = DEVICE.CPU
2022-10-27 17:02:31 +02:00
desktop = Confirm.ask("Start as desktop app?", default=True)
2022-10-27 16:38:00 +02:00
configs = {
"model": model,
"device": device,
"hf_access_token": hf_access_token,
"desktop": desktop,
}
2022-10-24 12:32:35 +02:00
log.info(f"Save config to {CONFIG_PATH}")
2022-10-27 16:38:00 +02:00
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
2022-10-24 12:32:35 +02:00
json.dump(configs, f, indent=2, ensure_ascii=False)
2022-11-21 04:56:34 +01:00
Confirm.ask("Config finish, you can close this window")
2022-10-24 12:32:35 +02:00
2022-10-27 16:38:00 +02:00
@task(pre=[info])
2022-10-24 12:32:35 +02:00
def start(c):
if not os.path.exists(CONFIG_PATH):
2022-11-21 04:56:34 +01:00
Confirm.ask("Config file not exists, please run config scritp first")
2022-10-24 12:32:35 +02:00
exit()
log.info(f"Load config from {CONFIG_PATH}")
2022-10-27 16:38:00 +02:00
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
2022-10-24 12:32:35 +02:00
configs = json.load(f)
2022-10-27 16:38:00 +02:00
model = configs["model"]
device = configs["device"]
hf_access_token = configs["hf_access_token"]
desktop = configs["desktop"]
2022-10-24 12:32:35 +02:00
port = find_free_port()
log.info(f"Using random port: {port}")
2022-10-27 16:38:00 +02:00
if desktop:
c.run(
f"lama-cleaner --model {model} --device {device} --hf_access_token={hf_access_token} --port {port} --gui --gui-size 1400 900"
)
else:
c.run(
f"lama-cleaner --model {model} --device {device} --hf_access_token={hf_access_token} --port {port}"
)