Merge branch 'make_gif_share'

This commit is contained in:
Qing 2023-02-07 21:47:11 +08:00
commit f837e4be8a
49 changed files with 15954 additions and 12250 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ build
!lama_cleaner/app/build
dist/
lama_cleaner.egg-info/
venv/

View File

@ -59,6 +59,6 @@ Only needed if you plan to modify the frontend and recompile yourself.
Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their
great online services [here](https://cleanup.pictures/).
- Install dependencies:`cd lama_cleaner/app/ && yarn`
- Start development server: `yarn start`
- Build: `yarn build`
- Install dependencies:`cd lama_cleaner/app/ && pnpm install`
- Start development server: `pnpm start`
- Build: `pnpm build`

View File

@ -1,7 +1,8 @@
{
"files": {
"main.css": "/static/css/main.c28d98ca.css",
"main.js": "/static/js/main.1bd455bc.js",
"main.css": "/static/css/main.e24c9a9b.css",
"main.js": "/static/js/main.23732b19.js",
"static/media/coffee-machine-lineal.gif": "/static/media/coffee-machine-lineal.ee32631219cc3986f861.gif",
"static/media/WorkSans-SemiBold.ttf": "/static/media/WorkSans-SemiBold.1e98db4eb705b586728e.ttf",
"static/media/WorkSans-Bold.ttf": "/static/media/WorkSans-Bold.2bea7a7f7d052c74da25.ttf",
"static/media/WorkSans-Regular.ttf": "/static/media/WorkSans-Regular.bb287b894b27372d8ea7.ttf",
@ -9,7 +10,7 @@
"index.html": "/index.html"
},
"entrypoints": [
"static/css/main.c28d98ca.css",
"static/js/main.1bd455bc.js"
"static/css/main.e24c9a9b.css",
"static/js/main.23732b19.js"
]
}

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><script defer="defer" src="/static/js/main.1bd455bc.js"></script><link href="/static/css/main.c28d98ca.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
<!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><script defer="defer" src="/static/js/main.23732b19.js"></script><link href="/static/css/main.e24c9a9b.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 KiB

View File

@ -2,7 +2,7 @@
"name": "lama-cleaner",
"version": "0.1.0",
"private": true,
"proxy": "http://localhost:8080",
"proxy": "http://127.0.0.1:8080",
"dependencies": {
"@babel/core": "^7.16.0",
"@heroicons/react": "^2.0.0",

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
import { Rect, Settings } from '../store/Atoms'
import { dataURItoBlob, srcToFile } from '../utils'
import { dataURItoBlob, loadImage, srcToFile } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
@ -82,6 +82,11 @@ export default async function inpaint(
fd.append('paintByExampleImage', paintByExampleImage)
}
// InstructPix2Pix
fd.append('p2pSteps', settings.p2pSteps.toString())
fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString())
fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString())
if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080')
} else {
@ -223,3 +228,34 @@ export async function downloadToOutput(
throw new Error(`Something went wrong: ${error}`)
}
}
export async function makeGif(
originFile: File,
cleanImage: HTMLImageElement,
filename: string,
mimeType: string
) {
const cleanFile = await srcToFile(cleanImage.src, filename, mimeType)
const fd = new FormData()
fd.append('origin_img', originFile)
fd.append('clean_img', cleanFile)
fd.append('filename', filename)
try {
const res = await fetch(`${API_ENDPOINT}/make_gif`, {
method: 'POST',
body: fd,
})
if (!res.ok) {
const errMsg = await res.text()
throw new Error(errMsg)
}
const blob = await res.blob()
const newImage = new Image()
await loadImage(newImage, URL.createObjectURL(blob))
return newImage
} catch (error) {
throw new Error(`Something went wrong: ${error}`)
}
}

View File

