add paint by example

This commit is contained in:
Qing 2022-12-10 22:06:15 +08:00
parent 6e9d3d8442
commit 203f2bc9c7
18 changed files with 572 additions and 82 deletions

View File

@ -12,7 +12,8 @@ export default async function inpaint(
sizeLimit?: string, sizeLimit?: string,
seed?: number, seed?: number,
maskBase64?: string, maskBase64?: string,
customMask?: File customMask?: File,
paintByExampleImage?: File
) { ) {
// 1080, 2000, Original // 1080, 2000, Original
const fd = new FormData() const fd = new FormData()
@ -48,6 +49,7 @@ export default async function inpaint(
fd.append('croperHeight', croperRect.height.toString()) fd.append('croperHeight', croperRect.height.toString())
fd.append('croperWidth', croperRect.width.toString()) fd.append('croperWidth', croperRect.width.toString())
fd.append('useCroper', settings.showCroper ? 'true' : 'false') fd.append('useCroper', settings.showCroper ? 'true' : 'false')
fd.append('sdMaskBlur', settings.sdMaskBlur.toString()) fd.append('sdMaskBlur', settings.sdMaskBlur.toString())
fd.append('sdStrength', settings.sdStrength.toString()) fd.append('sdStrength', settings.sdStrength.toString())
fd.append('sdSteps', settings.sdSteps.toString()) fd.append('sdSteps', settings.sdSteps.toString())
@ -59,6 +61,26 @@ export default async function inpaint(
fd.append('cv2Radius', settings.cv2Radius.toString()) fd.append('cv2Radius', settings.cv2Radius.toString())
fd.append('cv2Flag', settings.cv2Flag.toString()) fd.append('cv2Flag', settings.cv2Flag.toString())
fd.append('paintByExampleSteps', settings.paintByExampleSteps.toString())
fd.append(
'paintByExampleGuidanceScale',
settings.paintByExampleGuidanceScale.toString()
)
fd.append('paintByExampleSeed', seed ? seed.toString() : '-1')
fd.append(
'paintByExampleMaskBlur',
settings.paintByExampleMaskBlur.toString()
)
fd.append(
'paintByExampleMatchHistograms',
settings.paintByExampleMatchHistograms ? 'true' : 'false'
)
// TODO: resize image's shortest_edge to 224 before pass to backend, save network time?
// https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor
if (paintByExampleImage) {
fd.append('paintByExampleImage', paintByExampleImage)
}
if (sizeLimit === undefined) { if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080') fd.append('sizeLimit', '1080')
} else { } else {

View File

@ -39,6 +39,7 @@ import {
isInpaintingState, isInpaintingState,
isInteractiveSegRunningState, isInteractiveSegRunningState,
isInteractiveSegState, isInteractiveSegState,
isPaintByExampleState,
isSDState, isSDState,
negativePropmtState, negativePropmtState,
propmtState, propmtState,
@ -53,6 +54,7 @@ import emitter, {
EVENT_PROMPT, EVENT_PROMPT,
EVENT_CUSTOM_MASK, EVENT_CUSTOM_MASK,
CustomMaskEventData, CustomMaskEventData,
EVENT_PAINT_BY_EXAMPLE,
} from '../../event' } from '../../event'
import FileSelect from '../FileSelect/FileSelect' import FileSelect from '../FileSelect/FileSelect'
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg' import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
@ -108,6 +110,7 @@ export default function Editor() {
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState) const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
const runMannually = useRecoilValue(runManuallyState) const runMannually = useRecoilValue(runManuallyState)
const isSD = useRecoilValue(isSDState) const isSD = useRecoilValue(isSDState)
const isPaintByExample = useRecoilValue(isPaintByExampleState)
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState( const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
isInteractiveSegState isInteractiveSegState
) )
@ -262,8 +265,11 @@ export default function Editor() {
async ( async (
useLastLineGroup?: boolean, useLastLineGroup?: boolean,
customMask?: File, customMask?: File,
maskImage?: HTMLImageElement | null maskImage?: HTMLImageElement | null,
paintByExampleImage?: File
) => { ) => {
// customMask: mask uploaded by user
// maskImage: mask from interactive segmentation
if (file === undefined) { if (file === undefined) {
return return
} }
@ -328,9 +334,6 @@ export default function Editor() {
} }
} }
const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1
console.log({ useCustomMask })
try { try {
const res = await inpaint( const res = await inpaint(
targetFile, targetFile,
@ -339,15 +342,16 @@ export default function Editor() {
promptVal, promptVal,
negativePromptVal, negativePromptVal,
sizeLimit.toString(), sizeLimit.toString(),
sdSeed, seedVal,
useCustomMask ? undefined : maskCanvas.toDataURL(), useCustomMask ? undefined : maskCanvas.toDataURL(),
useCustomMask ? customMask : undefined useCustomMask ? customMask : undefined,
paintByExampleImage
) )
if (!res) { if (!res) {
throw new Error('Something went wrong on server side.') throw new Error('Something went wrong on server side.')
} }
const { blob, seed } = res const { blob, seed } = res
if (seed && !settings.sdSeedFixed) { if (seed) {
setSeed(parseInt(seed, 10)) setSeed(parseInt(seed, 10))
} }
const newRender = new Image() const newRender = new Image()
@ -395,6 +399,7 @@ export default function Editor() {
drawOnCurrentRender, drawOnCurrentRender,
hadDrawSomething, hadDrawSomething,
drawLinesOnMask, drawLinesOnMask,
seedVal,
] ]
) )
@ -439,6 +444,31 @@ export default function Editor() {
} }
}, [runInpainting]) }, [runInpainting])
useEffect(() => {
emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => {
if (hadDrawSomething() || interactiveSegMask) {
runInpainting(false, undefined, interactiveSegMask, data.image)
} else if (lastLineGroup.length !== 0) {
// 使用上一次手绘的 mask 生成
runInpainting(true, undefined, prevInteractiveSegMask, data.image)
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask, data.image)
} else {
setToastState({
open: true,
desc: 'Please draw mask on picture',
state: 'error',
duration: 1500,
})
}
})
return () => {
emitter.off(EVENT_PAINT_BY_EXAMPLE)
}
}, [runInpainting])
const hadRunInpainting = () => { const hadRunInpainting = () => {
return renders.length !== 0 return renders.length !== 0
} }
@ -793,7 +823,11 @@ export default function Editor() {
return return
} }
if (isSD && settings.showCroper && isOutsideCroper(mouseXY(ev))) { if (
(isSD || isPaintByExample) &&
settings.showCroper &&
isOutsideCroper(mouseXY(ev))
) {
return return
} }
@ -876,7 +910,12 @@ export default function Editor() {
return false return false
} }
useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD]) useKey(undoPredicate, undo, undefined, [
undoStroke,
undoRender,
runMannually,
curLineGroup,
])
const disableUndo = () => { const disableUndo = () => {
if (isInteractiveSeg) { if (isInteractiveSeg) {
@ -955,7 +994,12 @@ export default function Editor() {
return false return false
} }
useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD]) useKey(redoPredicate, redo, undefined, [
redoStroke,
redoRender,
runMannually,
redoCurLines,
])
const disableRedo = () => { const disableRedo = () => {
if (isInteractiveSeg) { if (isInteractiveSeg) {
@ -1295,7 +1339,7 @@ export default function Editor() {
</div> </div>
</div> </div>
{isSD && settings.showCroper ? ( {(isSD || isPaintByExample) && settings.showCroper ? (
<Croper <Croper
maxHeight={original.naturalHeight} maxHeight={original.naturalHeight}
maxWidth={original.naturalWidth} maxWidth={original.naturalWidth}
@ -1358,7 +1402,7 @@ export default function Editor() {
)} )}
<div className="editor-toolkit-panel"> <div className="editor-toolkit-panel">
{isSD || file === undefined ? ( {isSD || isPaintByExample || file === undefined ? (
<></> <></>
) : ( ) : (
<SizeSelector <SizeSelector
@ -1466,7 +1510,7 @@ export default function Editor() {
onClick={download} onClick={download}
/> />
{settings.runInpaintingManually && !isSD && ( {settings.runInpaintingManually && !isSD && !isPaintByExample && (
<Button <Button
toolTip="Run Inpainting" toolTip="Run Inpainting"
tooltipPosition="top" tooltipPosition="top"

View File

@ -193,6 +193,8 @@ function ModelSettingBlock() {
return undefined return undefined
case AIModel.SD2: case AIModel.SD2:
return undefined return undefined
case AIModel.PAINT_BY_EXAMPLE:
return undefined
case AIModel.Mange: case AIModel.Mange:
return undefined return undefined
case AIModel.CV2: case AIModel.CV2:
@ -258,6 +260,12 @@ function ModelSettingBlock() {
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html', 'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html',
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html' 'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html'
) )
case AIModel.PAINT_BY_EXAMPLE:
return renderModelDesc(
'Paint by Example',
'https://arxiv.org/abs/2211.13227',
'https://github.com/Fantasy-Studio/Paint-by-Example'
)
default: default:
return <></> return <></>
} }
@ -270,7 +278,6 @@ function ModelSettingBlock() {
titleSuffix={renderPaperCodeBadge()} titleSuffix={renderPaperCodeBadge()}
input={ input={
<Selector <Selector
width={80}
value={setting.model as string} value={setting.model as string}
options={Object.values(AIModel)} options={Object.values(AIModel)}
onChange={val => onModelChange(val as AIModel)} onChange={val => onModelChange(val as AIModel)}

View File

@ -0,0 +1,231 @@
import React, { useState } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use'
import { UploadIcon } from '@radix-ui/react-icons'
import {
isInpaintingState,
paintByExampleImageState,
settingState,
} from '../../store/Atoms'
import NumberInputSetting from '../Settings/NumberInputSetting'
import SettingBlock from '../Settings/SettingBlock'
import { Switch, SwitchThumb } from '../shared/Switch'
import Button from '../shared/Button'
import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event'
import { useImage } from '../../utils'
const INPUT_WIDTH = 30
const PESidePanel = () => {
const [open, toggleOpen] = useToggle(true)
const [setting, setSettingState] = useRecoilState(settingState)
const [paintByExampleImage, setPaintByExampleImage] = useRecoilState(
paintByExampleImageState
)
const [uploadElemId] = useState(
`example-file-upload-${Math.random().toString()}`
)
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleImage)
const isInpainting = useRecoilValue(isInpaintingState)
const renderUploadIcon = () => {
return (
<label htmlFor={uploadElemId}>
<Button
border
toolTip="Upload example image"
tooltipPosition="top"
icon={<UploadIcon />}
style={{ padding: '0.3rem', gap: 0 }}
>
<input
style={{ display: 'none' }}
id={uploadElemId}
name={uploadElemId}
type="file"
onChange={ev => {
const newFile = ev.currentTarget.files?.[0]
if (newFile) {
setPaintByExampleImage(newFile)
}
}}
accept="image/png, image/jpeg"
/>
</Button>
</label>
)
}
return (
<div className="side-panel">
<PopoverPrimitive.Root open={open}>
<PopoverPrimitive.Trigger
className="btn-primary side-panel-trigger"
onClick={() => toggleOpen()}
>
Configurations
</PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="side-panel-content">
<SettingBlock
title="Croper"
input={
<Switch
checked={setting.showCroper}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, showCroper: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/>
<NumberInputSetting
title="Steps"
width={INPUT_WIDTH}
value={`${setting.paintByExampleSteps}`}
desc="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, paintByExampleSteps: val }
})
}}
/>
<NumberInputSetting
title="Guidance Scale"
width={INPUT_WIDTH}
allowFloat
value={`${setting.paintByExampleGuidanceScale}`}
desc="Higher guidance scale encourages to generate images that are close to the example image"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, paintByExampleGuidanceScale: val }
})
}}
/>
<NumberInputSetting
title="Mask Blur"
width={INPUT_WIDTH}
value={`${setting.paintByExampleMaskBlur}`}
desc="Blur the edge of mask area. The higher the number the smoother blend with the original image"
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, paintByExampleMaskBlur: val }
})
}}
/>
<SettingBlock
title="Match Histograms"
desc="Match the inpainting result histogram to the source image histogram, will improves the inpainting quality for some images."
input={
<Switch
checked={setting.paintByExampleMatchHistograms}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, paintByExampleMatchHistograms: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/>
<SettingBlock
title="Seed"
input={
<div
style={{
display: 'flex',
gap: 0,
justifyContent: 'center',
alignItems: 'center',
}}
>
{/* 每次会从服务器返回更新该值 */}
<NumberInputSetting
title=""
width={80}
value={`${setting.paintByExampleSeed}`}
desc=""
disable={!setting.paintByExampleSeedFixed}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, paintByExampleSeed: val }
})
}}
/>
<Switch
checked={setting.paintByExampleSeedFixed}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, paintByExampleSeedFixed: value }
})
}}
style={{ marginLeft: '8px' }}
>
<SwitchThumb />
</Switch>
</div>
}
/>
<div style={{ display: 'flex', flexDirection: 'column' }}>
<SettingBlock title="Example Image" input={renderUploadIcon()} />
{paintByExampleImage ? (
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
}}
>
<img
src={exampleImage.src}
alt="example"
style={{
maxWidth: 200,
maxHeight: 200,
margin: 12,
}}
/>
</div>
) : (
<></>
)}
</div>
<Button
border
disabled={!isExampleImageLoaded || isInpainting}
style={{ width: '100%' }}
onClick={() => {
if (isExampleImageLoaded) {
emitter.emit(EVENT_PAINT_BY_EXAMPLE, {
image: paintByExampleImage,
})
}
}}
>
Paint
</Button>
</PopoverPrimitive.Content>
</PopoverPrimitive.Portal>
</PopoverPrimitive.Root>
</div>
)
}
export default PESidePanel

