IOPaint/lama_cleaner/model/paint_by_example.py

76 lines
2.9 KiB
Python
Raw Normal View History

2022-12-10 15:06:15 +01:00
import PIL
import PIL.Image
import cv2
import torch
2023-01-05 15:07:39 +01:00
from loguru import logger
2023-01-27 13:59:22 +01:00
from lama_cleaner.model.base import DiffusionInpaintModel
2022-12-10 15:06:15 +01:00
from lama_cleaner.schema import Config
2023-01-27 13:59:22 +01:00
class PaintByExample(DiffusionInpaintModel):
2023-12-15 05:40:29 +01:00
name = "Fantasy-Studio/Paint-by-Example"
2022-12-10 15:06:15 +01:00
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
2023-11-16 14:12:06 +01:00
from diffusers import DiffusionPipeline
fp16 = not kwargs.get("no_half", False)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
2023-01-03 14:30:33 +01:00
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
2023-12-16 06:34:56 +01:00
model_kwargs = {}
2023-11-16 14:12:06 +01:00
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Paint By Example Model NSFW checker")
2023-11-16 14:12:06 +01:00
model_kwargs.update(
dict(safety_checker=None, requires_safety_checker=False)
)
2022-12-10 15:06:15 +01:00
self.model = DiffusionPipeline.from_pretrained(
2023-11-16 14:12:06 +01:00
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
2022-12-10 15:06:15 +01:00
)
2023-01-05 15:07:39 +01:00
self.model.enable_attention_slicing()
2023-11-16 14:12:06 +01:00
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
2023-01-05 15:07:39 +01:00
# TODO: gpu_id
2023-11-16 14:12:06 +01:00
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
2023-01-05 15:07:39 +01:00
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
2022-12-10 15:06:15 +01:00
2023-11-16 14:12:06 +01:00
@staticmethod
def download():
from diffusers import DiffusionPipeline
DiffusionPipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")
2022-12-10 15:06:15 +01:00
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
output = self.model(
image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=config.paint_by_example_example_image,
2023-12-16 06:34:56 +01:00
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
2023-11-16 14:12:06 +01:00
output_type="np.array",
2023-12-16 06:34:56 +01:00
generator=torch.manual_seed(config.sd_seed),
2022-12-10 15:06:15 +01:00
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True