add adjust mask feature

This commit is contained in:
Qing 2024-01-05 14:57:30 +08:00
parent 2996544e75
commit e889e527ab
18 changed files with 507 additions and 76 deletions

View File

@ -29,6 +29,7 @@ from lama_cleaner.helper import (
numpy_to_bytes, numpy_to_bytes,
concat_alpha_channel, concat_alpha_channel,
gen_frontend_mask, gen_frontend_mask,
adjust_mask,
) )
from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo from lama_cleaner.model_info import ModelInfo
@ -44,6 +45,7 @@ from lama_cleaner.schema import (
RunPluginRequest, RunPluginRequest,
SDSampler, SDSampler,
PluginInfo, PluginInfo,
AdjustMaskRequest,
) )
CURRENT_DIR = Path(__file__).parent.absolute().resolve() CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@ -150,6 +152,7 @@ class Api:
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on # fmt: on
@ -294,6 +297,13 @@ class Api:
def api_samplers(self) -> List[str]: def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()] return [member.value for member in SDSampler.__members__.values()]
def api_adjust_mask(self, req: AdjustMaskRequest):
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
cv2.imwrite("tmp_adjust_mask_input.png", mask)
mask = adjust_mask(mask, req.kernel_size, req.operate)
cv2.imwrite("tmp_adjust_mask.png", mask)
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
def launch(self): def launch(self):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run( uvicorn.run(

View File

@ -358,16 +358,19 @@ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
mask[mask >= 127] = 255 mask[mask >= 127] = 255
mask[mask < 127] = 0 mask[mask < 127] = 0
# fronted brush color "ffcc00bb" # fronted brush color "ffcc00bb"
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
)
if operate == "expand": if operate == "expand":
mask = cv2.dilate( mask = cv2.dilate(
mask, mask,
np.ones((kernel_size, kernel_size), np.uint8), kernel,
iterations=1, iterations=1,
) )
else: else:
mask = cv2.erode( mask = cv2.erode(
mask, mask,
np.ones((kernel_size, kernel_size), np.uint8), kernel,
iterations=1, iterations=1,
) )
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)

View File

@ -110,8 +110,8 @@ class ApiConfig(BaseModel):
class InpaintRequest(BaseModel): class InpaintRequest(BaseModel):
image: Optional[str] = Field(..., description="base64 encoded image") image: Optional[str] = Field(None, description="base64 encoded image")
mask: Optional[str] = Field(..., description="base64 encoded mask") mask: Optional[str] = Field(None, description="base64 encoded mask")
ldm_steps: int = Field(20, description="Steps for ldm model.") ldm_steps: int = Field(20, description="Steps for ldm model.")
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.") ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
@ -289,3 +289,12 @@ class ServerConfigResponse(BaseModel):
class SwitchModelRequest(BaseModel): class SwitchModelRequest(BaseModel):
name: str name: str
AdjustMaskOperate = Literal["expand", "shrink"]
class AdjustMaskRequest(BaseModel):
mask: str = Field(..., description="base64 encoded mask. 255 means area to do inpaint")
operate: AdjustMaskOperate = Field(..., description="expand or shrink")
kernel_size: int = Field(5, description="Kernel size for expanding mask")

View File

@ -0,0 +1,15 @@
import cv2
from lama_cleaner.helper import adjust_mask
from lama_cleaner.tests.utils import current_dir, save_dir
mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png"
def test_adjust_mask():
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
res_mask = adjust_mask(mask, 0, "expand")
cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask)
res_mask = adjust_mask(mask, 40, "expand")
cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask)
res_mask = adjust_mask(mask, 20, "shrink")
cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask)

View File

