add migan
This commit is contained in:
parent
53aea791c5
commit
a5c241ac02
109
lama_cleaner/model/mi_gan.py
Normal file
109
lama_cleaner/model/mi_gan.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lama_cleaner.const import Config
|
||||||
|
from lama_cleaner.helper import (
|
||||||
|
load_jit_model,
|
||||||
|
download_model,
|
||||||
|
get_cache_path_by_url,
|
||||||
|
boxes_from_mask,
|
||||||
|
resize_max_size,
|
||||||
|
norm_img,
|
||||||
|
)
|
||||||
|
from lama_cleaner.model.base import InpaintModel
|
||||||
|
|
||||||
|
MIGAN_MODEL_URL = os.environ.get(
|
||||||
|
"MIGAN_MODEL_URL",
|
||||||
|
"/Users/cwq/code/github/MI-GAN/exported_models/migan_places512/models/migan_traced.pt",
|
||||||
|
)
|
||||||
|
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
|
||||||
|
|
||||||
|
|
||||||
|
class MIGAN(InpaintModel):
|
||||||
|
name = "migan"
|
||||||
|
min_size = 512
|
||||||
|
pad_mod = 512
|
||||||
|
pad_to_square = True
|
||||||
|
|
||||||
|
def init_model(self, device, **kwargs):
|
||||||
|
self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download():
|
||||||
|
download_model(MIGAN_MODEL_URL, MIGAN_MODEL_MD5)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, image, mask, config: Config):
|
||||||
|
"""
|
||||||
|
images: [H, W, C] RGB, not normalized
|
||||||
|
masks: [H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
if image.shape[0] == 512 and image.shape[1] == 512:
|
||||||
|
return self._pad_forward(image, mask, config)
|
||||||
|
|
||||||
|
boxes = boxes_from_mask(mask)
|
||||||
|
crop_result = []
|
||||||
|
config.hd_strategy_crop_margin = 128
|
||||||
|
for box in boxes:
|
||||||
|
crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
|
||||||
|
origin_size = crop_image.shape[:2]
|
||||||
|
resize_image = resize_max_size(crop_image, size_limit=512)
|
||||||
|
resize_mask = resize_max_size(crop_mask, size_limit=512)
|
||||||
|
inpaint_result = self._pad_forward(resize_image, resize_mask, config)
|
||||||
|
|
||||||
|
# only paste masked area result
|
||||||
|
inpaint_result = cv2.resize(
|
||||||
|
inpaint_result,
|
||||||
|
(origin_size[1], origin_size[0]),
|
||||||
|
interpolation=cv2.INTER_CUBIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_pixel_indices = crop_mask < 127
|
||||||
|
inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
|
||||||
|
original_pixel_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
crop_result.append((inpaint_result, crop_box))
|
||||||
|
|
||||||
|
inpaint_result = image[:, :, ::-1]
|
||||||
|
for crop_image, crop_box in crop_result:
|
||||||
|
x1, y1, x2, y2 = crop_box
|
||||||
|
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
||||||
|
|
||||||
|
return inpaint_result
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input images and output images have same size
|
||||||
|
images: [H, W, C] RGB
|
||||||
|
masks: [H, W] mask area == 255
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
|
||||||
|
image = norm_img(image) # [0, 1]
|
||||||
|
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
||||||
|
mask = (mask > 120) * 255
|
||||||
|
mask = norm_img(mask)
|
||||||
|
|
||||||
|
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||||
|
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
erased_img = image * (1 - mask)
|
||||||
|
input_image = torch.cat([0.5 - mask, erased_img], dim=1)
|
||||||
|
|
||||||
|
output = self.model(input_image)
|
||||||
|
output = (
|
||||||
|
(output.permute(0, 2, 3, 1) * 127.5 + 127.5)
|
||||||
|
.round()
|
||||||
|
.clamp(0, 255)
|
||||||
|
.to(torch.uint8)
|
||||||
|
)
|
||||||
|
output = output[0].cpu().numpy()
|
||||||
|
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return cur_res
|
@ -3,7 +3,11 @@ import gc
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU, MODELS_SUPPORT_LCM_LORA
|
from lama_cleaner.const import (
|
||||||
|
SD15_MODELS,
|
||||||
|
MODELS_SUPPORT_FREEU,
|
||||||
|
MODELS_SUPPORT_LCM_LORA,
|
||||||
|
)
|
||||||
from lama_cleaner.helper import switch_mps_device
|
from lama_cleaner.helper import switch_mps_device
|
||||||
from lama_cleaner.model.controlnet import ControlNet
|
from lama_cleaner.model.controlnet import ControlNet
|
||||||
from lama_cleaner.model.fcf import FcF
|
from lama_cleaner.model.fcf import FcF
|
||||||
@ -12,6 +16,7 @@ from lama_cleaner.model.lama import LaMa
|
|||||||
from lama_cleaner.model.ldm import LDM
|
from lama_cleaner.model.ldm import LDM
|
||||||
from lama_cleaner.model.manga import Manga
|
from lama_cleaner.model.manga import Manga
|
||||||
from lama_cleaner.model.mat import MAT
|
from lama_cleaner.model.mat import MAT
|
||||||
|
from lama_cleaner.model.mi_gan import MIGAN
|
||||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||||
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
||||||
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
||||||
@ -37,6 +42,7 @@ models = {
|
|||||||
"instruct_pix2pix": InstructPix2Pix,
|
"instruct_pix2pix": InstructPix2Pix,
|
||||||
Kandinsky22.name: Kandinsky22,
|
Kandinsky22.name: Kandinsky22,
|
||||||
SDXL.name: SDXL,
|
SDXL.name: SDXL,
|
||||||
|
MIGAN.name: MIGAN,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -146,4 +152,3 @@ class ModelManager:
|
|||||||
self.model.model.load_lora_weights(self.model.lcm_lora_id)
|
self.model.model.load_lora_weights(self.model.lcm_lora_id)
|
||||||
else:
|
else:
|
||||||
self.model.model.disable_lora()
|
self.model.model.disable_lora()
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ def test_cv2(strategy, cv2_flag, cv2_radius):
|
|||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"sd_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
|
f"cv2_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
)
|
)
|
||||||
@ -188,7 +188,23 @@ def test_manga(strategy):
|
|||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"sd_{strategy.capitalize()}.png",
|
f"manga_{strategy.capitalize()}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
def test_mi_gan(strategy):
|
||||||
|
model = ModelManager(
|
||||||
|
name="migan",
|
||||||
|
device=torch.device(device),
|
||||||
|
)
|
||||||
|
cfg = get_config(strategy)
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"migan_{strategy.capitalize()}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user