@ -2,6 +2,7 @@ import React, { useState } from 'react'
import { Coffee } from 'react-feather'
import Button from '../shared/Button'
import Modal from '../shared/Modal'
import CoffeeMachineGif from '../../media/coffee-machine-lineal.gif'
const CoffeeIcon = () => {
const [show, setShow] = useState(false)
@ -24,10 +25,26 @@ const CoffeeIcon = () => {
show={show}
showCloseIcon={false}
>
<h4 style={{ lineHeight: '24px' }}>
Hi there, If you found my project is useful, and want to help keep it
alive please consider donating! Thank you for your support!
</h4>
<div
style={{
display: 'flex',
flexDirection: 'column',
}}
>
<h4 style={{ lineHeight: '24px' }}>
Hi, if you found my project is useful, please conside buy me a
coffee to support my work. Thanks!
</h4>
<img
src={CoffeeMachineGif}
alt="coffee machine"
style={{
height: 150,
objectFit: 'contain',
}}
/>
</div>
<div
style={{
display: 'flex',
@ -53,7 +70,6 @@ const CoffeeIcon = () => {
}}
>
Sure
<Coffee />
</div>
</Button>
</a>

View File

@ -55,10 +55,9 @@
position: fixed;
bottom: 0.5rem;
border-radius: 3rem;
padding: 0.6rem 3rem;
display: grid;
grid-template-areas: 'toolkit-size-selector toolkit-brush-slider toolkit-btns';
column-gap: 2rem;
padding: 0.6rem 32px;
display: flex;
gap: 16px;
align-items: center;
justify-content: center;
backdrop-filter: blur(12px);
@ -96,10 +95,8 @@
}
.editor-toolkit-btns {
grid-area: toolkit-btns;
display: grid;
grid-auto-flow: column;
column-gap: 1rem;
display: flex;
gap: 12px;
}
.brush-shape {

View File

@ -16,10 +16,11 @@ import {
TransformComponent,
TransformWrapper,
} from 'react-zoom-pan-pinch'
import { useRecoilState, useRecoilValue } from 'recoil'
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint, {
downloadToOutput,
makeGif,
postInteractiveSeg,
} from '../../adapters/inpainting'
import Button from '../shared/Button'
@ -43,10 +44,12 @@ import {
imageHeightState,
imageWidthState,
interactiveSegClicksState,
isDiffusionModelsState,
isInpaintingState,
isInteractiveSegRunningState,
isInteractiveSegState,
isPaintByExampleState,
isPix2PixState,
isSDState,
negativePropmtState,
propmtState,
@ -66,6 +69,7 @@ import FileSelect from '../FileSelect/FileSelect'
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions'
import InteractiveSegReplaceModal from '../InteractiveSeg/ReplaceModal'
import MakeGIF from './MakeGIF'
const TOOLBAR_SIZE = 200
const MIN_BRUSH_SIZE = 10
@ -112,11 +116,11 @@ export default function Editor() {
const settings = useRecoilValue(settingState)
const [seedVal, setSeed] = useRecoilState(seedState)
const croperRect = useRecoilValue(croperState)
const [toastVal, setToastState] = useRecoilState(toastState)
const setToastState = useSetRecoilState(toastState)
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
const runMannually = useRecoilValue(runManuallyState)
const isSD = useRecoilValue(isSDState)
const isPaintByExample = useRecoilValue(isPaintByExampleState)
const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
const isPix2Pix = useRecoilValue(isPix2PixState)
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
isInteractiveSegState
)
@ -181,8 +185,8 @@ export default function Editor() {
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
const enableFileManager = useRecoilValue(enableFileManagerState)
const [imageWidth, setImageWidth] = useRecoilState(imageWidthState)
const [imageHeight, setImageHeight] = useRecoilState(imageHeightState)
const setImageWidth = useSetRecoilState(imageWidthState)
const setImageHeight = useSetRecoilState(imageHeightState)
const app = useRecoilValue(appState)
const draw = useCallback(
@ -253,13 +257,40 @@ export default function Editor() {
_lineGroups.forEach(lineGroup => {
drawLines(ctx, lineGroup, 'white')
})
if (
(maskImage === undefined || maskImage === null) &&
_lineGroups.length === 1 &&
_lineGroups[0].length === 0 &&
isPix2Pix
) {
// For InstructPix2Pix without mask
drawLines(
ctx,
[
{
size: 9999999999,
pts: [
{ x: 0, y: 0 },
{ x: original.naturalWidth, y: 0 },
{ x: original.naturalWidth, y: original.naturalHeight },
{ x: 0, y: original.naturalHeight },
],
},
],
'white'
)
}
},
[context, maskCanvas]
[context, maskCanvas, isPix2Pix]
)
const hadDrawSomething = useCallback(() => {
if (isPix2Pix) {
return true
}
return curLineGroup.length !== 0
}, [curLineGroup])
}, [curLineGroup, isPix2Pix])
const drawOnCurrentRender = useCallback(
(lineGroup: LineGroup) => {
@ -424,6 +455,8 @@ export default function Editor() {
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask)
} else if (isPix2Pix) {
runInpainting(false, undefined, null)
} else {
setToastState({
open: true,
@ -839,7 +872,7 @@ export default function Editor() {
}
if (
(isSD || isPaintByExample) &&
isDiffusionModels &&
settings.showCroper &&
isOutsideCroper(mouseXY(ev))
) {
@ -1385,7 +1418,7 @@ export default function Editor() {
minHeight={Math.min(256, original.naturalHeight)}
minWidth={Math.min(256, original.naturalWidth)}
scale={scale}
show={(isSD || isPaintByExample) && settings.showCroper}
show={isDiffusionModels && settings.showCroper}
/>
{isInteractiveSeg ? <InteractiveSeg /> : <></>}
@ -1439,7 +1472,7 @@ export default function Editor() {
)}
<div className="editor-toolkit-panel">
{isSD || isPaintByExample || file === undefined ? (
{isDiffusionModels || file === undefined ? (
<></>
) : (
<SizeSelector
@ -1534,6 +1567,7 @@ export default function Editor() {
}}
disabled={renders.length === 0}
/>
<MakeGIF renders={renders} />
<Button
toolTip="Save Image"
icon={<ArrowDownTrayIcon />}
@ -1541,7 +1575,7 @@ export default function Editor() {
onClick={download}
/>
{settings.runInpaintingManually && !isSD && !isPaintByExample && (
{settings.runInpaintingManually && !isDiffusionModels && (
<Button
toolTip="Run Inpainting"
icon={

View File

@ -0,0 +1,114 @@
import React, { useState } from 'react'
import { GifIcon } from '@heroicons/react/24/outline'
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
import Button from '../shared/Button'
import { fileState, gifImageState, toastState } from '../../store/Atoms'
import { makeGif } from '../../adapters/inpainting'
import Modal from '../shared/Modal'
import { LoadingIcon } from '../shared/Toast'
import { downloadImage } from '../../utils'
interface Props {
renders: HTMLImageElement[]
}
const MakeGIF = (props: Props) => {
const { renders } = props
const [gifImg, setGifImg] = useRecoilState(gifImageState)
const file = useRecoilValue(fileState)
const setToastState = useSetRecoilState(toastState)
const [show, setShow] = useState(false)
const handleOnClose = () => {
setShow(false)
}
const handleDownload = () => {
if (gifImg) {
const name = file.name.replace(/\.[^/.]+$/, '.gif')
downloadImage(gifImg.src, name)
}
}
return (
<div>
<Button
toolTip="Make Gif"
icon={<GifIcon />}
disabled={!renders.length}
onClick={async () => {
setShow(true)
setGifImg(null)
try {
const gif = await makeGif(
file,
renders[renders.length - 1],
file.name,
file.type
)
if (gif) {
setGifImg(gif)
}
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 2000,
})
}
}}
/>
<Modal
onClose={handleOnClose}
title="GIF"
className="modal-setting"
show={show}
>
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
flexDirection: 'column',
gap: 16,
}}
>
{gifImg ? (
<img src={gifImg.src} style={{ borderRadius: 8 }} alt="gif" />
) : (
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
gap: 8,
}}
>
<LoadingIcon />
Generating GIF...
</div>
)}
{gifImg && (
<div
style={{
display: 'flex',
width: '100%',
justifyContent: 'flex-end',
alignItems: 'center',
gap: '12px',
}}
>
<Button onClick={handleDownload} border>
Download
</Button>
</div>
)}
</div>
</Modal>
</div>
)
}
export default MakeGIF

View File

@ -7,6 +7,7 @@ import {
enableFileManagerState,
fileState,
isInpaintingState,
isPix2PixState,
isSDState,
maskState,
runManuallyState,
@ -30,6 +31,7 @@ const Header = () => {
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
const isSD = useRecoilValue(isSDState)
const isPix2Pix = useRecoilValue(isPix2PixState)
const runManually = useRecoilValue(runManuallyState)
const [openMaskPopover, setOpenMaskPopover] = useState(false)
const [showFileManager, setShowFileManager] =
@ -172,7 +174,7 @@ const Header = () => {
</div>
</div>
{isSD && file ? <PromptInput /> : <></>}
{(isSD || isPix2Pix) && file ? <PromptInput /> : <></>}
<div className="header-icons-wrapper">
<CoffeeIcon />

View File

@ -179,28 +179,16 @@ function ModelSettingBlock() {
const renderOptionDesc = (): ReactNode => {
switch (setting.model) {
case AIModel.LAMA:
return undefined
case AIModel.LDM:
return renderLDMModelDesc()
case AIModel.ZITS:
return renderZITSModelDesc()
case AIModel.MAT:
return undefined
case AIModel.FCF:
return renderFCFModelDesc()
case AIModel.SD15:
return undefined
case AIModel.SD2:
return undefined
case AIModel.PAINT_BY_EXAMPLE:
return undefined
case AIModel.Mange:
return undefined
case AIModel.CV2:
return renderOpenCV2Desc()
default:
return <></>
return undefined
}
}
@ -266,6 +254,12 @@ function ModelSettingBlock() {
'https://arxiv.org/abs/2211.13227',
'https://github.com/Fantasy-Studio/Paint-by-Example'
)
case AIModel.PIX2PIX:
return renderModelDesc(
'InstructPix2Pix',
'https://arxiv.org/abs/2211.09800',
'https://github.com/timothybrooks/instruct-pix2pix'
)
default:
return <></>
}

View File

@ -1,5 +1,6 @@
import React from 'react'
import { useRecoilState } from 'recoil'
import { Cog6ToothIcon } from '@heroicons/react/24/outline'
import { settingState } from '../../store/Atoms'
import Button from '../shared/Button'
@ -16,29 +17,7 @@ const SettingIcon = () => {
onClick={onClick}
toolTip="Settings"
style={{ border: 0 }}
icon={
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
role="img"
width="28"
height="28"
viewBox="0 0 24 24"
stroke="currentColor"
strokeWidth="2"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
/>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
/>
</svg>
}
icon={<Cog6ToothIcon />}
/>
</div>
)

View File

@ -8,7 +8,6 @@
color: var(--modal-text-color);
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
width: 680px;
min-height: 420px;
@include mobile {
display: grid;

View File

@ -1,6 +1,7 @@
import React from 'react'
import { useRecoilState, useRecoilValue } from 'recoil'
import {
isDiffusionModelsState,
isPaintByExampleState,
isSDState,
settingState,
@ -28,7 +29,7 @@ export default function SettingModal(props: SettingModalProps) {
const { onClose } = props
const [setting, setSettingState] = useRecoilState(settingState)
const isSD = useRecoilValue(isSDState)
const isPaintByExample = useRecoilValue(isPaintByExampleState)
const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
const handleOnClose = () => {
setSettingState(old => {
@ -56,9 +57,9 @@ export default function SettingModal(props: SettingModalProps) {
show={setting.show}
>
<DownloadMaskSettingBlock />
{isSD || isPaintByExample ? <></> : <ManualRunInpaintingSettingBlock />}
{isDiffusionModels ? <></> : <ManualRunInpaintingSettingBlock />}
<ModelSettingBlock />
{isSD ? <></> : <HDSettingBlock />}
{isDiffusionModels ? <></> : <HDSettingBlock />}
</Modal>
)
}

View File

@ -15,13 +15,14 @@
}
.shortcut-options {
display: grid;
row-gap: 1rem;
display: flex;
gap: 48px;
flex-direction: row;
.shortcut-option {
display: grid;
grid-template-columns: repeat(2, auto);
column-gap: 6rem;
column-gap: 2rem;
align-items: center;
@include mobile {
@ -67,3 +68,10 @@
}
}
}
.shortcut-options-column {
display: flex;
flex-direction: column;
gap: 12px;
width: 320px;
}

View File

@ -50,24 +50,37 @@ export default function ShortcutsModal() {
show={shortcutsShow}
>
<div className="shortcut-options">
<ShortCut
content="Multi-Stroke Mask Drawing"
keys={[`Hold ${CmdOrCtrl}`]}
/>
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
<ShortCut content="Interactive Segmentation" keys={['I']} />
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
<ShortCut content="Redo Inpainting" keys={[CmdOrCtrl, 'Shift', 'Z']} />
<ShortCut content="View Original Image" keys={['Hold Tab']} />
<ShortCut content="Pan" keys={['Space + Drag']} />
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
<ShortCut content="Decrease Brush Size" keys={['[']} />
<ShortCut content="Increase Brush Size" keys={[']']} />
<ShortCut content="Toggle Dark Mode" keys={['Shift', 'D']} />
<ShortCut content="Toggle Hotkeys Dialog" keys={['H']} />
<ShortCut content="Toggle Settings Dialog" keys={['S']} />
<ShortCut content="Toggle File Manager" keys={['F']} />
<div className="shortcut-options-column">
<ShortCut content="Pan" keys={['Space + Drag']} />
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
<ShortCut content="Decrease Brush Size" keys={['[']} />
<ShortCut content="Increase Brush Size" keys={[']']} />
<ShortCut content="View Original Image" keys={['Hold Tab']} />
<ShortCut
content="Multi-Stroke Drawing"
keys={[`Hold ${CmdOrCtrl}`]}
/>
<ShortCut content="Cancel Drawing" keys={['Esc']} />
</div>
<div className="shortcut-options-column">
<ShortCut content="Undo" keys={[CmdOrCtrl, 'Z']} />
<ShortCut content="Redo" keys={[CmdOrCtrl, 'Shift', 'Z']} />
<ShortCut content="Copy Result" keys={[CmdOrCtrl, 'C']} />
<ShortCut content="Paste Image" keys={[CmdOrCtrl, 'V']} />
<ShortCut
content="Trigger Manually Inpainting"
keys={['Shift', 'R']}
/>
<ShortCut content="Trigger Interactive Segmentation" keys={['I']} />
</div>
<div className="shortcut-options-column">
<ShortCut content="Switch Theme" keys={['Shift', 'D']} />
<ShortCut content="Toggle Hotkeys Dialog" keys={['H']} />
<ShortCut content="Toggle Settings Dialog" keys={['S']} />
<ShortCut content="Toggle File Manager" keys={['F']} />
</div>
</div>
</Modal>
)

View File

@ -0,0 +1,224 @@
import React, { FormEvent } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use'
import {
isInpaintingState,
negativePropmtState,
propmtState,
settingState,
} from '../../store/Atoms'
import NumberInputSetting from '../Settings/NumberInputSetting'
import SettingBlock from '../Settings/SettingBlock'
import { Switch, SwitchThumb } from '../shared/Switch'
import TextAreaInput from '../shared/Textarea'
import emitter, { EVENT_PROMPT } from '../../event'
import ImageResizeScale from './ImageResizeScale'
const INPUT_WIDTH = 30
const P2PSidePanel = () => {
const [open, toggleOpen] = useToggle(true)
const [setting, setSettingState] = useRecoilState(settingState)
const [negativePrompt, setNegativePrompt] =
useRecoilState(negativePropmtState)
const isInpainting = useRecoilValue(isInpaintingState)
const prompt = useRecoilValue(propmtState)
const handleOnInput = (evt: FormEvent<HTMLTextAreaElement>) => {
evt.preventDefault()
evt.stopPropagation()
const target = evt.target as HTMLTextAreaElement
setNegativePrompt(target.value)
}
const onKeyUp = (e: React.KeyboardEvent) => {
if (
e.key === 'Enter' &&
(e.ctrlKey || e.metaKey) &&
prompt.length !== 0 &&
!isInpainting
) {
emitter.emit(EVENT_PROMPT)
}
}
return (
<div className="side-panel">
<PopoverPrimitive.Root open={open}>
<PopoverPrimitive.Trigger
className="btn-primary side-panel-trigger"
onClick={() => toggleOpen()}
>
Config
</PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="side-panel-content">
<SettingBlock
title="Croper"
input={
<Switch
checked={setting.showCroper}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, showCroper: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/>
<ImageResizeScale />
<NumberInputSetting
title="Steps"
width={INPUT_WIDTH}
value={`${setting.p2pSteps}`}
desc="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, p2pSteps: val }
})
}}
/>
<NumberInputSetting
title="Guidance Scale"
width={INPUT_WIDTH}
allowFloat
value={`${setting.p2pGuidanceScale}`}
desc="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality."
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, p2pGuidanceScale: val }
})
}}
/>
<NumberInputSetting
title="Image Guidance Scale"
width={INPUT_WIDTH}
allowFloat
value={`${setting.p2pImageGuidanceScale}`}
desc=""
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, p2pImageGuidanceScale: val }
})
}}
/>
{/* <NumberInputSetting
title="Mask Blur"
width={INPUT_WIDTH}
value={`${setting.sdMaskBlur}`}
desc="Blur the edge of mask area. The higher the number the smoother blend with the original image"
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, sdMaskBlur: val }
})
}}
/> */}
{/* <SettingBlock
title="Match Histograms"
desc="Match the inpainting result histogram to the source image histogram, will improves the inpainting quality for some images."
input={
<Switch
checked={setting.sdMatchHistograms}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, sdMatchHistograms: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/> */}
{/* <SettingBlock
className="sub-setting-block"
title="Sampler"
input={
<Selector
width={80}
value={setting.sdSampler as string}
options={Object.values(SDSampler)}
onChange={val => {
const sampler = val as SDSampler
setSettingState(old => {
return { ...old, sdSampler: sampler }
})
}}
/>
}
/> */}
<SettingBlock
title="Seed"
input={
<div
style={{
display: 'flex',
gap: 0,
justifyContent: 'center',
alignItems: 'center',
}}
>
<NumberInputSetting
title=""
width={80}
value={`${setting.sdSeed}`}
desc=""
disable={!setting.sdSeedFixed}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, sdSeed: val }
})
}}
/>
<Switch
checked={setting.sdSeedFixed}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, sdSeedFixed: value }
})
}}
style={{ marginLeft: '8px' }}
>
<SwitchThumb />
</Switch>
</div>
}
/>
<SettingBlock
className="sub-setting-block"
title="Negative prompt"
layout="v"
input={
<TextAreaInput
className="negative-prompt"
value={negativePrompt}
onInput={handleOnInput}
onKeyUp={onKeyUp}
placeholder=""
/>
}
/>
</PopoverPrimitive.Content>
</PopoverPrimitive.Portal>
</PopoverPrimitive.Root>
</div>
)
}
export default P2PSidePanel

