better error handle
This commit is contained in:
parent
8f8bcfe0f4
commit
b1ec157467
@ -14,9 +14,11 @@ from torch.hub import download_url_to_file, get_dir
|
|||||||
|
|
||||||
|
|
||||||
def switch_mps_device(model_name, device):
|
def switch_mps_device(model_name, device):
|
||||||
if model_name not in MPS_SUPPORT_MODELS and (device == "mps" or device == torch.device('mps')):
|
if model_name not in MPS_SUPPORT_MODELS and (
|
||||||
|
device == "mps" or device == torch.device("mps")
|
||||||
|
):
|
||||||
logger.info(f"{model_name} not support mps, switch to cpu")
|
logger.info(f"{model_name} not support mps, switch to cpu")
|
||||||
return torch.device('cpu')
|
return torch.device("cpu")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
@ -51,12 +53,14 @@ def load_jit_model(url_or_path, device):
|
|||||||
model_path = url_or_path
|
model_path = url_or_path
|
||||||
else:
|
else:
|
||||||
model_path = download_model(url_or_path)
|
model_path = download_model(url_or_path)
|
||||||
logger.info(f"Load model from: {model_path}")
|
logger.info(f"Loading model from: {model_path}")
|
||||||
try:
|
try:
|
||||||
model = torch.jit.load(model_path).to(device)
|
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||||
except:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
f"Failed to load {model_path}, please delete model and restart lama-cleaner.\n"
|
||||||
|
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||||
|
f"If all above operations doesn't work, please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
|
||||||
)
|
)
|
||||||
exit(-1)
|
exit(-1)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -3,9 +3,12 @@ import os
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
from lama_cleaner.helper import (
|
||||||
|
norm_img,
|
||||||
|
get_cache_path_by_url,
|
||||||
|
load_jit_model,
|
||||||
|
)
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
@ -20,20 +23,7 @@ class LaMa(InpaintModel):
|
|||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
|
|
||||||
def init_model(self, device, **kwargs):
|
def init_model(self, device, **kwargs):
|
||||||
if os.environ.get("LAMA_MODEL"):
|
self.model = load_jit_model(LAMA_MODEL_URL, device).eval()
|
||||||
model_path = os.environ.get("LAMA_MODEL")
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"lama torchscript model not found: {model_path}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_path = download_model(LAMA_MODEL_URL)
|
|
||||||
logger.info(f"Load LaMa model from: {model_path}")
|
|
||||||
model = torch.jit.load(model_path, map_location="cpu")
|
|
||||||
model = model.to(device)
|
|
||||||
model.eval()
|
|
||||||
self.model = model
|
|
||||||
self.model_path = model_path
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_downloaded() -> bool:
|
def is_downloaded() -> bool:
|
||||||
|
Loading…
Reference in New Issue
Block a user