IOPaint/iopaint/model/paint_by_example.py

69 lines
2.7 KiB
Python
Raw Normal View History

2022-12-10 15:06:15 +01:00
import PIL
import PIL.Image
import cv2
import torch
2023-01-05 15:07:39 +01:00
from loguru import logger
2024-01-05 08:19:23 +01:00
from iopaint.helper import decode_base64_to_image
2024-01-05 09:40:06 +01:00
from .base import DiffusionInpaintModel
2024-01-05 08:19:23 +01:00
from iopaint.schema import InpaintRequest
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
2022-12-10 15:06:15 +01:00
2023-01-27 13:59:22 +01:00
class PaintByExample(DiffusionInpaintModel):
2023-12-15 05:40:29 +01:00
name = "Fantasy-Studio/Paint-by-Example"
2022-12-10 15:06:15 +01:00
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
2023-11-16 14:12:06 +01:00
from diffusers import DiffusionPipeline
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {
"local_files_only": is_local_files_only(**kwargs),
}
2023-11-16 14:12:06 +01:00
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Paint By Example Model NSFW checker")
2023-11-16 14:12:06 +01:00
model_kwargs.update(
dict(safety_checker=None, requires_safety_checker=False)
)
2022-12-10 15:06:15 +01:00
self.model = DiffusionPipeline.from_pretrained(
2023-12-27 15:00:07 +01:00
self.name, torch_dtype=torch_dtype, **model_kwargs
2022-12-10 15:06:15 +01:00
)
2024-01-09 15:42:48 +01:00
enable_low_mem(self.model, kwargs.get("low_mem", False))
2023-01-05 15:07:39 +01:00
# TODO: gpu_id
2023-11-16 14:12:06 +01:00
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
2023-01-05 15:07:39 +01:00
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
2022-12-10 15:06:15 +01:00
2023-12-30 16:36:44 +01:00
def forward(self, image, mask, config: InpaintRequest):
2022-12-10 15:06:15 +01:00
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
2023-12-30 16:36:44 +01:00
if config.paint_by_example_example_image is None:
raise ValueError("paint_by_example_example_image is required")
example_image, _, _ = decode_base64_to_image(
config.paint_by_example_example_image
)
2022-12-10 15:06:15 +01:00
output = self.model(
image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
2023-12-30 16:36:44 +01:00
example_image=PIL.Image.fromarray(example_image),
2023-12-16 06:34:56 +01:00
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
2023-11-16 14:12:06 +01:00
output_type="np.array",
2023-12-16 06:34:56 +01:00
generator=torch.manual_seed(config.sd_seed),
2022-12-10 15:06:15 +01:00
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output