auto switch mps device to cpu device

This commit is contained in:
Qing 2023-02-11 13:30:09 +08:00
parent f9b5dcbfd7
commit 8f8bcfe0f4
15 changed files with 52 additions and 19 deletions

View File

@ -8,10 +8,18 @@ import cv2
from PIL import Image, ImageOps from PIL import Image, ImageOps
import numpy as np import numpy as np
import torch import torch
from lama_cleaner.const import MPS_SUPPORT_MODELS
from loguru import logger from loguru import logger
from torch.hub import download_url_to_file, get_dir from torch.hub import download_url_to_file, get_dir
def switch_mps_device(model_name, device):
if model_name not in MPS_SUPPORT_MODELS and (device == "mps" or device == torch.device('mps')):
logger.info(f"{model_name} not support mps, switch to cpu")
return torch.device('cpu')
return device
def get_cache_path_by_url(url): def get_cache_path_by_url(url):
parts = urlparse(url) parts = urlparse(url)
hub_dir = get_dir() hub_dir = get_dir()

View File

@ -6,11 +6,12 @@ import torch
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device
from lama_cleaner.schema import Config, HDStrategy from lama_cleaner.schema import Config, HDStrategy
class InpaintModel: class InpaintModel:
name = "base"
min_size: Optional[int] = None min_size: Optional[int] = None
pad_mod = 8 pad_mod = 8
pad_to_square = False pad_to_square = False
@ -21,6 +22,7 @@ class InpaintModel:
Args: Args:
device: device:
""" """
device = switch_mps_device(self.name, device)
self.device = device self.device = device
self.init_model(device, **kwargs) self.init_model(device, **kwargs)

View File

@ -1131,6 +1131,7 @@ FCF_MODEL_URL = os.environ.get(
class FcF(InpaintModel): class FcF(InpaintModel):
name = "fcf"
min_size = 512 min_size = 512
pad_mod = 512 pad_mod = 512
pad_to_square = True pad_to_square = True

View File

@ -9,6 +9,7 @@ from lama_cleaner.schema import Config
class InstructPix2Pix(DiffusionInpaintModel): class InstructPix2Pix(DiffusionInpaintModel):
name = "instruct_pix2pix"
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512

View File

@ -5,7 +5,7 @@ import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
@ -16,6 +16,7 @@ LAMA_MODEL_URL = os.environ.get(
class LaMa(InpaintModel): class LaMa(InpaintModel):
name = "lama"
pad_mod = 8 pad_mod = 8
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):

View File

@ -225,6 +225,7 @@ class LatentDiffusion(DDPM):
class LDM(InpaintModel): class LDM(InpaintModel):
name = "ldm"
pad_mod = 32 pad_mod = 32
def __init__(self, device, fp16: bool = True, **kwargs): def __init__(self, device, fp16: bool = True, **kwargs):

View File

@ -76,6 +76,7 @@ MANGA_LINE_MODEL_URL = os.environ.get(
class Manga(InpaintModel): class Manga(InpaintModel):
name = "manga"
pad_mod = 16 pad_mod = 16
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):

View File

@ -1401,6 +1401,7 @@ MAT_MODEL_URL = os.environ.get(
class MAT(InpaintModel): class MAT(InpaintModel):
name = "mat"
min_size = 512 min_size = 512
pad_mod = 512 pad_mod = 512
pad_to_square = True pad_to_square = True

View File

@ -2,12 +2,11 @@ import cv2
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
flag_map = { flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
"INPAINT_NS": cv2.INPAINT_NS,
"INPAINT_TELEA": cv2.INPAINT_TELEA
}
class OpenCV2(InpaintModel): class OpenCV2(InpaintModel):
name = "cv2"
pad_mod = 1 pad_mod = 1
@staticmethod @staticmethod
@ -20,5 +19,10 @@ class OpenCV2(InpaintModel):
mask: [H, W, 1] mask: [H, W, 1]
return: BGR IMAGE return: BGR IMAGE
""" """
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag]) cur_res = cv2.inpaint(
image[:, :, ::-1],
mask,
inpaintRadius=config.cv2_radius,
flags=flag_map[config.cv2_flag],
)
return cur_res return cur_res

View File

@ -11,6 +11,7 @@ from lama_cleaner.schema import Config
class PaintByExample(DiffusionInpaintModel): class PaintByExample(DiffusionInpaintModel):
name = "paint_by_example"
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512

View File

@ -160,8 +160,10 @@ class SD(DiffusionInpaintModel):
class SD15(SD): class SD15(SD):
name = "sd1.5"
model_id_or_path = "runwayml/stable-diffusion-inpainting" model_id_or_path = "runwayml/stable-diffusion-inpainting"
class SD2(SD): class SD2(SD):
name = "sd2"
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"

View File

@ -203,6 +203,7 @@ def to_device(data, device):
class ZITS(InpaintModel): class ZITS(InpaintModel):
name = "zits"
min_size = 256 min_size = 256
pad_mod = 32 pad_mod = 32
pad_to_square = True pad_to_square = True

View File

@ -1,6 +1,7 @@
import torch import torch
import gc import gc
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.fcf import FcF from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.lama import LaMa from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM from lama_cleaner.model.ldm import LDM
@ -13,8 +14,19 @@ from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga, models = {
"sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix} "lama": LaMa,
"ldm": LDM,
"zits": ZITS,
"mat": MAT,
"fcf": FcF,
"sd1.5": SD15,
"cv2": OpenCV2,
"manga": Manga,
"sd2": SD2,
"paint_by_example": PaintByExample,
"instruct_pix2pix": InstructPix2Pix,
}
class ModelManager: class ModelManager:
@ -44,13 +56,15 @@ class ModelManager:
if new_name == self.name: if new_name == self.name:
return return
try: try:
if (torch.cuda.memory_allocated() > 0): if torch.cuda.memory_allocated() > 0:
# Clear current loaded model from memory # Clear current loaded model from memory
torch.cuda.empty_cache() torch.cuda.empty_cache()
del self.model del self.model
gc.collect() gc.collect()
self.model = self.init_model(new_name, self.device, **self.kwargs) self.model = self.init_model(
new_name, switch_mps_device(new_name, self.device), **self.kwargs
)
self.name = new_name self.name = new_name
except NotImplementedError as e: except NotImplementedError as e:
raise e raise e

View File

@ -143,12 +143,6 @@ def parse_args():
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
) )
if args.device == "mps":
if args.model not in MPS_SUPPORT_MODELS:
parser.error(
f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}"
)
if args.model_dir and args.model_dir is not None: if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):
parser.error(f"invalid --model-dir: {args.model_dir} is a file") parser.error(f"invalid --model-dir: {args.model_dir} is a file")

View File

@ -22,7 +22,8 @@ from lama_cleaner.const import (
DEFAULT_MODEL, DEFAULT_MODEL,
DEFAULT_DEVICE, DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP, NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS, DEFAULT_MODEL_DIR,
MPS_SUPPORT_MODELS,
) )
_config_file = None _config_file = None
@ -115,7 +116,7 @@ def main(config_file: str):
with gr.Row(): with gr.Row():
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model) model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
device = gr.Radio( device = gr.Radio(
AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device AVAILABLE_DEVICES, label="Device", value=init_config.device
) )
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
no_gui_auto_close = gr.Checkbox( no_gui_auto_close = gr.Checkbox(