View File

@ -7,6 +7,7 @@ import Toast from './shared/Toast'
import { import {
AIModel, AIModel,
fileState, fileState,
isPaintByExampleState,
isSDState, isSDState,
settingState, settingState,
toastState, toastState,
@ -17,12 +18,14 @@ import {
switchModel, switchModel,
} from '../adapters/inpainting' } from '../adapters/inpainting'
import SidePanel from './SidePanel/SidePanel' import SidePanel from './SidePanel/SidePanel'
import PESidePanel from './SidePanel/PESidePanel'
const Workspace = () => { const Workspace = () => {
const [file, setFile] = useRecoilState(fileState) const [file, setFile] = useRecoilState(fileState)
const [settings, setSettingState] = useRecoilState(settingState) const [settings, setSettingState] = useRecoilState(settingState)
const [toastVal, setToastState] = useRecoilState(toastState) const [toastVal, setToastState] = useRecoilState(toastState)
const isSD = useRecoilValue(isSDState) const isSD = useRecoilValue(isSDState)
const isPaintByExample = useRecoilValue(isPaintByExampleState)
const onSettingClose = async () => { const onSettingClose = async () => {
const curModel = await currentModel().then(res => res.text()) const curModel = await currentModel().then(res => res.text())
@ -88,6 +91,7 @@ const Workspace = () => {
return ( return (
<> <>
{isSD ? <SidePanel /> : <></>} {isSD ? <SidePanel /> : <></>}
{isPaintByExample ? <PESidePanel /> : <></>}
<Editor /> <Editor />
<SettingModal onClose={onSettingClose} /> <SettingModal onClose={onSettingClose} />
<ShortcutsModal /> <ShortcutsModal />

View File

@ -1,11 +1,17 @@
import mitt from 'mitt' import mitt from 'mitt'
export const EVENT_PROMPT = 'prompt' export const EVENT_PROMPT = 'prompt'
export const EVENT_CUSTOM_MASK = 'custom_mask' export const EVENT_CUSTOM_MASK = 'custom_mask'
export interface CustomMaskEventData { export interface CustomMaskEventData {
mask: File mask: File
} }
export const EVENT_PAINT_BY_EXAMPLE = 'paint_by_example'
export interface PaintByExampleEventData {
image: File
}
const emitter = mitt() const emitter = mitt()
export default emitter export default emitter

View File

@ -13,6 +13,7 @@ export enum AIModel {
SD2 = 'sd2', SD2 = 'sd2',
CV2 = 'cv2', CV2 = 'cv2',
Mange = 'manga', Mange = 'manga',
PAINT_BY_EXAMPLE = 'paint_by_example',
} }
export const maskState = atom<File | undefined>({ export const maskState = atom<File | undefined>({
@ -20,6 +21,11 @@ export const maskState = atom<File | undefined>({
default: undefined, default: undefined,
}) })
export const paintByExampleImageState = atom<File | undefined>({
key: 'paintByExampleImageState',
default: undefined,
})
export interface Rect { export interface Rect {
x: number x: number
y: number y: number
@ -252,6 +258,14 @@ export interface Settings {
// For OpenCV2 // For OpenCV2
cv2Radius: number cv2Radius: number
cv2Flag: CV2Flag cv2Flag: CV2Flag
// Paint by Example
paintByExampleSteps: number
paintByExampleGuidanceScale: number
paintByExampleSeed: number
paintByExampleSeedFixed: boolean
paintByExampleMaskBlur: number
paintByExampleMatchHistograms: boolean
} }
const defaultHDSettings: ModelsHDSettings = { const defaultHDSettings: ModelsHDSettings = {
@ -304,6 +318,13 @@ const defaultHDSettings: ModelsHDSettings = {
hdStrategyCropMargin: 128, hdStrategyCropMargin: 128,
enabled: false, enabled: false,
}, },
[AIModel.PAINT_BY_EXAMPLE]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.Mange]: { [AIModel.Mange]: {
hdStrategy: HDStrategy.CROP, hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1280, hdStrategyResizeLimit: 1280,
@ -364,6 +385,14 @@ export const settingStateDefault: Settings = {
// CV2 // CV2
cv2Radius: 5, cv2Radius: 5,
cv2Flag: CV2Flag.INPAINT_NS, cv2Flag: CV2Flag.INPAINT_NS,
// Paint by Example
paintByExampleSteps: 50,
paintByExampleGuidanceScale: 7.5,
paintByExampleSeed: 42,
paintByExampleMaskBlur: 5,
paintByExampleSeedFixed: false,
paintByExampleMatchHistograms: false,
} }
const localStorageEffect = const localStorageEffect =
@ -401,11 +430,28 @@ export const seedState = selector({
key: 'seed', key: 'seed',
get: ({ get }) => { get: ({ get }) => {
const settings = get(settingState) const settings = get(settingState)
return settings.sdSeed switch (settings.model) {
case AIModel.PAINT_BY_EXAMPLE:
return settings.paintByExampleSeedFixed
? settings.paintByExampleSeed
: -1
default:
return settings.sdSeedFixed ? settings.sdSeed : -1
}
}, },
set: ({ get, set }, newValue: any) => { set: ({ get, set }, newValue: any) => {
const settings = get(settingState) const settings = get(settingState)
set(settingState, { ...settings, sdSeed: newValue }) switch (settings.model) {
case AIModel.PAINT_BY_EXAMPLE:
if (!settings.paintByExampleSeedFixed) {
set(settingState, { ...settings, paintByExampleSeed: newValue })
}
break
default:
if (!settings.sdSeedFixed) {
set(settingState, { ...settings, sdSeed: newValue })
}
}
}, },
}) })
@ -435,11 +481,20 @@ export const isSDState = selector({
}, },
}) })
export const isPaintByExampleState = selector({
key: 'isPaintByExampleState',
get: ({ get }) => {
const settings = get(settingState)
return settings.model === AIModel.PAINT_BY_EXAMPLE
},
})
export const runManuallyState = selector({ export const runManuallyState = selector({
key: 'runManuallyState', key: 'runManuallyState',
get: ({ get }) => { get: ({ get }) => {
const settings = get(settingState) const settings = get(settingState)
const isSD = get(isSDState) const isSD = get(isSDState)
return settings.runInpaintingManually || isSD const isPaintByExample = get(isPaintByExampleState)
return settings.runInpaintingManually || isSD || isPaintByExample
}, },
}) })

View File

@ -211,6 +211,26 @@ class InpaintModel:
return result return result
def _apply_cropper(self, image, mask, config: Config):
img_h, img_w = image.shape[:2]
l, t, w, h = (
config.croper_x,
config.croper_y,
config.croper_width,
config.croper_height,
)
r = l + w
b = t + h
l = max(l, 0)
r = min(r, img_w)
t = max(t, 0)
b = min(b, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
return crop_img, crop_mask, (l, t, r, b)
def _run_box(self, image, mask, box, config: Config): def _run_box(self, image, mask, box, config: Config):
""" """

View File

@ -0,0 +1,80 @@
import random
import PIL
import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import DiffusionPipeline
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
class PaintByExample(InpaintModel):
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu else torch.float32
self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch_dtype,
)
self.model.enable_attention_slicing()
self.model = self.model.to(device)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
seed = config.paint_by_example_seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
output = self.model(
image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=config.paint_by_example_example_image,
num_inference_steps=config.paint_by_example_steps,
output_type='np.array',
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
def forward_post_process(self, result, image, mask, config):
if config.paint_by_example_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.paint_by_example_mask_blur != 0:
k = 2 * config.paint_by_example_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -12,31 +12,6 @@ from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config, SDSampler from lama_cleaner.schema import Config, SDSampler
#
#
# def preprocess_image(image):
# w, h = image.size
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
# image = image.resize((w, h), resample=PIL.Image.LANCZOS)
# image = np.array(image).astype(np.float32) / 255.0
# image = image[None].transpose(0, 3, 1, 2)
# image = torch.from_numpy(image)
# # [-1, 1]
# return 2.0 * image - 1.0
#
#
# def preprocess_mask(mask):
# mask = mask.convert("L")
# w, h = mask.size
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
# mask = np.array(mask).astype(np.float32) / 255.0
# mask = np.tile(mask, (4, 1, 1))
# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
# mask = 1 - mask # repaint white, keep black
# mask = torch.from_numpy(mask)
# return mask
class CPUTextEncoderWrapper: class CPUTextEncoderWrapper:
def __init__(self, text_encoder, torch_dtype): def __init__(self, text_encoder, torch_dtype):
self.config = text_encoder.config self.config = text_encoder.config
@ -92,17 +67,6 @@ class SD(InpaintModel):
return: BGR IMAGE return: BGR IMAGE
""" """
# image = norm_img(image) # [0, 1]
# image = image * 2 - 1 # [0, 1] -> [-1, 1]
# resize to latent feature map size
# h, w = mask.shape[:2]
# mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA)
# mask = norm_img(mask)
#
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
scheduler_config = self.model.scheduler.config scheduler_config = self.model.scheduler.config
if config.sd_sampler == SDSampler.ddim: if config.sd_sampler == SDSampler.ddim:
@ -139,7 +103,6 @@ class SD(InpaintModel):
prompt=config.prompt, prompt=config.prompt,
negative_prompt=config.negative_prompt, negative_prompt=config.negative_prompt,
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
strength=config.sd_strength,
num_inference_steps=config.sd_steps, num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np.array", output_type="np.array",
@ -159,30 +122,10 @@ class SD(InpaintModel):
masks: [H, W] masks: [H, W]
return: BGR IMAGE return: BGR IMAGE
""" """
img_h, img_w = image.shape[:2]
# boxes = boxes_from_mask(mask) # boxes = boxes_from_mask(mask)
if config.use_croper: if config.use_croper:
logger.info("use croper") crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
l, t, w, h = (
config.croper_x,
config.croper_y,
config.croper_width,
config.croper_height,
)
r = l + w
b = t + h
l = max(l, 0)
r = min(r, img_w)
t = max(t, 0)
b = min(b, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
crop_image = self._pad_forward(crop_img, crop_mask, config) crop_image = self._pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1] inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image inpaint_result[t:b, l:r, :] = crop_image
else: else:

View File

@ -5,13 +5,14 @@ from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.manga import Manga from lama_cleaner.model.manga import Manga
from lama_cleaner.model.mat import MAT from lama_cleaner.model.mat import MAT
from lama_cleaner.model.paint_by_example import PaintByExample
from lama_cleaner.model.sd import SD15, SD2 from lama_cleaner.model.sd import SD15, SD2
from lama_cleaner.model.zits import ZITS from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga, models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
"sd2": SD2} "sd2": SD2, "paint_by_example": PaintByExample}
class ModelManager: class ModelManager:

View File

@ -10,7 +10,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--model", "--model",
default="lama", default="lama",
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2"], choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
) )
parser.add_argument( parser.add_argument(
"--hf_access_token", "--hf_access_token",

View File

@ -1,5 +1,6 @@
from enum import Enum from enum import Enum
from PIL.Image import Image
from pydantic import BaseModel from pydantic import BaseModel
@ -29,6 +30,9 @@ class SDSampler(str, Enum):
class Config(BaseModel): class Config(BaseModel):
class Config:
arbitrary_types_allowed = True
# Configs for ldm model # Configs for ldm model
ldm_steps: int ldm_steps: int
ldm_sampler: str = LDMSampler.plms ldm_sampler: str = LDMSampler.plms
@ -73,3 +77,11 @@ class Config(BaseModel):
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
cv2_flag: str = 'INPAINT_NS' cv2_flag: str = 'INPAINT_NS'
cv2_radius: int = 4 cv2_radius: int = 4
# Paint by Example
paint_by_example_steps: int = 50
paint_by_example_guidance_scale: float = 7.5
paint_by_example_mask_blur: int = 0
paint_by_example_seed: int = 42
paint_by_example_match_histograms: bool = False
paint_by_example_example_image: Image = None

View File

@ -10,6 +10,7 @@ import time
import imghdr import imghdr
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from PIL import Image
import cv2 import cv2
import torch import torch
@ -97,8 +98,8 @@ def process():
input = request.files input = request.files
# RGB # RGB
origin_image_bytes = input["image"].read() origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes) image, alpha_channel = load_img(origin_image_bytes)
mask, _ = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
@ -115,6 +116,12 @@ def process():
else: else:
size_limit = int(size_limit) size_limit = int(size_limit)
if "paintByExampleImage" in input:
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else:
paint_by_example_example_image = None
config = Config( config = Config(
ldm_steps=form["ldmSteps"], ldm_steps=form["ldmSteps"],
ldm_sampler=form["ldmSampler"], ldm_sampler=form["ldmSampler"],
@ -138,11 +145,19 @@ def process():
sd_seed=form["sdSeed"], sd_seed=form["sdSeed"],
sd_match_histograms=form["sdMatchHistograms"], sd_match_histograms=form["sdMatchHistograms"],
cv2_flag=form["cv2Flag"], cv2_flag=form["cv2Flag"],
cv2_radius=form['cv2Radius'] cv2_radius=form['cv2Radius'],
paint_by_example_steps=form["paintByExampleSteps"],
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
paint_by_example_seed=form["paintByExampleSeed"],
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
paint_by_example_example_image=paint_by_example_example_image,
) )
if config.sd_seed == -1: if config.sd_seed == -1:
config.sd_seed = random.randint(1, 999999999) config.sd_seed = random.randint(1, 999999999)
if config.paint_by_example_seed == -1:
config.paint_by_example_seed = random.randint(1, 999999999)
logger.info(f"Origin image shape: {original_shape}") logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 517 KiB

View File

@ -0,0 +1,50 @@
from pathlib import Path
import cv2
import pytest
import torch
from PIL import Image
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import HDStrategy
from lama_cleaner.tests.test_model import get_config, get_data
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
def assert_equal(
model, config, gt_name,
fx: float = 1, fy: float = 1,
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
example_p=current_dir / "rabbit.jpeg",
):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
example_image = cv2.imread(str(example_p))
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
config.paint_by_example_example_image = Image.fromarray(example_image)
res = model(img, mask, config)
cv2.imwrite(str(save_dir / gt_name), res)
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example(strategy):
model = ModelManager(name="paint_by_example", device=device)
cfg = get_config(strategy, paint_by_example_steps=30)
assert_equal(
model,
cfg,
f"paint_by_example_{strategy.capitalize()}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fy=0.9,
fx=1.3
)

View File

@ -10,5 +10,5 @@ pytest
yacs yacs
markupsafe==2.0.1 markupsafe==2.0.1
scikit-image==0.19.3 scikit-image==0.19.3
diffusers[torch]==0.9 diffusers[torch]==0.10.2
transformers==4.21.0 transformers>=4.25.1