add paint by example
This commit is contained in:
parent
6e9d3d8442
commit
203f2bc9c7
@ -12,7 +12,8 @@ export default async function inpaint(
|
|||||||
sizeLimit?: string,
|
sizeLimit?: string,
|
||||||
seed?: number,
|
seed?: number,
|
||||||
maskBase64?: string,
|
maskBase64?: string,
|
||||||
customMask?: File
|
customMask?: File,
|
||||||
|
paintByExampleImage?: File
|
||||||
) {
|
) {
|
||||||
// 1080, 2000, Original
|
// 1080, 2000, Original
|
||||||
const fd = new FormData()
|
const fd = new FormData()
|
||||||
@ -48,6 +49,7 @@ export default async function inpaint(
|
|||||||
fd.append('croperHeight', croperRect.height.toString())
|
fd.append('croperHeight', croperRect.height.toString())
|
||||||
fd.append('croperWidth', croperRect.width.toString())
|
fd.append('croperWidth', croperRect.width.toString())
|
||||||
fd.append('useCroper', settings.showCroper ? 'true' : 'false')
|
fd.append('useCroper', settings.showCroper ? 'true' : 'false')
|
||||||
|
|
||||||
fd.append('sdMaskBlur', settings.sdMaskBlur.toString())
|
fd.append('sdMaskBlur', settings.sdMaskBlur.toString())
|
||||||
fd.append('sdStrength', settings.sdStrength.toString())
|
fd.append('sdStrength', settings.sdStrength.toString())
|
||||||
fd.append('sdSteps', settings.sdSteps.toString())
|
fd.append('sdSteps', settings.sdSteps.toString())
|
||||||
@ -59,6 +61,26 @@ export default async function inpaint(
|
|||||||
fd.append('cv2Radius', settings.cv2Radius.toString())
|
fd.append('cv2Radius', settings.cv2Radius.toString())
|
||||||
fd.append('cv2Flag', settings.cv2Flag.toString())
|
fd.append('cv2Flag', settings.cv2Flag.toString())
|
||||||
|
|
||||||
|
fd.append('paintByExampleSteps', settings.paintByExampleSteps.toString())
|
||||||
|
fd.append(
|
||||||
|
'paintByExampleGuidanceScale',
|
||||||
|
settings.paintByExampleGuidanceScale.toString()
|
||||||
|
)
|
||||||
|
fd.append('paintByExampleSeed', seed ? seed.toString() : '-1')
|
||||||
|
fd.append(
|
||||||
|
'paintByExampleMaskBlur',
|
||||||
|
settings.paintByExampleMaskBlur.toString()
|
||||||
|
)
|
||||||
|
fd.append(
|
||||||
|
'paintByExampleMatchHistograms',
|
||||||
|
settings.paintByExampleMatchHistograms ? 'true' : 'false'
|
||||||
|
)
|
||||||
|
// TODO: resize image's shortest_edge to 224 before pass to backend, save network time?
|
||||||
|
// https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor
|
||||||
|
if (paintByExampleImage) {
|
||||||
|
fd.append('paintByExampleImage', paintByExampleImage)
|
||||||
|
}
|
||||||
|
|
||||||
if (sizeLimit === undefined) {
|
if (sizeLimit === undefined) {
|
||||||
fd.append('sizeLimit', '1080')
|
fd.append('sizeLimit', '1080')
|
||||||
} else {
|
} else {
|
||||||
|
@ -39,6 +39,7 @@ import {
|
|||||||
isInpaintingState,
|
isInpaintingState,
|
||||||
isInteractiveSegRunningState,
|
isInteractiveSegRunningState,
|
||||||
isInteractiveSegState,
|
isInteractiveSegState,
|
||||||
|
isPaintByExampleState,
|
||||||
isSDState,
|
isSDState,
|
||||||
negativePropmtState,
|
negativePropmtState,
|
||||||
propmtState,
|
propmtState,
|
||||||
@ -53,6 +54,7 @@ import emitter, {
|
|||||||
EVENT_PROMPT,
|
EVENT_PROMPT,
|
||||||
EVENT_CUSTOM_MASK,
|
EVENT_CUSTOM_MASK,
|
||||||
CustomMaskEventData,
|
CustomMaskEventData,
|
||||||
|
EVENT_PAINT_BY_EXAMPLE,
|
||||||
} from '../../event'
|
} from '../../event'
|
||||||
import FileSelect from '../FileSelect/FileSelect'
|
import FileSelect from '../FileSelect/FileSelect'
|
||||||
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
|
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
|
||||||
@ -108,6 +110,7 @@ export default function Editor() {
|
|||||||
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
|
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
|
||||||
const runMannually = useRecoilValue(runManuallyState)
|
const runMannually = useRecoilValue(runManuallyState)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isSD = useRecoilValue(isSDState)
|
||||||
|
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
||||||
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
|
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
|
||||||
isInteractiveSegState
|
isInteractiveSegState
|
||||||
)
|
)
|
||||||
@ -262,8 +265,11 @@ export default function Editor() {
|
|||||||
async (
|
async (
|
||||||
useLastLineGroup?: boolean,
|
useLastLineGroup?: boolean,
|
||||||
customMask?: File,
|
customMask?: File,
|
||||||
maskImage?: HTMLImageElement | null
|
maskImage?: HTMLImageElement | null,
|
||||||
|
paintByExampleImage?: File
|
||||||
) => {
|
) => {
|
||||||
|
// customMask: mask uploaded by user
|
||||||
|
// maskImage: mask from interactive segmentation
|
||||||
if (file === undefined) {
|
if (file === undefined) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -328,9 +334,6 @@ export default function Editor() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1
|
|
||||||
|
|
||||||
console.log({ useCustomMask })
|
|
||||||
try {
|
try {
|
||||||
const res = await inpaint(
|
const res = await inpaint(
|
||||||
targetFile,
|
targetFile,
|
||||||
@ -339,15 +342,16 @@ export default function Editor() {
|
|||||||
promptVal,
|
promptVal,
|
||||||
negativePromptVal,
|
negativePromptVal,
|
||||||
sizeLimit.toString(),
|
sizeLimit.toString(),
|
||||||
sdSeed,
|
seedVal,
|
||||||
useCustomMask ? undefined : maskCanvas.toDataURL(),
|
useCustomMask ? undefined : maskCanvas.toDataURL(),
|
||||||
useCustomMask ? customMask : undefined
|
useCustomMask ? customMask : undefined,
|
||||||
|
paintByExampleImage
|
||||||
)
|
)
|
||||||
if (!res) {
|
if (!res) {
|
||||||
throw new Error('Something went wrong on server side.')
|
throw new Error('Something went wrong on server side.')
|
||||||
}
|
}
|
||||||
const { blob, seed } = res
|
const { blob, seed } = res
|
||||||
if (seed && !settings.sdSeedFixed) {
|
if (seed) {
|
||||||
setSeed(parseInt(seed, 10))
|
setSeed(parseInt(seed, 10))
|
||||||
}
|
}
|
||||||
const newRender = new Image()
|
const newRender = new Image()
|
||||||
@ -395,6 +399,7 @@ export default function Editor() {
|
|||||||
drawOnCurrentRender,
|
drawOnCurrentRender,
|
||||||
hadDrawSomething,
|
hadDrawSomething,
|
||||||
drawLinesOnMask,
|
drawLinesOnMask,
|
||||||
|
seedVal,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -439,6 +444,31 @@ export default function Editor() {
|
|||||||
}
|
}
|
||||||
}, [runInpainting])
|
}, [runInpainting])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => {
|
||||||
|
if (hadDrawSomething() || interactiveSegMask) {
|
||||||
|
runInpainting(false, undefined, interactiveSegMask, data.image)
|
||||||
|
} else if (lastLineGroup.length !== 0) {
|
||||||
|
// 使用上一次手绘的 mask 生成
|
||||||
|
runInpainting(true, undefined, prevInteractiveSegMask, data.image)
|
||||||
|
} else if (prevInteractiveSegMask) {
|
||||||
|
// 使用上一次 IS 的 mask 生成
|
||||||
|
runInpainting(false, undefined, prevInteractiveSegMask, data.image)
|
||||||
|
} else {
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: 'Please draw mask on picture',
|
||||||
|
state: 'error',
|
||||||
|
duration: 1500,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
emitter.off(EVENT_PAINT_BY_EXAMPLE)
|
||||||
|
}
|
||||||
|
}, [runInpainting])
|
||||||
|
|
||||||
const hadRunInpainting = () => {
|
const hadRunInpainting = () => {
|
||||||
return renders.length !== 0
|
return renders.length !== 0
|
||||||
}
|
}
|
||||||
@ -793,7 +823,11 @@ export default function Editor() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isSD && settings.showCroper && isOutsideCroper(mouseXY(ev))) {
|
if (
|
||||||
|
(isSD || isPaintByExample) &&
|
||||||
|
settings.showCroper &&
|
||||||
|
isOutsideCroper(mouseXY(ev))
|
||||||
|
) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -876,7 +910,12 @@ export default function Editor() {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD])
|
useKey(undoPredicate, undo, undefined, [
|
||||||
|
undoStroke,
|
||||||
|
undoRender,
|
||||||
|
runMannually,
|
||||||
|
curLineGroup,
|
||||||
|
])
|
||||||
|
|
||||||
const disableUndo = () => {
|
const disableUndo = () => {
|
||||||
if (isInteractiveSeg) {
|
if (isInteractiveSeg) {
|
||||||
@ -955,7 +994,12 @@ export default function Editor() {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD])
|
useKey(redoPredicate, redo, undefined, [
|
||||||
|
redoStroke,
|
||||||
|
redoRender,
|
||||||
|
runMannually,
|
||||||
|
redoCurLines,
|
||||||
|
])
|
||||||
|
|
||||||
const disableRedo = () => {
|
const disableRedo = () => {
|
||||||
if (isInteractiveSeg) {
|
if (isInteractiveSeg) {
|
||||||
@ -1295,7 +1339,7 @@ export default function Editor() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{isSD && settings.showCroper ? (
|
{(isSD || isPaintByExample) && settings.showCroper ? (
|
||||||
<Croper
|
<Croper
|
||||||
maxHeight={original.naturalHeight}
|
maxHeight={original.naturalHeight}
|
||||||
maxWidth={original.naturalWidth}
|
maxWidth={original.naturalWidth}
|
||||||
@ -1358,7 +1402,7 @@ export default function Editor() {
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
<div className="editor-toolkit-panel">
|
<div className="editor-toolkit-panel">
|
||||||
{isSD || file === undefined ? (
|
{isSD || isPaintByExample || file === undefined ? (
|
||||||
<></>
|
<></>
|
||||||
) : (
|
) : (
|
||||||
<SizeSelector
|
<SizeSelector
|
||||||
@ -1466,7 +1510,7 @@ export default function Editor() {
|
|||||||
onClick={download}
|
onClick={download}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{settings.runInpaintingManually && !isSD && (
|
{settings.runInpaintingManually && !isSD && !isPaintByExample && (
|
||||||
<Button
|
<Button
|
||||||
toolTip="Run Inpainting"
|
toolTip="Run Inpainting"
|
||||||
tooltipPosition="top"
|
tooltipPosition="top"
|
||||||
|
@ -193,6 +193,8 @@ function ModelSettingBlock() {
|
|||||||
return undefined
|
return undefined
|
||||||
case AIModel.SD2:
|
case AIModel.SD2:
|
||||||
return undefined
|
return undefined
|
||||||
|
case AIModel.PAINT_BY_EXAMPLE:
|
||||||
|
return undefined
|
||||||
case AIModel.Mange:
|
case AIModel.Mange:
|
||||||
return undefined
|
return undefined
|
||||||
case AIModel.CV2:
|
case AIModel.CV2:
|
||||||
@ -258,6 +260,12 @@ function ModelSettingBlock() {
|
|||||||
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html',
|
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html',
|
||||||
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html'
|
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html'
|
||||||
)
|
)
|
||||||
|
case AIModel.PAINT_BY_EXAMPLE:
|
||||||
|
return renderModelDesc(
|
||||||
|
'Paint by Example',
|
||||||
|
'https://arxiv.org/abs/2211.13227',
|
||||||
|
'https://github.com/Fantasy-Studio/Paint-by-Example'
|
||||||
|
)
|
||||||
default:
|
default:
|
||||||
return <></>
|
return <></>
|
||||||
}
|
}
|
||||||
@ -270,7 +278,6 @@ function ModelSettingBlock() {
|
|||||||
titleSuffix={renderPaperCodeBadge()}
|
titleSuffix={renderPaperCodeBadge()}
|
||||||
input={
|
input={
|
||||||
<Selector
|
<Selector
|
||||||
width={80}
|
|
||||||
value={setting.model as string}
|
value={setting.model as string}
|
||||||
options={Object.values(AIModel)}
|
options={Object.values(AIModel)}
|
||||||
onChange={val => onModelChange(val as AIModel)}
|
onChange={val => onModelChange(val as AIModel)}
|
||||||
|
231
lama_cleaner/app/src/components/SidePanel/PESidePanel.tsx
Normal file
231
lama_cleaner/app/src/components/SidePanel/PESidePanel.tsx
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
|
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||||
|
import { useToggle } from 'react-use'
|
||||||
|
import { UploadIcon } from '@radix-ui/react-icons'
|
||||||
|
import {
|
||||||
|
isInpaintingState,
|
||||||
|
paintByExampleImageState,
|
||||||
|
settingState,
|
||||||
|
} from '../../store/Atoms'
|
||||||
|
import NumberInputSetting from '../Settings/NumberInputSetting'
|
||||||
|
import SettingBlock from '../Settings/SettingBlock'
|
||||||
|
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||||
|
import Button from '../shared/Button'
|
||||||
|
import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event'
|
||||||
|
import { useImage } from '../../utils'
|
||||||
|
|
||||||
|
const INPUT_WIDTH = 30
|
||||||
|
|
||||||
|
const PESidePanel = () => {
|
||||||
|
const [open, toggleOpen] = useToggle(true)
|
||||||
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
const [paintByExampleImage, setPaintByExampleImage] = useRecoilState(
|
||||||
|
paintByExampleImageState
|
||||||
|
)
|
||||||
|
const [uploadElemId] = useState(
|
||||||
|
`example-file-upload-${Math.random().toString()}`
|
||||||
|
)
|
||||||
|
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleImage)
|
||||||
|
const isInpainting = useRecoilValue(isInpaintingState)
|
||||||
|
|
||||||
|
const renderUploadIcon = () => {
|
||||||
|
return (
|
||||||
|
<label htmlFor={uploadElemId}>
|
||||||
|
<Button
|
||||||
|
border
|
||||||
|
toolTip="Upload example image"
|
||||||
|
tooltipPosition="top"
|
||||||
|
icon={<UploadIcon />}
|
||||||
|
style={{ padding: '0.3rem', gap: 0 }}
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
style={{ display: 'none' }}
|
||||||
|
id={uploadElemId}
|
||||||
|
name={uploadElemId}
|
||||||
|
type="file"
|
||||||
|
onChange={ev => {
|
||||||
|
const newFile = ev.currentTarget.files?.[0]
|
||||||
|
if (newFile) {
|
||||||
|
setPaintByExampleImage(newFile)
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
accept="image/png, image/jpeg"
|
||||||
|
/>
|
||||||
|
</Button>
|
||||||
|
</label>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="side-panel">
|
||||||
|
<PopoverPrimitive.Root open={open}>
|
||||||
|
<PopoverPrimitive.Trigger
|
||||||
|
className="btn-primary side-panel-trigger"
|
||||||
|
onClick={() => toggleOpen()}
|
||||||
|
>
|
||||||
|
Configurations
|
||||||
|
</PopoverPrimitive.Trigger>
|
||||||
|
<PopoverPrimitive.Portal>
|
||||||
|
<PopoverPrimitive.Content className="side-panel-content">
|
||||||
|
<SettingBlock
|
||||||
|
title="Croper"
|
||||||
|
input={
|
||||||
|
<Switch
|
||||||
|
checked={setting.showCroper}
|
||||||
|
onCheckedChange={value => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, showCroper: value }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<SwitchThumb />
|
||||||
|
</Switch>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<NumberInputSetting
|
||||||
|
title="Steps"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
value={`${setting.paintByExampleSteps}`}
|
||||||
|
desc="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, paintByExampleSteps: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<NumberInputSetting
|
||||||
|
title="Guidance Scale"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
allowFloat
|
||||||
|
value={`${setting.paintByExampleGuidanceScale}`}
|
||||||
|
desc="Higher guidance scale encourages to generate images that are close to the example image"
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, paintByExampleGuidanceScale: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<NumberInputSetting
|
||||||
|
title="Mask Blur"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
value={`${setting.paintByExampleMaskBlur}`}
|
||||||
|
desc="Blur the edge of mask area. The higher the number the smoother blend with the original image"
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, paintByExampleMaskBlur: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SettingBlock
|
||||||
|
title="Match Histograms"
|
||||||
|
desc="Match the inpainting result histogram to the source image histogram, will improves the inpainting quality for some images."
|
||||||
|
input={
|
||||||
|
<Switch
|
||||||
|
checked={setting.paintByExampleMatchHistograms}
|
||||||
|
onCheckedChange={value => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, paintByExampleMatchHistograms: value }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<SwitchThumb />
|
||||||
|
</Switch>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SettingBlock
|
||||||
|
title="Seed"
|
||||||
|
input={
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
gap: 0,
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{/* 每次会从服务器返回更新该值 */}
|
||||||
|
<NumberInputSetting
|
||||||
|
title=""
|
||||||
|
width={80}
|
||||||
|
value={`${setting.paintByExampleSeed}`}
|
||||||
|
desc=""
|
||||||
|
disable={!setting.paintByExampleSeedFixed}
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, paintByExampleSeed: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
checked={setting.paintByExampleSeedFixed}
|
||||||
|
onCheckedChange={value => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, paintByExampleSeedFixed: value }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
style={{ marginLeft: '8px' }}
|
||||||
|
>
|
||||||
|
<SwitchThumb />
|
||||||
|
</Switch>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div style={{ display: 'flex', flexDirection: 'column' }}>
|
||||||
|
<SettingBlock title="Example Image" input={renderUploadIcon()} />
|
||||||
|
|
||||||
|
{paintByExampleImage ? (
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<img
|
||||||
|
src={exampleImage.src}
|
||||||
|
alt="example"
|
||||||
|
style={{
|
||||||
|
maxWidth: 200,
|
||||||
|
maxHeight: 200,
|
||||||
|
margin: 12,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<></>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
border
|
||||||
|
disabled={!isExampleImageLoaded || isInpainting}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
onClick={() => {
|
||||||
|
if (isExampleImageLoaded) {
|
||||||
|
emitter.emit(EVENT_PAINT_BY_EXAMPLE, {
|
||||||
|
image: paintByExampleImage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Paint
|
||||||
|
</Button>
|
||||||
|
</PopoverPrimitive.Content>
|
||||||
|
</PopoverPrimitive.Portal>
|
||||||
|
</PopoverPrimitive.Root>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default PESidePanel
|
@ -7,6 +7,7 @@ import Toast from './shared/Toast'
|
|||||||
import {
|
import {
|
||||||
AIModel,
|
AIModel,
|
||||||
fileState,
|
fileState,
|
||||||
|
isPaintByExampleState,
|
||||||
isSDState,
|
isSDState,
|
||||||
settingState,
|
settingState,
|
||||||
toastState,
|
toastState,
|
||||||
@ -17,12 +18,14 @@ import {
|
|||||||
switchModel,
|
switchModel,
|
||||||
} from '../adapters/inpainting'
|
} from '../adapters/inpainting'
|
||||||
import SidePanel from './SidePanel/SidePanel'
|
import SidePanel from './SidePanel/SidePanel'
|
||||||
|
import PESidePanel from './SidePanel/PESidePanel'
|
||||||
|
|
||||||
const Workspace = () => {
|
const Workspace = () => {
|
||||||
const [file, setFile] = useRecoilState(fileState)
|
const [file, setFile] = useRecoilState(fileState)
|
||||||
const [settings, setSettingState] = useRecoilState(settingState)
|
const [settings, setSettingState] = useRecoilState(settingState)
|
||||||
const [toastVal, setToastState] = useRecoilState(toastState)
|
const [toastVal, setToastState] = useRecoilState(toastState)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isSD = useRecoilValue(isSDState)
|
||||||
|
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
||||||
|
|
||||||
const onSettingClose = async () => {
|
const onSettingClose = async () => {
|
||||||
const curModel = await currentModel().then(res => res.text())
|
const curModel = await currentModel().then(res => res.text())
|
||||||
@ -88,6 +91,7 @@ const Workspace = () => {
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{isSD ? <SidePanel /> : <></>}
|
{isSD ? <SidePanel /> : <></>}
|
||||||
|
{isPaintByExample ? <PESidePanel /> : <></>}
|
||||||
<Editor />
|
<Editor />
|
||||||
<SettingModal onClose={onSettingClose} />
|
<SettingModal onClose={onSettingClose} />
|
||||||
<ShortcutsModal />
|
<ShortcutsModal />
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
import mitt from 'mitt'
|
import mitt from 'mitt'
|
||||||
|
|
||||||
export const EVENT_PROMPT = 'prompt'
|
export const EVENT_PROMPT = 'prompt'
|
||||||
|
|
||||||
export const EVENT_CUSTOM_MASK = 'custom_mask'
|
export const EVENT_CUSTOM_MASK = 'custom_mask'
|
||||||
export interface CustomMaskEventData {
|
export interface CustomMaskEventData {
|
||||||
mask: File
|
mask: File
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const EVENT_PAINT_BY_EXAMPLE = 'paint_by_example'
|
||||||
|
export interface PaintByExampleEventData {
|
||||||
|
image: File
|
||||||
|
}
|
||||||
|
|
||||||
const emitter = mitt()
|
const emitter = mitt()
|
||||||
|
|
||||||
export default emitter
|
export default emitter
|
||||||
|
@ -13,6 +13,7 @@ export enum AIModel {
|
|||||||
SD2 = 'sd2',
|
SD2 = 'sd2',
|
||||||
CV2 = 'cv2',
|
CV2 = 'cv2',
|
||||||
Mange = 'manga',
|
Mange = 'manga',
|
||||||
|
PAINT_BY_EXAMPLE = 'paint_by_example',
|
||||||
}
|
}
|
||||||
|
|
||||||
export const maskState = atom<File | undefined>({
|
export const maskState = atom<File | undefined>({
|
||||||
@ -20,6 +21,11 @@ export const maskState = atom<File | undefined>({
|
|||||||
default: undefined,
|
default: undefined,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const paintByExampleImageState = atom<File | undefined>({
|
||||||
|
key: 'paintByExampleImageState',
|
||||||
|
default: undefined,
|
||||||
|
})
|
||||||
|
|
||||||
export interface Rect {
|
export interface Rect {
|
||||||
x: number
|
x: number
|
||||||
y: number
|
y: number
|
||||||
@ -252,6 +258,14 @@ export interface Settings {
|
|||||||
// For OpenCV2
|
// For OpenCV2
|
||||||
cv2Radius: number
|
cv2Radius: number
|
||||||
cv2Flag: CV2Flag
|
cv2Flag: CV2Flag
|
||||||
|
|
||||||
|
// Paint by Example
|
||||||
|
paintByExampleSteps: number
|
||||||
|
paintByExampleGuidanceScale: number
|
||||||
|
paintByExampleSeed: number
|
||||||
|
paintByExampleSeedFixed: boolean
|
||||||
|
paintByExampleMaskBlur: number
|
||||||
|
paintByExampleMatchHistograms: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultHDSettings: ModelsHDSettings = {
|
const defaultHDSettings: ModelsHDSettings = {
|
||||||
@ -304,6 +318,13 @@ const defaultHDSettings: ModelsHDSettings = {
|
|||||||
hdStrategyCropMargin: 128,
|
hdStrategyCropMargin: 128,
|
||||||
enabled: false,
|
enabled: false,
|
||||||
},
|
},
|
||||||
|
[AIModel.PAINT_BY_EXAMPLE]: {
|
||||||
|
hdStrategy: HDStrategy.ORIGINAL,
|
||||||
|
hdStrategyResizeLimit: 768,
|
||||||
|
hdStrategyCropTrigerSize: 512,
|
||||||
|
hdStrategyCropMargin: 128,
|
||||||
|
enabled: false,
|
||||||
|
},
|
||||||
[AIModel.Mange]: {
|
[AIModel.Mange]: {
|
||||||
hdStrategy: HDStrategy.CROP,
|
hdStrategy: HDStrategy.CROP,
|
||||||
hdStrategyResizeLimit: 1280,
|
hdStrategyResizeLimit: 1280,
|
||||||
@ -364,6 +385,14 @@ export const settingStateDefault: Settings = {
|
|||||||
// CV2
|
// CV2
|
||||||
cv2Radius: 5,
|
cv2Radius: 5,
|
||||||
cv2Flag: CV2Flag.INPAINT_NS,
|
cv2Flag: CV2Flag.INPAINT_NS,
|
||||||
|
|
||||||
|
// Paint by Example
|
||||||
|
paintByExampleSteps: 50,
|
||||||
|
paintByExampleGuidanceScale: 7.5,
|
||||||
|
paintByExampleSeed: 42,
|
||||||
|
paintByExampleMaskBlur: 5,
|
||||||
|
paintByExampleSeedFixed: false,
|
||||||
|
paintByExampleMatchHistograms: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
const localStorageEffect =
|
const localStorageEffect =
|
||||||
@ -401,11 +430,28 @@ export const seedState = selector({
|
|||||||
key: 'seed',
|
key: 'seed',
|
||||||
get: ({ get }) => {
|
get: ({ get }) => {
|
||||||
const settings = get(settingState)
|
const settings = get(settingState)
|
||||||
return settings.sdSeed
|
switch (settings.model) {
|
||||||
|
case AIModel.PAINT_BY_EXAMPLE:
|
||||||
|
return settings.paintByExampleSeedFixed
|
||||||
|
? settings.paintByExampleSeed
|
||||||
|
: -1
|
||||||
|
default:
|
||||||
|
return settings.sdSeedFixed ? settings.sdSeed : -1
|
||||||
|
}
|
||||||
},
|
},
|
||||||
set: ({ get, set }, newValue: any) => {
|
set: ({ get, set }, newValue: any) => {
|
||||||
const settings = get(settingState)
|
const settings = get(settingState)
|
||||||
set(settingState, { ...settings, sdSeed: newValue })
|
switch (settings.model) {
|
||||||
|
case AIModel.PAINT_BY_EXAMPLE:
|
||||||
|
if (!settings.paintByExampleSeedFixed) {
|
||||||
|
set(settingState, { ...settings, paintByExampleSeed: newValue })
|
||||||
|
}
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
if (!settings.sdSeedFixed) {
|
||||||
|
set(settingState, { ...settings, sdSeed: newValue })
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -435,11 +481,20 @@ export const isSDState = selector({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const isPaintByExampleState = selector({
|
||||||
|
key: 'isPaintByExampleState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const settings = get(settingState)
|
||||||
|
return settings.model === AIModel.PAINT_BY_EXAMPLE
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
export const runManuallyState = selector({
|
export const runManuallyState = selector({
|
||||||
key: 'runManuallyState',
|
key: 'runManuallyState',
|
||||||
get: ({ get }) => {
|
get: ({ get }) => {
|
||||||
const settings = get(settingState)
|
const settings = get(settingState)
|
||||||
const isSD = get(isSDState)
|
const isSD = get(isSDState)
|
||||||
return settings.runInpaintingManually || isSD
|
const isPaintByExample = get(isPaintByExampleState)
|
||||||
|
return settings.runInpaintingManually || isSD || isPaintByExample
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -211,6 +211,26 @@ class InpaintModel:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _apply_cropper(self, image, mask, config: Config):
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
l, t, w, h = (
|
||||||
|
config.croper_x,
|
||||||
|
config.croper_y,
|
||||||
|
config.croper_width,
|
||||||
|
config.croper_height,
|
||||||
|
)
|
||||||
|
r = l + w
|
||||||
|
b = t + h
|
||||||
|
|
||||||
|
l = max(l, 0)
|
||||||
|
r = min(r, img_w)
|
||||||
|
t = max(t, 0)
|
||||||
|
b = min(b, img_h)
|
||||||
|
|
||||||
|
crop_img = image[t:b, l:r, :]
|
||||||
|
crop_mask = mask[t:b, l:r]
|
||||||
|
return crop_img, crop_mask, (l, t, r, b)
|
||||||
|
|
||||||
def _run_box(self, image, mask, box, config: Config):
|
def _run_box(self, image, mask, box, config: Config):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -0,0 +1,80 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from lama_cleaner.model.base import InpaintModel
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
|
class PaintByExample(InpaintModel):
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||||
|
torch_dtype = torch.float16 if use_gpu else torch.float32
|
||||||
|
self.model = DiffusionPipeline.from_pretrained(
|
||||||
|
"Fantasy-Studio/Paint-by-Example",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
seed = config.paint_by_example_seed
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
output = self.model(
|
||||||
|
image=PIL.Image.fromarray(image),
|
||||||
|
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||||
|
example_image=config.paint_by_example_example_image,
|
||||||
|
num_inference_steps=config.paint_by_example_steps,
|
||||||
|
output_type='np.array',
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, image, mask, config: Config):
|
||||||
|
"""
|
||||||
|
images: [H, W, C] RGB, not normalized
|
||||||
|
masks: [H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
if config.use_croper:
|
||||||
|
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||||
|
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
||||||
|
inpaint_result = image[:, :, ::-1]
|
||||||
|
inpaint_result[t:b, l:r, :] = crop_image
|
||||||
|
else:
|
||||||
|
inpaint_result = self._pad_forward(image, mask, config)
|
||||||
|
|
||||||
|
return inpaint_result
|
||||||
|
|
||||||
|
def forward_post_process(self, result, image, mask, config):
|
||||||
|
if config.paint_by_example_match_histograms:
|
||||||
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
|
||||||
|
if config.paint_by_example_mask_blur != 0:
|
||||||
|
k = 2 * config.paint_by_example_mask_blur + 1
|
||||||
|
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||||
|
return result, image, mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
# model will be downloaded when app start, and can't switch in frontend settings
|
||||||
|
return True
|
@ -12,31 +12,6 @@ from lama_cleaner.model.base import InpaintModel
|
|||||||
from lama_cleaner.schema import Config, SDSampler
|
from lama_cleaner.schema import Config, SDSampler
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# def preprocess_image(image):
|
|
||||||
# w, h = image.size
|
|
||||||
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
|
||||||
# image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
|
||||||
# image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
# image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
# image = torch.from_numpy(image)
|
|
||||||
# # [-1, 1]
|
|
||||||
# return 2.0 * image - 1.0
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# def preprocess_mask(mask):
|
|
||||||
# mask = mask.convert("L")
|
|
||||||
# w, h = mask.size
|
|
||||||
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
|
||||||
# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
|
||||||
# mask = np.array(mask).astype(np.float32) / 255.0
|
|
||||||
# mask = np.tile(mask, (4, 1, 1))
|
|
||||||
# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
|
||||||
# mask = 1 - mask # repaint white, keep black
|
|
||||||
# mask = torch.from_numpy(mask)
|
|
||||||
# return mask
|
|
||||||
|
|
||||||
class CPUTextEncoderWrapper:
|
class CPUTextEncoderWrapper:
|
||||||
def __init__(self, text_encoder, torch_dtype):
|
def __init__(self, text_encoder, torch_dtype):
|
||||||
self.config = text_encoder.config
|
self.config = text_encoder.config
|
||||||
@ -92,17 +67,6 @@ class SD(InpaintModel):
|
|||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# image = norm_img(image) # [0, 1]
|
|
||||||
# image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
|
||||||
|
|
||||||
# resize to latent feature map size
|
|
||||||
# h, w = mask.shape[:2]
|
|
||||||
# mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA)
|
|
||||||
# mask = norm_img(mask)
|
|
||||||
#
|
|
||||||
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
|
||||||
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
scheduler_config = self.model.scheduler.config
|
scheduler_config = self.model.scheduler.config
|
||||||
|
|
||||||
if config.sd_sampler == SDSampler.ddim:
|
if config.sd_sampler == SDSampler.ddim:
|
||||||
@ -139,7 +103,6 @@ class SD(InpaintModel):
|
|||||||
prompt=config.prompt,
|
prompt=config.prompt,
|
||||||
negative_prompt=config.negative_prompt,
|
negative_prompt=config.negative_prompt,
|
||||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||||
strength=config.sd_strength,
|
|
||||||
num_inference_steps=config.sd_steps,
|
num_inference_steps=config.sd_steps,
|
||||||
guidance_scale=config.sd_guidance_scale,
|
guidance_scale=config.sd_guidance_scale,
|
||||||
output_type="np.array",
|
output_type="np.array",
|
||||||
@ -159,30 +122,10 @@ class SD(InpaintModel):
|
|||||||
masks: [H, W]
|
masks: [H, W]
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
img_h, img_w = image.shape[:2]
|
|
||||||
|
|
||||||
# boxes = boxes_from_mask(mask)
|
# boxes = boxes_from_mask(mask)
|
||||||
if config.use_croper:
|
if config.use_croper:
|
||||||
logger.info("use croper")
|
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||||
l, t, w, h = (
|
|
||||||
config.croper_x,
|
|
||||||
config.croper_y,
|
|
||||||
config.croper_width,
|
|
||||||
config.croper_height,
|
|
||||||
)
|
|
||||||
r = l + w
|
|
||||||
b = t + h
|
|
||||||
|
|
||||||
l = max(l, 0)
|
|
||||||
r = min(r, img_w)
|
|
||||||
t = max(t, 0)
|
|
||||||
b = min(b, img_h)
|
|
||||||
|
|
||||||
crop_img = image[t:b, l:r, :]
|
|
||||||
crop_mask = mask[t:b, l:r]
|
|
||||||
|
|
||||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
||||||
|
|
||||||
inpaint_result = image[:, :, ::-1]
|
inpaint_result = image[:, :, ::-1]
|
||||||
inpaint_result[t:b, l:r, :] = crop_image
|
inpaint_result[t:b, l:r, :] = crop_image
|
||||||
else:
|
else:
|
||||||
|
@ -5,13 +5,14 @@ from lama_cleaner.model.lama import LaMa
|
|||||||
from lama_cleaner.model.ldm import LDM
|
from lama_cleaner.model.ldm import LDM
|
||||||
from lama_cleaner.model.manga import Manga
|
from lama_cleaner.model.manga import Manga
|
||||||
from lama_cleaner.model.mat import MAT
|
from lama_cleaner.model.mat import MAT
|
||||||
|
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||||
from lama_cleaner.model.sd import SD15, SD2
|
from lama_cleaner.model.sd import SD15, SD2
|
||||||
from lama_cleaner.model.zits import ZITS
|
from lama_cleaner.model.zits import ZITS
|
||||||
from lama_cleaner.model.opencv2 import OpenCV2
|
from lama_cleaner.model.opencv2 import OpenCV2
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
|
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
|
||||||
"sd2": SD2}
|
"sd2": SD2, "paint_by_example": PaintByExample}
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
|
@ -10,7 +10,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="lama",
|
default="lama",
|
||||||
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2"],
|
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf_access_token",
|
"--hf_access_token",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@ -29,6 +30,9 @@ class SDSampler(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
# Configs for ldm model
|
# Configs for ldm model
|
||||||
ldm_steps: int
|
ldm_steps: int
|
||||||
ldm_sampler: str = LDMSampler.plms
|
ldm_sampler: str = LDMSampler.plms
|
||||||
@ -73,3 +77,11 @@ class Config(BaseModel):
|
|||||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||||
cv2_flag: str = 'INPAINT_NS'
|
cv2_flag: str = 'INPAINT_NS'
|
||||||
cv2_radius: int = 4
|
cv2_radius: int = 4
|
||||||
|
|
||||||
|
# Paint by Example
|
||||||
|
paint_by_example_steps: int = 50
|
||||||
|
paint_by_example_guidance_scale: float = 7.5
|
||||||
|
paint_by_example_mask_blur: int = 0
|
||||||
|
paint_by_example_seed: int = 42
|
||||||
|
paint_by_example_match_histograms: bool = False
|
||||||
|
paint_by_example_example_image: Image = None
|
||||||
|
@ -10,6 +10,7 @@ import time
|
|||||||
import imghdr
|
import imghdr
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
@ -97,8 +98,8 @@ def process():
|
|||||||
input = request.files
|
input = request.files
|
||||||
# RGB
|
# RGB
|
||||||
origin_image_bytes = input["image"].read()
|
origin_image_bytes = input["image"].read()
|
||||||
|
|
||||||
image, alpha_channel = load_img(origin_image_bytes)
|
image, alpha_channel = load_img(origin_image_bytes)
|
||||||
|
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||||
|
|
||||||
@ -115,6 +116,12 @@ def process():
|
|||||||
else:
|
else:
|
||||||
size_limit = int(size_limit)
|
size_limit = int(size_limit)
|
||||||
|
|
||||||
|
if "paintByExampleImage" in input:
|
||||||
|
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
|
||||||
|
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
|
||||||
|
else:
|
||||||
|
paint_by_example_example_image = None
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
ldm_steps=form["ldmSteps"],
|
ldm_steps=form["ldmSteps"],
|
||||||
ldm_sampler=form["ldmSampler"],
|
ldm_sampler=form["ldmSampler"],
|
||||||
@ -138,11 +145,19 @@ def process():
|
|||||||
sd_seed=form["sdSeed"],
|
sd_seed=form["sdSeed"],
|
||||||
sd_match_histograms=form["sdMatchHistograms"],
|
sd_match_histograms=form["sdMatchHistograms"],
|
||||||
cv2_flag=form["cv2Flag"],
|
cv2_flag=form["cv2Flag"],
|
||||||
cv2_radius=form['cv2Radius']
|
cv2_radius=form['cv2Radius'],
|
||||||
|
paint_by_example_steps=form["paintByExampleSteps"],
|
||||||
|
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
|
||||||
|
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
|
||||||
|
paint_by_example_seed=form["paintByExampleSeed"],
|
||||||
|
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
|
||||||
|
paint_by_example_example_image=paint_by_example_example_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.sd_seed == -1:
|
if config.sd_seed == -1:
|
||||||
config.sd_seed = random.randint(1, 999999999)
|
config.sd_seed = random.randint(1, 999999999)
|
||||||
|
if config.paint_by_example_seed == -1:
|
||||||
|
config.paint_by_example_seed = random.randint(1, 999999999)
|
||||||
|
|
||||||
logger.info(f"Origin image shape: {original_shape}")
|
logger.info(f"Origin image shape: {original_shape}")
|
||||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||||
|
BIN
lama_cleaner/tests/bunny.jpeg
Normal file
BIN
lama_cleaner/tests/bunny.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 51 KiB |
BIN
lama_cleaner/tests/result/paint_by_example_Original.png
Normal file
BIN
lama_cleaner/tests/result/paint_by_example_Original.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 517 KiB |
@ -0,0 +1,50 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.schema import HDStrategy
|
||||||
|
from lama_cleaner.tests.test_model import get_config, get_data
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / 'result'
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_equal(
|
||||||
|
model, config, gt_name,
|
||||||
|
fx: float = 1, fy: float = 1,
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
example_p=current_dir / "rabbit.jpeg",
|
||||||
|
):
|
||||||
|
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||||
|
|
||||||
|
example_image = cv2.imread(str(example_p))
|
||||||
|
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
|
||||||
|
example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
|
||||||
|
config.paint_by_example_example_image = Image.fromarray(example_image)
|
||||||
|
res = model(img, mask, config)
|
||||||
|
cv2.imwrite(str(save_dir / gt_name), res)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
def test_paint_by_example(strategy):
|
||||||
|
model = ModelManager(name="paint_by_example", device=device)
|
||||||
|
cfg = get_config(strategy, paint_by_example_steps=30)
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"paint_by_example_{strategy.capitalize()}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fy=0.9,
|
||||||
|
fx=1.3
|
||||||
|
)
|
@ -10,5 +10,5 @@ pytest
|
|||||||
yacs
|
yacs
|
||||||
markupsafe==2.0.1
|
markupsafe==2.0.1
|
||||||
scikit-image==0.19.3
|
scikit-image==0.19.3
|
||||||
diffusers[torch]==0.9
|
diffusers[torch]==0.10.2
|
||||||
transformers==4.21.0
|
transformers>=4.25.1
|
||||||
|
Loading…
Reference in New Issue
Block a user