Merge branch 'make_gif_share'

This commit is contained in:
Qing 2023-02-07 21:47:11 +08:00
commit f837e4be8a
49 changed files with 15954 additions and 12250 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ build
!lama_cleaner/app/build !lama_cleaner/app/build
dist/ dist/
lama_cleaner.egg-info/ lama_cleaner.egg-info/
venv/

View File

@ -59,6 +59,6 @@ Only needed if you plan to modify the frontend and recompile yourself.
Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their
great online services [here](https://cleanup.pictures/). great online services [here](https://cleanup.pictures/).
- Install dependencies:`cd lama_cleaner/app/ && yarn` - Install dependencies:`cd lama_cleaner/app/ && pnpm install`
- Start development server: `yarn start` - Start development server: `pnpm start`
- Build: `yarn build` - Build: `pnpm build`

View File

@ -1,7 +1,8 @@
{ {
"files": { "files": {
"main.css": "/static/css/main.c28d98ca.css", "main.css": "/static/css/main.e24c9a9b.css",
"main.js": "/static/js/main.1bd455bc.js", "main.js": "/static/js/main.23732b19.js",
"static/media/coffee-machine-lineal.gif": "/static/media/coffee-machine-lineal.ee32631219cc3986f861.gif",
"static/media/WorkSans-SemiBold.ttf": "/static/media/WorkSans-SemiBold.1e98db4eb705b586728e.ttf", "static/media/WorkSans-SemiBold.ttf": "/static/media/WorkSans-SemiBold.1e98db4eb705b586728e.ttf",
"static/media/WorkSans-Bold.ttf": "/static/media/WorkSans-Bold.2bea7a7f7d052c74da25.ttf", "static/media/WorkSans-Bold.ttf": "/static/media/WorkSans-Bold.2bea7a7f7d052c74da25.ttf",
"static/media/WorkSans-Regular.ttf": "/static/media/WorkSans-Regular.bb287b894b27372d8ea7.ttf", "static/media/WorkSans-Regular.ttf": "/static/media/WorkSans-Regular.bb287b894b27372d8ea7.ttf",
@ -9,7 +10,7 @@
"index.html": "/index.html" "index.html": "/index.html"
}, },
"entrypoints": [ "entrypoints": [
"static/css/main.c28d98ca.css", "static/css/main.e24c9a9b.css",
"static/js/main.1bd455bc.js" "static/js/main.23732b19.js"
] ]
} }

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><script defer="defer" src="/static/js/main.1bd455bc.js"></script><link href="/static/css/main.c28d98ca.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html> <!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><script defer="defer" src="/static/js/main.23732b19.js"></script><link href="/static/css/main.e24c9a9b.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 KiB

View File

@ -2,7 +2,7 @@
"name": "lama-cleaner", "name": "lama-cleaner",
"version": "0.1.0", "version": "0.1.0",
"private": true, "private": true,
"proxy": "http://localhost:8080", "proxy": "http://127.0.0.1:8080",
"dependencies": { "dependencies": {
"@babel/core": "^7.16.0", "@babel/core": "^7.16.0",
"@heroicons/react": "^2.0.0", "@heroicons/react": "^2.0.0",

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
import { Rect, Settings } from '../store/Atoms' import { Rect, Settings } from '../store/Atoms'
import { dataURItoBlob, srcToFile } from '../utils' import { dataURItoBlob, loadImage, srcToFile } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
@ -82,6 +82,11 @@ export default async function inpaint(
fd.append('paintByExampleImage', paintByExampleImage) fd.append('paintByExampleImage', paintByExampleImage)
} }
// InstructPix2Pix
fd.append('p2pSteps', settings.p2pSteps.toString())
fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString())
fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString())
if (sizeLimit === undefined) { if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080') fd.append('sizeLimit', '1080')
} else { } else {
@ -223,3 +228,34 @@ export async function downloadToOutput(
throw new Error(`Something went wrong: ${error}`) throw new Error(`Something went wrong: ${error}`)
} }
} }
export async function makeGif(
originFile: File,
cleanImage: HTMLImageElement,
filename: string,
mimeType: string
) {
const cleanFile = await srcToFile(cleanImage.src, filename, mimeType)
const fd = new FormData()
fd.append('origin_img', originFile)
fd.append('clean_img', cleanFile)
fd.append('filename', filename)
try {
const res = await fetch(`${API_ENDPOINT}/make_gif`, {
method: 'POST',
body: fd,
})
if (!res.ok) {
const errMsg = await res.text()
throw new Error(errMsg)
}
const blob = await res.blob()
const newImage = new Image()
await loadImage(newImage, URL.createObjectURL(blob))
return newImage
} catch (error) {
throw new Error(`Something went wrong: ${error}`)
}
}

View File

@ -2,6 +2,7 @@ import React, { useState } from 'react'
import { Coffee } from 'react-feather' import { Coffee } from 'react-feather'
import Button from '../shared/Button' import Button from '../shared/Button'
import Modal from '../shared/Modal' import Modal from '../shared/Modal'
import CoffeeMachineGif from '../../media/coffee-machine-lineal.gif'
const CoffeeIcon = () => { const CoffeeIcon = () => {
const [show, setShow] = useState(false) const [show, setShow] = useState(false)
@ -24,10 +25,26 @@ const CoffeeIcon = () => {
show={show} show={show}
showCloseIcon={false} showCloseIcon={false}
> >
<h4 style={{ lineHeight: '24px' }}> <div
Hi there, If you found my project is useful, and want to help keep it style={{
alive please consider donating! Thank you for your support! display: 'flex',
</h4> flexDirection: 'column',
}}
>
<h4 style={{ lineHeight: '24px' }}>
Hi, if you found my project is useful, please conside buy me a
coffee to support my work. Thanks!
</h4>
<img
src={CoffeeMachineGif}
alt="coffee machine"
style={{
height: 150,
objectFit: 'contain',
}}
/>
</div>
<div <div
style={{ style={{
display: 'flex', display: 'flex',
@ -53,7 +70,6 @@ const CoffeeIcon = () => {
}} }}
> >
Sure Sure
<Coffee />
</div> </div>
</Button> </Button>
</a> </a>

View File

@ -55,10 +55,9 @@
position: fixed; position: fixed;
bottom: 0.5rem; bottom: 0.5rem;
border-radius: 3rem; border-radius: 3rem;
padding: 0.6rem 3rem; padding: 0.6rem 32px;
display: grid; display: flex;
grid-template-areas: 'toolkit-size-selector toolkit-brush-slider toolkit-btns'; gap: 16px;
column-gap: 2rem;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
backdrop-filter: blur(12px); backdrop-filter: blur(12px);
@ -96,10 +95,8 @@
} }
.editor-toolkit-btns { .editor-toolkit-btns {
grid-area: toolkit-btns; display: flex;
display: grid; gap: 12px;
grid-auto-flow: column;
column-gap: 1rem;
} }
.brush-shape { .brush-shape {

View File

@ -16,10 +16,11 @@ import {
TransformComponent, TransformComponent,
TransformWrapper, TransformWrapper,
} from 'react-zoom-pan-pinch' } from 'react-zoom-pan-pinch'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use' import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint, { import inpaint, {
downloadToOutput, downloadToOutput,
makeGif,
postInteractiveSeg, postInteractiveSeg,
} from '../../adapters/inpainting' } from '../../adapters/inpainting'
import Button from '../shared/Button' import Button from '../shared/Button'
@ -43,10 +44,12 @@ import {
imageHeightState, imageHeightState,
imageWidthState, imageWidthState,
interactiveSegClicksState, interactiveSegClicksState,
isDiffusionModelsState,
isInpaintingState, isInpaintingState,
isInteractiveSegRunningState, isInteractiveSegRunningState,
isInteractiveSegState, isInteractiveSegState,
isPaintByExampleState, isPaintByExampleState,
isPix2PixState,
isSDState, isSDState,
negativePropmtState, negativePropmtState,
propmtState, propmtState,
@ -66,6 +69,7 @@ import FileSelect from '../FileSelect/FileSelect'
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg' import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions' import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions'
import InteractiveSegReplaceModal from '../InteractiveSeg/ReplaceModal' import InteractiveSegReplaceModal from '../InteractiveSeg/ReplaceModal'
import MakeGIF from './MakeGIF'
const TOOLBAR_SIZE = 200 const TOOLBAR_SIZE = 200
const MIN_BRUSH_SIZE = 10 const MIN_BRUSH_SIZE = 10
@ -112,11 +116,11 @@ export default function Editor() {
const settings = useRecoilValue(settingState) const settings = useRecoilValue(settingState)
const [seedVal, setSeed] = useRecoilState(seedState) const [seedVal, setSeed] = useRecoilState(seedState)
const croperRect = useRecoilValue(croperState) const croperRect = useRecoilValue(croperState)
const [toastVal, setToastState] = useRecoilState(toastState) const setToastState = useSetRecoilState(toastState)
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState) const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
const runMannually = useRecoilValue(runManuallyState) const runMannually = useRecoilValue(runManuallyState)
const isSD = useRecoilValue(isSDState) const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
const isPaintByExample = useRecoilValue(isPaintByExampleState) const isPix2Pix = useRecoilValue(isPix2PixState)
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState( const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
isInteractiveSegState isInteractiveSegState
) )
@ -181,8 +185,8 @@ export default function Editor() {
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([]) const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
const enableFileManager = useRecoilValue(enableFileManagerState) const enableFileManager = useRecoilValue(enableFileManagerState)
const [imageWidth, setImageWidth] = useRecoilState(imageWidthState) const setImageWidth = useSetRecoilState(imageWidthState)
const [imageHeight, setImageHeight] = useRecoilState(imageHeightState) const setImageHeight = useSetRecoilState(imageHeightState)
const app = useRecoilValue(appState) const app = useRecoilValue(appState)
const draw = useCallback( const draw = useCallback(
@ -253,13 +257,40 @@ export default function Editor() {
_lineGroups.forEach(lineGroup => { _lineGroups.forEach(lineGroup => {
drawLines(ctx, lineGroup, 'white') drawLines(ctx, lineGroup, 'white')
}) })
if (
(maskImage === undefined || maskImage === null) &&
_lineGroups.length === 1 &&
_lineGroups[0].length === 0 &&
isPix2Pix
) {
// For InstructPix2Pix without mask
drawLines(
ctx,
[
{
size: 9999999999,
pts: [
{ x: 0, y: 0 },
{ x: original.naturalWidth, y: 0 },
{ x: original.naturalWidth, y: original.naturalHeight },
{ x: 0, y: original.naturalHeight },
],
},
],
'white'
)
}
}, },
[context, maskCanvas] [context, maskCanvas, isPix2Pix]
) )
const hadDrawSomething = useCallback(() => { const hadDrawSomething = useCallback(() => {
if (isPix2Pix) {
return true
}
return curLineGroup.length !== 0 return curLineGroup.length !== 0
}, [curLineGroup]) }, [curLineGroup, isPix2Pix])
const drawOnCurrentRender = useCallback( const drawOnCurrentRender = useCallback(
(lineGroup: LineGroup) => { (lineGroup: LineGroup) => {
@ -424,6 +455,8 @@ export default function Editor() {
} else if (prevInteractiveSegMask) { } else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成 // 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask) runInpainting(false, undefined, prevInteractiveSegMask)
} else if (isPix2Pix) {
runInpainting(false, undefined, null)
} else { } else {
setToastState({ setToastState({
open: true, open: true,
@ -839,7 +872,7 @@ export default function Editor() {
} }
if ( if (
(isSD || isPaintByExample) && isDiffusionModels &&
settings.showCroper && settings.showCroper &&
isOutsideCroper(mouseXY(ev)) isOutsideCroper(mouseXY(ev))
) { ) {
@ -1385,7 +1418,7 @@ export default function Editor() {
minHeight={Math.min(256, original.naturalHeight)} minHeight={Math.min(256, original.naturalHeight)}
minWidth={Math.min(256, original.naturalWidth)} minWidth={Math.min(256, original.naturalWidth)}
scale={scale} scale={scale}
show={(isSD || isPaintByExample) && settings.showCroper} show={isDiffusionModels && settings.showCroper}
/> />
{isInteractiveSeg ? <InteractiveSeg /> : <></>} {isInteractiveSeg ? <InteractiveSeg /> : <></>}
@ -1439,7 +1472,7 @@ export default function Editor() {
)} )}
<div className="editor-toolkit-panel"> <div className="editor-toolkit-panel">
{isSD || isPaintByExample || file === undefined ? ( {isDiffusionModels || file === undefined ? (
<></> <></>
) : ( ) : (
<SizeSelector <SizeSelector
@ -1534,6 +1567,7 @@ export default function Editor() {
}} }}
disabled={renders.length === 0} disabled={renders.length === 0}
/> />
<MakeGIF renders={renders} />
<Button <Button
toolTip="Save Image" toolTip="Save Image"
icon={<ArrowDownTrayIcon />} icon={<ArrowDownTrayIcon />}
@ -1541,7 +1575,7 @@ export default function Editor() {
onClick={download} onClick={download}
/> />
{settings.runInpaintingManually && !isSD && !isPaintByExample && ( {settings.runInpaintingManually && !isDiffusionModels && (
<Button <Button
toolTip="Run Inpainting" toolTip="Run Inpainting"
icon={ icon={

View File

@ -0,0 +1,114 @@
import React, { useState } from 'react'
import { GifIcon } from '@heroicons/react/24/outline'
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
import Button from '../shared/Button'
import { fileState, gifImageState, toastState } from '../../store/Atoms'
import { makeGif } from '../../adapters/inpainting'
import Modal from '../shared/Modal'
import { LoadingIcon } from '../shared/Toast'
import { downloadImage } from '../../utils'
interface Props {
renders: HTMLImageElement[]
}
const MakeGIF = (props: Props) => {
const { renders } = props
const [gifImg, setGifImg] = useRecoilState(gifImageState)
const file = useRecoilValue(fileState)
const setToastState = useSetRecoilState(toastState)
const [show, setShow] = useState(false)
const handleOnClose = () => {
setShow(false)
}
const handleDownload = () => {
if (gifImg) {
const name = file.name.replace(/\.[^/.]+$/, '.gif')
downloadImage(gifImg.src, name)
}
}
return (
<div>
<Button
toolTip="Make Gif"
icon={<GifIcon />}
disabled={!renders.length}
onClick={async () => {
setShow(true)
setGifImg(null)
try {
const gif = await makeGif(
file,
renders[renders.length - 1],
file.name,
file.type
)
if (gif) {
setGifImg(gif)
}
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 2000,
})
}
}}
/>
<Modal
onClose={handleOnClose}
title="GIF"
className="modal-setting"
show={show}
>
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
flexDirection: 'column',
gap: 16,
}}
>
{gifImg ? (
<img src={gifImg.src} style={{ borderRadius: 8 }} alt="gif" />
) : (
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
gap: 8,
}}
>
<LoadingIcon />
Generating GIF...
</div>
)}
{gifImg && (
<div
style={{
display: 'flex',
width: '100%',
justifyContent: 'flex-end',
alignItems: 'center',
gap: '12px',
}}
>
<Button onClick={handleDownload} border>
Download
</Button>
</div>
)}
</div>
</Modal>
</div>
)
}
export default MakeGIF

View File

@ -7,6 +7,7 @@ import {
enableFileManagerState, enableFileManagerState,
fileState, fileState,
isInpaintingState, isInpaintingState,
isPix2PixState,
isSDState, isSDState,
maskState, maskState,
runManuallyState, runManuallyState,
@ -30,6 +31,7 @@ const Header = () => {
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`) const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`) const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
const isSD = useRecoilValue(isSDState) const isSD = useRecoilValue(isSDState)
const isPix2Pix = useRecoilValue(isPix2PixState)
const runManually = useRecoilValue(runManuallyState) const runManually = useRecoilValue(runManuallyState)
const [openMaskPopover, setOpenMaskPopover] = useState(false) const [openMaskPopover, setOpenMaskPopover] = useState(false)
const [showFileManager, setShowFileManager] = const [showFileManager, setShowFileManager] =
@ -172,7 +174,7 @@ const Header = () => {
</div> </div>
</div> </div>
{isSD && file ? <PromptInput /> : <></>} {(isSD || isPix2Pix) && file ? <PromptInput /> : <></>}
<div className="header-icons-wrapper"> <div className="header-icons-wrapper">
<CoffeeIcon /> <CoffeeIcon />

View File

@ -179,28 +179,16 @@ function ModelSettingBlock() {
const renderOptionDesc = (): ReactNode => { const renderOptionDesc = (): ReactNode => {
switch (setting.model) { switch (setting.model) {
case AIModel.LAMA:
return undefined
case AIModel.LDM: case AIModel.LDM:
return renderLDMModelDesc() return renderLDMModelDesc()
case AIModel.ZITS: case AIModel.ZITS:
return renderZITSModelDesc() return renderZITSModelDesc()
case AIModel.MAT:
return undefined
case AIModel.FCF: case AIModel.FCF:
return renderFCFModelDesc() return renderFCFModelDesc()
case AIModel.SD15:
return undefined
case AIModel.SD2:
return undefined
case AIModel.PAINT_BY_EXAMPLE:
return undefined
case AIModel.Mange:
return undefined
case AIModel.CV2: case AIModel.CV2:
return renderOpenCV2Desc() return renderOpenCV2Desc()
default: default:
return <></> return undefined
} }
} }
@ -266,6 +254,12 @@ function ModelSettingBlock() {
'https://arxiv.org/abs/2211.13227', 'https://arxiv.org/abs/2211.13227',
'https://github.com/Fantasy-Studio/Paint-by-Example' 'https://github.com/Fantasy-Studio/Paint-by-Example'
) )
case AIModel.PIX2PIX:
return renderModelDesc(
'InstructPix2Pix',
'https://arxiv.org/abs/2211.09800',
'https://github.com/timothybrooks/instruct-pix2pix'
)
default: default:
return <></> return <></>
} }

View File

@ -1,5 +1,6 @@
import React from 'react' import React from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { Cog6ToothIcon } from '@heroicons/react/24/outline'
import { settingState } from '../../store/Atoms' import { settingState } from '../../store/Atoms'
import Button from '../shared/Button' import Button from '../shared/Button'
@ -16,29 +17,7 @@ const SettingIcon = () => {
onClick={onClick} onClick={onClick}
toolTip="Settings" toolTip="Settings"
style={{ border: 0 }} style={{ border: 0 }}
icon={ icon={<Cog6ToothIcon />}
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
role="img"
width="28"
height="28"
viewBox="0 0 24 24"
stroke="currentColor"
strokeWidth="2"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
/>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
/>
</svg>
}
/> />
</div> </div>
) )

View File

@ -8,7 +8,6 @@
color: var(--modal-text-color); color: var(--modal-text-color);
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2); box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
width: 680px; width: 680px;
min-height: 420px;
@include mobile { @include mobile {
display: grid; display: grid;

View File

@ -1,6 +1,7 @@
import React from 'react' import React from 'react'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue } from 'recoil'
import { import {
isDiffusionModelsState,
isPaintByExampleState, isPaintByExampleState,
isSDState, isSDState,
settingState, settingState,
@ -28,7 +29,7 @@ export default function SettingModal(props: SettingModalProps) {
const { onClose } = props const { onClose } = props
const [setting, setSettingState] = useRecoilState(settingState) const [setting, setSettingState] = useRecoilState(settingState)
const isSD = useRecoilValue(isSDState) const isSD = useRecoilValue(isSDState)
const isPaintByExample = useRecoilValue(isPaintByExampleState) const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
const handleOnClose = () => { const handleOnClose = () => {
setSettingState(old => { setSettingState(old => {
@ -56,9 +57,9 @@ export default function SettingModal(props: SettingModalProps) {
show={setting.show} show={setting.show}
> >
<DownloadMaskSettingBlock /> <DownloadMaskSettingBlock />
{isSD || isPaintByExample ? <></> : <ManualRunInpaintingSettingBlock />} {isDiffusionModels ? <></> : <ManualRunInpaintingSettingBlock />}
<ModelSettingBlock /> <ModelSettingBlock />
{isSD ? <></> : <HDSettingBlock />} {isDiffusionModels ? <></> : <HDSettingBlock />}
</Modal> </Modal>
) )
} }

View File

@ -15,13 +15,14 @@
} }
.shortcut-options { .shortcut-options {
display: grid; display: flex;
row-gap: 1rem; gap: 48px;
flex-direction: row;
.shortcut-option { .shortcut-option {
display: grid; display: grid;
grid-template-columns: repeat(2, auto); grid-template-columns: repeat(2, auto);
column-gap: 6rem; column-gap: 2rem;
align-items: center; align-items: center;
@include mobile { @include mobile {
@ -67,3 +68,10 @@
} }
} }
} }
.shortcut-options-column {
display: flex;
flex-direction: column;
gap: 12px;
width: 320px;
}

View File

@ -50,24 +50,37 @@ export default function ShortcutsModal() {
show={shortcutsShow} show={shortcutsShow}
> >
<div className="shortcut-options"> <div className="shortcut-options">
<ShortCut <div className="shortcut-options-column">
content="Multi-Stroke Mask Drawing" <ShortCut content="Pan" keys={['Space + Drag']} />
keys={[`Hold ${CmdOrCtrl}`]} <ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
/> <ShortCut content="Decrease Brush Size" keys={['[']} />
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} /> <ShortCut content="Increase Brush Size" keys={[']']} />
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} /> <ShortCut content="View Original Image" keys={['Hold Tab']} />
<ShortCut content="Interactive Segmentation" keys={['I']} /> <ShortCut
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} /> content="Multi-Stroke Drawing"
<ShortCut content="Redo Inpainting" keys={[CmdOrCtrl, 'Shift', 'Z']} /> keys={[`Hold ${CmdOrCtrl}`]}
<ShortCut content="View Original Image" keys={['Hold Tab']} /> />
<ShortCut content="Pan" keys={['Space + Drag']} /> <ShortCut content="Cancel Drawing" keys={['Esc']} />
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} /> </div>
<ShortCut content="Decrease Brush Size" keys={['[']} />
<ShortCut content="Increase Brush Size" keys={[']']} /> <div className="shortcut-options-column">
<ShortCut content="Toggle Dark Mode" keys={['Shift', 'D']} /> <ShortCut content="Undo" keys={[CmdOrCtrl, 'Z']} />
<ShortCut content="Toggle Hotkeys Dialog" keys={['H']} /> <ShortCut content="Redo" keys={[CmdOrCtrl, 'Shift', 'Z']} />
<ShortCut content="Toggle Settings Dialog" keys={['S']} /> <ShortCut content="Copy Result" keys={[CmdOrCtrl, 'C']} />
<ShortCut content="Toggle File Manager" keys={['F']} /> <ShortCut content="Paste Image" keys={[CmdOrCtrl, 'V']} />
<ShortCut
content="Trigger Manually Inpainting"
keys={['Shift', 'R']}
/>
<ShortCut content="Trigger Interactive Segmentation" keys={['I']} />
</div>
<div className="shortcut-options-column">
<ShortCut content="Switch Theme" keys={['Shift', 'D']} />
<ShortCut content="Toggle Hotkeys Dialog" keys={['H']} />
<ShortCut content="Toggle Settings Dialog" keys={['S']} />
<ShortCut content="Toggle File Manager" keys={['F']} />
</div>
</div> </div>
</Modal> </Modal>
) )

View File

@ -0,0 +1,224 @@
import React, { FormEvent } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use'
import {
isInpaintingState,
negativePropmtState,
propmtState,
settingState,
} from '../../store/Atoms'
import NumberInputSetting from '../Settings/NumberInputSetting'
import SettingBlock from '../Settings/SettingBlock'
import { Switch, SwitchThumb } from '../shared/Switch'
import TextAreaInput from '../shared/Textarea'
import emitter, { EVENT_PROMPT } from '../../event'
import ImageResizeScale from './ImageResizeScale'
const INPUT_WIDTH = 30
const P2PSidePanel = () => {
const [open, toggleOpen] = useToggle(true)
const [setting, setSettingState] = useRecoilState(settingState)
const [negativePrompt, setNegativePrompt] =
useRecoilState(negativePropmtState)
const isInpainting = useRecoilValue(isInpaintingState)
const prompt = useRecoilValue(propmtState)
const handleOnInput = (evt: FormEvent<HTMLTextAreaElement>) => {
evt.preventDefault()
evt.stopPropagation()
const target = evt.target as HTMLTextAreaElement
setNegativePrompt(target.value)
}
const onKeyUp = (e: React.KeyboardEvent) => {
if (
e.key === 'Enter' &&
(e.ctrlKey || e.metaKey) &&
prompt.length !== 0 &&
!isInpainting
) {
emitter.emit(EVENT_PROMPT)
}
}
return (
<div className="side-panel">
<PopoverPrimitive.Root open={open}>
<PopoverPrimitive.Trigger
className="btn-primary side-panel-trigger"
onClick={() => toggleOpen()}
>
Config
</PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="side-panel-content">
<SettingBlock
title="Croper"
input={
<Switch
checked={setting.showCroper}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, showCroper: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/>
<ImageResizeScale />
<NumberInputSetting
title="Steps"
width={INPUT_WIDTH}
value={`${setting.p2pSteps}`}
desc="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, p2pSteps: val }
})
}}
/>
<NumberInputSetting
title="Guidance Scale"
width={INPUT_WIDTH}
allowFloat
value={`${setting.p2pGuidanceScale}`}
desc="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality."
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, p2pGuidanceScale: val }
})
}}
/>
<NumberInputSetting
title="Image Guidance Scale"
width={INPUT_WIDTH}
allowFloat
value={`${setting.p2pImageGuidanceScale}`}
desc=""
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, p2pImageGuidanceScale: val }
})
}}
/>
{/* <NumberInputSetting
title="Mask Blur"
width={INPUT_WIDTH}
value={`${setting.sdMaskBlur}`}
desc="Blur the edge of mask area. The higher the number the smoother blend with the original image"
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, sdMaskBlur: val }
})
}}
/> */}
{/* <SettingBlock
title="Match Histograms"
desc="Match the inpainting result histogram to the source image histogram, will improves the inpainting quality for some images."
input={
<Switch
checked={setting.sdMatchHistograms}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, sdMatchHistograms: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/> */}
{/* <SettingBlock
className="sub-setting-block"
title="Sampler"
input={
<Selector
width={80}
value={setting.sdSampler as string}
options={Object.values(SDSampler)}
onChange={val => {
const sampler = val as SDSampler
setSettingState(old => {
return { ...old, sdSampler: sampler }
})
}}
/>
}
/> */}
<SettingBlock
title="Seed"
input={
<div
style={{
display: 'flex',
gap: 0,
justifyContent: 'center',
alignItems: 'center',
}}
>
<NumberInputSetting
title=""
width={80}
value={`${setting.sdSeed}`}
desc=""
disable={!setting.sdSeedFixed}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, sdSeed: val }
})
}}
/>
<Switch
checked={setting.sdSeedFixed}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, sdSeedFixed: value }
})
}}
style={{ marginLeft: '8px' }}
>
<SwitchThumb />
</Switch>
</div>
}
/>
<SettingBlock
className="sub-setting-block"
title="Negative prompt"
layout="v"
input={
<TextAreaInput
className="negative-prompt"
value={negativePrompt}
onInput={handleOnInput}
onKeyUp={onKeyUp}
placeholder=""
/>
}
/>
</PopoverPrimitive.Content>
</PopoverPrimitive.Portal>
</PopoverPrimitive.Root>
</div>
)
}
export default P2PSidePanel

