Merge branch 'make_gif_share'
This commit is contained in:
commit
f837e4be8a
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,3 +7,4 @@ build
|
|||||||
!lama_cleaner/app/build
|
!lama_cleaner/app/build
|
||||||
dist/
|
dist/
|
||||||
lama_cleaner.egg-info/
|
lama_cleaner.egg-info/
|
||||||
|
venv/
|
||||||
|
@ -59,6 +59,6 @@ Only needed if you plan to modify the frontend and recompile yourself.
|
|||||||
Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their
|
Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their
|
||||||
great online services [here](https://cleanup.pictures/).
|
great online services [here](https://cleanup.pictures/).
|
||||||
|
|
||||||
- Install dependencies:`cd lama_cleaner/app/ && yarn`
|
- Install dependencies:`cd lama_cleaner/app/ && pnpm install`
|
||||||
- Start development server: `yarn start`
|
- Start development server: `pnpm start`
|
||||||
- Build: `yarn build`
|
- Build: `pnpm build`
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
{
|
{
|
||||||
"files": {
|
"files": {
|
||||||
"main.css": "/static/css/main.c28d98ca.css",
|
"main.css": "/static/css/main.e24c9a9b.css",
|
||||||
"main.js": "/static/js/main.1bd455bc.js",
|
"main.js": "/static/js/main.23732b19.js",
|
||||||
|
"static/media/coffee-machine-lineal.gif": "/static/media/coffee-machine-lineal.ee32631219cc3986f861.gif",
|
||||||
"static/media/WorkSans-SemiBold.ttf": "/static/media/WorkSans-SemiBold.1e98db4eb705b586728e.ttf",
|
"static/media/WorkSans-SemiBold.ttf": "/static/media/WorkSans-SemiBold.1e98db4eb705b586728e.ttf",
|
||||||
"static/media/WorkSans-Bold.ttf": "/static/media/WorkSans-Bold.2bea7a7f7d052c74da25.ttf",
|
"static/media/WorkSans-Bold.ttf": "/static/media/WorkSans-Bold.2bea7a7f7d052c74da25.ttf",
|
||||||
"static/media/WorkSans-Regular.ttf": "/static/media/WorkSans-Regular.bb287b894b27372d8ea7.ttf",
|
"static/media/WorkSans-Regular.ttf": "/static/media/WorkSans-Regular.bb287b894b27372d8ea7.ttf",
|
||||||
@ -9,7 +10,7 @@
|
|||||||
"index.html": "/index.html"
|
"index.html": "/index.html"
|
||||||
},
|
},
|
||||||
"entrypoints": [
|
"entrypoints": [
|
||||||
"static/css/main.c28d98ca.css",
|
"static/css/main.e24c9a9b.css",
|
||||||
"static/js/main.1bd455bc.js"
|
"static/js/main.23732b19.js"
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -1 +1 @@
|
|||||||
<!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><script defer="defer" src="/static/js/main.1bd455bc.js"></script><link href="/static/css/main.c28d98ca.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
<!doctype html><html lang="en"><head><meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate"/><meta http-equiv="Pragma" content="no-cache"/><meta http-equiv="Expires" content="0"/><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by SOTA AI model</title><script defer="defer" src="/static/js/main.23732b19.js"></script><link href="/static/css/main.e24c9a9b.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|
File diff suppressed because one or more lines are too long
1
lama_cleaner/app/build/static/css/main.e24c9a9b.css
Normal file
1
lama_cleaner/app/build/static/css/main.e24c9a9b.css
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
lama_cleaner/app/build/static/js/main.23732b19.js
Normal file
2
lama_cleaner/app/build/static/js/main.23732b19.js
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
After Width: | Height: | Size: 422 KiB |
@ -2,7 +2,7 @@
|
|||||||
"name": "lama-cleaner",
|
"name": "lama-cleaner",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"private": true,
|
"private": true,
|
||||||
"proxy": "http://localhost:8080",
|
"proxy": "http://127.0.0.1:8080",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@babel/core": "^7.16.0",
|
"@babel/core": "^7.16.0",
|
||||||
"@heroicons/react": "^2.0.0",
|
"@heroicons/react": "^2.0.0",
|
||||||
|
14605
lama_cleaner/app/pnpm-lock.yaml
Normal file
14605
lama_cleaner/app/pnpm-lock.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
|||||||
import { Rect, Settings } from '../store/Atoms'
|
import { Rect, Settings } from '../store/Atoms'
|
||||||
import { dataURItoBlob, srcToFile } from '../utils'
|
import { dataURItoBlob, loadImage, srcToFile } from '../utils'
|
||||||
|
|
||||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||||
|
|
||||||
@ -82,6 +82,11 @@ export default async function inpaint(
|
|||||||
fd.append('paintByExampleImage', paintByExampleImage)
|
fd.append('paintByExampleImage', paintByExampleImage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstructPix2Pix
|
||||||
|
fd.append('p2pSteps', settings.p2pSteps.toString())
|
||||||
|
fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString())
|
||||||
|
fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString())
|
||||||
|
|
||||||
if (sizeLimit === undefined) {
|
if (sizeLimit === undefined) {
|
||||||
fd.append('sizeLimit', '1080')
|
fd.append('sizeLimit', '1080')
|
||||||
} else {
|
} else {
|
||||||
@ -223,3 +228,34 @@ export async function downloadToOutput(
|
|||||||
throw new Error(`Something went wrong: ${error}`)
|
throw new Error(`Something went wrong: ${error}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function makeGif(
|
||||||
|
originFile: File,
|
||||||
|
cleanImage: HTMLImageElement,
|
||||||
|
filename: string,
|
||||||
|
mimeType: string
|
||||||
|
) {
|
||||||
|
const cleanFile = await srcToFile(cleanImage.src, filename, mimeType)
|
||||||
|
const fd = new FormData()
|
||||||
|
fd.append('origin_img', originFile)
|
||||||
|
fd.append('clean_img', cleanFile)
|
||||||
|
fd.append('filename', filename)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${API_ENDPOINT}/make_gif`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: fd,
|
||||||
|
})
|
||||||
|
if (!res.ok) {
|
||||||
|
const errMsg = await res.text()
|
||||||
|
throw new Error(errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
const blob = await res.blob()
|
||||||
|
const newImage = new Image()
|
||||||
|
await loadImage(newImage, URL.createObjectURL(blob))
|
||||||
|
return newImage
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Something went wrong: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ import React, { useState } from 'react'
|
|||||||
import { Coffee } from 'react-feather'
|
import { Coffee } from 'react-feather'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
import Modal from '../shared/Modal'
|
import Modal from '../shared/Modal'
|
||||||
|
import CoffeeMachineGif from '../../media/coffee-machine-lineal.gif'
|
||||||
|
|
||||||
const CoffeeIcon = () => {
|
const CoffeeIcon = () => {
|
||||||
const [show, setShow] = useState(false)
|
const [show, setShow] = useState(false)
|
||||||
@ -24,10 +25,26 @@ const CoffeeIcon = () => {
|
|||||||
show={show}
|
show={show}
|
||||||
showCloseIcon={false}
|
showCloseIcon={false}
|
||||||
>
|
>
|
||||||
<h4 style={{ lineHeight: '24px' }}>
|
<div
|
||||||
Hi there, If you found my project is useful, and want to help keep it
|
style={{
|
||||||
alive please consider donating! Thank you for your support!
|
display: 'flex',
|
||||||
</h4>
|
flexDirection: 'column',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<h4 style={{ lineHeight: '24px' }}>
|
||||||
|
Hi, if you found my project is useful, please conside buy me a
|
||||||
|
coffee to support my work. Thanks!
|
||||||
|
</h4>
|
||||||
|
<img
|
||||||
|
src={CoffeeMachineGif}
|
||||||
|
alt="coffee machine"
|
||||||
|
style={{
|
||||||
|
height: 150,
|
||||||
|
objectFit: 'contain',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
style={{
|
style={{
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
@ -53,7 +70,6 @@ const CoffeeIcon = () => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
Sure
|
Sure
|
||||||
<Coffee />
|
|
||||||
</div>
|
</div>
|
||||||
</Button>
|
</Button>
|
||||||
</a>
|
</a>
|
||||||
|
@ -55,10 +55,9 @@
|
|||||||
position: fixed;
|
position: fixed;
|
||||||
bottom: 0.5rem;
|
bottom: 0.5rem;
|
||||||
border-radius: 3rem;
|
border-radius: 3rem;
|
||||||
padding: 0.6rem 3rem;
|
padding: 0.6rem 32px;
|
||||||
display: grid;
|
display: flex;
|
||||||
grid-template-areas: 'toolkit-size-selector toolkit-brush-slider toolkit-btns';
|
gap: 16px;
|
||||||
column-gap: 2rem;
|
|
||||||
align-items: center;
|
align-items: center;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
backdrop-filter: blur(12px);
|
backdrop-filter: blur(12px);
|
||||||
@ -96,10 +95,8 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.editor-toolkit-btns {
|
.editor-toolkit-btns {
|
||||||
grid-area: toolkit-btns;
|
display: flex;
|
||||||
display: grid;
|
gap: 12px;
|
||||||
grid-auto-flow: column;
|
|
||||||
column-gap: 1rem;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.brush-shape {
|
.brush-shape {
|
||||||
|
@ -16,10 +16,11 @@ import {
|
|||||||
TransformComponent,
|
TransformComponent,
|
||||||
TransformWrapper,
|
TransformWrapper,
|
||||||
} from 'react-zoom-pan-pinch'
|
} from 'react-zoom-pan-pinch'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
|
||||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||||
import inpaint, {
|
import inpaint, {
|
||||||
downloadToOutput,
|
downloadToOutput,
|
||||||
|
makeGif,
|
||||||
postInteractiveSeg,
|
postInteractiveSeg,
|
||||||
} from '../../adapters/inpainting'
|
} from '../../adapters/inpainting'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
@ -43,10 +44,12 @@ import {
|
|||||||
imageHeightState,
|
imageHeightState,
|
||||||
imageWidthState,
|
imageWidthState,
|
||||||
interactiveSegClicksState,
|
interactiveSegClicksState,
|
||||||
|
isDiffusionModelsState,
|
||||||
isInpaintingState,
|
isInpaintingState,
|
||||||
isInteractiveSegRunningState,
|
isInteractiveSegRunningState,
|
||||||
isInteractiveSegState,
|
isInteractiveSegState,
|
||||||
isPaintByExampleState,
|
isPaintByExampleState,
|
||||||
|
isPix2PixState,
|
||||||
isSDState,
|
isSDState,
|
||||||
negativePropmtState,
|
negativePropmtState,
|
||||||
propmtState,
|
propmtState,
|
||||||
@ -66,6 +69,7 @@ import FileSelect from '../FileSelect/FileSelect'
|
|||||||
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
|
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
|
||||||
import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions'
|
import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions'
|
||||||
import InteractiveSegReplaceModal from '../InteractiveSeg/ReplaceModal'
|
import InteractiveSegReplaceModal from '../InteractiveSeg/ReplaceModal'
|
||||||
|
import MakeGIF from './MakeGIF'
|
||||||
|
|
||||||
const TOOLBAR_SIZE = 200
|
const TOOLBAR_SIZE = 200
|
||||||
const MIN_BRUSH_SIZE = 10
|
const MIN_BRUSH_SIZE = 10
|
||||||
@ -112,11 +116,11 @@ export default function Editor() {
|
|||||||
const settings = useRecoilValue(settingState)
|
const settings = useRecoilValue(settingState)
|
||||||
const [seedVal, setSeed] = useRecoilState(seedState)
|
const [seedVal, setSeed] = useRecoilState(seedState)
|
||||||
const croperRect = useRecoilValue(croperState)
|
const croperRect = useRecoilValue(croperState)
|
||||||
const [toastVal, setToastState] = useRecoilState(toastState)
|
const setToastState = useSetRecoilState(toastState)
|
||||||
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
|
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
|
||||||
const runMannually = useRecoilValue(runManuallyState)
|
const runMannually = useRecoilValue(runManuallyState)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
|
||||||
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
const isPix2Pix = useRecoilValue(isPix2PixState)
|
||||||
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
|
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
|
||||||
isInteractiveSegState
|
isInteractiveSegState
|
||||||
)
|
)
|
||||||
@ -181,8 +185,8 @@ export default function Editor() {
|
|||||||
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
|
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
|
||||||
const enableFileManager = useRecoilValue(enableFileManagerState)
|
const enableFileManager = useRecoilValue(enableFileManagerState)
|
||||||
|
|
||||||
const [imageWidth, setImageWidth] = useRecoilState(imageWidthState)
|
const setImageWidth = useSetRecoilState(imageWidthState)
|
||||||
const [imageHeight, setImageHeight] = useRecoilState(imageHeightState)
|
const setImageHeight = useSetRecoilState(imageHeightState)
|
||||||
const app = useRecoilValue(appState)
|
const app = useRecoilValue(appState)
|
||||||
|
|
||||||
const draw = useCallback(
|
const draw = useCallback(
|
||||||
@ -253,13 +257,40 @@ export default function Editor() {
|
|||||||
_lineGroups.forEach(lineGroup => {
|
_lineGroups.forEach(lineGroup => {
|
||||||
drawLines(ctx, lineGroup, 'white')
|
drawLines(ctx, lineGroup, 'white')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if (
|
||||||
|
(maskImage === undefined || maskImage === null) &&
|
||||||
|
_lineGroups.length === 1 &&
|
||||||
|
_lineGroups[0].length === 0 &&
|
||||||
|
isPix2Pix
|
||||||
|
) {
|
||||||
|
// For InstructPix2Pix without mask
|
||||||
|
drawLines(
|
||||||
|
ctx,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
size: 9999999999,
|
||||||
|
pts: [
|
||||||
|
{ x: 0, y: 0 },
|
||||||
|
{ x: original.naturalWidth, y: 0 },
|
||||||
|
{ x: original.naturalWidth, y: original.naturalHeight },
|
||||||
|
{ x: 0, y: original.naturalHeight },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'white'
|
||||||
|
)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[context, maskCanvas]
|
[context, maskCanvas, isPix2Pix]
|
||||||
)
|
)
|
||||||
|
|
||||||
const hadDrawSomething = useCallback(() => {
|
const hadDrawSomething = useCallback(() => {
|
||||||
|
if (isPix2Pix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return curLineGroup.length !== 0
|
return curLineGroup.length !== 0
|
||||||
}, [curLineGroup])
|
}, [curLineGroup, isPix2Pix])
|
||||||
|
|
||||||
const drawOnCurrentRender = useCallback(
|
const drawOnCurrentRender = useCallback(
|
||||||
(lineGroup: LineGroup) => {
|
(lineGroup: LineGroup) => {
|
||||||
@ -424,6 +455,8 @@ export default function Editor() {
|
|||||||
} else if (prevInteractiveSegMask) {
|
} else if (prevInteractiveSegMask) {
|
||||||
// 使用上一次 IS 的 mask 生成
|
// 使用上一次 IS 的 mask 生成
|
||||||
runInpainting(false, undefined, prevInteractiveSegMask)
|
runInpainting(false, undefined, prevInteractiveSegMask)
|
||||||
|
} else if (isPix2Pix) {
|
||||||
|
runInpainting(false, undefined, null)
|
||||||
} else {
|
} else {
|
||||||
setToastState({
|
setToastState({
|
||||||
open: true,
|
open: true,
|
||||||
@ -839,7 +872,7 @@ export default function Editor() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(isSD || isPaintByExample) &&
|
isDiffusionModels &&
|
||||||
settings.showCroper &&
|
settings.showCroper &&
|
||||||
isOutsideCroper(mouseXY(ev))
|
isOutsideCroper(mouseXY(ev))
|
||||||
) {
|
) {
|
||||||
@ -1385,7 +1418,7 @@ export default function Editor() {
|
|||||||
minHeight={Math.min(256, original.naturalHeight)}
|
minHeight={Math.min(256, original.naturalHeight)}
|
||||||
minWidth={Math.min(256, original.naturalWidth)}
|
minWidth={Math.min(256, original.naturalWidth)}
|
||||||
scale={scale}
|
scale={scale}
|
||||||
show={(isSD || isPaintByExample) && settings.showCroper}
|
show={isDiffusionModels && settings.showCroper}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{isInteractiveSeg ? <InteractiveSeg /> : <></>}
|
{isInteractiveSeg ? <InteractiveSeg /> : <></>}
|
||||||
@ -1439,7 +1472,7 @@ export default function Editor() {
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
<div className="editor-toolkit-panel">
|
<div className="editor-toolkit-panel">
|
||||||
{isSD || isPaintByExample || file === undefined ? (
|
{isDiffusionModels || file === undefined ? (
|
||||||
<></>
|
<></>
|
||||||
) : (
|
) : (
|
||||||
<SizeSelector
|
<SizeSelector
|
||||||
@ -1534,6 +1567,7 @@ export default function Editor() {
|
|||||||
}}
|
}}
|
||||||
disabled={renders.length === 0}
|
disabled={renders.length === 0}
|
||||||
/>
|
/>
|
||||||
|
<MakeGIF renders={renders} />
|
||||||
<Button
|
<Button
|
||||||
toolTip="Save Image"
|
toolTip="Save Image"
|
||||||
icon={<ArrowDownTrayIcon />}
|
icon={<ArrowDownTrayIcon />}
|
||||||
@ -1541,7 +1575,7 @@ export default function Editor() {
|
|||||||
onClick={download}
|
onClick={download}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{settings.runInpaintingManually && !isSD && !isPaintByExample && (
|
{settings.runInpaintingManually && !isDiffusionModels && (
|
||||||
<Button
|
<Button
|
||||||
toolTip="Run Inpainting"
|
toolTip="Run Inpainting"
|
||||||
icon={
|
icon={
|
||||||
|
114
lama_cleaner/app/src/components/Editor/MakeGIF.tsx
Normal file
114
lama_cleaner/app/src/components/Editor/MakeGIF.tsx
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import React, { useState } from 'react'
|
||||||
|
import { GifIcon } from '@heroicons/react/24/outline'
|
||||||
|
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'
|
||||||
|
import Button from '../shared/Button'
|
||||||
|
import { fileState, gifImageState, toastState } from '../../store/Atoms'
|
||||||
|
import { makeGif } from '../../adapters/inpainting'
|
||||||
|
import Modal from '../shared/Modal'
|
||||||
|
import { LoadingIcon } from '../shared/Toast'
|
||||||
|
import { downloadImage } from '../../utils'
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
renders: HTMLImageElement[]
|
||||||
|
}
|
||||||
|
|
||||||
|
const MakeGIF = (props: Props) => {
|
||||||
|
const { renders } = props
|
||||||
|
const [gifImg, setGifImg] = useRecoilState(gifImageState)
|
||||||
|
const file = useRecoilValue(fileState)
|
||||||
|
const setToastState = useSetRecoilState(toastState)
|
||||||
|
const [show, setShow] = useState(false)
|
||||||
|
|
||||||
|
const handleOnClose = () => {
|
||||||
|
setShow(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDownload = () => {
|
||||||
|
if (gifImg) {
|
||||||
|
const name = file.name.replace(/\.[^/.]+$/, '.gif')
|
||||||
|
downloadImage(gifImg.src, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<Button
|
||||||
|
toolTip="Make Gif"
|
||||||
|
icon={<GifIcon />}
|
||||||
|
disabled={!renders.length}
|
||||||
|
onClick={async () => {
|
||||||
|
setShow(true)
|
||||||
|
setGifImg(null)
|
||||||
|
try {
|
||||||
|
const gif = await makeGif(
|
||||||
|
file,
|
||||||
|
renders[renders.length - 1],
|
||||||
|
file.name,
|
||||||
|
file.type
|
||||||
|
)
|
||||||
|
if (gif) {
|
||||||
|
setGifImg(gif)
|
||||||
|
}
|
||||||
|
} catch (e: any) {
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: e.message ? e.message : e.toString(),
|
||||||
|
state: 'error',
|
||||||
|
duration: 2000,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Modal
|
||||||
|
onClose={handleOnClose}
|
||||||
|
title="GIF"
|
||||||
|
className="modal-setting"
|
||||||
|
show={show}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
flexDirection: 'column',
|
||||||
|
gap: 16,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{gifImg ? (
|
||||||
|
<img src={gifImg.src} style={{ borderRadius: 8 }} alt="gif" />
|
||||||
|
) : (
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
gap: 8,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<LoadingIcon />
|
||||||
|
Generating GIF...
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{gifImg && (
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
width: '100%',
|
||||||
|
justifyContent: 'flex-end',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: '12px',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Button onClick={handleDownload} border>
|
||||||
|
Download
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MakeGIF
|
@ -7,6 +7,7 @@ import {
|
|||||||
enableFileManagerState,
|
enableFileManagerState,
|
||||||
fileState,
|
fileState,
|
||||||
isInpaintingState,
|
isInpaintingState,
|
||||||
|
isPix2PixState,
|
||||||
isSDState,
|
isSDState,
|
||||||
maskState,
|
maskState,
|
||||||
runManuallyState,
|
runManuallyState,
|
||||||
@ -30,6 +31,7 @@ const Header = () => {
|
|||||||
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
|
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
|
||||||
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
|
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isSD = useRecoilValue(isSDState)
|
||||||
|
const isPix2Pix = useRecoilValue(isPix2PixState)
|
||||||
const runManually = useRecoilValue(runManuallyState)
|
const runManually = useRecoilValue(runManuallyState)
|
||||||
const [openMaskPopover, setOpenMaskPopover] = useState(false)
|
const [openMaskPopover, setOpenMaskPopover] = useState(false)
|
||||||
const [showFileManager, setShowFileManager] =
|
const [showFileManager, setShowFileManager] =
|
||||||
@ -172,7 +174,7 @@ const Header = () => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{isSD && file ? <PromptInput /> : <></>}
|
{(isSD || isPix2Pix) && file ? <PromptInput /> : <></>}
|
||||||
|
|
||||||
<div className="header-icons-wrapper">
|
<div className="header-icons-wrapper">
|
||||||
<CoffeeIcon />
|
<CoffeeIcon />
|
||||||
|
@ -179,28 +179,16 @@ function ModelSettingBlock() {
|
|||||||
|
|
||||||
const renderOptionDesc = (): ReactNode => {
|
const renderOptionDesc = (): ReactNode => {
|
||||||
switch (setting.model) {
|
switch (setting.model) {
|
||||||
case AIModel.LAMA:
|
|
||||||
return undefined
|
|
||||||
case AIModel.LDM:
|
case AIModel.LDM:
|
||||||
return renderLDMModelDesc()
|
return renderLDMModelDesc()
|
||||||
case AIModel.ZITS:
|
case AIModel.ZITS:
|
||||||
return renderZITSModelDesc()
|
return renderZITSModelDesc()
|
||||||
case AIModel.MAT:
|
|
||||||
return undefined
|
|
||||||
case AIModel.FCF:
|
case AIModel.FCF:
|
||||||
return renderFCFModelDesc()
|
return renderFCFModelDesc()
|
||||||
case AIModel.SD15:
|
|
||||||
return undefined
|
|
||||||
case AIModel.SD2:
|
|
||||||
return undefined
|
|
||||||
case AIModel.PAINT_BY_EXAMPLE:
|
|
||||||
return undefined
|
|
||||||
case AIModel.Mange:
|
|
||||||
return undefined
|
|
||||||
case AIModel.CV2:
|
case AIModel.CV2:
|
||||||
return renderOpenCV2Desc()
|
return renderOpenCV2Desc()
|
||||||
default:
|
default:
|
||||||
return <></>
|
return undefined
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,6 +254,12 @@ function ModelSettingBlock() {
|
|||||||
'https://arxiv.org/abs/2211.13227',
|
'https://arxiv.org/abs/2211.13227',
|
||||||
'https://github.com/Fantasy-Studio/Paint-by-Example'
|
'https://github.com/Fantasy-Studio/Paint-by-Example'
|
||||||
)
|
)
|
||||||
|
case AIModel.PIX2PIX:
|
||||||
|
return renderModelDesc(
|
||||||
|
'InstructPix2Pix',
|
||||||
|
'https://arxiv.org/abs/2211.09800',
|
||||||
|
'https://github.com/timothybrooks/instruct-pix2pix'
|
||||||
|
)
|
||||||
default:
|
default:
|
||||||
return <></>
|
return <></>
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import React from 'react'
|
import React from 'react'
|
||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
|
import { Cog6ToothIcon } from '@heroicons/react/24/outline'
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState } from '../../store/Atoms'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
|
|
||||||
@ -16,29 +17,7 @@ const SettingIcon = () => {
|
|||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
toolTip="Settings"
|
toolTip="Settings"
|
||||||
style={{ border: 0 }}
|
style={{ border: 0 }}
|
||||||
icon={
|
icon={<Cog6ToothIcon />}
|
||||||
<svg
|
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
|
||||||
fill="none"
|
|
||||||
role="img"
|
|
||||||
width="28"
|
|
||||||
height="28"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
strokeWidth="2"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
strokeLinecap="round"
|
|
||||||
strokeLinejoin="round"
|
|
||||||
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
|
|
||||||
/>
|
|
||||||
<path
|
|
||||||
strokeLinecap="round"
|
|
||||||
strokeLinejoin="round"
|
|
||||||
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
color: var(--modal-text-color);
|
color: var(--modal-text-color);
|
||||||
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
||||||
width: 680px;
|
width: 680px;
|
||||||
min-height: 420px;
|
|
||||||
|
|
||||||
@include mobile {
|
@include mobile {
|
||||||
display: grid;
|
display: grid;
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import React from 'react'
|
import React from 'react'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import {
|
import {
|
||||||
|
isDiffusionModelsState,
|
||||||
isPaintByExampleState,
|
isPaintByExampleState,
|
||||||
isSDState,
|
isSDState,
|
||||||
settingState,
|
settingState,
|
||||||
@ -28,7 +29,7 @@ export default function SettingModal(props: SettingModalProps) {
|
|||||||
const { onClose } = props
|
const { onClose } = props
|
||||||
const [setting, setSettingState] = useRecoilState(settingState)
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isSD = useRecoilValue(isSDState)
|
||||||
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
const isDiffusionModels = useRecoilValue(isDiffusionModelsState)
|
||||||
|
|
||||||
const handleOnClose = () => {
|
const handleOnClose = () => {
|
||||||
setSettingState(old => {
|
setSettingState(old => {
|
||||||
@ -56,9 +57,9 @@ export default function SettingModal(props: SettingModalProps) {
|
|||||||
show={setting.show}
|
show={setting.show}
|
||||||
>
|
>
|
||||||
<DownloadMaskSettingBlock />
|
<DownloadMaskSettingBlock />
|
||||||
{isSD || isPaintByExample ? <></> : <ManualRunInpaintingSettingBlock />}
|
{isDiffusionModels ? <></> : <ManualRunInpaintingSettingBlock />}
|
||||||
<ModelSettingBlock />
|
<ModelSettingBlock />
|
||||||
{isSD ? <></> : <HDSettingBlock />}
|
{isDiffusionModels ? <></> : <HDSettingBlock />}
|
||||||
</Modal>
|
</Modal>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -15,13 +15,14 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.shortcut-options {
|
.shortcut-options {
|
||||||
display: grid;
|
display: flex;
|
||||||
row-gap: 1rem;
|
gap: 48px;
|
||||||
|
flex-direction: row;
|
||||||
|
|
||||||
.shortcut-option {
|
.shortcut-option {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: repeat(2, auto);
|
grid-template-columns: repeat(2, auto);
|
||||||
column-gap: 6rem;
|
column-gap: 2rem;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
|
|
||||||
@include mobile {
|
@include mobile {
|
||||||
@ -67,3 +68,10 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.shortcut-options-column {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 12px;
|
||||||
|
width: 320px;
|
||||||
|
}
|
||||||
|
@ -50,24 +50,37 @@ export default function ShortcutsModal() {
|
|||||||
show={shortcutsShow}
|
show={shortcutsShow}
|
||||||
>
|
>
|
||||||
<div className="shortcut-options">
|
<div className="shortcut-options">
|
||||||
<ShortCut
|
<div className="shortcut-options-column">
|
||||||
content="Multi-Stroke Mask Drawing"
|
<ShortCut content="Pan" keys={['Space + Drag']} />
|
||||||
keys={[`Hold ${CmdOrCtrl}`]}
|
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
|
||||||
/>
|
<ShortCut content="Decrease Brush Size" keys={['[']} />
|
||||||
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
|
<ShortCut content="Increase Brush Size" keys={[']']} />
|
||||||
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
|
<ShortCut content="View Original Image" keys={['Hold Tab']} />
|
||||||
<ShortCut content="Interactive Segmentation" keys={['I']} />
|
<ShortCut
|
||||||
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
|
content="Multi-Stroke Drawing"
|
||||||
<ShortCut content="Redo Inpainting" keys={[CmdOrCtrl, 'Shift', 'Z']} />
|
keys={[`Hold ${CmdOrCtrl}`]}
|
||||||
<ShortCut content="View Original Image" keys={['Hold Tab']} />
|
/>
|
||||||
<ShortCut content="Pan" keys={['Space + Drag']} />
|
<ShortCut content="Cancel Drawing" keys={['Esc']} />
|
||||||
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
|
</div>
|
||||||
<ShortCut content="Decrease Brush Size" keys={['[']} />
|
|
||||||
<ShortCut content="Increase Brush Size" keys={[']']} />
|
<div className="shortcut-options-column">
|
||||||
<ShortCut content="Toggle Dark Mode" keys={['Shift', 'D']} />
|
<ShortCut content="Undo" keys={[CmdOrCtrl, 'Z']} />
|
||||||
<ShortCut content="Toggle Hotkeys Dialog" keys={['H']} />
|
<ShortCut content="Redo" keys={[CmdOrCtrl, 'Shift', 'Z']} />
|
||||||
<ShortCut content="Toggle Settings Dialog" keys={['S']} />
|
<ShortCut content="Copy Result" keys={[CmdOrCtrl, 'C']} />
|
||||||
<ShortCut content="Toggle File Manager" keys={['F']} />
|
<ShortCut content="Paste Image" keys={[CmdOrCtrl, 'V']} />
|
||||||
|
<ShortCut
|
||||||
|
content="Trigger Manually Inpainting"
|
||||||
|
keys={['Shift', 'R']}
|
||||||
|
/>
|
||||||
|
<ShortCut content="Trigger Interactive Segmentation" keys={['I']} />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="shortcut-options-column">
|
||||||
|
<ShortCut content="Switch Theme" keys={['Shift', 'D']} />
|
||||||
|
<ShortCut content="Toggle Hotkeys Dialog" keys={['H']} />
|
||||||
|
<ShortCut content="Toggle Settings Dialog" keys={['S']} />
|
||||||
|
<ShortCut content="Toggle File Manager" keys={['F']} />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Modal>
|
</Modal>
|
||||||
)
|
)
|
||||||
|
224
lama_cleaner/app/src/components/SidePanel/P2PSidePanel.tsx
Normal file
224
lama_cleaner/app/src/components/SidePanel/P2PSidePanel.tsx
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
import React, { FormEvent } from 'react'
|
||||||
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
|
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||||
|
import { useToggle } from 'react-use'
|
||||||
|
import {
|
||||||
|
isInpaintingState,
|
||||||
|
negativePropmtState,
|
||||||
|
propmtState,
|
||||||
|
settingState,
|
||||||
|
} from '../../store/Atoms'
|
||||||
|
import NumberInputSetting from '../Settings/NumberInputSetting'
|
||||||
|
import SettingBlock from '../Settings/SettingBlock'
|
||||||
|
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||||
|
import TextAreaInput from '../shared/Textarea'
|
||||||
|
import emitter, { EVENT_PROMPT } from '../../event'
|
||||||
|
import ImageResizeScale from './ImageResizeScale'
|
||||||
|
|
||||||
|
const INPUT_WIDTH = 30
|
||||||
|
|
||||||
|
const P2PSidePanel = () => {
|
||||||
|
const [open, toggleOpen] = useToggle(true)
|
||||||
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
const [negativePrompt, setNegativePrompt] =
|
||||||
|
useRecoilState(negativePropmtState)
|
||||||
|
const isInpainting = useRecoilValue(isInpaintingState)
|
||||||
|
const prompt = useRecoilValue(propmtState)
|
||||||
|
|
||||||
|
const handleOnInput = (evt: FormEvent<HTMLTextAreaElement>) => {
|
||||||
|
evt.preventDefault()
|
||||||
|
evt.stopPropagation()
|
||||||
|
const target = evt.target as HTMLTextAreaElement
|
||||||
|
setNegativePrompt(target.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
const onKeyUp = (e: React.KeyboardEvent) => {
|
||||||
|
if (
|
||||||
|
e.key === 'Enter' &&
|
||||||
|
(e.ctrlKey || e.metaKey) &&
|
||||||
|
prompt.length !== 0 &&
|
||||||
|
!isInpainting
|
||||||
|
) {
|
||||||
|
emitter.emit(EVENT_PROMPT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="side-panel">
|
||||||
|
<PopoverPrimitive.Root open={open}>
|
||||||
|
<PopoverPrimitive.Trigger
|
||||||
|
className="btn-primary side-panel-trigger"
|
||||||
|
onClick={() => toggleOpen()}
|
||||||
|
>
|
||||||
|
Config
|
||||||
|
</PopoverPrimitive.Trigger>
|
||||||
|
<PopoverPrimitive.Portal>
|
||||||
|
<PopoverPrimitive.Content className="side-panel-content">
|
||||||
|
<SettingBlock
|
||||||
|
title="Croper"
|
||||||
|
input={
|
||||||
|
<Switch
|
||||||
|
checked={setting.showCroper}
|
||||||
|
onCheckedChange={value => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, showCroper: value }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<SwitchThumb />
|
||||||
|
</Switch>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<ImageResizeScale />
|
||||||
|
|
||||||
|
<NumberInputSetting
|
||||||
|
title="Steps"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
value={`${setting.p2pSteps}`}
|
||||||
|
desc="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, p2pSteps: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<NumberInputSetting
|
||||||
|
title="Guidance Scale"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
allowFloat
|
||||||
|
value={`${setting.p2pGuidanceScale}`}
|
||||||
|
desc="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality."
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, p2pGuidanceScale: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<NumberInputSetting
|
||||||
|
title="Image Guidance Scale"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
allowFloat
|
||||||
|
value={`${setting.p2pImageGuidanceScale}`}
|
||||||
|
desc=""
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, p2pImageGuidanceScale: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* <NumberInputSetting
|
||||||
|
title="Mask Blur"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
value={`${setting.sdMaskBlur}`}
|
||||||
|
desc="Blur the edge of mask area. The higher the number the smoother blend with the original image"
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, sdMaskBlur: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/> */}
|
||||||
|
|
||||||
|
{/* <SettingBlock
|
||||||
|
title="Match Histograms"
|
||||||
|
desc="Match the inpainting result histogram to the source image histogram, will improves the inpainting quality for some images."
|
||||||
|
input={
|
||||||
|
<Switch
|
||||||
|
checked={setting.sdMatchHistograms}
|
||||||
|
onCheckedChange={value => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, sdMatchHistograms: value }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<SwitchThumb />
|
||||||
|
</Switch>
|
||||||
|
}
|
||||||
|
/> */}
|
||||||
|
|
||||||
|
{/* <SettingBlock
|
||||||
|
className="sub-setting-block"
|
||||||
|
title="Sampler"
|
||||||
|
input={
|
||||||
|
<Selector
|
||||||
|
width={80}
|
||||||
|
value={setting.sdSampler as string}
|
||||||
|
options={Object.values(SDSampler)}
|
||||||
|
onChange={val => {
|
||||||
|
const sampler = val as SDSampler
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, sdSampler: sampler }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/> */}
|
||||||
|
|
||||||
|
<SettingBlock
|
||||||
|
title="Seed"
|
||||||
|
input={
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
gap: 0,
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<NumberInputSetting
|
||||||
|
title=""
|
||||||
|
width={80}
|
||||||
|
value={`${setting.sdSeed}`}
|
||||||
|
desc=""
|
||||||
|
disable={!setting.sdSeedFixed}
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, sdSeed: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
checked={setting.sdSeedFixed}
|
||||||
|
onCheckedChange={value => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, sdSeedFixed: value }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
style={{ marginLeft: '8px' }}
|
||||||
|
>
|
||||||
|
<SwitchThumb />
|
||||||
|
</Switch>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SettingBlock
|
||||||
|
className="sub-setting-block"
|
||||||
|
title="Negative prompt"
|
||||||
|
layout="v"
|
||||||
|
input={
|
||||||
|
<TextAreaInput
|
||||||
|
className="negative-prompt"
|
||||||
|
value={negativePrompt}
|
||||||
|
onInput={handleOnInput}
|
||||||
|
onKeyUp={onKeyUp}
|
||||||
|
placeholder=""
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</PopoverPrimitive.Content>
|
||||||
|
</PopoverPrimitive.Portal>
|
||||||
|
</PopoverPrimitive.Root>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default P2PSidePanel
|
@ -8,6 +8,7 @@ import {
|
|||||||
AIModel,
|
AIModel,
|
||||||
fileState,
|
fileState,
|
||||||
isPaintByExampleState,
|
isPaintByExampleState,
|
||||||
|
isPix2PixState,
|
||||||
isSDState,
|
isSDState,
|
||||||
settingState,
|
settingState,
|
||||||
showFileManagerState,
|
showFileManagerState,
|
||||||
@ -22,6 +23,7 @@ import {
|
|||||||
import SidePanel from './SidePanel/SidePanel'
|
import SidePanel from './SidePanel/SidePanel'
|
||||||
import PESidePanel from './SidePanel/PESidePanel'
|
import PESidePanel from './SidePanel/PESidePanel'
|
||||||
import FileManager from './FileManager/FileManager'
|
import FileManager from './FileManager/FileManager'
|
||||||
|
import P2PSidePanel from './SidePanel/P2PSidePanel'
|
||||||
|
|
||||||
const Workspace = () => {
|
const Workspace = () => {
|
||||||
const setFile = useSetRecoilState(fileState)
|
const setFile = useSetRecoilState(fileState)
|
||||||
@ -29,6 +31,7 @@ const Workspace = () => {
|
|||||||
const [toastVal, setToastState] = useRecoilState(toastState)
|
const [toastVal, setToastState] = useRecoilState(toastState)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isSD = useRecoilValue(isSDState)
|
||||||
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
const isPaintByExample = useRecoilValue(isPaintByExampleState)
|
||||||
|
const isPix2Pix = useRecoilValue(isPix2PixState)
|
||||||
|
|
||||||
const [showFileManager, setShowFileManager] =
|
const [showFileManager, setShowFileManager] =
|
||||||
useRecoilState(showFileManagerState)
|
useRecoilState(showFileManagerState)
|
||||||
@ -98,6 +101,7 @@ const Workspace = () => {
|
|||||||
<>
|
<>
|
||||||
{isSD ? <SidePanel /> : <></>}
|
{isSD ? <SidePanel /> : <></>}
|
||||||
{isPaintByExample ? <PESidePanel /> : <></>}
|
{isPaintByExample ? <PESidePanel /> : <></>}
|
||||||
|
{isPix2Pix ? <P2PSidePanel /> : <></>}
|
||||||
<FileManager
|
<FileManager
|
||||||
photoWidth={256}
|
photoWidth={256}
|
||||||
show={showFileManager}
|
show={showFileManager}
|
||||||
|
@ -3,7 +3,7 @@ import * as ToastPrimitive from '@radix-ui/react-toast'
|
|||||||
import { ToastProps } from '@radix-ui/react-toast'
|
import { ToastProps } from '@radix-ui/react-toast'
|
||||||
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/24/outline'
|
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/24/outline'
|
||||||
|
|
||||||
const LoadingIcon = () => {
|
export const LoadingIcon = () => {
|
||||||
return (
|
return (
|
||||||
<span className="loading-icon">
|
<span className="loading-icon">
|
||||||
<svg
|
<svg
|
||||||
|
BIN
lama_cleaner/app/src/media/coffee-machine-lineal.gif
Normal file
BIN
lama_cleaner/app/src/media/coffee-machine-lineal.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 422 KiB |
@ -14,6 +14,7 @@ export enum AIModel {
|
|||||||
CV2 = 'cv2',
|
CV2 = 'cv2',
|
||||||
Mange = 'manga',
|
Mange = 'manga',
|
||||||
PAINT_BY_EXAMPLE = 'paint_by_example',
|
PAINT_BY_EXAMPLE = 'paint_by_example',
|
||||||
|
PIX2PIX = 'instruct_pix2pix',
|
||||||
}
|
}
|
||||||
|
|
||||||
export const maskState = atom<File | undefined>({
|
export const maskState = atom<File | undefined>({
|
||||||
@ -45,6 +46,7 @@ interface AppState {
|
|||||||
interactiveSegClicks: number[][]
|
interactiveSegClicks: number[][]
|
||||||
showFileManager: boolean
|
showFileManager: boolean
|
||||||
enableFileManager: boolean
|
enableFileManager: boolean
|
||||||
|
gifImage: HTMLImageElement | undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
export const appState = atom<AppState>({
|
export const appState = atom<AppState>({
|
||||||
@ -61,6 +63,7 @@ export const appState = atom<AppState>({
|
|||||||
interactiveSegClicks: [],
|
interactiveSegClicks: [],
|
||||||
showFileManager: false,
|
showFileManager: false,
|
||||||
enableFileManager: false,
|
enableFileManager: false,
|
||||||
|
gifImage: undefined,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -134,6 +137,18 @@ export const enableFileManagerState = selector({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const gifImageState = selector({
|
||||||
|
key: 'gifImageState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.gifImage
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, { ...app, gifImage: newValue })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
export const fileState = selector({
|
export const fileState = selector({
|
||||||
key: 'fileState',
|
key: 'fileState',
|
||||||
get: ({ get }) => {
|
get: ({ get }) => {
|
||||||
@ -329,6 +344,11 @@ export interface Settings {
|
|||||||
paintByExampleSeedFixed: boolean
|
paintByExampleSeedFixed: boolean
|
||||||
paintByExampleMaskBlur: number
|
paintByExampleMaskBlur: number
|
||||||
paintByExampleMatchHistograms: boolean
|
paintByExampleMatchHistograms: boolean
|
||||||
|
|
||||||
|
// InstructPix2Pix
|
||||||
|
p2pSteps: number
|
||||||
|
p2pImageGuidanceScale: number
|
||||||
|
p2pGuidanceScale: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultHDSettings: ModelsHDSettings = {
|
const defaultHDSettings: ModelsHDSettings = {
|
||||||
@ -388,6 +408,13 @@ const defaultHDSettings: ModelsHDSettings = {
|
|||||||
hdStrategyCropMargin: 128,
|
hdStrategyCropMargin: 128,
|
||||||
enabled: false,
|
enabled: false,
|
||||||
},
|
},
|
||||||
|
[AIModel.PIX2PIX]: {
|
||||||
|
hdStrategy: HDStrategy.ORIGINAL,
|
||||||
|
hdStrategyResizeLimit: 768,
|
||||||
|
hdStrategyCropTrigerSize: 512,
|
||||||
|
hdStrategyCropMargin: 128,
|
||||||
|
enabled: false,
|
||||||
|
},
|
||||||
[AIModel.Mange]: {
|
[AIModel.Mange]: {
|
||||||
hdStrategy: HDStrategy.CROP,
|
hdStrategy: HDStrategy.CROP,
|
||||||
hdStrategyResizeLimit: 1280,
|
hdStrategyResizeLimit: 1280,
|
||||||
@ -457,6 +484,11 @@ export const settingStateDefault: Settings = {
|
|||||||
paintByExampleMaskBlur: 5,
|
paintByExampleMaskBlur: 5,
|
||||||
paintByExampleSeedFixed: false,
|
paintByExampleSeedFixed: false,
|
||||||
paintByExampleMatchHistograms: false,
|
paintByExampleMatchHistograms: false,
|
||||||
|
|
||||||
|
// InstructPix2Pix
|
||||||
|
p2pSteps: 50,
|
||||||
|
p2pImageGuidanceScale: 1.5,
|
||||||
|
p2pGuidanceScale: 7.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
const localStorageEffect =
|
const localStorageEffect =
|
||||||
@ -553,12 +585,33 @@ export const isPaintByExampleState = selector({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const isPix2PixState = selector({
|
||||||
|
key: 'isPix2PixState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const settings = get(settingState)
|
||||||
|
return settings.model === AIModel.PIX2PIX
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
export const runManuallyState = selector({
|
export const runManuallyState = selector({
|
||||||
key: 'runManuallyState',
|
key: 'runManuallyState',
|
||||||
get: ({ get }) => {
|
get: ({ get }) => {
|
||||||
const settings = get(settingState)
|
const settings = get(settingState)
|
||||||
const isSD = get(isSDState)
|
const isSD = get(isSDState)
|
||||||
const isPaintByExample = get(isPaintByExampleState)
|
const isPaintByExample = get(isPaintByExampleState)
|
||||||
return settings.runInpaintingManually || isSD || isPaintByExample
|
const isPix2Pix = get(isPix2PixState)
|
||||||
|
return (
|
||||||
|
settings.runInpaintingManually || isSD || isPaintByExample || isPix2Pix
|
||||||
|
)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
export const isDiffusionModelsState = selector({
|
||||||
|
key: 'isDiffusionModelsState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const isSD = get(isSDState)
|
||||||
|
const isPaintByExample = get(isPaintByExampleState)
|
||||||
|
const isPix2Pix = get(isPix2PixState)
|
||||||
|
return isSD || isPaintByExample || isPix2Pix
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
MPS_SUPPORT_MODELS = [
|
||||||
|
"instruct_pix2pix",
|
||||||
|
"sd1.5",
|
||||||
|
"sd2",
|
||||||
|
"paint_by_example"
|
||||||
|
]
|
||||||
|
|
||||||
DEFAULT_MODEL = "lama"
|
DEFAULT_MODEL = "lama"
|
||||||
AVAILABLE_MODELS = [
|
AVAILABLE_MODELS = [
|
||||||
"lama",
|
"lama",
|
||||||
@ -11,7 +18,8 @@ AVAILABLE_MODELS = [
|
|||||||
"cv2",
|
"cv2",
|
||||||
"manga",
|
"manga",
|
||||||
"sd2",
|
"sd2",
|
||||||
"paint_by_example"
|
"paint_by_example",
|
||||||
|
"instruct_pix2pix",
|
||||||
]
|
]
|
||||||
|
|
||||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||||
|
@ -62,7 +62,7 @@ def load_model(model: torch.nn.Module, url_or_path, device):
|
|||||||
model_path = download_model(url_or_path)
|
model_path = download_model(url_or_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state_dict = torch.load(model_path, map_location='cpu')
|
state_dict = torch.load(model_path, map_location="cpu")
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
logger.info(f"Load model from: {model_path}")
|
logger.info(f"Load model from: {model_path}")
|
||||||
@ -85,26 +85,43 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
|||||||
return image_bytes
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
def load_img(img_bytes, gray: bool = False):
|
def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes:
|
||||||
|
with io.BytesIO() as output:
|
||||||
|
pil_img.save(output, format=ext, exif=exif, quality=95)
|
||||||
|
image_bytes = output.getvalue()
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
|
||||||
alpha_channel = None
|
alpha_channel = None
|
||||||
image = Image.open(io.BytesIO(img_bytes))
|
image = Image.open(io.BytesIO(img_bytes))
|
||||||
|
|
||||||
|
try:
|
||||||
|
if return_exif:
|
||||||
|
exif = image.getexif()
|
||||||
|
except:
|
||||||
|
exif = None
|
||||||
|
logger.error("Failed to extract exif from image")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if gray:
|
if gray:
|
||||||
image = image.convert('L')
|
image = image.convert("L")
|
||||||
np_img = np.array(image)
|
np_img = np.array(image)
|
||||||
else:
|
else:
|
||||||
if image.mode == 'RGBA':
|
if image.mode == "RGBA":
|
||||||
np_img = np.array(image)
|
np_img = np.array(image)
|
||||||
alpha_channel = np_img[:, :, -1]
|
alpha_channel = np_img[:, :, -1]
|
||||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
||||||
else:
|
else:
|
||||||
image = image.convert('RGB')
|
image = image.convert("RGB")
|
||||||
np_img = np.array(image)
|
np_img = np.array(image)
|
||||||
|
|
||||||
|
if return_exif:
|
||||||
|
return np_img, alpha_channel, exif
|
||||||
return np_img, alpha_channel
|
return np_img, alpha_channel
|
||||||
|
|
||||||
|
|
||||||
|
125
lama_cleaner/make_gif.py
Normal file
125
lama_cleaner/make_gif.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import io
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
|
||||||
|
def keep_ratio_resize(img, size, resample=Image.BILINEAR):
|
||||||
|
if img.width > img.height:
|
||||||
|
w = size
|
||||||
|
h = int(img.height * size / img.width)
|
||||||
|
else:
|
||||||
|
h = size
|
||||||
|
w = int(img.width * size / img.height)
|
||||||
|
return img.resize((w, h), resample)
|
||||||
|
|
||||||
|
|
||||||
|
def cubic_bezier(p1, p2, duration: int, frames: int):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p1:
|
||||||
|
p2:
|
||||||
|
duration: Total duration of the curve
|
||||||
|
frames:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
x0, y0 = (0, 0)
|
||||||
|
x1, y1 = p1
|
||||||
|
x2, y2 = p2
|
||||||
|
x3, y3 = (1, 1)
|
||||||
|
|
||||||
|
def cal_y(t):
|
||||||
|
return math.pow(1 - t, 3) * y0 + \
|
||||||
|
3 * math.pow(1 - t, 2) * t * y1 + \
|
||||||
|
3 * (1 - t) * math.pow(t, 2) * y2 + \
|
||||||
|
math.pow(t, 3) * y3
|
||||||
|
|
||||||
|
def cal_x(t):
|
||||||
|
return math.pow(1 - t, 3) * x0 + \
|
||||||
|
3 * math.pow(1 - t, 2) * t * x1 + \
|
||||||
|
3 * (1 - t) * math.pow(t, 2) * x2 + \
|
||||||
|
math.pow(t, 3) * x3
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for t in range(0, 1 * frames, duration):
|
||||||
|
t = t / frames
|
||||||
|
res.append((cal_x(t), cal_y(t)))
|
||||||
|
|
||||||
|
res.append((1, 0))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def make_compare_gif(
|
||||||
|
clean_img: Image.Image,
|
||||||
|
src_img: Image.Image,
|
||||||
|
max_side_length: int = 600,
|
||||||
|
splitter_width: int = 5,
|
||||||
|
splitter_color=(255, 203, 0, int(255 * 0.73))
|
||||||
|
):
|
||||||
|
if clean_img.size != src_img.size:
|
||||||
|
clean_img = clean_img.resize(src_img.size, Image.BILINEAR)
|
||||||
|
|
||||||
|
duration_per_frame = 20
|
||||||
|
num_frames = 50
|
||||||
|
# erase-in-out
|
||||||
|
cubic_bezier_points = cubic_bezier((0.33, 0), (0.66, 1), 1, num_frames)
|
||||||
|
cubic_bezier_points.reverse()
|
||||||
|
|
||||||
|
max_side_length = min(max_side_length, max(clean_img.size))
|
||||||
|
|
||||||
|
src_img = keep_ratio_resize(src_img, max_side_length)
|
||||||
|
clean_img = keep_ratio_resize(clean_img, max_side_length)
|
||||||
|
width, height = src_img.size
|
||||||
|
|
||||||
|
# Generate images to make Gif from right to left
|
||||||
|
images = []
|
||||||
|
|
||||||
|
for i in range(num_frames):
|
||||||
|
new_frame = Image.new('RGB', (width, height))
|
||||||
|
new_frame.paste(clean_img, (0, 0))
|
||||||
|
|
||||||
|
left = int(cubic_bezier_points[i][0] * width)
|
||||||
|
cropped_src_img = src_img.crop((left, 0, width, height))
|
||||||
|
new_frame.paste(cropped_src_img, (left, 0, width, height))
|
||||||
|
if i != num_frames - 1:
|
||||||
|
# draw a yellow splitter on the edge of the cropped image
|
||||||
|
draw = ImageDraw.Draw(new_frame)
|
||||||
|
draw.line([(left, 0), (left, height)], width=splitter_width, fill=splitter_color)
|
||||||
|
images.append(new_frame)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
images.append(src_img)
|
||||||
|
|
||||||
|
cubic_bezier_points.reverse()
|
||||||
|
# Generate images to make Gif from left to right
|
||||||
|
for i in range(num_frames):
|
||||||
|
new_frame = Image.new('RGB', (width, height))
|
||||||
|
new_frame.paste(src_img, (0, 0))
|
||||||
|
|
||||||
|
right = int(cubic_bezier_points[i][0] * width)
|
||||||
|
cropped_src_img = clean_img.crop((0, 0, right, height))
|
||||||
|
new_frame.paste(cropped_src_img, (0, 0, right, height))
|
||||||
|
if i != num_frames - 1:
|
||||||
|
# draw a yellow splitter on the edge of the cropped image
|
||||||
|
draw = ImageDraw.Draw(new_frame)
|
||||||
|
draw.line([(right, 0), (right, height)], width=splitter_width, fill=splitter_color)
|
||||||
|
images.append(new_frame)
|
||||||
|
|
||||||
|
images.append(clean_img)
|
||||||
|
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
clean_img.save(
|
||||||
|
img_byte_arr,
|
||||||
|
format='GIF',
|
||||||
|
save_all=True,
|
||||||
|
include_color_table=True,
|
||||||
|
append_images=images,
|
||||||
|
optimize=False,
|
||||||
|
duration=duration_per_frame,
|
||||||
|
loop=0
|
||||||
|
)
|
||||||
|
return img_byte_arr.getvalue()
|
@ -245,3 +245,42 @@ class InpaintModel:
|
|||||||
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
||||||
|
|
||||||
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionInpaintModel(InpaintModel):
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, image, mask, config: Config):
|
||||||
|
"""
|
||||||
|
images: [H, W, C] RGB, not normalized
|
||||||
|
masks: [H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
# boxes = boxes_from_mask(mask)
|
||||||
|
if config.use_croper:
|
||||||
|
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||||
|
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||||
|
inpaint_result = image[:, :, ::-1]
|
||||||
|
inpaint_result[t:b, l:r, :] = crop_image
|
||||||
|
else:
|
||||||
|
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
||||||
|
|
||||||
|
return inpaint_result
|
||||||
|
|
||||||
|
def _scaled_pad_forward(self, image, mask, config: Config):
|
||||||
|
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||||
|
origin_size = image.shape[:2]
|
||||||
|
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||||
|
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
||||||
|
logger.info(
|
||||||
|
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
||||||
|
)
|
||||||
|
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||||
|
# only paste masked area result
|
||||||
|
inpaint_result = cv2.resize(
|
||||||
|
inpaint_result,
|
||||||
|
(origin_size[1], origin_size[0]),
|
||||||
|
interpolation=cv2.INTER_CUBIC,
|
||||||
|
)
|
||||||
|
original_pixel_indices = mask < 127
|
||||||
|
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||||
|
return inpaint_result
|
||||||
|
83
lama_cleaner/model/instruct_pix2pix.py
Normal file
83
lama_cleaner/model/instruct_pix2pix.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
|
from lama_cleaner.model.utils import set_seed
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
|
class InstructPix2Pix(DiffusionInpaintModel):
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from diffusers import StableDiffusionInstructPix2PixPipeline
|
||||||
|
fp16 = not kwargs.get('no_half', False)
|
||||||
|
|
||||||
|
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
||||||
|
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False
|
||||||
|
))
|
||||||
|
|
||||||
|
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||||
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||||
|
"timbrooks/instruct-pix2pix",
|
||||||
|
revision="fp16" if use_gpu and fp16 else "main",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
**model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
if kwargs.get('enable_xformers', False):
|
||||||
|
self.model.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
if kwargs.get('cpu_offload', False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
|
||||||
|
"""
|
||||||
|
set_seed(config.sd_seed)
|
||||||
|
|
||||||
|
output = self.model(
|
||||||
|
image=PIL.Image.fromarray(image),
|
||||||
|
prompt=config.prompt,
|
||||||
|
negative_prompt=config.negative_prompt,
|
||||||
|
num_inference_steps=config.p2p_steps,
|
||||||
|
image_guidance_scale=config.p2p_image_guidance_scale,
|
||||||
|
guidance_scale=config.p2p_guidance_scale,
|
||||||
|
output_type="np.array",
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
||||||
|
|
||||||
|
#
|
||||||
|
# def forward_post_process(self, result, image, mask, config):
|
||||||
|
# if config.sd_match_histograms:
|
||||||
|
# result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
#
|
||||||
|
# if config.sd_mask_blur != 0:
|
||||||
|
# k = 2 * config.sd_mask_blur + 1
|
||||||
|
# mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||||
|
# return result, image, mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
# model will be downloaded when app start, and can't switch in frontend settings
|
||||||
|
return True
|
@ -1,19 +1,16 @@
|
|||||||
import random
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import resize_max_size
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.utils import set_seed
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
class PaintByExample(InpaintModel):
|
class PaintByExample(DiffusionInpaintModel):
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
@ -53,11 +50,7 @@ class PaintByExample(InpaintModel):
|
|||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
seed = config.paint_by_example_seed
|
set_seed(config.paint_by_example_seed)
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
output = self.model(
|
output = self.model(
|
||||||
image=PIL.Image.fromarray(image),
|
image=PIL.Image.fromarray(image),
|
||||||
@ -71,42 +64,6 @@ class PaintByExample(InpaintModel):
|
|||||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _scaled_pad_forward(self, image, mask, config: Config):
|
|
||||||
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
|
||||||
origin_size = image.shape[:2]
|
|
||||||
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
|
||||||
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
|
||||||
logger.info(
|
|
||||||
f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}"
|
|
||||||
)
|
|
||||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
|
||||||
# only paste masked area result
|
|
||||||
inpaint_result = cv2.resize(
|
|
||||||
inpaint_result,
|
|
||||||
(origin_size[1], origin_size[0]),
|
|
||||||
interpolation=cv2.INTER_CUBIC,
|
|
||||||
)
|
|
||||||
original_pixel_indices = mask < 127
|
|
||||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
|
||||||
return inpaint_result
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(self, image, mask, config: Config):
|
|
||||||
"""
|
|
||||||
images: [H, W, C] RGB, not normalized
|
|
||||||
masks: [H, W]
|
|
||||||
return: BGR IMAGE
|
|
||||||
"""
|
|
||||||
if config.use_croper:
|
|
||||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
|
||||||
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
|
||||||
inpaint_result = image[:, :, ::-1]
|
|
||||||
inpaint_result[t:b, l:r, :] = crop_image
|
|
||||||
else:
|
|
||||||
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
|
||||||
|
|
||||||
return inpaint_result
|
|
||||||
|
|
||||||
def forward_post_process(self, result, image, mask, config):
|
def forward_post_process(self, result, image, mask, config):
|
||||||
if config.paint_by_example_match_histograms:
|
if config.paint_by_example_match_histograms:
|
||||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
@ -4,20 +4,25 @@ import PIL.Image
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
|
from diffusers import (
|
||||||
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
|
PNDMScheduler,
|
||||||
|
DDIMScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import resize_max_size
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.utils import torch_gc, set_seed
|
||||||
from lama_cleaner.model.utils import torch_gc
|
|
||||||
from lama_cleaner.schema import Config, SDSampler
|
from lama_cleaner.schema import Config, SDSampler
|
||||||
|
|
||||||
|
|
||||||
class CPUTextEncoderWrapper:
|
class CPUTextEncoderWrapper:
|
||||||
def __init__(self, text_encoder, torch_dtype):
|
def __init__(self, text_encoder, torch_dtype):
|
||||||
self.config = text_encoder.config
|
self.config = text_encoder.config
|
||||||
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
|
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||||
self.torch_dtype = torch_dtype
|
self.torch_dtype = torch_dtype
|
||||||
del text_encoder
|
del text_encoder
|
||||||
@ -25,27 +30,40 @@ class CPUTextEncoderWrapper:
|
|||||||
|
|
||||||
def __call__(self, x, **kwargs):
|
def __call__(self, x, **kwargs):
|
||||||
input_device = x.device
|
input_device = x.device
|
||||||
return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)]
|
return [
|
||||||
|
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
||||||
|
.to(input_device)
|
||||||
|
.to(self.torch_dtype)
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.torch_dtype
|
||||||
|
|
||||||
|
|
||||||
class SD(InpaintModel):
|
class SD(DiffusionInpaintModel):
|
||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||||
fp16 = not kwargs.get('no_half', False)
|
|
||||||
|
|
||||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
|
fp16 = not kwargs.get("no_half", False)
|
||||||
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
|
||||||
|
model_kwargs = {
|
||||||
|
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
||||||
|
}
|
||||||
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(dict(
|
model_kwargs.update(
|
||||||
safety_checker=None,
|
dict(
|
||||||
feature_extractor=None,
|
safety_checker=None,
|
||||||
requires_safety_checker=False
|
feature_extractor=None,
|
||||||
))
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
@ -58,40 +76,23 @@ class SD(InpaintModel):
|
|||||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||||
self.model.enable_attention_slicing()
|
self.model.enable_attention_slicing()
|
||||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||||
if kwargs.get('enable_xformers', False):
|
if kwargs.get("enable_xformers", False):
|
||||||
self.model.enable_xformers_memory_efficient_attention()
|
self.model.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
if kwargs.get('cpu_offload', False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
# TODO: gpu_id
|
# TODO: gpu_id
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
else:
|
else:
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
if kwargs['sd_cpu_textencoder']:
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
def _scaled_pad_forward(self, image, mask, config: Config):
|
|
||||||
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
|
||||||
origin_size = image.shape[:2]
|
|
||||||
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
|
||||||
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
|
||||||
logger.info(
|
|
||||||
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
|
||||||
)
|
|
||||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
|
||||||
# only paste masked area result
|
|
||||||
inpaint_result = cv2.resize(
|
|
||||||
inpaint_result,
|
|
||||||
(origin_size[1], origin_size[0]),
|
|
||||||
interpolation=cv2.INTER_CUBIC,
|
|
||||||
)
|
|
||||||
original_pixel_indices = mask < 127
|
|
||||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
|
||||||
return inpaint_result
|
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
@ -118,11 +119,7 @@ class SD(InpaintModel):
|
|||||||
|
|
||||||
self.model.scheduler = scheduler
|
self.model.scheduler = scheduler
|
||||||
|
|
||||||
seed = config.sd_seed
|
set_seed(config.sd_seed)
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
if config.sd_mask_blur != 0:
|
if config.sd_mask_blur != 0:
|
||||||
k = 2 * config.sd_mask_blur + 1
|
k = 2 * config.sd_mask_blur + 1
|
||||||
@ -147,24 +144,6 @@ class SD(InpaintModel):
|
|||||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(self, image, mask, config: Config):
|
|
||||||
"""
|
|
||||||
images: [H, W, C] RGB, not normalized
|
|
||||||
masks: [H, W]
|
|
||||||
return: BGR IMAGE
|
|
||||||
"""
|
|
||||||
# boxes = boxes_from_mask(mask)
|
|
||||||
if config.use_croper:
|
|
||||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
|
||||||
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
|
||||||
inpaint_result = image[:, :, ::-1]
|
|
||||||
inpaint_result[t:b, l:r, :] = crop_image
|
|
||||||
else:
|
|
||||||
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
|
||||||
|
|
||||||
return inpaint_result
|
|
||||||
|
|
||||||
def forward_post_process(self, result, image, mask, config):
|
def forward_post_process(self, result, image, mask, config):
|
||||||
if config.sd_match_histograms:
|
if config.sd_match_histograms:
|
||||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import random
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -713,3 +714,10 @@ def torch_gc():
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
@ -7,13 +7,14 @@ from lama_cleaner.model.ldm import LDM
|
|||||||
from lama_cleaner.model.manga import Manga
|
from lama_cleaner.model.manga import Manga
|
||||||
from lama_cleaner.model.mat import MAT
|
from lama_cleaner.model.mat import MAT
|
||||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||||
|
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
||||||
from lama_cleaner.model.sd import SD15, SD2
|
from lama_cleaner.model.sd import SD15, SD2
|
||||||
from lama_cleaner.model.zits import ZITS
|
from lama_cleaner.model.zits import ZITS
|
||||||
from lama_cleaner.model.opencv2 import OpenCV2
|
from lama_cleaner.model.opencv2 import OpenCV2
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
|
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
|
||||||
"sd2": SD2, "paint_by_example": PaintByExample}
|
"sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix}
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
|
@ -5,9 +5,25 @@ from pathlib import Path
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import AVAILABLE_MODELS, NO_HALF_HELP, CPU_OFFLOAD_HELP, DISABLE_NSFW_HELP, \
|
from lama_cleaner.const import (
|
||||||
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, AVAILABLE_DEVICES, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, \
|
AVAILABLE_MODELS,
|
||||||
OUTPUT_DIR_HELP, INPUT_HELP, GUI_HELP, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR
|
NO_HALF_HELP,
|
||||||
|
CPU_OFFLOAD_HELP,
|
||||||
|
DISABLE_NSFW_HELP,
|
||||||
|
SD_CPU_TEXTENCODER_HELP,
|
||||||
|
LOCAL_FILES_ONLY_HELP,
|
||||||
|
AVAILABLE_DEVICES,
|
||||||
|
ENABLE_XFORMERS_HELP,
|
||||||
|
MODEL_DIR_HELP,
|
||||||
|
OUTPUT_DIR_HELP,
|
||||||
|
INPUT_HELP,
|
||||||
|
GUI_HELP,
|
||||||
|
DEFAULT_DEVICE,
|
||||||
|
NO_GUI_AUTO_CLOSE_HELP,
|
||||||
|
DEFAULT_MODEL_DIR,
|
||||||
|
DEFAULT_MODEL,
|
||||||
|
MPS_SUPPORT_MODELS,
|
||||||
|
)
|
||||||
from lama_cleaner.runtime import dump_environment_info
|
from lama_cleaner.runtime import dump_environment_info
|
||||||
|
|
||||||
|
|
||||||
@ -16,22 +32,40 @@ def parse_args():
|
|||||||
parser.add_argument("--host", default="127.0.0.1")
|
parser.add_argument("--host", default="127.0.0.1")
|
||||||
parser.add_argument("--port", default=8080, type=int)
|
parser.add_argument("--port", default=8080, type=int)
|
||||||
|
|
||||||
parser.add_argument("--config-installer", action="store_true",
|
parser.add_argument(
|
||||||
help="Open config web page, mainly for windows installer")
|
"--config-installer",
|
||||||
parser.add_argument("--load-installer-config", action="store_true",
|
action="store_true",
|
||||||
help="Load all cmd args from installer config file")
|
help="Open config web page, mainly for windows installer",
|
||||||
parser.add_argument("--installer-config", default=None, help="Config file for windows installer")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-installer-config",
|
||||||
|
action="store_true",
|
||||||
|
help="Load all cmd args from installer config file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--installer-config", default=None, help="Config file for windows installer"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--model", default="lama", choices=AVAILABLE_MODELS)
|
parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS)
|
||||||
parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP)
|
parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP)
|
||||||
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
|
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
|
||||||
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
|
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
|
||||||
parser.add_argument("--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP)
|
parser.add_argument(
|
||||||
parser.add_argument("--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP)
|
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
|
||||||
parser.add_argument("--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP)
|
)
|
||||||
parser.add_argument("--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES)
|
parser.add_argument(
|
||||||
|
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES
|
||||||
|
)
|
||||||
parser.add_argument("--gui", action="store_true", help=GUI_HELP)
|
parser.add_argument("--gui", action="store_true", help=GUI_HELP)
|
||||||
parser.add_argument("--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP)
|
parser.add_argument(
|
||||||
|
"--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gui-size",
|
"--gui-size",
|
||||||
default=[1600, 1000],
|
default=[1600, 1000],
|
||||||
@ -41,8 +75,14 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--input", type=str, default=None, help=INPUT_HELP)
|
parser.add_argument("--input", type=str, default=None, help=INPUT_HELP)
|
||||||
parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP)
|
parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP)
|
||||||
parser.add_argument("--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP)
|
parser.add_argument(
|
||||||
parser.add_argument("--disable-model-switch", action="store_true", help="Disable model switch in frontend")
|
"--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-model-switch",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable model switch in frontend",
|
||||||
|
)
|
||||||
parser.add_argument("--debug", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
|
|
||||||
# useless args
|
# useless args
|
||||||
@ -64,7 +104,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sd-enable-xformers",
|
"--sd-enable-xformers",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
|
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -74,14 +114,18 @@ def parse_args():
|
|||||||
|
|
||||||
if args.config_installer:
|
if args.config_installer:
|
||||||
if args.installer_config is None:
|
if args.installer_config is None:
|
||||||
parser.error(f"args.config_installer==True, must set args.installer_config to store config file")
|
parser.error(
|
||||||
|
f"args.config_installer==True, must set args.installer_config to store config file"
|
||||||
|
)
|
||||||
from lama_cleaner.web_config import main
|
from lama_cleaner.web_config import main
|
||||||
|
|
||||||
logger.info(f"Launching installer web config page")
|
logger.info(f"Launching installer web config page")
|
||||||
main(args.installer_config)
|
main(args.installer_config)
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
if args.load_installer_config:
|
if args.load_installer_config:
|
||||||
from lama_cleaner.web_config import load_config
|
from lama_cleaner.web_config import load_config
|
||||||
|
|
||||||
if args.installer_config and not os.path.exists(args.installer_config):
|
if args.installer_config and not os.path.exists(args.installer_config):
|
||||||
parser.error(f"args.installer_config={args.installer_config} not exists")
|
parser.error(f"args.installer_config={args.installer_config} not exists")
|
||||||
|
|
||||||
@ -93,9 +137,17 @@ def parse_args():
|
|||||||
|
|
||||||
if args.device == "cuda":
|
if args.device == "cuda":
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.is_available() is False:
|
if torch.cuda.is_available() is False:
|
||||||
parser.error(
|
parser.error(
|
||||||
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
|
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.device == "mps":
|
||||||
|
if args.model not in MPS_SUPPORT_MODELS:
|
||||||
|
parser.error(
|
||||||
|
f"mps only support: {MPS_SUPPORT_MODELS}, but got {args.model}"
|
||||||
|
)
|
||||||
|
|
||||||
if args.model_dir and args.model_dir is not None:
|
if args.model_dir and args.model_dir is not None:
|
||||||
if os.path.isfile(args.model_dir):
|
if os.path.isfile(args.model_dir):
|
||||||
@ -115,7 +167,9 @@ def parse_args():
|
|||||||
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
||||||
else:
|
else:
|
||||||
if args.output_dir is None:
|
if args.output_dir is None:
|
||||||
parser.error(f"invalid --input: {args.input} is a directory, --output-dir is required")
|
parser.error(
|
||||||
|
f"invalid --input: {args.input} is a directory, --output-dir is required"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
@ -123,6 +177,8 @@ def parse_args():
|
|||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
else:
|
else:
|
||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
parser.error(f"invalid --output-dir: {output_dir} is not a directory")
|
parser.error(
|
||||||
|
f"invalid --output-dir: {output_dir} is not a directory"
|
||||||
|
)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
@ -88,3 +88,8 @@ class Config(BaseModel):
|
|||||||
paint_by_example_seed: int = 42
|
paint_by_example_seed: int = 42
|
||||||
paint_by_example_match_histograms: bool = False
|
paint_by_example_match_histograms: bool = False
|
||||||
paint_by_example_example_image: Image = None
|
paint_by_example_example_image: Image = None
|
||||||
|
|
||||||
|
# InstructPix2Pix
|
||||||
|
p2p_steps: int = 50
|
||||||
|
p2p_image_guidance_scale: float = 1.5
|
||||||
|
p2p_guidance_scale: float = 7.5
|
||||||
|
@ -19,6 +19,7 @@ from loguru import logger
|
|||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
|
|
||||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||||
|
from lama_cleaner.make_gif import make_compare_gif
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
from lama_cleaner.file_manager import FileManager
|
from lama_cleaner.file_manager import FileManager
|
||||||
@ -31,7 +32,15 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
|
from flask import (
|
||||||
|
Flask,
|
||||||
|
request,
|
||||||
|
send_file,
|
||||||
|
cli,
|
||||||
|
make_response,
|
||||||
|
send_from_directory,
|
||||||
|
jsonify,
|
||||||
|
)
|
||||||
|
|
||||||
# Disable ability for Flask to display warning about using a development server in a production environment.
|
# Disable ability for Flask to display warning about using a development server in a production environment.
|
||||||
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
||||||
@ -42,6 +51,7 @@ from lama_cleaner.helper import (
|
|||||||
load_img,
|
load_img,
|
||||||
numpy_to_bytes,
|
numpy_to_bytes,
|
||||||
resize_max_size,
|
resize_max_size,
|
||||||
|
pil_to_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||||
@ -93,6 +103,25 @@ def diffuser_callback(i, t, latents):
|
|||||||
# socketio.emit('diffusion_step', {'diffusion_step': step})
|
# socketio.emit('diffusion_step', {'diffusion_step': step})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/make_gif", methods=["POST"])
|
||||||
|
def make_gif():
|
||||||
|
input = request.files
|
||||||
|
filename = request.form["filename"]
|
||||||
|
origin_image_bytes = input["origin_img"].read()
|
||||||
|
clean_image_bytes = input["clean_img"].read()
|
||||||
|
origin_image, _ = load_img(origin_image_bytes)
|
||||||
|
clean_image, _ = load_img(clean_image_bytes)
|
||||||
|
gif_bytes = make_compare_gif(
|
||||||
|
Image.fromarray(origin_image), Image.fromarray(clean_image)
|
||||||
|
)
|
||||||
|
return send_file(
|
||||||
|
io.BytesIO(gif_bytes),
|
||||||
|
mimetype="image/gif",
|
||||||
|
as_attachment=True,
|
||||||
|
attachment_filename=filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/save_image", methods=["POST"])
|
@app.route("/save_image", methods=["POST"])
|
||||||
def save_image():
|
def save_image():
|
||||||
# all image in output directory
|
# all image in output directory
|
||||||
@ -100,12 +129,12 @@ def save_image():
|
|||||||
origin_image_bytes = input["image"].read() # RGB
|
origin_image_bytes = input["image"].read() # RGB
|
||||||
image, _ = load_img(origin_image_bytes)
|
image, _ = load_img(origin_image_bytes)
|
||||||
thumb.save_to_output_directory(image, request.form["filename"])
|
thumb.save_to_output_directory(image, request.form["filename"])
|
||||||
return 'ok', 200
|
return "ok", 200
|
||||||
|
|
||||||
|
|
||||||
@app.route("/medias/<tab>")
|
@app.route("/medias/<tab>")
|
||||||
def medias(tab):
|
def medias(tab):
|
||||||
if tab == 'image':
|
if tab == "image":
|
||||||
response = make_response(jsonify(thumb.media_names), 200)
|
response = make_response(jsonify(thumb.media_names), 200)
|
||||||
else:
|
else:
|
||||||
response = make_response(jsonify(thumb.output_media_names), 200)
|
response = make_response(jsonify(thumb.output_media_names), 200)
|
||||||
@ -116,18 +145,18 @@ def medias(tab):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.route('/media/<tab>/<filename>')
|
@app.route("/media/<tab>/<filename>")
|
||||||
def media_file(tab, filename):
|
def media_file(tab, filename):
|
||||||
if tab == 'image':
|
if tab == "image":
|
||||||
return send_from_directory(thumb.root_directory, filename)
|
return send_from_directory(thumb.root_directory, filename)
|
||||||
return send_from_directory(thumb.output_dir, filename)
|
return send_from_directory(thumb.output_dir, filename)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/media_thumbnail/<tab>/<filename>')
|
@app.route("/media_thumbnail/<tab>/<filename>")
|
||||||
def media_thumbnail_file(tab, filename):
|
def media_thumbnail_file(tab, filename):
|
||||||
args = request.args
|
args = request.args
|
||||||
width = args.get('width')
|
width = args.get("width")
|
||||||
height = args.get('height')
|
height = args.get("height")
|
||||||
if width is None and height is None:
|
if width is None and height is None:
|
||||||
width = 256
|
width = 256
|
||||||
if width:
|
if width:
|
||||||
@ -136,9 +165,11 @@ def media_thumbnail_file(tab, filename):
|
|||||||
height = int(float(height))
|
height = int(float(height))
|
||||||
|
|
||||||
directory = thumb.root_directory
|
directory = thumb.root_directory
|
||||||
if tab == 'output':
|
if tab == "output":
|
||||||
directory = thumb.output_dir
|
directory = thumb.output_dir
|
||||||
thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height)
|
thumb_filename, (width, height) = thumb.get_thumbnail(
|
||||||
|
directory, filename, width, height
|
||||||
|
)
|
||||||
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
|
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
|
||||||
|
|
||||||
response = make_response(send_file(thumb_filepath))
|
response = make_response(send_file(thumb_filepath))
|
||||||
@ -152,13 +183,16 @@ def process():
|
|||||||
input = request.files
|
input = request.files
|
||||||
# RGB
|
# RGB
|
||||||
origin_image_bytes = input["image"].read()
|
origin_image_bytes = input["image"].read()
|
||||||
image, alpha_channel = load_img(origin_image_bytes)
|
image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True)
|
||||||
|
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||||
|
|
||||||
if image.shape[:2] != mask.shape[:2]:
|
if image.shape[:2] != mask.shape[:2]:
|
||||||
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
|
return (
|
||||||
|
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
original_shape = image.shape
|
original_shape = image.shape
|
||||||
interpolation = cv2.INTER_CUBIC
|
interpolation = cv2.INTER_CUBIC
|
||||||
@ -171,7 +205,9 @@ def process():
|
|||||||
size_limit = int(size_limit)
|
size_limit = int(size_limit)
|
||||||
|
|
||||||
if "paintByExampleImage" in input:
|
if "paintByExampleImage" in input:
|
||||||
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
|
paint_by_example_example_image, _ = load_img(
|
||||||
|
input["paintByExampleImage"].read()
|
||||||
|
)
|
||||||
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
|
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
|
||||||
else:
|
else:
|
||||||
paint_by_example_example_image = None
|
paint_by_example_example_image = None
|
||||||
@ -200,13 +236,16 @@ def process():
|
|||||||
sd_seed=form["sdSeed"],
|
sd_seed=form["sdSeed"],
|
||||||
sd_match_histograms=form["sdMatchHistograms"],
|
sd_match_histograms=form["sdMatchHistograms"],
|
||||||
cv2_flag=form["cv2Flag"],
|
cv2_flag=form["cv2Flag"],
|
||||||
cv2_radius=form['cv2Radius'],
|
cv2_radius=form["cv2Radius"],
|
||||||
paint_by_example_steps=form["paintByExampleSteps"],
|
paint_by_example_steps=form["paintByExampleSteps"],
|
||||||
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
|
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
|
||||||
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
|
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
|
||||||
paint_by_example_seed=form["paintByExampleSeed"],
|
paint_by_example_seed=form["paintByExampleSeed"],
|
||||||
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
|
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
|
||||||
paint_by_example_example_image=paint_by_example_example_image,
|
paint_by_example_example_image=paint_by_example_example_image,
|
||||||
|
p2p_steps=form["p2pSteps"],
|
||||||
|
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
||||||
|
p2p_guidance_scale=form["p2pGuidanceScale"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.sd_seed == -1:
|
if config.sd_seed == -1:
|
||||||
@ -235,6 +274,7 @@ def process():
|
|||||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||||
if alpha_channel is not None:
|
if alpha_channel is not None:
|
||||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
||||||
alpha_channel = cv2.resize(
|
alpha_channel = cv2.resize(
|
||||||
@ -246,9 +286,15 @@ def process():
|
|||||||
|
|
||||||
ext = get_image_ext(origin_image_bytes)
|
ext = get_image_ext(origin_image_bytes)
|
||||||
|
|
||||||
|
if exif is not None:
|
||||||
|
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, exif=exif))
|
||||||
|
else:
|
||||||
|
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext))
|
||||||
|
|
||||||
response = make_response(
|
response = make_response(
|
||||||
send_file(
|
send_file(
|
||||||
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
# io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||||
|
bytes_io,
|
||||||
mimetype=f"image/{ext}",
|
mimetype=f"image/{ext}",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -261,7 +307,7 @@ def interactive_seg():
|
|||||||
input = request.files
|
input = request.files
|
||||||
origin_image_bytes = input["image"].read() # RGB
|
origin_image_bytes = input["image"].read() # RGB
|
||||||
image, _ = load_img(origin_image_bytes)
|
image, _ = load_img(origin_image_bytes)
|
||||||
if 'mask' in input:
|
if "mask" in input:
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
@ -269,14 +315,16 @@ def interactive_seg():
|
|||||||
_clicks = json.loads(request.form["clicks"])
|
_clicks = json.loads(request.form["clicks"])
|
||||||
clicks = []
|
clicks = []
|
||||||
for i, click in enumerate(_clicks):
|
for i, click in enumerate(_clicks):
|
||||||
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
|
clicks.append(
|
||||||
|
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
|
||||||
|
)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
||||||
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
||||||
response = make_response(
|
response = make_response(
|
||||||
send_file(
|
send_file(
|
||||||
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
|
io.BytesIO(numpy_to_bytes(new_mask, "png")),
|
||||||
mimetype=f"image/png",
|
mimetype=f"image/png",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -290,13 +338,13 @@ def current_model():
|
|||||||
|
|
||||||
@app.route("/is_disable_model_switch")
|
@app.route("/is_disable_model_switch")
|
||||||
def get_is_disable_model_switch():
|
def get_is_disable_model_switch():
|
||||||
res = 'true' if is_disable_model_switch else 'false'
|
res = "true" if is_disable_model_switch else "false"
|
||||||
return res, 200
|
return res, 200
|
||||||
|
|
||||||
|
|
||||||
@app.route("/is_enable_file_manager")
|
@app.route("/is_enable_file_manager")
|
||||||
def get_is_enable_file_manager():
|
def get_is_enable_file_manager():
|
||||||
res = 'true' if is_enable_file_manager else 'false'
|
res = "true" if is_enable_file_manager else "false"
|
||||||
return res, 200
|
return res, 200
|
||||||
|
|
||||||
|
|
||||||
@ -365,14 +413,18 @@ def main(args):
|
|||||||
is_disable_model_switch = args.disable_model_switch
|
is_disable_model_switch = args.disable_model_switch
|
||||||
is_desktop = args.gui
|
is_desktop = args.gui
|
||||||
if is_disable_model_switch:
|
if is_disable_model_switch:
|
||||||
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable")
|
logger.info(
|
||||||
|
f"Start with --disable-model-switch, model switch on frontend is disable"
|
||||||
|
)
|
||||||
|
|
||||||
if args.input and os.path.isdir(args.input):
|
if args.input and os.path.isdir(args.input):
|
||||||
logger.info(f"Initialize file manager")
|
logger.info(f"Initialize file manager")
|
||||||
thumb = FileManager(app)
|
thumb = FileManager(app)
|
||||||
is_enable_file_manager = True
|
is_enable_file_manager = True
|
||||||
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
||||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails')
|
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||||
|
args.output_dir, "lama_cleaner_thumbnails"
|
||||||
|
)
|
||||||
thumb.output_dir = Path(args.output_dir)
|
thumb.output_dir = Path(args.output_dir)
|
||||||
# thumb.start()
|
# thumb.start()
|
||||||
# try:
|
# try:
|
||||||
@ -408,8 +460,12 @@ def main(args):
|
|||||||
from flaskwebgui import FlaskUI
|
from flaskwebgui import FlaskUI
|
||||||
|
|
||||||
ui = FlaskUI(
|
ui = FlaskUI(
|
||||||
app, width=app_width, height=app_height, host=args.host, port=args.port,
|
app,
|
||||||
close_server_on_exit=not args.no_gui_auto_close
|
width=app_width,
|
||||||
|
height=app_height,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
close_server_on_exit=not args.no_gui_auto_close,
|
||||||
)
|
)
|
||||||
ui.run()
|
ui.run()
|
||||||
else:
|
else:
|
||||||
|
62
lama_cleaner/tests/test_instruct_pix2pix.py
Normal file
62
lama_cleaner/tests/test_instruct_pix2pix.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||||
|
from lama_cleaner.schema import HDStrategy
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / 'result'
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||||
|
@pytest.mark.parametrize("cpu_offload", [False, True])
|
||||||
|
def test_instruct_pix2pix(disable_nsfw, cpu_offload):
|
||||||
|
sd_steps = 50 if device == 'cuda' else 1
|
||||||
|
model = ModelManager(name="instruct_pix2pix",
|
||||||
|
device=torch.device(device),
|
||||||
|
hf_access_token="",
|
||||||
|
sd_run_local=True,
|
||||||
|
disable_nsfw=disable_nsfw,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
cpu_offload=cpu_offload)
|
||||||
|
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1)
|
||||||
|
|
||||||
|
name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"instruct_pix2pix_{name}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=1.3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("disable_nsfw", [False])
|
||||||
|
@pytest.mark.parametrize("cpu_offload", [False])
|
||||||
|
def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
|
||||||
|
sd_steps = 50 if device == 'cuda' else 1
|
||||||
|
model = ModelManager(name="instruct_pix2pix",
|
||||||
|
device=torch.device(device),
|
||||||
|
hf_access_token="",
|
||||||
|
sd_run_local=True,
|
||||||
|
disable_nsfw=disable_nsfw,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
cpu_offload=cpu_offload)
|
||||||
|
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps)
|
||||||
|
|
||||||
|
name = f"snow"
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"instruct_pix2pix_{name}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
42
lama_cleaner/tests/test_save_exif.py
Normal file
42
lama_cleaner/tests/test_save_exif.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from lama_cleaner.helper import pil_to_bytes
|
||||||
|
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
png_img_p = current_dir / "image.png"
|
||||||
|
jpg_img_p = current_dir / "bunny.jpeg"
|
||||||
|
|
||||||
|
|
||||||
|
def print_exif(exif):
|
||||||
|
for k, v in exif.items():
|
||||||
|
print(f"{k}: {v}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_png():
|
||||||
|
img = Image.open(png_img_p)
|
||||||
|
exif = img.getexif()
|
||||||
|
print_exif(exif)
|
||||||
|
|
||||||
|
pil_bytes = pil_to_bytes(img, ext="png", exif=exif)
|
||||||
|
|
||||||
|
res_img = Image.open(io.BytesIO(pil_bytes))
|
||||||
|
res_exif = res_img.getexif()
|
||||||
|
|
||||||
|
assert dict(exif) == dict(res_exif)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jpeg():
|
||||||
|
img = Image.open(jpg_img_p)
|
||||||
|
exif = img.getexif()
|
||||||
|
print_exif(exif)
|
||||||
|
|
||||||
|
pil_bytes = pil_to_bytes(img, ext="jpeg", exif=exif)
|
||||||
|
|
||||||
|
res_img = Image.open(io.BytesIO(pil_bytes))
|
||||||
|
res_exif = res_img.getexif()
|
||||||
|
|
||||||
|
assert dict(exif) == dict(res_exif)
|
@ -8,33 +8,37 @@ from lama_cleaner.schema import HDStrategy, SDSampler
|
|||||||
from lama_cleaner.tests.test_model import get_config, assert_equal
|
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / 'result'
|
save_dir = current_dir / "result"
|
||||||
save_dir.mkdir(exist_ok=True, parents=True)
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
@pytest.mark.parametrize("sd_device", ["cuda"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||||
def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
def test_runway_sd_1_5_ddim(
|
||||||
|
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
|
||||||
|
):
|
||||||
def callback(i, t, latents):
|
def callback(i, t, latents):
|
||||||
print(f"sd_step_{i}")
|
pass
|
||||||
|
|
||||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
sd_steps = 50 if sd_device == "cuda" else 1
|
||||||
model = ModelManager(name="sd1.5",
|
model = ModelManager(
|
||||||
device=torch.device(sd_device),
|
name="sd1.5",
|
||||||
hf_access_token="",
|
device=torch.device(sd_device),
|
||||||
sd_run_local=True,
|
hf_access_token="",
|
||||||
disable_nsfw=disable_nsfw,
|
sd_run_local=True,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
disable_nsfw=disable_nsfw,
|
||||||
callback=callback)
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
callback=callback,
|
||||||
|
)
|
||||||
|
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||||
@ -45,31 +49,35 @@ def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disab
|
|||||||
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.3
|
fx=1.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
@pytest.mark.parametrize("sd_device", ["cuda"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a])
|
@pytest.mark.parametrize(
|
||||||
|
"sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a]
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [False])
|
@pytest.mark.parametrize("cpu_textencoder", [False])
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True])
|
@pytest.mark.parametrize("disable_nsfw", [True])
|
||||||
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||||
def callback(i, t, latents):
|
def callback(i, t, latents):
|
||||||
print(f"sd_step_{i}")
|
print(f"sd_step_{i}")
|
||||||
|
|
||||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
sd_steps = 50 if sd_device == "cuda" else 1
|
||||||
model = ModelManager(name="sd1.5",
|
model = ModelManager(
|
||||||
device=torch.device(sd_device),
|
name="sd1.5",
|
||||||
hf_access_token="",
|
device=torch.device(sd_device),
|
||||||
sd_run_local=True,
|
hf_access_token="",
|
||||||
disable_nsfw=disable_nsfw,
|
sd_run_local=True,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
disable_nsfw=disable_nsfw,
|
||||||
callback=callback)
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
callback=callback,
|
||||||
|
)
|
||||||
|
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||||
@ -80,35 +88,37 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
|
|||||||
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.3
|
fx=1.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
@pytest.mark.parametrize("sd_device", ["cuda"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
||||||
def callback(i, t, latents):
|
def callback(i, t, latents):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
sd_steps = 50 if sd_device == "cuda" else 1
|
||||||
model = ModelManager(name="sd1.5",
|
model = ModelManager(
|
||||||
device=torch.device(sd_device),
|
name="sd1.5",
|
||||||
hf_access_token="",
|
device=torch.device(sd_device),
|
||||||
sd_run_local=True,
|
hf_access_token="",
|
||||||
disable_nsfw=False,
|
sd_run_local=True,
|
||||||
sd_cpu_textencoder=False,
|
disable_nsfw=False,
|
||||||
callback=callback)
|
sd_cpu_textencoder=False,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
strategy,
|
strategy,
|
||||||
sd_steps=sd_steps,
|
sd_steps=sd_steps,
|
||||||
prompt='Face of a fox, high resolution, sitting on a park bench',
|
prompt="Face of a fox, high resolution, sitting on a park bench",
|
||||||
negative_prompt='orange, yellow, small',
|
negative_prompt="orange, yellow, small",
|
||||||
sd_sampler=sampler,
|
sd_sampler=sampler,
|
||||||
sd_match_histograms=True
|
sd_match_histograms=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
name = f"{sampler}_negative_prompt"
|
name = f"{sampler}_negative_prompt"
|
||||||
@ -119,27 +129,33 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
|||||||
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1
|
fx=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
@pytest.mark.parametrize("sd_device", ["cuda"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [False])
|
@pytest.mark.parametrize("cpu_textencoder", [False])
|
||||||
@pytest.mark.parametrize("disable_nsfw", [False])
|
@pytest.mark.parametrize("disable_nsfw", [False])
|
||||||
def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
def test_runway_sd_1_5_sd_scale(
|
||||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
|
||||||
|
):
|
||||||
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
sd_steps = 50 if sd_device == "cuda" else 1
|
||||||
model = ModelManager(name="sd1.5",
|
model = ModelManager(
|
||||||
device=torch.device(sd_device),
|
name="sd1.5",
|
||||||
hf_access_token="",
|
device=torch.device(sd_device),
|
||||||
sd_run_local=True,
|
hf_access_token="",
|
||||||
disable_nsfw=disable_nsfw,
|
sd_run_local=True,
|
||||||
sd_cpu_textencoder=cpu_textencoder)
|
disable_nsfw=disable_nsfw,
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
|
||||||
|
)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||||
@ -150,26 +166,30 @@ def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, d
|
|||||||
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
|
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.3
|
fx=1.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
@pytest.mark.parametrize("sd_device", ["cuda"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
||||||
def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
||||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
sd_steps = 50 if sd_device == "cuda" else 1
|
||||||
model = ModelManager(name="sd1.5",
|
model = ModelManager(
|
||||||
device=torch.device(sd_device),
|
name="sd1.5",
|
||||||
hf_access_token="",
|
device=torch.device(sd_device),
|
||||||
sd_run_local=True,
|
hf_access_token="",
|
||||||
disable_nsfw=True,
|
sd_run_local=True,
|
||||||
sd_cpu_textencoder=False,
|
disable_nsfw=True,
|
||||||
cpu_offload=True)
|
sd_cpu_textencoder=False,
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
cpu_offload=True,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
|
||||||
|
)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}"
|
name = f"device_{sd_device}_{sampler}"
|
||||||
@ -182,27 +202,3 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
|||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ['cpu'])
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
|
||||||
def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler):
|
|
||||||
model = ModelManager(name="sd1.5",
|
|
||||||
device=torch.device(sd_device),
|
|
||||||
hf_access_token="",
|
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=False,
|
|
||||||
sd_cpu_textencoder=False,
|
|
||||||
cpu_offload=True)
|
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}"
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload_cpu_device.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
)
|
|
||||||
|
@ -6,9 +6,24 @@ import gradio as gr
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from lama_cleaner.const import AVAILABLE_MODELS, AVAILABLE_DEVICES, CPU_OFFLOAD_HELP, NO_HALF_HELP, DISABLE_NSFW_HELP, \
|
from lama_cleaner.const import (
|
||||||
SD_CPU_TEXTENCODER_HELP, LOCAL_FILES_ONLY_HELP, ENABLE_XFORMERS_HELP, MODEL_DIR_HELP, OUTPUT_DIR_HELP, INPUT_HELP, \
|
AVAILABLE_MODELS,
|
||||||
GUI_HELP, DEFAULT_MODEL, DEFAULT_DEVICE, NO_GUI_AUTO_CLOSE_HELP, DEFAULT_MODEL_DIR
|
AVAILABLE_DEVICES,
|
||||||
|
CPU_OFFLOAD_HELP,
|
||||||
|
NO_HALF_HELP,
|
||||||
|
DISABLE_NSFW_HELP,
|
||||||
|
SD_CPU_TEXTENCODER_HELP,
|
||||||
|
LOCAL_FILES_ONLY_HELP,
|
||||||
|
ENABLE_XFORMERS_HELP,
|
||||||
|
MODEL_DIR_HELP,
|
||||||
|
OUTPUT_DIR_HELP,
|
||||||
|
INPUT_HELP,
|
||||||
|
GUI_HELP,
|
||||||
|
DEFAULT_MODEL,
|
||||||
|
DEFAULT_DEVICE,
|
||||||
|
NO_GUI_AUTO_CLOSE_HELP,
|
||||||
|
DEFAULT_MODEL_DIR, MPS_SUPPORT_MODELS,
|
||||||
|
)
|
||||||
|
|
||||||
_config_file = None
|
_config_file = None
|
||||||
|
|
||||||
@ -33,16 +48,28 @@ class Config(BaseModel):
|
|||||||
|
|
||||||
def load_config(installer_config: str):
|
def load_config(installer_config: str):
|
||||||
if os.path.exists(installer_config):
|
if os.path.exists(installer_config):
|
||||||
with open(installer_config, "r", encoding='utf-8') as f:
|
with open(installer_config, "r", encoding="utf-8") as f:
|
||||||
return Config(**json.load(f))
|
return Config(**json.load(f))
|
||||||
else:
|
else:
|
||||||
return Config()
|
return Config()
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(
|
||||||
host, port, model, device, gui, no_gui_auto_close, no_half, cpu_offload,
|
host,
|
||||||
disable_nsfw, sd_cpu_textencoder, enable_xformers, local_files_only,
|
port,
|
||||||
model_dir, input, output_dir
|
model,
|
||||||
|
device,
|
||||||
|
gui,
|
||||||
|
no_gui_auto_close,
|
||||||
|
no_half,
|
||||||
|
cpu_offload,
|
||||||
|
disable_nsfw,
|
||||||
|
sd_cpu_textencoder,
|
||||||
|
enable_xformers,
|
||||||
|
local_files_only,
|
||||||
|
model_dir,
|
||||||
|
input,
|
||||||
|
output_dir,
|
||||||
):
|
):
|
||||||
config = Config(**locals())
|
config = Config(**locals())
|
||||||
print(config)
|
print(config)
|
||||||
@ -63,6 +90,7 @@ def save_config(
|
|||||||
def close_server(*args):
|
def close_server(*args):
|
||||||
# TODO: make close both browser and server works
|
# TODO: make close both browser and server works
|
||||||
import os, signal
|
import os, signal
|
||||||
|
|
||||||
pid = os.getpid()
|
pid = os.getpid()
|
||||||
os.kill(pid, signal.SIGUSR1)
|
os.kill(pid, signal.SIGUSR1)
|
||||||
|
|
||||||
@ -86,33 +114,53 @@ def main(config_file: str):
|
|||||||
port = gr.Number(init_config.port, label="Port", precision=0)
|
port = gr.Number(init_config.port, label="Port", precision=0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
|
model = gr.Radio(AVAILABLE_MODELS, label="Model", value=init_config.model)
|
||||||
device = gr.Radio(AVAILABLE_DEVICES, label="Device", value=init_config.device)
|
device = gr.Radio(
|
||||||
|
AVAILABLE_DEVICES, label=f"Device(mps supports {MPS_SUPPORT_MODELS})", value=init_config.device
|
||||||
|
)
|
||||||
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
|
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
|
||||||
no_gui_auto_close = gr.Checkbox(init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}")
|
no_gui_auto_close = gr.Checkbox(
|
||||||
|
init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}"
|
||||||
|
)
|
||||||
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
||||||
cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}")
|
cpu_offload = gr.Checkbox(init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}")
|
||||||
disable_nsfw = gr.Checkbox(init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}")
|
disable_nsfw = gr.Checkbox(
|
||||||
sd_cpu_textencoder = gr.Checkbox(init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}")
|
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
|
||||||
enable_xformers = gr.Checkbox(init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}")
|
)
|
||||||
local_files_only = gr.Checkbox(init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}")
|
sd_cpu_textencoder = gr.Checkbox(
|
||||||
|
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
|
||||||
|
)
|
||||||
|
enable_xformers = gr.Checkbox(
|
||||||
|
init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}"
|
||||||
|
)
|
||||||
|
local_files_only = gr.Checkbox(
|
||||||
|
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
|
||||||
|
)
|
||||||
model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}")
|
model_dir = gr.Textbox(init_config.model_dir, label=f"{MODEL_DIR_HELP}")
|
||||||
input = gr.Textbox(init_config.input, label=f"Input file or directory. {INPUT_HELP}")
|
input = gr.Textbox(
|
||||||
output_dir = gr.Textbox(init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}")
|
init_config.input, label=f"Input file or directory. {INPUT_HELP}"
|
||||||
save_btn.click(save_config, [
|
)
|
||||||
host,
|
output_dir = gr.Textbox(
|
||||||
port,
|
init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}"
|
||||||
model,
|
)
|
||||||
device,
|
save_btn.click(
|
||||||
gui,
|
save_config,
|
||||||
no_gui_auto_close,
|
[
|
||||||
no_half,
|
host,
|
||||||
cpu_offload,
|
port,
|
||||||
disable_nsfw,
|
model,
|
||||||
sd_cpu_textencoder,
|
device,
|
||||||
enable_xformers,
|
gui,
|
||||||
local_files_only,
|
no_gui_auto_close,
|
||||||
model_dir,
|
no_half,
|
||||||
input,
|
cpu_offload,
|
||||||
output_dir,
|
disable_nsfw,
|
||||||
], message)
|
sd_cpu_textencoder,
|
||||||
|
enable_xformers,
|
||||||
|
local_files_only,
|
||||||
|
model_dir,
|
||||||
|
input,
|
||||||
|
output_dir,
|
||||||
|
],
|
||||||
|
message,
|
||||||
|
)
|
||||||
demo.launch(inbrowser=True, show_api=False)
|
demo.launch(inbrowser=True, show_api=False)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
torch>=1.9.0
|
torch>=1.9.0
|
||||||
opencv-python
|
opencv-python
|
||||||
flask_cors
|
flask_cors
|
||||||
|
Jinja2==2.11.3
|
||||||
flask==1.1.4
|
flask==1.1.4
|
||||||
flaskwebgui==0.3.5
|
flaskwebgui==0.3.5
|
||||||
tqdm
|
tqdm
|
||||||
@ -11,7 +12,8 @@ pytest
|
|||||||
yacs
|
yacs
|
||||||
markupsafe==2.0.1
|
markupsafe==2.0.1
|
||||||
scikit-image==0.19.3
|
scikit-image==0.19.3
|
||||||
diffusers[torch]==0.10.2
|
diffusers[torch]==0.12.1
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
watchdog==2.2.1
|
watchdog==2.2.1
|
||||||
gradio
|
gradio
|
||||||
|
piexif==1.1.3
|
2
setup.py
2
setup.py
@ -21,7 +21,7 @@ def load_requirements():
|
|||||||
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="lama-cleaner",
|
name="lama-cleaner",
|
||||||
version="0.34.0",
|
version="0.35.0",
|
||||||
author="PanicByte",
|
author="PanicByte",
|
||||||
author_email="cwq1913@gmail.com",
|
author_email="cwq1913@gmail.com",
|
||||||
description="Image inpainting tool powered by SOTA AI Model",
|
description="Image inpainting tool powered by SOTA AI Model",
|
||||||
|
Loading…
Reference in New Issue
Block a user