add adjust mask feature

This commit is contained in:
Qing 2024-01-05 14:57:30 +08:00
parent 2996544e75
commit e889e527ab
18 changed files with 507 additions and 76 deletions

View File

@ -29,6 +29,7 @@ from lama_cleaner.helper import (
numpy_to_bytes,
concat_alpha_channel,
gen_frontend_mask,
adjust_mask,
)
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo
@ -44,6 +45,7 @@ from lama_cleaner.schema import (
RunPluginRequest,
SDSampler,
PluginInfo,
AdjustMaskRequest,
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@ -150,6 +152,7 @@ class Api:
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
@ -294,6 +297,13 @@ class Api:
def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()]
def api_adjust_mask(self, req: AdjustMaskRequest):
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
cv2.imwrite("tmp_adjust_mask_input.png", mask)
mask = adjust_mask(mask, req.kernel_size, req.operate)
cv2.imwrite("tmp_adjust_mask.png", mask)
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
def launch(self):
self.app.include_router(self.router)
uvicorn.run(

View File

@ -358,16 +358,19 @@ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
mask[mask >= 127] = 255
mask[mask < 127] = 0
# fronted brush color "ffcc00bb"
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
)
if operate == "expand":
mask = cv2.dilate(
mask,
np.ones((kernel_size, kernel_size), np.uint8),
kernel,
iterations=1,
)
else:
mask = cv2.erode(
mask,
np.ones((kernel_size, kernel_size), np.uint8),
kernel,
iterations=1,
)
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)

View File

@ -110,8 +110,8 @@ class ApiConfig(BaseModel):
class InpaintRequest(BaseModel):
image: Optional[str] = Field(..., description="base64 encoded image")
mask: Optional[str] = Field(..., description="base64 encoded mask")
image: Optional[str] = Field(None, description="base64 encoded image")
mask: Optional[str] = Field(None, description="base64 encoded mask")
ldm_steps: int = Field(20, description="Steps for ldm model.")
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
@ -289,3 +289,12 @@ class ServerConfigResponse(BaseModel):
class SwitchModelRequest(BaseModel):
name: str
AdjustMaskOperate = Literal["expand", "shrink"]
class AdjustMaskRequest(BaseModel):
mask: str = Field(..., description="base64 encoded mask. 255 means area to do inpaint")
operate: AdjustMaskOperate = Field(..., description="expand or shrink")
kernel_size: int = Field(5, description="Kernel size for expanding mask")

View File

@ -0,0 +1,15 @@
import cv2
from lama_cleaner.helper import adjust_mask
from lama_cleaner.tests.utils import current_dir, save_dir
mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png"
def test_adjust_mask():
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
res_mask = adjust_mask(mask, 0, "expand")
cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask)
res_mask = adjust_mask(mask, 40, "expand")
cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask)
res_mask = adjust_mask(mask, 20, "shrink")
cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask)

View File

