auto switch mps device to cpu device
This commit is contained in:
parent
f9b5dcbfd7
commit
8f8bcfe0f4
@ -8,10 +8,18 @@ import cv2
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from lama_cleaner.const import MPS_SUPPORT_MODELS
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch.hub import download_url_to_file, get_dir
|
from torch.hub import download_url_to_file, get_dir
|
||||||
|
|
||||||
|
|
||||||
|
def switch_mps_device(model_name, device):
|
||||||
|
if model_name not in MPS_SUPPORT_MODELS and (device == "mps" or device == torch.device('mps')):
|
||||||
|
logger.info(f"{model_name} not support mps, switch to cpu")
|
||||||
|
return torch.device('cpu')
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
def get_cache_path_by_url(url):
|
def get_cache_path_by_url(url):
|
||||||
parts = urlparse(url)
|
parts = urlparse(url)
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
|
@ -6,11 +6,12 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device
|
||||||
from lama_cleaner.schema import Config, HDStrategy
|
from lama_cleaner.schema import Config, HDStrategy
|
||||||
|
|
||||||
|
|
||||||
class InpaintModel:
|
class InpaintModel:
|
||||||
|
name = "base"
|
||||||
min_size: Optional[int] = None
|
min_size: Optional[int] = None
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
pad_to_square = False
|
pad_to_square = False
|
||||||
@ -21,6 +22,7 @@ class InpaintModel:
|
|||||||
Args:
|
Args:
|
||||||
device:
|
device:
|
||||||
"""
|
"""
|
||||||
|
device = switch_mps_device(self.name, device)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.init_model(device, **kwargs)
|
self.init_model(device, **kwargs)
|
||||||
|
|
||||||
|
@ -1131,6 +1131,7 @@ FCF_MODEL_URL = os.environ.get(
|
|||||||
|
|
||||||
|
|
||||||
class FcF(InpaintModel):
|
class FcF(InpaintModel):
|
||||||
|
name = "fcf"
|
||||||
min_size = 512
|
min_size = 512
|
||||||
pad_mod = 512
|
pad_mod = 512
|
||||||
pad_to_square = True
|
pad_to_square = True
|
||||||
|
@ -9,6 +9,7 @@ from lama_cleaner.schema import Config
|
|||||||
|
|
||||||
|
|
||||||
class InstructPix2Pix(DiffusionInpaintModel):
|
class InstructPix2Pix(DiffusionInpaintModel):
|
||||||
|
name = "instruct_pix2pix"
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
@ -16,6 +16,7 @@ LAMA_MODEL_URL = os.environ.get(
|
|||||||
|
|
||||||
|
|
||||||
class LaMa(InpaintModel):
|
class LaMa(InpaintModel):
|
||||||
|
name = "lama"
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
|
|
||||||
def init_model(self, device, **kwargs):
|
def init_model(self, device, **kwargs):
|
||||||
|
@ -225,6 +225,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
|
|
||||||
class LDM(InpaintModel):
|
class LDM(InpaintModel):
|
||||||
|
name = "ldm"
|
||||||
pad_mod = 32
|
pad_mod = 32
|
||||||
|
|
||||||
def __init__(self, device, fp16: bool = True, **kwargs):
|
def __init__(self, device, fp16: bool = True, **kwargs):
|
||||||
|
@ -76,6 +76,7 @@ MANGA_LINE_MODEL_URL = os.environ.get(
|
|||||||
|
|
||||||
|
|
||||||
class Manga(InpaintModel):
|
class Manga(InpaintModel):
|
||||||
|
name = "manga"
|
||||||
pad_mod = 16
|
pad_mod = 16
|
||||||
|
|
||||||
def init_model(self, device, **kwargs):
|
def init_model(self, device, **kwargs):
|
||||||
|
@ -1401,6 +1401,7 @@ MAT_MODEL_URL = os.environ.get(
|
|||||||
|
|
||||||
|
|
||||||
class MAT(InpaintModel):
|
class MAT(InpaintModel):
|
||||||
|
name = "mat"
|
||||||
min_size = 512
|
min_size = 512
|
||||||
pad_mod = 512
|
pad_mod = 512
|
||||||
pad_to_square = True
|
pad_to_square = True
|
||||||
|
@ -2,12 +2,11 @@ import cv2
|
|||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
flag_map = {
|
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
|
||||||
"INPAINT_NS": cv2.INPAINT_NS,
|
|
||||||
"INPAINT_TELEA": cv2.INPAINT_TELEA
|
|
||||||
}
|
|
||||||
|
|
||||||
class OpenCV2(InpaintModel):
|
class OpenCV2(InpaintModel):
|
||||||
|
name = "cv2"
|
||||||
pad_mod = 1
|
pad_mod = 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -20,5 +19,10 @@ class OpenCV2(InpaintModel):
|
|||||||
mask: [H, W, 1]
|
mask: [H, W, 1]
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
|
cur_res = cv2.inpaint(
|
||||||
|
image[:, :, ::-1],
|
||||||
|
mask,
|
||||||
|
inpaintRadius=config.cv2_radius,
|
||||||
|
flags=flag_map[config.cv2_flag],
|
||||||
|
)
|
||||||
return cur_res
|
return cur_res
|
||||||
|
@ -11,6 +11,7 @@ from lama_cleaner.schema import Config
|
|||||||
|
|
||||||
|
|
||||||
class PaintByExample(DiffusionInpaintModel):
|
class PaintByExample(DiffusionInpaintModel):
|
||||||
|
name = "paint_by_example"
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
|
@ -160,8 +160,10 @@ class SD(DiffusionInpaintModel):
|
|||||||
|
|
||||||
|
|
||||||
class SD15(SD):
|
class SD15(SD):
|
||||||
|
name = "sd1.5"
|
||||||
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
||||||
|
|
||||||
|
|
||||||
class SD2(SD):
|
class SD2(SD):
|
||||||
|
name = "sd2"
|
||||||
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"
|
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"
|
||||||
|
@ -203,6 +203,7 @@ def to_device(data, device):
|
|||||||
|
|
||||||
|
|
||||||
class ZITS(InpaintModel):
|
class ZITS(InpaintModel):
|
||||||
|
name = "zits"
|
||||||
min_size = 256
|
min_size = 256
|
||||||
pad_mod = 32
|
pad_mod = 32
|
||||||
pad_to_square = True
|
pad_to_square = True
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
from lama_cleaner.helper import switch_mps_device
|
||||||
from lama_cleaner.model.fcf import FcF
|
from lama_cleaner.model.fcf import FcF
|
||||||
from lama_cleaner.model.lama import LaMa
|
from lama_cleaner.model.lama import LaMa
|
||||||
from lama_cleaner.model.ldm import LDM
|
from lama_cleaner.model.ldm import LDM
|
||||||
@ -13,8 +14,19 @@ from lama_cleaner.model.zits import ZITS
|
|||||||
from lama_cleaner.model.opencv2 import OpenCV2
|
from lama_cleaner.model.opencv2 import OpenCV2
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
|
models = {
|
||||||
"sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix}
|
"lama": LaMa,
|
||||||
|
"ldm": LDM,
|
||||||
|
"zits": ZITS,
|
||||||
|
"mat": MAT,
|
||||||
|
"fcf": FcF,
|
||||||
|
"sd1.5": SD15,
|
||||||
|
"cv2": OpenCV2,
|
||||||
|
"manga": Manga,
|
||||||
|
"sd2": SD2,
|
||||||
|
"paint_by_example": PaintByExample,
|
||||||
|
"instruct_pix2pix": InstructPix2Pix,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
@ -44,13 +56,15 @@ class ModelManager:
|
|||||||
if new_name == self.name:
|
if new_name == self.name:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if (torch.cuda.memory_allocated() > 0):
|
if torch.cuda.memory_allocated() > 0:
|
||||||
# Clear current loaded model from memory
|
# Clear current loaded model from memory
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
del self.model
|
del self.model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
self.model = self.init_model(new_name, self.device, **self.kwargs)
|
self.model = self.init_model(
|
||||||
|
new_name, switch_mps_device(new_name, self.device), **self.kwargs
|
||||||
|
)
|
||||||
self.name = new_name
|
self.name = new_name
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -143,12 +143,6 @@ def parse_args():
|
|||||||
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
|
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.device == "mps":
|
|
||||||
if args.model not in MPS_SUPPORT_MODELS:
|
|
||||||
parser.error(
|
|
||||||
f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.model_dir and args.model_dir is not None:
|
if args.model_dir and args.model_dir is not None:
|
||||||
if os.path.isfile(args.model_dir):
|
if os.path.isfile(args.model_dir):
|
||||||
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
||||||
|
@ -22,7 +22,8 @@ from lama_cleaner.const import (
|
|||||||
DEFAULT_MODEL,
|
DEFAULT_MODEL,
|
||||||
DEFAULT_DEVICE,
|
DEFAULT_DEVICE,
|
||||||
NO_GUI_AUTO_CLOSE_HELP,
|
NO_GUI_AUTO_CLOSE_HELP,
|
||||||
DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS,
|
DEFAULT_MODEL_DIR,
|
||||||
|
MPS_SUPPORT_MODELS,
|
||||||
)
|
)
|
||||||
|
|
||||||
_config_file = None
|
_config_file = None
|
||||||
@ -115,7 +116,7 @@ def main(config_file: str):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
|
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
|
||||||
device = gr.Radio(
|
device = gr.Radio(
|
||||||
AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device
|
AVAILABLE_DEVICES, label="Device", value=init_config.device
|
||||||
)
|
)
|
||||||
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
|
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
|
||||||
no_gui_auto_close = gr.Checkbox(
|
no_gui_auto_close = gr.Checkbox(
|
||||||
|
Loading…
Reference in New Issue
Block a user