add adjust mask feature
This commit is contained in:
parent
2996544e75
commit
e889e527ab
@ -29,6 +29,7 @@ from lama_cleaner.helper import (
|
||||
numpy_to_bytes,
|
||||
concat_alpha_channel,
|
||||
gen_frontend_mask,
|
||||
adjust_mask,
|
||||
)
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_info import ModelInfo
|
||||
@ -44,6 +45,7 @@ from lama_cleaner.schema import (
|
||||
RunPluginRequest,
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
AdjustMaskRequest,
|
||||
)
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||
@ -150,6 +152,7 @@ class Api:
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
|
||||
@ -294,6 +297,13 @@ class Api:
|
||||
def api_samplers(self) -> List[str]:
|
||||
return [member.value for member in SDSampler.__members__.values()]
|
||||
|
||||
def api_adjust_mask(self, req: AdjustMaskRequest):
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
cv2.imwrite("tmp_adjust_mask_input.png", mask)
|
||||
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
||||
cv2.imwrite("tmp_adjust_mask.png", mask)
|
||||
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
||||
|
||||
def launch(self):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
|
@ -358,16 +358,19 @@ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
|
||||
mask[mask >= 127] = 255
|
||||
mask[mask < 127] = 0
|
||||
# fronted brush color "ffcc00bb"
|
||||
kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
|
||||
)
|
||||
if operate == "expand":
|
||||
mask = cv2.dilate(
|
||||
mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8),
|
||||
kernel,
|
||||
iterations=1,
|
||||
)
|
||||
else:
|
||||
mask = cv2.erode(
|
||||
mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8),
|
||||
kernel,
|
||||
iterations=1,
|
||||
)
|
||||
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
||||
|
@ -110,8 +110,8 @@ class ApiConfig(BaseModel):
|
||||
|
||||
|
||||
class InpaintRequest(BaseModel):
|
||||
image: Optional[str] = Field(..., description="base64 encoded image")
|
||||
mask: Optional[str] = Field(..., description="base64 encoded mask")
|
||||
image: Optional[str] = Field(None, description="base64 encoded image")
|
||||
mask: Optional[str] = Field(None, description="base64 encoded mask")
|
||||
|
||||
ldm_steps: int = Field(20, description="Steps for ldm model.")
|
||||
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
|
||||
@ -289,3 +289,12 @@ class ServerConfigResponse(BaseModel):
|
||||
|
||||
class SwitchModelRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
AdjustMaskOperate = Literal["expand", "shrink"]
|
||||
|
||||
|
||||
class AdjustMaskRequest(BaseModel):
|
||||
mask: str = Field(..., description="base64 encoded mask. 255 means area to do inpaint")
|
||||
operate: AdjustMaskOperate = Field(..., description="expand or shrink")
|
||||
kernel_size: int = Field(5, description="Kernel size for expanding mask")
|
||||
|
15
lama_cleaner/tests/test_adjust_mask.py
Normal file
15
lama_cleaner/tests/test_adjust_mask.py
Normal file
@ -0,0 +1,15 @@
|
||||
import cv2
|
||||
from lama_cleaner.helper import adjust_mask
|
||||
from lama_cleaner.tests.utils import current_dir, save_dir
|
||||
|
||||
mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
|
||||
def test_adjust_mask():
|
||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||
res_mask = adjust_mask(mask, 0, "expand")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask)
|
||||
res_mask = adjust_mask(mask, 40, "expand")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask)
|
||||
res_mask = adjust_mask(mask, 20, "shrink")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask)
|
@ -31,8 +31,6 @@ def assert_equal(
|
||||
mask_p=current_dir / "mask.png",
|
||||
):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
|
||||
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
|
||||
print(f"Input image shape: {img.shape}")
|
||||
res = model(img, mask, config)
|
||||
ok = cv2.imwrite(
|
||||
|
29
web_app/package-lock.json
generated
29
web_app/package-lock.json
generated
@ -12,6 +12,7 @@
|
||||
"@hookform/resolvers": "^3.3.2",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.0.5",
|
||||
"@radix-ui/react-context-menu": "^2.1.5",
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-dropdown-menu": "^2.0.6",
|
||||
"@radix-ui/react-icons": "^1.3.0",
|
||||
@ -1537,6 +1538,34 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-context-menu": {
|
||||
"version": "2.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.1.5.tgz",
|
||||
"integrity": "sha512-R5XaDj06Xul1KGb+WP8qiOh7tKJNz2durpLBXAGZjSVtctcRFCuEvy2gtMwRJGePwQQE5nV77gs4FwRi8T+r2g==",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.13.10",
|
||||
"@radix-ui/primitive": "1.0.1",
|
||||
"@radix-ui/react-context": "1.0.1",
|
||||
"@radix-ui/react-menu": "2.0.6",
|
||||
"@radix-ui/react-primitive": "1.0.3",
|
||||
"@radix-ui/react-use-callback-ref": "1.0.1",
|
||||
"@radix-ui/react-use-controllable-state": "1.0.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dialog": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz",
|
||||
|
@ -14,6 +14,7 @@
|
||||
"@hookform/resolvers": "^3.3.2",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.0.5",
|
||||
"@radix-ui/react-context-menu": "^2.1.5",
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-dropdown-menu": "^2.0.6",
|
||||
"@radix-ui/react-icons": "^1.3.0",
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { useCallback, useEffect, useMemo, useRef } from "react"
|
||||
import { nanoid } from "nanoid"
|
||||
import { useCallback, useEffect, useRef } from "react"
|
||||
|
||||
import useInputImage from "@/hooks/useInputImage"
|
||||
import { keepGUIAlive } from "@/lib/utils"
|
||||
@ -52,10 +51,6 @@ function Home() {
|
||||
fetchServerConfig()
|
||||
}, [])
|
||||
|
||||
const workspaceId = useMemo(() => {
|
||||
return nanoid()
|
||||
}, [file])
|
||||
|
||||
const dragCounter = useRef(0)
|
||||
|
||||
const handleDrag = useCallback((event: any) => {
|
||||
@ -155,7 +150,7 @@ function Home() {
|
||||
<main className="flex min-h-screen flex-col items-center justify-between w-full bg-[radial-gradient(circle_at_1px_1px,_#8e8e8e8e_1px,_transparent_0)] [background-size:20px_20px] bg-repeat">
|
||||
<Toaster />
|
||||
<Header />
|
||||
<Workspace key={workspaceId} />
|
||||
<Workspace />
|
||||
{!file ? (
|
||||
<FileSelect
|
||||
onSelection={async (f) => {
|
||||
|
@ -52,9 +52,9 @@ const DiffusionProgress = () => {
|
||||
|
||||
return (
|
||||
<div
|
||||
className="fixed w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
|
||||
className="z-10 fixed bg-background w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
|
||||
style={{
|
||||
visibility: isInpainting && isConnected && isSD ? "visible" : "hidden",
|
||||
visibility: isConnected && isInpainting && isSD ? "visible" : "hidden",
|
||||
}}
|
||||
>
|
||||
<Progress value={progress} />
|
||||
|
@ -95,6 +95,7 @@ export default function Editor(props: EditorProps) {
|
||||
const brushSize = useStore((state) => state.getBrushSize())
|
||||
const renders = useStore((state) => state.editorState.renders)
|
||||
const extraMasks = useStore((state) => state.editorState.extraMasks)
|
||||
const temporaryMasks = useStore((state) => state.editorState.temporaryMasks)
|
||||
const lineGroups = useStore((state) => state.editorState.lineGroups)
|
||||
const curLineGroup = useStore((state) => state.editorState.curLineGroup)
|
||||
|
||||
@ -166,6 +167,9 @@ export default function Editor(props: EditorProps) {
|
||||
context.canvas.width = imageWidth
|
||||
context.canvas.height = imageHeight
|
||||
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
|
||||
temporaryMasks.forEach((maskImage) => {
|
||||
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||
})
|
||||
extraMasks.forEach((maskImage) => {
|
||||
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||
})
|
||||
@ -182,20 +186,9 @@ export default function Editor(props: EditorProps) {
|
||||
imageHeight
|
||||
)
|
||||
}
|
||||
if (
|
||||
!interactiveSegState.isInteractiveSeg &&
|
||||
interactiveSegState.interactiveSegMask
|
||||
) {
|
||||
context.drawImage(
|
||||
interactiveSegState.interactiveSegMask,
|
||||
0,
|
||||
0,
|
||||
imageWidth,
|
||||
imageHeight
|
||||
)
|
||||
}
|
||||
drawLines(context, curLineGroup)
|
||||
}, [
|
||||
temporaryMasks,
|
||||
extraMasks,
|
||||
isOriginalLoaded,
|
||||
interactiveSegState,
|
||||
|
@ -21,7 +21,6 @@ const PromptInput = () => {
|
||||
state.hidePrevMask,
|
||||
])
|
||||
const ref = useRef(null)
|
||||
|
||||
useClickAway<MouseEvent>(ref, () => {
|
||||
if (ref?.current) {
|
||||
const input = ref.current as HTMLInputElement
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { FormEvent } from "react"
|
||||
import { FormEvent, useRef } from "react"
|
||||
import { useStore } from "@/lib/states"
|
||||
import { Switch } from "../ui/switch"
|
||||
import { NumberInput } from "../ui/input"
|
||||
@ -18,7 +18,8 @@ import { Slider } from "../ui/slider"
|
||||
import { useImage } from "@/hooks/useImage"
|
||||
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
|
||||
import { RowContainer, LabelTitle } from "./LabelTitle"
|
||||
import { Upload } from "lucide-react"
|
||||
import { Minus, Plus, Upload } from "lucide-react"
|
||||
import { useClickAway } from "react-use"
|
||||
|
||||
const ExtenderButton = ({
|
||||
text,
|
||||
@ -31,7 +32,7 @@ const ExtenderButton = ({
|
||||
return (
|
||||
<Button
|
||||
variant="outline"
|
||||
className="p-1 h-7"
|
||||
className="p-1 h-8"
|
||||
disabled={!showExtender}
|
||||
onClick={onClick}
|
||||
>
|
||||
@ -51,6 +52,8 @@ const DiffusionOptions = () => {
|
||||
updateAppState,
|
||||
updateExtenderByBuiltIn,
|
||||
updateExtenderDirection,
|
||||
adjustMask,
|
||||
clearMask,
|
||||
] = useStore((state) => [
|
||||
state.serverConfig.samplers,
|
||||
state.settings,
|
||||
@ -61,8 +64,17 @@ const DiffusionOptions = () => {
|
||||
state.updateAppState,
|
||||
state.updateExtenderByBuiltIn,
|
||||
state.updateExtenderDirection,
|
||||
state.adjustMask,
|
||||
state.clearMask,
|
||||
])
|
||||
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
||||
const negativePromptRef = useRef(null)
|
||||
useClickAway<MouseEvent>(negativePromptRef, () => {
|
||||
if (negativePromptRef?.current) {
|
||||
const input = negativePromptRef.current as HTMLInputElement
|
||||
input.blur()
|
||||
}
|
||||
})
|
||||
|
||||
const onKeyUp = (e: React.KeyboardEvent) => {
|
||||
// negativePrompt 回车触发 inpainting
|
||||
@ -320,6 +332,7 @@ const DiffusionOptions = () => {
|
||||
/>
|
||||
<div className="pl-2 pr-4">
|
||||
<Textarea
|
||||
ref={negativePromptRef}
|
||||
rows={4}
|
||||
onKeyUp={onKeyUp}
|
||||
className="max-h-[8rem] overflow-y-auto mb-2"
|
||||
@ -550,7 +563,7 @@ const DiffusionOptions = () => {
|
||||
<RowContainer>
|
||||
<LabelTitle
|
||||
text="Task"
|
||||
toolTip="When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
|
||||
toolTip="PowerPaint task. When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
|
||||
/>
|
||||
<Select
|
||||
defaultValue={settings.powerpaintTask}
|
||||
@ -760,10 +773,85 @@ const DiffusionOptions = () => {
|
||||
)
|
||||
}
|
||||
|
||||
const renderMaskAdjuster = () => {
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-1">
|
||||
<LabelTitle
|
||||
htmlFor="adjustMaskKernelSize"
|
||||
text="Adjust Mask"
|
||||
toolTip="Expand or shrink mask. Using the slider to adjust the kernel size for dilation or erosion."
|
||||
/>
|
||||
<RowContainer>
|
||||
<Slider
|
||||
className="w-[180px]"
|
||||
defaultValue={[12]}
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
value={[Math.floor(settings.adjustMaskKernelSize)]}
|
||||
onValueChange={(vals) =>
|
||||
updateSettings({ adjustMaskKernelSize: vals[0] })
|
||||
}
|
||||
/>
|
||||
<NumberInput
|
||||
id="adjustMaskKernelSize"
|
||||
className="w-[60px] rounded-full"
|
||||
numberValue={settings.adjustMaskKernelSize}
|
||||
allowFloat={false}
|
||||
onNumberValueChange={(val) => {
|
||||
updateSettings({ adjustMaskKernelSize: val })
|
||||
}}
|
||||
/>
|
||||
</RowContainer>
|
||||
|
||||
<RowContainer>
|
||||
<div className="flex gap-1 justify-start">
|
||||
<Button
|
||||
variant="outline"
|
||||
className="p-1 h-8"
|
||||
onClick={() => adjustMask("expand")}
|
||||
disabled={isProcessing}
|
||||
>
|
||||
<div className="flex items-center gap-1 select-none">
|
||||
<Plus size={16} />
|
||||
expand
|
||||
</div>
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
className="p-1 h-8"
|
||||
onClick={() => adjustMask("shrink")}
|
||||
disabled={isProcessing}
|
||||
>
|
||||
<div className="flex items-center gap-1 select-none">
|
||||
<Minus size={16} />
|
||||
Shrink
|
||||
</div>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
className="p-1 h-8 justify-self-end"
|
||||
onClick={clearMask}
|
||||
disabled={isProcessing}
|
||||
>
|
||||
<div className="flex items-center gap-1 select-none">Clear</div>
|
||||
</Button>
|
||||
</RowContainer>
|
||||
</div>
|
||||
<Separator />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4 mt-4">
|
||||
{renderCropper()}
|
||||
{renderExtender()}
|
||||
{renderMaskAdjuster()}
|
||||
{renderPowerPaintTaskType()}
|
||||
{renderSteps()}
|
||||
{renderGuidanceScale()}
|
||||
|
202
web_app/src/components/ui/context-menu.tsx
Normal file
202
web_app/src/components/ui/context-menu.tsx
Normal file
@ -0,0 +1,202 @@
|
||||
import * as React from "react"
|
||||
import * as ContextMenuPrimitive from "@radix-ui/react-context-menu"
|
||||
import {
|
||||
CheckIcon,
|
||||
ChevronRightIcon,
|
||||
DotFilledIcon,
|
||||
} from "@radix-ui/react-icons"
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
const ContextMenu = ContextMenuPrimitive.Root
|
||||
|
||||
const ContextMenuTrigger = ContextMenuPrimitive.Trigger
|
||||
|
||||
const ContextMenuGroup = ContextMenuPrimitive.Group
|
||||
|
||||
const ContextMenuPortal = ContextMenuPrimitive.Portal
|
||||
|
||||
const ContextMenuSub = ContextMenuPrimitive.Sub
|
||||
|
||||
const ContextMenuRadioGroup = ContextMenuPrimitive.RadioGroup
|
||||
|
||||
const ContextMenuSubTrigger = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.SubTrigger>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubTrigger> & {
|
||||
inset?: boolean
|
||||
}
|
||||
>(({ className, inset, children, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.SubTrigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[state=open]:text-accent-foreground",
|
||||
inset && "pl-8",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<ChevronRightIcon className="ml-auto h-4 w-4" />
|
||||
</ContextMenuPrimitive.SubTrigger>
|
||||
))
|
||||
ContextMenuSubTrigger.displayName = ContextMenuPrimitive.SubTrigger.displayName
|
||||
|
||||
const ContextMenuSubContent = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.SubContent>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubContent>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.SubContent
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-lg data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
ContextMenuSubContent.displayName = ContextMenuPrimitive.SubContent.displayName
|
||||
|
||||
const ContextMenuContent = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Content>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.Portal>
|
||||
<ContextMenuPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
</ContextMenuPrimitive.Portal>
|
||||
))
|
||||
ContextMenuContent.displayName = ContextMenuPrimitive.Content.displayName
|
||||
|
||||
const ContextMenuItem = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.Item>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Item> & {
|
||||
inset?: boolean
|
||||
}
|
||||
>(({ className, inset, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.Item
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"relative flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||
inset && "pl-8",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
ContextMenuItem.displayName = ContextMenuPrimitive.Item.displayName
|
||||
|
||||
const ContextMenuCheckboxItem = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.CheckboxItem>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.CheckboxItem>
|
||||
>(({ className, children, checked, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.CheckboxItem
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||
className
|
||||
)}
|
||||
checked={checked}
|
||||
{...props}
|
||||
>
|
||||
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
|
||||
<ContextMenuPrimitive.ItemIndicator>
|
||||
<CheckIcon className="h-4 w-4" />
|
||||
</ContextMenuPrimitive.ItemIndicator>
|
||||
</span>
|
||||
{children}
|
||||
</ContextMenuPrimitive.CheckboxItem>
|
||||
))
|
||||
ContextMenuCheckboxItem.displayName =
|
||||
ContextMenuPrimitive.CheckboxItem.displayName
|
||||
|
||||
const ContextMenuRadioItem = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.RadioItem>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.RadioItem>
|
||||
>(({ className, children, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.RadioItem
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
|
||||
<ContextMenuPrimitive.ItemIndicator>
|
||||
<DotFilledIcon className="h-4 w-4 fill-current" />
|
||||
</ContextMenuPrimitive.ItemIndicator>
|
||||
</span>
|
||||
{children}
|
||||
</ContextMenuPrimitive.RadioItem>
|
||||
))
|
||||
ContextMenuRadioItem.displayName = ContextMenuPrimitive.RadioItem.displayName
|
||||
|
||||
const ContextMenuLabel = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.Label>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Label> & {
|
||||
inset?: boolean
|
||||
}
|
||||
>(({ className, inset, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.Label
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"px-2 py-1.5 text-sm font-semibold text-foreground",
|
||||
inset && "pl-8",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
ContextMenuLabel.displayName = ContextMenuPrimitive.Label.displayName
|
||||
|
||||
const ContextMenuSeparator = React.forwardRef<
|
||||
React.ElementRef<typeof ContextMenuPrimitive.Separator>,
|
||||
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Separator>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<ContextMenuPrimitive.Separator
|
||||
ref={ref}
|
||||
className={cn("-mx-1 my-1 h-px bg-border", className)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
ContextMenuSeparator.displayName = ContextMenuPrimitive.Separator.displayName
|
||||
|
||||
const ContextMenuShortcut = ({
|
||||
className,
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLSpanElement>) => {
|
||||
return (
|
||||
<span
|
||||
className={cn(
|
||||
"ml-auto text-xs tracking-widest text-muted-foreground",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
ContextMenuShortcut.displayName = "ContextMenuShortcut"
|
||||
|
||||
export {
|
||||
ContextMenu,
|
||||
ContextMenuTrigger,
|
||||
ContextMenuContent,
|
||||
ContextMenuItem,
|
||||
ContextMenuCheckboxItem,
|
||||
ContextMenuRadioItem,
|
||||
ContextMenuLabel,
|
||||
ContextMenuSeparator,
|
||||
ContextMenuShortcut,
|
||||
ContextMenuGroup,
|
||||
ContextMenuPortal,
|
||||
ContextMenuSub,
|
||||
ContextMenuSubContent,
|
||||
ContextMenuSubTrigger,
|
||||
ContextMenuRadioGroup,
|
||||
}
|
@ -17,6 +17,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
const handleOnBlur = () => {
|
||||
updateAppState({ disableShortCuts: false })
|
||||
}
|
||||
|
||||
return (
|
||||
<input
|
||||
type={type}
|
||||
|
@ -201,3 +201,28 @@ export async function getSamplers(): Promise<string[]> {
|
||||
const res = await api.post("/samplers")
|
||||
return res.data
|
||||
}
|
||||
|
||||
export async function postAdjustMask(
|
||||
mask: File | Blob,
|
||||
operate: "expand" | "shrink",
|
||||
kernel_size: number
|
||||
) {
|
||||
const maskBase64 = await convertToBase64(mask)
|
||||
const res = await fetch(`${API_ENDPOINT}/adjust_mask`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
mask: maskBase64,
|
||||
operate: operate,
|
||||
kernel_size: kernel_size,
|
||||
}),
|
||||
})
|
||||
if (res.ok) {
|
||||
const blob = await res.blob()
|
||||
return blob
|
||||
}
|
||||
const errMsg = await res.json()
|
||||
throw new Error(errMsg)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import { immer } from "zustand/middleware/immer"
|
||||
import { castDraft } from "immer"
|
||||
import { createWithEqualityFn } from "zustand/traditional"
|
||||
import {
|
||||
AdjustMaskOperate,
|
||||
CV2Flag,
|
||||
ExtenderDirection,
|
||||
FreeuConfig,
|
||||
@ -27,13 +28,14 @@ import {
|
||||
PAINT_BY_EXAMPLE,
|
||||
} from "./const"
|
||||
import {
|
||||
blobToImage,
|
||||
canvasToImage,
|
||||
dataURItoBlob,
|
||||
generateMask,
|
||||
loadImage,
|
||||
srcToFile,
|
||||
} from "./utils"
|
||||
import inpaint, { getGenInfo, runPlugin } from "./api"
|
||||
import inpaint, { getGenInfo, postAdjustMask, runPlugin } from "./api"
|
||||
import { toast } from "@/components/ui/use-toast"
|
||||
|
||||
type FileManagerState = {
|
||||
@ -102,13 +104,14 @@ export type Settings = {
|
||||
|
||||
// PowerPaint
|
||||
powerpaintTask: PowerPaintTask
|
||||
|
||||
// AdjustMask
|
||||
adjustMaskKernelSize: number
|
||||
}
|
||||
|
||||
type InteractiveSegState = {
|
||||
isInteractiveSeg: boolean
|
||||
interactiveSegMask: HTMLImageElement | null
|
||||
tmpInteractiveSegMask: HTMLImageElement | null
|
||||
prevInteractiveSegMask: HTMLImageElement | null
|
||||
clicks: number[][]
|
||||
}
|
||||
|
||||
@ -119,8 +122,12 @@ type EditorState = {
|
||||
lineGroups: LineGroup[]
|
||||
lastLineGroup: LineGroup
|
||||
curLineGroup: LineGroup
|
||||
// 只用来显示
|
||||
|
||||
// mask from interactive-seg or other segmentation models
|
||||
extraMasks: HTMLImageElement[]
|
||||
prevExtraMasks: HTMLImageElement[]
|
||||
|
||||
temporaryMasks: HTMLImageElement[]
|
||||
// redo 相关
|
||||
redoRenders: HTMLImageElement[]
|
||||
redoCurLines: Line[]
|
||||
@ -135,6 +142,7 @@ type AppState = {
|
||||
imageWidth: number
|
||||
isInpainting: boolean
|
||||
isPluginRunning: boolean
|
||||
isAdjustingMask: boolean
|
||||
windowSize: Size
|
||||
editorState: EditorState
|
||||
disableShortCuts: boolean
|
||||
@ -209,6 +217,9 @@ type AppAction = {
|
||||
redo: () => void
|
||||
undoDisabled: () => boolean
|
||||
redoDisabled: () => boolean
|
||||
|
||||
adjustMask: (operate: AdjustMaskOperate) => Promise<void>
|
||||
clearMask: () => void
|
||||
}
|
||||
|
||||
const defaultValues: AppState = {
|
||||
@ -219,6 +230,7 @@ const defaultValues: AppState = {
|
||||
imageWidth: 0,
|
||||
isInpainting: false,
|
||||
isPluginRunning: false,
|
||||
isAdjustingMask: false,
|
||||
disableShortCuts: false,
|
||||
|
||||
windowSize: {
|
||||
@ -230,6 +242,8 @@ const defaultValues: AppState = {
|
||||
brushSizeScale: 1,
|
||||
renders: [],
|
||||
extraMasks: [],
|
||||
prevExtraMasks: [],
|
||||
temporaryMasks: [],
|
||||
lineGroups: [],
|
||||
lastLineGroup: [],
|
||||
curLineGroup: [],
|
||||
@ -240,9 +254,7 @@ const defaultValues: AppState = {
|
||||
|
||||
interactiveSegState: {
|
||||
isInteractiveSeg: false,
|
||||
interactiveSegMask: null,
|
||||
tmpInteractiveSegMask: null,
|
||||
prevInteractiveSegMask: null,
|
||||
clicks: [],
|
||||
},
|
||||
|
||||
@ -323,6 +335,7 @@ const defaultValues: AppState = {
|
||||
enableFreeu: false,
|
||||
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
||||
powerpaintTask: PowerPaintTask.text_guided,
|
||||
adjustMaskKernelSize: 12,
|
||||
},
|
||||
}
|
||||
|
||||
@ -335,10 +348,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
if (get().settings.showExtender) {
|
||||
return
|
||||
}
|
||||
const { lastLineGroup, curLineGroup } = get().editorState
|
||||
const { prevInteractiveSegMask, interactiveSegMask } =
|
||||
get().interactiveSegState
|
||||
if (curLineGroup.length !== 0 || interactiveSegMask !== null) {
|
||||
const { lastLineGroup, curLineGroup, prevExtraMasks, extraMasks } =
|
||||
get().editorState
|
||||
if (curLineGroup.length !== 0 || extraMasks.length !== 0) {
|
||||
return
|
||||
}
|
||||
const { imageWidth, imageHeight } = get()
|
||||
@ -347,13 +359,13 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
[lastLineGroup],
|
||||
prevInteractiveSegMask ? [prevInteractiveSegMask] : [],
|
||||
prevExtraMasks,
|
||||
BRUSH_COLOR
|
||||
)
|
||||
try {
|
||||
const maskImage = await canvasToImage(maskCanvas)
|
||||
set((state) => {
|
||||
state.editorState.extraMasks.push(castDraft(maskImage))
|
||||
state.editorState.temporaryMasks.push(castDraft(maskImage))
|
||||
})
|
||||
} catch (e) {
|
||||
console.error(e)
|
||||
@ -362,7 +374,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
},
|
||||
hidePrevMask: () => {
|
||||
set((state) => {
|
||||
state.editorState.extraMasks = []
|
||||
state.editorState.temporaryMasks = []
|
||||
})
|
||||
},
|
||||
|
||||
@ -408,33 +420,36 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
return
|
||||
}
|
||||
|
||||
const { lastLineGroup, curLineGroup, lineGroups, renders } =
|
||||
get().editorState
|
||||
|
||||
const { interactiveSegMask, prevInteractiveSegMask } =
|
||||
get().interactiveSegState
|
||||
const {
|
||||
lastLineGroup,
|
||||
curLineGroup,
|
||||
lineGroups,
|
||||
renders,
|
||||
prevExtraMasks,
|
||||
extraMasks,
|
||||
} = get().editorState
|
||||
|
||||
const useLastLineGroup =
|
||||
curLineGroup.length === 0 &&
|
||||
interactiveSegMask === null &&
|
||||
extraMasks.length === 0 &&
|
||||
!settings.showExtender
|
||||
|
||||
// useLastLineGroup 的影响
|
||||
// 1. 使用上一次的 mask
|
||||
// 2. 结果替换当前 render
|
||||
let maskImage = null
|
||||
let maskImages: HTMLImageElement[] = []
|
||||
let maskLineGroup: LineGroup = []
|
||||
if (useLastLineGroup === true) {
|
||||
maskLineGroup = lastLineGroup
|
||||
maskImage = prevInteractiveSegMask
|
||||
maskImages = prevExtraMasks
|
||||
} else {
|
||||
maskLineGroup = curLineGroup
|
||||
maskImage = interactiveSegMask
|
||||
maskImages = extraMasks
|
||||
}
|
||||
|
||||
if (
|
||||
maskLineGroup.length === 0 &&
|
||||
maskImage === null &&
|
||||
maskImages === null &&
|
||||
!settings.showExtender
|
||||
) {
|
||||
toast({
|
||||
@ -474,7 +489,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
[maskLineGroup],
|
||||
maskImage ? [maskImage] : []
|
||||
maskImages
|
||||
)
|
||||
|
||||
try {
|
||||
@ -500,6 +515,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
lineGroups: newLineGroups,
|
||||
lastLineGroup: maskLineGroup,
|
||||
curLineGroup: [],
|
||||
extraMasks: [],
|
||||
prevExtraMasks: maskImages,
|
||||
})
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
@ -512,15 +529,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
set((state) => {
|
||||
state.isInpainting = false
|
||||
})
|
||||
|
||||
const newInteractiveSegState = {
|
||||
...defaultValues.interactiveSegState,
|
||||
prevInteractiveSegMask: maskImage,
|
||||
}
|
||||
|
||||
set((state) => {
|
||||
state.interactiveSegState = castDraft(newInteractiveSegState)
|
||||
})
|
||||
},
|
||||
|
||||
runRenderablePlugin: async (
|
||||
@ -557,8 +565,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
} else {
|
||||
const newMask = new Image()
|
||||
await loadImage(newMask, blob)
|
||||
get().updateInteractiveSegState({
|
||||
interactiveSegMask: newMask,
|
||||
set((state) => {
|
||||
state.editorState.extraMasks.push(castDraft(newMask))
|
||||
})
|
||||
}
|
||||
const end = new Date()
|
||||
@ -618,7 +626,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
},
|
||||
|
||||
getIsProcessing: (): boolean => {
|
||||
return get().isInpainting || get().isPluginRunning
|
||||
return (
|
||||
get().isInpainting || get().isPluginRunning || get().isAdjustingMask
|
||||
)
|
||||
},
|
||||
|
||||
isSD: (): boolean => {
|
||||
@ -809,14 +819,14 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
|
||||
handleInteractiveSegAccept: () => {
|
||||
set((state) => {
|
||||
return {
|
||||
...state,
|
||||
interactiveSegState: {
|
||||
...defaultValues.interactiveSegState,
|
||||
interactiveSegMask:
|
||||
state.interactiveSegState.tmpInteractiveSegMask,
|
||||
},
|
||||
if (state.interactiveSegState.tmpInteractiveSegMask) {
|
||||
state.editorState.extraMasks.push(
|
||||
castDraft(state.interactiveSegState.tmpInteractiveSegMask)
|
||||
)
|
||||
}
|
||||
state.interactiveSegState = castDraft({
|
||||
...defaultValues.interactiveSegState,
|
||||
})
|
||||
})
|
||||
},
|
||||
|
||||
@ -986,10 +996,54 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
set((state) => {
|
||||
state.settings.seed = newValue
|
||||
}),
|
||||
|
||||
adjustMask: async (operate: AdjustMaskOperate) => {
|
||||
const { imageWidth, imageHeight } = get()
|
||||
const { curLineGroup, extraMasks } = get().editorState
|
||||
const { adjustMaskKernelSize } = get().settings
|
||||
if (curLineGroup.length === 0 && extraMasks.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
set((state) => {
|
||||
state.isAdjustingMask = true
|
||||
})
|
||||
|
||||
const maskCanvas = generateMask(
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
[curLineGroup],
|
||||
extraMasks,
|
||||
BRUSH_COLOR
|
||||
)
|
||||
const maskBlob = dataURItoBlob(maskCanvas.toDataURL())
|
||||
const newMaskBlob = await postAdjustMask(
|
||||
maskBlob,
|
||||
operate,
|
||||
adjustMaskKernelSize
|
||||
)
|
||||
const newMask = await blobToImage(newMaskBlob)
|
||||
|
||||
// TODO: currently ignore stroke undo/redo
|
||||
set((state) => {
|
||||
state.editorState.extraMasks = [castDraft(newMask)]
|
||||
state.editorState.curLineGroup = []
|
||||
})
|
||||
|
||||
set((state) => {
|
||||
state.isAdjustingMask = false
|
||||
})
|
||||
},
|
||||
clearMask: () => {
|
||||
set((state) => {
|
||||
state.editorState.extraMasks = []
|
||||
state.editorState.curLineGroup = []
|
||||
})
|
||||
},
|
||||
})),
|
||||
{
|
||||
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
||||
version: 0,
|
||||
version: 1,
|
||||
partialize: (state) =>
|
||||
Object.fromEntries(
|
||||
Object.entries(state).filter(([key]) =>
|
||||
|
@ -125,3 +125,5 @@ export enum PowerPaintTask {
|
||||
object_remove = "object-remove",
|
||||
outpainting = "outpainting",
|
||||
}
|
||||
|
||||
export type AdjustMaskOperate = "expand" | "shrink"
|
||||
|
@ -53,6 +53,13 @@ export function loadImage(image: HTMLImageElement, src: string) {
|
||||
})
|
||||
}
|
||||
|
||||
export async function blobToImage(blob: Blob) {
|
||||
const dataURL = URL.createObjectURL(blob)
|
||||
const newImage = new Image()
|
||||
await loadImage(newImage, dataURL)
|
||||
return newImage
|
||||
}
|
||||
|
||||
export function canvasToImage(
|
||||
canvas: HTMLCanvasElement
|
||||
): Promise<HTMLImageElement> {
|
||||
|
Loading…
Reference in New Issue
Block a user