Merge branch 'add_paint_by_example'
This commit is contained in:
commit
c79778f492
@ -1,17 +1,17 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.bb67386a.chunk.css",
|
||||
"main.js": "/static/js/main.5cf6948e.chunk.js",
|
||||
"main.css": "/static/css/main.b30da02b.chunk.css",
|
||||
"main.js": "/static/js/main.994b5b32.chunk.js",
|
||||
"runtime-main.js": "/static/js/runtime-main.5e86ac81.js",
|
||||
"static/js/2.ee9dcc6c.chunk.js": "/static/js/2.ee9dcc6c.chunk.js",
|
||||
"static/js/2.ada71d88.chunk.js": "/static/js/2.ada71d88.chunk.js",
|
||||
"index.html": "/index.html",
|
||||
"static/js/2.ee9dcc6c.chunk.js.LICENSE.txt": "/static/js/2.ee9dcc6c.chunk.js.LICENSE.txt",
|
||||
"static/js/2.ada71d88.chunk.js.LICENSE.txt": "/static/js/2.ada71d88.chunk.js.LICENSE.txt",
|
||||
"static/media/_index.scss": "/static/media/WorkSans-SemiBold.1e98db4e.ttf"
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/js/runtime-main.5e86ac81.js",
|
||||
"static/js/2.ee9dcc6c.chunk.js",
|
||||
"static/css/main.bb67386a.chunk.css",
|
||||
"static/js/main.5cf6948e.chunk.js"
|
||||
"static/js/2.ada71d88.chunk.js",
|
||||
"static/css/main.b30da02b.chunk.css",
|
||||
"static/js/main.994b5b32.chunk.js"
|
||||
]
|
||||
}
|
@ -1 +1 @@
|
||||
<!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><link href="/static/css/main.bb67386a.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.ee9dcc6c.chunk.js"></script><script src="/static/js/main.5cf6948e.chunk.js"></script></body></html>
|
||||
<!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><link href="/static/css/main.b30da02b.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.ada71d88.chunk.js"></script><script src="/static/js/main.994b5b32.chunk.js"></script></body></html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
lama_cleaner/app/build/static/js/2.ada71d88.chunk.js
Normal file
2
lama_cleaner/app/build/static/js/2.ada71d88.chunk.js
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
lama_cleaner/app/build/static/js/main.994b5b32.chunk.js
Normal file
1
lama_cleaner/app/build/static/js/main.994b5b32.chunk.js
Normal file
File diff suppressed because one or more lines are too long
@ -12,7 +12,8 @@ export default async function inpaint(
|
||||
sizeLimit?: string,
|
||||
seed?: number,
|
||||
maskBase64?: string,
|
||||
customMask?: File
|
||||
customMask?: File,
|
||||
paintByExampleImage?: File
|
||||
) {
|
||||
// 1080, 2000, Original
|
||||
const fd = new FormData()
|
||||
@ -48,6 +49,7 @@ export default async function inpaint(
|
||||
fd.append('croperHeight', croperRect.height.toString())
|
||||
fd.append('croperWidth', croperRect.width.toString())
|
||||
fd.append('useCroper', settings.showCroper ? 'true' : 'false')
|
||||
|
||||
fd.append('sdMaskBlur', settings.sdMaskBlur.toString())
|
||||
fd.append('sdStrength', settings.sdStrength.toString())
|
||||
fd.append('sdSteps', settings.sdSteps.toString())
|
||||
@ -59,6 +61,26 @@ export default async function inpaint(
|
||||
fd.append('cv2Radius', settings.cv2Radius.toString())
|
||||
fd.append('cv2Flag', settings.cv2Flag.toString())
|
||||
|
||||
fd.append('paintByExampleSteps', settings.paintByExampleSteps.toString())
|
||||
fd.append(
|
||||
'paintByExampleGuidanceScale',
|
||||
settings.paintByExampleGuidanceScale.toString()
|
||||
)
|
||||
fd.append('paintByExampleSeed', seed ? seed.toString() : '-1')
|
||||
fd.append(
|
||||
'paintByExampleMaskBlur',
|
||||
settings.paintByExampleMaskBlur.toString()
|
||||
)
|
||||
fd.append(
|
||||
'paintByExampleMatchHistograms',
|
||||
settings.paintByExampleMatchHistograms ? 'true' : 'false'
|
||||
)
|
||||
// TODO: resize image's shortest_edge to 224 before pass to backend, save network time?
|
||||
// https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor
|
||||
if (paintByExampleImage) {
|
||||
fd.append('paintByExampleImage', paintByExampleImage)
|
||||
}
|
||||
|
||||
if (sizeLimit === undefined) {
|
||||
fd.append('sizeLimit', '1080')
|
||||
} else {
|
||||
|
@ -30,6 +30,7 @@ interface Props {
|
||||
scale: number
|
||||
minHeight: number
|
||||
minWidth: number
|
||||
show: boolean
|
||||
}
|
||||
|
||||
const clamp = (
|
||||
@ -66,7 +67,7 @@ const clamp = (
|
||||
}
|
||||
|
||||
const Croper = (props: Props) => {
|
||||
const { minHeight, minWidth, maxHeight, maxWidth, scale } = props
|
||||
const { minHeight, minWidth, maxHeight, maxWidth, scale, show } = props
|
||||
const [x, setX] = useRecoilState(croperX)
|
||||
const [y, setY] = useRecoilState(croperY)
|
||||
const [height, setHeight] = useRecoilState(croperHeight)
|
||||
@ -79,7 +80,7 @@ const Croper = (props: Props) => {
|
||||
useEffect(() => {
|
||||
setX(Math.round((maxWidth - 512) / 2))
|
||||
setY(Math.round((maxHeight - 512) / 2))
|
||||
}, [maxHeight, maxWidth, minHeight, minWidth])
|
||||
}, [maxHeight, maxWidth])
|
||||
|
||||
const [evData, setEVData] = useState<EVData>({
|
||||
initX: 0,
|
||||
@ -391,7 +392,10 @@ const Croper = (props: Props) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="croper-wrapper">
|
||||
<div
|
||||
className="croper-wrapper"
|
||||
style={{ visibility: show ? 'visible' : 'hidden' }}
|
||||
>
|
||||
<div className="croper" style={{ height, width, left: x, top: y }}>
|
||||
{createBorder()}
|
||||
{createInfoBar()}
|
||||
|
@ -39,6 +39,7 @@ import {
|
||||
isInpaintingState,
|
||||
isInteractiveSegRunningState,
|
||||
isInteractiveSegState,
|
||||
isPaintByExampleState,
|
||||
isSDState,
|
||||
negativePropmtState,
|
||||
propmtState,
|
||||
@ -53,6 +54,7 @@ import emitter, {
|
||||
EVENT_PROMPT,
|
||||
EVENT_CUSTOM_MASK,
|
||||
CustomMaskEventData,
|
||||
EVENT_PAINT_BY_EXAMPLE,
|
||||
} from '../../event'
|
||||
import FileSelect from '../FileSelect/FileSelect'
|
||||
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
|
||||
@ -108,6 +110,7 @@ export default function Editor() {
|
||||
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
|
||||
const runMannually = useRecoilValue(runManuallyState)
|
||||
const isSD = useRecoilValue(isSDState)
|
||||
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
||||
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
|
||||
isInteractiveSegState
|
||||
)
|
||||
@ -262,8 +265,11 @@ export default function Editor() {
|
||||
async (
|
||||
useLastLineGroup?: boolean,
|
||||
customMask?: File,
|
||||
maskImage?: HTMLImageElement | null
|
||||
maskImage?: HTMLImageElement | null,
|
||||
paintByExampleImage?: File
|
||||
) => {
|
||||
// customMask: mask uploaded by user
|
||||
// maskImage: mask from interactive segmentation
|
||||
if (file === undefined) {
|
||||
return
|
||||
}
|
||||
@ -328,9 +334,6 @@ export default function Editor() {
|
||||
}
|
||||
}
|
||||
|
||||
const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1
|
||||
|
||||
console.log({ useCustomMask })
|
||||
try {
|
||||
const res = await inpaint(
|
||||
targetFile,
|
||||
@ -339,15 +342,16 @@ export default function Editor() {
|
||||
promptVal,
|
||||
negativePromptVal,
|
||||
sizeLimit.toString(),
|
||||
sdSeed,
|
||||
seedVal,
|
||||
useCustomMask ? undefined : maskCanvas.toDataURL(),
|
||||
useCustomMask ? customMask : undefined
|
||||
useCustomMask ? customMask : undefined,
|
||||
paintByExampleImage
|
||||
)
|
||||
if (!res) {
|
||||
throw new Error('Something went wrong on server side.')
|
||||
}
|
||||
const { blob, seed } = res
|
||||
if (seed && !settings.sdSeedFixed) {
|
||||
if (seed) {
|
||||
setSeed(parseInt(seed, 10))
|
||||
}
|
||||
const newRender = new Image()
|
||||
@ -395,6 +399,7 @@ export default function Editor() {
|
||||
drawOnCurrentRender,
|
||||
hadDrawSomething,
|
||||
drawLinesOnMask,
|
||||
seedVal,
|
||||
]
|
||||
)
|
||||
|
||||
@ -431,6 +436,7 @@ export default function Editor() {
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(EVENT_CUSTOM_MASK, (data: any) => {
|
||||
// TODO: not work with paint by example
|
||||
runInpainting(false, data.mask)
|
||||
})
|
||||
|
||||
@ -439,6 +445,31 @@ export default function Editor() {
|
||||
}
|
||||
}, [runInpainting])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => {
|
||||
if (hadDrawSomething() || interactiveSegMask) {
|
||||
runInpainting(false, undefined, interactiveSegMask, data.image)
|
||||
} else if (lastLineGroup.length !== 0) {
|
||||
// 使用上一次手绘的 mask 生成
|
||||
runInpainting(true, undefined, prevInteractiveSegMask, data.image)
|
||||
} else if (prevInteractiveSegMask) {
|
||||
// 使用上一次 IS 的 mask 生成
|
||||
runInpainting(false, undefined, prevInteractiveSegMask, data.image)
|
||||
} else {
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: 'Please draw mask on picture',
|
||||
state: 'error',
|
||||
duration: 1500,
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
emitter.off(EVENT_PAINT_BY_EXAMPLE)
|
||||
}
|
||||
}, [runInpainting])
|
||||
|
||||
const hadRunInpainting = () => {
|
||||
return renders.length !== 0
|
||||
}
|
||||
@ -793,7 +824,11 @@ export default function Editor() {
|
||||
return
|
||||
}
|
||||
|
||||
if (isSD && settings.showCroper && isOutsideCroper(mouseXY(ev))) {
|
||||
if (
|
||||
(isSD || isPaintByExample) &&
|
||||
settings.showCroper &&
|
||||
isOutsideCroper(mouseXY(ev))
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -876,7 +911,12 @@ export default function Editor() {
|
||||
return false
|
||||
}
|
||||
|
||||
useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD])
|
||||
useKey(undoPredicate, undo, undefined, [
|
||||
undoStroke,
|
||||
undoRender,
|
||||
runMannually,
|
||||
curLineGroup,
|
||||
])
|
||||
|
||||
const disableUndo = () => {
|
||||
if (isInteractiveSeg) {
|
||||
@ -955,7 +995,12 @@ export default function Editor() {
|
||||
return false
|
||||
}
|
||||
|
||||
useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD])
|
||||
useKey(redoPredicate, redo, undefined, [
|
||||
redoStroke,
|
||||
redoRender,
|
||||
runMannually,
|
||||
redoCurLines,
|
||||
])
|
||||
|
||||
const disableRedo = () => {
|
||||
if (isInteractiveSeg) {
|
||||
@ -1295,17 +1340,14 @@ export default function Editor() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isSD && settings.showCroper ? (
|
||||
<Croper
|
||||
maxHeight={original.naturalHeight}
|
||||
maxWidth={original.naturalWidth}
|
||||
minHeight={Math.min(256, original.naturalHeight)}
|
||||
minWidth={Math.min(256, original.naturalWidth)}
|
||||
scale={scale}
|
||||
show={(isSD || isPaintByExample) && settings.showCroper}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
|
||||
{isInteractiveSeg ? <InteractiveSeg /> : <></>}
|
||||
</TransformComponent>
|
||||
@ -1358,7 +1400,7 @@ export default function Editor() {
|
||||
)}
|
||||
|
||||
<div className="editor-toolkit-panel">
|
||||
{isSD || file === undefined ? (
|
||||
{isSD || isPaintByExample || file === undefined ? (
|
||||
<></>
|
||||
) : (
|
||||
<SizeSelector
|
||||
@ -1466,7 +1508,7 @@ export default function Editor() {
|
||||
onClick={download}
|
||||
/>
|
||||
|
||||
{settings.runInpaintingManually && !isSD && (
|
||||
{settings.runInpaintingManually && !isSD && !isPaintByExample && (
|
||||
<Button
|
||||
toolTip="Run Inpainting"
|
||||
tooltipPosition="top"
|
||||
|
@ -32,3 +32,12 @@ header {
|
||||
gap: 6px;
|
||||
justify-self: end;
|
||||
}
|
||||
|
||||
.mask-preview {
|
||||
max-height: 400px;
|
||||
max-width: 400px;
|
||||
margin-top: 30px;
|
||||
margin-left: 20px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
@ -2,12 +2,13 @@ import { ArrowUpTrayIcon } from '@heroicons/react/24/outline'
|
||||
import { PlayIcon } from '@radix-ui/react-icons'
|
||||
import React, { useState } from 'react'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||
import {
|
||||
fileState,
|
||||
interactiveSegClicksState,
|
||||
isInpaintingState,
|
||||
isSDState,
|
||||
maskState,
|
||||
runManuallyState,
|
||||
} from '../../store/Atoms'
|
||||
import Button from '../shared/Button'
|
||||
import Shortcuts from '../Shortcuts/Shortcuts'
|
||||
@ -16,14 +17,18 @@ import SettingIcon from '../Settings/SettingIcon'
|
||||
import PromptInput from './PromptInput'
|
||||
import CoffeeIcon from '../CoffeeIcon/CoffeeIcon'
|
||||
import emitter, { EVENT_CUSTOM_MASK } from '../../event'
|
||||
import { useImage } from '../../utils'
|
||||
|
||||
const Header = () => {
|
||||
const isInpainting = useRecoilValue(isInpaintingState)
|
||||
const [file, setFile] = useRecoilState(fileState)
|
||||
const [mask, setMask] = useRecoilState(maskState)
|
||||
const [maskImage, maskImageLoaded] = useImage(mask)
|
||||
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
|
||||
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
|
||||
const isSD = useRecoilValue(isSDState)
|
||||
const runManually = useRecoilValue(runManuallyState)
|
||||
const [openMaskPopover, setOpenMaskPopover] = useState(false)
|
||||
|
||||
const renderHeader = () => {
|
||||
return (
|
||||
@ -88,10 +93,11 @@ const Header = () => {
|
||||
onChange={ev => {
|
||||
const newFile = ev.currentTarget.files?.[0]
|
||||
if (newFile) {
|
||||
// TODO: check mask size
|
||||
console.info('Send custom mask')
|
||||
emitter.emit(EVENT_CUSTOM_MASK, { mask: newFile })
|
||||
setMask(newFile)
|
||||
console.info('Send custom mask')
|
||||
if (!runManually) {
|
||||
emitter.emit(EVENT_CUSTOM_MASK, { mask: newFile })
|
||||
}
|
||||
}
|
||||
}}
|
||||
accept="image/png, image/jpeg"
|
||||
@ -99,17 +105,42 @@ const Header = () => {
|
||||
Mask
|
||||
</Button>
|
||||
</label>
|
||||
<Button
|
||||
|
||||
<PopoverPrimitive.Root open={openMaskPopover}>
|
||||
<PopoverPrimitive.Trigger
|
||||
className="btn-primary side-panel-trigger"
|
||||
onMouseEnter={() => setOpenMaskPopover(true)}
|
||||
onMouseLeave={() => setOpenMaskPopover(false)}
|
||||
style={{
|
||||
visibility: mask ? 'visible' : 'hidden',
|
||||
outline: 'none',
|
||||
}}
|
||||
icon={<PlayIcon />}
|
||||
onClick={() => {
|
||||
if (mask) {
|
||||
emitter.emit(EVENT_CUSTOM_MASK, { mask })
|
||||
}
|
||||
}}
|
||||
>
|
||||
<PlayIcon />
|
||||
</PopoverPrimitive.Trigger>
|
||||
<PopoverPrimitive.Portal>
|
||||
<PopoverPrimitive.Content
|
||||
style={{
|
||||
outline: 'none',
|
||||
}}
|
||||
>
|
||||
{maskImageLoaded ? (
|
||||
<img
|
||||
src={maskImage.src}
|
||||
alt="mask"
|
||||
className="mask-preview"
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</PopoverPrimitive.Content>
|
||||
</PopoverPrimitive.Portal>
|
||||
</PopoverPrimitive.Root>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
@ -193,6 +193,8 @@ function ModelSettingBlock() {
|
||||
return undefined
|
||||
case AIModel.SD2:
|
||||
return undefined
|
||||
case AIModel.PAINT_BY_EXAMPLE:
|
||||
return undefined
|
||||
case AIModel.Mange:
|
||||
return undefined
|
||||
case AIModel.CV2:
|
||||
@ -258,6 +260,12 @@ function ModelSettingBlock() {
|
||||
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html',
|
||||
'https://docs.opencv.org/4.6.0/df/d3d/tutorial_py_inpainting.html'
|
||||
)
|
||||
case AIModel.PAINT_BY_EXAMPLE:
|
||||
return renderModelDesc(
|
||||
'Paint by Example',
|
||||
'https://arxiv.org/abs/2211.13227',
|
||||
'https://github.com/Fantasy-Studio/Paint-by-Example'
|
||||
)
|
||||
default:
|
||||
return <></>
|
||||
}
|
||||
@ -270,7 +278,6 @@ function ModelSettingBlock() {
|
||||
titleSuffix={renderPaperCodeBadge()}
|
||||
input={
|
||||
<Selector
|
||||
width={80}
|
||||
value={setting.model as string}
|
||||
options={Object.values(AIModel)}
|
||||
onChange={val => onModelChange(val as AIModel)}
|
||||
|
@ -1,12 +1,15 @@
|
||||
import React from 'react'
|
||||
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import { isSDState, settingState } from '../../store/Atoms'
|
||||
import {
|
||||
isPaintByExampleState,
|
||||
isSDState,
|
||||
settingState,
|
||||
} from '../../store/Atoms'
|
||||
import Modal from '../shared/Modal'
|
||||
import ManualRunInpaintingSettingBlock from './ManualRunInpaintingSettingBlock'
|
||||
import HDSettingBlock from './HDSettingBlock'
|
||||
import ModelSettingBlock from './ModelSettingBlock'
|
||||
import GraduallyInpaintingSettingBlock from './GraduallyInpaintingSettingBlock'
|
||||
import DownloadMaskSettingBlock from './DownloadMaskSettingBlock'
|
||||
import useHotKey from '../../hooks/useHotkey'
|
||||
|
||||
@ -17,6 +20,7 @@ export default function SettingModal(props: SettingModalProps) {
|
||||
const { onClose } = props
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
const isSD = useRecoilValue(isSDState)
|
||||
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
||||
|
||||
const handleOnClose = () => {
|
||||
setSettingState(old => {
|
||||
@ -43,7 +47,7 @@ export default function SettingModal(props: SettingModalProps) {
|
||||
className="modal-setting"
|
||||
show={setting.show}
|
||||
>
|
||||
{isSD ? <></> : <ManualRunInpaintingSettingBlock />}
|
||||
{isSD || isPaintByExample ? <></> : <ManualRunInpaintingSettingBlock />}
|
||||
|
||||
{/* <GraduallyInpaintingSettingBlock /> */}
|
||||
<DownloadMaskSettingBlock />
|
||||
|
231
lama_cleaner/app/src/components/SidePanel/PESidePanel.tsx
Normal file
231
lama_cleaner/app/src/components/SidePanel/PESidePanel.tsx
Normal file
@ -0,0 +1,231 @@
|
||||
import React, { useState } from 'react'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||
import { useToggle } from 'react-use'
|
||||
import { UploadIcon } from '@radix-ui/react-icons'
|
||||
import {
|
||||
isInpaintingState,
|
||||
paintByExampleImageState,
|
||||
settingState,
|
||||
} from '../../store/Atoms'
|
||||
import NumberInputSetting from '../Settings/NumberInputSetting'
|
||||
import SettingBlock from '../Settings/SettingBlock'
|
||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||
import Button from '../shared/Button'
|
||||
import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event'
|
||||
import { useImage } from '../../utils'
|
||||
|
||||
const INPUT_WIDTH = 30
|
||||
|
||||
const PESidePanel = () => {
|
||||
const [open, toggleOpen] = useToggle(true)
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
const [paintByExampleImage, setPaintByExampleImage] = useRecoilState(
|
||||
paintByExampleImageState
|
||||
)
|
||||
const [uploadElemId] = useState(
|
||||
`example-file-upload-${Math.random().toString()}`
|
||||
)
|
||||
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleImage)
|
||||
const isInpainting = useRecoilValue(isInpaintingState)
|
||||
|
||||
const renderUploadIcon = () => {
|
||||
return (
|
||||
<label htmlFor={uploadElemId}>
|
||||
<Button
|
||||
border
|
||||
toolTip="Upload example image"
|
||||
tooltipPosition="top"
|
||||
icon={<UploadIcon />}
|
||||
style={{ padding: '0.3rem', gap: 0 }}
|
||||
>
|
||||
<input
|
||||
style={{ display: 'none' }}
|
||||
id={uploadElemId}
|
||||
name={uploadElemId}
|
||||
type="file"
|
||||
onChange={ev => {
|
||||
const newFile = ev.currentTarget.files?.[0]
|
||||
if (newFile) {
|
||||
setPaintByExampleImage(newFile)
|
||||
}
|
||||
}}
|
||||
accept="image/png, image/jpeg"
|
||||
/>
|
||||
</Button>
|
||||
</label>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="side-panel">
|
||||
<PopoverPrimitive.Root open={open}>
|
||||
<PopoverPrimitive.Trigger
|
||||
className="btn-primary side-panel-trigger"
|
||||
onClick={() => toggleOpen()}
|
||||
>
|
||||
Configurations
|
||||
</PopoverPrimitive.Trigger>
|
||||
<PopoverPrimitive.Portal>
|
||||
<PopoverPrimitive.Content className="side-panel-content">
|
||||
<SettingBlock
|
||||
title="Croper"
|
||||
input={
|
||||
<Switch
|
||||
checked={setting.showCroper}
|
||||
onCheckedChange={value => {
|
||||
setSettingState(old => {
|
||||
return { ...old, showCroper: value }
|
||||
})
|
||||
}}
|
||||
>
|
||||
<SwitchThumb />
|
||||
</Switch>
|
||||
}
|
||||
/>
|
||||
|
||||
<NumberInputSetting
|
||||
title="Steps"
|
||||
width={INPUT_WIDTH}
|
||||
value={`${setting.paintByExampleSteps}`}
|
||||
desc="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, paintByExampleSteps: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
|
||||
<NumberInputSetting
|
||||
title="Guidance Scale"
|
||||
width={INPUT_WIDTH}
|
||||
allowFloat
|
||||
value={`${setting.paintByExampleGuidanceScale}`}
|
||||
desc="Higher guidance scale encourages to generate images that are close to the example image"
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||
setSettingState(old => {
|
||||
return { ...old, paintByExampleGuidanceScale: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
|
||||
<NumberInputSetting
|
||||
title="Mask Blur"
|
||||
width={INPUT_WIDTH}
|
||||
value={`${setting.paintByExampleMaskBlur}`}
|
||||
desc="Blur the edge of mask area. The higher the number the smoother blend with the original image"
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, paintByExampleMaskBlur: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
|
||||
<SettingBlock
|
||||
title="Match Histograms"
|
||||
desc="Match the inpainting result histogram to the source image histogram, will improves the inpainting quality for some images."
|
||||
input={
|
||||
<Switch
|
||||
checked={setting.paintByExampleMatchHistograms}
|
||||
onCheckedChange={value => {
|
||||
setSettingState(old => {
|
||||
return { ...old, paintByExampleMatchHistograms: value }
|
||||
})
|
||||
}}
|
||||
>
|
||||
<SwitchThumb />
|
||||
</Switch>
|
||||
}
|
||||
/>
|
||||
|
||||
<SettingBlock
|
||||
title="Seed"
|
||||
input={
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
gap: 0,
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
{/* 每次会从服务器返回更新该值 */}
|
||||
<NumberInputSetting
|
||||
title=""
|
||||
width={80}
|
||||
value={`${setting.paintByExampleSeed}`}
|
||||
desc=""
|
||||
disable={!setting.paintByExampleSeedFixed}
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, paintByExampleSeed: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
<Switch
|
||||
checked={setting.paintByExampleSeedFixed}
|
||||
onCheckedChange={value => {
|
||||
setSettingState(old => {
|
||||
return { ...old, paintByExampleSeedFixed: value }
|
||||
})
|
||||
}}
|
||||
style={{ marginLeft: '8px' }}
|
||||
>
|
||||
<SwitchThumb />
|
||||
</Switch>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
|
||||
<div style={{ display: 'flex', flexDirection: 'column' }}>
|
||||
<SettingBlock title="Example Image" input={renderUploadIcon()} />
|
||||
|
||||
{paintByExampleImage ? (
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<img
|
||||
src={exampleImage.src}
|
||||
alt="example"
|
||||
style={{
|
||||
maxWidth: 200,
|
||||
maxHeight: 200,
|
||||
margin: 12,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Button
|
||||
border
|
||||
disabled={!isExampleImageLoaded || isInpainting}
|
||||
style={{ width: '100%' }}
|
||||
onClick={() => {
|
||||
if (isExampleImageLoaded) {
|
||||
emitter.emit(EVENT_PAINT_BY_EXAMPLE, {
|
||||
image: paintByExampleImage,
|
||||
})
|
||||
}
|
||||
}}
|
||||
>
|
||||
Paint
|
||||
</Button>
|
||||
</PopoverPrimitive.Content>
|
||||
</PopoverPrimitive.Portal>
|
||||
</PopoverPrimitive.Root>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default PESidePanel
|
@ -7,6 +7,7 @@ import Toast from './shared/Toast'
|
||||
import {
|
||||
AIModel,
|
||||
fileState,
|
||||
isPaintByExampleState,
|
||||
isSDState,
|
||||
settingState,
|
||||
toastState,
|
||||
@ -17,12 +18,14 @@ import {
|
||||
switchModel,
|
||||
} from '../adapters/inpainting'
|
||||
import SidePanel from './SidePanel/SidePanel'
|
||||
import PESidePanel from './SidePanel/PESidePanel'
|
||||
|
||||
const Workspace = () => {
|
||||
const [file, setFile] = useRecoilState(fileState)
|
||||
const [settings, setSettingState] = useRecoilState(settingState)
|
||||
const [toastVal, setToastState] = useRecoilState(toastState)
|
||||
const isSD = useRecoilValue(isSDState)
|
||||
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
||||
|
||||
const onSettingClose = async () => {
|
||||
const curModel = await currentModel().then(res => res.text())
|
||||
@ -88,6 +91,7 @@ const Workspace = () => {
|
||||
return (
|
||||
<>
|
||||
{isSD ? <SidePanel /> : <></>}
|
||||
{isPaintByExample ? <PESidePanel /> : <></>}
|
||||
<Editor />
|
||||
<SettingModal onClose={onSettingClose} />
|
||||
<ShortcutsModal />
|
||||
|
@ -1,11 +1,17 @@
|
||||
import mitt from 'mitt'
|
||||
|
||||
export const EVENT_PROMPT = 'prompt'
|
||||
|
||||
export const EVENT_CUSTOM_MASK = 'custom_mask'
|
||||
export interface CustomMaskEventData {
|
||||
mask: File
|
||||
}
|
||||
|
||||
export const EVENT_PAINT_BY_EXAMPLE = 'paint_by_example'
|
||||
export interface PaintByExampleEventData {
|
||||
image: File
|
||||
}
|
||||
|
||||
const emitter = mitt()
|
||||
|
||||
export default emitter
|
||||
|
@ -13,6 +13,7 @@ export enum AIModel {
|
||||
SD2 = 'sd2',
|
||||
CV2 = 'cv2',
|
||||
Mange = 'manga',
|
||||
PAINT_BY_EXAMPLE = 'paint_by_example',
|
||||
}
|
||||
|
||||
export const maskState = atom<File | undefined>({
|
||||
@ -20,6 +21,11 @@ export const maskState = atom<File | undefined>({
|
||||
default: undefined,
|
||||
})
|
||||
|
||||
export const paintByExampleImageState = atom<File | undefined>({
|
||||
key: 'paintByExampleImageState',
|
||||
default: undefined,
|
||||
})
|
||||
|
||||
export interface Rect {
|
||||
x: number
|
||||
y: number
|
||||
@ -252,6 +258,14 @@ export interface Settings {
|
||||
// For OpenCV2
|
||||
cv2Radius: number
|
||||
cv2Flag: CV2Flag
|
||||
|
||||
// Paint by Example
|
||||
paintByExampleSteps: number
|
||||
paintByExampleGuidanceScale: number
|
||||
paintByExampleSeed: number
|
||||
paintByExampleSeedFixed: boolean
|
||||
paintByExampleMaskBlur: number
|
||||
paintByExampleMatchHistograms: boolean
|
||||
}
|
||||
|
||||
const defaultHDSettings: ModelsHDSettings = {
|
||||
@ -304,6 +318,13 @@ const defaultHDSettings: ModelsHDSettings = {
|
||||
hdStrategyCropMargin: 128,
|
||||
enabled: false,
|
||||
},
|
||||
[AIModel.PAINT_BY_EXAMPLE]: {
|
||||
hdStrategy: HDStrategy.ORIGINAL,
|
||||
hdStrategyResizeLimit: 768,
|
||||
hdStrategyCropTrigerSize: 512,
|
||||
hdStrategyCropMargin: 128,
|
||||
enabled: false,
|
||||
},
|
||||
[AIModel.Mange]: {
|
||||
hdStrategy: HDStrategy.CROP,
|
||||
hdStrategyResizeLimit: 1280,
|
||||
@ -364,6 +385,14 @@ export const settingStateDefault: Settings = {
|
||||
// CV2
|
||||
cv2Radius: 5,
|
||||
cv2Flag: CV2Flag.INPAINT_NS,
|
||||
|
||||
// Paint by Example
|
||||
paintByExampleSteps: 50,
|
||||
paintByExampleGuidanceScale: 7.5,
|
||||
paintByExampleSeed: 42,
|
||||
paintByExampleMaskBlur: 5,
|
||||
paintByExampleSeedFixed: false,
|
||||
paintByExampleMatchHistograms: false,
|
||||
}
|
||||
|
||||
const localStorageEffect =
|
||||
@ -401,11 +430,28 @@ export const seedState = selector({
|
||||
key: 'seed',
|
||||
get: ({ get }) => {
|
||||
const settings = get(settingState)
|
||||
return settings.sdSeed
|
||||
switch (settings.model) {
|
||||
case AIModel.PAINT_BY_EXAMPLE:
|
||||
return settings.paintByExampleSeedFixed
|
||||
? settings.paintByExampleSeed
|
||||
: -1
|
||||
default:
|
||||
return settings.sdSeedFixed ? settings.sdSeed : -1
|
||||
}
|
||||
},
|
||||
set: ({ get, set }, newValue: any) => {
|
||||
const settings = get(settingState)
|
||||
switch (settings.model) {
|
||||
case AIModel.PAINT_BY_EXAMPLE:
|
||||
if (!settings.paintByExampleSeedFixed) {
|
||||
set(settingState, { ...settings, paintByExampleSeed: newValue })
|
||||
}
|
||||
break
|
||||
default:
|
||||
if (!settings.sdSeedFixed) {
|
||||
set(settingState, { ...settings, sdSeed: newValue })
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
@ -435,11 +481,20 @@ export const isSDState = selector({
|
||||
},
|
||||
})
|
||||
|
||||
export const isPaintByExampleState = selector({
|
||||
key: 'isPaintByExampleState',
|
||||
get: ({ get }) => {
|
||||
const settings = get(settingState)
|
||||
return settings.model === AIModel.PAINT_BY_EXAMPLE
|
||||
},
|
||||
})
|
||||
|
||||
export const runManuallyState = selector({
|
||||
key: 'runManuallyState',
|
||||
get: ({ get }) => {
|
||||
const settings = get(settingState)
|
||||
const isSD = get(isSDState)
|
||||
return settings.runInpaintingManually || isSD
|
||||
const isPaintByExample = get(isPaintByExampleState)
|
||||
return settings.runInpaintingManually || isSD || isPaintByExample
|
||||
},
|
||||
})
|
||||
|
@ -211,6 +211,26 @@ class InpaintModel:
|
||||
|
||||
return result
|
||||
|
||||
def _apply_cropper(self, image, mask, config: Config):
|
||||
img_h, img_w = image.shape[:2]
|
||||
l, t, w, h = (
|
||||
config.croper_x,
|
||||
config.croper_y,
|
||||
config.croper_width,
|
||||
config.croper_height,
|
||||
)
|
||||
r = l + w
|
||||
b = t + h
|
||||
|
||||
l = max(l, 0)
|
||||
r = min(r, img_w)
|
||||
t = max(t, 0)
|
||||
b = min(b, img_h)
|
||||
|
||||
crop_img = image[t:b, l:r, :]
|
||||
crop_mask = mask[t:b, l:r]
|
||||
return crop_img, crop_mask, (l, t, r, b)
|
||||
|
||||
def _run_box(self, image, mask, box, config: Config):
|
||||
"""
|
||||
|
||||
|
80
lama_cleaner/model/paint_by_example.py
Normal file
80
lama_cleaner/model/paint_by_example.py
Normal file
@ -0,0 +1,80 @@
|
||||
import random
|
||||
|
||||
import PIL
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class PaintByExample(InpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu else torch.float32
|
||||
self.model = DiffusionPipeline.from_pretrained(
|
||||
"Fantasy-Studio/Paint-by-Example",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
self.model.enable_attention_slicing()
|
||||
self.model = self.model.to(device)
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
seed = config.paint_by_example_seed
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
output = self.model(
|
||||
image=PIL.Image.fromarray(image),
|
||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||
example_image=config.paint_by_example_example_image,
|
||||
num_inference_steps=config.paint_by_example_steps,
|
||||
output_type='np.array',
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
if config.use_croper:
|
||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
inpaint_result[t:b, l:r, :] = crop_image
|
||||
else:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
if config.paint_by_example_match_histograms:
|
||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
|
||||
if config.paint_by_example_mask_blur != 0:
|
||||
k = 2 * config.paint_by_example_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
return result, image, mask
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
@ -12,31 +12,6 @@ from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config, SDSampler
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
# def preprocess_image(image):
|
||||
# w, h = image.size
|
||||
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
# image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
# image = np.array(image).astype(np.float32) / 255.0
|
||||
# image = image[None].transpose(0, 3, 1, 2)
|
||||
# image = torch.from_numpy(image)
|
||||
# # [-1, 1]
|
||||
# return 2.0 * image - 1.0
|
||||
#
|
||||
#
|
||||
# def preprocess_mask(mask):
|
||||
# mask = mask.convert("L")
|
||||
# w, h = mask.size
|
||||
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
||||
# mask = np.array(mask).astype(np.float32) / 255.0
|
||||
# mask = np.tile(mask, (4, 1, 1))
|
||||
# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
# mask = 1 - mask # repaint white, keep black
|
||||
# mask = torch.from_numpy(mask)
|
||||
# return mask
|
||||
|
||||
class CPUTextEncoderWrapper:
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
self.config = text_encoder.config
|
||||
@ -75,7 +50,7 @@ class SD(InpaintModel):
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||
if kwargs['sd_enable_xformers']:
|
||||
if kwargs.get('sd_enable_xformers', False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
self.model = self.model.to(device)
|
||||
|
||||
@ -92,17 +67,6 @@ class SD(InpaintModel):
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
|
||||
# image = norm_img(image) # [0, 1]
|
||||
# image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
||||
|
||||
# resize to latent feature map size
|
||||
# h, w = mask.shape[:2]
|
||||
# mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA)
|
||||
# mask = norm_img(mask)
|
||||
#
|
||||
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
scheduler_config = self.model.scheduler.config
|
||||
|
||||
if config.sd_sampler == SDSampler.ddim:
|
||||
@ -139,7 +103,6 @@ class SD(InpaintModel):
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||
strength=config.sd_strength,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np.array",
|
||||
@ -159,30 +122,10 @@ class SD(InpaintModel):
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
# boxes = boxes_from_mask(mask)
|
||||
if config.use_croper:
|
||||
logger.info("use croper")
|
||||
l, t, w, h = (
|
||||
config.croper_x,
|
||||
config.croper_y,
|
||||
config.croper_width,
|
||||
config.croper_height,
|
||||
)
|
||||
r = l + w
|
||||
b = t + h
|
||||
|
||||
l = max(l, 0)
|
||||
r = min(r, img_w)
|
||||
t = max(t, 0)
|
||||
b = min(b, img_h)
|
||||
|
||||
crop_img = image[t:b, l:r, :]
|
||||
crop_mask = mask[t:b, l:r]
|
||||
|
||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
||||
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
inpaint_result[t:b, l:r, :] = crop_image
|
||||
else:
|
||||
|
@ -5,13 +5,14 @@ from lama_cleaner.model.lama import LaMa
|
||||
from lama_cleaner.model.ldm import LDM
|
||||
from lama_cleaner.model.manga import Manga
|
||||
from lama_cleaner.model.mat import MAT
|
||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||
from lama_cleaner.model.sd import SD15, SD2
|
||||
from lama_cleaner.model.zits import ZITS
|
||||
from lama_cleaner.model.opencv2 import OpenCV2
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
|
||||
"sd2": SD2}
|
||||
"sd2": SD2, "paint_by_example": PaintByExample}
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
@ -10,7 +10,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="lama",
|
||||
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2"],
|
||||
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_access_token",
|
||||
|
@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -29,6 +30,9 @@ class SDSampler(str, Enum):
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# Configs for ldm model
|
||||
ldm_steps: int
|
||||
ldm_sampler: str = LDMSampler.plms
|
||||
@ -73,3 +77,11 @@ class Config(BaseModel):
|
||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||
cv2_flag: str = 'INPAINT_NS'
|
||||
cv2_radius: int = 4
|
||||
|
||||
# Paint by Example
|
||||
paint_by_example_steps: int = 50
|
||||
paint_by_example_guidance_scale: float = 7.5
|
||||
paint_by_example_mask_blur: int = 0
|
||||
paint_by_example_seed: int = 42
|
||||
paint_by_example_match_histograms: bool = False
|
||||
paint_by_example_example_image: Image = None
|
||||
|
@ -10,6 +10,7 @@ import time
|
||||
import imghdr
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@ -97,8 +98,8 @@ def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
origin_image_bytes = input["image"].read()
|
||||
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||
|
||||
@ -115,6 +116,12 @@ def process():
|
||||
else:
|
||||
size_limit = int(size_limit)
|
||||
|
||||
if "paintByExampleImage" in input:
|
||||
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
|
||||
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
|
||||
else:
|
||||
paint_by_example_example_image = None
|
||||
|
||||
config = Config(
|
||||
ldm_steps=form["ldmSteps"],
|
||||
ldm_sampler=form["ldmSampler"],
|
||||
@ -138,11 +145,19 @@ def process():
|
||||
sd_seed=form["sdSeed"],
|
||||
sd_match_histograms=form["sdMatchHistograms"],
|
||||
cv2_flag=form["cv2Flag"],
|
||||
cv2_radius=form['cv2Radius']
|
||||
cv2_radius=form['cv2Radius'],
|
||||
paint_by_example_steps=form["paintByExampleSteps"],
|
||||
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
|
||||
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
|
||||
paint_by_example_seed=form["paintByExampleSeed"],
|
||||
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
|
||||
paint_by_example_example_image=paint_by_example_example_image,
|
||||
)
|
||||
|
||||
if config.sd_seed == -1:
|
||||
config.sd_seed = random.randint(1, 999999999)
|
||||
if config.paint_by_example_seed == -1:
|
||||
config.paint_by_example_seed = random.randint(1, 999999999)
|
||||
|
||||
logger.info(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
|
BIN
lama_cleaner/tests/bunny.jpeg
Normal file
BIN
lama_cleaner/tests/bunny.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 51 KiB |
50
lama_cleaner/tests/test_paint_by_example.py
Normal file
50
lama_cleaner/tests/test_paint_by_example.py
Normal file
@ -0,0 +1,50 @@
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import HDStrategy
|
||||
from lama_cleaner.tests.test_model import get_config, get_data
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / 'result'
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
device = torch.device(device)
|
||||
|
||||
|
||||
def assert_equal(
|
||||
model, config, gt_name,
|
||||
fx: float = 1, fy: float = 1,
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
example_p=current_dir / "bunny.jpeg",
|
||||
):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
|
||||
example_image = cv2.imread(str(example_p))
|
||||
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
|
||||
example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
||||
|
||||
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
|
||||
config.paint_by_example_example_image = Image.fromarray(example_image)
|
||||
res = model(img, mask, config)
|
||||
cv2.imwrite(str(save_dir / gt_name), res)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30 if device == 'cuda' else 1)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"paint_by_example_{strategy.capitalize()}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fy=0.9,
|
||||
fx=1.3
|
||||
)
|
@ -1,12 +1,10 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
|
||||
from lama_cleaner.schema import HDStrategy, SDSampler
|
||||
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
@ -96,7 +94,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 50
|
||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
||||
model = ModelManager(name="sd1.5",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
|
@ -10,5 +10,5 @@ pytest
|
||||
yacs
|
||||
markupsafe==2.0.1
|
||||
scikit-image==0.19.3
|
||||
diffusers[torch]==0.9
|
||||
transformers==4.21.0
|
||||
diffusers[torch]==0.10.2
|
||||
transformers>=4.25.1
|
||||
|
10
setup.py
10
setup.py
@ -31,11 +31,15 @@ setuptools.setup(
|
||||
packages=setuptools.find_packages("./"),
|
||||
package_data={"lama_cleaner": web_files},
|
||||
install_requires=load_requirements(),
|
||||
python_requires=">=3.6",
|
||||
python_requires=">=3.7",
|
||||
entry_points={"console_scripts": ["lama-cleaner=lama_cleaner:entry_point"]},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user