From a5c241ac02dcf429be7195000d7e4e9d0816bfbb Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 20 Nov 2023 13:05:28 +0800 Subject: [PATCH] add migan --- lama_cleaner/model/mi_gan.py | 109 +++++++++++++++++++++++++++++++ lama_cleaner/model_manager.py | 9 ++- lama_cleaner/tests/test_model.py | 20 +++++- 3 files changed, 134 insertions(+), 4 deletions(-) create mode 100644 lama_cleaner/model/mi_gan.py diff --git a/lama_cleaner/model/mi_gan.py b/lama_cleaner/model/mi_gan.py new file mode 100644 index 0000000..1b2ba1d --- /dev/null +++ b/lama_cleaner/model/mi_gan.py @@ -0,0 +1,109 @@ +import os + +import cv2 +import torch + +from lama_cleaner.const import Config +from lama_cleaner.helper import ( + load_jit_model, + download_model, + get_cache_path_by_url, + boxes_from_mask, + resize_max_size, + norm_img, +) +from lama_cleaner.model.base import InpaintModel + +MIGAN_MODEL_URL = os.environ.get( + "MIGAN_MODEL_URL", + "/Users/cwq/code/github/MI-GAN/exported_models/migan_places512/models/migan_traced.pt", +) +MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c") + + +class MIGAN(InpaintModel): + name = "migan" + min_size = 512 + pad_mod = 512 + pad_to_square = True + + def init_model(self, device, **kwargs): + self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval() + + @staticmethod + def download(): + download_model(MIGAN_MODEL_URL, MIGAN_MODEL_MD5) + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL)) + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + if image.shape[0] == 512 and image.shape[1] == 512: + return self._pad_forward(image, mask, config) + + boxes = boxes_from_mask(mask) + crop_result = [] + config.hd_strategy_crop_margin = 128 + for box in boxes: + crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config) + origin_size = crop_image.shape[:2] + resize_image = resize_max_size(crop_image, size_limit=512) + resize_mask = resize_max_size(crop_mask, size_limit=512) + inpaint_result = self._pad_forward(resize_image, resize_mask, config) + + # only paste masked area result + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) + + original_pixel_indices = crop_mask < 127 + inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][ + original_pixel_indices + ] + + crop_result.append((inpaint_result, crop_box)) + + inpaint_result = image[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + return inpaint_result + + def forward(self, image, mask, config: Config): + """Input images and output images have same size + images: [H, W, C] RGB + masks: [H, W] mask area == 255 + return: BGR IMAGE + """ + + image = norm_img(image) # [0, 1] + image = image * 2 - 1 # [0, 1] -> [-1, 1] + mask = (mask > 120) * 255 + mask = norm_img(mask) + + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + erased_img = image * (1 - mask) + input_image = torch.cat([0.5 - mask, erased_img], dim=1) + + output = self.model(input_image) + output = ( + (output.permute(0, 2, 3, 1) * 127.5 + 127.5) + .round() + .clamp(0, 255) + .to(torch.uint8) + ) + output = output[0].cpu().numpy() + cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 43d9ced..8461c95 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -3,7 +3,11 @@ import gc from loguru import logger -from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU, MODELS_SUPPORT_LCM_LORA +from lama_cleaner.const import ( + SD15_MODELS, + MODELS_SUPPORT_FREEU, + MODELS_SUPPORT_LCM_LORA, +) from lama_cleaner.helper import switch_mps_device from lama_cleaner.model.controlnet import ControlNet from lama_cleaner.model.fcf import FcF @@ -12,6 +16,7 @@ from lama_cleaner.model.lama import LaMa from lama_cleaner.model.ldm import LDM from lama_cleaner.model.manga import Manga from lama_cleaner.model.mat import MAT +from lama_cleaner.model.mi_gan import MIGAN from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 @@ -37,6 +42,7 @@ models = { "instruct_pix2pix": InstructPix2Pix, Kandinsky22.name: Kandinsky22, SDXL.name: SDXL, + MIGAN.name: MIGAN, } @@ -146,4 +152,3 @@ class ModelManager: self.model.model.load_lora_weights(self.model.lcm_lora_id) else: self.model.model.disable_lora() - diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index be253b5..f6931ce 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -170,7 +170,7 @@ def test_cv2(strategy, cv2_flag, cv2_radius): assert_equal( model, cfg, - f"sd_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png", + f"cv2_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", ) @@ -188,7 +188,23 @@ def test_manga(strategy): assert_equal( model, cfg, - f"sd_{strategy.capitalize()}.png", + f"manga_{strategy.capitalize()}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_mi_gan(strategy): + model = ModelManager( + name="migan", + device=torch.device(device), + ) + cfg = get_config(strategy) + assert_equal( + model, + cfg, + f"migan_{strategy.capitalize()}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", )