View File

@ -8,6 +8,7 @@ import {
AIModel, AIModel,
fileState, fileState,
isPaintByExampleState, isPaintByExampleState,
isPix2PixState,
isSDState, isSDState,
settingState, settingState,
showFileManagerState, showFileManagerState,
@ -22,6 +23,7 @@ import {
import SidePanel from './SidePanel/SidePanel' import SidePanel from './SidePanel/SidePanel'
import PESidePanel from './SidePanel/PESidePanel' import PESidePanel from './SidePanel/PESidePanel'
import FileManager from './FileManager/FileManager' import FileManager from './FileManager/FileManager'
import P2PSidePanel from './SidePanel/P2PSidePanel'
const Workspace = () => { const Workspace = () => {
const setFile = useSetRecoilState(fileState) const setFile = useSetRecoilState(fileState)
@ -29,6 +31,7 @@ const Workspace = () => {
const [toastVal, setToastState] = useRecoilState(toastState) const [toastVal, setToastState] = useRecoilState(toastState)
const isSD = useRecoilValue(isSDState) const isSD = useRecoilValue(isSDState)
const isPaintByExample = useRecoilValue(isPaintByExampleState) const isPaintByExample = useRecoilValue(isPaintByExampleState)
const isPix2Pix = useRecoilValue(isPix2PixState)
const [showFileManager, setShowFileManager] = const [showFileManager, setShowFileManager] =
useRecoilState(showFileManagerState) useRecoilState(showFileManagerState)
@ -98,6 +101,7 @@ const Workspace = () => {
<> <>
{isSD ? <SidePanel /> : <></>} {isSD ? <SidePanel /> : <></>}
{isPaintByExample ? <PESidePanel /> : <></>} {isPaintByExample ? <PESidePanel /> : <></>}
{isPix2Pix ? <P2PSidePanel /> : <></>}
<FileManager <FileManager
photoWidth={256} photoWidth={256}
show={showFileManager} show={showFileManager}

View File

@ -3,7 +3,7 @@ import * as ToastPrimitive from '@radix-ui/react-toast'
import { ToastProps } from '@radix-ui/react-toast' import { ToastProps } from '@radix-ui/react-toast'
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/24/outline' import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/24/outline'
const LoadingIcon = () => { export const LoadingIcon = () => {
return ( return (
<span className="loading-icon"> <span className="loading-icon">
<svg <svg

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 KiB

View File

@ -14,6 +14,7 @@ export enum AIModel {
CV2 = 'cv2', CV2 = 'cv2',
Mange = 'manga', Mange = 'manga',
PAINT_BY_EXAMPLE = 'paint_by_example', PAINT_BY_EXAMPLE = 'paint_by_example',
PIX2PIX = 'instruct_pix2pix',
} }
export const maskState = atom<File | undefined>({ export const maskState = atom<File | undefined>({
@ -45,6 +46,7 @@ interface AppState {
interactiveSegClicks: number[][] interactiveSegClicks: number[][]
showFileManager: boolean showFileManager: boolean
enableFileManager: boolean enableFileManager: boolean
gifImage: HTMLImageElement | undefined
} }
export const appState = atom<AppState>({ export const appState = atom<AppState>({
@ -61,6 +63,7 @@ export const appState = atom<AppState>({
interactiveSegClicks: [], interactiveSegClicks: [],
showFileManager: false, showFileManager: false,
enableFileManager: false, enableFileManager: false,
gifImage: undefined,
}, },
}) })
@ -134,6 +137,18 @@ export const enableFileManagerState = selector({
}, },
}) })
export const gifImageState = selector({
key: 'gifImageState',
get: ({ get }) => {
const app = get(appState)
return app.gifImage
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, gifImage: newValue })
},
})
export const fileState = selector({ export const fileState = selector({
key: 'fileState', key: 'fileState',
get: ({ get }) => { get: ({ get }) => {
@ -329,6 +344,11 @@ export interface Settings {
paintByExampleSeedFixed: boolean paintByExampleSeedFixed: boolean
paintByExampleMaskBlur: number paintByExampleMaskBlur: number
paintByExampleMatchHistograms: boolean paintByExampleMatchHistograms: boolean
// InstructPix2Pix
p2pSteps: number
p2pImageGuidanceScale: number
p2pGuidanceScale: number
} }
const defaultHDSettings: ModelsHDSettings = { const defaultHDSettings: ModelsHDSettings = {
@ -388,6 +408,13 @@ const defaultHDSettings: ModelsHDSettings = {
hdStrategyCropMargin: 128, hdStrategyCropMargin: 128,
enabled: false, enabled: false,
}, },
[AIModel.PIX2PIX]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.Mange]: { [AIModel.Mange]: {
hdStrategy: HDStrategy.CROP, hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1280, hdStrategyResizeLimit: 1280,
@ -457,6 +484,11 @@ export const settingStateDefault: Settings = {
paintByExampleMaskBlur: 5, paintByExampleMaskBlur: 5,
paintByExampleSeedFixed: false, paintByExampleSeedFixed: false,
paintByExampleMatchHistograms: false, paintByExampleMatchHistograms: false,
// InstructPix2Pix
p2pSteps: 50,
p2pImageGuidanceScale: 1.5,
p2pGuidanceScale: 7.5,
} }
const localStorageEffect = const localStorageEffect =
@ -553,12 +585,33 @@ export const isPaintByExampleState = selector({
}, },
}) })
export const isPix2PixState = selector({
key: 'isPix2PixState',
get: ({ get }) => {
const settings = get(settingState)
return settings.model === AIModel.PIX2PIX
},
})
export const runManuallyState = selector({ export const runManuallyState = selector({
key: 'runManuallyState', key: 'runManuallyState',
get: ({ get }) => { get: ({ get }) => {
const settings = get(settingState) const settings = get(settingState)
const isSD = get(isSDState) const isSD = get(isSDState)
const isPaintByExample = get(isPaintByExampleState) const isPaintByExample = get(isPaintByExampleState)
return settings.runInpaintingManually || isSD || isPaintByExample const isPix2Pix = get(isPix2PixState)
return (
settings.runInpaintingManually || isSD || isPaintByExample || isPix2Pix
)
},
})
export const isDiffusionModelsState = selector({
key: 'isDiffusionModelsState',
get: ({ get }) => {
const isSD = get(isSDState)
const isPaintByExample = get(isPaintByExampleState)
const isPix2Pix = get(isPix2PixState)
return isSD || isPaintByExample || isPix2Pix
}, },
}) })

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,12 @@
import os import os
MPS_SUPPORT_MODELS = [
"instruct_pix2pix",
"sd1.5",
"sd2",
"paint_by_example"
]
DEFAULT_MODEL = "lama" DEFAULT_MODEL = "lama"
AVAILABLE_MODELS = [ AVAILABLE_MODELS = [
"lama", "lama",
@ -11,7 +18,8 @@ AVAILABLE_MODELS = [
"cv2", "cv2",
"manga", "manga",
"sd2", "sd2",
"paint_by_example" "paint_by_example",
"instruct_pix2pix",
] ]
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]

View File

@ -62,7 +62,7 @@ def load_model(model: torch.nn.Module, url_or_path, device):
model_path = download_model(url_or_path) model_path = download_model(url_or_path)
try: try:
state_dict = torch.load(model_path, map_location='cpu') state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
model.to(device) model.to(device)
logger.info(f"Load model from: {model_path}") logger.info(f"Load model from: {model_path}")
@ -85,26 +85,43 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
return image_bytes return image_bytes
def load_img(img_bytes, gray: bool = False): def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes:
with io.BytesIO() as output:
pil_img.save(output, format=ext, exif=exif, quality=95)
image_bytes = output.getvalue()
return image_bytes
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
alpha_channel = None alpha_channel = None
image = Image.open(io.BytesIO(img_bytes)) image = Image.open(io.BytesIO(img_bytes))
try:
if return_exif:
exif = image.getexif()
except:
exif = None
logger.error("Failed to extract exif from image")
try: try:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
except: except:
pass pass
if gray: if gray:
image = image.convert('L') image = image.convert("L")
np_img = np.array(image) np_img = np.array(image)
else: else:
if image.mode == 'RGBA': if image.mode == "RGBA":
np_img = np.array(image) np_img = np.array(image)
alpha_channel = np_img[:, :, -1] alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else: else:
image = image.convert('RGB') image = image.convert("RGB")
np_img = np.array(image) np_img = np.array(image)
if return_exif:
return np_img, alpha_channel, exif
return np_img, alpha_channel return np_img, alpha_channel

125
lama_cleaner/make_gif.py Normal file
View File

@ -0,0 +1,125 @@
import io
import math
from pathlib import Path
from PIL import Image, ImageDraw
def keep_ratio_resize(img, size, resample=Image.BILINEAR):
if img.width > img.height:
w = size
h = int(img.height * size / img.width)
else:
h = size
w = int(img.width * size / img.height)
return img.resize((w, h), resample)
def cubic_bezier(p1, p2, duration: int, frames: int):
"""
Args:
p1:
p2:
duration: Total duration of the curve
frames:
Returns:
"""
x0, y0 = (0, 0)
x1, y1 = p1
x2, y2 = p2
x3, y3 = (1, 1)
def cal_y(t):
return math.pow(1 - t, 3) * y0 + \
3 * math.pow(1 - t, 2) * t * y1 + \
3 * (1 - t) * math.pow(t, 2) * y2 + \
math.pow(t, 3) * y3
def cal_x(t):
return math.pow(1 - t, 3) * x0 + \
3 * math.pow(1 - t, 2) * t * x1 + \
3 * (1 - t) * math.pow(t, 2) * x2 + \
math.pow(t, 3) * x3
res = []
for t in range(0, 1 * frames, duration):
t = t / frames
res.append((cal_x(t), cal_y(t)))
res.append((1, 0))
return res
def make_compare_gif(
clean_img: Image.Image,
src_img: Image.Image,
max_side_length: int = 600,
splitter_width: int = 5,
splitter_color=(255, 203, 0, int(255 * 0.73))
):
if clean_img.size != src_img.size:
clean_img = clean_img.resize(src_img.size, Image.BILINEAR)
duration_per_frame = 20
num_frames = 50
# erase-in-out
cubic_bezier_points = cubic_bezier((0.33, 0), (0.66, 1), 1, num_frames)
cubic_bezier_points.reverse()
max_side_length = min(max_side_length, max(clean_img.size))
src_img = keep_ratio_resize(src_img, max_side_length)
clean_img = keep_ratio_resize(clean_img, max_side_length)
width, height = src_img.size
# Generate images to make Gif from right to left
images = []
for i in range(num_frames):
new_frame = Image.new('RGB', (width, height))
new_frame.paste(clean_img, (0, 0))
left = int(cubic_bezier_points[i][0] * width)
cropped_src_img = src_img.crop((left, 0, width, height))
new_frame.paste(cropped_src_img, (left, 0, width, height))
if i != num_frames - 1:
# draw a yellow splitter on the edge of the cropped image
draw = ImageDraw.Draw(new_frame)
draw.line([(left, 0), (left, height)], width=splitter_width, fill=splitter_color)
images.append(new_frame)
for i in range(10):
images.append(src_img)
cubic_bezier_points.reverse()
# Generate images to make Gif from left to right
for i in range(num_frames):
new_frame = Image.new('RGB', (width, height))
new_frame.paste(src_img, (0, 0))
right = int(cubic_bezier_points[i][0] * width)
cropped_src_img = clean_img.crop((0, 0, right, height))
new_frame.paste(cropped_src_img, (0, 0, right, height))
if i != num_frames - 1:
# draw a yellow splitter on the edge of the cropped image
draw = ImageDraw.Draw(new_frame)
draw.line([(right, 0), (right, height)], width=splitter_width, fill=splitter_color)
images.append(new_frame)
images.append(clean_img)
img_byte_arr = io.BytesIO()
clean_img.save(
img_byte_arr,
format='GIF',
save_all=True,
include_color_table=True,
append_images=images,
optimize=False,
duration=duration_per_frame,
loop=0
)
return img_byte_arr.getvalue()

View File

@ -245,3 +245,42 @@ class InpaintModel:
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
class DiffusionInpaintModel(InpaintModel):
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result

View File

@ -0,0 +1,83 @@
import PIL.Image
import cv2
import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config
class InstructPix2Pix(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers import StableDiffusionInstructPix2PixPipeline
fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
))
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix",
revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype,
**model_kwargs
)
self.model.enable_attention_slicing()
if kwargs.get('enable_xformers', False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get('cpu_offload', False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
"""
set_seed(config.sd_seed)
output = self.model(
image=PIL.Image.fromarray(image),
prompt=config.prompt,
negative_prompt=config.negative_prompt,
num_inference_steps=config.p2p_steps,
image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.p2p_guidance_scale,
output_type="np.array",
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
#
# def forward_post_process(self, result, image, mask, config):
# if config.sd_match_histograms:
# result = self._match_histograms(result, image[:, :, ::-1], mask)
#
# if config.sd_mask_blur != 0:
# k = 2 * config.sd_mask_blur + 1
# mask = cv2.GaussianBlur(mask, (k, k), 0)
# return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -1,19 +1,16 @@
import random
import PIL import PIL
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from loguru import logger from loguru import logger
from lama_cleaner.helper import resize_max_size from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
class PaintByExample(InpaintModel): class PaintByExample(DiffusionInpaintModel):
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
@ -53,11 +50,7 @@ class PaintByExample(InpaintModel):
mask: [H, W, 1] 255 means area to repaint mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE return: BGR IMAGE
""" """
seed = config.paint_by_example_seed set_seed(config.paint_by_example_seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
output = self.model( output = self.model(
image=PIL.Image.fromarray(image), image=PIL.Image.fromarray(image),
@ -71,42 +64,6 @@ class PaintByExample(InpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def forward_post_process(self, result, image, mask, config): def forward_post_process(self, result, image, mask, config):
if config.paint_by_example_match_histograms: if config.paint_by_example_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask) result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -4,20 +4,25 @@ import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \ from diffusers import (
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler PNDMScheduler,
DDIMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from loguru import logger from loguru import logger
from lama_cleaner.helper import resize_max_size from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.utils import torch_gc, set_seed
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.schema import Config, SDSampler from lama_cleaner.schema import Config, SDSampler
class CPUTextEncoderWrapper: class CPUTextEncoderWrapper:
def __init__(self, text_encoder, torch_dtype): def __init__(self, text_encoder, torch_dtype):
self.config = text_encoder.config self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True) self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
del text_encoder del text_encoder
@ -25,27 +30,40 @@ class CPUTextEncoderWrapper:
def __call__(self, x, **kwargs): def __call__(self, x, **kwargs):
input_device = x.device input_device = x.device
return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)] return [
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
.to(input_device)
.to(self.torch_dtype)
]
@property
def dtype(self):
return self.torch_dtype
class SD(InpaintModel): class SD(DiffusionInpaintModel):
pad_mod = 8 pad_mod = 8
min_size = 512 min_size = 512
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])} fp16 = not kwargs.get("no_half", False)
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict( model_kwargs.update(
safety_checker=None, dict(
feature_extractor=None, safety_checker=None,
requires_safety_checker=False feature_extractor=None,
)) requires_safety_checker=False,
)
)
use_gpu = device == torch.device('cuda') and torch.cuda.is_available() use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
@ -58,40 +76,23 @@ class SD(InpaintModel):
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing() self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get('enable_xformers', False): if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention() self.model.enable_xformers_memory_efficient_attention()
if kwargs.get('cpu_offload', False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
# TODO: gpu_id # TODO: gpu_id
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)
else: else:
self.model = self.model.to(device) self.model = self.model.to(device)
if kwargs['sd_cpu_textencoder']: if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU") logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype) self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -118,11 +119,7 @@ class SD(InpaintModel):
self.model.scheduler = scheduler self.model.scheduler = scheduler
seed = config.sd_seed set_seed(config.sd_seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if config.sd_mask_blur != 0: if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1 k = 2 * config.sd_mask_blur + 1
@ -147,24 +144,6 @@ class SD(InpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def forward_post_process(self, result, image, mask, config): def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms: if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask) result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -1,4 +1,5 @@
import math import math
import random
from typing import Any from typing import Any
import torch import torch
@ -713,3 +714,10 @@ def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@ -7,13 +7,14 @@ from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.manga import Manga from lama_cleaner.model.manga import Manga
from lama_cleaner.model.mat import MAT from lama_cleaner.model.mat import MAT
from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.paint_by_example import PaintByExample
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
from lama_cleaner.model.sd import SD15, SD2 from lama_cleaner.model.sd import SD15, SD2
from lama_cleaner.model.zits import ZITS from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga, models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
"sd2": SD2, "paint_by_example": PaintByExample} "sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix}
class ModelManager: class ModelManager:

View File

@ -5,9 +5,25 @@ from pathlib import Path
from loguru import logger from loguru import logger
from lama_cleaner.const import AVAILABLE_MODELS, NO_HALF_HELP, CPU_OFFLOAD_HELP, DISABLE_NSFW_HELP, \ from lama_cleaner.const import (
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, AVAILABLE_DEVICES, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, \ AVAILABLE_MODELS,
OUTPUT_DIR_HELP, INPUT_HELP, GUI_HELP, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR NO_HALF_HELP,
CPU_OFFLOAD_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
AVAILABLE_DEVICES,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR,
DEFAULT_MODEL,
MPS_SUPPORT_MODELS,
)
from lama_cleaner.runtime import dump_environment_info from lama_cleaner.runtime import dump_environment_info
@ -16,22 +32,40 @@ def parse_args():
parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int) parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--config-installer", action="store_true", parser.add_argument(
help="Open config web page, mainly for windows installer") "--config-installer",
parser.add_argument("--load-installer-config", action="store_true", action="store_true",
help="Load all cmd args from installer config file") help="Open config web page, mainly for windows installer",
parser.add_argument("--installer-config", default=None, help="Config file for windows installer") )
parser.add_argument(
"--load-installer-config",
action="store_true",
help="Load all cmd args from installer config file",
)
parser.add_argument(
"--installer-config", default=None, help="Config file for windows installer"
)
parser.add_argument("--model", default="lama", choices=AVAILABLE_MODELS) parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS)
parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP)
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
parser.add_argument("--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP) parser.add_argument(
parser.add_argument("--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP) "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
parser.add_argument("--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP) )
parser.add_argument("--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES) parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
)
parser.add_argument(
"--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP
)
parser.add_argument(
"--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES
)
parser.add_argument("--gui", action="store_true", help=GUI_HELP) parser.add_argument("--gui", action="store_true", help=GUI_HELP)
parser.add_argument("--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP) parser.add_argument(
"--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP
)
parser.add_argument( parser.add_argument(
"--gui-size", "--gui-size",
default=[1600, 1000], default=[1600, 1000],
@ -41,8 +75,14 @@ def parse_args():
) )
parser.add_argument("--input", type=str, default=None, help=INPUT_HELP) parser.add_argument("--input", type=str, default=None, help=INPUT_HELP)
parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP) parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP)
parser.add_argument("--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP) parser.add_argument(
parser.add_argument("--disable-model-switch", action="store_true", help="Disable model switch in frontend") "--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP
)
parser.add_argument(
"--disable-model-switch",
action="store_true",
help="Disable model switch in frontend",
)
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
# useless args # useless args
@ -64,7 +104,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--sd-enable-xformers", "--sd-enable-xformers",
action="store_true", action="store_true",
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers" help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers",
) )
args = parser.parse_args() args = parser.parse_args()
@ -74,14 +114,18 @@ def parse_args():
if args.config_installer: if args.config_installer:
if args.installer_config is None: if args.installer_config is None:
parser.error(f"args.config_installer==True, must set args.installer_config to store config file") parser.error(
f"args.config_installer==True, must set args.installer_config to store config file"
)
from lama_cleaner.web_config import main from lama_cleaner.web_config import main
logger.info(f"Launching installer web config page") logger.info(f"Launching installer web config page")
main(args.installer_config) main(args.installer_config)
exit() exit()
if args.load_installer_config: if args.load_installer_config:
from lama_cleaner.web_config import load_config from lama_cleaner.web_config import load_config
if args.installer_config and not os.path.exists(args.installer_config): if args.installer_config and not os.path.exists(args.installer_config):
parser.error(f"args.installer_config={args.installer_config} not exists") parser.error(f"args.installer_config={args.installer_config} not exists")
@ -93,9 +137,17 @@ def parse_args():
if args.device == "cuda": if args.device == "cuda":
import torch import torch
if torch.cuda.is_available() is False: if torch.cuda.is_available() is False:
parser.error( parser.error(
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
)
if args.device == "mps":
if args.model not in MPS_SUPPORT_MODELS:
parser.error(
f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}"
)
if args.model_dir and args.model_dir is not None: if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):
@ -115,7 +167,9 @@ def parse_args():
parser.error(f"invalid --input: {args.input} is not a valid image file") parser.error(f"invalid --input: {args.input} is not a valid image file")
else: else:
if args.output_dir is None: if args.output_dir is None:
parser.error(f"invalid --input: {args.input} is a directory, --output-dir is required") parser.error(
f"invalid --input: {args.input} is a directory, --output-dir is required"
)
else: else:
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
if not output_dir.exists(): if not output_dir.exists():
@ -123,6 +177,8 @@ def parse_args():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
else: else:
if not output_dir.is_dir(): if not output_dir.is_dir():
parser.error(f"invalid --output-dir: {output_dir} is not a directory") parser.error(
f"invalid --output-dir: {output_dir} is not a directory"
)
return args return args

View File

@ -88,3 +88,8 @@ class Config(BaseModel):
paint_by_example_seed: int = 42 paint_by_example_seed: int = 42
paint_by_example_match_histograms: bool = False paint_by_example_match_histograms: bool = False
paint_by_example_example_image: Image = None paint_by_example_example_image: Image = None
# InstructPix2Pix
p2p_steps: int = 50
p2p_image_guidance_scale: float = 1.5
p2p_guidance_scale: float = 7.5

View File

@ -19,6 +19,7 @@ from loguru import logger
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler
from lama_cleaner.interactive_seg import InteractiveSeg, Click from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager from lama_cleaner.file_manager import FileManager
@ -31,7 +32,15 @@ try:
except: except:
pass pass
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify from flask import (
Flask,
request,
send_file,
cli,
make_response,
send_from_directory,
jsonify,
)
# Disable ability for Flask to display warning about using a development server in a production environment. # Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 # https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
@ -42,6 +51,7 @@ from lama_cleaner.helper import (
load_img, load_img,
numpy_to_bytes, numpy_to_bytes,
resize_max_size, resize_max_size,
pil_to_bytes,
) )
NUM_THREADS = str(multiprocessing.cpu_count()) NUM_THREADS = str(multiprocessing.cpu_count())
@ -93,6 +103,25 @@ def diffuser_callback(i, t, latents):
# socketio.emit('diffusion_step', {'diffusion_step': step}) # socketio.emit('diffusion_step', {'diffusion_step': step})
@app.route("/make_gif", methods=["POST"])
def make_gif():
input = request.files
filename = request.form["filename"]
origin_image_bytes = input["origin_img"].read()
clean_image_bytes = input["clean_img"].read()
origin_image, _ = load_img(origin_image_bytes)
clean_image, _ = load_img(clean_image_bytes)
gif_bytes = make_compare_gif(
Image.fromarray(origin_image), Image.fromarray(clean_image)
)
return send_file(
io.BytesIO(gif_bytes),
mimetype="image/gif",
as_attachment=True,
attachment_filename=filename,
)
@app.route("/save_image", methods=["POST"]) @app.route("/save_image", methods=["POST"])
def save_image(): def save_image():
# all image in output directory # all image in output directory
@ -100,12 +129,12 @@ def save_image():
origin_image_bytes = input["image"].read() # RGB origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes) image, _ = load_img(origin_image_bytes)
thumb.save_to_output_directory(image, request.form["filename"]) thumb.save_to_output_directory(image, request.form["filename"])
return 'ok', 200 return "ok", 200
@app.route("/medias/<tab>") @app.route("/medias/<tab>")
def medias(tab): def medias(tab):
if tab == 'image': if tab == "image":
response = make_response(jsonify(thumb.media_names), 200) response = make_response(jsonify(thumb.media_names), 200)
else: else:
response = make_response(jsonify(thumb.output_media_names), 200) response = make_response(jsonify(thumb.output_media_names), 200)
@ -116,18 +145,18 @@ def medias(tab):
return response return response
@app.route('/media/<tab>/<filename>') @app.route("/media/<tab>/<filename>")
def media_file(tab, filename): def media_file(tab, filename):
if tab == 'image': if tab == "image":
return send_from_directory(thumb.root_directory, filename) return send_from_directory(thumb.root_directory, filename)
return send_from_directory(thumb.output_dir, filename) return send_from_directory(thumb.output_dir, filename)
@app.route('/media_thumbnail/<tab>/<filename>') @app.route("/media_thumbnail/<tab>/<filename>")
def media_thumbnail_file(tab, filename): def media_thumbnail_file(tab, filename):
args = request.args args = request.args
width = args.get('width') width = args.get("width")
height = args.get('height') height = args.get("height")
if width is None and height is None: if width is None and height is None:
width = 256 width = 256
if width: if width:
@ -136,9 +165,11 @@ def media_thumbnail_file(tab, filename):
height = int(float(height)) height = int(float(height))
directory = thumb.root_directory directory = thumb.root_directory
if tab == 'output': if tab == "output":
directory = thumb.output_dir directory = thumb.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height) thumb_filename, (width, height) = thumb.get_thumbnail(
directory, filename, width, height
)
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
response = make_response(send_file(thumb_filepath)) response = make_response(send_file(thumb_filepath))
@ -152,13 +183,16 @@ def process():
input = request.files input = request.files
# RGB # RGB
origin_image_bytes = input["image"].read() origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes) image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True)
mask, _ = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]: if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400 return (
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
400,
)
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
@ -171,7 +205,9 @@ def process():
size_limit = int(size_limit) size_limit = int(size_limit)
if "paintByExampleImage" in input: if "paintByExampleImage" in input:
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read()) paint_by_example_example_image, _ = load_img(
input["paintByExampleImage"].read()
)
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else: else:
paint_by_example_example_image = None paint_by_example_example_image = None
@ -200,13 +236,16 @@ def process():
sd_seed=form["sdSeed"], sd_seed=form["sdSeed"],
sd_match_histograms=form["sdMatchHistograms"], sd_match_histograms=form["sdMatchHistograms"],
cv2_flag=form["cv2Flag"], cv2_flag=form["cv2Flag"],
cv2_radius=form['cv2Radius'], cv2_radius=form["cv2Radius"],
paint_by_example_steps=form["paintByExampleSteps"], paint_by_example_steps=form["paintByExampleSteps"],
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"], paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
paint_by_example_mask_blur=form["paintByExampleMaskBlur"], paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
paint_by_example_seed=form["paintByExampleSeed"], paint_by_example_seed=form["paintByExampleSeed"],
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"], paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
paint_by_example_example_image=paint_by_example_example_image, paint_by_example_example_image=paint_by_example_example_image,
p2p_steps=form["p2pSteps"],
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
p2p_guidance_scale=form["p2pGuidanceScale"],
) )
if config.sd_seed == -1: if config.sd_seed == -1:
@ -235,6 +274,7 @@ def process():
logger.info(f"process time: {(time.time() - start) * 1000}ms") logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache() torch.cuda.empty_cache()
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
if alpha_channel is not None: if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]: if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize( alpha_channel = cv2.resize(
@ -246,9 +286,15 @@ def process():
ext = get_image_ext(origin_image_bytes) ext = get_image_ext(origin_image_bytes)
if exif is not None:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, exif=exif))
else:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext))
response = make_response( response = make_response(
send_file( send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)), # io.BytesIO(numpy_to_bytes(res_np_img, ext)),
bytes_io,
mimetype=f"image/{ext}", mimetype=f"image/{ext}",
) )
) )
@ -261,7 +307,7 @@ def interactive_seg():
input = request.files input = request.files
origin_image_bytes = input["image"].read() # RGB origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes) image, _ = load_img(origin_image_bytes)
if 'mask' in input: if "mask" in input:
mask, _ = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
else: else:
mask = None mask = None
@ -269,14 +315,16 @@ def interactive_seg():
_clicks = json.loads(request.form["clicks"]) _clicks = json.loads(request.form["clicks"])
clicks = [] clicks = []
for i, click in enumerate(_clicks): for i, click in enumerate(_clicks):
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)) clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
start = time.time() start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask) new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms") logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
response = make_response( response = make_response(
send_file( send_file(
io.BytesIO(numpy_to_bytes(new_mask, 'png')), io.BytesIO(numpy_to_bytes(new_mask, "png")),
mimetype=f"image/png", mimetype=f"image/png",
) )
) )
@ -290,13 +338,13 @@ def current_model():
@app.route("/is_disable_model_switch") @app.route("/is_disable_model_switch")
def get_is_disable_model_switch(): def get_is_disable_model_switch():
res = 'true' if is_disable_model_switch else 'false' res = "true" if is_disable_model_switch else "false"
return res, 200 return res, 200
@app.route("/is_enable_file_manager") @app.route("/is_enable_file_manager")
def get_is_enable_file_manager(): def get_is_enable_file_manager():
res = 'true' if is_enable_file_manager else 'false' res = "true" if is_enable_file_manager else "false"
return res, 200 return res, 200
@ -365,14 +413,18 @@ def main(args):
is_disable_model_switch = args.disable_model_switch is_disable_model_switch = args.disable_model_switch
is_desktop = args.gui is_desktop = args.gui
if is_disable_model_switch: if is_disable_model_switch:
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable") logger.info(
f"Start with --disable-model-switch, model switch on frontend is disable"
)
if args.input and os.path.isdir(args.input): if args.input and os.path.isdir(args.input):
logger.info(f"Initialize file manager") logger.info(f"Initialize file manager")
thumb = FileManager(app) thumb = FileManager(app)
is_enable_file_manager = True is_enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails') app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
args.output_dir, "lama_cleaner_thumbnails"
)
thumb.output_dir = Path(args.output_dir) thumb.output_dir = Path(args.output_dir)
# thumb.start() # thumb.start()
# try: # try:
@ -408,8 +460,12 @@ def main(args):
from flaskwebgui import FlaskUI from flaskwebgui import FlaskUI
ui = FlaskUI( ui = FlaskUI(
app, width=app_width, height=app_height, host=args.host, port=args.port, app,
close_server_on_exit=not args.no_gui_auto_close width=app_width,
height=app_height,
host=args.host,
port=args.port,
close_server_on_exit=not args.no_gui_auto_close,
) )
ui.run() ui.run()
else: else:

View File

@ -0,0 +1,62 @@
from pathlib import Path
import pytest
import torch
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.tests.test_model import get_config, assert_equal
from lama_cleaner.schema import HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@pytest.mark.parametrize("disable_nsfw", [True, False])
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_instruct_pix2pix(disable_nsfw, cpu_offload):
sd_steps = 50 if device == 'cuda' else 1
model = ModelManager(name="instruct_pix2pix",
device=torch.device(device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False,
cpu_offload=cpu_offload)
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1)
name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
assert_equal(
model,
cfg,
f"instruct_pix2pix_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3
)
@pytest.mark.parametrize("disable_nsfw", [False])
@pytest.mark.parametrize("cpu_offload", [False])
def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
sd_steps = 50 if device == 'cuda' else 1
model = ModelManager(name="instruct_pix2pix",
device=torch.device(device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False,
cpu_offload=cpu_offload)
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps)
name = f"snow"
assert_equal(
model,
cfg,
f"instruct_pix2pix_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -0,0 +1,42 @@
import io
from pathlib import Path
from PIL import Image
from lama_cleaner.helper import pil_to_bytes
current_dir = Path(__file__).parent.absolute().resolve()
png_img_p = current_dir / "image.png"
jpg_img_p = current_dir / "bunny.jpeg"
def print_exif(exif):
for k, v in exif.items():
print(f"{k}: {v}")
def test_png():
img = Image.open(png_img_p)
exif = img.getexif()
print_exif(exif)
pil_bytes = pil_to_bytes(img, ext="png", exif=exif)
res_img = Image.open(io.BytesIO(pil_bytes))
res_exif = res_img.getexif()
assert dict(exif) == dict(res_exif)
def test_jpeg():
img = Image.open(jpg_img_p)
exif = img.getexif()
print_exif(exif)
pil_bytes = pil_to_bytes(img, ext="jpeg", exif=exif)
res_img = Image.open(io.BytesIO(pil_bytes))
res_exif = res_img.getexif()
assert dict(exif) == dict(res_exif)

View File

@ -8,33 +8,37 @@ from lama_cleaner.schema import HDStrategy, SDSampler
from lama_cleaner.tests.test_model import get_config, assert_equal from lama_cleaner.tests.test_model import get_config, assert_equal
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result' save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True) save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device) device = torch.device(device)
@pytest.mark.parametrize("sd_device", ['cuda']) @pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim]) @pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("cpu_textencoder", [True, False]) @pytest.mark.parametrize("cpu_textencoder", [True, False])
@pytest.mark.parametrize("disable_nsfw", [True, False]) @pytest.mark.parametrize("disable_nsfw", [True, False])
def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): def test_runway_sd_1_5_ddim(
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
):
def callback(i, t, latents): def callback(i, t, latents):
print(f"sd_step_{i}") pass
if sd_device == 'cuda' and not torch.cuda.is_available(): if sd_device == "cuda" and not torch.cuda.is_available():
return return
sd_steps = 50 if sd_device == 'cuda' else 1 sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(name="sd1.5", model = ModelManager(
device=torch.device(sd_device), name="sd1.5",
hf_access_token="", device=torch.device(sd_device),
sd_run_local=True, hf_access_token="",
disable_nsfw=disable_nsfw, sd_run_local=True,
sd_cpu_textencoder=cpu_textencoder, disable_nsfw=disable_nsfw,
callback=callback) sd_cpu_textencoder=cpu_textencoder,
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) callback=callback,
)
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
@ -45,31 +49,35 @@ def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disab
f"runway_sd_{strategy.capitalize()}_{name}.png", f"runway_sd_{strategy.capitalize()}_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3 fx=1.3,
) )
@pytest.mark.parametrize("sd_device", ['cuda']) @pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a]) @pytest.mark.parametrize(
"sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a]
)
@pytest.mark.parametrize("cpu_textencoder", [False]) @pytest.mark.parametrize("cpu_textencoder", [False])
@pytest.mark.parametrize("disable_nsfw", [True]) @pytest.mark.parametrize("disable_nsfw", [True])
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
def callback(i, t, latents): def callback(i, t, latents):
print(f"sd_step_{i}") print(f"sd_step_{i}")
if sd_device == 'cuda' and not torch.cuda.is_available(): if sd_device == "cuda" and not torch.cuda.is_available():
return return
sd_steps = 50 if sd_device == 'cuda' else 1 sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(name="sd1.5", model = ModelManager(
device=torch.device(sd_device), name="sd1.5",
hf_access_token="", device=torch.device(sd_device),
sd_run_local=True, hf_access_token="",
disable_nsfw=disable_nsfw, sd_run_local=True,
sd_cpu_textencoder=cpu_textencoder, disable_nsfw=disable_nsfw,
callback=callback) sd_cpu_textencoder=cpu_textencoder,
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) callback=callback,
)
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
@ -80,35 +88,37 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
f"runway_sd_{strategy.capitalize()}_{name}.png", f"runway_sd_{strategy.capitalize()}_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3 fx=1.3,
) )
@pytest.mark.parametrize("sd_device", ['cuda']) @pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim]) @pytest.mark.parametrize("sampler", [SDSampler.ddim])
def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler): def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
def callback(i, t, latents): def callback(i, t, latents):
pass pass
if sd_device == 'cuda' and not torch.cuda.is_available(): if sd_device == "cuda" and not torch.cuda.is_available():
return return
sd_steps = 50 if sd_device == 'cuda' else 1 sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(name="sd1.5", model = ModelManager(
device=torch.device(sd_device), name="sd1.5",
hf_access_token="", device=torch.device(sd_device),
sd_run_local=True, hf_access_token="",
disable_nsfw=False, sd_run_local=True,
sd_cpu_textencoder=False, disable_nsfw=False,
callback=callback) sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config( cfg = get_config(
strategy, strategy,
sd_steps=sd_steps, sd_steps=sd_steps,
prompt='Face of a fox, high resolution, sitting on a park bench', prompt="Face of a fox, high resolution, sitting on a park bench",
negative_prompt='orange, yellow, small', negative_prompt="orange, yellow, small",
sd_sampler=sampler, sd_sampler=sampler,
sd_match_histograms=True sd_match_histograms=True,
) )
name = f"{sampler}_negative_prompt" name = f"{sampler}_negative_prompt"
@ -119,27 +129,33 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
f"runway_sd_{strategy.capitalize()}_{name}.png", f"runway_sd_{strategy.capitalize()}_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1 fx=1,
) )
@pytest.mark.parametrize("sd_device", ['cuda']) @pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) @pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
@pytest.mark.parametrize("cpu_textencoder", [False]) @pytest.mark.parametrize("cpu_textencoder", [False])
@pytest.mark.parametrize("disable_nsfw", [False]) @pytest.mark.parametrize("disable_nsfw", [False])
def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): def test_runway_sd_1_5_sd_scale(
if sd_device == 'cuda' and not torch.cuda.is_available(): sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
):
if sd_device == "cuda" and not torch.cuda.is_available():
return return
sd_steps = 50 if sd_device == 'cuda' else 1 sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(name="sd1.5", model = ModelManager(
device=torch.device(sd_device), name="sd1.5",
hf_access_token="", device=torch.device(sd_device),
sd_run_local=True, hf_access_token="",
disable_nsfw=disable_nsfw, sd_run_local=True,
sd_cpu_textencoder=cpu_textencoder) disable_nsfw=disable_nsfw,
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) sd_cpu_textencoder=cpu_textencoder,
)
cfg = get_config(
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
)
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
@ -150,26 +166,30 @@ def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, d
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png", f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3 fx=1.3,
) )
@pytest.mark.parametrize("sd_device", ['cuda']) @pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) @pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler): def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
if sd_device == 'cuda' and not torch.cuda.is_available(): if sd_device == "cuda" and not torch.cuda.is_available():
return return
sd_steps = 50 if sd_device == 'cuda' else 1 sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(name="sd1.5", model = ModelManager(
device=torch.device(sd_device), name="sd1.5",
hf_access_token="", device=torch.device(sd_device),
sd_run_local=True, hf_access_token="",
disable_nsfw=True, sd_run_local=True,
sd_cpu_textencoder=False, disable_nsfw=True,
cpu_offload=True) sd_cpu_textencoder=False,
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) cpu_offload=True,
)
cfg = get_config(
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
)
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}" name = f"device_{sd_device}_{sampler}"
@ -182,27 +202,3 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
) )
@pytest.mark.parametrize("sd_device", ['cpu'])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler):
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=False,
sd_cpu_textencoder=False,
cpu_offload=True)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
assert_equal(
model,
cfg,
f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload_cpu_device.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -6,9 +6,24 @@ import gradio as gr
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from lama_cleaner.const import AVAILABLE_MODELS, AVAILABLE_DEVICES, CPU_OFFLOAD_HELP, NO_HALF_HELP, DISABLE_NSFW_HELP, \ from lama_cleaner.const import (
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, OUTPUT_DIR_HELP, INPUT_HELP, \ AVAILABLE_MODELS,
GUI_HELP, DEFAULT_MODEL, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR AVAILABLE_DEVICES,
CPU_OFFLOAD_HELP,
NO_HALF_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_MODEL,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS,
)
_config_file = None _config_file = None
@ -33,16 +48,28 @@ class Config(BaseModel):
def load_config(installer_config: str): def load_config(installer_config: str):
if os.path.exists(installer_config): if os.path.exists(installer_config):
with open(installer_config, "r", encoding='utf-8') as f: with open(installer_config, "r", encoding="utf-8") as f:
return Config(**json.load(f)) return Config(**json.load(f))
else: else:
return Config() return Config()
def save_config( def save_config(
host, port, model, device, gui, no_gui_auto_close, no_half, cpu_offload, host,
disable_nsfw, sd_cpu_textencoder, enable_xformers, local_files_only, port,
model_dir, input, output_dir model,
device,
gui,
no_gui_auto_close,
no_half,
cpu_offload,
disable_nsfw,
sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,
output_dir,
): ):
config = Config(**locals()) config = Config(**locals())
print(config) print(config)
@ -63,6 +90,7 @@ def save_config(
def close_server(*args): def close_server(*args):
# TODO: make close both browser and server works # TODO: make close both browser and server works
import os, signal import os, signal
pid = os.getpid() pid = os.getpid()
os.kill(pid, signal.SIGUSR1) os.kill(pid, signal.SIGUSR1)
@ -86,33 +114,53 @@ def main(config_file: str):
port = gr.Number(init_config.port, label="Port", precision=0) port = gr.Number(init_config.port, label="Port", precision=0)
with gr.Row(): with gr.Row():
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model) model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
device = gr.Radio(AVAILABLE_DEVICES, label="Device", value=init_config.device) device = gr.Radio(
AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device
)
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}") gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
no_gui_auto_close = gr.Checkbox(init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}") no_gui_auto_close = gr.Checkbox(
init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}"
)
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}") cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}")
disable_nsfw = gr.Checkbox(init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}") disable_nsfw = gr.Checkbox(
sd_cpu_textencoder = gr.Checkbox(init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}") init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
enable_xformers = gr.Checkbox(init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}") )
local_files_only = gr.Checkbox(init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}") sd_cpu_textencoder = gr.Checkbox(
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
)
enable_xformers = gr.Checkbox(
init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}"
)
local_files_only = gr.Checkbox(
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
)
model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}") model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}")
input = gr.Textbox(init_config.input, label=f"Input file or directory. {INPUT_HELP}") input = gr.Textbox(
output_dir = gr.Textbox(init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}") init_config.input, label=f"Input file or directory. {INPUT_HELP}"
save_btn.click(save_config, [ )
host, output_dir = gr.Textbox(
port, init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}"
model, )
device, save_btn.click(
gui, save_config,
no_gui_auto_close, [
no_half, host,
cpu_offload, port,
disable_nsfw, model,
sd_cpu_textencoder, device,
enable_xformers, gui,
local_files_only, no_gui_auto_close,
model_dir, no_half,
input, cpu_offload,
output_dir, disable_nsfw,
], message) sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,
output_dir,
],
message,
)
demo.launch(inbrowser=True, show_api=False) demo.launch(inbrowser=True, show_api=False)

View File

@ -1,6 +1,7 @@
torch>=1.9.0 torch>=1.9.0
opencv-python opencv-python
flask_cors flask_cors
Jinja2==2.11.3
flask==1.1.4 flask==1.1.4
flaskwebgui==0.3.5 flaskwebgui==0.3.5
tqdm tqdm
@ -11,7 +12,8 @@ pytest
yacs yacs
markupsafe==2.0.1 markupsafe==2.0.1
scikit-image==0.19.3 scikit-image==0.19.3
diffusers[torch]==0.10.2 diffusers[torch]==0.12.1
transformers>=4.25.1 transformers>=4.25.1
watchdog==2.2.1 watchdog==2.2.1
gradio gradio
piexif==1.1.3

View File

@ -21,7 +21,7 @@ def load_requirements():
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files # https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
setuptools.setup( setuptools.setup(
name="lama-cleaner", name="lama-cleaner",
version="0.34.0", version="0.35.0",
author="PanicByte", author="PanicByte",
author_email="cwq1913@gmail.com", author_email="cwq1913@gmail.com",
description="Image inpainting tool powered by SOTA AI Model", description="Image inpainting tool powered by SOTA AI Model",