@ -31,8 +31,6 @@ def assert_equal(
mask_p=current_dir / "mask.png", mask_p=current_dir / "mask.png",
): ):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
print(f"Input image shape: {img.shape}") print(f"Input image shape: {img.shape}")
res = model(img, mask, config) res = model(img, mask, config)
ok = cv2.imwrite( ok = cv2.imwrite(

View File

@ -12,6 +12,7 @@
"@hookform/resolvers": "^3.3.2", "@hookform/resolvers": "^3.3.2",
"@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "^1.0.5", "@radix-ui/react-alert-dialog": "^1.0.5",
"@radix-ui/react-context-menu": "^2.1.5",
"@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-icons": "^1.3.0",
@ -1537,6 +1538,34 @@
} }
} }
}, },
"node_modules/@radix-ui/react-context-menu": {
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.1.5.tgz",
"integrity": "sha512-R5XaDj06Xul1KGb+WP8qiOh7tKJNz2durpLBXAGZjSVtctcRFCuEvy2gtMwRJGePwQQE5nV77gs4FwRi8T+r2g==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/primitive": "1.0.1",
"@radix-ui/react-context": "1.0.1",
"@radix-ui/react-menu": "2.0.6",
"@radix-ui/react-primitive": "1.0.3",
"@radix-ui/react-use-callback-ref": "1.0.1",
"@radix-ui/react-use-controllable-state": "1.0.1"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-dialog": { "node_modules/@radix-ui/react-dialog": {
"version": "1.0.5", "version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz", "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz",

View File

@ -14,6 +14,7 @@
"@hookform/resolvers": "^3.3.2", "@hookform/resolvers": "^3.3.2",
"@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "^1.0.5", "@radix-ui/react-alert-dialog": "^1.0.5",
"@radix-ui/react-context-menu": "^2.1.5",
"@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-icons": "^1.3.0",

View File

@ -1,5 +1,4 @@
import { useCallback, useEffect, useMemo, useRef } from "react" import { useCallback, useEffect, useRef } from "react"
import { nanoid } from "nanoid"
import useInputImage from "@/hooks/useInputImage" import useInputImage from "@/hooks/useInputImage"
import { keepGUIAlive } from "@/lib/utils" import { keepGUIAlive } from "@/lib/utils"
@ -52,10 +51,6 @@ function Home() {
fetchServerConfig() fetchServerConfig()
}, []) }, [])
const workspaceId = useMemo(() => {
return nanoid()
}, [file])
const dragCounter = useRef(0) const dragCounter = useRef(0)
const handleDrag = useCallback((event: any) => { const handleDrag = useCallback((event: any) => {
@ -155,7 +150,7 @@ function Home() {
<main className="flex min-h-screen flex-col items-center justify-between w-full bg-[radial-gradient(circle_at_1px_1px,_#8e8e8e8e_1px,_transparent_0)] [background-size:20px_20px] bg-repeat"> <main className="flex min-h-screen flex-col items-center justify-between w-full bg-[radial-gradient(circle_at_1px_1px,_#8e8e8e8e_1px,_transparent_0)] [background-size:20px_20px] bg-repeat">
<Toaster /> <Toaster />
<Header /> <Header />
<Workspace key={workspaceId} /> <Workspace />
{!file ? ( {!file ? (
<FileSelect <FileSelect
onSelection={async (f) => { onSelection={async (f) => {

View File

@ -52,9 +52,9 @@ const DiffusionProgress = () => {
return ( return (
<div <div
className="fixed w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]" className="z-10 fixed bg-background w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
style={{ style={{
visibility: isInpainting && isConnected && isSD ? "visible" : "hidden", visibility: isConnected && isInpainting && isSD ? "visible" : "hidden",
}} }}
> >
<Progress value={progress} /> <Progress value={progress} />

View File

@ -95,6 +95,7 @@ export default function Editor(props: EditorProps) {
const brushSize = useStore((state) => state.getBrushSize()) const brushSize = useStore((state) => state.getBrushSize())
const renders = useStore((state) => state.editorState.renders) const renders = useStore((state) => state.editorState.renders)
const extraMasks = useStore((state) => state.editorState.extraMasks) const extraMasks = useStore((state) => state.editorState.extraMasks)
const temporaryMasks = useStore((state) => state.editorState.temporaryMasks)
const lineGroups = useStore((state) => state.editorState.lineGroups) const lineGroups = useStore((state) => state.editorState.lineGroups)
const curLineGroup = useStore((state) => state.editorState.curLineGroup) const curLineGroup = useStore((state) => state.editorState.curLineGroup)
@ -166,6 +167,9 @@ export default function Editor(props: EditorProps) {
context.canvas.width = imageWidth context.canvas.width = imageWidth
context.canvas.height = imageHeight context.canvas.height = imageHeight
context.clearRect(0, 0, context.canvas.width, context.canvas.height) context.clearRect(0, 0, context.canvas.width, context.canvas.height)
temporaryMasks.forEach((maskImage) => {
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
})
extraMasks.forEach((maskImage) => { extraMasks.forEach((maskImage) => {
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
}) })
@ -182,20 +186,9 @@ export default function Editor(props: EditorProps) {
imageHeight imageHeight
) )
} }
if (
!interactiveSegState.isInteractiveSeg &&
interactiveSegState.interactiveSegMask
) {
context.drawImage(
interactiveSegState.interactiveSegMask,
0,
0,
imageWidth,
imageHeight
)
}
drawLines(context, curLineGroup) drawLines(context, curLineGroup)
}, [ }, [
temporaryMasks,
extraMasks, extraMasks,
isOriginalLoaded, isOriginalLoaded,
interactiveSegState, interactiveSegState,

View File

@ -21,7 +21,6 @@ const PromptInput = () => {
state.hidePrevMask, state.hidePrevMask,
]) ])
const ref = useRef(null) const ref = useRef(null)
useClickAway<MouseEvent>(ref, () => { useClickAway<MouseEvent>(ref, () => {
if (ref?.current) { if (ref?.current) {
const input = ref.current as HTMLInputElement const input = ref.current as HTMLInputElement

View File

@ -1,4 +1,4 @@
import { FormEvent } from "react" import { FormEvent, useRef } from "react"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import { Switch } from "../ui/switch" import { Switch } from "../ui/switch"
import { NumberInput } from "../ui/input" import { NumberInput } from "../ui/input"
@ -18,7 +18,8 @@ import { Slider } from "../ui/slider"
import { useImage } from "@/hooks/useImage" import { useImage } from "@/hooks/useImage"
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const" import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
import { RowContainer, LabelTitle } from "./LabelTitle" import { RowContainer, LabelTitle } from "./LabelTitle"
import { Upload } from "lucide-react" import { Minus, Plus, Upload } from "lucide-react"
import { useClickAway } from "react-use"
const ExtenderButton = ({ const ExtenderButton = ({
text, text,
@ -31,7 +32,7 @@ const ExtenderButton = ({
return ( return (
<Button <Button
variant="outline" variant="outline"
className="p-1 h-7" className="p-1 h-8"
disabled={!showExtender} disabled={!showExtender}
onClick={onClick} onClick={onClick}
> >
@ -51,6 +52,8 @@ const DiffusionOptions = () => {
updateAppState, updateAppState,
updateExtenderByBuiltIn, updateExtenderByBuiltIn,
updateExtenderDirection, updateExtenderDirection,
adjustMask,
clearMask,
] = useStore((state) => [ ] = useStore((state) => [
state.serverConfig.samplers, state.serverConfig.samplers,
state.settings, state.settings,
@ -61,8 +64,17 @@ const DiffusionOptions = () => {
state.updateAppState, state.updateAppState,
state.updateExtenderByBuiltIn, state.updateExtenderByBuiltIn,
state.updateExtenderDirection, state.updateExtenderDirection,
state.adjustMask,
state.clearMask,
]) ])
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile) const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
const negativePromptRef = useRef(null)
useClickAway<MouseEvent>(negativePromptRef, () => {
if (negativePromptRef?.current) {
const input = negativePromptRef.current as HTMLInputElement
input.blur()
}
})
const onKeyUp = (e: React.KeyboardEvent) => { const onKeyUp = (e: React.KeyboardEvent) => {
// negativePrompt 回车触发 inpainting // negativePrompt 回车触发 inpainting
@ -320,6 +332,7 @@ const DiffusionOptions = () => {
/> />
<div className="pl-2 pr-4"> <div className="pl-2 pr-4">
<Textarea <Textarea
ref={negativePromptRef}
rows={4} rows={4}
onKeyUp={onKeyUp} onKeyUp={onKeyUp}
className="max-h-[8rem] overflow-y-auto mb-2" className="max-h-[8rem] overflow-y-auto mb-2"
@ -550,7 +563,7 @@ const DiffusionOptions = () => {
<RowContainer> <RowContainer>
<LabelTitle <LabelTitle
text="Task" text="Task"
toolTip="When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above." toolTip="PowerPaint task. When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
/> />
<Select <Select
defaultValue={settings.powerpaintTask} defaultValue={settings.powerpaintTask}
@ -760,10 +773,85 @@ const DiffusionOptions = () => {
) )
} }
const renderMaskAdjuster = () => {
return (
<>
<div className="flex flex-col gap-1">
<LabelTitle
htmlFor="adjustMaskKernelSize"
text="Adjust Mask"
toolTip="Expand or shrink mask. Using the slider to adjust the kernel size for dilation or erosion."
/>
<RowContainer>
<Slider
className="w-[180px]"
defaultValue={[12]}
min={1}
max={100}
step={1}
value={[Math.floor(settings.adjustMaskKernelSize)]}
onValueChange={(vals) =>
updateSettings({ adjustMaskKernelSize: vals[0] })
}
/>
<NumberInput
id="adjustMaskKernelSize"
className="w-[60px] rounded-full"
numberValue={settings.adjustMaskKernelSize}
allowFloat={false}
onNumberValueChange={(val) => {
updateSettings({ adjustMaskKernelSize: val })
}}
/>
</RowContainer>
<RowContainer>
<div className="flex gap-1 justify-start">
<Button
variant="outline"
className="p-1 h-8"
onClick={() => adjustMask("expand")}
disabled={isProcessing}
>
<div className="flex items-center gap-1 select-none">
<Plus size={16} />
expand
</div>
</Button>
<Button
variant="outline"
className="p-1 h-8"
onClick={() => adjustMask("shrink")}
disabled={isProcessing}
>
<div className="flex items-center gap-1 select-none">
<Minus size={16} />
Shrink
</div>
</Button>
</div>
<Button
variant="outline"
className="p-1 h-8 justify-self-end"
onClick={clearMask}
disabled={isProcessing}
>
<div className="flex items-center gap-1 select-none">Clear</div>
</Button>
</RowContainer>
</div>
<Separator />
</>
)
}
return ( return (
<div className="flex flex-col gap-4 mt-4"> <div className="flex flex-col gap-4 mt-4">
{renderCropper()} {renderCropper()}
{renderExtender()} {renderExtender()}
{renderMaskAdjuster()}
{renderPowerPaintTaskType()} {renderPowerPaintTaskType()}
{renderSteps()} {renderSteps()}
{renderGuidanceScale()} {renderGuidanceScale()}

View File

@ -0,0 +1,202 @@
import * as React from "react"
import * as ContextMenuPrimitive from "@radix-ui/react-context-menu"
import {
CheckIcon,
ChevronRightIcon,
DotFilledIcon,
} from "@radix-ui/react-icons"
import { cn } from "@/lib/utils"
const ContextMenu = ContextMenuPrimitive.Root
const ContextMenuTrigger = ContextMenuPrimitive.Trigger
const ContextMenuGroup = ContextMenuPrimitive.Group
const ContextMenuPortal = ContextMenuPrimitive.Portal
const ContextMenuSub = ContextMenuPrimitive.Sub
const ContextMenuRadioGroup = ContextMenuPrimitive.RadioGroup
const ContextMenuSubTrigger = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.SubTrigger>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubTrigger> & {
inset?: boolean
}
>(({ className, inset, children, ...props }, ref) => (
<ContextMenuPrimitive.SubTrigger
ref={ref}
className={cn(
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[state=open]:text-accent-foreground",
inset && "pl-8",
className
)}
{...props}
>
{children}
<ChevronRightIcon className="ml-auto h-4 w-4" />
</ContextMenuPrimitive.SubTrigger>
))
ContextMenuSubTrigger.displayName = ContextMenuPrimitive.SubTrigger.displayName
const ContextMenuSubContent = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.SubContent>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubContent>
>(({ className, ...props }, ref) => (
<ContextMenuPrimitive.SubContent
ref={ref}
className={cn(
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-lg data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
className
)}
{...props}
/>
))
ContextMenuSubContent.displayName = ContextMenuPrimitive.SubContent.displayName
const ContextMenuContent = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Content>
>(({ className, ...props }, ref) => (
<ContextMenuPrimitive.Portal>
<ContextMenuPrimitive.Content
ref={ref}
className={cn(
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
className
)}
{...props}
/>
</ContextMenuPrimitive.Portal>
))
ContextMenuContent.displayName = ContextMenuPrimitive.Content.displayName
const ContextMenuItem = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Item> & {
inset?: boolean
}
>(({ className, inset, ...props }, ref) => (
<ContextMenuPrimitive.Item
ref={ref}
className={cn(
"relative flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
inset && "pl-8",
className
)}
{...props}
/>
))
ContextMenuItem.displayName = ContextMenuPrimitive.Item.displayName
const ContextMenuCheckboxItem = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.CheckboxItem>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.CheckboxItem>
>(({ className, children, checked, ...props }, ref) => (
<ContextMenuPrimitive.CheckboxItem
ref={ref}
className={cn(
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
className
)}
checked={checked}
{...props}
>
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
<ContextMenuPrimitive.ItemIndicator>
<CheckIcon className="h-4 w-4" />
</ContextMenuPrimitive.ItemIndicator>
</span>
{children}
</ContextMenuPrimitive.CheckboxItem>
))
ContextMenuCheckboxItem.displayName =
ContextMenuPrimitive.CheckboxItem.displayName
const ContextMenuRadioItem = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.RadioItem>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.RadioItem>
>(({ className, children, ...props }, ref) => (
<ContextMenuPrimitive.RadioItem
ref={ref}
className={cn(
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
className
)}
{...props}
>
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
<ContextMenuPrimitive.ItemIndicator>
<DotFilledIcon className="h-4 w-4 fill-current" />
</ContextMenuPrimitive.ItemIndicator>
</span>
{children}
</ContextMenuPrimitive.RadioItem>
))
ContextMenuRadioItem.displayName = ContextMenuPrimitive.RadioItem.displayName
const ContextMenuLabel = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Label>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Label> & {
inset?: boolean
}
>(({ className, inset, ...props }, ref) => (
<ContextMenuPrimitive.Label
ref={ref}
className={cn(
"px-2 py-1.5 text-sm font-semibold text-foreground",
inset && "pl-8",
className
)}
{...props}
/>
))
ContextMenuLabel.displayName = ContextMenuPrimitive.Label.displayName
const ContextMenuSeparator = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Separator>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Separator>
>(({ className, ...props }, ref) => (
<ContextMenuPrimitive.Separator
ref={ref}
className={cn("-mx-1 my-1 h-px bg-border", className)}
{...props}
/>
))
ContextMenuSeparator.displayName = ContextMenuPrimitive.Separator.displayName
const ContextMenuShortcut = ({
className,
...props
}: React.HTMLAttributes<HTMLSpanElement>) => {
return (
<span
className={cn(
"ml-auto text-xs tracking-widest text-muted-foreground",
className
)}
{...props}
/>
)
}
ContextMenuShortcut.displayName = "ContextMenuShortcut"
export {
ContextMenu,
ContextMenuTrigger,
ContextMenuContent,
ContextMenuItem,
ContextMenuCheckboxItem,
ContextMenuRadioItem,
ContextMenuLabel,
ContextMenuSeparator,
ContextMenuShortcut,
ContextMenuGroup,
ContextMenuPortal,
ContextMenuSub,
ContextMenuSubContent,
ContextMenuSubTrigger,
ContextMenuRadioGroup,
}

View File

@ -17,6 +17,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
const handleOnBlur = () => { const handleOnBlur = () => {
updateAppState({ disableShortCuts: false }) updateAppState({ disableShortCuts: false })
} }
return ( return (
<input <input
type={type} type={type}

View File

@ -201,3 +201,28 @@ export async function getSamplers(): Promise<string[]> {
const res = await api.post("/samplers") const res = await api.post("/samplers")
return res.data return res.data
} }
export async function postAdjustMask(
mask: File | Blob,
operate: "expand" | "shrink",
kernel_size: number
) {
const maskBase64 = await convertToBase64(mask)
const res = await fetch(`${API_ENDPOINT}/adjust_mask`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
mask: maskBase64,
operate: operate,
kernel_size: kernel_size,
}),
})
if (res.ok) {
const blob = await res.blob()
return blob
}
const errMsg = await res.json()
throw new Error(errMsg)
}

View File

@ -4,6 +4,7 @@ import { immer } from "zustand/middleware/immer"
import { castDraft } from "immer" import { castDraft } from "immer"
import { createWithEqualityFn } from "zustand/traditional" import { createWithEqualityFn } from "zustand/traditional"
import { import {
AdjustMaskOperate,
CV2Flag, CV2Flag,
ExtenderDirection, ExtenderDirection,
FreeuConfig, FreeuConfig,
@ -27,13 +28,14 @@ import {
PAINT_BY_EXAMPLE, PAINT_BY_EXAMPLE,
} from "./const" } from "./const"
import { import {
blobToImage,
canvasToImage, canvasToImage,
dataURItoBlob, dataURItoBlob,
generateMask, generateMask,
loadImage, loadImage,
srcToFile, srcToFile,
} from "./utils" } from "./utils"
import inpaint, { getGenInfo, runPlugin } from "./api" import inpaint, { getGenInfo, postAdjustMask, runPlugin } from "./api"
import { toast } from "@/components/ui/use-toast" import { toast } from "@/components/ui/use-toast"
type FileManagerState = { type FileManagerState = {
@ -102,13 +104,14 @@ export type Settings = {
// PowerPaint // PowerPaint
powerpaintTask: PowerPaintTask powerpaintTask: PowerPaintTask
// AdjustMask
adjustMaskKernelSize: number
} }
type InteractiveSegState = { type InteractiveSegState = {
isInteractiveSeg: boolean isInteractiveSeg: boolean
interactiveSegMask: HTMLImageElement | null
tmpInteractiveSegMask: HTMLImageElement | null tmpInteractiveSegMask: HTMLImageElement | null
prevInteractiveSegMask: HTMLImageElement | null
clicks: number[][] clicks: number[][]
} }
@ -119,8 +122,12 @@ type EditorState = {
lineGroups: LineGroup[] lineGroups: LineGroup[]
lastLineGroup: LineGroup lastLineGroup: LineGroup
curLineGroup: LineGroup curLineGroup: LineGroup
// 只用来显示
// mask from interactive-seg or other segmentation models
extraMasks: HTMLImageElement[] extraMasks: HTMLImageElement[]
prevExtraMasks: HTMLImageElement[]
temporaryMasks: HTMLImageElement[]
// redo 相关 // redo 相关
redoRenders: HTMLImageElement[] redoRenders: HTMLImageElement[]
redoCurLines: Line[] redoCurLines: Line[]
@ -135,6 +142,7 @@ type AppState = {
imageWidth: number imageWidth: number
isInpainting: boolean isInpainting: boolean
isPluginRunning: boolean isPluginRunning: boolean
isAdjustingMask: boolean
windowSize: Size windowSize: Size
editorState: EditorState editorState: EditorState
disableShortCuts: boolean disableShortCuts: boolean
@ -209,6 +217,9 @@ type AppAction = {
redo: () => void redo: () => void
undoDisabled: () => boolean undoDisabled: () => boolean
redoDisabled: () => boolean redoDisabled: () => boolean
adjustMask: (operate: AdjustMaskOperate) => Promise<void>
clearMask: () => void
} }
const defaultValues: AppState = { const defaultValues: AppState = {
@ -219,6 +230,7 @@ const defaultValues: AppState = {
imageWidth: 0, imageWidth: 0,
isInpainting: false, isInpainting: false,
isPluginRunning: false, isPluginRunning: false,
isAdjustingMask: false,
disableShortCuts: false, disableShortCuts: false,
windowSize: { windowSize: {
@ -230,6 +242,8 @@ const defaultValues: AppState = {
brushSizeScale: 1, brushSizeScale: 1,
renders: [], renders: [],
extraMasks: [], extraMasks: [],
prevExtraMasks: [],
temporaryMasks: [],
lineGroups: [], lineGroups: [],
lastLineGroup: [], lastLineGroup: [],
curLineGroup: [], curLineGroup: [],
@ -240,9 +254,7 @@ const defaultValues: AppState = {
interactiveSegState: { interactiveSegState: {
isInteractiveSeg: false, isInteractiveSeg: false,
interactiveSegMask: null,
tmpInteractiveSegMask: null, tmpInteractiveSegMask: null,
prevInteractiveSegMask: null,
clicks: [], clicks: [],
}, },
@ -323,6 +335,7 @@ const defaultValues: AppState = {
enableFreeu: false, enableFreeu: false,
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 }, freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
powerpaintTask: PowerPaintTask.text_guided, powerpaintTask: PowerPaintTask.text_guided,
adjustMaskKernelSize: 12,
}, },
} }
@ -335,10 +348,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
if (get().settings.showExtender) { if (get().settings.showExtender) {
return return
} }
const { lastLineGroup, curLineGroup } = get().editorState const { lastLineGroup, curLineGroup, prevExtraMasks, extraMasks } =
const { prevInteractiveSegMask, interactiveSegMask } = get().editorState
get().interactiveSegState if (curLineGroup.length !== 0 || extraMasks.length !== 0) {
if (curLineGroup.length !== 0 || interactiveSegMask !== null) {
return return
} }
const { imageWidth, imageHeight } = get() const { imageWidth, imageHeight } = get()
@ -347,13 +359,13 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
imageWidth, imageWidth,
imageHeight, imageHeight,
[lastLineGroup], [lastLineGroup],
prevInteractiveSegMask ? [prevInteractiveSegMask] : [], prevExtraMasks,
BRUSH_COLOR BRUSH_COLOR
) )
try { try {
const maskImage = await canvasToImage(maskCanvas) const maskImage = await canvasToImage(maskCanvas)
set((state) => { set((state) => {
state.editorState.extraMasks.push(castDraft(maskImage)) state.editorState.temporaryMasks.push(castDraft(maskImage))
}) })
} catch (e) { } catch (e) {
console.error(e) console.error(e)
@ -362,7 +374,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
}, },
hidePrevMask: () => { hidePrevMask: () => {
set((state) => { set((state) => {
state.editorState.extraMasks = [] state.editorState.temporaryMasks = []
}) })
}, },
@ -408,33 +420,36 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
return return
} }
const { lastLineGroup, curLineGroup, lineGroups, renders } = const {
get().editorState lastLineGroup,
curLineGroup,
const { interactiveSegMask, prevInteractiveSegMask } = lineGroups,
get().interactiveSegState renders,
prevExtraMasks,
extraMasks,
} = get().editorState
const useLastLineGroup = const useLastLineGroup =
curLineGroup.length === 0 && curLineGroup.length === 0 &&
interactiveSegMask === null && extraMasks.length === 0 &&
!settings.showExtender !settings.showExtender
// useLastLineGroup 的影响 // useLastLineGroup 的影响
// 1. 使用上一次的 mask // 1. 使用上一次的 mask
// 2. 结果替换当前 render // 2. 结果替换当前 render
let maskImage = null let maskImages: HTMLImageElement[] = []
let maskLineGroup: LineGroup = [] let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) { if (useLastLineGroup === true) {
maskLineGroup = lastLineGroup maskLineGroup = lastLineGroup
maskImage = prevInteractiveSegMask maskImages = prevExtraMasks
} else { } else {
maskLineGroup = curLineGroup maskLineGroup = curLineGroup
maskImage = interactiveSegMask maskImages = extraMasks
} }
if ( if (
maskLineGroup.length === 0 && maskLineGroup.length === 0 &&
maskImage === null && maskImages === null &&
!settings.showExtender !settings.showExtender
) { ) {
toast({ toast({
@ -474,7 +489,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
imageWidth, imageWidth,
imageHeight, imageHeight,
[maskLineGroup], [maskLineGroup],
maskImage ? [maskImage] : [] maskImages
) )
try { try {
@ -500,6 +515,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
lineGroups: newLineGroups, lineGroups: newLineGroups,
lastLineGroup: maskLineGroup, lastLineGroup: maskLineGroup,
curLineGroup: [], curLineGroup: [],
extraMasks: [],
prevExtraMasks: maskImages,
}) })
} catch (e: any) { } catch (e: any) {
toast({ toast({
@ -512,15 +529,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
set((state) => { set((state) => {
state.isInpainting = false state.isInpainting = false
}) })
const newInteractiveSegState = {
...defaultValues.interactiveSegState,
prevInteractiveSegMask: maskImage,
}
set((state) => {
state.interactiveSegState = castDraft(newInteractiveSegState)
})
}, },
runRenderablePlugin: async ( runRenderablePlugin: async (
@ -557,8 +565,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
} else { } else {
const newMask = new Image() const newMask = new Image()
await loadImage(newMask, blob) await loadImage(newMask, blob)
get().updateInteractiveSegState({ set((state) => {
interactiveSegMask: newMask, state.editorState.extraMasks.push(castDraft(newMask))
}) })
} }
const end = new Date() const end = new Date()
@ -618,7 +626,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
}, },
getIsProcessing: (): boolean => { getIsProcessing: (): boolean => {
return get().isInpainting || get().isPluginRunning return (
get().isInpainting || get().isPluginRunning || get().isAdjustingMask
)
}, },
isSD: (): boolean => { isSD: (): boolean => {
@ -809,14 +819,14 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
handleInteractiveSegAccept: () => { handleInteractiveSegAccept: () => {
set((state) => { set((state) => {
return { if (state.interactiveSegState.tmpInteractiveSegMask) {
...state, state.editorState.extraMasks.push(
interactiveSegState: { castDraft(state.interactiveSegState.tmpInteractiveSegMask)
...defaultValues.interactiveSegState, )
interactiveSegMask:
state.interactiveSegState.tmpInteractiveSegMask,
},
} }
state.interactiveSegState = castDraft({
...defaultValues.interactiveSegState,
})
}) })
}, },
@ -986,10 +996,54 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
set((state) => { set((state) => {
state.settings.seed = newValue state.settings.seed = newValue
}), }),
adjustMask: async (operate: AdjustMaskOperate) => {
const { imageWidth, imageHeight } = get()
const { curLineGroup, extraMasks } = get().editorState
const { adjustMaskKernelSize } = get().settings
if (curLineGroup.length === 0 && extraMasks.length === 0) {
return
}
set((state) => {
state.isAdjustingMask = true
})
const maskCanvas = generateMask(
imageWidth,
imageHeight,
[curLineGroup],
extraMasks,
BRUSH_COLOR
)
const maskBlob = dataURItoBlob(maskCanvas.toDataURL())
const newMaskBlob = await postAdjustMask(
maskBlob,
operate,
adjustMaskKernelSize
)
const newMask = await blobToImage(newMaskBlob)
// TODO: currently ignore stroke undo/redo
set((state) => {
state.editorState.extraMasks = [castDraft(newMask)]
state.editorState.curLineGroup = []
})
set((state) => {
state.isAdjustingMask = false
})
},
clearMask: () => {
set((state) => {
state.editorState.extraMasks = []
state.editorState.curLineGroup = []
})
},
})), })),
{ {
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique) name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
version: 0, version: 1,
partialize: (state) => partialize: (state) =>
Object.fromEntries( Object.fromEntries(
Object.entries(state).filter(([key]) => Object.entries(state).filter(([key]) =>

View File

@ -125,3 +125,5 @@ export enum PowerPaintTask {
object_remove = "object-remove", object_remove = "object-remove",
outpainting = "outpainting", outpainting = "outpainting",
} }
export type AdjustMaskOperate = "expand" | "shrink"

View File

@ -53,6 +53,13 @@ export function loadImage(image: HTMLImageElement, src: string) {
}) })
} }
export async function blobToImage(blob: Blob) {
const dataURL = URL.createObjectURL(blob)
const newImage = new Image()
await loadImage(newImage, dataURL)
return newImage
}
export function canvasToImage( export function canvasToImage(
canvas: HTMLCanvasElement canvas: HTMLCanvasElement
): Promise<HTMLImageElement> { ): Promise<HTMLImageElement> {