View File

@ -8,6 +8,7 @@ import {
AIModel,
fileState,
isPaintByExampleState,
isPix2PixState,
isSDState,
settingState,
showFileManagerState,
@ -22,6 +23,7 @@ import {
import SidePanel from './SidePanel/SidePanel'
import PESidePanel from './SidePanel/PESidePanel'
import FileManager from './FileManager/FileManager'
import P2PSidePanel from './SidePanel/P2PSidePanel'
const Workspace = () => {
const setFile = useSetRecoilState(fileState)
@ -29,6 +31,7 @@ const Workspace = () => {
const [toastVal, setToastState] = useRecoilState(toastState)
const isSD = useRecoilValue(isSDState)
const isPaintByExample = useRecoilValue(isPaintByExampleState)
const isPix2Pix = useRecoilValue(isPix2PixState)
const [showFileManager, setShowFileManager] =
useRecoilState(showFileManagerState)
@ -98,6 +101,7 @@ const Workspace = () => {
<>
{isSD ? <SidePanel /> : <></>}
{isPaintByExample ? <PESidePanel /> : <></>}
{isPix2Pix ? <P2PSidePanel /> : <></>}
<FileManager
photoWidth={256}
show={showFileManager}

View File

@ -3,7 +3,7 @@ import * as ToastPrimitive from '@radix-ui/react-toast'
import { ToastProps } from '@radix-ui/react-toast'
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/24/outline'
const LoadingIcon = () => {
export const LoadingIcon = () => {
return (
<span className="loading-icon">
<svg

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 KiB

View File

@ -14,6 +14,7 @@ export enum AIModel {
CV2 = 'cv2',
Mange = 'manga',
PAINT_BY_EXAMPLE = 'paint_by_example',
PIX2PIX = 'instruct_pix2pix',
}
export const maskState = atom<File | undefined>({
@ -45,6 +46,7 @@ interface AppState {
interactiveSegClicks: number[][]
showFileManager: boolean
enableFileManager: boolean
gifImage: HTMLImageElement | undefined
}
export const appState = atom<AppState>({
@ -61,6 +63,7 @@ export const appState = atom<AppState>({
interactiveSegClicks: [],
showFileManager: false,
enableFileManager: false,
gifImage: undefined,
},
})
@ -134,6 +137,18 @@ export const enableFileManagerState = selector({
},
})
export const gifImageState = selector({
key: 'gifImageState',
get: ({ get }) => {
const app = get(appState)
return app.gifImage
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, gifImage: newValue })
},
})
export const fileState = selector({
key: 'fileState',
get: ({ get }) => {
@ -329,6 +344,11 @@ export interface Settings {
paintByExampleSeedFixed: boolean
paintByExampleMaskBlur: number
paintByExampleMatchHistograms: boolean
// InstructPix2Pix
p2pSteps: number
p2pImageGuidanceScale: number
p2pGuidanceScale: number
}
const defaultHDSettings: ModelsHDSettings = {
@ -388,6 +408,13 @@ const defaultHDSettings: ModelsHDSettings = {
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.PIX2PIX]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.Mange]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1280,
@ -457,6 +484,11 @@ export const settingStateDefault: Settings = {
paintByExampleMaskBlur: 5,
paintByExampleSeedFixed: false,
paintByExampleMatchHistograms: false,
// InstructPix2Pix
p2pSteps: 50,
p2pImageGuidanceScale: 1.5,
p2pGuidanceScale: 7.5,
}
const localStorageEffect =
@ -553,12 +585,33 @@ export const isPaintByExampleState = selector({
},
})
export const isPix2PixState = selector({
key: 'isPix2PixState',
get: ({ get }) => {
const settings = get(settingState)
return settings.model === AIModel.PIX2PIX
},
})
export const runManuallyState = selector({
key: 'runManuallyState',
get: ({ get }) => {
const settings = get(settingState)
const isSD = get(isSDState)
const isPaintByExample = get(isPaintByExampleState)
return settings.runInpaintingManually || isSD || isPaintByExample
const isPix2Pix = get(isPix2PixState)
return (
settings.runInpaintingManually || isSD || isPaintByExample || isPix2Pix
)
},
})
export const isDiffusionModelsState = selector({
key: 'isDiffusionModelsState',
get: ({ get }) => {
const isSD = get(isSDState)
const isPaintByExample = get(isPaintByExampleState)
const isPix2Pix = get(isPix2PixState)
return isSD || isPaintByExample || isPix2Pix
},
})

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,12 @@
import os
MPS_SUPPORT_MODELS = [
"instruct_pix2pix",
"sd1.5",
"sd2",
"paint_by_example"
]
DEFAULT_MODEL = "lama"
AVAILABLE_MODELS = [
"lama",
@ -11,7 +18,8 @@ AVAILABLE_MODELS = [
"cv2",
"manga",
"sd2",
"paint_by_example"
"paint_by_example",
"instruct_pix2pix",
]
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]

View File

@ -62,7 +62,7 @@ def load_model(model: torch.nn.Module, url_or_path, device):
model_path = download_model(url_or_path)
try:
state_dict = torch.load(model_path, map_location='cpu')
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.to(device)
logger.info(f"Load model from: {model_path}")
@ -85,26 +85,43 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
return image_bytes
def load_img(img_bytes, gray: bool = False):
def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes:
with io.BytesIO() as output:
pil_img.save(output, format=ext, exif=exif, quality=95)
image_bytes = output.getvalue()
return image_bytes
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
alpha_channel = None
image = Image.open(io.BytesIO(img_bytes))
try:
if return_exif:
exif = image.getexif()
except:
exif = None
logger.error("Failed to extract exif from image")
try:
image = ImageOps.exif_transpose(image)
except:
pass
if gray:
image = image.convert('L')
image = image.convert("L")
np_img = np.array(image)
else:
if image.mode == 'RGBA':
if image.mode == "RGBA":
np_img = np.array(image)
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else:
image = image.convert('RGB')
image = image.convert("RGB")
np_img = np.array(image)
if return_exif:
return np_img, alpha_channel, exif
return np_img, alpha_channel

125
lama_cleaner/make_gif.py Normal file
View File

@ -0,0 +1,125 @@
import io
import math
from pathlib import Path
from PIL import Image, ImageDraw
def keep_ratio_resize(img, size, resample=Image.BILINEAR):
if img.width > img.height:
w = size
h = int(img.height * size / img.width)
else:
h = size
w = int(img.width * size / img.height)
return img.resize((w, h), resample)
def cubic_bezier(p1, p2, duration: int, frames: int):
"""
Args:
p1:
p2:
duration: Total duration of the curve
frames:
Returns:
"""
x0, y0 = (0, 0)
x1, y1 = p1
x2, y2 = p2
x3, y3 = (1, 1)
def cal_y(t):
return math.pow(1 - t, 3) * y0 + \
3 * math.pow(1 - t, 2) * t * y1 + \
3 * (1 - t) * math.pow(t, 2) * y2 + \
math.pow(t, 3) * y3
def cal_x(t):
return math.pow(1 - t, 3) * x0 + \
3 * math.pow(1 - t, 2) * t * x1 + \
3 * (1 - t) * math.pow(t, 2) * x2 + \
math.pow(t, 3) * x3
res = []
for t in range(0, 1 * frames, duration):
t = t / frames
res.append((cal_x(t), cal_y(t)))
res.append((1, 0))
return res
def make_compare_gif(
clean_img: Image.Image,
src_img: Image.Image,
max_side_length: int = 600,
splitter_width: int = 5,
splitter_color=(255, 203, 0, int(255 * 0.73))
):
if clean_img.size != src_img.size:
clean_img = clean_img.resize(src_img.size, Image.BILINEAR)
duration_per_frame = 20
num_frames = 50
# erase-in-out
cubic_bezier_points = cubic_bezier((0.33, 0), (0.66, 1), 1, num_frames)
cubic_bezier_points.reverse()
max_side_length = min(max_side_length, max(clean_img.size))
src_img = keep_ratio_resize(src_img, max_side_length)
clean_img = keep_ratio_resize(clean_img, max_side_length)
width, height = src_img.size
# Generate images to make Gif from right to left
images = []
for i in range(num_frames):
new_frame = Image.new('RGB', (width, height))
new_frame.paste(clean_img, (0, 0))
left = int(cubic_bezier_points[i][0] * width)
cropped_src_img = src_img.crop((left, 0, width, height))
new_frame.paste(cropped_src_img, (left, 0, width, height))
if i != num_frames - 1:
# draw a yellow splitter on the edge of the cropped image
draw = ImageDraw.Draw(new_frame)
draw.line([(left, 0), (left, height)], width=splitter_width, fill=splitter_color)
images.append(new_frame)
for i in range(10):
images.append(src_img)
cubic_bezier_points.reverse()
# Generate images to make Gif from left to right
for i in range(num_frames):
new_frame = Image.new('RGB', (width, height))
new_frame.paste(src_img, (0, 0))
right = int(cubic_bezier_points[i][0] * width)
cropped_src_img = clean_img.crop((0, 0, right, height))
new_frame.paste(cropped_src_img, (0, 0, right, height))
if i != num_frames - 1:
# draw a yellow splitter on the edge of the cropped image
draw = ImageDraw.Draw(new_frame)
draw.line([(right, 0), (right, height)], width=splitter_width, fill=splitter_color)
images.append(new_frame)
images.append(clean_img)
img_byte_arr = io.BytesIO()
clean_img.save(
img_byte_arr,
format='GIF',
save_all=True,
include_color_table=True,
append_images=images,
optimize=False,
duration=duration_per_frame,
loop=0
)
return img_byte_arr.getvalue()

View File

@ -245,3 +245,42 @@ class InpaintModel:
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
class DiffusionInpaintModel(InpaintModel):
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result

View File

@ -0,0 +1,83 @@
import PIL.Image
import cv2
import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config
class InstructPix2Pix(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers import StableDiffusionInstructPix2PixPipeline
fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
))
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix",
revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype,
**model_kwargs
)
self.model.enable_attention_slicing()
if kwargs.get('enable_xformers', False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get('cpu_offload', False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
"""
set_seed(config.sd_seed)
output = self.model(
image=PIL.Image.fromarray(image),
prompt=config.prompt,
negative_prompt=config.negative_prompt,
num_inference_steps=config.p2p_steps,
image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.p2p_guidance_scale,
output_type="np.array",
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
#
# def forward_post_process(self, result, image, mask, config):
# if config.sd_match_histograms:
# result = self._match_histograms(result, image[:, :, ::-1], mask)
#
# if config.sd_mask_blur != 0:
# k = 2 * config.sd_mask_blur + 1
# mask = cv2.GaussianBlur(mask, (k, k), 0)
# return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -1,19 +1,16 @@
import random
import PIL
import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import DiffusionPipeline
from loguru import logger
from lama_cleaner.helper import resize_max_size
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config
class PaintByExample(InpaintModel):
class PaintByExample(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
@ -53,11 +50,7 @@ class PaintByExample(InpaintModel):
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
seed = config.paint_by_example_seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(config.paint_by_example_seed)
output = self.model(
image=PIL.Image.fromarray(image),
@ -71,42 +64,6 @@ class PaintByExample(InpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def forward_post_process(self, result, image, mask, config):
if config.paint_by_example_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -4,20 +4,25 @@ import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers import (
PNDMScheduler,
DDIMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from loguru import logger
from lama_cleaner.helper import resize_max_size
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc, set_seed
from lama_cleaner.schema import Config, SDSampler
class CPUTextEncoderWrapper:
def __init__(self, text_encoder, torch_dtype):
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype
del text_encoder
@ -25,27 +30,40 @@ class CPUTextEncoderWrapper:
def __call__(self, x, **kwargs):
input_device = x.device
return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)]
return [
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
.to(input_device)
.to(self.torch_dtype)
]
@property
def dtype(self):
return self.torch_dtype
class SD(InpaintModel):
class SD(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
fp16 = not kwargs.get("no_half", False)
model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
))
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path,
@ -58,40 +76,23 @@ class SD(InpaintModel):
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get('enable_xformers', False):
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get('cpu_offload', False) and use_gpu:
if kwargs.get("cpu_offload", False) and use_gpu:
# TODO: gpu_id
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs['sd_cpu_textencoder']:
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -118,11 +119,7 @@ class SD(InpaintModel):
self.model.scheduler = scheduler
seed = config.sd_seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(config.sd_seed)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
@ -147,24 +144,6 @@ class SD(InpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)

View File

@ -1,4 +1,5 @@
import math
import random
from typing import Any
import torch
@ -713,3 +714,10 @@ def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@ -7,13 +7,14 @@ from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.manga import Manga
from lama_cleaner.model.mat import MAT
from lama_cleaner.model.paint_by_example import PaintByExample
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
from lama_cleaner.model.sd import SD15, SD2
from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
"sd2": SD2, "paint_by_example": PaintByExample}
"sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix}
class ModelManager:

View File

@ -5,9 +5,25 @@ from pathlib import Path
from loguru import logger
from lama_cleaner.const import AVAILABLE_MODELS, NO_HALF_HELP, CPU_OFFLOAD_HELP, DISABLE_NSFW_HELP, \
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, AVAILABLE_DEVICES, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, \
OUTPUT_DIR_HELP, INPUT_HELP, GUI_HELP, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR
from lama_cleaner.const import (
AVAILABLE_MODELS,
NO_HALF_HELP,
CPU_OFFLOAD_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
AVAILABLE_DEVICES,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR,
DEFAULT_MODEL,
MPS_SUPPORT_MODELS,
)
from lama_cleaner.runtime import dump_environment_info
@ -16,22 +32,40 @@ def parse_args():
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--config-installer", action="store_true",
help="Open config web page, mainly for windows installer")
parser.add_argument("--load-installer-config", action="store_true",
help="Load all cmd args from installer config file")
parser.add_argument("--installer-config", default=None, help="Config file for windows installer")
parser.add_argument(
"--config-installer",
action="store_true",
help="Open config web page, mainly for windows installer",
)
parser.add_argument(
"--load-installer-config",
action="store_true",
help="Load all cmd args from installer config file",
)
parser.add_argument(
"--installer-config", default=None, help="Config file for windows installer"
)
parser.add_argument("--model", default="lama", choices=AVAILABLE_MODELS)
parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS)
parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP)
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
parser.add_argument("--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP)
parser.add_argument("--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP)
parser.add_argument("--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP)
parser.add_argument("--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES)
parser.add_argument(
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
)
parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
)
parser.add_argument(
"--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP
)
parser.add_argument(
"--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES
)
parser.add_argument("--gui", action="store_true", help=GUI_HELP)
parser.add_argument("--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP)
parser.add_argument(
"--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP
)
parser.add_argument(
"--gui-size",
default=[1600, 1000],
@ -41,8 +75,14 @@ def parse_args():
)
parser.add_argument("--input", type=str, default=None, help=INPUT_HELP)
parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP)
parser.add_argument("--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP)
parser.add_argument("--disable-model-switch", action="store_true", help="Disable model switch in frontend")
parser.add_argument(
"--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP
)
parser.add_argument(
"--disable-model-switch",
action="store_true",
help="Disable model switch in frontend",
)
parser.add_argument("--debug", action="store_true")
# useless args
@ -64,7 +104,7 @@ def parse_args():
parser.add_argument(
"--sd-enable-xformers",
action="store_true",
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers",
)
args = parser.parse_args()
@ -74,14 +114,18 @@ def parse_args():
if args.config_installer:
if args.installer_config is None:
parser.error(f"args.config_installer==True, must set args.installer_config to store config file")
parser.error(
f"args.config_installer==True, must set args.installer_config to store config file"
)
from lama_cleaner.web_config import main
logger.info(f"Launching installer web config page")
main(args.installer_config)
exit()
if args.load_installer_config:
from lama_cleaner.web_config import load_config
if args.installer_config and not os.path.exists(args.installer_config):
parser.error(f"args.installer_config={args.installer_config} not exists")
@ -93,9 +137,17 @@ def parse_args():
if args.device == "cuda":
import torch
if torch.cuda.is_available() is False:
parser.error(
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
)
if args.device == "mps":
if args.model not in MPS_SUPPORT_MODELS:
parser.error(
f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}"
)
if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir):
@ -115,7 +167,9 @@ def parse_args():
parser.error(f"invalid --input: {args.input} is not a valid image file")
else:
if args.output_dir is None:
parser.error(f"invalid --input: {args.input} is a directory, --output-dir is required")
parser.error(
f"invalid --input: {args.input} is a directory, --output-dir is required"
)
else:
output_dir = Path(args.output_dir)
if not output_dir.exists():
@ -123,6 +177,8 @@ def parse_args():
output_dir.mkdir(parents=True)
else:
if not output_dir.is_dir():
parser.error(f"invalid --output-dir: {output_dir} is not a directory")
parser.error(
f"invalid --output-dir: {output_dir} is not a directory"
)
return args

View File

@ -88,3 +88,8 @@ class Config(BaseModel):
paint_by_example_seed: int = 42
paint_by_example_match_histograms: bool = False
paint_by_example_example_image: Image = None
# InstructPix2Pix
p2p_steps: int = 50
p2p_image_guidance_scale: float = 1.5
p2p_guidance_scale: float = 7.5

View File

@ -19,6 +19,7 @@ from loguru import logger
from watchdog.events import FileSystemEventHandler
from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager
@ -31,7 +32,15 @@ try:
except:
pass
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
from flask import (
Flask,
request,
send_file,
cli,
make_response,
send_from_directory,
jsonify,
)
# Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
@ -42,6 +51,7 @@ from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
pil_to_bytes,
)
NUM_THREADS = str(multiprocessing.cpu_count())
@ -93,6 +103,25 @@ def diffuser_callback(i, t, latents):
# socketio.emit('diffusion_step', {'diffusion_step': step})
@app.route("/make_gif", methods=["POST"])
def make_gif():
input = request.files
filename = request.form["filename"]
origin_image_bytes = input["origin_img"].read()
clean_image_bytes = input["clean_img"].read()
origin_image, _ = load_img(origin_image_bytes)
clean_image, _ = load_img(clean_image_bytes)
gif_bytes = make_compare_gif(
Image.fromarray(origin_image), Image.fromarray(clean_image)
)
return send_file(
io.BytesIO(gif_bytes),
mimetype="image/gif",
as_attachment=True,
attachment_filename=filename,
)
@app.route("/save_image", methods=["POST"])
def save_image():
# all image in output directory
@ -100,12 +129,12 @@ def save_image():
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
thumb.save_to_output_directory(image, request.form["filename"])
return 'ok', 200
return "ok", 200
@app.route("/medias/<tab>")
def medias(tab):
if tab == 'image':
if tab == "image":
response = make_response(jsonify(thumb.media_names), 200)
else:
response = make_response(jsonify(thumb.output_media_names), 200)
@ -116,18 +145,18 @@ def medias(tab):
return response
@app.route('/media/<tab>/<filename>')
@app.route("/media/<tab>/<filename>")
def media_file(tab, filename):
if tab == 'image':
if tab == "image":
return send_from_directory(thumb.root_directory, filename)
return send_from_directory(thumb.output_dir, filename)
@app.route('/media_thumbnail/<tab>/<filename>')
@app.route("/media_thumbnail/<tab>/<filename>")
def media_thumbnail_file(tab, filename):
args = request.args
width = args.get('width')
height = args.get('height')
width = args.get("width")
height = args.get("height")
if width is None and height is None:
width = 256
if width:
@ -136,9 +165,11 @@ def media_thumbnail_file(tab, filename):
height = int(float(height))
directory = thumb.root_directory
if tab == 'output':
if tab == "output":
directory = thumb.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height)
thumb_filename, (width, height) = thumb.get_thumbnail(
directory, filename, width, height
)
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
response = make_response(send_file(thumb_filepath))
@ -152,13 +183,16 @@ def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True)
mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
return (
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
400,
)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
@ -171,7 +205,9 @@ def process():
size_limit = int(size_limit)
if "paintByExampleImage" in input:
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
paint_by_example_example_image, _ = load_img(
input["paintByExampleImage"].read()
)
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else:
paint_by_example_example_image = None
@ -200,13 +236,16 @@ def process():
sd_seed=form["sdSeed"],
sd_match_histograms=form["sdMatchHistograms"],
cv2_flag=form["cv2Flag"],
cv2_radius=form['cv2Radius'],
cv2_radius=form["cv2Radius"],
paint_by_example_steps=form["paintByExampleSteps"],
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
paint_by_example_seed=form["paintByExampleSeed"],
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
paint_by_example_example_image=paint_by_example_example_image,
p2p_steps=form["p2pSteps"],
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
p2p_guidance_scale=form["p2pGuidanceScale"],
)
if config.sd_seed == -1:
@ -235,6 +274,7 @@ def process():
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
@ -246,9 +286,15 @@ def process():
ext = get_image_ext(origin_image_bytes)
if exif is not None:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, exif=exif))
else:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext))
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
# io.BytesIO(numpy_to_bytes(res_np_img, ext)),
bytes_io,
mimetype=f"image/{ext}",
)
)
@ -261,7 +307,7 @@ def interactive_seg():
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
if 'mask' in input:
if "mask" in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
@ -269,14 +315,16 @@ def interactive_seg():
_clicks = json.loads(request.form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
io.BytesIO(numpy_to_bytes(new_mask, "png")),
mimetype=f"image/png",
)
)
@ -290,13 +338,13 @@ def current_model():
@app.route("/is_disable_model_switch")
def get_is_disable_model_switch():
res = 'true' if is_disable_model_switch else 'false'
res = "true" if is_disable_model_switch else "false"
return res, 200
@app.route("/is_enable_file_manager")
def get_is_enable_file_manager():
res = 'true' if is_enable_file_manager else 'false'
res = "true" if is_enable_file_manager else "false"
return res, 200
@ -365,14 +413,18 @@ def main(args):
is_disable_model_switch = args.disable_model_switch
is_desktop = args.gui
if is_disable_model_switch:
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable")
logger.info(
f"Start with --disable-model-switch, model switch on frontend is disable"
)
if args.input and os.path.isdir(args.input):
logger.info(f"Initialize file manager")
thumb = FileManager(app)
is_enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails')
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
args.output_dir, "lama_cleaner_thumbnails"
)
thumb.output_dir = Path(args.output_dir)
# thumb.start()
# try:
@ -408,8 +460,12 @@ def main(args):
from flaskwebgui import FlaskUI
ui = FlaskUI(
app, width=app_width, height=app_height, host=args.host, port=args.port,
close_server_on_exit=not args.no_gui_auto_close
app,
width=app_width,
height=app_height,
host=args.host,
port=args.port,
close_server_on_exit=not args.no_gui_auto_close,
)
ui.run()
else:

View File

@ -0,0 +1,62 @@
from pathlib import Path
import pytest
import torch
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.tests.test_model import get_config, assert_equal
from lama_cleaner.schema import HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@pytest.mark.parametrize("disable_nsfw", [True, False])
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_instruct_pix2pix(disable_nsfw, cpu_offload):
sd_steps = 50 if device == 'cuda' else 1
model = ModelManager(name="instruct_pix2pix",
device=torch.device(device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False,
cpu_offload=cpu_offload)
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1)
name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
assert_equal(
model,
cfg,
f"instruct_pix2pix_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3
)
@pytest.mark.parametrize("disable_nsfw", [False])
@pytest.mark.parametrize("cpu_offload", [False])
def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
sd_steps = 50 if device == 'cuda' else 1
model = ModelManager(name="instruct_pix2pix",
device=torch.device(device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=False,
cpu_offload=cpu_offload)
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps)
name = f"snow"
assert_equal(
model,
cfg,
f"instruct_pix2pix_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -0,0 +1,42 @@
import io
from pathlib import Path
from PIL import Image
from lama_cleaner.helper import pil_to_bytes
current_dir = Path(__file__).parent.absolute().resolve()
png_img_p = current_dir / "image.png"
jpg_img_p = current_dir / "bunny.jpeg"
def print_exif(exif):
for k, v in exif.items():
print(f"{k}: {v}")
def test_png():
img = Image.open(png_img_p)
exif = img.getexif()
print_exif(exif)
pil_bytes = pil_to_bytes(img, ext="png", exif=exif)
res_img = Image.open(io.BytesIO(pil_bytes))
res_exif = res_img.getexif()
assert dict(exif) == dict(res_exif)
def test_jpeg():
img = Image.open(jpg_img_p)
exif = img.getexif()
print_exif(exif)
pil_bytes = pil_to_bytes(img, ext="jpeg", exif=exif)
res_img = Image.open(io.BytesIO(pil_bytes))
res_exif = res_img.getexif()
assert dict(exif) == dict(res_exif)

View File

@ -8,33 +8,37 @@ from lama_cleaner.schema import HDStrategy, SDSampler
from lama_cleaner.tests.test_model import get_config, assert_equal
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
@pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("cpu_textencoder", [True, False])
@pytest.mark.parametrize("disable_nsfw", [True, False])
def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
def test_runway_sd_1_5_ddim(
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
):
def callback(i, t, latents):
print(f"sd_step_{i}")
pass
if sd_device == 'cuda' and not torch.cuda.is_available():
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder,
callback=callback)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder,
callback=callback,
)
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
@ -45,31 +49,35 @@ def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disab
f"runway_sd_{strategy.capitalize()}_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3
fx=1.3,
)
@pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a])
@pytest.mark.parametrize(
"sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a]
)
@pytest.mark.parametrize("cpu_textencoder", [False])
@pytest.mark.parametrize("disable_nsfw", [True])
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
def callback(i, t, latents):
print(f"sd_step_{i}")
if sd_device == 'cuda' and not torch.cuda.is_available():
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder,
callback=callback)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder,
callback=callback,
)
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
@ -80,35 +88,37 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
f"runway_sd_{strategy.capitalize()}_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3
fx=1.3,
)
@pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
def callback(i, t, latents):
pass
if sd_device == 'cuda' and not torch.cuda.is_available():
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=False,
sd_cpu_textencoder=False,
callback=callback)
sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=False,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
strategy,
sd_steps=sd_steps,
prompt='Face of a fox, high resolution, sitting on a park bench',
negative_prompt='orange, yellow, small',
prompt="Face of a fox, high resolution, sitting on a park bench",
negative_prompt="orange, yellow, small",
sd_sampler=sampler,
sd_match_histograms=True
sd_match_histograms=True,
)
name = f"{sampler}_negative_prompt"
@ -119,27 +129,33 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
f"runway_sd_{strategy.capitalize()}_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1
fx=1,
)
@pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
@pytest.mark.parametrize("cpu_textencoder", [False])
@pytest.mark.parametrize("disable_nsfw", [False])
def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
if sd_device == 'cuda' and not torch.cuda.is_available():
def test_runway_sd_1_5_sd_scale(
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
):
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder,
)
cfg = get_config(
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
@ -150,26 +166,30 @@ def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, d
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3
fx=1.3,
)
@pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
if sd_device == 'cuda' and not torch.cuda.is_available():
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
sd_steps = 50 if sd_device == "cuda" else 1
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
)
cfg = get_config(
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
@ -182,27 +202,3 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
@pytest.mark.parametrize("sd_device", ['cpu'])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler):
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
disable_nsfw=False,
sd_cpu_textencoder=False,
cpu_offload=True)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
assert_equal(
model,
cfg,
f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload_cpu_device.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -6,9 +6,24 @@ import gradio as gr
from loguru import logger
from pydantic import BaseModel
from lama_cleaner.const import AVAILABLE_MODELS, AVAILABLE_DEVICES, CPU_OFFLOAD_HELP, NO_HALF_HELP, DISABLE_NSFW_HELP, \
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, OUTPUT_DIR_HELP, INPUT_HELP, \
GUI_HELP, DEFAULT_MODEL, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR
from lama_cleaner.const import (
AVAILABLE_MODELS,
AVAILABLE_DEVICES,
CPU_OFFLOAD_HELP,
NO_HALF_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_MODEL,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS,
)
_config_file = None
@ -33,16 +48,28 @@ class Config(BaseModel):
def load_config(installer_config: str):
if os.path.exists(installer_config):
with open(installer_config, "r", encoding='utf-8') as f:
with open(installer_config, "r", encoding="utf-8") as f:
return Config(**json.load(f))
else:
return Config()
def save_config(
host, port, model, device, gui, no_gui_auto_close, no_half, cpu_offload,
disable_nsfw, sd_cpu_textencoder, enable_xformers, local_files_only,
model_dir, input, output_dir
host,
port,
model,
device,
gui,
no_gui_auto_close,
no_half,
cpu_offload,
disable_nsfw,
sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,
output_dir,
):
config = Config(**locals())
print(config)
@ -63,6 +90,7 @@ def save_config(
def close_server(*args):
# TODO: make close both browser and server works
import os, signal
pid = os.getpid()
os.kill(pid, signal.SIGUSR1)
@ -86,33 +114,53 @@ def main(config_file: str):
port = gr.Number(init_config.port, label="Port", precision=0)
with gr.Row():
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
device = gr.Radio(AVAILABLE_DEVICES, label="Device", value=init_config.device)
device = gr.Radio(
AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device
)
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
no_gui_auto_close = gr.Checkbox(init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}")
no_gui_auto_close = gr.Checkbox(
init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}"
)
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}")
disable_nsfw = gr.Checkbox(init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}")
sd_cpu_textencoder = gr.Checkbox(init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}")
enable_xformers = gr.Checkbox(init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}")
local_files_only = gr.Checkbox(init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}")
disable_nsfw = gr.Checkbox(
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
)
sd_cpu_textencoder = gr.Checkbox(
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
)
enable_xformers = gr.Checkbox(
init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}"
)
local_files_only = gr.Checkbox(
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
)
model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}")
input = gr.Textbox(init_config.input, label=f"Input file or directory. {INPUT_HELP}")
output_dir = gr.Textbox(init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}")
save_btn.click(save_config, [
host,
port,
model,
device,
gui,
no_gui_auto_close,
no_half,
cpu_offload,
disable_nsfw,
sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,
output_dir,
], message)
input = gr.Textbox(
init_config.input, label=f"Input file or directory. {INPUT_HELP}"
)
output_dir = gr.Textbox(
init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}"
)
save_btn.click(
save_config,
[
host,
port,
model,
device,
gui,
no_gui_auto_close,
no_half,
cpu_offload,
disable_nsfw,
sd_cpu_textencoder,
enable_xformers,
local_files_only,
model_dir,
input,
output_dir,
],
message,
)
demo.launch(inbrowser=True, show_api=False)

View File

@ -1,6 +1,7 @@
torch>=1.9.0
opencv-python
flask_cors
Jinja2==2.11.3
flask==1.1.4
flaskwebgui==0.3.5
tqdm
@ -11,7 +12,8 @@ pytest
yacs
markupsafe==2.0.1
scikit-image==0.19.3
diffusers[torch]==0.10.2
diffusers[torch]==0.12.1
transformers>=4.25.1
watchdog==2.2.1
gradio
gradio
piexif==1.1.3

View File

@ -21,7 +21,7 @@ def load_requirements():
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
setuptools.setup(
name="lama-cleaner",
version="0.34.0",
version="0.35.0",
author="PanicByte",
author_email="cwq1913@gmail.com",
description="Image inpainting tool powered by SOTA AI Model",