update user_scripts

This commit is contained in:
Sanster 2022-10-27 22:38:00 +08:00
parent 6921a13a83
commit af914e2086
2 changed files with 49 additions and 18 deletions

View File

@ -20,20 +20,24 @@ log = logging.getLogger("lama-cleaner")
def find_free_port() -> int: def find_free_port() -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0)) s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1] return s.getsockname()[1]
CONFIG_PATH = "config.json" CONFIG_PATH = "config.json"
class MODEL(str, Enum): class MODEL(str, Enum):
SD15 = "sd1.5" SD15 = "sd1.5"
LAMA = 'lama' LAMA = "lama"
class DEVICE(str, Enum): class DEVICE(str, Enum):
CUDA = "cuda" CUDA = "cuda"
CPU = "cpu" CPU = "cpu"
@task @task
def info(c): def info(c):
print("Environment information".center(60, "-")) print("Environment information".center(60, "-"))
@ -44,57 +48,83 @@ def info(c):
c.run("python --version") c.run("python --version")
c.run("which pip") c.run("which pip")
c.run("pip --version") c.run("pip --version")
c.run("pip list | grep lama") c.run('pip list | grep "torch\|lama\|diffusers\|opencv\|cuda"')
except: except:
pass pass
print("-"*60) print("-" * 60)
@task(pre=[info]) @task(pre=[info])
def config(c, disable_device_choice=False): def config(c, disable_device_choice=False):
# TODO: 提示选择模型选择设备端口host # TODO: 提示选择模型选择设备端口host
# 如果是 sd 模型,提示接受条款和输入 huggingface token # 如果是 sd 模型,提示接受条款和输入 huggingface token
model = Prompt.ask("Choice model", choices=[MODEL.SD15, MODEL.LAMA], default=MODEL.SD15) model = Prompt.ask(
"Choice model", choices=[MODEL.SD15, MODEL.LAMA], default=MODEL.SD15
)
hf_access_token = "" hf_access_token = ""
if model == MODEL.SD15: if model == MODEL.SD15:
while True: while True:
hf_access_token = Prompt.ask("Huggingface access token (https://huggingface.co/docs/hub/security-tokens)") hf_access_token = Prompt.ask(
"Huggingface access token (https://huggingface.co/docs/hub/security-tokens)"
)
if hf_access_token == "": if hf_access_token == "":
log.warning("Access token is required to download model") log.warning("Access token is required to download model")
else: else:
break break
if disable_device_choice: if disable_device_choice:
device = DEVICE.CPU device = DEVICE.CPU
else: else:
device = Prompt.ask("Choice device", choices=[DEVICE.CUDA, DEVICE.CPU], default=DEVICE.CUDA) device = Prompt.ask(
"Choice device", choices=[DEVICE.CUDA, DEVICE.CPU], default=DEVICE.CUDA
)
if device == DEVICE.CUDA: if device == DEVICE.CUDA:
import torch import torch
if not torch.cuda.is_available(): if not torch.cuda.is_available():
log.warning("Did not find CUDA device on your computer, fallback to cpu") log.warning(
"Did not find CUDA device on your computer, fallback to cpu"
)
device = DEVICE.CPU device = DEVICE.CPU
configs = {"model": model, "device": device, "hf_access_token": hf_access_token} desktop = Confirm.ask("Start as desktop app?")
configs = {
"model": model,
"device": device,
"hf_access_token": hf_access_token,
"desktop": desktop,
}
log.info(f"Save config to {CONFIG_PATH}") log.info(f"Save config to {CONFIG_PATH}")
with open(CONFIG_PATH, 'w', encoding='utf-8') as f: with open(CONFIG_PATH, "w", encoding="utf-8") as f:
json.dump(configs, f, indent=2, ensure_ascii=False) json.dump(configs, f, indent=2, ensure_ascii=False)
log.info(f"Config finish, you can close this window.") log.info(f"Config finish, you can close this window.")
@task(pre=[info]) @task(pre=[info])
def start(c): def start(c):
if not os.path.exists(CONFIG_PATH): if not os.path.exists(CONFIG_PATH):
log.info("Config file not exists, please run config.sh first") log.info("Config file not exists, please run config.sh first")
exit() exit()
log.info(f"Load config from {CONFIG_PATH}") log.info(f"Load config from {CONFIG_PATH}")
with open(CONFIG_PATH, 'r', encoding='utf-8') as f: with open(CONFIG_PATH, "r", encoding="utf-8") as f:
configs = json.load(f) configs = json.load(f)
model = configs['model'] model = configs["model"]
device = configs['device'] device = configs["device"]
hf_access_token = configs['hf_access_token'] hf_access_token = configs["hf_access_token"]
desktop = configs["desktop"]
port = find_free_port() port = find_free_port()
log.info(f"Using random port: {port}") log.info(f"Using random port: {port}")
c.run(f"lama-cleaner --model {model} --device {device} --hf_access_token={hf_access_token} --port {port} --gui --gui-size 1400 900") if desktop:
c.run(
f"lama-cleaner --model {model} --device {device} --hf_access_token={hf_access_token} --port {port} --gui --gui-size 1400 900"
)
else:
c.run(
f"lama-cleaner --model {model} --device {device} --hf_access_token={hf_access_token} --port {port}"
)

View File

@ -6,7 +6,8 @@ set PATH=C:\Windows\System32;%PATH%
@call conda-unpack @call conda-unpack
@call pip3 install -U torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 @call conda install -y cudatoolkit=11.3
@call pip3 install torch --extra-index-url https://download.pytorch.org/whl/cu113
@call pip3 install -U lama-cleaner @call pip3 install -U lama-cleaner
@call invoke config @call invoke config