IOPaint/scripts/user_scripts/tasks.py

117 lines
2.9 KiB
Python
Raw Normal View History

2022-10-24 12:32:35 +02:00
import os
import json
from enum import Enum
import socket
import logging
from contextlib import closing
from invoke import task
from rich import print
from rich.prompt import IntPrompt, Prompt, Confirm
from rich.logging import RichHandler
FORMAT = "%(message)s"
logging.basicConfig(
level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
)
log = logging.getLogger("lama-cleaner")
def find_free_port() -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
2022-10-27 16:38:00 +02:00
s.bind(("", 0))
2022-10-24 12:32:35 +02:00
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
CONFIG_PATH = "config.json"
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
class MODEL(str, Enum):
SD15 = "sd1.5"
2022-10-27 16:38:00 +02:00
LAMA = "lama"
2023-01-08 14:13:36 +01:00
PAINT_BY_EXAMPLE = 'paint_by_example'
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
class DEVICE(str, Enum):
CUDA = "cuda"
CPU = "cpu"
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
@task
def info(c):
print("Environment information".center(60, "-"))
try:
c.run("git --version")
c.run("conda --version")
c.run("which python")
c.run("python --version")
c.run("which pip")
c.run("pip --version")
2023-01-08 14:13:36 +01:00
c.run('pip list | grep "torch\|lama\|diffusers\|opencv\|cuda\|xformers\|accelerate"')
2022-10-24 12:32:35 +02:00
except:
pass
2022-10-27 16:38:00 +02:00
print("-" * 60)
2022-10-24 12:32:35 +02:00
@task(pre=[info])
def config(c, disable_device_choice=False):
2022-10-27 16:38:00 +02:00
model = Prompt.ask(
2023-01-08 14:13:36 +01:00
"Choice model", choices=[MODEL.SD15, MODEL.LAMA, MODEL.PAINT_BY_EXAMPLE], default=MODEL.SD15
2022-10-27 16:38:00 +02:00
)
2022-10-24 12:32:35 +02:00
if disable_device_choice:
device = DEVICE.CPU
else:
2022-10-27 16:38:00 +02:00
device = Prompt.ask(
"Choice device", choices=[DEVICE.CUDA, DEVICE.CPU], default=DEVICE.CUDA
)
2022-10-24 12:32:35 +02:00
if device == DEVICE.CUDA:
import torch
2022-10-27 16:38:00 +02:00
2022-10-24 12:32:35 +02:00
if not torch.cuda.is_available():
2022-10-27 16:38:00 +02:00
log.warning(
"Did not find CUDA device on your computer, fallback to cpu"
)
2022-10-24 12:32:35 +02:00
device = DEVICE.CPU
2022-10-27 17:02:31 +02:00
desktop = Confirm.ask("Start as desktop app?", default=True)
2022-10-27 16:38:00 +02:00
configs = {
"model": model,
"device": device,
"desktop": desktop,
}
2022-10-24 12:32:35 +02:00
log.info(f"Save config to {CONFIG_PATH}")
2022-10-27 16:38:00 +02:00
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
2022-10-24 12:32:35 +02:00
json.dump(configs, f, indent=2, ensure_ascii=False)
2022-11-21 04:56:34 +01:00
Confirm.ask("Config finish, you can close this window")
2022-10-24 12:32:35 +02:00
2022-10-27 16:38:00 +02:00
@task(pre=[info])
2022-10-24 12:32:35 +02:00
def start(c):
if not os.path.exists(CONFIG_PATH):
2022-11-21 04:56:34 +01:00
Confirm.ask("Config file not exists, please run config scritp first")
2022-10-24 12:32:35 +02:00
exit()
log.info(f"Load config from {CONFIG_PATH}")
2022-10-27 16:38:00 +02:00
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
2022-10-24 12:32:35 +02:00
configs = json.load(f)
2022-10-27 16:38:00 +02:00
model = configs["model"]
device = configs["device"]
desktop = configs["desktop"]
2022-10-24 12:32:35 +02:00
port = find_free_port()
log.info(f"Using random port: {port}")
2022-10-27 16:38:00 +02:00
if desktop:
c.run(
2023-01-08 14:13:36 +01:00
f"lama-cleaner --model {model} --device {device} --port {port} --gui --gui-size 1400 900"
2022-10-27 16:38:00 +02:00
)
else:
c.run(
2023-01-08 14:13:36 +01:00
f"lama-cleaner --model {model} --device {device} --port {port}"
2022-10-27 16:38:00 +02:00
)