IOPaint/iopaint/model/paint_by_example.py

70 lines
2.7 KiB
Python
Raw Normal View History

2022-12-10 15:06:15 +01:00
import PIL
import PIL.Image
import cv2
import torch
2023-01-05 15:07:39 +01:00
from loguru import logger
2024-01-05 08:19:23 +01:00
from iopaint.helper import decode_base64_to_image
2024-01-05 09:40:06 +01:00
from .base import DiffusionInpaintModel
2024-01-05 08:19:23 +01:00
from iopaint.schema import InpaintRequest
2022-12-10 15:06:15 +01:00
2023-01-27 13:59:22 +01:00
class PaintByExample(DiffusionInpaintModel):
2023-12-15 05:40:29 +01:00
name = "Fantasy-Studio/Paint-by-Example"
2022-12-10 15:06:15 +01:00
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
2023-11-16 14:12:06 +01:00
from diffusers import DiffusionPipeline
fp16 = not kwargs.get("no_half", False)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
2023-01-03 14:30:33 +01:00
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
2023-12-16 06:34:56 +01:00
model_kwargs = {}
2023-11-16 14:12:06 +01:00
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Paint By Example Model NSFW checker")
2023-11-16 14:12:06 +01:00
model_kwargs.update(
dict(safety_checker=None, requires_safety_checker=False)
)
2022-12-10 15:06:15 +01:00
self.model = DiffusionPipeline.from_pretrained(
2023-12-27 15:00:07 +01:00
self.name, torch_dtype=torch_dtype, **model_kwargs
2022-12-10 15:06:15 +01:00
)
if torch.backends.mps.is_available():
self.model.enable_attention_slicing()
2023-01-05 15:07:39 +01:00
# TODO: gpu_id
2023-11-16 14:12:06 +01:00
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
2023-01-05 15:07:39 +01:00
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
2022-12-10 15:06:15 +01:00
2023-12-30 16:36:44 +01:00
def forward(self, image, mask, config: InpaintRequest):
2022-12-10 15:06:15 +01:00
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
2023-12-30 16:36:44 +01:00
if config.paint_by_example_example_image is None:
raise ValueError("paint_by_example_example_image is required")
example_image, _, _ = decode_base64_to_image(
config.paint_by_example_example_image
)
2022-12-10 15:06:15 +01:00
output = self.model(
image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
2023-12-30 16:36:44 +01:00
example_image=PIL.Image.fromarray(example_image),
2023-12-16 06:34:56 +01:00
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
2023-11-16 14:12:06 +01:00
output_type="np.array",
2023-12-16 06:34:56 +01:00
generator=torch.manual_seed(config.sd_seed),
2022-12-10 15:06:15 +01:00
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output