better error handle

This commit is contained in:
Qing 2023-02-14 09:08:56 +08:00
parent 8f8bcfe0f4
commit b1ec157467
2 changed files with 16 additions and 22 deletions

View File

@ -14,9 +14,11 @@ from torch.hub import download_url_to_file, get_dir
def switch_mps_device(model_name, device): def switch_mps_device(model_name, device):
if model_name not in MPS_SUPPORT_MODELS and (device == "mps" or device == torch.device('mps')): if model_name not in MPS_SUPPORT_MODELS and (
device == "mps" or device == torch.device("mps")
):
logger.info(f"{model_name} not support mps, switch to cpu") logger.info(f"{model_name} not support mps, switch to cpu")
return torch.device('cpu') return torch.device("cpu")
return device return device
@ -51,12 +53,14 @@ def load_jit_model(url_or_path, device):
model_path = url_or_path model_path = url_or_path
else: else:
model_path = download_model(url_or_path) model_path = download_model(url_or_path)
logger.info(f"Load model from: {model_path}") logger.info(f"Loading model from: {model_path}")
try: try:
model = torch.jit.load(model_path).to(device) model = torch.jit.load(model_path, map_location="cpu").to(device)
except: except Exception as e:
logger.error( logger.error(
f"Failed to load {model_path}, delete model and restart lama-cleaner" f"Failed to load {model_path}, please delete model and restart lama-cleaner.\n"
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
f"If all above operations doesn't work, please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
) )
exit(-1) exit(-1)
model.eval() model.eval()

View File

@ -3,9 +3,12 @@ import os
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from loguru import logger
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.helper import (
norm_img,
get_cache_path_by_url,
load_jit_model,
)
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
@ -20,20 +23,7 @@ class LaMa(InpaintModel):
pad_mod = 8 pad_mod = 8
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):
if os.environ.get("LAMA_MODEL"): self.model = load_jit_model(LAMA_MODEL_URL, device).eval()
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
logger.info(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
self.model_path = model_path
@staticmethod @staticmethod
def is_downloaded() -> bool: def is_downloaded() -> bool: