wip: add interactive seg model
This commit is contained in:
parent
af87cca643
commit
023306ae40
@ -4,7 +4,7 @@
|
|||||||
"private": true,
|
"private": true,
|
||||||
"proxy": "http://localhost:8080",
|
"proxy": "http://localhost:8080",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@heroicons/react": "^1.0.4",
|
"@heroicons/react": "^2.0.0",
|
||||||
"@radix-ui/react-dialog": "0.1.8-rc.25",
|
"@radix-ui/react-dialog": "0.1.8-rc.25",
|
||||||
"@radix-ui/react-icons": "^1.1.1",
|
"@radix-ui/react-icons": "^1.1.1",
|
||||||
"@radix-ui/react-popover": "^1.0.0",
|
"@radix-ui/react-popover": "^1.0.0",
|
||||||
|
@ -114,3 +114,31 @@ export function modelDownloaded(name: string) {
|
|||||||
method: 'GET',
|
method: 'GET',
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function postInteractiveSeg(
|
||||||
|
imageFile: File,
|
||||||
|
maskFile: File | null,
|
||||||
|
clicks: number[][]
|
||||||
|
) {
|
||||||
|
const fd = new FormData()
|
||||||
|
fd.append('image', imageFile)
|
||||||
|
fd.append('clicks', JSON.stringify(clicks))
|
||||||
|
if (maskFile !== null) {
|
||||||
|
fd.append('mask', maskFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${API_ENDPOINT}/interactive_seg`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: fd,
|
||||||
|
})
|
||||||
|
if (res.ok) {
|
||||||
|
const blob = await res.blob()
|
||||||
|
return { blob: URL.createObjectURL(blob) }
|
||||||
|
}
|
||||||
|
const errMsg = await res.text()
|
||||||
|
throw new Error(errMsg)
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Something went wrong: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/outline'
|
|
||||||
import React, { useEffect, useState } from 'react'
|
import React, { useEffect, useState } from 'react'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import {
|
import {
|
||||||
|
@ -1,8 +1,3 @@
|
|||||||
import {
|
|
||||||
ArrowsExpandIcon,
|
|
||||||
DownloadIcon,
|
|
||||||
EyeIcon,
|
|
||||||
} from '@heroicons/react/outline'
|
|
||||||
import React, {
|
import React, {
|
||||||
SyntheticEvent,
|
SyntheticEvent,
|
||||||
useCallback,
|
useCallback,
|
||||||
@ -10,6 +5,12 @@ import React, {
|
|||||||
useRef,
|
useRef,
|
||||||
useState,
|
useState,
|
||||||
} from 'react'
|
} from 'react'
|
||||||
|
import {
|
||||||
|
CursorArrowRaysIcon,
|
||||||
|
EyeIcon,
|
||||||
|
ArrowsPointingOutIcon,
|
||||||
|
ArrowDownTrayIcon,
|
||||||
|
} from '@heroicons/react/24/outline'
|
||||||
import {
|
import {
|
||||||
ReactZoomPanPinchRef,
|
ReactZoomPanPinchRef,
|
||||||
TransformComponent,
|
TransformComponent,
|
||||||
@ -17,7 +18,7 @@ import {
|
|||||||
} from 'react-zoom-pan-pinch'
|
} from 'react-zoom-pan-pinch'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||||
import inpaint from '../../adapters/inpainting'
|
import inpaint, { postInteractiveSeg } from '../../adapters/inpainting'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
import Slider from './Slider'
|
import Slider from './Slider'
|
||||||
import SizeSelector from './SizeSelector'
|
import SizeSelector from './SizeSelector'
|
||||||
@ -34,7 +35,10 @@ import {
|
|||||||
import {
|
import {
|
||||||
croperState,
|
croperState,
|
||||||
fileState,
|
fileState,
|
||||||
|
interactiveSegClicksState,
|
||||||
isInpaintingState,
|
isInpaintingState,
|
||||||
|
isInteractiveSegRunningState,
|
||||||
|
isInteractiveSegState,
|
||||||
isSDState,
|
isSDState,
|
||||||
negativePropmtState,
|
negativePropmtState,
|
||||||
propmtState,
|
propmtState,
|
||||||
@ -51,6 +55,8 @@ import emitter, {
|
|||||||
CustomMaskEventData,
|
CustomMaskEventData,
|
||||||
} from '../../event'
|
} from '../../event'
|
||||||
import FileSelect from '../FileSelect/FileSelect'
|
import FileSelect from '../FileSelect/FileSelect'
|
||||||
|
import InteractiveSeg from '../InteractiveSeg/InteractiveSeg'
|
||||||
|
import InteractiveSegConfirmActions from '../InteractiveSeg/ConfirmActions'
|
||||||
|
|
||||||
const TOOLBAR_SIZE = 200
|
const TOOLBAR_SIZE = 200
|
||||||
const MIN_BRUSH_SIZE = 10
|
const MIN_BRUSH_SIZE = 10
|
||||||
@ -101,6 +107,20 @@ export default function Editor() {
|
|||||||
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
|
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
|
||||||
const runMannually = useRecoilValue(runManuallyState)
|
const runMannually = useRecoilValue(runManuallyState)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isSD = useRecoilValue(isSDState)
|
||||||
|
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
|
||||||
|
isInteractiveSegState
|
||||||
|
)
|
||||||
|
const [isInteractiveSegRunning, setIsInteractiveSegRunning] = useRecoilState(
|
||||||
|
isInteractiveSegRunningState
|
||||||
|
)
|
||||||
|
|
||||||
|
const [interactiveSegMask, setInteractiveSegMask] =
|
||||||
|
useState<HTMLImageElement | null>(null)
|
||||||
|
// only used while interactive segmentation is on
|
||||||
|
const [tmpInteractiveSegMask, setTmpInteractiveSegMask] =
|
||||||
|
useState<HTMLImageElement | null>(null)
|
||||||
|
|
||||||
|
const [clicks, setClicks] = useRecoilState(interactiveSegClicksState)
|
||||||
|
|
||||||
const [brushSize, setBrushSize] = useState(40)
|
const [brushSize, setBrushSize] = useState(40)
|
||||||
const [original, isOriginalLoaded] = useImage(file)
|
const [original, isOriginalLoaded] = useImage(file)
|
||||||
@ -159,13 +179,37 @@ export default function Editor() {
|
|||||||
original.naturalWidth,
|
original.naturalWidth,
|
||||||
original.naturalHeight
|
original.naturalHeight
|
||||||
)
|
)
|
||||||
|
if (isInteractiveSeg && tmpInteractiveSegMask !== null) {
|
||||||
|
context.drawImage(
|
||||||
|
tmpInteractiveSegMask,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
original.naturalWidth,
|
||||||
|
original.naturalHeight
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (!isInteractiveSeg && interactiveSegMask !== null) {
|
||||||
|
context.drawImage(
|
||||||
|
interactiveSegMask,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
original.naturalWidth,
|
||||||
|
original.naturalHeight
|
||||||
|
)
|
||||||
|
}
|
||||||
drawLines(context, lineGroup)
|
drawLines(context, lineGroup)
|
||||||
},
|
},
|
||||||
[context, original]
|
[
|
||||||
|
context,
|
||||||
|
original,
|
||||||
|
isInteractiveSeg,
|
||||||
|
tmpInteractiveSegMask,
|
||||||
|
interactiveSegMask,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
const drawLinesOnMask = useCallback(
|
const drawLinesOnMask = useCallback(
|
||||||
(_lineGroups: LineGroup[]) => {
|
(_lineGroups: LineGroup[], maskImage?: HTMLImageElement | null) => {
|
||||||
if (!context?.canvas.width || !context?.canvas.height) {
|
if (!context?.canvas.width || !context?.canvas.height) {
|
||||||
throw new Error('canvas has invalid size')
|
throw new Error('canvas has invalid size')
|
||||||
}
|
}
|
||||||
@ -176,6 +220,17 @@ export default function Editor() {
|
|||||||
throw new Error('could not retrieve mask canvas')
|
throw new Error('could not retrieve mask canvas')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (maskImage !== undefined && maskImage !== null) {
|
||||||
|
// TODO: check whether draw yellow mask works on backend
|
||||||
|
ctx.drawImage(
|
||||||
|
maskImage,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
original.naturalWidth,
|
||||||
|
original.naturalHeight
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
_lineGroups.forEach(lineGroup => {
|
_lineGroups.forEach(lineGroup => {
|
||||||
drawLines(ctx, lineGroup, 'white')
|
drawLines(ctx, lineGroup, 'white')
|
||||||
})
|
})
|
||||||
@ -199,15 +254,24 @@ export default function Editor() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
const runInpainting = useCallback(
|
const runInpainting = useCallback(
|
||||||
async (useLastLineGroup?: boolean, customMask?: File) => {
|
async (
|
||||||
|
useLastLineGroup?: boolean,
|
||||||
|
customMask?: File,
|
||||||
|
maskImage?: HTMLImageElement | null
|
||||||
|
) => {
|
||||||
if (file === undefined) {
|
if (file === undefined) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const useCustomMask = customMask !== undefined
|
const useCustomMask = customMask !== undefined
|
||||||
|
const useMaskImage = maskImage !== undefined && maskImage !== null
|
||||||
// useLastLineGroup 的影响
|
// useLastLineGroup 的影响
|
||||||
// 1. 使用上一次的 mask
|
// 1. 使用上一次的 mask
|
||||||
// 2. 结果替换当前 render
|
// 2. 结果替换当前 render
|
||||||
console.log('runInpainting')
|
console.log('runInpainting')
|
||||||
|
console.log({
|
||||||
|
useCustomMask,
|
||||||
|
useMaskImage,
|
||||||
|
})
|
||||||
|
|
||||||
let maskLineGroup: LineGroup = []
|
let maskLineGroup: LineGroup = []
|
||||||
if (useLastLineGroup === true) {
|
if (useLastLineGroup === true) {
|
||||||
@ -216,7 +280,7 @@ export default function Editor() {
|
|||||||
}
|
}
|
||||||
maskLineGroup = lastLineGroup
|
maskLineGroup = lastLineGroup
|
||||||
} else if (!useCustomMask) {
|
} else if (!useCustomMask) {
|
||||||
if (!hadDrawSomething()) {
|
if (!hadDrawSomething() && !useMaskImage) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,7 +294,7 @@ export default function Editor() {
|
|||||||
setIsDraging(false)
|
setIsDraging(false)
|
||||||
setIsInpainting(true)
|
setIsInpainting(true)
|
||||||
if (settings.graduallyInpainting) {
|
if (settings.graduallyInpainting) {
|
||||||
drawLinesOnMask([maskLineGroup])
|
drawLinesOnMask([maskLineGroup], maskImage)
|
||||||
} else {
|
} else {
|
||||||
drawLinesOnMask(newLineGroups)
|
drawLinesOnMask(newLineGroups)
|
||||||
}
|
}
|
||||||
@ -309,6 +373,8 @@ export default function Editor() {
|
|||||||
drawOnCurrentRender([])
|
drawOnCurrentRender([])
|
||||||
}
|
}
|
||||||
setIsInpainting(false)
|
setIsInpainting(false)
|
||||||
|
setTmpInteractiveSegMask(null)
|
||||||
|
setInteractiveSegMask(null)
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
lineGroups,
|
lineGroups,
|
||||||
@ -493,10 +559,22 @@ export default function Editor() {
|
|||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
const onInteractiveCancel = useCallback(() => {
|
||||||
|
setIsInteractiveSeg(false)
|
||||||
|
setIsInteractiveSegRunning(false)
|
||||||
|
setClicks([])
|
||||||
|
setTmpInteractiveSegMask(null)
|
||||||
|
}, [])
|
||||||
|
|
||||||
const handleEscPressed = () => {
|
const handleEscPressed = () => {
|
||||||
if (isInpainting) {
|
if (isInpainting) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isInteractiveSeg) {
|
||||||
|
onInteractiveCancel()
|
||||||
|
}
|
||||||
|
|
||||||
if (isDraging || isMultiStrokeKeyPressed) {
|
if (isDraging || isMultiStrokeKeyPressed) {
|
||||||
setIsDraging(false)
|
setIsDraging(false)
|
||||||
setCurLineGroup([])
|
setCurLineGroup([])
|
||||||
@ -516,6 +594,8 @@ export default function Editor() {
|
|||||||
isDraging,
|
isDraging,
|
||||||
isInpainting,
|
isInpainting,
|
||||||
isMultiStrokeKeyPressed,
|
isMultiStrokeKeyPressed,
|
||||||
|
isInteractiveSeg,
|
||||||
|
onInteractiveCancel,
|
||||||
resetZoom,
|
resetZoom,
|
||||||
drawOnCurrentRender,
|
drawOnCurrentRender,
|
||||||
]
|
]
|
||||||
@ -536,6 +616,9 @@ export default function Editor() {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if (isInteractiveSeg) {
|
||||||
|
return
|
||||||
|
}
|
||||||
if (isPanning) {
|
if (isPanning) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -551,10 +634,58 @@ export default function Editor() {
|
|||||||
drawOnCurrentRender(lineGroup)
|
drawOnCurrentRender(lineGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const runInteractiveSeg = async (newClicks: number[][]) => {
|
||||||
|
if (!file) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsInteractiveSegRunning(true)
|
||||||
|
|
||||||
|
let targetFile = file
|
||||||
|
if (renders.length > 0) {
|
||||||
|
const lastRender = renders[renders.length - 1]
|
||||||
|
targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type)
|
||||||
|
}
|
||||||
|
|
||||||
|
const prevMask = null
|
||||||
|
// prev_mask seems to be not working better
|
||||||
|
// if (tmpInteractiveSegMask !== null) {
|
||||||
|
// prevMask = await srcToFile(
|
||||||
|
// tmpInteractiveSegMask.currentSrc,
|
||||||
|
// 'prev_mask.jpg',
|
||||||
|
// 'image/jpeg'
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
|
||||||
|
try {
|
||||||
|
const res = await postInteractiveSeg(targetFile, prevMask, newClicks)
|
||||||
|
if (!res) {
|
||||||
|
throw new Error('Something went wrong on server side.')
|
||||||
|
}
|
||||||
|
const { blob } = res
|
||||||
|
const img = new Image()
|
||||||
|
img.onload = () => {
|
||||||
|
setTmpInteractiveSegMask(img)
|
||||||
|
}
|
||||||
|
img.src = blob
|
||||||
|
} catch (e: any) {
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: e.message ? e.message : e.toString(),
|
||||||
|
state: 'error',
|
||||||
|
duration: 4000,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
setIsInteractiveSegRunning(false)
|
||||||
|
}
|
||||||
|
|
||||||
const onPointerUp = (ev: SyntheticEvent) => {
|
const onPointerUp = (ev: SyntheticEvent) => {
|
||||||
if (isMidClick(ev)) {
|
if (isMidClick(ev)) {
|
||||||
setIsPanning(false)
|
setIsPanning(false)
|
||||||
}
|
}
|
||||||
|
if (isInteractiveSeg) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if (isPanning) {
|
if (isPanning) {
|
||||||
return
|
return
|
||||||
@ -601,7 +732,24 @@ export default function Editor() {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onCanvasMouseUp = (ev: SyntheticEvent) => {
|
||||||
|
if (isInteractiveSeg) {
|
||||||
|
const xy = mouseXY(ev)
|
||||||
|
const newClicks: number[][] = [...clicks]
|
||||||
|
if (isRightClick(ev)) {
|
||||||
|
newClicks.push([xy.x, xy.y, 0, newClicks.length])
|
||||||
|
} else {
|
||||||
|
newClicks.push([xy.x, xy.y, 1, newClicks.length])
|
||||||
|
}
|
||||||
|
runInteractiveSeg(newClicks)
|
||||||
|
setClicks(newClicks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const onMouseDown = (ev: SyntheticEvent) => {
|
const onMouseDown = (ev: SyntheticEvent) => {
|
||||||
|
if (isInteractiveSeg) {
|
||||||
|
return
|
||||||
|
}
|
||||||
if (isChangingBrushSizeByMouse) {
|
if (isChangingBrushSizeByMouse) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -714,6 +862,9 @@ export default function Editor() {
|
|||||||
useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD])
|
useKey(undoPredicate, undo, undefined, [undoStroke, undoRender, isSD])
|
||||||
|
|
||||||
const disableUndo = () => {
|
const disableUndo = () => {
|
||||||
|
if (isInteractiveSeg) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
if (isInpainting) {
|
if (isInpainting) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -790,6 +941,9 @@ export default function Editor() {
|
|||||||
useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD])
|
useKey(redoPredicate, redo, undefined, [redoStroke, redoRender, isSD])
|
||||||
|
|
||||||
const disableRedo = () => {
|
const disableRedo = () => {
|
||||||
|
if (isInteractiveSeg) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
if (isInpainting) {
|
if (isInpainting) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -877,6 +1031,16 @@ export default function Editor() {
|
|||||||
return undefined
|
return undefined
|
||||||
}, [showBrush, isPanning])
|
}, [showBrush, isPanning])
|
||||||
|
|
||||||
|
useHotKey(
|
||||||
|
'i',
|
||||||
|
() => {
|
||||||
|
if (!isInteractiveSeg) {
|
||||||
|
setIsInteractiveSeg(true)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
isInteractiveSeg
|
||||||
|
)
|
||||||
|
|
||||||
// Standard Hotkeys for Brush Size
|
// Standard Hotkeys for Brush Size
|
||||||
useHotKey('[', () => {
|
useHotKey('[', () => {
|
||||||
setBrushSize(currentBrushSize => {
|
setBrushSize(currentBrushSize => {
|
||||||
@ -1002,6 +1166,20 @@ export default function Editor() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderInteractiveSegCursor = () => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className="interactive-seg-cursor"
|
||||||
|
style={{
|
||||||
|
left: `${x}px`,
|
||||||
|
top: `${y}px`,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<CursorArrowRaysIcon />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderCanvas = () => {
|
const renderCanvas = () => {
|
||||||
return (
|
return (
|
||||||
<TransformWrapper
|
<TransformWrapper
|
||||||
@ -1029,7 +1207,11 @@ export default function Editor() {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<TransformComponent
|
<TransformComponent
|
||||||
contentClass={isInpainting ? 'editor-canvas-loading' : ''}
|
contentClass={
|
||||||
|
isInpainting || isInteractiveSegRunning
|
||||||
|
? 'editor-canvas-loading'
|
||||||
|
: ''
|
||||||
|
}
|
||||||
contentStyle={{
|
contentStyle={{
|
||||||
visibility: initialCentered ? 'visible' : 'hidden',
|
visibility: initialCentered ? 'visible' : 'hidden',
|
||||||
}}
|
}}
|
||||||
@ -1052,6 +1234,7 @@ export default function Editor() {
|
|||||||
onFocus={() => toggleShowBrush(true)}
|
onFocus={() => toggleShowBrush(true)}
|
||||||
onMouseLeave={() => toggleShowBrush(false)}
|
onMouseLeave={() => toggleShowBrush(false)}
|
||||||
onMouseDown={onMouseDown}
|
onMouseDown={onMouseDown}
|
||||||
|
onMouseUp={onCanvasMouseUp}
|
||||||
onMouseMove={onMouseDrag}
|
onMouseMove={onMouseDrag}
|
||||||
ref={r => {
|
ref={r => {
|
||||||
if (r && !context) {
|
if (r && !context) {
|
||||||
@ -1101,11 +1284,22 @@ export default function Editor() {
|
|||||||
) : (
|
) : (
|
||||||
<></>
|
<></>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{isInteractiveSeg ? <InteractiveSeg /> : <></>}
|
||||||
</TransformComponent>
|
</TransformComponent>
|
||||||
</TransformWrapper>
|
</TransformWrapper>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onInteractiveAccept = () => {
|
||||||
|
setInteractiveSegMask(tmpInteractiveSegMask)
|
||||||
|
setTmpInteractiveSegMask(null)
|
||||||
|
|
||||||
|
if (!runMannually && tmpInteractiveSegMask) {
|
||||||
|
runInpainting(false, undefined, tmpInteractiveSegMask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className="editor-container"
|
className="editor-container"
|
||||||
@ -1113,17 +1307,26 @@ export default function Editor() {
|
|||||||
onMouseMove={onMouseMove}
|
onMouseMove={onMouseMove}
|
||||||
onMouseUp={onPointerUp}
|
onMouseUp={onPointerUp}
|
||||||
>
|
>
|
||||||
|
<InteractiveSegConfirmActions
|
||||||
|
onAcceptClick={onInteractiveAccept}
|
||||||
|
onCancelClick={onInteractiveCancel}
|
||||||
|
/>
|
||||||
{file === undefined ? renderFileSelect() : renderCanvas()}
|
{file === undefined ? renderFileSelect() : renderCanvas()}
|
||||||
|
|
||||||
{showBrush && !isInpainting && !isPanning && (
|
{showBrush &&
|
||||||
<div
|
!isInpainting &&
|
||||||
className="brush-shape"
|
!isPanning &&
|
||||||
style={getBrushStyle(
|
(isInteractiveSeg ? (
|
||||||
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.x : x,
|
renderInteractiveSegCursor()
|
||||||
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.y : y
|
) : (
|
||||||
)}
|
<div
|
||||||
/>
|
className="brush-shape"
|
||||||
)}
|
style={getBrushStyle(
|
||||||
|
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.x : x,
|
||||||
|
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.y : y
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
|
||||||
{showRefBrush && (
|
{showRefBrush && (
|
||||||
<div
|
<div
|
||||||
@ -1151,10 +1354,17 @@ export default function Editor() {
|
|||||||
onClick={() => setShowRefBrush(false)}
|
onClick={() => setShowRefBrush(false)}
|
||||||
/>
|
/>
|
||||||
<div className="editor-toolkit-btns">
|
<div className="editor-toolkit-btns">
|
||||||
|
<Button
|
||||||
|
toolTip="Interactive Segmentation"
|
||||||
|
tooltipPosition="top"
|
||||||
|
icon={<CursorArrowRaysIcon />}
|
||||||
|
disabled={isInteractiveSeg || isInpainting}
|
||||||
|
onClick={() => setIsInteractiveSeg(true)}
|
||||||
|
/>
|
||||||
<Button
|
<Button
|
||||||
toolTip="Reset Zoom & Pan"
|
toolTip="Reset Zoom & Pan"
|
||||||
tooltipPosition="top"
|
tooltipPosition="top"
|
||||||
icon={<ArrowsExpandIcon />}
|
icon={<ArrowsPointingOutIcon />}
|
||||||
disabled={scale === minScale && panned === false}
|
disabled={scale === minScale && panned === false}
|
||||||
onClick={resetZoom}
|
onClick={resetZoom}
|
||||||
/>
|
/>
|
||||||
@ -1224,7 +1434,7 @@ export default function Editor() {
|
|||||||
<Button
|
<Button
|
||||||
toolTip="Save Image"
|
toolTip="Save Image"
|
||||||
tooltipPosition="top"
|
tooltipPosition="top"
|
||||||
icon={<DownloadIcon />}
|
icon={<ArrowDownTrayIcon />}
|
||||||
disabled={!renders.length}
|
disabled={!renders.length}
|
||||||
onClick={download}
|
onClick={download}
|
||||||
/>
|
/>
|
||||||
@ -1247,11 +1457,13 @@ export default function Editor() {
|
|||||||
/>
|
/>
|
||||||
</svg>
|
</svg>
|
||||||
}
|
}
|
||||||
disabled={!hadDrawSomething() || isInpainting}
|
disabled={
|
||||||
|
!interactiveSegMask &&
|
||||||
|
(!hadDrawSomething() || isInpainting || isInteractiveSeg)
|
||||||
|
}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
if (!isInpainting && hadDrawSomething()) {
|
// ensured by disabled
|
||||||
runInpainting()
|
runInpainting(false, undefined, interactiveSegMask)
|
||||||
}
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline'
|
import { ArrowUpTrayIcon } from '@heroicons/react/24/outline'
|
||||||
import { PlayIcon } from '@radix-ui/react-icons'
|
import { PlayIcon } from '@radix-ui/react-icons'
|
||||||
import React, { useState } from 'react'
|
import React, { useState } from 'react'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import {
|
import {
|
||||||
fileState,
|
fileState,
|
||||||
|
interactiveSegClicksState,
|
||||||
isInpaintingState,
|
isInpaintingState,
|
||||||
isSDState,
|
isSDState,
|
||||||
maskState,
|
maskState,
|
||||||
@ -37,9 +38,11 @@ const Header = () => {
|
|||||||
>
|
>
|
||||||
<label htmlFor={uploadElemId}>
|
<label htmlFor={uploadElemId}>
|
||||||
<Button
|
<Button
|
||||||
icon={<UploadIcon />}
|
icon={<ArrowUpTrayIcon />}
|
||||||
style={{ border: 0 }}
|
style={{ border: 0 }}
|
||||||
disabled={isInpainting}
|
disabled={isInpainting}
|
||||||
|
toolTip="Upload image"
|
||||||
|
tooltipPosition="bottom"
|
||||||
>
|
>
|
||||||
<input
|
<input
|
||||||
style={{ display: 'none' }}
|
style={{ display: 'none' }}
|
||||||
@ -67,7 +70,12 @@ const Header = () => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<label htmlFor={maskUploadElemId}>
|
<label htmlFor={maskUploadElemId}>
|
||||||
<Button style={{ border: 0 }} disabled={isInpainting}>
|
<Button
|
||||||
|
style={{ border: 0 }}
|
||||||
|
disabled={isInpainting}
|
||||||
|
toolTip="Upload custom mask"
|
||||||
|
tooltipPosition="bottom"
|
||||||
|
>
|
||||||
<input
|
<input
|
||||||
style={{ display: 'none' }}
|
style={{ display: 'none' }}
|
||||||
id={maskUploadElemId}
|
id={maskUploadElemId}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import React, { useEffect } from 'react'
|
import React, { useEffect } from 'react'
|
||||||
import { atom, useRecoilState } from 'recoil'
|
import { atom, useRecoilState } from 'recoil'
|
||||||
import { SunIcon, MoonIcon } from '@heroicons/react/outline'
|
import { SunIcon, MoonIcon } from '@heroicons/react/24/outline'
|
||||||
|
|
||||||
export const themeState = atom({
|
export const themeState = atom({
|
||||||
key: 'themeState',
|
key: 'themeState',
|
||||||
|
@ -0,0 +1,62 @@
|
|||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
|
import {
|
||||||
|
interactiveSegClicksState,
|
||||||
|
isInteractiveSegRunningState,
|
||||||
|
isInteractiveSegState,
|
||||||
|
} from '../../store/Atoms'
|
||||||
|
import Button from '../shared/Button'
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
onCancelClick: () => void
|
||||||
|
onAcceptClick: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const InteractiveSegConfirmActions = (props: Props) => {
|
||||||
|
const { onCancelClick, onAcceptClick } = props
|
||||||
|
|
||||||
|
const [isInteractiveSeg, setIsInteractiveSeg] = useRecoilState(
|
||||||
|
isInteractiveSegState
|
||||||
|
)
|
||||||
|
const [isInteractiveSegRunning, setIsInteractiveSegRunning] = useRecoilState(
|
||||||
|
isInteractiveSegRunningState
|
||||||
|
)
|
||||||
|
const [clicks, setClicks] = useRecoilState(interactiveSegClicksState)
|
||||||
|
|
||||||
|
const clearState = () => {
|
||||||
|
setIsInteractiveSeg(false)
|
||||||
|
setIsInteractiveSegRunning(false)
|
||||||
|
setClicks([])
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className="interactive-seg-confirm-actions"
|
||||||
|
style={{
|
||||||
|
visibility: isInteractiveSeg ? 'visible' : 'hidden',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className="action-buttons">
|
||||||
|
<Button
|
||||||
|
onClick={() => {
|
||||||
|
clearState()
|
||||||
|
onCancelClick()
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
border
|
||||||
|
onClick={() => {
|
||||||
|
clearState()
|
||||||
|
onAcceptClick()
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Accept
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default InteractiveSegConfirmActions
|
@ -0,0 +1,69 @@
|
|||||||
|
$positiveBackgroundColor: rgba(21, 215, 121, 0.936);
|
||||||
|
$positiveOutline: 6px solid rgba(98, 255, 179, 0.31);
|
||||||
|
|
||||||
|
$negativeBackgroundColor: rgba(237, 49, 55, 0.942);
|
||||||
|
$negativeOutline: 6px solid rgba(255, 89, 95, 0.31);
|
||||||
|
|
||||||
|
.interactive-seg-wrapper {
|
||||||
|
position: absolute;
|
||||||
|
height: 100%;
|
||||||
|
width: 100%;
|
||||||
|
z-index: 2;
|
||||||
|
overflow: hidden;
|
||||||
|
pointer-events: none;
|
||||||
|
|
||||||
|
.click-item {
|
||||||
|
position: absolute;
|
||||||
|
height: 8px;
|
||||||
|
width: 8px;
|
||||||
|
border-radius: 50%;
|
||||||
|
}
|
||||||
|
.click-item-positive {
|
||||||
|
background-color: $positiveBackgroundColor;
|
||||||
|
outline: $positiveOutline;
|
||||||
|
}
|
||||||
|
|
||||||
|
.click-item-negative {
|
||||||
|
background-color: $negativeBackgroundColor;
|
||||||
|
outline: $negativeOutline;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.interactive-seg-confirm-actions {
|
||||||
|
position: absolute;
|
||||||
|
top: 68px;
|
||||||
|
z-index: 5;
|
||||||
|
background-color: var(--page-bg);
|
||||||
|
border-radius: 16px;
|
||||||
|
border-style: solid;
|
||||||
|
border-color: var(--border-color);
|
||||||
|
border-width: 1px;
|
||||||
|
padding: 8px;
|
||||||
|
|
||||||
|
.action-buttons {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
to {
|
||||||
|
box-shadow: 0 0 0 14px rgba(21, 215, 121, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.interactive-seg-cursor {
|
||||||
|
position: absolute;
|
||||||
|
height: 20px;
|
||||||
|
width: 20px;
|
||||||
|
pointer-events: none;
|
||||||
|
color: hsla(137, 100%, 95.8%, 0.98);
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: $positiveBackgroundColor;
|
||||||
|
// outline: $positiveOutline;
|
||||||
|
transform: 'translate(-50%, -50%)';
|
||||||
|
box-shadow: 0 0 0 0 rgba(21, 215, 121, 0.936);
|
||||||
|
animation: pulse 1.5s infinite cubic-bezier(0.66, 0, 0, 1);
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
import { nanoid } from 'nanoid'
|
||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import { useRecoilValue } from 'recoil'
|
||||||
|
import { interactiveSegClicksState } from '../../store/Atoms'
|
||||||
|
|
||||||
|
interface ItemProps {
|
||||||
|
x: number
|
||||||
|
y: number
|
||||||
|
positive: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const Item = (props: ItemProps) => {
|
||||||
|
const { x, y, positive } = props
|
||||||
|
const name = positive ? 'click-item-positive' : 'click-item-negative'
|
||||||
|
return <div className={`click-item ${name}`} style={{ left: x, top: y }} />
|
||||||
|
}
|
||||||
|
|
||||||
|
const InteractiveSeg = () => {
|
||||||
|
const clicks = useRecoilValue<number[][]>(interactiveSegClicksState)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="interactive-seg-wrapper">
|
||||||
|
{clicks.map(click => {
|
||||||
|
return (
|
||||||
|
<Item
|
||||||
|
key={click[3]}
|
||||||
|
x={click[0]}
|
||||||
|
y={click[1]}
|
||||||
|
positive={click[2] === 1}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default InteractiveSeg
|
@ -56,6 +56,7 @@ export default function ShortcutsModal() {
|
|||||||
/>
|
/>
|
||||||
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
|
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
|
||||||
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
|
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
|
||||||
|
<ShortCut content="Interactive Segmentation" keys={['I']} />
|
||||||
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
|
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
|
||||||
<ShortCut content="Redo Inpainting" keys={[CmdOrCtrl, 'Shift', 'Z']} />
|
<ShortCut content="Redo Inpainting" keys={[CmdOrCtrl, 'Shift', 'Z']} />
|
||||||
<ShortCut content="View Original Image" keys={['Hold Tab']} />
|
<ShortCut content="View Original Image" keys={['Hold Tab']} />
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { XIcon } from '@heroicons/react/outline'
|
import { XMarkIcon } from '@heroicons/react/24/outline'
|
||||||
import React, { ReactNode } from 'react'
|
import React, { ReactNode } from 'react'
|
||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import * as DialogPrimitive from '@radix-ui/react-dialog'
|
import * as DialogPrimitive from '@radix-ui/react-dialog'
|
||||||
@ -41,7 +41,7 @@ const Modal = React.forwardRef<
|
|||||||
<div className="modal-header">
|
<div className="modal-header">
|
||||||
<DialogPrimitive.Title>{title}</DialogPrimitive.Title>
|
<DialogPrimitive.Title>{title}</DialogPrimitive.Title>
|
||||||
{showCloseIcon ? (
|
{showCloseIcon ? (
|
||||||
<Button icon={<XIcon />} onClick={onClose} />
|
<Button icon={<XMarkIcon />} onClick={onClose} />
|
||||||
) : (
|
) : (
|
||||||
<></>
|
<></>
|
||||||
)}
|
)}
|
||||||
|
@ -3,7 +3,7 @@ import {
|
|||||||
CheckIcon,
|
CheckIcon,
|
||||||
ChevronDownIcon,
|
ChevronDownIcon,
|
||||||
ChevronUpIcon,
|
ChevronUpIcon,
|
||||||
} from '@heroicons/react/outline'
|
} from '@heroicons/react/24/outline'
|
||||||
import * as Select from '@radix-ui/react-select'
|
import * as Select from '@radix-ui/react-select'
|
||||||
|
|
||||||
type SelectorChevronDirection = 'up' | 'down'
|
type SelectorChevronDirection = 'up' | 'down'
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import * as React from 'react'
|
import * as React from 'react'
|
||||||
import * as ToastPrimitive from '@radix-ui/react-toast'
|
import * as ToastPrimitive from '@radix-ui/react-toast'
|
||||||
import { ToastProps } from '@radix-ui/react-toast'
|
import { ToastProps } from '@radix-ui/react-toast'
|
||||||
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/outline'
|
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/24/outline'
|
||||||
|
|
||||||
const LoadingIcon = () => {
|
const LoadingIcon = () => {
|
||||||
return (
|
return (
|
||||||
|
@ -14,11 +14,6 @@ export enum AIModel {
|
|||||||
Mange = 'manga',
|
Mange = 'manga',
|
||||||
}
|
}
|
||||||
|
|
||||||
export const fileState = atom<File | undefined>({
|
|
||||||
key: 'fileState',
|
|
||||||
default: undefined,
|
|
||||||
})
|
|
||||||
|
|
||||||
export const maskState = atom<File | undefined>({
|
export const maskState = atom<File | undefined>({
|
||||||
key: 'maskState',
|
key: 'maskState',
|
||||||
default: undefined,
|
default: undefined,
|
||||||
@ -32,17 +27,25 @@ export interface Rect {
|
|||||||
}
|
}
|
||||||
|
|
||||||
interface AppState {
|
interface AppState {
|
||||||
|
file: File | undefined
|
||||||
disableShortCuts: boolean
|
disableShortCuts: boolean
|
||||||
isInpainting: boolean
|
isInpainting: boolean
|
||||||
isDisableModelSwitch: boolean
|
isDisableModelSwitch: boolean
|
||||||
|
isInteractiveSeg: boolean
|
||||||
|
isInteractiveSegRunning: boolean
|
||||||
|
interactiveSegClicks: number[][]
|
||||||
}
|
}
|
||||||
|
|
||||||
export const appState = atom<AppState>({
|
export const appState = atom<AppState>({
|
||||||
key: 'appState',
|
key: 'appState',
|
||||||
default: {
|
default: {
|
||||||
|
file: undefined,
|
||||||
disableShortCuts: false,
|
disableShortCuts: false,
|
||||||
isInpainting: false,
|
isInpainting: false,
|
||||||
isDisableModelSwitch: false,
|
isDisableModelSwitch: false,
|
||||||
|
isInteractiveSeg: false,
|
||||||
|
isInteractiveSegRunning: false,
|
||||||
|
interactiveSegClicks: [],
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -68,6 +71,60 @@ export const isInpaintingState = selector({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const fileState = selector({
|
||||||
|
key: 'fileState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.file
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, {
|
||||||
|
...app,
|
||||||
|
file: newValue,
|
||||||
|
interactiveSegClicks: [],
|
||||||
|
isInteractiveSeg: false,
|
||||||
|
isInteractiveSegRunning: false,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
export const isInteractiveSegState = selector({
|
||||||
|
key: 'isInteractiveSegState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.isInteractiveSeg
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, { ...app, isInteractiveSeg: newValue })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
export const isInteractiveSegRunningState = selector({
|
||||||
|
key: 'isInteractiveSegRunningState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.isInteractiveSegRunning
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, { ...app, isInteractiveSegRunning: newValue })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
export const interactiveSegClicksState = selector({
|
||||||
|
key: 'interactiveSegClicksState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.interactiveSegClicks
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, { ...app, interactiveSegClicks: newValue })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
export const isDisableModelSwitchState = selector({
|
export const isDisableModelSwitchState = selector({
|
||||||
key: 'isDisableModelSwitchState',
|
key: 'isDisableModelSwitchState',
|
||||||
get: ({ get }) => {
|
get: ({ get }) => {
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
@use '../components/Settings/Settings.scss';
|
@use '../components/Settings/Settings.scss';
|
||||||
@use '../components/SidePanel/SidePanel.scss';
|
@use '../components/SidePanel/SidePanel.scss';
|
||||||
@use '../components/Croper/Croper.scss';
|
@use '../components/Croper/Croper.scss';
|
||||||
|
@use '../components/InteractiveSeg/InteractiveSeg.scss';
|
||||||
|
|
||||||
// Shared
|
// Shared
|
||||||
@use '../components/FileSelect/FileSelect';
|
@use '../components/FileSelect/FileSelect';
|
||||||
|
@ -1298,10 +1298,10 @@
|
|||||||
dependencies:
|
dependencies:
|
||||||
"@hapi/hoek" "^8.3.0"
|
"@hapi/hoek" "^8.3.0"
|
||||||
|
|
||||||
"@heroicons/react@^1.0.4":
|
"@heroicons/react@^2.0.0":
|
||||||
version "1.0.4"
|
version "2.0.13"
|
||||||
resolved "https://registry.npmjs.org/@heroicons/react/-/react-1.0.4.tgz"
|
resolved "https://registry.npmmirror.com/@heroicons/react/-/react-2.0.13.tgz#9b1cc54ff77d6625c9565efdce0054a4bcd9074c"
|
||||||
integrity sha512-3kOrTmo8+Z8o6AL0rzN82MOf8J5CuxhRLFhpI8mrn+3OqekA6d5eb1GYO3EYYo1Vn6mYQSMNTzCWbEwUInb0cQ==
|
integrity sha512-iSN5XwmagrnirWlYEWNPdCDj9aRYVD/lnK3JlsC9/+fqGF80k8C7rl+1HCvBX0dBoagKqOFBs6fMhJJ1hOg1EQ==
|
||||||
|
|
||||||
"@humanwhocodes/config-array@^0.5.0":
|
"@humanwhocodes/config-array@^0.5.0":
|
||||||
version "0.5.0"
|
version "0.5.0"
|
||||||
|
@ -180,3 +180,29 @@ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
|||||||
boxes.append(box)
|
boxes.append(box)
|
||||||
|
|
||||||
return boxes
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
|
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
mask: (h, w) 0~255
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
||||||
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
max_area = 0
|
||||||
|
max_index = -1
|
||||||
|
for i, cnt in enumerate(contours):
|
||||||
|
area = cv2.contourArea(cnt)
|
||||||
|
if area > max_area:
|
||||||
|
max_area = area
|
||||||
|
max_index = i
|
||||||
|
|
||||||
|
if max_index != -1:
|
||||||
|
new_mask = np.zeros_like(mask)
|
||||||
|
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
||||||
|
else:
|
||||||
|
return mask
|
||||||
|
202
lama_cleaner/interactive_seg.py
Normal file
202
lama_cleaner/interactive_seg.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from typing import Tuple, List
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lama_cleaner.helper import only_keep_largest_contour, load_jit_model
|
||||||
|
|
||||||
|
|
||||||
|
class Click(BaseModel):
|
||||||
|
# [y, x]
|
||||||
|
coords: Tuple[float, float]
|
||||||
|
is_positive: bool
|
||||||
|
indx: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coords_and_indx(self):
|
||||||
|
return (*self.coords, self.indx)
|
||||||
|
|
||||||
|
def scale(self, x_ratio: float, y_ratio: float) -> 'Click':
|
||||||
|
return Click(
|
||||||
|
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
|
||||||
|
is_positive=self.is_positive,
|
||||||
|
indx=self.indx
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeTrans:
|
||||||
|
def __init__(self, size=480):
|
||||||
|
super().__init__()
|
||||||
|
self.crop_height = size
|
||||||
|
self.crop_width = size
|
||||||
|
|
||||||
|
def transform(self, image_nd, clicks_lists):
|
||||||
|
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
|
||||||
|
image_height, image_width = image_nd.shape[2:4]
|
||||||
|
self.image_height = image_height
|
||||||
|
self.image_width = image_width
|
||||||
|
image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True)
|
||||||
|
|
||||||
|
y_ratio = self.crop_height / image_height
|
||||||
|
x_ratio = self.crop_width / image_width
|
||||||
|
|
||||||
|
clicks_lists_resized = []
|
||||||
|
for clicks_list in clicks_lists:
|
||||||
|
clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list]
|
||||||
|
clicks_lists_resized.append(clicks_list_resized)
|
||||||
|
|
||||||
|
return image_nd_r, clicks_lists_resized
|
||||||
|
|
||||||
|
def inv_transform(self, prob_map):
|
||||||
|
new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear',
|
||||||
|
align_corners=True)
|
||||||
|
|
||||||
|
return new_prob_map
|
||||||
|
|
||||||
|
|
||||||
|
class ISPredictor(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
open_kernel_size: int,
|
||||||
|
dilate_kernel_size: int,
|
||||||
|
net_clicks_limit=None,
|
||||||
|
zoom_in=None,
|
||||||
|
infer_size=384,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.open_kernel_size = open_kernel_size
|
||||||
|
self.dilate_kernel_size = dilate_kernel_size
|
||||||
|
self.net_clicks_limit = net_clicks_limit
|
||||||
|
self.device = device
|
||||||
|
self.zoom_in = zoom_in
|
||||||
|
self.infer_size = infer_size
|
||||||
|
|
||||||
|
# self.transforms = [zoom_in] if zoom_in is not None else []
|
||||||
|
|
||||||
|
def __call__(self, input_image: torch.Tensor, clicks: List[Click], prev_mask):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_image: [1, 3, H, W] [0~1]
|
||||||
|
clicks: List[Click]
|
||||||
|
prev_mask: [1, 1, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
transforms = [ResizeTrans(self.infer_size)]
|
||||||
|
input_image = torch.cat((input_image, prev_mask), dim=1)
|
||||||
|
|
||||||
|
# image_nd resized to infer_size
|
||||||
|
for t in transforms:
|
||||||
|
image_nd, clicks_lists = t.transform(input_image, [clicks])
|
||||||
|
|
||||||
|
# image_nd.shape = [1, 4, 256, 256]
|
||||||
|
# points_nd.sha[e = [1, 2, 3]
|
||||||
|
# clicks_lists[0][0] Click 类
|
||||||
|
points_nd = self.get_points_nd(clicks_lists)
|
||||||
|
pred_logits = self.model(image_nd, points_nd)
|
||||||
|
pred = torch.sigmoid(pred_logits)
|
||||||
|
pred = self.post_process(pred)
|
||||||
|
|
||||||
|
prediction = F.interpolate(pred, mode='bilinear', align_corners=True,
|
||||||
|
size=image_nd.size()[2:])
|
||||||
|
|
||||||
|
for t in reversed(transforms):
|
||||||
|
prediction = t.inv_transform(prediction)
|
||||||
|
|
||||||
|
# if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
|
||||||
|
# return self.get_prediction(clicker)
|
||||||
|
|
||||||
|
return prediction.cpu().numpy()[0, 0]
|
||||||
|
|
||||||
|
def post_process(self, pred: torch.Tensor) -> torch.Tensor:
|
||||||
|
pred_mask = pred.cpu().numpy()[0][0]
|
||||||
|
# morph_open to remove small noise
|
||||||
|
kernel_size = self.open_kernel_size
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||||
|
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
||||||
|
|
||||||
|
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
|
||||||
|
dilate_kernel_size = self.dilate_kernel_size
|
||||||
|
if dilate_kernel_size > 1:
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size))
|
||||||
|
pred_mask = cv2.dilate(pred_mask, kernel, 1)
|
||||||
|
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
def get_points_nd(self, clicks_lists):
|
||||||
|
total_clicks = []
|
||||||
|
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
|
||||||
|
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
|
||||||
|
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
||||||
|
if self.net_clicks_limit is not None:
|
||||||
|
num_max_points = min(self.net_clicks_limit, num_max_points)
|
||||||
|
num_max_points = max(1, num_max_points)
|
||||||
|
|
||||||
|
for clicks_list in clicks_lists:
|
||||||
|
clicks_list = clicks_list[:self.net_clicks_limit]
|
||||||
|
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
|
||||||
|
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
|
||||||
|
|
||||||
|
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
|
||||||
|
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
|
||||||
|
total_clicks.append(pos_clicks + neg_clicks)
|
||||||
|
|
||||||
|
return torch.tensor(total_clicks, device=self.device)
|
||||||
|
|
||||||
|
|
||||||
|
INTERACTIVE_SEG_MODEL_URL = os.environ.get(
|
||||||
|
"INTERACTIVE_SEG_MODEL_URL",
|
||||||
|
"/Users/qing/code/github/ClickSEG/clickseg_pplnet.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InteractiveSeg:
|
||||||
|
def __init__(self, infer_size=448, open_kernel_size=3, dilate_kernel_size=3):
|
||||||
|
device = torch.device('cpu')
|
||||||
|
model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device).eval()
|
||||||
|
self.predictor = ISPredictor(model, device,
|
||||||
|
infer_size=infer_size,
|
||||||
|
open_kernel_size=open_kernel_size,
|
||||||
|
dilate_kernel_size=dilate_kernel_size)
|
||||||
|
|
||||||
|
def __call__(self, image, clicks, prev_mask=None):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: [H,W,C] RGB
|
||||||
|
clicks:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||||
|
image = torch.from_numpy((image / 255).transpose(2, 0, 1)).unsqueeze(0).float()
|
||||||
|
if prev_mask is None:
|
||||||
|
mask = torch.zeros_like(image[:, :1, :, :])
|
||||||
|
else:
|
||||||
|
logger.info('InteractiveSeg run with prev_mask')
|
||||||
|
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
|
||||||
|
|
||||||
|
pred_probs = self.predictor(image, clicks, mask)
|
||||||
|
pred_mask = pred_probs > 0.5
|
||||||
|
pred_mask = (pred_mask * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Find largest contour
|
||||||
|
# pred_mask = only_keep_largest_contour(pred_mask)
|
||||||
|
# To simplify frontend process, add mask brush color here
|
||||||
|
fg = pred_mask == 255
|
||||||
|
bg = pred_mask != 255
|
||||||
|
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2BGRA)
|
||||||
|
# frontend brush color "ffcc00bb"
|
||||||
|
pred_mask[bg] = 0
|
||||||
|
pred_mask[fg] = [255, 203, 0, int(255 * 0.73)]
|
||||||
|
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_BGRA2RGBA)
|
||||||
|
return pred_mask
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
@ -15,6 +16,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
@ -71,6 +73,7 @@ CORS(app, expose_headers=["Content-Disposition"])
|
|||||||
# socketio = SocketIO(app, max_http_buffer_size=MAX_BUFFER_SIZE, async_mode='threading')
|
# socketio = SocketIO(app, max_http_buffer_size=MAX_BUFFER_SIZE, async_mode='threading')
|
||||||
|
|
||||||
model: ModelManager = None
|
model: ModelManager = None
|
||||||
|
interactive_seg_model: InteractiveSeg = None
|
||||||
device = None
|
device = None
|
||||||
input_image_path: str = None
|
input_image_path: str = None
|
||||||
is_disable_model_switch: bool = False
|
is_disable_model_switch: bool = False
|
||||||
@ -97,6 +100,8 @@ def process():
|
|||||||
|
|
||||||
image, alpha_channel = load_img(origin_image_bytes)
|
image, alpha_channel = load_img(origin_image_bytes)
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||||
|
|
||||||
if image.shape[:2] != mask.shape[:2]:
|
if image.shape[:2] != mask.shape[:2]:
|
||||||
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
|
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
|
||||||
|
|
||||||
@ -181,6 +186,33 @@ def process():
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/interactive_seg", methods=["POST"])
|
||||||
|
def interactive_seg():
|
||||||
|
input = request.files
|
||||||
|
origin_image_bytes = input["image"].read() # RGB
|
||||||
|
image, _ = load_img(origin_image_bytes)
|
||||||
|
if 'mask' in input:
|
||||||
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
_clicks = json.loads(request.form["clicks"])
|
||||||
|
clicks = []
|
||||||
|
for i, click in enumerate(_clicks):
|
||||||
|
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
||||||
|
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
||||||
|
response = make_response(
|
||||||
|
send_file(
|
||||||
|
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
|
||||||
|
mimetype=f"image/png",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.route("/model")
|
@app.route("/model")
|
||||||
def current_model():
|
def current_model():
|
||||||
return model.name, 200
|
return model.name, 200
|
||||||
@ -240,6 +272,7 @@ def set_input_photo():
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
global model
|
global model
|
||||||
|
global interactive_seg_model
|
||||||
global device
|
global device
|
||||||
global input_image_path
|
global input_image_path
|
||||||
global is_disable_model_switch
|
global is_disable_model_switch
|
||||||
@ -263,6 +296,8 @@ def main(args):
|
|||||||
callback=diffuser_callback,
|
callback=diffuser_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
interactive_seg_model = InteractiveSeg()
|
||||||
|
|
||||||
if args.gui:
|
if args.gui:
|
||||||
app_width, app_height = args.gui_size
|
app_width, app_height = args.gui_size
|
||||||
from flaskwebgui import FlaskUI
|
from flaskwebgui import FlaskUI
|
||||||
|
37
lama_cleaner/tests/test_interactive_seg.py
Normal file
37
lama_cleaner/tests/test_interactive_seg.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / 'result'
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
|
||||||
|
|
||||||
|
|
||||||
|
def test_interactive_seg():
|
||||||
|
interactive_seg_model = InteractiveSeg()
|
||||||
|
img = cv2.imread(str(img_p))
|
||||||
|
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)])
|
||||||
|
cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
|
||||||
|
|
||||||
|
|
||||||
|
def test_interactive_seg_with_negative_click():
|
||||||
|
interactive_seg_model = InteractiveSeg()
|
||||||
|
img = cv2.imread(str(img_p))
|
||||||
|
pred = interactive_seg_model(img, clicks=[
|
||||||
|
Click(coords=(256, 256), indx=0, is_positive=True),
|
||||||
|
Click(coords=(384, 256), indx=1, is_positive=False)
|
||||||
|
])
|
||||||
|
cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
|
||||||
|
|
||||||
|
|
||||||
|
def test_interactive_seg_with_prev_mask():
|
||||||
|
interactive_seg_model = InteractiveSeg()
|
||||||
|
img = cv2.imread(str(img_p))
|
||||||
|
mask = np.zeros_like(img)[:, :, 0]
|
||||||
|
pred = interactive_seg_model(img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask)
|
||||||
|
cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)
|
Loading…
Reference in New Issue
Block a user