@ -31,8 +31,6 @@ def assert_equal(
mask_p=current_dir / "mask.png",
):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
print(f"Input image shape: {img.shape}")
res = model(img, mask, config)
ok = cv2.imwrite(

View File

@ -12,6 +12,7 @@
"@hookform/resolvers": "^3.3.2",
"@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "^1.0.5",
"@radix-ui/react-context-menu": "^2.1.5",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-icons": "^1.3.0",
@ -1537,6 +1538,34 @@
}
}
},
"node_modules/@radix-ui/react-context-menu": {
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.1.5.tgz",
"integrity": "sha512-R5XaDj06Xul1KGb+WP8qiOh7tKJNz2durpLBXAGZjSVtctcRFCuEvy2gtMwRJGePwQQE5nV77gs4FwRi8T+r2g==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/primitive": "1.0.1",
"@radix-ui/react-context": "1.0.1",
"@radix-ui/react-menu": "2.0.6",
"@radix-ui/react-primitive": "1.0.3",
"@radix-ui/react-use-callback-ref": "1.0.1",
"@radix-ui/react-use-controllable-state": "1.0.1"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-dialog": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz",

View File

@ -14,6 +14,7 @@
"@hookform/resolvers": "^3.3.2",
"@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "^1.0.5",
"@radix-ui/react-context-menu": "^2.1.5",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-dropdown-menu": "^2.0.6",
"@radix-ui/react-icons": "^1.3.0",

View File

@ -1,5 +1,4 @@
import { useCallback, useEffect, useMemo, useRef } from "react"
import { nanoid } from "nanoid"
import { useCallback, useEffect, useRef } from "react"
import useInputImage from "@/hooks/useInputImage"
import { keepGUIAlive } from "@/lib/utils"
@ -52,10 +51,6 @@ function Home() {
fetchServerConfig()
}, [])
const workspaceId = useMemo(() => {
return nanoid()
}, [file])
const dragCounter = useRef(0)
const handleDrag = useCallback((event: any) => {
@ -155,7 +150,7 @@ function Home() {
<main className="flex min-h-screen flex-col items-center justify-between w-full bg-[radial-gradient(circle_at_1px_1px,_#8e8e8e8e_1px,_transparent_0)] [background-size:20px_20px] bg-repeat">
<Toaster />
<Header />
<Workspace key={workspaceId} />
<Workspace />
{!file ? (
<FileSelect
onSelection={async (f) => {

View File

@ -52,9 +52,9 @@ const DiffusionProgress = () => {
return (
<div
className="fixed w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
className="z-10 fixed bg-background w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
style={{
visibility: isInpainting && isConnected && isSD ? "visible" : "hidden",
visibility: isConnected && isInpainting && isSD ? "visible" : "hidden",
}}
>
<Progress value={progress} />

View File

@ -95,6 +95,7 @@ export default function Editor(props: EditorProps) {
const brushSize = useStore((state) => state.getBrushSize())
const renders = useStore((state) => state.editorState.renders)
const extraMasks = useStore((state) => state.editorState.extraMasks)
const temporaryMasks = useStore((state) => state.editorState.temporaryMasks)
const lineGroups = useStore((state) => state.editorState.lineGroups)
const curLineGroup = useStore((state) => state.editorState.curLineGroup)
@ -166,6 +167,9 @@ export default function Editor(props: EditorProps) {
context.canvas.width = imageWidth
context.canvas.height = imageHeight
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
temporaryMasks.forEach((maskImage) => {
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
})
extraMasks.forEach((maskImage) => {
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
})
@ -182,20 +186,9 @@ export default function Editor(props: EditorProps) {
imageHeight
)
}
if (
!interactiveSegState.isInteractiveSeg &&
interactiveSegState.interactiveSegMask
) {
context.drawImage(
interactiveSegState.interactiveSegMask,
0,
0,
imageWidth,
imageHeight
)
}
drawLines(context, curLineGroup)
}, [
temporaryMasks,
extraMasks,
isOriginalLoaded,
interactiveSegState,

View File

@ -21,7 +21,6 @@ const PromptInput = () => {
state.hidePrevMask,
])
const ref = useRef(null)
useClickAway<MouseEvent>(ref, () => {
if (ref?.current) {
const input = ref.current as HTMLInputElement

View File

@ -1,4 +1,4 @@
import { FormEvent } from "react"
import { FormEvent, useRef } from "react"
import { useStore } from "@/lib/states"
import { Switch } from "../ui/switch"
import { NumberInput } from "../ui/input"
@ -18,7 +18,8 @@ import { Slider } from "../ui/slider"
import { useImage } from "@/hooks/useImage"
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
import { RowContainer, LabelTitle } from "./LabelTitle"
import { Upload } from "lucide-react"
import { Minus, Plus, Upload } from "lucide-react"
import { useClickAway } from "react-use"
const ExtenderButton = ({
text,
@ -31,7 +32,7 @@ const ExtenderButton = ({
return (
<Button
variant="outline"
className="p-1 h-7"
className="p-1 h-8"
disabled={!showExtender}
onClick={onClick}
>
@ -51,6 +52,8 @@ const DiffusionOptions = () => {
updateAppState,
updateExtenderByBuiltIn,
updateExtenderDirection,
adjustMask,
clearMask,
] = useStore((state) => [
state.serverConfig.samplers,
state.settings,
@ -61,8 +64,17 @@ const DiffusionOptions = () => {
state.updateAppState,
state.updateExtenderByBuiltIn,
state.updateExtenderDirection,
state.adjustMask,
state.clearMask,
])
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
const negativePromptRef = useRef(null)
useClickAway<MouseEvent>(negativePromptRef, () => {
if (negativePromptRef?.current) {
const input = negativePromptRef.current as HTMLInputElement
input.blur()
}
})
const onKeyUp = (e: React.KeyboardEvent) => {
// negativePrompt 回车触发 inpainting
@ -320,6 +332,7 @@ const DiffusionOptions = () => {
/>
<div className="pl-2 pr-4">
<Textarea
ref={negativePromptRef}
rows={4}
onKeyUp={onKeyUp}
className="max-h-[8rem] overflow-y-auto mb-2"
@ -550,7 +563,7 @@ const DiffusionOptions = () => {
<RowContainer>
<LabelTitle
text="Task"
toolTip="When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
toolTip="PowerPaint task. When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
/>
<Select
defaultValue={settings.powerpaintTask}
@ -760,10 +773,85 @@ const DiffusionOptions = () => {
)
}
const renderMaskAdjuster = () => {
return (
<>
<div className="flex flex-col gap-1">
<LabelTitle
htmlFor="adjustMaskKernelSize"
text="Adjust Mask"
toolTip="Expand or shrink mask. Using the slider to adjust the kernel size for dilation or erosion."
/>
<RowContainer>
<Slider
className="w-[180px]"
defaultValue={[12]}
min={1}
max={100}
step={1}
value={[Math.floor(settings.adjustMaskKernelSize)]}
onValueChange={(vals) =>
updateSettings({ adjustMaskKernelSize: vals[0] })
}
/>
<NumberInput
id="adjustMaskKernelSize"
className="w-[60px] rounded-full"
numberValue={settings.adjustMaskKernelSize}
allowFloat={false}
onNumberValueChange={(val) => {
updateSettings({ adjustMaskKernelSize: val })
}}
/>
</RowContainer>
<RowContainer>
<div className="flex gap-1 justify-start">
<Button
variant="outline"
className="p-1 h-8"
onClick={() => adjustMask("expand")}
disabled={isProcessing}
>
<div className="flex items-center gap-1 select-none">
<Plus size={16} />
expand
</div>
</Button>
<Button
variant="outline"
className="p-1 h-8"
onClick={() => adjustMask("shrink")}
disabled={isProcessing}
>
<div className="flex items-center gap-1 select-none">
<Minus size={16} />
Shrink
</div>
</Button>
</div>
<Button
variant="outline"
className="p-1 h-8 justify-self-end"
onClick={clearMask}
disabled={isProcessing}
>
<div className="flex items-center gap-1 select-none">Clear</div>
</Button>
</RowContainer>
</div>
<Separator />
</>
)
}
return (
<div className="flex flex-col gap-4 mt-4">
{renderCropper()}
{renderExtender()}
{renderMaskAdjuster()}
{renderPowerPaintTaskType()}
{renderSteps()}
{renderGuidanceScale()}

View File

@ -0,0 +1,202 @@
import * as React from "react"
import * as ContextMenuPrimitive from "@radix-ui/react-context-menu"
import {
CheckIcon,
ChevronRightIcon,
DotFilledIcon,
} from "@radix-ui/react-icons"
import { cn } from "@/lib/utils"
const ContextMenu = ContextMenuPrimitive.Root
const ContextMenuTrigger = ContextMenuPrimitive.Trigger
const ContextMenuGroup = ContextMenuPrimitive.Group
const ContextMenuPortal = ContextMenuPrimitive.Portal
const ContextMenuSub = ContextMenuPrimitive.Sub
const ContextMenuRadioGroup = ContextMenuPrimitive.RadioGroup
const ContextMenuSubTrigger = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.SubTrigger>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubTrigger> & {
inset?: boolean
}
>(({ className, inset, children, ...props }, ref) => (
<ContextMenuPrimitive.SubTrigger
ref={ref}
className={cn(
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[state=open]:text-accent-foreground",
inset && "pl-8",
className
)}
{...props}
>
{children}
<ChevronRightIcon className="ml-auto h-4 w-4" />
</ContextMenuPrimitive.SubTrigger>
))
ContextMenuSubTrigger.displayName = ContextMenuPrimitive.SubTrigger.displayName
const ContextMenuSubContent = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.SubContent>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubContent>
>(({ className, ...props }, ref) => (
<ContextMenuPrimitive.SubContent
ref={ref}
className={cn(
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-lg data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
className
)}
{...props}
/>
))
ContextMenuSubContent.displayName = ContextMenuPrimitive.SubContent.displayName
const ContextMenuContent = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Content>
>(({ className, ...props }, ref) => (
<ContextMenuPrimitive.Portal>
<ContextMenuPrimitive.Content
ref={ref}
className={cn(
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
className
)}
{...props}
/>
</ContextMenuPrimitive.Portal>
))
ContextMenuContent.displayName = ContextMenuPrimitive.Content.displayName
const ContextMenuItem = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Item> & {
inset?: boolean
}
>(({ className, inset, ...props }, ref) => (
<ContextMenuPrimitive.Item
ref={ref}
className={cn(
"relative flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
inset && "pl-8",
className
)}
{...props}
/>
))
ContextMenuItem.displayName = ContextMenuPrimitive.Item.displayName
const ContextMenuCheckboxItem = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.CheckboxItem>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.CheckboxItem>
>(({ className, children, checked, ...props }, ref) => (
<ContextMenuPrimitive.CheckboxItem
ref={ref}
className={cn(
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
className
)}
checked={checked}
{...props}
>
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
<ContextMenuPrimitive.ItemIndicator>
<CheckIcon className="h-4 w-4" />
</ContextMenuPrimitive.ItemIndicator>
</span>
{children}
</ContextMenuPrimitive.CheckboxItem>
))
ContextMenuCheckboxItem.displayName =
ContextMenuPrimitive.CheckboxItem.displayName
const ContextMenuRadioItem = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.RadioItem>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.RadioItem>
>(({ className, children, ...props }, ref) => (
<ContextMenuPrimitive.RadioItem
ref={ref}
className={cn(
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
className
)}
{...props}
>
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
<ContextMenuPrimitive.ItemIndicator>
<DotFilledIcon className="h-4 w-4 fill-current" />
</ContextMenuPrimitive.ItemIndicator>
</span>
{children}
</ContextMenuPrimitive.RadioItem>
))
ContextMenuRadioItem.displayName = ContextMenuPrimitive.RadioItem.displayName
const ContextMenuLabel = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Label>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Label> & {
inset?: boolean
}
>(({ className, inset, ...props }, ref) => (
<ContextMenuPrimitive.Label
ref={ref}
className={cn(
"px-2 py-1.5 text-sm font-semibold text-foreground",
inset && "pl-8",
className
)}
{...props}
/>
))
ContextMenuLabel.displayName = ContextMenuPrimitive.Label.displayName
const ContextMenuSeparator = React.forwardRef<
React.ElementRef<typeof ContextMenuPrimitive.Separator>,
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Separator>
>(({ className, ...props }, ref) => (
<ContextMenuPrimitive.Separator
ref={ref}
className={cn("-mx-1 my-1 h-px bg-border", className)}
{...props}
/>
))
ContextMenuSeparator.displayName = ContextMenuPrimitive.Separator.displayName
const ContextMenuShortcut = ({
className,
...props
}: React.HTMLAttributes<HTMLSpanElement>) => {
return (
<span
className={cn(
"ml-auto text-xs tracking-widest text-muted-foreground",
className
)}
{...props}
/>
)
}
ContextMenuShortcut.displayName = "ContextMenuShortcut"
export {
ContextMenu,
ContextMenuTrigger,
ContextMenuContent,
ContextMenuItem,
ContextMenuCheckboxItem,
ContextMenuRadioItem,
ContextMenuLabel,
ContextMenuSeparator,
ContextMenuShortcut,
ContextMenuGroup,
ContextMenuPortal,
ContextMenuSub,
ContextMenuSubContent,
ContextMenuSubTrigger,
ContextMenuRadioGroup,
}

View File

@ -17,6 +17,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
const handleOnBlur = () => {
updateAppState({ disableShortCuts: false })
}
return (
<input
type={type}

View File

@ -201,3 +201,28 @@ export async function getSamplers(): Promise<string[]> {
const res = await api.post("/samplers")
return res.data
}
export async function postAdjustMask(
mask: File | Blob,
operate: "expand" | "shrink",
kernel_size: number
) {
const maskBase64 = await convertToBase64(mask)
const res = await fetch(`${API_ENDPOINT}/adjust_mask`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
mask: maskBase64,
operate: operate,
kernel_size: kernel_size,
}),
})
if (res.ok) {
const blob = await res.blob()
return blob
}
const errMsg = await res.json()
throw new Error(errMsg)
}

View File

@ -4,6 +4,7 @@ import { immer } from "zustand/middleware/immer"
import { castDraft } from "immer"
import { createWithEqualityFn } from "zustand/traditional"
import {
AdjustMaskOperate,
CV2Flag,
ExtenderDirection,
FreeuConfig,
@ -27,13 +28,14 @@ import {
PAINT_BY_EXAMPLE,
} from "./const"
import {
blobToImage,
canvasToImage,
dataURItoBlob,
generateMask,
loadImage,
srcToFile,
} from "./utils"
import inpaint, { getGenInfo, runPlugin } from "./api"
import inpaint, { getGenInfo, postAdjustMask, runPlugin } from "./api"
import { toast } from "@/components/ui/use-toast"
type FileManagerState = {
@ -102,13 +104,14 @@ export type Settings = {
// PowerPaint
powerpaintTask: PowerPaintTask
// AdjustMask
adjustMaskKernelSize: number
}
type InteractiveSegState = {
isInteractiveSeg: boolean
interactiveSegMask: HTMLImageElement | null
tmpInteractiveSegMask: HTMLImageElement | null
prevInteractiveSegMask: HTMLImageElement | null
clicks: number[][]
}
@ -119,8 +122,12 @@ type EditorState = {
lineGroups: LineGroup[]
lastLineGroup: LineGroup
curLineGroup: LineGroup
// 只用来显示
// mask from interactive-seg or other segmentation models
extraMasks: HTMLImageElement[]
prevExtraMasks: HTMLImageElement[]
temporaryMasks: HTMLImageElement[]
// redo 相关
redoRenders: HTMLImageElement[]
redoCurLines: Line[]
@ -135,6 +142,7 @@ type AppState = {
imageWidth: number
isInpainting: boolean
isPluginRunning: boolean
isAdjustingMask: boolean
windowSize: Size
editorState: EditorState
disableShortCuts: boolean
@ -209,6 +217,9 @@ type AppAction = {
redo: () => void
undoDisabled: () => boolean
redoDisabled: () => boolean
adjustMask: (operate: AdjustMaskOperate) => Promise<void>
clearMask: () => void
}
const defaultValues: AppState = {
@ -219,6 +230,7 @@ const defaultValues: AppState = {
imageWidth: 0,
isInpainting: false,
isPluginRunning: false,
isAdjustingMask: false,
disableShortCuts: false,
windowSize: {
@ -230,6 +242,8 @@ const defaultValues: AppState = {
brushSizeScale: 1,
renders: [],
extraMasks: [],
prevExtraMasks: [],
temporaryMasks: [],
lineGroups: [],
lastLineGroup: [],
curLineGroup: [],
@ -240,9 +254,7 @@ const defaultValues: AppState = {
interactiveSegState: {
isInteractiveSeg: false,
interactiveSegMask: null,
tmpInteractiveSegMask: null,
prevInteractiveSegMask: null,
clicks: [],
},
@ -323,6 +335,7 @@ const defaultValues: AppState = {
enableFreeu: false,
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
powerpaintTask: PowerPaintTask.text_guided,
adjustMaskKernelSize: 12,
},
}
@ -335,10 +348,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
if (get().settings.showExtender) {
return
}
const { lastLineGroup, curLineGroup } = get().editorState
const { prevInteractiveSegMask, interactiveSegMask } =
get().interactiveSegState
if (curLineGroup.length !== 0 || interactiveSegMask !== null) {
const { lastLineGroup, curLineGroup, prevExtraMasks, extraMasks } =
get().editorState
if (curLineGroup.length !== 0 || extraMasks.length !== 0) {
return
}
const { imageWidth, imageHeight } = get()
@ -347,13 +359,13 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
imageWidth,
imageHeight,
[lastLineGroup],
prevInteractiveSegMask ? [prevInteractiveSegMask] : [],
prevExtraMasks,
BRUSH_COLOR
)
try {
const maskImage = await canvasToImage(maskCanvas)
set((state) => {
state.editorState.extraMasks.push(castDraft(maskImage))
state.editorState.temporaryMasks.push(castDraft(maskImage))
})
} catch (e) {
console.error(e)
@ -362,7 +374,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
},
hidePrevMask: () => {
set((state) => {
state.editorState.extraMasks = []
state.editorState.temporaryMasks = []
})
},
@ -408,33 +420,36 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
return
}
const { lastLineGroup, curLineGroup, lineGroups, renders } =
get().editorState
const { interactiveSegMask, prevInteractiveSegMask } =
get().interactiveSegState
const {
lastLineGroup,
curLineGroup,
lineGroups,
renders,
prevExtraMasks,
extraMasks,
} = get().editorState
const useLastLineGroup =
curLineGroup.length === 0 &&
interactiveSegMask === null &&
extraMasks.length === 0 &&
!settings.showExtender
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
let maskImage = null
let maskImages: HTMLImageElement[] = []
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
maskLineGroup = lastLineGroup
maskImage = prevInteractiveSegMask
maskImages = prevExtraMasks
} else {
maskLineGroup = curLineGroup
maskImage = interactiveSegMask
maskImages = extraMasks
}
if (
maskLineGroup.length === 0 &&
maskImage === null &&
maskImages === null &&
!settings.showExtender
) {
toast({
@ -474,7 +489,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
imageWidth,
imageHeight,
[maskLineGroup],
maskImage ? [maskImage] : []
maskImages
)
try {
@ -500,6 +515,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
lineGroups: newLineGroups,
lastLineGroup: maskLineGroup,
curLineGroup: [],
extraMasks: [],
prevExtraMasks: maskImages,
})
} catch (e: any) {
toast({
@ -512,15 +529,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
set((state) => {
state.isInpainting = false
})
const newInteractiveSegState = {
...defaultValues.interactiveSegState,
prevInteractiveSegMask: maskImage,
}
set((state) => {
state.interactiveSegState = castDraft(newInteractiveSegState)
})
},
runRenderablePlugin: async (
@ -557,8 +565,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
} else {
const newMask = new Image()
await loadImage(newMask, blob)
get().updateInteractiveSegState({
interactiveSegMask: newMask,
set((state) => {
state.editorState.extraMasks.push(castDraft(newMask))
})
}
const end = new Date()
@ -618,7 +626,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
},
getIsProcessing: (): boolean => {
return get().isInpainting || get().isPluginRunning
return (
get().isInpainting || get().isPluginRunning || get().isAdjustingMask
)
},
isSD: (): boolean => {
@ -809,14 +819,14 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
handleInteractiveSegAccept: () => {
set((state) => {
return {
...state,
interactiveSegState: {
...defaultValues.interactiveSegState,
interactiveSegMask:
state.interactiveSegState.tmpInteractiveSegMask,
},
if (state.interactiveSegState.tmpInteractiveSegMask) {
state.editorState.extraMasks.push(
castDraft(state.interactiveSegState.tmpInteractiveSegMask)
)
}
state.interactiveSegState = castDraft({
...defaultValues.interactiveSegState,
})
})
},
@ -986,10 +996,54 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
set((state) => {
state.settings.seed = newValue
}),
adjustMask: async (operate: AdjustMaskOperate) => {
const { imageWidth, imageHeight } = get()
const { curLineGroup, extraMasks } = get().editorState
const { adjustMaskKernelSize } = get().settings
if (curLineGroup.length === 0 && extraMasks.length === 0) {
return
}
set((state) => {
state.isAdjustingMask = true
})
const maskCanvas = generateMask(
imageWidth,
imageHeight,
[curLineGroup],
extraMasks,
BRUSH_COLOR
)
const maskBlob = dataURItoBlob(maskCanvas.toDataURL())
const newMaskBlob = await postAdjustMask(
maskBlob,
operate,
adjustMaskKernelSize
)
const newMask = await blobToImage(newMaskBlob)
// TODO: currently ignore stroke undo/redo
set((state) => {
state.editorState.extraMasks = [castDraft(newMask)]
state.editorState.curLineGroup = []
})
set((state) => {
state.isAdjustingMask = false
})
},
clearMask: () => {
set((state) => {
state.editorState.extraMasks = []
state.editorState.curLineGroup = []
})
},
})),
{
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
version: 0,
version: 1,
partialize: (state) =>
Object.fromEntries(
Object.entries(state).filter(([key]) =>

View File

@ -125,3 +125,5 @@ export enum PowerPaintTask {
object_remove = "object-remove",
outpainting = "outpainting",
}
export type AdjustMaskOperate = "expand" | "shrink"

View File

@ -53,6 +53,13 @@ export function loadImage(image: HTMLImageElement, src: string) {
})
}
export async function blobToImage(blob: Blob) {
const dataURL = URL.createObjectURL(blob)
const newImage = new Image()
await loadImage(newImage, dataURL)
return newImage
}
export function canvasToImage(
canvas: HTMLCanvasElement
): Promise<HTMLImageElement> {