2022-12-10 15:06:15 +01:00
|
|
|
import PIL
|
|
|
|
import PIL.Image
|
|
|
|
import cv2
|
|
|
|
import torch
|
2023-01-05 15:07:39 +01:00
|
|
|
from loguru import logger
|
|
|
|
|
2023-01-27 13:59:22 +01:00
|
|
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
2022-12-10 15:06:15 +01:00
|
|
|
from lama_cleaner.schema import Config
|
|
|
|
|
|
|
|
|
2023-01-27 13:59:22 +01:00
|
|
|
class PaintByExample(DiffusionInpaintModel):
|
2023-12-15 05:40:29 +01:00
|
|
|
name = "Fantasy-Studio/Paint-by-Example"
|
2022-12-10 15:06:15 +01:00
|
|
|
pad_mod = 8
|
|
|
|
min_size = 512
|
|
|
|
|
|
|
|
def init_model(self, device: torch.device, **kwargs):
|
2023-11-16 14:12:06 +01:00
|
|
|
from diffusers import DiffusionPipeline
|
|
|
|
|
|
|
|
fp16 = not kwargs.get("no_half", False)
|
|
|
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
2023-01-03 14:30:33 +01:00
|
|
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
2023-12-16 06:34:56 +01:00
|
|
|
model_kwargs = {}
|
2023-01-18 11:34:10 +01:00
|
|
|
|
2023-11-16 14:12:06 +01:00
|
|
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
2023-01-18 11:34:10 +01:00
|
|
|
logger.info("Disable Paint By Example Model NSFW checker")
|
2023-11-16 14:12:06 +01:00
|
|
|
model_kwargs.update(
|
|
|
|
dict(safety_checker=None, requires_safety_checker=False)
|
|
|
|
)
|
2023-01-18 11:34:10 +01:00
|
|
|
|
2022-12-10 15:06:15 +01:00
|
|
|
self.model = DiffusionPipeline.from_pretrained(
|
2023-12-27 15:00:07 +01:00
|
|
|
self.name, torch_dtype=torch_dtype, **model_kwargs
|
2022-12-10 15:06:15 +01:00
|
|
|
)
|
2023-01-18 11:34:10 +01:00
|
|
|
|
2023-01-05 15:07:39 +01:00
|
|
|
# TODO: gpu_id
|
2023-11-16 14:12:06 +01:00
|
|
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
2023-01-18 11:34:10 +01:00
|
|
|
self.model.image_encoder = self.model.image_encoder.to(device)
|
2023-01-05 15:07:39 +01:00
|
|
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
2023-01-18 11:34:10 +01:00
|
|
|
else:
|
|
|
|
self.model = self.model.to(device)
|
2022-12-10 15:06:15 +01:00
|
|
|
|
|
|
|
def forward(self, image, mask, config: Config):
|
|
|
|
"""Input image and output image have same size
|
|
|
|
image: [H, W, C] RGB
|
|
|
|
mask: [H, W, 1] 255 means area to repaint
|
|
|
|
return: BGR IMAGE
|
|
|
|
"""
|
|
|
|
output = self.model(
|
|
|
|
image=PIL.Image.fromarray(image),
|
|
|
|
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
|
|
|
example_image=config.paint_by_example_example_image,
|
2023-12-16 06:34:56 +01:00
|
|
|
num_inference_steps=config.sd_steps,
|
|
|
|
guidance_scale=config.sd_guidance_scale,
|
|
|
|
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
|
2023-11-16 14:12:06 +01:00
|
|
|
output_type="np.array",
|
2023-12-16 06:34:56 +01:00
|
|
|
generator=torch.manual_seed(config.sd_seed),
|
2022-12-10 15:06:15 +01:00
|
|
|
).images[0]
|
|
|
|
|
|
|
|
output = (output * 255).round().astype("uint8")
|
|
|
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
|
|
|
return output
|