IOPaint/lama_cleaner/model_manager.py

43 lines
1.2 KiB
Python
Raw Normal View History

2022-04-15 18:11:51 +02:00
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.schema import Config
class ModelManager:
LAMA = 'lama'
LDM = 'ldm'
def __init__(self, name: str, device):
self.name = name
self.device = device
self.model = self.init_model(name, device)
def init_model(self, name: str, device):
if name == self.LAMA:
model = LaMa(device)
elif name == self.LDM:
model = LDM(device)
else:
raise NotImplementedError(f"Not supported model: {name}")
return model
2022-04-17 17:31:12 +02:00
def is_downloaded(self, name: str) -> bool:
if name == self.LAMA:
return LaMa.is_downloaded()
elif name == self.LDM:
return LDM.is_downloaded()
else:
raise NotImplementedError(f"Not supported model: {name}")
2022-04-15 18:11:51 +02:00
def __call__(self, image, mask, config: Config):
return self.model(image, mask, config)
def switch(self, new_name: str):
if new_name == self.name:
return
try:
self.model = self.init_model(new_name, self.device)
self.name = new_name
except NotImplementedError as e:
raise e