diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index e14736e..549b307 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -1,4 +1,5 @@ import torch +import gc from lama_cleaner.model.fcf import FcF from lama_cleaner.model.lama import LaMa @@ -42,6 +43,12 @@ class ModelManager: if new_name == self.name: return try: + if (torch.cuda.memory_allocated() > 0): + # Clear current loaded model from memory + torch.cuda.empty_cache() + gc.collect() + del self.model + self.model = self.init_model(new_name, self.device, **self.kwargs) self.name = new_name except NotImplementedError as e: