diff --git a/lama_cleaner/app/package.json b/lama_cleaner/app/package.json index 94ccd95..840714e 100644 --- a/lama_cleaner/app/package.json +++ b/lama_cleaner/app/package.json @@ -4,7 +4,7 @@ "private": true, "proxy": "http://localhost:8080", "dependencies": { - "@heroicons/react": "^1.0.4", + "@heroicons/react": "^2.0.0", "@radix-ui/react-dialog": "0.1.8-rc.25", "@radix-ui/react-icons": "^1.1.1", "@radix-ui/react-popover": "^1.0.0", diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index 11b9ebe..bb7c292 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -114,3 +114,31 @@ export function modelDownloaded(name: string) { method: 'GET', }) } + +export async function postInteractiveSeg( + imageFile: File, + maskFile: File | null, + clicks: number[][] +) { + const fd = new FormData() + fd.append('image', imageFile) + fd.append('clicks', JSON.stringify(clicks)) + if (maskFile !== null) { + fd.append('mask', maskFile) + } + + try { + const res = await fetch(`${API_ENDPOINT}/interactive_seg`, { + method: 'POST', + body: fd, + }) + if (res.ok) { + const blob = await res.blob() + return { blob: URL.createObjectURL(blob) } + } + const errMsg = await res.text() + throw new Error(errMsg) + } catch (error) { + throw new Error(`Something went wrong: ${error}`) + } +} diff --git a/lama_cleaner/app/src/components/Croper/Croper.tsx b/lama_cleaner/app/src/components/Croper/Croper.tsx index c764e63..92d61f2 100644 --- a/lama_cleaner/app/src/components/Croper/Croper.tsx +++ b/lama_cleaner/app/src/components/Croper/Croper.tsx @@ -1,4 +1,3 @@ -import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/outline' import React, { useEffect, useState } from 'react' import { useRecoilState, useRecoilValue } from 'recoil' import { diff --git a/lama_cleaner/app/src/components/Editor/Editor.tsx b/lama_cleaner/app/src/components/Editor/Editor.tsx index 003313b..68b6521 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.tsx +++ b/lama_cleaner/app/src/components/Editor/Editor.tsx @@ -1,8 +1,3 @@ -import { - ArrowsExpandIcon, - DownloadIcon, - EyeIcon, -} from '@heroicons/react/outline' import React, { SyntheticEvent, useCallback, @@ -10,6 +5,12 @@ import React, { useRef, useState, } from 'react' +import { + CursorArrowRaysIcon, + EyeIcon, + ArrowsPointingOutIcon, + ArrowDownTrayIcon, +} from '@heroicons/react/24/outline' import { ReactZoomPanPinchRef, TransformComponent, @@ -17,7 +18,7 @@ import { } from 'react-zoom-pan-pinch' import { useRecoilState, useRecoilValue } from 'recoil' import { useWindowSize, useKey, useKeyPressEvent } from 'react-use' -import inpaint from '../../adapters/inpainting' +import inpaint, { postInteractiveSeg } from '../../adapters/inpainting' import Button from '../shared/Button' import Slider from './Slider' import SizeSelector from './SizeSelector' @@ -34,7 +35,10 @@ import { import { croperState, fileState, + interactiveSegClicksState, isInpaintingState, + isInteractiveSegRunningState, + isInteractiveSegState, isSDState, negativePropmtState, propmtState, @@ -51,6 +55,8 @@ import emitter, { CustomMaskEventData, } from '../../event' import FileSelect from '../FileSelect/FileSelect' +import InteractiveSeg from '../InteractiveSeg/InteractiveSeg' +import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions' const TOOLBAR_SIZE = 200 const MIN_BRUSH_SIZE = 10 @@ -101,6 +107,20 @@ export default function Editor() { const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState) const runMannually = useRecoilValue(runManuallyState) const isSD = useRecoilValue(isSDState) + const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState( + isInteractiveSegState + ) + const [isInteractiveSegRunning, setIsInteractiveSegRunning] = useRecoilState( + isInteractiveSegRunningState + ) + + const [interactiveSegMask, setInteractiveSegMask] = + useState(null) + // only used while interactive segmentation is on + const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = + useState(null) + + const [clicks, setClicks] = useRecoilState(interactiveSegClicksState) const [brushSize, setBrushSize] = useState(40) const [original, isOriginalLoaded] = useImage(file) @@ -159,13 +179,37 @@ export default function Editor() { original.naturalWidth, original.naturalHeight ) + if (isInteractiveSeg && tmpInteractiveSegMask !== null) { + context.drawImage( + tmpInteractiveSegMask, + 0, + 0, + original.naturalWidth, + original.naturalHeight + ) + } + if (!isInteractiveSeg && interactiveSegMask !== null) { + context.drawImage( + interactiveSegMask, + 0, + 0, + original.naturalWidth, + original.naturalHeight + ) + } drawLines(context, lineGroup) }, - [context, original] + [ + context, + original, + isInteractiveSeg, + tmpInteractiveSegMask, + interactiveSegMask, + ] ) const drawLinesOnMask = useCallback( - (_lineGroups: LineGroup[]) => { + (_lineGroups: LineGroup[], maskImage?: HTMLImageElement | null) => { if (!context?.canvas.width || !context?.canvas.height) { throw new Error('canvas has invalid size') } @@ -176,6 +220,17 @@ export default function Editor() { throw new Error('could not retrieve mask canvas') } + if (maskImage !== undefined && maskImage !== null) { + // TODO: check whether draw yellow mask works on backend + ctx.drawImage( + maskImage, + 0, + 0, + original.naturalWidth, + original.naturalHeight + ) + } + _lineGroups.forEach(lineGroup => { drawLines(ctx, lineGroup, 'white') }) @@ -199,15 +254,24 @@ export default function Editor() { ) const runInpainting = useCallback( - async (useLastLineGroup?: boolean, customMask?: File) => { + async ( + useLastLineGroup?: boolean, + customMask?: File, + maskImage?: HTMLImageElement | null + ) => { if (file === undefined) { return } const useCustomMask = customMask !== undefined + const useMaskImage = maskImage !== undefined && maskImage !== null // useLastLineGroup 的影响 // 1. 使用上一次的 mask // 2. 结果替换当前 render console.log('runInpainting') + console.log({ + useCustomMask, + useMaskImage, + }) let maskLineGroup: LineGroup = [] if (useLastLineGroup === true) { @@ -216,7 +280,7 @@ export default function Editor() { } maskLineGroup = lastLineGroup } else if (!useCustomMask) { - if (!hadDrawSomething()) { + if (!hadDrawSomething() && !useMaskImage) { return } @@ -230,7 +294,7 @@ export default function Editor() { setIsDraging(false) setIsInpainting(true) if (settings.graduallyInpainting) { - drawLinesOnMask([maskLineGroup]) + drawLinesOnMask([maskLineGroup], maskImage) } else { drawLinesOnMask(newLineGroups) } @@ -309,6 +373,8 @@ export default function Editor() { drawOnCurrentRender([]) } setIsInpainting(false) + setTmpInteractiveSegMask(null) + setInteractiveSegMask(null) }, [ lineGroups, @@ -493,10 +559,22 @@ export default function Editor() { } }, []) + const onInteractiveCancel = useCallback(() => { + setIsInteractiveSeg(false) + setIsInteractiveSegRunning(false) + setClicks([]) + setTmpInteractiveSegMask(null) + }, []) + const handleEscPressed = () => { if (isInpainting) { return } + + if (isInteractiveSeg) { + onInteractiveCancel() + } + if (isDraging || isMultiStrokeKeyPressed) { setIsDraging(false) setCurLineGroup([]) @@ -516,6 +594,8 @@ export default function Editor() { isDraging, isInpainting, isMultiStrokeKeyPressed, + isInteractiveSeg, + onInteractiveCancel, resetZoom, drawOnCurrentRender, ] @@ -536,6 +616,9 @@ export default function Editor() { } return } + if (isInteractiveSeg) { + return + } if (isPanning) { return } @@ -551,10 +634,58 @@ export default function Editor() { drawOnCurrentRender(lineGroup) } + const runInteractiveSeg = async (newClicks: number[][]) => { + if (!file) { + return + } + + setIsInteractiveSegRunning(true) + + let targetFile = file + if (renders.length > 0) { + const lastRender = renders[renders.length - 1] + targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type) + } + + const prevMask = null + // prev_mask seems to be not working better + // if (tmpInteractiveSegMask !== null) { + // prevMask = await srcToFile( + // tmpInteractiveSegMask.currentSrc, + // 'prev_mask.jpg', + // 'image/jpeg' + // ) + // } + + try { + const res = await postInteractiveSeg(targetFile, prevMask, newClicks) + if (!res) { + throw new Error('Something went wrong on server side.') + } + const { blob } = res + const img = new Image() + img.onload = () => { + setTmpInteractiveSegMask(img) + } + img.src = blob + } catch (e: any) { + setToastState({ + open: true, + desc: e.message ? e.message : e.toString(), + state: 'error', + duration: 4000, + }) + } + setIsInteractiveSegRunning(false) + } + const onPointerUp = (ev: SyntheticEvent) => { if (isMidClick(ev)) { setIsPanning(false) } + if (isInteractiveSeg) { + return + } if (isPanning) { return @@ -601,7 +732,24 @@ export default function Editor() { return false } + const onCanvasMouseUp = (ev: SyntheticEvent) => { + if (isInteractiveSeg) { + const xy = mouseXY(ev) + const newClicks: number[][] = [...clicks] + if (isRightClick(ev)) { + newClicks.push([xy.x, xy.y, 0, newClicks.length]) + } else { + newClicks.push([xy.x, xy.y, 1, newClicks.length]) + } + runInteractiveSeg(newClicks) + setClicks(newClicks) + } + } + const onMouseDown = (ev: SyntheticEvent) => { + if (isInteractiveSeg) { + return + } if (isChangingBrushSizeByMouse) { return } @@ -714,6 +862,9 @@ export default function Editor() { useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD]) const disableUndo = () => { + if (isInteractiveSeg) { + return true + } if (isInpainting) { return true } @@ -790,6 +941,9 @@ export default function Editor() { useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD]) const disableRedo = () => { + if (isInteractiveSeg) { + return true + } if (isInpainting) { return true } @@ -877,6 +1031,16 @@ export default function Editor() { return undefined }, [showBrush, isPanning]) + useHotKey( + 'i', + () => { + if (!isInteractiveSeg) { + setIsInteractiveSeg(true) + } + }, + isInteractiveSeg + ) + // Standard Hotkeys for Brush Size useHotKey('[', () => { setBrushSize(currentBrushSize => { @@ -1002,6 +1166,20 @@ export default function Editor() { ) } + const renderInteractiveSegCursor = () => { + return ( +
+ +
+ ) + } + const renderCanvas = () => { return ( toggleShowBrush(true)} onMouseLeave={() => toggleShowBrush(false)} onMouseDown={onMouseDown} + onMouseUp={onCanvasMouseUp} onMouseMove={onMouseDrag} ref={r => { if (r && !context) { @@ -1101,11 +1284,22 @@ export default function Editor() { ) : ( <> )} + + {isInteractiveSeg ? : <>} ) } + const onInteractiveAccept = () => { + setInteractiveSegMask(tmpInteractiveSegMask) + setTmpInteractiveSegMask(null) + + if (!runMannually && tmpInteractiveSegMask) { + runInpainting(false, undefined, tmpInteractiveSegMask) + } + } + return (
+ {file === undefined ? renderFileSelect() : renderCanvas()} - {showBrush && !isInpainting && !isPanning && ( -
- )} + {showBrush && + !isInpainting && + !isPanning && + (isInteractiveSeg ? ( + renderInteractiveSegCursor() + ) : ( +
+ ))} {showRefBrush && (
setShowRefBrush(false)} />
+ + +
+
+ ) +} + +export default InteractiveSegConfirmActions diff --git a/lama_cleaner/app/src/components/InteractiveSeg/InteractiveSeg.scss b/lama_cleaner/app/src/components/InteractiveSeg/InteractiveSeg.scss new file mode 100644 index 0000000..f7996a7 --- /dev/null +++ b/lama_cleaner/app/src/components/InteractiveSeg/InteractiveSeg.scss @@ -0,0 +1,69 @@ +$positiveBackgroundColor: rgba(21, 215, 121, 0.936); +$positiveOutline: 6px solid rgba(98, 255, 179, 0.31); + +$negativeBackgroundColor: rgba(237, 49, 55, 0.942); +$negativeOutline: 6px solid rgba(255, 89, 95, 0.31); + +.interactive-seg-wrapper { + position: absolute; + height: 100%; + width: 100%; + z-index: 2; + overflow: hidden; + pointer-events: none; + + .click-item { + position: absolute; + height: 8px; + width: 8px; + border-radius: 50%; + } + .click-item-positive { + background-color: $positiveBackgroundColor; + outline: $positiveOutline; + } + + .click-item-negative { + background-color: $negativeBackgroundColor; + outline: $negativeOutline; + } +} + +.interactive-seg-confirm-actions { + position: absolute; + top: 68px; + z-index: 5; + background-color: var(--page-bg); + border-radius: 16px; + border-style: solid; + border-color: var(--border-color); + border-width: 1px; + padding: 8px; + + .action-buttons { + display: flex; + justify-content: center; + align-items: center; + gap: 8px; + } +} + +@keyframes pulse { + to { + box-shadow: 0 0 0 14px rgba(21, 215, 121, 0); + } +} + +.interactive-seg-cursor { + position: absolute; + height: 20px; + width: 20px; + pointer-events: none; + color: hsla(137, 100%, 95.8%, 0.98); + border-radius: 50%; + background-color: $positiveBackgroundColor; + // outline: $positiveOutline; + transform: 'translate(-50%, -50%)'; + box-shadow: 0 0 0 0 rgba(21, 215, 121, 0.936); + animation: pulse 1.5s infinite cubic-bezier(0.66, 0, 0, 1); +} diff --git a/lama_cleaner/app/src/components/InteractiveSeg/InteractiveSeg.tsx b/lama_cleaner/app/src/components/InteractiveSeg/InteractiveSeg.tsx new file mode 100644 index 0000000..54a7ff5 --- /dev/null +++ b/lama_cleaner/app/src/components/InteractiveSeg/InteractiveSeg.tsx @@ -0,0 +1,37 @@ +import { nanoid } from 'nanoid' +import React, { useEffect, useState } from 'react' +import { useRecoilValue } from 'recoil' +import { interactiveSegClicksState } from '../../store/Atoms' + +interface ItemProps { + x: number + y: number + positive: boolean +} + +const Item = (props: ItemProps) => { + const { x, y, positive } = props + const name = positive ? 'click-item-positive' : 'click-item-negative' + return
+} + +const InteractiveSeg = () => { + const clicks = useRecoilValue(interactiveSegClicksState) + + return ( +
+ {clicks.map(click => { + return ( + + ) + })} +
+ ) +} + +export default InteractiveSeg diff --git a/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx b/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx index d23329e..f38e3a7 100644 --- a/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx +++ b/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx @@ -56,6 +56,7 @@ export default function ShortcutsModal() { /> + diff --git a/lama_cleaner/app/src/components/shared/Modal.tsx b/lama_cleaner/app/src/components/shared/Modal.tsx index c523b6b..0489797 100644 --- a/lama_cleaner/app/src/components/shared/Modal.tsx +++ b/lama_cleaner/app/src/components/shared/Modal.tsx @@ -1,4 +1,4 @@ -import { XIcon } from '@heroicons/react/outline' +import { XMarkIcon } from '@heroicons/react/24/outline' import React, { ReactNode } from 'react' import { useRecoilState } from 'recoil' import * as DialogPrimitive from '@radix-ui/react-dialog' @@ -41,7 +41,7 @@ const Modal = React.forwardRef<
{title} {showCloseIcon ? ( -