diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 2deaf9c..8097c0d 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -8,10 +8,18 @@ import cv2 from PIL import Image, ImageOps import numpy as np import torch +from lama_cleaner.const import MPS_SUPPORT_MODELS from loguru import logger from torch.hub import download_url_to_file, get_dir +def switch_mps_device(model_name, device): + if model_name not in MPS_SUPPORT_MODELS and (device == "mps" or device == torch.device('mps')): + logger.info(f"{model_name} not support mps, switch to cpu") + return torch.device('cpu') + return device + + def get_cache_path_by_url(url): parts = urlparse(url) hub_dir = get_dir() diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 60eccb4..3ff29b8 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -6,11 +6,12 @@ import torch import numpy as np from loguru import logger -from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo +from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device from lama_cleaner.schema import Config, HDStrategy class InpaintModel: + name = "base" min_size: Optional[int] = None pad_mod = 8 pad_to_square = False @@ -21,6 +22,7 @@ class InpaintModel: Args: device: """ + device = switch_mps_device(self.name, device) self.device = device self.init_model(device, **kwargs) diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py index 9e9e8c0..0b17775 100644 --- a/lama_cleaner/model/fcf.py +++ b/lama_cleaner/model/fcf.py @@ -1131,6 +1131,7 @@ FCF_MODEL_URL = os.environ.get( class FcF(InpaintModel): + name = "fcf" min_size = 512 pad_mod = 512 pad_to_square = True diff --git a/lama_cleaner/model/instruct_pix2pix.py b/lama_cleaner/model/instruct_pix2pix.py index b343800..dc57763 100644 --- a/lama_cleaner/model/instruct_pix2pix.py +++ b/lama_cleaner/model/instruct_pix2pix.py @@ -9,6 +9,7 @@ from lama_cleaner.schema import Config class InstructPix2Pix(DiffusionInpaintModel): + name = "instruct_pix2pix" pad_mod = 8 min_size = 512 diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index 7f85e3b..adaec35 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -5,7 +5,7 @@ import numpy as np import torch from loguru import logger -from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url +from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config @@ -16,6 +16,7 @@ LAMA_MODEL_URL = os.environ.get( class LaMa(InpaintModel): + name = "lama" pad_mod = 8 def init_model(self, device, **kwargs): diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index 27f6d5a..c224069 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -225,6 +225,7 @@ class LatentDiffusion(DDPM): class LDM(InpaintModel): + name = "ldm" pad_mod = 32 def __init__(self, device, fp16: bool = True, **kwargs): diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py index cae4766..e1ece6e 100644 --- a/lama_cleaner/model/manga.py +++ b/lama_cleaner/model/manga.py @@ -76,6 +76,7 @@ MANGA_LINE_MODEL_URL = os.environ.get( class Manga(InpaintModel): + name = "manga" pad_mod = 16 def init_model(self, device, **kwargs): diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py index 67020bc..58b55bf 100644 --- a/lama_cleaner/model/mat.py +++ b/lama_cleaner/model/mat.py @@ -1401,6 +1401,7 @@ MAT_MODEL_URL = os.environ.get( class MAT(InpaintModel): + name = "mat" min_size = 512 pad_mod = 512 pad_to_square = True diff --git a/lama_cleaner/model/opencv2.py b/lama_cleaner/model/opencv2.py index 1802ccd..e0618dd 100644 --- a/lama_cleaner/model/opencv2.py +++ b/lama_cleaner/model/opencv2.py @@ -2,12 +2,11 @@ import cv2 from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config -flag_map = { - "INPAINT_NS": cv2.INPAINT_NS, - "INPAINT_TELEA": cv2.INPAINT_TELEA -} +flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA} + class OpenCV2(InpaintModel): + name = "cv2" pad_mod = 1 @staticmethod @@ -20,5 +19,10 @@ class OpenCV2(InpaintModel): mask: [H, W, 1] return: BGR IMAGE """ - cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag]) + cur_res = cv2.inpaint( + image[:, :, ::-1], + mask, + inpaintRadius=config.cv2_radius, + flags=flag_map[config.cv2_flag], + ) return cur_res diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index 29ffa2a..d28b275 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -11,6 +11,7 @@ from lama_cleaner.schema import Config class PaintByExample(DiffusionInpaintModel): + name = "paint_by_example" pad_mod = 8 min_size = 512 diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 598ce30..7239988 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -160,8 +160,10 @@ class SD(DiffusionInpaintModel): class SD15(SD): + name = "sd1.5" model_id_or_path = "runwayml/stable-diffusion-inpainting" class SD2(SD): + name = "sd2" model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index b26981b..9ccdcf4 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -203,6 +203,7 @@ def to_device(data, device): class ZITS(InpaintModel): + name = "zits" min_size = 256 pad_mod = 32 pad_to_square = True diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 9ecdc12..7b5720a 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -1,6 +1,7 @@ import torch import gc +from lama_cleaner.helper import switch_mps_device from lama_cleaner.model.fcf import FcF from lama_cleaner.model.lama import LaMa from lama_cleaner.model.ldm import LDM @@ -13,8 +14,19 @@ from lama_cleaner.model.zits import ZITS from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.schema import Config -models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga, - "sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix} +models = { + "lama": LaMa, + "ldm": LDM, + "zits": ZITS, + "mat": MAT, + "fcf": FcF, + "sd1.5": SD15, + "cv2": OpenCV2, + "manga": Manga, + "sd2": SD2, + "paint_by_example": PaintByExample, + "instruct_pix2pix": InstructPix2Pix, +} class ModelManager: @@ -44,13 +56,15 @@ class ModelManager: if new_name == self.name: return try: - if (torch.cuda.memory_allocated() > 0): + if torch.cuda.memory_allocated() > 0: # Clear current loaded model from memory torch.cuda.empty_cache() del self.model gc.collect() - self.model = self.init_model(new_name, self.device, **self.kwargs) + self.model = self.init_model( + new_name, switch_mps_device(new_name, self.device), **self.kwargs + ) self.name = new_name except NotImplementedError as e: raise e diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index aaaddf3..cb6900a 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -143,12 +143,6 @@ def parse_args(): "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" ) - if args.device == "mps": - if args.model not in MPS_SUPPORT_MODELS: - parser.error( - f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}" - ) - if args.model_dir and args.model_dir is not None: if os.path.isfile(args.model_dir): parser.error(f"invalid --model-dir: {args.model_dir} is a file") diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index b5882dc..a01cbf3 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -22,7 +22,8 @@ from lama_cleaner.const import ( DEFAULT_MODEL, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, - DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS, + DEFAULT_MODEL_DIR, + MPS_SUPPORT_MODELS, ) _config_file = None @@ -115,7 +116,7 @@ def main(config_file: str): with gr.Row(): model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model) device = gr.Radio( - AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device + AVAILABLE_DEVICES, label="Device", value=init_config.device ) gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") no_gui_auto_close = gr.Checkbox(