diff --git a/README.md b/README.md index 972b202..8bd1ad3 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ 1. [ZITS](https://github.com/DQiaole/ZITS_inpainting) 1. [MAT](https://github.com/fenglinglwb/MAT) 1. [FcF](https://github.com/SHI-Labs/FcF-Inpainting) + 1. [SD](https://github.com/CompVis/stable-diffusion) - Support CPU & GPU - Various inpainting [strategy](#inpainting-strategy) - Run as a desktop APP @@ -54,15 +55,16 @@ lama-cleaner --model=lama --device=cpu --port=8080 Available arguments: -| Name | Description | Default | -| ---------- | ---------------------------------------------------------------- | -------- | -| --model | lama/ldm/zits. See details in [Inpaint Model](#inpainting-model) | lama | -| --device | cuda or cpu | cuda | -| --port | Port for backend flask web server | 8080 | -| --gui | Launch lama-cleaner as a desktop application | | -| --gui_size | Set the window size for the application | 1200 900 | -| --input | Path to image you want to load by default | None | -| --debug | Enable debug mode for flask web server | | +| Name | Description | Default | +| ----------------- | -------------------------------------------------------------------------------------------------------- | -------- | +| --model | lama/ldm/zits/mat/fcf/sd. See details in [Inpaint Model](#inpainting-model) | lama | +| --hf_access_token | stable-diffusion(sd) model need huggingface access token https://huggingface.co/docs/hub/security-tokens | | +| --device | cuda or cpu | cuda | +| --port | Port for backend flask web server | 8080 | +| --gui | Launch lama-cleaner as a desktop application | | +| --gui_size | Set the window size for the application | 1200 900 | +| --input | Path to image you want to load by default | None | +| --debug | Enable debug mode for flask web server | | ## Inpainting Model @@ -73,6 +75,7 @@ Available arguments: | ZITS | :+1: Better holistic structures compared with previous methods
:neutral_face: Wireframe module is **very** slow on CPU | `Wireframe`: Enable edge and line detect | | MAT | TODO | | | FcF | :+1: Better structure and texture generation
:neutral_face: Only support fixed size (512x512) input | | +| SD | :+1: SOTA text-to-image diffusion model | | ### LaMa vs LDM diff --git a/lama_cleaner/app/package.json b/lama_cleaner/app/package.json index 5e7db27..17e38ee 100644 --- a/lama_cleaner/app/package.json +++ b/lama_cleaner/app/package.json @@ -6,6 +6,7 @@ "dependencies": { "@heroicons/react": "^1.0.4", "@radix-ui/react-dialog": "0.1.8-rc.25", + "@radix-ui/react-popover": "^1.0.0", "@radix-ui/react-select": "0.1.2-rc.27", "@radix-ui/react-switch": "^0.1.5", "@radix-ui/react-toast": "^0.1.1", @@ -20,14 +21,17 @@ "@types/react-dom": "^17.0.9", "cross-env": "7.x", "lodash": "^4.17.21", + "mitt": "^3.0.0", "nanoid": "^4.0.0", "npm-run-all": "4.x", "react": "^17.0.2", "react-dom": "^17.0.2", + "react-hotkeys-hook": "^3.4.7", "react-scripts": "4.0.3", "react-use": "^17.3.1", "react-zoom-pan-pinch": "^2.1.3", "recoil": "^0.6.1", + "socket.io-client": "^4.5.2", "typescript": "4.x" }, "scripts": { diff --git a/lama_cleaner/app/src/App.tsx b/lama_cleaner/app/src/App.tsx index f5b849d..7df91a1 100644 --- a/lama_cleaner/app/src/App.tsx +++ b/lama_cleaner/app/src/App.tsx @@ -1,5 +1,4 @@ import React, { useEffect, useMemo } from 'react' -import { useKeyPressEvent } from 'react-use' import { useRecoilState } from 'recoil' import { nanoid } from 'nanoid' import useInputImage from './hooks/useInputImage' @@ -9,6 +8,7 @@ import Workspace from './components/Workspace' import { fileState } from './store/Atoms' import { keepGUIAlive } from './utils' import Header from './components/Header/Header' +import useHotKey from './hooks/useHotkey' // Keeping GUI Window Open keepGUIAlive() @@ -24,11 +24,15 @@ function App() { }, [userInputImage, setFile]) // Dark Mode Hotkey - useKeyPressEvent('D', ev => { - ev?.preventDefault() - const newTheme = theme === 'light' ? 'dark' : 'light' - setTheme(newTheme) - }) + useHotKey( + 'shift+d', + () => { + const newTheme = theme === 'light' ? 'dark' : 'light' + setTheme(newTheme) + }, + {}, + [theme] + ) useEffect(() => { document.body.setAttribute('data-theme', theme) diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index e82f855..38bb884 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -1,4 +1,4 @@ -import { Settings } from '../store/Atoms' +import { Rect, Settings } from '../store/Atoms' import { dataURItoBlob } from '../utils' export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` @@ -7,6 +7,8 @@ export default async function inpaint( imageFile: File, maskBase64: string, settings: Settings, + croperRect: Rect, + prompt?: string, sizeLimit?: string ) { // 1080, 2000, Original @@ -30,6 +32,18 @@ export default async function inpaint( hdSettings.hdStrategyResizeLimit.toString() ) + fd.append('prompt', prompt === undefined ? '' : prompt) + fd.append('croperX', croperRect.x.toString()) + fd.append('croperY', croperRect.y.toString()) + fd.append('croperHeight', croperRect.height.toString()) + fd.append('croperWidth', croperRect.width.toString()) + fd.append('useCroper', settings.showCroper ? 'true' : 'false') + fd.append('sdStrength', settings.sdStrength.toString()) + fd.append('sdSteps', settings.sdSteps.toString()) + fd.append('sdGuidanceScale', settings.sdGuidanceScale.toString()) + fd.append('sdSampler', settings.sdSampler.toString()) + fd.append('sdSeed', settings.sdSeedFixed ? settings.sdSeed.toString() : '-1') + if (sizeLimit === undefined) { fd.append('sizeLimit', '1080') } else { diff --git a/lama_cleaner/app/src/components/Croper/Croper.scss b/lama_cleaner/app/src/components/Croper/Croper.scss new file mode 100644 index 0000000..f1642be --- /dev/null +++ b/lama_cleaner/app/src/components/Croper/Croper.scss @@ -0,0 +1,126 @@ +@use 'sass:math'; + +$drag-handle-shortside: 12px; +$drag-handle-longside: 40px; + +$half-handle-shortside: math.div($drag-handle-shortside, 2); +$half-handle-longside: math.div($drag-handle-longside, 2); + +.crop-border { + outline-color: var(--yellow-accent); + outline-style: dashed; +} + +.info-bar { + position: absolute; + pointer-events: auto; + font-size: 1rem; + padding: 0.2rem 0.8rem; + display: flex; + align-items: center; + justify-content: center; + gap: 12px; + color: var(--text-color); + background-color: var(--page-bg); + border-radius: 9999px; + + border: var(--editor-toolkit-panel-border); + box-shadow: 0 0 0 1px #0000001a, 0 3px 16px #00000014, 0 2px 6px 1px #00000017; + + &:hover { + cursor: move; + } +} + +.croper-wrapper { + position: absolute; + height: 100%; + width: 100%; + z-index: 2; + overflow: hidden; + pointer-events: none; +} + +.croper { + position: relative; + top: 0; + bottom: 0; + left: 0; + right: 0; + z-index: 2; + pointer-events: none; + + display: flex; + flex-direction: column; + align-items: center; + + box-shadow: 0 0 0 9999px rgba(0, 0, 0, 0.5); +} + +.drag-handle { + width: $drag-handle-shortside; + height: $drag-handle-shortside; + z-index: 4; + position: absolute; + display: block; + content: ''; + border: 2px solid var(--yellow-accent); + background-color: var(--yellow-accent-light); + pointer-events: auto; + + &:hover { + background-color: var(--yellow-accent); + } +} + +.ord-topleft { + cursor: nw-resize; + top: (-$half-handle-shortside)-1px; + left: (-$half-handle-shortside)-1px; +} + +.ord-topright { + cursor: ne-resize; + top: -($half-handle-shortside)-1px; + right: -($half-handle-shortside)-1px; +} + +.ord-bottomright { + cursor: se-resize; + bottom: -($half-handle-shortside)-1px; + right: -($half-handle-shortside)-1px; +} + +.ord-bottomleft { + cursor: sw-resize; + bottom: -($half-handle-shortside)-1px; + left: -($half-handle-shortside)-1px; +} + +.ord-top, +.ord-bottom { + left: calc(50% - $half-handle-shortside); + cursor: ns-resize; +} + +.ord-top { + top: (-$half-handle-shortside)-1px; +} + +.ord-bottom { + bottom: -($half-handle-shortside)-1px; +} + +.ord-left, +.ord-right { + top: calc(50% - $half-handle-shortside); + cursor: ew-resize; +} + +.ord-left { + left: (-$half-handle-shortside)-1px; +} + +.ord-right { + right: -($half-handle-shortside)-1px; +} diff --git a/lama_cleaner/app/src/components/Croper/Croper.tsx b/lama_cleaner/app/src/components/Croper/Croper.tsx new file mode 100644 index 0000000..496d909 --- /dev/null +++ b/lama_cleaner/app/src/components/Croper/Croper.tsx @@ -0,0 +1,326 @@ +import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/outline' +import React, { useEffect, useState } from 'react' +import { useRecoilState, useRecoilValue } from 'recoil' +import { + croperHeight, + croperWidth, + croperX, + croperY, + isInpaintingState, +} from '../../store/Atoms' + +const DOC_MOVE_OPTS = { capture: true, passive: false } + +const DRAG_HANDLE_BORDER = 2 +const DRAG_HANDLE_SHORT = 12 +const DRAG_HANDLE_LONG = 40 + +interface EVData { + initX: number + initY: number + initHeight: number + initWidth: number + startResizeX: number + startResizeY: number + ord: string // top/right/bottom/left +} + +interface Props { + maxHeight: number + maxWidth: number + scale: number + minHeight: number + minWidth: number +} + +const Croper = (props: Props) => { + const { minHeight, minWidth, maxHeight, maxWidth, scale } = props + const [x, setX] = useRecoilState(croperX) + const [y, setY] = useRecoilState(croperY) + const [height, setHeight] = useRecoilState(croperHeight) + const [width, setWidth] = useRecoilState(croperWidth) + const isInpainting = useRecoilValue(isInpaintingState) + + const [isResizing, setIsResizing] = useState(false) + const [isMoving, setIsMoving] = useState(false) + + useEffect(() => { + setX(Math.round((maxWidth - 512) / 2)) + setY(Math.round((maxHeight - 512) / 2)) + }, [maxHeight, maxWidth, minHeight, minWidth]) + + const [evData, setEVData] = useState({ + initX: 0, + initY: 0, + initHeight: 0, + initWidth: 0, + startResizeX: 0, + startResizeY: 0, + ord: 'top', + }) + + const onDragFocus = () => { + console.log('focus') + } + + const checkTopBottomLimit = (newY: number, newHeight: number) => { + if (newY > 0 && newHeight > minHeight && newY + newHeight <= maxHeight) { + return true + } + return false + } + + const checkLeftRightLimit = (newX: number, newWidth: number) => { + if (newX > 0 && newWidth > minWidth && newX + newWidth <= maxWidth) { + return true + } + return false + } + + const onPointerMove = (e: PointerEvent) => { + if (isInpainting) { + return + } + const curX = e.clientX + const curY = e.clientY + if (isResizing) { + switch (evData.ord) { + case 'top': { + // TODO: 添加四个角以及 drag bar handle + const offset = Math.round((curY - evData.startResizeY) / scale) + const newHeight = evData.initHeight - offset + const newY = evData.initY + offset + if (checkTopBottomLimit(newY, newHeight)) { + setHeight(newHeight) + setY(newY) + } + break + } + case 'right': { + const offset = Math.round((curX - evData.startResizeX) / scale) + const newWidth = evData.initWidth + offset + if (checkLeftRightLimit(evData.initX, newWidth)) { + setWidth(newWidth) + } + break + } + case 'bottom': { + const offset = Math.round((curY - evData.startResizeY) / scale) + const newHeight = evData.initHeight + offset + if (checkTopBottomLimit(evData.initY, newHeight)) { + setHeight(newHeight) + } + break + } + case 'left': { + const offset = Math.round((curX - evData.startResizeX) / scale) + const newWidth = evData.initWidth - offset + const newX = evData.initX + offset + if (checkLeftRightLimit(newX, newWidth)) { + setWidth(newWidth) + setX(newX) + } + break + } + + default: + break + } + } + + if (isMoving) { + const offsetX = Math.round((curX - evData.startResizeX) / scale) + const offsetY = Math.round((curY - evData.startResizeY) / scale) + const newX = evData.initX + offsetX + const newY = evData.initY + offsetY + if ( + checkLeftRightLimit(newX, evData.initWidth) && + checkTopBottomLimit(newY, evData.initHeight) + ) { + setX(newX) + setY(newY) + } + } + } + + const onPointerDone = (e: PointerEvent) => { + if (isResizing) { + setIsResizing(false) + } + + if (isMoving) { + setIsMoving(false) + } + } + + useEffect(() => { + if (isResizing || isMoving) { + document.addEventListener('pointermove', onPointerMove, DOC_MOVE_OPTS) + document.addEventListener('pointerup', onPointerDone, DOC_MOVE_OPTS) + document.addEventListener('pointercancel', onPointerDone, DOC_MOVE_OPTS) + return () => { + document.removeEventListener( + 'pointermove', + onPointerMove, + DOC_MOVE_OPTS + ) + document.removeEventListener('pointerup', onPointerDone, DOC_MOVE_OPTS) + document.removeEventListener( + 'pointercancel', + onPointerDone, + DOC_MOVE_OPTS + ) + } + } + }, [isResizing, isMoving, width, height, evData]) + + const onCropPointerDown = (e: React.PointerEvent) => { + const { ord } = (e.target as HTMLElement).dataset + if (ord) { + setIsResizing(true) + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord, + }) + } + } + + const createCropSelection = () => { + return ( +
+
+ +
+ +
+ +
+ +
+
+
+
+
+ ) + } + + const onInfoBarPointerDown = (e: React.PointerEvent) => { + setIsMoving(true) + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord: '', + }) + } + + const createInfoBar = () => { + return ( +
+
+ {width} x {height} +
+
+ ) + } + + const createBorder = () => { + return ( +
+ ) + } + + return ( +
+
+ {createBorder()} + {createInfoBar()} + {createCropSelection()} +
+
+ ) +} + +export default Croper diff --git a/lama_cleaner/app/src/components/Editor/Editor.scss b/lama_cleaner/app/src/components/Editor/Editor.scss index 4b00593..6f6a485 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.scss +++ b/lama_cleaner/app/src/components/Editor/Editor.scss @@ -55,7 +55,7 @@ position: fixed; bottom: 0.5rem; border-radius: 3rem; - padding: 1rem 3rem; + padding: 0.6rem 3rem; display: grid; grid-template-areas: 'toolkit-size-selector toolkit-brush-slider toolkit-btns'; column-gap: 2rem; diff --git a/lama_cleaner/app/src/components/Editor/Editor.tsx b/lama_cleaner/app/src/components/Editor/Editor.tsx index 4c07990..154658b 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.tsx +++ b/lama_cleaner/app/src/components/Editor/Editor.tsx @@ -22,7 +22,6 @@ import Button from '../shared/Button' import Slider from './Slider' import SizeSelector from './SizeSelector' import { - dataURItoBlob, downloadImage, isMidClick, isRightClick, @@ -30,7 +29,18 @@ import { srcToFile, useImage, } from '../../utils' -import { settingState, toastState } from '../../store/Atoms' +import { + croperState, + isInpaintingState, + isSDState, + propmtState, + runManuallyState, + settingState, + toastState, +} from '../../store/Atoms' +import useHotKey from '../../hooks/useHotkey' +import Croper from '../Croper/Croper' +import emitter, { EVENT_PROMPT } from '../../event' const TOOLBAR_SIZE = 200 const BRUSH_COLOR = '#ffcc00bb' @@ -74,8 +84,14 @@ function mouseXY(ev: SyntheticEvent) { export default function Editor(props: EditorProps) { const { file } = props + const promptVal = useRecoilValue(propmtState) const settings = useRecoilValue(settingState) + const croperRect = useRecoilValue(croperState) const [toastVal, setToastState] = useRecoilState(toastState) + const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState) + const runMannually = useRecoilValue(runManuallyState) + const isSD = useRecoilValue(isSDState) + const [brushSize, setBrushSize] = useState(40) const [original, isOriginalLoaded] = useImage(file) const [renders, setRenders] = useState([]) @@ -90,7 +106,6 @@ export default function Editor(props: EditorProps) { const [showRefBrush, setShowRefBrush] = useState(false) const [isPanning, setIsPanning] = useState(false) const [showOriginal, setShowOriginal] = useState(false) - const [isInpaintingLoading, setIsInpaintingLoading] = useState(false) const [scale, setScale] = useState(1) const [panned, setPanned] = useState(false) const [minScale, setMinScale] = useState(1.0) @@ -130,83 +145,28 @@ export default function Editor(props: EditorProps) { [context, original] ) - const drawLinesOnMask = (_lineGroups: LineGroup[]) => { - if (!context?.canvas.width || !context?.canvas.height) { - throw new Error('canvas has invalid size') - } - maskCanvas.width = context?.canvas.width - maskCanvas.height = context?.canvas.height - const ctx = maskCanvas.getContext('2d') - if (!ctx) { - throw new Error('could not retrieve mask canvas') - } - - _lineGroups.forEach(lineGroup => { - drawLines(ctx, lineGroup, 'white') - }) - } - - const runInpainting = async () => { - if (!hadDrawSomething()) { - return - } - - const newLineGroups = [...lineGroups, curLineGroup] - setCurLineGroup([]) - setIsDraging(false) - setIsInpaintingLoading(true) - if (settings.graduallyInpainting) { - drawLinesOnMask([curLineGroup]) - } else { - drawLinesOnMask(newLineGroups) - } - - let targetFile = file - if (settings.graduallyInpainting === true && renders.length > 0) { - console.info('gradually inpainting on last result') - const lastRender = renders[renders.length - 1] - targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type) - } - - try { - const res = await inpaint( - targetFile, - maskCanvas.toDataURL(), - settings, - sizeLimit.toString() - ) - if (!res) { - throw new Error('empty response') + const drawLinesOnMask = useCallback( + (_lineGroups: LineGroup[]) => { + if (!context?.canvas.width || !context?.canvas.height) { + throw new Error('canvas has invalid size') + } + maskCanvas.width = context?.canvas.width + maskCanvas.height = context?.canvas.height + const ctx = maskCanvas.getContext('2d') + if (!ctx) { + throw new Error('could not retrieve mask canvas') } - const newRender = new Image() - await loadImage(newRender, res) - const newRenders = [...renders, newRender] - setRenders(newRenders) - draw(newRender, []) - // Only append new LineGroup after inpainting success - setLineGroups(newLineGroups) - // clear redo stack - resetRedoState() - } catch (e: any) { - setToastState({ - open: true, - desc: e.message ? e.message : e.toString(), - state: 'error', - duration: 2000, + _lineGroups.forEach(lineGroup => { + drawLines(ctx, lineGroup, 'white') }) - drawOnCurrentRender([]) - } - setIsInpaintingLoading(false) - } + }, + [context, maskCanvas] + ) - const hadDrawSomething = () => { + const hadDrawSomething = useCallback(() => { return curLineGroup.length !== 0 - } - - const hadRunInpainting = () => { - return renders.length !== 0 - } + }, [curLineGroup]) const drawOnCurrentRender = useCallback( (lineGroup: LineGroup) => { @@ -219,8 +179,107 @@ export default function Editor(props: EditorProps) { [original, renders, draw] ) + const runInpainting = useCallback( + async (prompt?: string) => { + console.log('runInpainting') + if (!hadDrawSomething()) { + return + } + console.log(prompt) + + const newLineGroups = [...lineGroups, curLineGroup] + setCurLineGroup([]) + setIsDraging(false) + setIsInpainting(true) + if (settings.graduallyInpainting) { + drawLinesOnMask([curLineGroup]) + } else { + drawLinesOnMask(newLineGroups) + } + + let targetFile = file + if (settings.graduallyInpainting === true && renders.length > 0) { + console.info('gradually inpainting on last result') + const lastRender = renders[renders.length - 1] + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type + ) + } + + try { + const res = await inpaint( + targetFile, + maskCanvas.toDataURL(), + settings, + croperRect, + prompt, + sizeLimit.toString() + ) + if (!res) { + throw new Error('empty response') + } + const newRender = new Image() + await loadImage(newRender, res) + const newRenders = [...renders, newRender] + setRenders(newRenders) + draw(newRender, []) + // Only append new LineGroup after inpainting success + setLineGroups(newLineGroups) + + // clear redo stack + resetRedoState() + } catch (e: any) { + setToastState({ + open: true, + desc: e.message ? e.message : e.toString(), + state: 'error', + duration: 4000, + }) + drawOnCurrentRender([]) + } + setIsInpainting(false) + }, + [ + lineGroups, + curLineGroup, + maskCanvas, + settings.graduallyInpainting, + settings, + croperRect, + sizeLimit, + promptVal, + drawOnCurrentRender, + hadDrawSomething, + drawLinesOnMask, + ] + ) + + useEffect(() => { + emitter.on(EVENT_PROMPT, () => { + if (hadDrawSomething()) { + runInpainting(promptVal) + } else { + setToastState({ + open: true, + desc: 'Please draw mask on picture', + state: 'error', + duration: 1500, + }) + } + }) + return () => { + emitter.off(EVENT_PROMPT) + } + }, [hadDrawSomething, runInpainting, prompt]) + + const hadRunInpainting = () => { + return renders.length !== 0 + } + const handleMultiStrokeKeyDown = () => { - if (isInpaintingLoading) { + if (isInpainting) { return } setIsMultiStrokeKeyPressed(true) @@ -230,13 +289,13 @@ export default function Editor(props: EditorProps) { if (!isMultiStrokeKeyPressed) { return } - if (isInpaintingLoading) { + if (isInpainting) { return } setIsMultiStrokeKeyPressed(false) - if (!settings.runInpaintingManually) { + if (!runMannually) { runInpainting() } } @@ -246,7 +305,7 @@ export default function Editor(props: EditorProps) { } useKey(predicate, handleMultiStrokeKeyup, { event: 'keyup' }, [ - isInpaintingLoading, + isInpainting, isMultiStrokeKeyPressed, hadDrawSomething, ]) @@ -257,7 +316,7 @@ export default function Editor(props: EditorProps) { { event: 'keydown', }, - [isInpaintingLoading] + [isInpainting] ) // Draw once the original image is loaded @@ -341,7 +400,7 @@ export default function Editor(props: EditorProps) { }, [windowSize, resetZoom]) const handleEscPressed = () => { - if (isInpaintingLoading) { + if (isInpainting) { return } if (isDraging || isMultiStrokeKeyPressed) { @@ -361,7 +420,7 @@ export default function Editor(props: EditorProps) { }, [ isDraging, - isInpaintingLoading, + isInpainting, isMultiStrokeKeyPressed, resetZoom, drawOnCurrentRender, @@ -404,7 +463,7 @@ export default function Editor(props: EditorProps) { if (!canvas) { return } - if (isInpaintingLoading) { + if (isInpainting) { return } if (!isDraging) { @@ -416,13 +475,29 @@ export default function Editor(props: EditorProps) { return } - if (settings.runInpaintingManually) { + if (runMannually) { setIsDraging(false) } else { runInpainting() } } + const isOutsideCroper = (clickPnt: { x: number; y: number }) => { + if (clickPnt.x < croperRect.x) { + return true + } + if (clickPnt.y < croperRect.y) { + return true + } + if (clickPnt.x > croperRect.x + croperRect.width) { + return true + } + if (clickPnt.y > croperRect.y + croperRect.height) { + return true + } + return false + } + const onMouseDown = (ev: SyntheticEvent) => { if (isPanning) { return @@ -434,7 +509,7 @@ export default function Editor(props: EditorProps) { if (!canvas) { return } - if (isInpaintingLoading) { + if (isInpainting) { return } @@ -447,10 +522,14 @@ export default function Editor(props: EditorProps) { return } + if (isSD && settings.showCroper && isOutsideCroper(mouseXY(ev))) { + return + } + setIsDraging(true) let lineGroup: LineGroup = [] - if (isMultiStrokeKeyPressed || settings.runInpaintingManually) { + if (isMultiStrokeKeyPressed || runMannually) { lineGroup = [...curLineGroup] } lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] }) @@ -501,7 +580,7 @@ export default function Editor(props: EditorProps) { }, [draw, renders, redoRenders, redoLineGroups, lineGroups, original]) const undo = () => { - if (settings.runInpaintingManually && curLineGroup.length !== 0) { + if (runMannually && curLineGroup.length !== 0) { undoStroke() } else { undoRender() @@ -527,14 +606,14 @@ export default function Editor(props: EditorProps) { useKey(undoPredicate, undo, undefined, [undoStroke, undoRender]) const disableUndo = () => { - if (isInpaintingLoading) { + if (isInpainting) { return true } if (renders.length > 0) { return false } - if (settings.runInpaintingManually) { + if (runMannually) { if (curLineGroup.length === 0) { return true } @@ -575,7 +654,7 @@ export default function Editor(props: EditorProps) { }, [draw, renders, redoRenders, redoLineGroups, lineGroups, original]) const redo = () => { - if (settings.runInpaintingManually && redoCurLines.length !== 0) { + if (runMannually && redoCurLines.length !== 0) { redoStroke() } else { redoRender() @@ -603,14 +682,14 @@ export default function Editor(props: EditorProps) { useKey(redoPredicate, redo, undefined, [redoStroke, redoRender]) const disableRedo = () => { - if (isInpaintingLoading) { + if (isInpainting) { return true } if (redoRenders.length > 0) { return false } - if (settings.runInpaintingManually) { + if (runMannually) { if (redoCurLines.length === 0) { return true } @@ -688,7 +767,7 @@ export default function Editor(props: EditorProps) { }, [showBrush, isPanning]) // Standard Hotkeys for Brush Size - useKeyPressEvent('[', () => { + useHotKey('[', () => { setBrushSize(currentBrushSize => { if (currentBrushSize > 10) { return currentBrushSize - 10 @@ -700,18 +779,23 @@ export default function Editor(props: EditorProps) { }) }) - useKeyPressEvent(']', () => { + useHotKey(']', () => { setBrushSize(currentBrushSize => { return currentBrushSize + 10 }) }) // Manual Inpainting Hotkey - useKeyPressEvent('R', () => { - if (settings.runInpaintingManually && hadDrawSomething()) { - runInpainting() - } - }) + useHotKey( + 'shift+r', + () => { + if (runMannually && hadDrawSomething()) { + runInpainting() + } + }, + {}, + [runMannually] + ) // Toggle clean/zoom tool on spacebar. useKeyPressEvent( @@ -792,7 +876,7 @@ export default function Editor(props: EditorProps) { }} >
+ + {settings.showCroper ? ( + + ) : ( + <> + )} - {showBrush && !isInpaintingLoading && !isPanning && ( + {showBrush && !isInpainting && !isPanning && (
)} @@ -867,11 +963,15 @@ export default function Editor(props: EditorProps) { )}
- + {isSD ? ( + <> + ) : ( + + )} } - disabled={!hadDrawSomething() || isInpaintingLoading} + disabled={!hadDrawSomething() || isInpainting} onClick={() => { - if (!isInpaintingLoading && hadDrawSomething()) { + if (!isInpainting && hadDrawSomething()) { runInpainting() } }} diff --git a/lama_cleaner/app/src/components/Header/Header.scss b/lama_cleaner/app/src/components/Header/Header.scss index cd5909e..b8c887a 100644 --- a/lama_cleaner/app/src/components/Header/Header.scss +++ b/lama_cleaner/app/src/components/Header/Header.scss @@ -1,6 +1,6 @@ header { height: 60px; - padding: 1rem 2rem; + padding: 1rem 1.5rem; position: absolute; top: 0; display: flex; @@ -31,4 +31,4 @@ header { align-items: center; gap: 6px; justify-self: end; -} \ No newline at end of file +} diff --git a/lama_cleaner/app/src/components/Header/Header.tsx b/lama_cleaner/app/src/components/Header/Header.tsx index 4d62259..d839d01 100644 --- a/lama_cleaner/app/src/components/Header/Header.tsx +++ b/lama_cleaner/app/src/components/Header/Header.tsx @@ -1,17 +1,19 @@ import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline' import React, { useState } from 'react' -import { useRecoilState } from 'recoil' -import { fileState } from '../../store/Atoms' +import { useRecoilState, useRecoilValue } from 'recoil' +import { fileState, isSDState } from '../../store/Atoms' import Button from '../shared/Button' import Shortcuts from '../Shortcuts/Shortcuts' import useResolution from '../../hooks/useResolution' import { ThemeChanger } from './ThemeChanger' import SettingIcon from '../Settings/SettingIcon' +import PromptInput from './PromptInput' const Header = () => { const [file, setFile] = useRecoilState(fileState) const resolution = useResolution() const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`) + const isSD = useRecoilValue(isSDState) const renderHeader = () => { return ( @@ -37,6 +39,8 @@ const Header = () => {
+ {isSD && file ? : <>} +
{file && ( diff --git a/lama_cleaner/app/src/components/Header/PromptInput.scss b/lama_cleaner/app/src/components/Header/PromptInput.scss new file mode 100644 index 0000000..1f8e748 --- /dev/null +++ b/lama_cleaner/app/src/components/Header/PromptInput.scss @@ -0,0 +1,18 @@ +.prompt-wrapper { + display: flex; + gap: 12px; +} + +.prompt-wrapper input { + all: unset; + border-width: 0; + border-radius: 0.5rem; + min-width: 600px; + padding: 0 0.8rem; + outline: 1px solid var(--border-color); + + &:focus-visible { + border-width: 0; + outline: 1px solid var(--yellow-accent); + } +} diff --git a/lama_cleaner/app/src/components/Header/PromptInput.tsx b/lama_cleaner/app/src/components/Header/PromptInput.tsx new file mode 100644 index 0000000..7d0d3a5 --- /dev/null +++ b/lama_cleaner/app/src/components/Header/PromptInput.tsx @@ -0,0 +1,44 @@ +import React, { FormEvent, useState } from 'react' +import { useRecoilState } from 'recoil' +import emitter, { EVENT_PROMPT } from '../../event' +import { appState, propmtState } from '../../store/Atoms' +import Button from '../shared/Button' +import TextInput from '../shared/Input' + +// TODO: show progress in input +const PromptInput = () => { + const [app, setAppState] = useRecoilState(appState) + const [prompt, setPrompt] = useRecoilState(propmtState) + + const handleOnInput = (evt: FormEvent) => { + evt.preventDefault() + evt.stopPropagation() + const target = evt.target as HTMLInputElement + setPrompt(target.value) + } + + const handleRepaintClick = () => { + if (prompt.length !== 0 && !app.isInpainting) { + emitter.emit(EVENT_PROMPT) + } + } + + return ( +
+ + +
+ ) +} + +export default PromptInput diff --git a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx index 386202c..19aeea0 100644 --- a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx @@ -1,6 +1,6 @@ import React, { ReactNode } from 'react' import { useRecoilState } from 'recoil' -import { AIModel, settingState } from '../../store/Atoms' +import { AIModel, SDSampler, settingState } from '../../store/Atoms' import Selector from '../shared/Selector' import { Switch, SwitchThumb } from '../shared/Switch' import Tooltip from '../shared/Tooltip' @@ -145,6 +145,8 @@ function ModelSettingBlock() { return undefined case AIModel.FCF: return renderFCFModelDesc() + case AIModel.SD14: + return undefined default: return <> } @@ -182,6 +184,12 @@ function ModelSettingBlock() { 'https://arxiv.org/abs/2208.03382', 'https://github.com/SHI-Labs/FcF-Inpainting' ) + case AIModel.SD14: + return renderModelDesc( + 'Stable Diffusion', + 'https://ommer-lab.com/research/latent-diffusion-models/', + 'https://github.com/CompVis/stable-diffusion' + ) default: return <> } diff --git a/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx b/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx index f18e17e..a71aa8f 100644 --- a/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx +++ b/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx @@ -4,14 +4,28 @@ import SettingBlock from './SettingBlock' interface NumberInputSettingProps { title: string + allowFloat?: boolean desc?: string value: string suffix?: string + width?: number + widthUnit?: string + disable?: boolean onValue: (val: string) => void } function NumberInputSetting(props: NumberInputSettingProps) { - const { title, desc, value, suffix, onValue } = props + const { + title, + allowFloat, + desc, + value, + suffix, + onValue, + width, + widthUnit, + disable, + } = props return ( {suffix && {suffix}} @@ -39,4 +55,11 @@ function NumberInputSetting(props: NumberInputSettingProps) { ) } +NumberInputSetting.defaultProps = { + allowFloat: false, + width: 80, + widthUnit: 'px', + disable: false, +} + export default NumberInputSetting diff --git a/lama_cleaner/app/src/components/Settings/SettingsModal.tsx b/lama_cleaner/app/src/components/Settings/SettingsModal.tsx index 0d441cb..29c563e 100644 --- a/lama_cleaner/app/src/components/Settings/SettingsModal.tsx +++ b/lama_cleaner/app/src/components/Settings/SettingsModal.tsx @@ -1,13 +1,14 @@ import React from 'react' -import { useRecoilState } from 'recoil' -import { settingState } from '../../store/Atoms' +import { useRecoilState, useRecoilValue } from 'recoil' +import { isSDState, settingState } from '../../store/Atoms' import Modal from '../shared/Modal' import ManualRunInpaintingSettingBlock from './ManualRunInpaintingSettingBlock' import HDSettingBlock from './HDSettingBlock' import ModelSettingBlock from './ModelSettingBlock' import GraduallyInpaintingSettingBlock from './GraduallyInpaintingSettingBlock' import DownloadMaskSettingBlock from './DownloadMaskSettingBlock' +import useHotKey from '../../hooks/useHotkey' interface SettingModalProps { onClose: () => void @@ -15,6 +16,7 @@ interface SettingModalProps { export default function SettingModal(props: SettingModalProps) { const { onClose } = props const [setting, setSettingState] = useRecoilState(settingState) + const isSD = useRecoilValue(isSDState) const handleOnClose = () => { setSettingState(old => { @@ -23,6 +25,17 @@ export default function SettingModal(props: SettingModalProps) { onClose() } + useHotKey( + 's', + () => { + setSettingState(old => { + return { ...old, show: !old.show } + }) + }, + {}, + [] + ) + return ( - + {isSD ? <> : } + - + {isSD ? <> : } ) } diff --git a/lama_cleaner/app/src/components/Shortcuts/Shortcuts.tsx b/lama_cleaner/app/src/components/Shortcuts/Shortcuts.tsx index 6dda87b..36bf4d0 100644 --- a/lama_cleaner/app/src/components/Shortcuts/Shortcuts.tsx +++ b/lama_cleaner/app/src/components/Shortcuts/Shortcuts.tsx @@ -1,6 +1,6 @@ import React from 'react' -import { useKeyPressEvent } from 'react-use' import { useRecoilState } from 'recoil' +import useHotKey from '../../hooks/useHotkey' import { shortcutsState } from '../../store/Atoms' import Button from '../shared/Button' @@ -13,8 +13,7 @@ const Shortcuts = () => { }) } - useKeyPressEvent('h', ev => { - ev?.preventDefault() + useHotKey('h', () => { shortcutStateHandler() }) diff --git a/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx b/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx index f2b500e..d23329e 100644 --- a/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx +++ b/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx @@ -64,7 +64,8 @@ export default function ShortcutsModal() { - + +
) diff --git a/lama_cleaner/app/src/components/SidePanel/SidePanel.scss b/lama_cleaner/app/src/components/SidePanel/SidePanel.scss new file mode 100644 index 0000000..ff54c20 --- /dev/null +++ b/lama_cleaner/app/src/components/SidePanel/SidePanel.scss @@ -0,0 +1,57 @@ +@use '../../styles/Mixins/' as *; + +.side-panel { + position: absolute; + top: 68px; + right: 1.5rem; + padding: 0.3rem 0.3rem; + z-index: 4; + + border-radius: 0.8rem; + border-style: solid; + border-color: var(--border-color); + border-width: 1px; +} + +.side-panel-trigger { + font-family: 'WorkSans', sans-serif; + font-size: 16px; + border: 0px; +} + +.side-panel-content { + position: relative; + font-family: 'WorkSans', sans-serif; + font-size: 14px; + top: 1rem; + right: 1.5rem; + padding: 1rem 1rem; + z-index: 9; + + // backdrop-filter: blur(12px); + color: var(--text-color); + background-color: var(--page-bg); + + border-radius: 0.8rem; + border-style: solid; + border-color: var(--border-color); + border-width: 1px; + + display: flex; + flex-direction: column; + gap: 12px; + + .setting-block-content { + gap: 1rem; + } + + // input { + // height: 24px; + // // border-radius: 4px; + // } + + // button { + // height: 28px; + // // border-radius: 4px; + // } +} diff --git a/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx b/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx new file mode 100644 index 0000000..44337bb --- /dev/null +++ b/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx @@ -0,0 +1,188 @@ +import React, { useState } from 'react' +import { useRecoilState } from 'recoil' +import * as PopoverPrimitive from '@radix-ui/react-popover' +import { useToggle } from 'react-use' +import { SDSampler, settingState } from '../../store/Atoms' +import NumberInputSetting from '../Settings/NumberInputSetting' +import SettingBlock from '../Settings/SettingBlock' +import Selector from '../shared/Selector' +import { Switch, SwitchThumb } from '../shared/Switch' +import Button from '../shared/Button' +import emitter, { EVENT_PROMPT } from '../../event' + +const INPUT_WIDTH = 30 + +// TODO: 添加收起来的按钮 +const SidePanel = () => { + const [open, toggleOpen] = useToggle(false) + const [setting, setSettingState] = useRecoilState(settingState) + + const onReRunBtnClick = () => { + emitter.emit(EVENT_PROMPT) + } + + return ( +
+ + toggleOpen()} + > + Stable Diffusion + + + + { + setSettingState(old => { + return { ...old, showCroper: value } + }) + }} + > + + + } + /> + {/* + { + const val = value.length === 0 ? 0 : parseInt(value, 10) + setSettingState(old => { + return { ...old, sdNumSamples: val } + }) + }} + /> */} + + { + const val = value.length === 0 ? 0 : parseInt(value, 10) + setSettingState(old => { + return { ...old, sdSteps: val } + }) + }} + /> + + { + const val = value.length === 0 ? 0 : parseFloat(value) + console.log(val) + setSettingState(old => { + return { ...old, sdStrength: val } + }) + }} + /> + + { + const val = value.length === 0 ? 0 : parseFloat(value) + setSettingState(old => { + return { ...old, sdGuidanceScale: val } + }) + }} + /> + + { + const sampler = val as SDSampler + setSettingState(old => { + return { ...old, sdSampler: sampler } + }) + }} + /> + } + /> + + +
+ } + /> + + + +
+ ) +} + +export default SidePanel diff --git a/lama_cleaner/app/src/components/Workspace.tsx b/lama_cleaner/app/src/components/Workspace.tsx index f8beba9..73c79ed 100644 --- a/lama_cleaner/app/src/components/Workspace.tsx +++ b/lama_cleaner/app/src/components/Workspace.tsx @@ -10,6 +10,7 @@ import { modelDownloaded, switchModel, } from '../adapters/inpainting' +import SidePanel from './SidePanel/SidePanel' interface WorkspaceProps { file: File @@ -82,6 +83,7 @@ const Workspace = ({ file }: WorkspaceProps) => { return ( <> + diff --git a/lama_cleaner/app/src/components/shared/Button.scss b/lama_cleaner/app/src/components/shared/Button.scss index 4a0c3d8..010cfcb 100644 --- a/lama_cleaner/app/src/components/shared/Button.scss +++ b/lama_cleaner/app/src/components/shared/Button.scss @@ -4,6 +4,7 @@ display: grid; grid-auto-flow: column; column-gap: 1rem; + background-color: var(--page-bg); color: var(--btn-text-color); font-family: 'WorkSans', sans-serif; width: max-content; @@ -25,6 +26,13 @@ } .btn-primary-disabled { + background-color: var(--page-bg); pointer-events: none; opacity: 0.5; } + +.btn-border { + border-color: var(--btn-border-color); + border-width: 1px; + border-style: solid; +} diff --git a/lama_cleaner/app/src/components/shared/Button.tsx b/lama_cleaner/app/src/components/shared/Button.tsx index ea35962..707c44a 100644 --- a/lama_cleaner/app/src/components/shared/Button.tsx +++ b/lama_cleaner/app/src/components/shared/Button.tsx @@ -1,6 +1,7 @@ import React, { ReactNode } from 'react' interface ButtonProps { + border?: boolean disabled?: boolean children?: ReactNode className?: string @@ -17,6 +18,7 @@ interface ButtonProps { const Button: React.FC = props => { const { children, + border, className, disabled, icon, @@ -55,6 +57,7 @@ const Button: React.FC = props => { toolTip ? 'info-tooltip' : '', tooltipPosition ? `info-tooltip-${tooltipPosition}` : '', className, + border ? `btn-border` : '', ].join(' ')} > {icon} @@ -65,6 +68,7 @@ const Button: React.FC = props => { Button.defaultProps = { disabled: false, + border: false, } export default Button diff --git a/lama_cleaner/app/src/components/shared/Input.tsx b/lama_cleaner/app/src/components/shared/Input.tsx new file mode 100644 index 0000000..5646bff --- /dev/null +++ b/lama_cleaner/app/src/components/shared/Input.tsx @@ -0,0 +1,42 @@ +import React, { FocusEvent, InputHTMLAttributes } from 'react' +import { useRecoilState } from 'recoil' +import { appState } from '../../store/Atoms' + +const TextInput = React.forwardRef< + HTMLInputElement, + InputHTMLAttributes +>((props: InputHTMLAttributes, forwardedRef) => { + const { onFocus, onBlur, ...itemProps } = props + const [_, setAppState] = useRecoilState(appState) + + const handleOnFocus = (evt: FocusEvent) => { + setAppState(old => { + return { ...old, disableShortCuts: true } + }) + onFocus?.(evt) + } + + const handleOnBlur = (evt: FocusEvent) => { + setAppState(old => { + return { ...old, disableShortCuts: false } + }) + onBlur?.(evt) + } + + return ( + { + if (e.key === 'Escape') { + e.currentTarget.blur() + } + }} + /> + ) +}) + +export default TextInput diff --git a/lama_cleaner/app/src/components/shared/Modal.tsx b/lama_cleaner/app/src/components/shared/Modal.tsx index 81507a9..602a7bf 100644 --- a/lama_cleaner/app/src/components/shared/Modal.tsx +++ b/lama_cleaner/app/src/components/shared/Modal.tsx @@ -1,7 +1,9 @@ import { XIcon } from '@heroicons/react/outline' import React, { ReactNode } from 'react' +import { useRecoilState } from 'recoil' import * as DialogPrimitive from '@radix-ui/react-dialog' import Button from './Button' +import { appState } from '../../store/Atoms' export interface ModalProps { show: boolean @@ -16,10 +18,14 @@ const Modal = React.forwardRef< ModalProps >((props, forwardedRef) => { const { show, children, onClose, className, title } = props + const [_, setAppState] = useRecoilState(appState) const onOpenChange = (open: boolean) => { if (!open) { onClose?.() + setAppState(old => { + return { ...old, disableShortCuts: false } + }) } } diff --git a/lama_cleaner/app/src/components/shared/NumberInput.scss b/lama_cleaner/app/src/components/shared/NumberInput.scss index cd44cd0..cfea79c 100644 --- a/lama_cleaner/app/src/components/shared/NumberInput.scss +++ b/lama_cleaner/app/src/components/shared/NumberInput.scss @@ -5,8 +5,13 @@ padding: 0 0.8rem; outline: 1px solid var(--border-color); height: 32px; + text-align: right; &:focus-visible { outline: 1px solid var(--yellow-accent); } + + &:disabled { + color: var(--border-color); + } } diff --git a/lama_cleaner/app/src/components/shared/NumberInput.tsx b/lama_cleaner/app/src/components/shared/NumberInput.tsx index 5f93bca..ead35aa 100644 --- a/lama_cleaner/app/src/components/shared/NumberInput.tsx +++ b/lama_cleaner/app/src/components/shared/NumberInput.tsx @@ -1,31 +1,44 @@ -import React, { FormEvent, InputHTMLAttributes } from 'react' +import React, { FormEvent, InputHTMLAttributes, useState } from 'react' +import TextInput from './Input' interface NumberInputProps extends InputHTMLAttributes { value: string + allowFloat?: boolean onValue?: (val: string) => void } const NumberInput = React.forwardRef( (props: NumberInputProps, forwardedRef) => { - const { value, onValue, ...itemProps } = props + const { value, allowFloat, onValue, ...itemProps } = props + const [innerValue, setInnerValue] = useState(value) const handleOnInput = (evt: FormEvent) => { const target = evt.target as HTMLInputElement - const val = target.value.replace(/\D/g, '') - onValue?.(val) + let val = target.value + if (allowFloat) { + val = val.replace(/[^0-9.]/g, '').replace(/(\..*?)\..*/g, '$1') + onValue?.(val) + } else { + val = val.replace(/\D/g, '') + onValue?.(val) + } + setInnerValue(val) } return ( - ) } ) +NumberInput.defaultProps = { + allowFloat: false, +} + export default NumberInput diff --git a/lama_cleaner/app/src/components/shared/Selector.tsx b/lama_cleaner/app/src/components/shared/Selector.tsx index 5037b7d..3635382 100644 --- a/lama_cleaner/app/src/components/shared/Selector.tsx +++ b/lama_cleaner/app/src/components/shared/Selector.tsx @@ -51,6 +51,7 @@ const Selector = (props: Props) => { className="select-trigger" style={{ width }} ref={contentRef} + onKeyDown={e => e.preventDefault()} > diff --git a/lama_cleaner/app/src/components/shared/Switch.tsx b/lama_cleaner/app/src/components/shared/Switch.tsx index fea7e49..ce5338e 100644 --- a/lama_cleaner/app/src/components/shared/Switch.tsx +++ b/lama_cleaner/app/src/components/shared/Switch.tsx @@ -12,6 +12,7 @@ const Switch = React.forwardRef< {...itemProps} ref={forwardedRef} className={`switch-root ${className}`} + onKeyDown={e => e.preventDefault()} /> ) }) diff --git a/lama_cleaner/app/src/components/shared/Toast.scss b/lama_cleaner/app/src/components/shared/Toast.scss index 828f744..21b2b03 100644 --- a/lama_cleaner/app/src/components/shared/Toast.scss +++ b/lama_cleaner/app/src/components/shared/Toast.scss @@ -1,7 +1,7 @@ .toast-viewpoint { position: fixed; - top: 48px; - right: 0; + bottom: 48px; + right: 1.5rem; display: flex; flex-direction: row; padding: 25px; diff --git a/lama_cleaner/app/src/event.ts b/lama_cleaner/app/src/event.ts new file mode 100644 index 0000000..e4344fb --- /dev/null +++ b/lama_cleaner/app/src/event.ts @@ -0,0 +1,7 @@ +import mitt from 'mitt' + +export const EVENT_PROMPT = 'prompt' + +const emitter = mitt() + +export default emitter diff --git a/lama_cleaner/app/src/hooks/useHotkey.tsx b/lama_cleaner/app/src/hooks/useHotkey.tsx new file mode 100644 index 0000000..ed2ed62 --- /dev/null +++ b/lama_cleaner/app/src/hooks/useHotkey.tsx @@ -0,0 +1,22 @@ +import { Options, useHotkeys } from 'react-hotkeys-hook' +import { useRecoilValue } from 'recoil' +import { appState } from '../store/Atoms' + +const useHotKey = ( + keys: string, + callback: any, + options?: Options, + deps?: any[] +) => { + const app = useRecoilValue(appState) + + const ref = useHotkeys( + keys, + callback, + { ...options, enabled: !app.disableShortCuts }, + deps + ) + return ref +} + +export default useHotKey diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index 05c6c08..ed8890b 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -9,6 +9,7 @@ export enum AIModel { ZITS = 'zits', MAT = 'mat', FCF = 'fcf', + SD14 = 'sd1.4', } export const fileState = atom({ @@ -16,6 +17,89 @@ export const fileState = atom({ default: undefined, }) +export interface Rect { + x: number + y: number + width: number + height: number +} + +interface AppState { + disableShortCuts: boolean + isInpainting: boolean +} + +export const appState = atom({ + key: 'appState', + default: { + disableShortCuts: false, + isInpainting: false, + }, +}) + +export const propmtState = atom({ + key: 'promptState', + default: '', +}) + +export const isInpaintingState = selector({ + key: 'isInpainting', + get: ({ get }) => { + const app = get(appState) + return app.isInpainting + }, + set: ({ get, set }, newValue: any) => { + const app = get(appState) + set(appState, { ...app, isInpainting: newValue }) + }, +}) + +export const croperState = atom({ + key: 'croperState', + default: { + x: 0, + y: 0, + width: 512, + height: 512, + }, +}) + +export const croperX = selector({ + key: 'croperX', + get: ({ get }) => get(croperState).x, + set: ({ get, set }, newValue: any) => { + const rect = get(croperState) + set(croperState, { ...rect, x: newValue }) + }, +}) + +export const croperY = selector({ + key: 'croperY', + get: ({ get }) => get(croperState).y, + set: ({ get, set }, newValue: any) => { + const rect = get(croperState) + set(croperState, { ...rect, y: newValue }) + }, +}) + +export const croperHeight = selector({ + key: 'croperHeight', + get: ({ get }) => get(croperState).height, + set: ({ get, set }, newValue: any) => { + const rect = get(croperState) + set(croperState, { ...rect, height: newValue }) + }, +}) + +export const croperWidth = selector({ + key: 'croperWidth', + get: ({ get }) => get(croperState).width, + set: ({ get, set }, newValue: any) => { + const rect = get(croperState) + set(croperState, { ...rect, width: newValue }) + }, +}) + interface ToastAtomState { open: boolean desc: string @@ -50,6 +134,7 @@ type ModelsHDSettings = { [key in AIModel]: HDSettings } export interface Settings { show: boolean + showCroper: boolean downloadMask: boolean graduallyInpainting: boolean runInpaintingManually: boolean @@ -62,6 +147,16 @@ export interface Settings { // For ZITS zitsWireframe: boolean + + // For SD + sdMode: SDMode + sdStrength: number + sdSteps: number + sdGuidanceScale: number + sdSampler: SDSampler + sdSeed: number + sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend + sdNumSamples: number } const defaultHDSettings: ModelsHDSettings = { @@ -100,10 +195,28 @@ const defaultHDSettings: ModelsHDSettings = { hdStrategyCropMargin: 128, enabled: false, }, + [AIModel.SD14]: { + hdStrategy: HDStrategy.ORIGINAL, + hdStrategyResizeLimit: 768, + hdStrategyCropTrigerSize: 512, + hdStrategyCropMargin: 128, + enabled: true, + }, +} + +export enum SDSampler { + ddim = 'ddim', +} + +export enum SDMode { + text2img = 'text2img', + img2img = 'img2img', + inpainting = 'inpainting', } export const settingStateDefault: Settings = { show: false, + showCroper: false, downloadMask: false, graduallyInpainting: true, runInpaintingManually: false, @@ -114,6 +227,16 @@ export const settingStateDefault: Settings = { ldmSampler: LDMSampler.plms, zitsWireframe: true, + + // SD + sdMode: SDMode.inpainting, + sdStrength: 0.75, + sdSteps: 50, + sdGuidanceScale: 7.5, + sdSampler: SDSampler.ddim, + sdSeed: 42, + sdSeedFixed: true, + sdNumSamples: 1, } const localStorageEffect = @@ -164,3 +287,20 @@ export const hdSettingsState = selector({ }) }, }) + +export const isSDState = selector({ + key: 'isSD', + get: ({ get }) => { + const settings = get(settingState) + return settings.model === AIModel.SD14 + }, +}) + +export const runManuallyState = selector({ + key: 'runManuallyState', + get: ({ get }) => { + const settings = get(settingState) + const isSD = get(isSDState) + return settings.runInpaintingManually || isSD + }, +}) diff --git a/lama_cleaner/app/src/styles/_Colors.scss b/lama_cleaner/app/src/styles/_Colors.scss index 0c2d493..ec138e7 100644 --- a/lama_cleaner/app/src/styles/_Colors.scss +++ b/lama_cleaner/app/src/styles/_Colors.scss @@ -6,6 +6,7 @@ --page-bg-light: rgb(255, 255, 255, 0.5); --page-text-color: #040404; --yellow-accent: #ffcc00; + --yellow-accent-light: #ffcc0055; --link-color: rgb(0, 0, 0); --border-color: rgb(100, 100, 120); --border-color-light: rgba(100, 100, 120, 0.5); @@ -57,4 +58,6 @@ --box-shadow: inset 0 0.5px rgba(255, 255, 255, 0.1), inset 0 1px 5px hsl(210 16.7% 97.6%), 0px 0px 0px 0.5px hsl(205 10.7% 78%), 0px 2px 1px -1px hsl(205 10.7% 78%), 0 1px hsl(205 10.7% 78%); + + --croper-bg: rgba(0, 0, 0, 0.5); } diff --git a/lama_cleaner/app/src/styles/_ColorsDark.scss b/lama_cleaner/app/src/styles/_ColorsDark.scss index 7bd9ecb..e5e4dab 100644 --- a/lama_cleaner/app/src/styles/_ColorsDark.scss +++ b/lama_cleaner/app/src/styles/_ColorsDark.scss @@ -6,6 +6,7 @@ --page-bg-light: #04040488; --page-text-color: #f9f9f9; --yellow-accent: #ffcc00; + --yellow-accent-light: #ffcc0055; --link-color: var(--yellow-accent); --border-color: rgb(100, 100, 120); --border-color-light: rgba(102, 102, 102); @@ -55,4 +56,6 @@ --box-shadow: inset 0 0.5px rgba(255, 255, 255, 0.1), inset 0 1px 5px hsl(195 7.1% 11%), 0px 0px 0px 0.5px hsl(207 5.6% 31.6%), 0px 2px 1px -1px hsl(207 5.6% 31.6%), 0 1px hsl(207 5.6% 31.6%); + + --croper-bg: rgba(0, 0, 0, 0.5); } diff --git a/lama_cleaner/app/src/styles/_index.scss b/lama_cleaner/app/src/styles/_index.scss index 65d0fbd..ae3f150 100644 --- a/lama_cleaner/app/src/styles/_index.scss +++ b/lama_cleaner/app/src/styles/_index.scss @@ -9,9 +9,12 @@ @use '../components/Editor/Editor'; @use '../components/LandingPage/LandingPage'; @use '../components/Header/Header'; +@use '../components/Header/PromptInput'; @use '../components/Header/ThemeChanger'; @use '../components/Shortcuts/Shortcuts'; @use '../components/Settings/Settings.scss'; +@use '../components/SidePanel/SidePanel.scss'; +@use '../components/Croper/Croper.scss'; // Shared @use '../components/FileSelect/FileSelect'; diff --git a/lama_cleaner/app/yarn.lock b/lama_cleaner/app/yarn.lock index a3e70f6..c1a266d 100644 --- a/lama_cleaner/app/yarn.lock +++ b/lama_cleaner/app/yarn.lock @@ -1241,6 +1241,26 @@ minimatch "^3.0.4" strip-json-comments "^3.1.1" +"@floating-ui/core@^0.7.3": + version "0.7.3" + resolved "https://registry.npmmirror.com/@floating-ui/core/-/core-0.7.3.tgz#d274116678ffae87f6b60e90f88cc4083eefab86" + integrity sha512-buc8BXHmG9l82+OQXOFU3Kr2XQx9ys01U/Q9HMIrZ300iLc8HLMgh7dcCqgYzAzf4BkoQvDcXf5Y+CuEZ5JBYg== + +"@floating-ui/dom@^0.5.3": + version "0.5.4" + resolved "https://registry.npmmirror.com/@floating-ui/dom/-/dom-0.5.4.tgz#4eae73f78bcd4bd553ae2ade30e6f1f9c73fe3f1" + integrity sha512-419BMceRLq0RrmTSDxn8hf9R3VCJv2K9PUfugh5JyEFmdjzDo+e8U5EdR8nzKq8Yj1htzLm3b6eQEEam3/rrtg== + dependencies: + "@floating-ui/core" "^0.7.3" + +"@floating-ui/react-dom@0.7.2": + version "0.7.2" + resolved "https://registry.npmmirror.com/@floating-ui/react-dom/-/react-dom-0.7.2.tgz#0bf4ceccb777a140fc535c87eb5d6241c8e89864" + integrity sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg== + dependencies: + "@floating-ui/dom" "^0.5.3" + use-isomorphic-layout-effect "^1.1.1" + "@gar/promisify@^1.0.1": version "1.1.2" resolved "https://registry.npmjs.org/@gar/promisify/-/promisify-1.1.2.tgz" @@ -1566,6 +1586,13 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/primitive@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/primitive/-/primitive-1.0.0.tgz#e1d8ef30b10ea10e69c76e896f608d9276352253" + integrity sha512-3e7rn8FDMin4CgeL7Z/49smCA3rFYY3Ha2rUQ7HRWFadS5iCRw08ZgVT1LaNTCNqgvrUiyczLflrVrF0SRQtNA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-arrow@0.1.4": version "0.1.4" resolved "https://registry.npmmirror.com/@radix-ui/react-arrow/-/react-arrow-0.1.4.tgz#a871448a418cd3507d83840fdd47558cb961672b" @@ -1574,6 +1601,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-primitive" "0.1.4" +"@radix-ui/react-arrow@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-arrow/-/react-arrow-1.0.0.tgz#c461f4c2cab3317e3d42a1ae62910a4cbb0192a1" + integrity sha512-1MUuv24HCdepi41+qfv125EwMuxgQ+U+h0A9K3BjCO/J8nVRREKHHpkD9clwfnjEDk9hgGzCnff4aUKCPiRepw== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-primitive" "1.0.0" + "@radix-ui/react-collection@0.1.5-rc.18": version "0.1.5-rc.18" resolved "https://registry.npmmirror.com/@radix-ui/react-collection/-/react-collection-0.1.5-rc.18.tgz#4dc03a8f464643748c0dad781b472f149d671d5c" @@ -1599,6 +1634,13 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-compose-refs@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-compose-refs/-/react-compose-refs-1.0.0.tgz#37595b1f16ec7f228d698590e78eeed18ff218ae" + integrity sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-context@0.1.1": version "0.1.1" resolved "https://registry.npmmirror.com/@radix-ui/react-context/-/react-context-0.1.1.tgz#06996829ea124d9a1bc1dbe3e51f33588fab0875" @@ -1613,6 +1655,13 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-context@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-context/-/react-context-1.0.0.tgz#f38e30c5859a9fb5e9aa9a9da452ee3ed9e0aee0" + integrity sha512-1pVM9RfOQ+n/N5PJK33kRSKsr1glNxomxONs5c49MliinBY6Yw2Q995qfBUUo0/Mbg05B/sGA0gkgPI7kmSHBg== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-dialog@0.1.8-rc.25": version "0.1.8-rc.25" resolved "https://registry.npmmirror.com/@radix-ui/react-dialog/-/react-dialog-0.1.8-rc.25.tgz#dea6af32268b34070346ed5d6d609ff699a1de43" @@ -1667,6 +1716,18 @@ "@radix-ui/react-use-callback-ref" "0.1.1-rc.18" "@radix-ui/react-use-escape-keydown" "0.1.1-rc.18" +"@radix-ui/react-dismissable-layer@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.0.tgz#35b7826fa262fd84370faef310e627161dffa76b" + integrity sha512-n7kDRfx+LB1zLueRDvZ1Pd0bxdJWDUZNQ/GWoxDn2prnuJKRdxsjulejX/ePkOsLi2tTm6P24mDqlMSgQpsT6g== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "1.0.0" + "@radix-ui/react-compose-refs" "1.0.0" + "@radix-ui/react-primitive" "1.0.0" + "@radix-ui/react-use-callback-ref" "1.0.0" + "@radix-ui/react-use-escape-keydown" "1.0.0" + "@radix-ui/react-focus-guards@0.1.1-rc.18": version "0.1.1-rc.18" resolved "https://registry.npmmirror.com/@radix-ui/react-focus-guards/-/react-focus-guards-0.1.1-rc.18.tgz#f0e2ebd3cbfd363a71682e3234b274ab7d7df4ce" @@ -1674,6 +1735,13 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-focus-guards@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-focus-guards/-/react-focus-guards-1.0.0.tgz#339c1c69c41628c1a5e655f15f7020bf11aa01fa" + integrity sha512-UagjDk4ijOAnGu4WMUPj9ahi7/zJJqNZ9ZAiGPp7waUWJO0O1aWXi/udPphI0IUjvrhBsZJGSN66dR2dsueLWQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-focus-scope@0.1.5-rc.18": version "0.1.5-rc.18" resolved "https://registry.npmmirror.com/@radix-ui/react-focus-scope/-/react-focus-scope-0.1.5-rc.18.tgz#e26a0317130687fd3668af8ec68e19e04dc7668f" @@ -1684,6 +1752,16 @@ "@radix-ui/react-primitive" "0.1.5-rc.18" "@radix-ui/react-use-callback-ref" "0.1.1-rc.18" +"@radix-ui/react-focus-scope@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-focus-scope/-/react-focus-scope-1.0.0.tgz#95a0c1188276dc8933b1eac5f1cdb6471e01ade5" + integrity sha512-C4SWtsULLGf/2L4oGeIHlvWQx7Rf+7cX/vKOAD2dXW0A1b5QXwi3wWeaEgW+wn+SEVrraMUk05vLU9fZZz5HbQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-compose-refs" "1.0.0" + "@radix-ui/react-primitive" "1.0.0" + "@radix-ui/react-use-callback-ref" "1.0.0" + "@radix-ui/react-id@0.1.5": version "0.1.5" resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8" @@ -1700,6 +1778,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-use-layout-effect" "0.1.1-rc.18" +"@radix-ui/react-id@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-1.0.0.tgz#8d43224910741870a45a8c9d092f25887bb6d11e" + integrity sha512-Q6iAB/U7Tq3NTolBBQbHTgclPmGWE3OlktGGqrClPozSw4vkQ1DfQAOtzgRPecKsMdJINE05iaoDUG8tRzCBjw== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-layout-effect" "1.0.0" + "@radix-ui/react-label@0.1.5": version "0.1.5" resolved "https://registry.npmmirror.com/@radix-ui/react-label/-/react-label-0.1.5.tgz#12cd965bfc983e0148121d4c99fb8e27a917c45c" @@ -1722,6 +1808,28 @@ "@radix-ui/react-id" "0.1.6-rc.18" "@radix-ui/react-primitive" "0.1.5-rc.18" +"@radix-ui/react-popover@^1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-popover/-/react-popover-1.0.0.tgz#5ee72013089fdf9038417fc1eb98a749c17457fd" + integrity sha512-osxFFO0TiZ9ABpEOitZu0R1Fdd+tSpJgAqLZxRLLdZQ7ya0onSODcITp5hXDVuYQeVXH6pKEBGwXN6ZGjZ0a5g== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "1.0.0" + "@radix-ui/react-compose-refs" "1.0.0" + "@radix-ui/react-context" "1.0.0" + "@radix-ui/react-dismissable-layer" "1.0.0" + "@radix-ui/react-focus-guards" "1.0.0" + "@radix-ui/react-focus-scope" "1.0.0" + "@radix-ui/react-id" "1.0.0" + "@radix-ui/react-popper" "1.0.0" + "@radix-ui/react-portal" "1.0.0" + "@radix-ui/react-presence" "1.0.0" + "@radix-ui/react-primitive" "1.0.0" + "@radix-ui/react-slot" "1.0.0" + "@radix-ui/react-use-controllable-state" "1.0.0" + aria-hidden "^1.1.1" + react-remove-scroll "2.5.4" + "@radix-ui/react-popper@0.1.4": version "0.1.4" resolved "https://registry.npmmirror.com/@radix-ui/react-popper/-/react-popper-0.1.4.tgz#dfc055dcd7dfae6a2eff7a70d333141d15a5d029" @@ -1737,6 +1845,22 @@ "@radix-ui/react-use-size" "0.1.1" "@radix-ui/rect" "0.1.1" +"@radix-ui/react-popper@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-popper/-/react-popper-1.0.0.tgz#fb4f937864bf39c48f27f55beee61fa9f2bef93c" + integrity sha512-k2dDd+1Wl0XWAMs9ZvAxxYsB9sOsEhrFQV4CINd7IUZf0wfdye4OHen9siwxvZImbzhgVeKTJi68OQmPRvVdMg== + dependencies: + "@babel/runtime" "^7.13.10" + "@floating-ui/react-dom" "0.7.2" + "@radix-ui/react-arrow" "1.0.0" + "@radix-ui/react-compose-refs" "1.0.0" + "@radix-ui/react-context" "1.0.0" + "@radix-ui/react-primitive" "1.0.0" + "@radix-ui/react-use-layout-effect" "1.0.0" + "@radix-ui/react-use-rect" "1.0.0" + "@radix-ui/react-use-size" "1.0.0" + "@radix-ui/rect" "1.0.0" + "@radix-ui/react-portal@0.1.4": version "0.1.4" resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2" @@ -1755,6 +1879,14 @@ "@radix-ui/react-primitive" "0.1.5-rc.18" "@radix-ui/react-use-layout-effect" "0.1.1-rc.18" +"@radix-ui/react-portal@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-1.0.0.tgz#7220b66743394fabb50c55cb32381395cc4a276b" + integrity sha512-a8qyFO/Xb99d8wQdu4o7qnigNjTPG123uADNecz0eX4usnQEj7o+cG4ZX4zkqq98NYekT7UoEQIjxBNWIFuqTA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-primitive" "1.0.0" + "@radix-ui/react-presence@0.1.2": version "0.1.2" resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3" @@ -1773,6 +1905,15 @@ "@radix-ui/react-compose-refs" "0.1.1-rc.18" "@radix-ui/react-use-layout-effect" "0.1.1-rc.18" +"@radix-ui/react-presence@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-1.0.0.tgz#814fe46df11f9a468808a6010e3f3ca7e0b2e84a" + integrity sha512-A+6XEvN01NfVWiKu38ybawfHsBjWum42MRPnEuqPsBZ4eV7e/7K321B5VgYMPv3Xx5An6o1/l9ZuDBgmcmWK3w== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-compose-refs" "1.0.0" + "@radix-ui/react-use-layout-effect" "1.0.0" + "@radix-ui/react-primitive@0.1.4": version "0.1.4" resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1" @@ -1789,6 +1930,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-slot" "0.1.3-rc.18" +"@radix-ui/react-primitive@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-1.0.0.tgz#376cd72b0fcd5e0e04d252ed33eb1b1f025af2b0" + integrity sha512-EyXe6mnRlHZ8b6f4ilTDrXmkLShICIuOTTj0GX4w1rp+wSxf3+TD05u1UOITC8VsJ2a9nwHvdXtOXEOl0Cw/zQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-slot" "1.0.0" + "@radix-ui/react-select@0.1.2-rc.27": version "0.1.2-rc.27" resolved "https://registry.npmmirror.com/@radix-ui/react-select/-/react-select-0.1.2-rc.27.tgz#91948d482b3db8cf83172838dfae0f4bedec9566" @@ -1831,6 +1980,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-compose-refs" "0.1.1-rc.18" +"@radix-ui/react-slot@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-slot/-/react-slot-1.0.0.tgz#7fa805b99891dea1e862d8f8fbe07f4d6d0fd698" + integrity sha512-3mrKauI/tWXo1Ll+gN5dHcxDPdm/Df1ufcDLCecn+pnCIVcdWE7CujXo8QaXOWRJyZyQWWbpB8eFwHzWXlv5mQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-compose-refs" "1.0.0" + "@radix-ui/react-switch@^0.1.5": version "0.1.5" resolved "https://registry.npmmirror.com/@radix-ui/react-switch/-/react-switch-0.1.5.tgz#071ffa19a17a47fdc5c5e6f371bd5901c9fef2f4" @@ -1915,6 +2072,13 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-use-callback-ref@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.0.tgz#9e7b8b6b4946fe3cbe8f748c82a2cce54e7b6a90" + integrity sha512-GZtyzoHz95Rhs6S63D2t/eqvdFCm7I+yHMLVQheKM7nBD8mbZIt+ct1jz4536MDnaOGKIxynJ8eHTkVGVVkoTg== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-controllable-state@0.1.0": version "0.1.0" resolved "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-0.1.0.tgz#4fced164acfc69a4e34fb9d193afdab973a55de1" @@ -1931,6 +2095,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-use-callback-ref" "0.1.1-rc.18" +"@radix-ui/react-use-controllable-state@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.0.0.tgz#a64deaafbbc52d5d407afaa22d493d687c538b7f" + integrity sha512-FohDoZvk3mEXh9AWAVyRTYR4Sq7/gavuofglmiXB2g1aKyboUD4YtgWxKj8O5n+Uak52gXQ4wKz5IFST4vtJHg== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-callback-ref" "1.0.0" + "@radix-ui/react-use-escape-keydown@0.1.0": version "0.1.0" resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874" @@ -1947,6 +2119,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-use-callback-ref" "0.1.1-rc.18" +"@radix-ui/react-use-escape-keydown@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.0.0.tgz#aef375db4736b9de38a5a679f6f49b45a060e5d1" + integrity sha512-JwfBCUIfhXRxKExgIqGa4CQsiMemo1Xt0W/B4ei3fpzpvPENKpMKQ8mZSB6Acj3ebrAEgi2xiQvcI1PAAodvyg== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-callback-ref" "1.0.0" + "@radix-ui/react-use-layout-effect@0.1.0": version "0.1.0" resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223" @@ -1961,6 +2141,13 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-use-layout-effect@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.0.0.tgz#2fc19e97223a81de64cd3ba1dc42ceffd82374dc" + integrity sha512-6Tpkq+R6LOlmQb1R5NNETLG0B4YP0wc+klfXafpUCj6JGyaUc8il7/kUZ7m59rGbXGczE9Bs+iz2qloqsZBduQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-previous@0.1.1": version "0.1.1" resolved "https://registry.npmmirror.com/@radix-ui/react-use-previous/-/react-use-previous-0.1.1.tgz#0226017f72267200f6e832a7103760e96a6db5d0" @@ -1983,6 +2170,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/rect" "0.1.1" +"@radix-ui/react-use-rect@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-rect/-/react-use-rect-1.0.0.tgz#b040cc88a4906b78696cd3a32b075ed5b1423b3e" + integrity sha512-TB7pID8NRMEHxb/qQJpvSt3hQU4sqNPM1VCTjTRjEOa7cEop/QMuq8S6fb/5Tsz64kqSvB9WnwsDHtjnrM9qew== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/rect" "1.0.0" + "@radix-ui/react-use-size@0.1.1": version "0.1.1" resolved "https://registry.npmmirror.com/@radix-ui/react-use-size/-/react-use-size-0.1.1.tgz#f6b75272a5d41c3089ca78c8a2e48e5f204ef90f" @@ -1990,6 +2185,14 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-use-size@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-size/-/react-use-size-1.0.0.tgz#a0b455ac826749419f6354dc733e2ca465054771" + integrity sha512-imZ3aYcoYCKhhgNpkNDh/aTiU05qw9hX+HHI1QDBTyIlcFjgeFlKKySNGMwTp7nYFLQg/j0VA2FmCY4WPDDHMg== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-layout-effect" "1.0.0" + "@radix-ui/react-visually-hidden@0.1.4": version "0.1.4" resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03" @@ -2013,6 +2216,13 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/rect@1.0.0": + version "1.0.0" + resolved "https://registry.npmmirror.com/@radix-ui/rect/-/rect-1.0.0.tgz#0dc8e6a829ea2828d53cbc94b81793ba6383bf3c" + integrity sha512-d0O68AYy/9oeEy1DdC07bz1/ZXX+DqCskRd3i4JzLSTXwefzaepQrKjXC7aNM8lTHjFLDO0pDgaEiQ7jEk+HVg== + dependencies: + "@babel/runtime" "^7.13.10" + "@rollup/plugin-node-resolve@^7.1.1": version "7.1.3" resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz" @@ -2055,6 +2265,11 @@ dependencies: "@sinonjs/commons" "^1.7.0" +"@socket.io/component-emitter@~3.1.0": + version "3.1.0" + resolved "https://registry.npmmirror.com/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz#96116f2a912e0c02817345b3c10751069920d553" + integrity sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg== + "@surma/rollup-plugin-off-main-thread@^1.1.1": version "1.4.2" resolved "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-1.4.2.tgz" @@ -4621,6 +4836,13 @@ debug@^3.1.1, debug@^3.2.6, debug@^3.2.7: dependencies: ms "^2.1.1" +debug@~4.3.1, debug@~4.3.2: + version "4.3.4" + resolved "https://registry.npmmirror.com/debug/-/debug-4.3.4.tgz#1319f6579357f2338d3337d2cdd4914bb5dcc865" + integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== + dependencies: + ms "2.1.2" + decamelize@^1.2.0: version "1.2.0" resolved "https://registry.npmjs.org/decamelize/-/decamelize-1.2.0.tgz" @@ -5004,6 +5226,22 @@ end-of-stream@^1.0.0, end-of-stream@^1.1.0: dependencies: once "^1.4.0" +engine.io-client@~6.2.1: + version "6.2.2" + resolved "https://registry.npmmirror.com/engine.io-client/-/engine.io-client-6.2.2.tgz#c6c5243167f5943dcd9c4abee1bfc634aa2cbdd0" + integrity sha512-8ZQmx0LQGRTYkHuogVZuGSpDqYZtCM/nv8zQ68VZ+JkOpazJ7ICdsSpaO6iXwvaU30oFg5QJOJWj8zWqhbKjkQ== + dependencies: + "@socket.io/component-emitter" "~3.1.0" + debug "~4.3.1" + engine.io-parser "~5.0.3" + ws "~8.2.3" + xmlhttprequest-ssl "~2.0.0" + +engine.io-parser@~5.0.3: + version "5.0.4" + resolved "https://registry.npmmirror.com/engine.io-parser/-/engine.io-parser-5.0.4.tgz#0b13f704fa9271b3ec4f33112410d8f3f41d0fc0" + integrity sha512-+nVFp+5z1E3HcToEnO7ZIj3g+3k9389DvWtvJZz0T6/eOCPIyyxehFcedoYrZQrp0LgQbD9pPXhpMBKMd5QURg== + enhanced-resolve@^4.3.0: version "4.5.0" resolved "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-4.5.0.tgz" @@ -6215,6 +6453,11 @@ hosted-git-info@^2.1.4: resolved "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz" integrity sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw== +hotkeys-js@3.9.4: + version "3.9.4" + resolved "https://registry.npmmirror.com/hotkeys-js/-/hotkeys-js-3.9.4.tgz#ce1aa4c3a132b6a63a9dd5644fc92b8a9b9cbfb9" + integrity sha512-2zuLt85Ta+gIyvs4N88pCYskNrxf1TFv3LR9t5mdAZIX8BcgQQ48F2opUptvHa6m8zsy5v/a0i9mWzTrlNWU0Q== + hpack.js@^2.1.6: version "2.1.6" resolved "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz" @@ -7785,16 +8028,11 @@ lodash.uniq@^4.5.0: resolved "https://registry.npmjs.org/lodash.uniq/-/lodash.uniq-4.5.0.tgz" integrity sha1-0CJTc662Uq3BvILklFM5qEJ1R3M= -"lodash@>=3.5 <5", lodash@^4.17.11, lodash@^4.17.14, lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.20, lodash@^4.17.5, lodash@^4.7.0: +"lodash@>=3.5 <5", lodash@^4.17.11, lodash@^4.17.14, lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.20, lodash@^4.17.21, lodash@^4.17.5, lodash@^4.7.0: version "4.17.21" resolved "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== -lodash@^4.17.21: - version "4.17.21" - resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" - integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== - loglevel@^1.6.8: version "1.7.1" resolved "https://registry.npmjs.org/loglevel/-/loglevel-1.7.1.tgz" @@ -8095,6 +8333,11 @@ mississippi@^3.0.0: stream-each "^1.1.0" through2 "^2.0.0" +mitt@^3.0.0: + version "3.0.0" + resolved "https://registry.npmmirror.com/mitt/-/mitt-3.0.0.tgz#69ef9bd5c80ff6f57473e8d89326d01c414be0bd" + integrity sha512-7dX2/10ITVyqh4aOSVI9gdape+t9l2/8QxHrFmUXu4EEUpdlxl6RudZUPZoc+zuY2hk1j7XxVroIVIan/pD/SQ== + mixin-deep@^1.2.0: version "1.3.2" resolved "https://registry.npmjs.org/mixin-deep/-/mixin-deep-1.3.2.tgz" @@ -9924,6 +10167,13 @@ react-error-overlay@^6.0.9: resolved "https://registry.npmjs.org/react-error-overlay/-/react-error-overlay-6.0.9.tgz" integrity sha512-nQTTcUu+ATDbrSD1BZHr5kgSD4oF8OFjxun8uAaL8RwPBacGBNPf/yAuVVdx17N8XNzRDMrZ9XcKZHCjPW+9ew== +react-hotkeys-hook@^3.4.7: + version "3.4.7" + resolved "https://registry.npmmirror.com/react-hotkeys-hook/-/react-hotkeys-hook-3.4.7.tgz#e16a0a85f59feed9f48d12cfaf166d7df4c96b7a" + integrity sha512-+bbPmhPAl6ns9VkXkNNyxlmCAIyDAcWbB76O4I0ntr3uWCRuIQf/aRLartUahe9chVMPj+OEzzfk3CQSjclUEQ== + dependencies: + hotkeys-js "3.9.4" + react-is@^16.8.1: version "16.13.1" resolved "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz" @@ -9947,6 +10197,25 @@ react-remove-scroll-bar@^2.3.0: react-style-singleton "^2.2.0" tslib "^2.0.0" +react-remove-scroll-bar@^2.3.3: + version "2.3.3" + resolved "https://registry.npmmirror.com/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.3.tgz#e291f71b1bb30f5f67f023765b7435f4b2b2cd94" + integrity sha512-i9GMNWwpz8XpUpQ6QlevUtFjHGqnPG4Hxs+wlIJntu/xcsZVEpJcIV71K3ZkqNy2q3GfgvkD7y6t/Sv8ofYSbw== + dependencies: + react-style-singleton "^2.2.1" + tslib "^2.0.0" + +react-remove-scroll@2.5.4: + version "2.5.4" + resolved "https://registry.npmmirror.com/react-remove-scroll/-/react-remove-scroll-2.5.4.tgz#afe6491acabde26f628f844b67647645488d2ea0" + integrity sha512-xGVKJJr0SJGQVirVFAUZ2k1QLyO6m+2fy0l8Qawbp5Jgrv3DeLalrfMNBFSlmz5kriGGzsVBtGVnf4pTKIhhWA== + dependencies: + react-remove-scroll-bar "^2.3.3" + react-style-singleton "^2.2.1" + tslib "^2.1.0" + use-callback-ref "^1.3.0" + use-sidecar "^1.1.2" + react-remove-scroll@^2.4.0: version "2.5.1" resolved "https://registry.npmmirror.com/react-remove-scroll/-/react-remove-scroll-2.5.1.tgz#28c318c2e076040e5d6172bf28aab2916ad89b46" @@ -10033,6 +10302,15 @@ react-style-singleton@^2.2.0: invariant "^2.2.4" tslib "^2.0.0" +react-style-singleton@^2.2.1: + version "2.2.1" + resolved "https://registry.npmmirror.com/react-style-singleton/-/react-style-singleton-2.2.1.tgz#f99e420492b2d8f34d38308ff660b60d0b1205b4" + integrity sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g== + dependencies: + get-nonce "^1.0.0" + invariant "^2.2.4" + tslib "^2.0.0" + react-universal-interface@^0.6.2: version "0.6.2" resolved "https://registry.npmjs.org/react-universal-interface/-/react-universal-interface-0.6.2.tgz" @@ -10845,6 +11123,24 @@ snapdragon@^0.8.1: source-map-resolve "^0.5.0" use "^3.1.0" +socket.io-client@^4.5.2: + version "4.5.2" + resolved "https://registry.npmmirror.com/socket.io-client/-/socket.io-client-4.5.2.tgz#9481518c560388c980c88b01e3cf62f367f04c96" + integrity sha512-naqYfFu7CLDiQ1B7AlLhRXKX3gdeaIMfgigwavDzgJoIUYulc1qHH5+2XflTsXTPY7BlPH5rppJyUjhjrKQKLg== + dependencies: + "@socket.io/component-emitter" "~3.1.0" + debug "~4.3.2" + engine.io-client "~6.2.1" + socket.io-parser "~4.2.0" + +socket.io-parser@~4.2.0: + version "4.2.1" + resolved "https://registry.npmmirror.com/socket.io-parser/-/socket.io-parser-4.2.1.tgz#01c96efa11ded938dcb21cbe590c26af5eff65e5" + integrity sha512-V4GrkLy+HeF1F/en3SpUaM+7XxYXpuMUWLGde1kSSh5nQMN4hLrbPIkD+otwh6q9R6NOQBN4AMaOZ2zVjui82g== + dependencies: + "@socket.io/component-emitter" "~3.1.0" + debug "~4.3.1" + sockjs-client@^1.5.0: version "1.5.2" resolved "https://registry.npmjs.org/sockjs-client/-/sockjs-client-1.5.2.tgz" @@ -11855,6 +12151,11 @@ use-callback-ref@^1.3.0: dependencies: tslib "^2.0.0" +use-isomorphic-layout-effect@^1.1.1: + version "1.1.2" + resolved "https://registry.npmmirror.com/use-isomorphic-layout-effect/-/use-isomorphic-layout-effect-1.1.2.tgz#497cefb13d863d687b08477d9e5a164ad8c1a6fb" + integrity sha512-49L8yCO3iGT/ZF9QttjwLF/ZD9Iwto5LnH5LmEdk/6cFmXddqi2ulF0edxTwjj+7mqvpVVGQWvbXZdn32wRSHA== + use-sidecar@^1.1.2: version "1.1.2" resolved "https://registry.npmmirror.com/use-sidecar/-/use-sidecar-1.1.2.tgz#2f43126ba2d7d7e117aa5855e5d8f0276dfe73c2" @@ -12410,6 +12711,11 @@ ws@^7.4.6: resolved "https://registry.npmjs.org/ws/-/ws-7.5.5.tgz" integrity sha512-BAkMFcAzl8as1G/hArkxOxq3G7pjUqQ3gzYbLL0/5zNkph70e+lCoxBGnm6AW1+/aiNeV4fnKqZ8m4GZewmH2w== +ws@~8.2.3: + version "8.2.3" + resolved "https://registry.npmmirror.com/ws/-/ws-8.2.3.tgz#63a56456db1b04367d0b721a0b80cae6d8becbba" + integrity sha512-wBuoj1BDpC6ZQ1B7DWQBYVLphPWkm8i9Y0/3YdHjHKHiohOJ1ws+3OccDWtH+PoC9DZD5WOTrJvNbWvjS6JWaA== + xml-name-validator@^3.0.0: version "3.0.0" resolved "https://registry.npmjs.org/xml-name-validator/-/xml-name-validator-3.0.0.tgz" @@ -12420,6 +12726,11 @@ xmlchars@^2.2.0: resolved "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz" integrity sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw== +xmlhttprequest-ssl@~2.0.0: + version "2.0.0" + resolved "https://registry.npmmirror.com/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz#91360c86b914e67f44dce769180027c0da618c67" + integrity sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A== + xtend@^4.0.0, xtend@~4.0.1: version "4.0.2" resolved "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz" diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 1cb8b57..80252f3 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -14,17 +14,17 @@ class InpaintModel: pad_mod = 8 pad_to_square = False - def __init__(self, device): + def __init__(self, device, **kwargs): """ Args: device: """ self.device = device - self.init_model(device) + self.init_model(device, **kwargs) @abc.abstractmethod - def init_model(self, device): + def init_model(self, device, **kwargs): ... @staticmethod @@ -36,15 +36,19 @@ class InpaintModel: def forward(self, image, mask, config: Config): """Input images and output images have same size images: [H, W, C] RGB - masks: [H, W] 255 为 masks 区域 + masks: [H, W, 1] 255 为 masks 区域 return: BGR IMAGE """ ... def _pad_forward(self, image, mask, config: Config): origin_height, origin_width = image.shape[:2] - pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size) - pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size) + pad_image = pad_img_to_modulo( + image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) + pad_mask = pad_img_to_modulo( + mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size + ) logger.info(f"final forward pad size: {pad_image.shape}") @@ -81,18 +85,30 @@ class InpaintModel: elif config.hd_strategy == HDStrategy.RESIZE: if max(image.shape) > config.hd_strategy_resize_limit: origin_size = image.shape[:2] - downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit) - downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit) + downsize_image = resize_max_size( + image, size_limit=config.hd_strategy_resize_limit + ) + downsize_mask = resize_max_size( + mask, size_limit=config.hd_strategy_resize_limit + ) - logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}") - inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) + logger.info( + f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}" + ) + inpaint_result = self._pad_forward( + downsize_image, downsize_mask, config + ) # only paste masked area result - inpaint_result = cv2.resize(inpaint_result, - (origin_size[1], origin_size[0]), - interpolation=cv2.INTER_CUBIC) + inpaint_result = cv2.resize( + inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC, + ) original_pixel_indices = mask < 127 - inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] + inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + original_pixel_indices + ] if inpaint_result is None: inpaint_result = self._pad_forward(image, mask, config) @@ -133,11 +149,11 @@ class InpaintModel: if _l < 0: r += abs(_l) if _r > img_w: - l -= (_r - img_w) + l -= _r - img_w if _t < 0: b += abs(_t) if _b > img_h: - t -= (_b - img_h) + t -= _b - img_h l = max(l, 0) r = min(r, img_w) diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py index b030e0d..9e9e8c0 100644 --- a/lama_cleaner/model/fcf.py +++ b/lama_cleaner/model/fcf.py @@ -1135,7 +1135,7 @@ class FcF(InpaintModel): pad_mod = 512 pad_to_square = True - def init_model(self, device): + def init_model(self, device, **kwargs): seed = 0 random.seed(seed) np.random.seed(seed) diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index eba3860..b414f1d 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -18,16 +18,7 @@ LAMA_MODEL_URL = os.environ.get( class LaMa(InpaintModel): pad_mod = 8 - def __init__(self, device): - """ - - Args: - device: - """ - super().__init__(device) - self.device = device - - def init_model(self, device): + def init_model(self, device, **kwargs): if os.environ.get("LAMA_MODEL"): model_path = os.environ.get("LAMA_MODEL") if not os.path.exists(model_path): diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index 5364cf2..a45a7e9 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -227,7 +227,7 @@ class LDM(InpaintModel): super().__init__(device) self.device = device - def init_model(self, device): + def init_model(self, device, **kwargs): self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py index dbd634f..67020bc 100644 --- a/lama_cleaner/model/mat.py +++ b/lama_cleaner/model/mat.py @@ -1405,7 +1405,7 @@ class MAT(InpaintModel): pad_mod = 512 pad_to_square = True - def init_model(self, device): + def init_model(self, device, **kwargs): seed = 240 # pick up a random number random.seed(seed) np.random.seed(seed) diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py new file mode 100644 index 0000000..db418b4 --- /dev/null +++ b/lama_cleaner/model/sd.py @@ -0,0 +1,152 @@ +import random + +import PIL.Image +import cv2 +import numpy as np +import torch +from loguru import logger + +from lama_cleaner.helper import norm_img + +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + + +# +# +# def preprocess_image(image): +# w, h = image.size +# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 +# image = image.resize((w, h), resample=PIL.Image.LANCZOS) +# image = np.array(image).astype(np.float32) / 255.0 +# image = image[None].transpose(0, 3, 1, 2) +# image = torch.from_numpy(image) +# # [-1, 1] +# return 2.0 * image - 1.0 +# +# +# def preprocess_mask(mask): +# mask = mask.convert("L") +# w, h = mask.size +# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 +# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) +# mask = np.array(mask).astype(np.float32) / 255.0 +# mask = np.tile(mask, (4, 1, 1)) +# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? +# mask = 1 - mask # repaint white, keep black +# mask = torch.from_numpy(mask) +# return mask + + +class SD(InpaintModel): + pad_mod = 32 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + # return + from .sd_pipeline import StableDiffusionInpaintPipeline + + self.model = StableDiffusionInpaintPipeline.from_pretrained( + self.model_id_or_path, + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=kwargs["hf_access_token"], + ) + # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing + self.model.enable_attention_slicing() + self.model = self.model.to(device) + self.callbacks = kwargs.pop("callbacks", None) + + @torch.cuda.amp.autocast() + def forward(self, image, mask, config: Config): + # return image + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + + # image = norm_img(image) # [0, 1] + # image = image * 2 - 1 # [0, 1] -> [-1, 1] + + # resize to latent feature map size + # h, w = mask.shape[:2] + # mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA) + # mask = norm_img(mask) + # + # image = torch.from_numpy(image).unsqueeze(0).to(self.device) + # mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + # import time + # time.sleep(2) + # return image + seed = config.sd_seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + output = self.model( + prompt=config.prompt, + init_image=PIL.Image.fromarray(image), + mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), + strength=config.sd_strength, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np.array", + callbacks=self.callbacks, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + images: [H, W, C] RGB, not normalized + masks: [H, W] + return: BGR IMAGE + """ + img_h, img_w = image.shape[:2] + + # boxes = boxes_from_mask(mask) + if config.use_croper: + logger.info("use croper") + l, t, w, h = ( + config.croper_x, + config.croper_y, + config.croper_width, + config.croper_height, + ) + r = l + w + b = t + h + + l = max(l, 0) + r = min(r, img_w) + t = max(t, 0) + b = min(b, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + + crop_image = self._pad_forward(crop_img, crop_mask, config) + + inpaint_result = image[:, :, ::-1] + inpaint_result[t:b, l:r, :] = crop_image + else: + inpaint_result = self._pad_forward(image, mask, config) + + return inpaint_result + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True + + +class SD14(SD): + model_id_or_path = "CompVis/stable-diffusion-v1-4" + + +class SD15(SD): + model_id_or_path = "CompVis/stable-diffusion-v1-5" diff --git a/lama_cleaner/model/sd_pipeline.py b/lama_cleaner/model/sd_pipeline.py new file mode 100644 index 0000000..6616406 --- /dev/null +++ b/lama_cleaner/model/sd_pipeline.py @@ -0,0 +1,309 @@ +import inspect +from typing import List, Optional, Union, Callable + +import numpy as np +import torch + +import PIL +from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput +from diffusers.utils import logging +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +logger = logging.get_logger(__name__) + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `set_attention_slice` + self.enable_attention_slice(None) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callbacks: List[Callable[[int], None]] = None + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be + converted to a single channel (luminance) before use. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + offset = 0 + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # preprocess image + init_image = preprocess_image(init_image).to(self.device) + + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + init_latents_orig = init_latents + + # preprocess mask + mask = preprocess_mask(mask_image).to(self.device) + mask = torch.cat([mask] * batch_size) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + if callbacks is not None: + for callback in callbacks: + callback(i) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index f3d3483..db59c63 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -216,7 +216,7 @@ class ZITS(InpaintModel): self.device = device self.sample_edge_line_iterations = 1 - def init_model(self, device): + def init_model(self, device, **kwargs): self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device) self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device) self.structure_upsample = load_jit_model( diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 01435f7..1419183 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -2,27 +2,23 @@ from lama_cleaner.model.fcf import FcF from lama_cleaner.model.lama import LaMa from lama_cleaner.model.ldm import LDM from lama_cleaner.model.mat import MAT +from lama_cleaner.model.sd import SD14 from lama_cleaner.model.zits import ZITS from lama_cleaner.schema import Config -models = { - 'lama': LaMa, - 'ldm': LDM, - 'zits': ZITS, - 'mat': MAT, - 'fcf': FcF -} +models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14} class ModelManager: - def __init__(self, name: str, device): + def __init__(self, name: str, device, **kwargs): self.name = name self.device = device - self.model = self.init_model(name, device) + self.kwargs = kwargs + self.model = self.init_model(name, device, **kwargs) - def init_model(self, name: str, device): + def init_model(self, name: str, device, **kwargs): if name in models: - model = models[name](device) + model = models[name](device, **kwargs) else: raise NotImplementedError(f"Not supported model: {name}") return model @@ -40,7 +36,7 @@ class ModelManager: if new_name == self.name: return try: - self.model = self.init_model(new_name, self.device) + self.model = self.init_model(new_name, self.device, **self.kwargs) self.name = new_name except NotImplementedError as e: raise e diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 6481f82..3b0218d 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -7,7 +7,16 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", default=8080, type=int) - parser.add_argument("--model", default="lama", choices=["lama", "ldm", "zits", "mat", 'fcf']) + parser.add_argument( + "--model", + default="lama", + choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.4"], + ) + parser.add_argument( + "--hf_access_token", + default="", + help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens", + ) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument( @@ -29,4 +38,10 @@ def parse_args(): if imghdr.what(args.input) is None: parser.error(f"invalid --input: {args.input} is not a valid image file") + if args.model.startswith("sd"): + if not args.hf_access_token.startswith("hf_"): + parser.error( + f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens" + ) + return args diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index dffd984..1f6e3ef 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -22,3 +22,19 @@ class Config(BaseModel): hd_strategy_crop_margin: int hd_strategy_crop_trigger_size: int hd_strategy_resize_limit: int + + prompt: str = '' + # 始终是在原图尺度上的值 + use_croper: bool = False + croper_x: int = None + croper_y: int = None + croper_height: int = None + croper_width: int = None + + # sd + sd_strength: float = 0.75 + sd_steps: int = 50 + sd_guidance_scale: float = 7.5 + sd_sampler: str = 'ddim' # ddim/pndm + # -1 mean random seed + sd_seed: int = 42 diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 5b44592..0c7a404 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -4,6 +4,7 @@ import io import logging import multiprocessing import os +import random import time import imghdr from pathlib import Path @@ -41,7 +42,7 @@ from lama_cleaner.helper import ( NUM_THREADS = str(multiprocessing.cpu_count()) # fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 -os.environ["KMP_DUPLICATE_LIB_OK"]="True" +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" os.environ["OMP_NUM_THREADS"] = NUM_THREADS os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS @@ -64,6 +65,10 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app.config["JSON_AS_ASCII"] = False CORS(app, expose_headers=["Content-Disposition"]) +# MAX_BUFFER_SIZE = 50 * 1000 * 1000 # 50 MB +# async_mode 优先级: eventlet/gevent_uwsgi/gevent/threading +# only threading works on macOS +# socketio = SocketIO(app, max_http_buffer_size=MAX_BUFFER_SIZE, async_mode='threading') model: ModelManager = None device = None @@ -77,6 +82,11 @@ def get_image_ext(img_bytes): return w +def diffuser_callback(step: int): + pass + # socketio.emit('diffusion_step', {'diffusion_step': step}) + + @app.route("/inpaint", methods=["POST"]) def process(): input = request.files @@ -102,8 +112,24 @@ def process(): hd_strategy_crop_margin=form["hdStrategyCropMargin"], hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], hd_strategy_resize_limit=form["hdStrategyResizeLimit"], + + prompt=form['prompt'], + use_croper=form['useCroper'], + croper_x=form['croperX'], + croper_y=form['croperY'], + croper_height=form['croperHeight'], + croper_width=form['croperWidth'], + + sd_strength=form["sdStrength"], + sd_steps=form["sdSteps"], + sd_guidance_scale=form["sdGuidanceScale"], + sd_sampler=form["sdSampler"], + sd_seed=form["sdSeed"], ) + if config.sd_seed == -1: + config.sd_seed = random.randint(1, 9999999) + logger.info(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) logger.info(f"Resized image shape: {image.shape}") @@ -184,7 +210,8 @@ def main(args): device = torch.device(args.device) input_image_path = args.input - model = ModelManager(name=args.model, device=device) + model = ModelManager(name=args.model, device=device, hf_access_token=args.hf_access_token, + callbacks=[diffuser_callback]) if args.gui: app_width, app_height = args.gui_size @@ -195,4 +222,5 @@ def main(args): ) ui.run() else: + # TODO: socketio app.run(host=args.host, port=args.port, debug=args.debug) diff --git a/lama_cleaner/tests/overture-creations-5sI6fQgYIuo.png b/lama_cleaner/tests/overture-creations-5sI6fQgYIuo.png new file mode 100644 index 0000000..e84dfc8 Binary files /dev/null and b/lama_cleaner/tests/overture-creations-5sI6fQgYIuo.png differ diff --git a/lama_cleaner/tests/overture-creations-5sI6fQgYIuo_mask.png b/lama_cleaner/tests/overture-creations-5sI6fQgYIuo_mask.png new file mode 100644 index 0000000..7f3c753 Binary files /dev/null and b/lama_cleaner/tests/overture-creations-5sI6fQgYIuo_mask.png differ diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index 8ade0c2..0c292d1 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import cv2 @@ -11,10 +12,10 @@ current_dir = Path(__file__).parent.absolute().resolve() device = 'cuda' if torch.cuda.is_available() else 'cpu' -def get_data(fx=1, fy=1.0): - img = cv2.imread(str(current_dir / "image.png")) +def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"): + img = cv2.imread(str(img_p)) img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) - mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE) + mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) if fx != 1: img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) @@ -35,8 +36,8 @@ def get_config(strategy, **kwargs): return Config(**data) -def assert_equal(model, config, gt_name, fx=1, fy=1): - img, mask = get_data(fx=fx, fy=fy) +def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"): + img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) res = model(img, mask, config) cv2.imwrite( str(current_dir / gt_name), @@ -153,3 +154,26 @@ def test_fcf(strategy): fx=3.8, fy=2 ) + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_sd(strategy, capfd): + def callback(step: int): + print(f"sd_step_{step}") + + sd_steps = 2 + model = ModelManager(name="sd", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], callbacks=[callback]) + cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) + + assert_equal( + model, + cfg, + f"sd_{strategy.capitalize()}_result.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=0.5, + fy=0.5 + ) + + captured = capfd.readouterr() + for i in range(sd_steps): + assert f'sd_step_{i}' in captured.out diff --git a/requirements.txt b/requirements.txt index c635120..b029bb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ pytest yacs markupsafe==2.0.1 scikit-image==0.19.3 +diffusers==0.3.0 +transformers==4.20.0