diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index bb7c292..73aae70 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -12,7 +12,8 @@ export default async function inpaint( sizeLimit?: string, seed?: number, maskBase64?: string, - customMask?: File + customMask?: File, + paintByExampleImage?: File ) { // 1080, 2000, Original const fd = new FormData() @@ -48,6 +49,7 @@ export default async function inpaint( fd.append('croperHeight', croperRect.height.toString()) fd.append('croperWidth', croperRect.width.toString()) fd.append('useCroper', settings.showCroper ? 'true' : 'false') + fd.append('sdMaskBlur', settings.sdMaskBlur.toString()) fd.append('sdStrength', settings.sdStrength.toString()) fd.append('sdSteps', settings.sdSteps.toString()) @@ -59,6 +61,26 @@ export default async function inpaint( fd.append('cv2Radius', settings.cv2Radius.toString()) fd.append('cv2Flag', settings.cv2Flag.toString()) + fd.append('paintByExampleSteps', settings.paintByExampleSteps.toString()) + fd.append( + 'paintByExampleGuidanceScale', + settings.paintByExampleGuidanceScale.toString() + ) + fd.append('paintByExampleSeed', seed ? seed.toString() : '-1') + fd.append( + 'paintByExampleMaskBlur', + settings.paintByExampleMaskBlur.toString() + ) + fd.append( + 'paintByExampleMatchHistograms', + settings.paintByExampleMatchHistograms ? 'true' : 'false' + ) + // TODO: resize image's shortest_edge to 224 before pass to backend, save network time? + // https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor + if (paintByExampleImage) { + fd.append('paintByExampleImage', paintByExampleImage) + } + if (sizeLimit === undefined) { fd.append('sizeLimit', '1080') } else { diff --git a/lama_cleaner/app/src/components/Editor/Editor.tsx b/lama_cleaner/app/src/components/Editor/Editor.tsx index e37c3cd..54b6907 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.tsx +++ b/lama_cleaner/app/src/components/Editor/Editor.tsx @@ -39,6 +39,7 @@ import { isInpaintingState, isInteractiveSegRunningState, isInteractiveSegState, + isPaintByExampleState, isSDState, negativePropmtState, propmtState, @@ -53,6 +54,7 @@ import emitter, { EVENT_PROMPT, EVENT_CUSTOM_MASK, CustomMaskEventData, + EVENT_PAINT_BY_EXAMPLE, } from '../../event' import FileSelect from '../FileSelect/FileSelect' import InteractiveSeg from '../InteractiveSeg/InteractiveSeg' @@ -108,6 +110,7 @@ export default function Editor() { const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState) const runMannually = useRecoilValue(runManuallyState) const isSD = useRecoilValue(isSDState) + const isPaintByExample = useRecoilValue(isPaintByExampleState) const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState( isInteractiveSegState ) @@ -262,8 +265,11 @@ export default function Editor() { async ( useLastLineGroup?: boolean, customMask?: File, - maskImage?: HTMLImageElement | null + maskImage?: HTMLImageElement | null, + paintByExampleImage?: File ) => { + // customMask: mask uploaded by user + // maskImage: mask from interactive segmentation if (file === undefined) { return } @@ -328,9 +334,6 @@ export default function Editor() { } } - const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1 - - console.log({ useCustomMask }) try { const res = await inpaint( targetFile, @@ -339,15 +342,16 @@ export default function Editor() { promptVal, negativePromptVal, sizeLimit.toString(), - sdSeed, + seedVal, useCustomMask ? undefined : maskCanvas.toDataURL(), - useCustomMask ? customMask : undefined + useCustomMask ? customMask : undefined, + paintByExampleImage ) if (!res) { throw new Error('Something went wrong on server side.') } const { blob, seed } = res - if (seed && !settings.sdSeedFixed) { + if (seed) { setSeed(parseInt(seed, 10)) } const newRender = new Image() @@ -395,6 +399,7 @@ export default function Editor() { drawOnCurrentRender, hadDrawSomething, drawLinesOnMask, + seedVal, ] ) @@ -439,6 +444,31 @@ export default function Editor() { } }, [runInpainting]) + useEffect(() => { + emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => { + if (hadDrawSomething() || interactiveSegMask) { + runInpainting(false, undefined, interactiveSegMask, data.image) + } else if (lastLineGroup.length !== 0) { + // 使用上一次手绘的 mask 生成 + runInpainting(true, undefined, prevInteractiveSegMask, data.image) + } else if (prevInteractiveSegMask) { + // 使用上一次 IS 的 mask 生成 + runInpainting(false, undefined, prevInteractiveSegMask, data.image) + } else { + setToastState({ + open: true, + desc: 'Please draw mask on picture', + state: 'error', + duration: 1500, + }) + } + }) + + return () => { + emitter.off(EVENT_PAINT_BY_EXAMPLE) + } + }, [runInpainting]) + const hadRunInpainting = () => { return renders.length !== 0 } @@ -793,7 +823,11 @@ export default function Editor() { return } - if (isSD && settings.showCroper && isOutsideCroper(mouseXY(ev))) { + if ( + (isSD || isPaintByExample) && + settings.showCroper && + isOutsideCroper(mouseXY(ev)) + ) { return } @@ -876,7 +910,12 @@ export default function Editor() { return false } - useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD]) + useKey(undoPredicate, undo, undefined, [ + undoStroke, + undoRender, + runMannually, + curLineGroup, + ]) const disableUndo = () => { if (isInteractiveSeg) { @@ -955,7 +994,12 @@ export default function Editor() { return false } - useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD]) + useKey(redoPredicate, redo, undefined, [ + redoStroke, + redoRender, + runMannually, + redoCurLines, + ]) const disableRedo = () => { if (isInteractiveSeg) { @@ -1295,7 +1339,7 @@ export default function Editor() { - {isSD && settings.showCroper ? ( + {(isSD || isPaintByExample) && settings.showCroper ? ( - {isSD || file === undefined ? ( + {isSD || isPaintByExample || file === undefined ? ( <> ) : ( - {settings.runInpaintingManually && !isSD && ( + {settings.runInpaintingManually && !isSD && !isPaintByExample && ( + + ) + } + + return ( +
+ + toggleOpen()} + > + Configurations + + + + { + setSettingState(old => { + return { ...old, showCroper: value } + }) + }} + > + + + } + /> + + { + const val = value.length === 0 ? 0 : parseInt(value, 10) + setSettingState(old => { + return { ...old, paintByExampleSteps: val } + }) + }} + /> + + { + const val = value.length === 0 ? 0 : parseFloat(value) + setSettingState(old => { + return { ...old, paintByExampleGuidanceScale: val } + }) + }} + /> + + { + const val = value.length === 0 ? 0 : parseInt(value, 10) + setSettingState(old => { + return { ...old, paintByExampleMaskBlur: val } + }) + }} + /> + + { + setSettingState(old => { + return { ...old, paintByExampleMatchHistograms: value } + }) + }} + > + + + } + /> + + + {/* 每次会从服务器返回更新该值 */} + { + const val = value.length === 0 ? 0 : parseInt(value, 10) + setSettingState(old => { + return { ...old, paintByExampleSeed: val } + }) + }} + /> + { + setSettingState(old => { + return { ...old, paintByExampleSeedFixed: value } + }) + }} + style={{ marginLeft: '8px' }} + > + + +
+ } + /> + +
+ + + {paintByExampleImage ? ( +
+ example +
+ ) : ( + <> + )} +
+ + + + + + + ) +} + +export default PESidePanel diff --git a/lama_cleaner/app/src/components/Workspace.tsx b/lama_cleaner/app/src/components/Workspace.tsx index 8f1d834..e33f5a1 100644 --- a/lama_cleaner/app/src/components/Workspace.tsx +++ b/lama_cleaner/app/src/components/Workspace.tsx @@ -7,6 +7,7 @@ import Toast from './shared/Toast' import { AIModel, fileState, + isPaintByExampleState, isSDState, settingState, toastState, @@ -17,12 +18,14 @@ import { switchModel, } from '../adapters/inpainting' import SidePanel from './SidePanel/SidePanel' +import PESidePanel from './SidePanel/PESidePanel' const Workspace = () => { const [file, setFile] = useRecoilState(fileState) const [settings, setSettingState] = useRecoilState(settingState) const [toastVal, setToastState] = useRecoilState(toastState) const isSD = useRecoilValue(isSDState) + const isPaintByExample = useRecoilValue(isPaintByExampleState) const onSettingClose = async () => { const curModel = await currentModel().then(res => res.text()) @@ -88,6 +91,7 @@ const Workspace = () => { return ( <> {isSD ? : <>} + {isPaintByExample ? : <>} diff --git a/lama_cleaner/app/src/event.ts b/lama_cleaner/app/src/event.ts index cdfc70a..37d78e5 100644 --- a/lama_cleaner/app/src/event.ts +++ b/lama_cleaner/app/src/event.ts @@ -1,11 +1,17 @@ import mitt from 'mitt' export const EVENT_PROMPT = 'prompt' + export const EVENT_CUSTOM_MASK = 'custom_mask' export interface CustomMaskEventData { mask: File } +export const EVENT_PAINT_BY_EXAMPLE = 'paint_by_example' +export interface PaintByExampleEventData { + image: File +} + const emitter = mitt() export default emitter diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index 353c7d6..65558d0 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -13,6 +13,7 @@ export enum AIModel { SD2 = 'sd2', CV2 = 'cv2', Mange = 'manga', + PAINT_BY_EXAMPLE = 'paint_by_example', } export const maskState = atom({ @@ -20,6 +21,11 @@ export const maskState = atom({ default: undefined, }) +export const paintByExampleImageState = atom({ + key: 'paintByExampleImageState', + default: undefined, +}) + export interface Rect { x: number y: number @@ -252,6 +258,14 @@ export interface Settings { // For OpenCV2 cv2Radius: number cv2Flag: CV2Flag + + // Paint by Example + paintByExampleSteps: number + paintByExampleGuidanceScale: number + paintByExampleSeed: number + paintByExampleSeedFixed: boolean + paintByExampleMaskBlur: number + paintByExampleMatchHistograms: boolean } const defaultHDSettings: ModelsHDSettings = { @@ -304,6 +318,13 @@ const defaultHDSettings: ModelsHDSettings = { hdStrategyCropMargin: 128, enabled: false, }, + [AIModel.PAINT_BY_EXAMPLE]: { + hdStrategy: HDStrategy.ORIGINAL, + hdStrategyResizeLimit: 768, + hdStrategyCropTrigerSize: 512, + hdStrategyCropMargin: 128, + enabled: false, + }, [AIModel.Mange]: { hdStrategy: HDStrategy.CROP, hdStrategyResizeLimit: 1280, @@ -364,6 +385,14 @@ export const settingStateDefault: Settings = { // CV2 cv2Radius: 5, cv2Flag: CV2Flag.INPAINT_NS, + + // Paint by Example + paintByExampleSteps: 50, + paintByExampleGuidanceScale: 7.5, + paintByExampleSeed: 42, + paintByExampleMaskBlur: 5, + paintByExampleSeedFixed: false, + paintByExampleMatchHistograms: false, } const localStorageEffect = @@ -401,11 +430,28 @@ export const seedState = selector({ key: 'seed', get: ({ get }) => { const settings = get(settingState) - return settings.sdSeed + switch (settings.model) { + case AIModel.PAINT_BY_EXAMPLE: + return settings.paintByExampleSeedFixed + ? settings.paintByExampleSeed + : -1 + default: + return settings.sdSeedFixed ? settings.sdSeed : -1 + } }, set: ({ get, set }, newValue: any) => { const settings = get(settingState) - set(settingState, { ...settings, sdSeed: newValue }) + switch (settings.model) { + case AIModel.PAINT_BY_EXAMPLE: + if (!settings.paintByExampleSeedFixed) { + set(settingState, { ...settings, paintByExampleSeed: newValue }) + } + break + default: + if (!settings.sdSeedFixed) { + set(settingState, { ...settings, sdSeed: newValue }) + } + } }, }) @@ -435,11 +481,20 @@ export const isSDState = selector({ }, }) +export const isPaintByExampleState = selector({ + key: 'isPaintByExampleState', + get: ({ get }) => { + const settings = get(settingState) + return settings.model === AIModel.PAINT_BY_EXAMPLE + }, +}) + export const runManuallyState = selector({ key: 'runManuallyState', get: ({ get }) => { const settings = get(settingState) const isSD = get(isSDState) - return settings.runInpaintingManually || isSD + const isPaintByExample = get(isPaintByExampleState) + return settings.runInpaintingManually || isSD || isPaintByExample }, }) diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 882aa42..05d27f1 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -211,6 +211,26 @@ class InpaintModel: return result + def _apply_cropper(self, image, mask, config: Config): + img_h, img_w = image.shape[:2] + l, t, w, h = ( + config.croper_x, + config.croper_y, + config.croper_width, + config.croper_height, + ) + r = l + w + b = t + h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + return crop_img, crop_mask, (l, t, r, b) + def _run_box(self, image, mask, box, config: Config): """ diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index e69de29..991e836 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -0,0 +1,80 @@ +import random + +import PIL +import PIL.Image +import cv2 +import numpy as np +import torch +from diffusers import DiffusionPipeline +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + + +class PaintByExample(InpaintModel): + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + use_gpu = device == torch.device('cuda') and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu else torch.float32 + self.model = DiffusionPipeline.from_pretrained( + "Fantasy-Studio/Paint-by-Example", + torch_dtype=torch_dtype, + ) + self.model.enable_attention_slicing() + self.model = self.model.to(device) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + seed = config.paint_by_example_seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + output = self.model( + image=PIL.Image.fromarray(image), + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + example_image=config.paint_by_example_example_image, + num_inference_steps=config.paint_by_example_steps, + output_type='np.array', + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + if config.use_croper: + crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) + crop_image = self._pad_forward(crop_img, crop_mask, config) + inpaint_result = image[:, :, ::-1] + inpaint_result[t:b, l:r, :] = crop_image + else: + inpaint_result = self._pad_forward(image, mask, config) + + return inpaint_result + + def forward_post_process(self, result, image, mask, config): + if config.paint_by_example_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.paint_by_example_mask_blur != 0: + k = 2 * config.paint_by_example_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index fced353..a10af76 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -12,31 +12,6 @@ from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config, SDSampler -# -# -# def preprocess_image(image): -# w, h = image.size -# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 -# image = image.resize((w, h), resample=PIL.Image.LANCZOS) -# image = np.array(image).astype(np.float32) / 255.0 -# image = image[None].transpose(0, 3, 1, 2) -# image = torch.from_numpy(image) -# # [-1, 1] -# return 2.0 * image - 1.0 -# -# -# def preprocess_mask(mask): -# mask = mask.convert("L") -# w, h = mask.size -# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 -# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) -# mask = np.array(mask).astype(np.float32) / 255.0 -# mask = np.tile(mask, (4, 1, 1)) -# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? -# mask = 1 - mask # repaint white, keep black -# mask = torch.from_numpy(mask) -# return mask - class CPUTextEncoderWrapper: def __init__(self, text_encoder, torch_dtype): self.config = text_encoder.config @@ -92,17 +67,6 @@ class SD(InpaintModel): return: BGR IMAGE """ - # image = norm_img(image) # [0, 1] - # image = image * 2 - 1 # [0, 1] -> [-1, 1] - - # resize to latent feature map size - # h, w = mask.shape[:2] - # mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA) - # mask = norm_img(mask) - # - # image = torch.from_numpy(image).unsqueeze(0).to(self.device) - # mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) - scheduler_config = self.model.scheduler.config if config.sd_sampler == SDSampler.ddim: @@ -139,7 +103,6 @@ class SD(InpaintModel): prompt=config.prompt, negative_prompt=config.negative_prompt, mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), - strength=config.sd_strength, num_inference_steps=config.sd_steps, guidance_scale=config.sd_guidance_scale, output_type="np.array", @@ -159,30 +122,10 @@ class SD(InpaintModel): masks: [H, W] return: BGR IMAGE """ - img_h, img_w = image.shape[:2] - # boxes = boxes_from_mask(mask) if config.use_croper: - logger.info("use croper") - l, t, w, h = ( - config.croper_x, - config.croper_y, - config.croper_width, - config.croper_height, - ) - r = l + w - b = t + h - - l = max(l, 0) - r = min(r, img_w) - t = max(t, 0) - b = min(b, img_h) - - crop_img = image[t:b, l:r, :] - crop_mask = mask[t:b, l:r] - + crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) crop_image = self._pad_forward(crop_img, crop_mask, config) - inpaint_result = image[:, :, ::-1] inpaint_result[t:b, l:r, :] = crop_image else: diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index c9f2b9b..e14736e 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -5,13 +5,14 @@ from lama_cleaner.model.lama import LaMa from lama_cleaner.model.ldm import LDM from lama_cleaner.model.manga import Manga from lama_cleaner.model.mat import MAT +from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.sd import SD15, SD2 from lama_cleaner.model.zits import ZITS from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.schema import Config models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga, - "sd2": SD2} + "sd2": SD2, "paint_by_example": PaintByExample} class ModelManager: diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 7508bd9..e01867d 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -10,7 +10,7 @@ def parse_args(): parser.add_argument( "--model", default="lama", - choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2"], + choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"], ) parser.add_argument( "--hf_access_token", diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 47baaa7..dba1042 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -1,5 +1,6 @@ from enum import Enum +from PIL.Image import Image from pydantic import BaseModel @@ -29,6 +30,9 @@ class SDSampler(str, Enum): class Config(BaseModel): + class Config: + arbitrary_types_allowed = True + # Configs for ldm model ldm_steps: int ldm_sampler: str = LDMSampler.plms @@ -73,3 +77,11 @@ class Config(BaseModel): # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 cv2_flag: str = 'INPAINT_NS' cv2_radius: int = 4 + + # Paint by Example + paint_by_example_steps: int = 50 + paint_by_example_guidance_scale: float = 7.5 + paint_by_example_mask_blur: int = 0 + paint_by_example_seed: int = 42 + paint_by_example_match_histograms: bool = False + paint_by_example_example_image: Image = None diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index fbd35a1..f447e00 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -10,6 +10,7 @@ import time import imghdr from pathlib import Path from typing import Union +from PIL import Image import cv2 import torch @@ -97,8 +98,8 @@ def process(): input = request.files # RGB origin_image_bytes = input["image"].read() - image, alpha_channel = load_img(origin_image_bytes) + mask, _ = load_img(input["mask"].read(), gray=True) mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] @@ -115,6 +116,12 @@ def process(): else: size_limit = int(size_limit) + if "paintByExampleImage" in input: + paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read()) + paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) + else: + paint_by_example_example_image = None + config = Config( ldm_steps=form["ldmSteps"], ldm_sampler=form["ldmSampler"], @@ -138,11 +145,19 @@ def process(): sd_seed=form["sdSeed"], sd_match_histograms=form["sdMatchHistograms"], cv2_flag=form["cv2Flag"], - cv2_radius=form['cv2Radius'] + cv2_radius=form['cv2Radius'], + paint_by_example_steps=form["paintByExampleSteps"], + paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"], + paint_by_example_mask_blur=form["paintByExampleMaskBlur"], + paint_by_example_seed=form["paintByExampleSeed"], + paint_by_example_match_histograms=form["paintByExampleMatchHistograms"], + paint_by_example_example_image=paint_by_example_example_image, ) if config.sd_seed == -1: config.sd_seed = random.randint(1, 999999999) + if config.paint_by_example_seed == -1: + config.paint_by_example_seed = random.randint(1, 999999999) logger.info(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) diff --git a/lama_cleaner/tests/bunny.jpeg b/lama_cleaner/tests/bunny.jpeg new file mode 100644 index 0000000..3727a45 Binary files /dev/null and b/lama_cleaner/tests/bunny.jpeg differ diff --git a/lama_cleaner/tests/result/paint_by_example_Original.png b/lama_cleaner/tests/result/paint_by_example_Original.png new file mode 100644 index 0000000..9fe40e4 Binary files /dev/null and b/lama_cleaner/tests/result/paint_by_example_Original.png differ diff --git a/lama_cleaner/tests/test_paint_by_example.py b/lama_cleaner/tests/test_paint_by_example.py index e69de29..a8ce09a 100644 --- a/lama_cleaner/tests/test_paint_by_example.py +++ b/lama_cleaner/tests/test_paint_by_example.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import cv2 +import pytest +import torch +from PIL import Image + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import HDStrategy +from lama_cleaner.tests.test_model import get_config, get_data + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = torch.device(device) + + +def assert_equal( + model, config, gt_name, + fx: float = 1, fy: float = 1, + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + example_p=current_dir / "rabbit.jpeg", +): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + + example_image = cv2.imread(str(example_p)) + example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB) + example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) + + print(f"Input image shape: {img.shape}, example_image: {example_image.shape}") + config.paint_by_example_example_image = Image.fromarray(example_image) + res = model(img, mask, config) + cv2.imwrite(str(save_dir / gt_name), res) + + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example(strategy): + model = ModelManager(name="paint_by_example", device=device) + cfg = get_config(strategy, paint_by_example_steps=30) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fy=0.9, + fx=1.3 + ) diff --git a/requirements.txt b/requirements.txt index b9d1fa2..ae1083d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,5 @@ pytest yacs markupsafe==2.0.1 scikit-image==0.19.3 -diffusers[torch]==0.9 -transformers==4.21.0 +diffusers[torch]==0.10.2 +transformers>=4.25.1