add adjust mask feature
This commit is contained in:
parent
2996544e75
commit
e889e527ab
@ -29,6 +29,7 @@ from lama_cleaner.helper import (
|
|||||||
numpy_to_bytes,
|
numpy_to_bytes,
|
||||||
concat_alpha_channel,
|
concat_alpha_channel,
|
||||||
gen_frontend_mask,
|
gen_frontend_mask,
|
||||||
|
adjust_mask,
|
||||||
)
|
)
|
||||||
from lama_cleaner.model.utils import torch_gc
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.model_info import ModelInfo
|
from lama_cleaner.model_info import ModelInfo
|
||||||
@ -44,6 +45,7 @@ from lama_cleaner.schema import (
|
|||||||
RunPluginRequest,
|
RunPluginRequest,
|
||||||
SDSampler,
|
SDSampler,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
|
AdjustMaskRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||||
@ -150,6 +152,7 @@ class Api:
|
|||||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||||
|
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
||||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -294,6 +297,13 @@ class Api:
|
|||||||
def api_samplers(self) -> List[str]:
|
def api_samplers(self) -> List[str]:
|
||||||
return [member.value for member in SDSampler.__members__.values()]
|
return [member.value for member in SDSampler.__members__.values()]
|
||||||
|
|
||||||
|
def api_adjust_mask(self, req: AdjustMaskRequest):
|
||||||
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||||
|
cv2.imwrite("tmp_adjust_mask_input.png", mask)
|
||||||
|
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
||||||
|
cv2.imwrite("tmp_adjust_mask.png", mask)
|
||||||
|
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
||||||
|
|
||||||
def launch(self):
|
def launch(self):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
|
@ -358,16 +358,19 @@ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
|
|||||||
mask[mask >= 127] = 255
|
mask[mask >= 127] = 255
|
||||||
mask[mask < 127] = 0
|
mask[mask < 127] = 0
|
||||||
# fronted brush color "ffcc00bb"
|
# fronted brush color "ffcc00bb"
|
||||||
|
kernel = cv2.getStructuringElement(
|
||||||
|
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
|
||||||
|
)
|
||||||
if operate == "expand":
|
if operate == "expand":
|
||||||
mask = cv2.dilate(
|
mask = cv2.dilate(
|
||||||
mask,
|
mask,
|
||||||
np.ones((kernel_size, kernel_size), np.uint8),
|
kernel,
|
||||||
iterations=1,
|
iterations=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
mask = cv2.erode(
|
mask = cv2.erode(
|
||||||
mask,
|
mask,
|
||||||
np.ones((kernel_size, kernel_size), np.uint8),
|
kernel,
|
||||||
iterations=1,
|
iterations=1,
|
||||||
)
|
)
|
||||||
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
||||||
|
@ -110,8 +110,8 @@ class ApiConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class InpaintRequest(BaseModel):
|
class InpaintRequest(BaseModel):
|
||||||
image: Optional[str] = Field(..., description="base64 encoded image")
|
image: Optional[str] = Field(None, description="base64 encoded image")
|
||||||
mask: Optional[str] = Field(..., description="base64 encoded mask")
|
mask: Optional[str] = Field(None, description="base64 encoded mask")
|
||||||
|
|
||||||
ldm_steps: int = Field(20, description="Steps for ldm model.")
|
ldm_steps: int = Field(20, description="Steps for ldm model.")
|
||||||
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
|
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
|
||||||
@ -289,3 +289,12 @@ class ServerConfigResponse(BaseModel):
|
|||||||
|
|
||||||
class SwitchModelRequest(BaseModel):
|
class SwitchModelRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
AdjustMaskOperate = Literal["expand", "shrink"]
|
||||||
|
|
||||||
|
|
||||||
|
class AdjustMaskRequest(BaseModel):
|
||||||
|
mask: str = Field(..., description="base64 encoded mask. 255 means area to do inpaint")
|
||||||
|
operate: AdjustMaskOperate = Field(..., description="expand or shrink")
|
||||||
|
kernel_size: int = Field(5, description="Kernel size for expanding mask")
|
||||||
|
15
lama_cleaner/tests/test_adjust_mask.py
Normal file
15
lama_cleaner/tests/test_adjust_mask.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import cv2
|
||||||
|
from lama_cleaner.helper import adjust_mask
|
||||||
|
from lama_cleaner.tests.utils import current_dir, save_dir
|
||||||
|
|
||||||
|
mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png"
|
||||||
|
|
||||||
|
|
||||||
|
def test_adjust_mask():
|
||||||
|
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||||
|
res_mask = adjust_mask(mask, 0, "expand")
|
||||||
|
cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask)
|
||||||
|
res_mask = adjust_mask(mask, 40, "expand")
|
||||||
|
cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask)
|
||||||
|
res_mask = adjust_mask(mask, 20, "shrink")
|
||||||
|
cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask)
|
@ -31,8 +31,6 @@ def assert_equal(
|
|||||||
mask_p=current_dir / "mask.png",
|
mask_p=current_dir / "mask.png",
|
||||||
):
|
):
|
||||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||||
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
|
|
||||||
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
|
|
||||||
print(f"Input image shape: {img.shape}")
|
print(f"Input image shape: {img.shape}")
|
||||||
res = model(img, mask, config)
|
res = model(img, mask, config)
|
||||||
ok = cv2.imwrite(
|
ok = cv2.imwrite(
|
||||||
|
29
web_app/package-lock.json
generated
29
web_app/package-lock.json
generated
@ -12,6 +12,7 @@
|
|||||||
"@hookform/resolvers": "^3.3.2",
|
"@hookform/resolvers": "^3.3.2",
|
||||||
"@radix-ui/react-accordion": "^1.1.2",
|
"@radix-ui/react-accordion": "^1.1.2",
|
||||||
"@radix-ui/react-alert-dialog": "^1.0.5",
|
"@radix-ui/react-alert-dialog": "^1.0.5",
|
||||||
|
"@radix-ui/react-context-menu": "^2.1.5",
|
||||||
"@radix-ui/react-dialog": "^1.0.5",
|
"@radix-ui/react-dialog": "^1.0.5",
|
||||||
"@radix-ui/react-dropdown-menu": "^2.0.6",
|
"@radix-ui/react-dropdown-menu": "^2.0.6",
|
||||||
"@radix-ui/react-icons": "^1.3.0",
|
"@radix-ui/react-icons": "^1.3.0",
|
||||||
@ -1537,6 +1538,34 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@radix-ui/react-context-menu": {
|
||||||
|
"version": "2.1.5",
|
||||||
|
"resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.1.5.tgz",
|
||||||
|
"integrity": "sha512-R5XaDj06Xul1KGb+WP8qiOh7tKJNz2durpLBXAGZjSVtctcRFCuEvy2gtMwRJGePwQQE5nV77gs4FwRi8T+r2g==",
|
||||||
|
"dependencies": {
|
||||||
|
"@babel/runtime": "^7.13.10",
|
||||||
|
"@radix-ui/primitive": "1.0.1",
|
||||||
|
"@radix-ui/react-context": "1.0.1",
|
||||||
|
"@radix-ui/react-menu": "2.0.6",
|
||||||
|
"@radix-ui/react-primitive": "1.0.3",
|
||||||
|
"@radix-ui/react-use-callback-ref": "1.0.1",
|
||||||
|
"@radix-ui/react-use-controllable-state": "1.0.1"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@types/react": "*",
|
||||||
|
"@types/react-dom": "*",
|
||||||
|
"react": "^16.8 || ^17.0 || ^18.0",
|
||||||
|
"react-dom": "^16.8 || ^17.0 || ^18.0"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"@types/react": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"@types/react-dom": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@radix-ui/react-dialog": {
|
"node_modules/@radix-ui/react-dialog": {
|
||||||
"version": "1.0.5",
|
"version": "1.0.5",
|
||||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz",
|
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz",
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
"@hookform/resolvers": "^3.3.2",
|
"@hookform/resolvers": "^3.3.2",
|
||||||
"@radix-ui/react-accordion": "^1.1.2",
|
"@radix-ui/react-accordion": "^1.1.2",
|
||||||
"@radix-ui/react-alert-dialog": "^1.0.5",
|
"@radix-ui/react-alert-dialog": "^1.0.5",
|
||||||
|
"@radix-ui/react-context-menu": "^2.1.5",
|
||||||
"@radix-ui/react-dialog": "^1.0.5",
|
"@radix-ui/react-dialog": "^1.0.5",
|
||||||
"@radix-ui/react-dropdown-menu": "^2.0.6",
|
"@radix-ui/react-dropdown-menu": "^2.0.6",
|
||||||
"@radix-ui/react-icons": "^1.3.0",
|
"@radix-ui/react-icons": "^1.3.0",
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { useCallback, useEffect, useMemo, useRef } from "react"
|
import { useCallback, useEffect, useRef } from "react"
|
||||||
import { nanoid } from "nanoid"
|
|
||||||
|
|
||||||
import useInputImage from "@/hooks/useInputImage"
|
import useInputImage from "@/hooks/useInputImage"
|
||||||
import { keepGUIAlive } from "@/lib/utils"
|
import { keepGUIAlive } from "@/lib/utils"
|
||||||
@ -52,10 +51,6 @@ function Home() {
|
|||||||
fetchServerConfig()
|
fetchServerConfig()
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const workspaceId = useMemo(() => {
|
|
||||||
return nanoid()
|
|
||||||
}, [file])
|
|
||||||
|
|
||||||
const dragCounter = useRef(0)
|
const dragCounter = useRef(0)
|
||||||
|
|
||||||
const handleDrag = useCallback((event: any) => {
|
const handleDrag = useCallback((event: any) => {
|
||||||
@ -155,7 +150,7 @@ function Home() {
|
|||||||
<main className="flex min-h-screen flex-col items-center justify-between w-full bg-[radial-gradient(circle_at_1px_1px,_#8e8e8e8e_1px,_transparent_0)] [background-size:20px_20px] bg-repeat">
|
<main className="flex min-h-screen flex-col items-center justify-between w-full bg-[radial-gradient(circle_at_1px_1px,_#8e8e8e8e_1px,_transparent_0)] [background-size:20px_20px] bg-repeat">
|
||||||
<Toaster />
|
<Toaster />
|
||||||
<Header />
|
<Header />
|
||||||
<Workspace key={workspaceId} />
|
<Workspace />
|
||||||
{!file ? (
|
{!file ? (
|
||||||
<FileSelect
|
<FileSelect
|
||||||
onSelection={async (f) => {
|
onSelection={async (f) => {
|
||||||
|
@ -52,9 +52,9 @@ const DiffusionProgress = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className="fixed w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
|
className="z-10 fixed bg-background w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
|
||||||
style={{
|
style={{
|
||||||
visibility: isInpainting && isConnected && isSD ? "visible" : "hidden",
|
visibility: isConnected && isInpainting && isSD ? "visible" : "hidden",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Progress value={progress} />
|
<Progress value={progress} />
|
||||||
|
@ -95,6 +95,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
const brushSize = useStore((state) => state.getBrushSize())
|
const brushSize = useStore((state) => state.getBrushSize())
|
||||||
const renders = useStore((state) => state.editorState.renders)
|
const renders = useStore((state) => state.editorState.renders)
|
||||||
const extraMasks = useStore((state) => state.editorState.extraMasks)
|
const extraMasks = useStore((state) => state.editorState.extraMasks)
|
||||||
|
const temporaryMasks = useStore((state) => state.editorState.temporaryMasks)
|
||||||
const lineGroups = useStore((state) => state.editorState.lineGroups)
|
const lineGroups = useStore((state) => state.editorState.lineGroups)
|
||||||
const curLineGroup = useStore((state) => state.editorState.curLineGroup)
|
const curLineGroup = useStore((state) => state.editorState.curLineGroup)
|
||||||
|
|
||||||
@ -166,6 +167,9 @@ export default function Editor(props: EditorProps) {
|
|||||||
context.canvas.width = imageWidth
|
context.canvas.width = imageWidth
|
||||||
context.canvas.height = imageHeight
|
context.canvas.height = imageHeight
|
||||||
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
|
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
|
||||||
|
temporaryMasks.forEach((maskImage) => {
|
||||||
|
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||||
|
})
|
||||||
extraMasks.forEach((maskImage) => {
|
extraMasks.forEach((maskImage) => {
|
||||||
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||||
})
|
})
|
||||||
@ -182,20 +186,9 @@ export default function Editor(props: EditorProps) {
|
|||||||
imageHeight
|
imageHeight
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if (
|
|
||||||
!interactiveSegState.isInteractiveSeg &&
|
|
||||||
interactiveSegState.interactiveSegMask
|
|
||||||
) {
|
|
||||||
context.drawImage(
|
|
||||||
interactiveSegState.interactiveSegMask,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
imageWidth,
|
|
||||||
imageHeight
|
|
||||||
)
|
|
||||||
}
|
|
||||||
drawLines(context, curLineGroup)
|
drawLines(context, curLineGroup)
|
||||||
}, [
|
}, [
|
||||||
|
temporaryMasks,
|
||||||
extraMasks,
|
extraMasks,
|
||||||
isOriginalLoaded,
|
isOriginalLoaded,
|
||||||
interactiveSegState,
|
interactiveSegState,
|
||||||
|
@ -21,7 +21,6 @@ const PromptInput = () => {
|
|||||||
state.hidePrevMask,
|
state.hidePrevMask,
|
||||||
])
|
])
|
||||||
const ref = useRef(null)
|
const ref = useRef(null)
|
||||||
|
|
||||||
useClickAway<MouseEvent>(ref, () => {
|
useClickAway<MouseEvent>(ref, () => {
|
||||||
if (ref?.current) {
|
if (ref?.current) {
|
||||||
const input = ref.current as HTMLInputElement
|
const input = ref.current as HTMLInputElement
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { FormEvent } from "react"
|
import { FormEvent, useRef } from "react"
|
||||||
import { useStore } from "@/lib/states"
|
import { useStore } from "@/lib/states"
|
||||||
import { Switch } from "../ui/switch"
|
import { Switch } from "../ui/switch"
|
||||||
import { NumberInput } from "../ui/input"
|
import { NumberInput } from "../ui/input"
|
||||||
@ -18,7 +18,8 @@ import { Slider } from "../ui/slider"
|
|||||||
import { useImage } from "@/hooks/useImage"
|
import { useImage } from "@/hooks/useImage"
|
||||||
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
|
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
|
||||||
import { RowContainer, LabelTitle } from "./LabelTitle"
|
import { RowContainer, LabelTitle } from "./LabelTitle"
|
||||||
import { Upload } from "lucide-react"
|
import { Minus, Plus, Upload } from "lucide-react"
|
||||||
|
import { useClickAway } from "react-use"
|
||||||
|
|
||||||
const ExtenderButton = ({
|
const ExtenderButton = ({
|
||||||
text,
|
text,
|
||||||
@ -31,7 +32,7 @@ const ExtenderButton = ({
|
|||||||
return (
|
return (
|
||||||
<Button
|
<Button
|
||||||
variant="outline"
|
variant="outline"
|
||||||
className="p-1 h-7"
|
className="p-1 h-8"
|
||||||
disabled={!showExtender}
|
disabled={!showExtender}
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
>
|
>
|
||||||
@ -51,6 +52,8 @@ const DiffusionOptions = () => {
|
|||||||
updateAppState,
|
updateAppState,
|
||||||
updateExtenderByBuiltIn,
|
updateExtenderByBuiltIn,
|
||||||
updateExtenderDirection,
|
updateExtenderDirection,
|
||||||
|
adjustMask,
|
||||||
|
clearMask,
|
||||||
] = useStore((state) => [
|
] = useStore((state) => [
|
||||||
state.serverConfig.samplers,
|
state.serverConfig.samplers,
|
||||||
state.settings,
|
state.settings,
|
||||||
@ -61,8 +64,17 @@ const DiffusionOptions = () => {
|
|||||||
state.updateAppState,
|
state.updateAppState,
|
||||||
state.updateExtenderByBuiltIn,
|
state.updateExtenderByBuiltIn,
|
||||||
state.updateExtenderDirection,
|
state.updateExtenderDirection,
|
||||||
|
state.adjustMask,
|
||||||
|
state.clearMask,
|
||||||
])
|
])
|
||||||
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
||||||
|
const negativePromptRef = useRef(null)
|
||||||
|
useClickAway<MouseEvent>(negativePromptRef, () => {
|
||||||
|
if (negativePromptRef?.current) {
|
||||||
|
const input = negativePromptRef.current as HTMLInputElement
|
||||||
|
input.blur()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
const onKeyUp = (e: React.KeyboardEvent) => {
|
const onKeyUp = (e: React.KeyboardEvent) => {
|
||||||
// negativePrompt 回车触发 inpainting
|
// negativePrompt 回车触发 inpainting
|
||||||
@ -320,6 +332,7 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<div className="pl-2 pr-4">
|
<div className="pl-2 pr-4">
|
||||||
<Textarea
|
<Textarea
|
||||||
|
ref={negativePromptRef}
|
||||||
rows={4}
|
rows={4}
|
||||||
onKeyUp={onKeyUp}
|
onKeyUp={onKeyUp}
|
||||||
className="max-h-[8rem] overflow-y-auto mb-2"
|
className="max-h-[8rem] overflow-y-auto mb-2"
|
||||||
@ -550,7 +563,7 @@ const DiffusionOptions = () => {
|
|||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Task"
|
text="Task"
|
||||||
toolTip="When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
|
toolTip="PowerPaint task. When using extender, image-outpainting task will be auto used. For object-removal and image-outpainting, it is recommended to set the guidance_scale at 10 or above."
|
||||||
/>
|
/>
|
||||||
<Select
|
<Select
|
||||||
defaultValue={settings.powerpaintTask}
|
defaultValue={settings.powerpaintTask}
|
||||||
@ -760,10 +773,85 @@ const DiffusionOptions = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderMaskAdjuster = () => {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className="flex flex-col gap-1">
|
||||||
|
<LabelTitle
|
||||||
|
htmlFor="adjustMaskKernelSize"
|
||||||
|
text="Adjust Mask"
|
||||||
|
toolTip="Expand or shrink mask. Using the slider to adjust the kernel size for dilation or erosion."
|
||||||
|
/>
|
||||||
|
<RowContainer>
|
||||||
|
<Slider
|
||||||
|
className="w-[180px]"
|
||||||
|
defaultValue={[12]}
|
||||||
|
min={1}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
value={[Math.floor(settings.adjustMaskKernelSize)]}
|
||||||
|
onValueChange={(vals) =>
|
||||||
|
updateSettings({ adjustMaskKernelSize: vals[0] })
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<NumberInput
|
||||||
|
id="adjustMaskKernelSize"
|
||||||
|
className="w-[60px] rounded-full"
|
||||||
|
numberValue={settings.adjustMaskKernelSize}
|
||||||
|
allowFloat={false}
|
||||||
|
onNumberValueChange={(val) => {
|
||||||
|
updateSettings({ adjustMaskKernelSize: val })
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</RowContainer>
|
||||||
|
|
||||||
|
<RowContainer>
|
||||||
|
<div className="flex gap-1 justify-start">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
className="p-1 h-8"
|
||||||
|
onClick={() => adjustMask("expand")}
|
||||||
|
disabled={isProcessing}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-1 select-none">
|
||||||
|
<Plus size={16} />
|
||||||
|
expand
|
||||||
|
</div>
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
className="p-1 h-8"
|
||||||
|
onClick={() => adjustMask("shrink")}
|
||||||
|
disabled={isProcessing}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-1 select-none">
|
||||||
|
<Minus size={16} />
|
||||||
|
Shrink
|
||||||
|
</div>
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
className="p-1 h-8 justify-self-end"
|
||||||
|
onClick={clearMask}
|
||||||
|
disabled={isProcessing}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-1 select-none">Clear</div>
|
||||||
|
</Button>
|
||||||
|
</RowContainer>
|
||||||
|
</div>
|
||||||
|
<Separator />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4 mt-4">
|
<div className="flex flex-col gap-4 mt-4">
|
||||||
{renderCropper()}
|
{renderCropper()}
|
||||||
{renderExtender()}
|
{renderExtender()}
|
||||||
|
{renderMaskAdjuster()}
|
||||||
{renderPowerPaintTaskType()}
|
{renderPowerPaintTaskType()}
|
||||||
{renderSteps()}
|
{renderSteps()}
|
||||||
{renderGuidanceScale()}
|
{renderGuidanceScale()}
|
||||||
|
202
web_app/src/components/ui/context-menu.tsx
Normal file
202
web_app/src/components/ui/context-menu.tsx
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import * as React from "react"
|
||||||
|
import * as ContextMenuPrimitive from "@radix-ui/react-context-menu"
|
||||||
|
import {
|
||||||
|
CheckIcon,
|
||||||
|
ChevronRightIcon,
|
||||||
|
DotFilledIcon,
|
||||||
|
} from "@radix-ui/react-icons"
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils"
|
||||||
|
|
||||||
|
const ContextMenu = ContextMenuPrimitive.Root
|
||||||
|
|
||||||
|
const ContextMenuTrigger = ContextMenuPrimitive.Trigger
|
||||||
|
|
||||||
|
const ContextMenuGroup = ContextMenuPrimitive.Group
|
||||||
|
|
||||||
|
const ContextMenuPortal = ContextMenuPrimitive.Portal
|
||||||
|
|
||||||
|
const ContextMenuSub = ContextMenuPrimitive.Sub
|
||||||
|
|
||||||
|
const ContextMenuRadioGroup = ContextMenuPrimitive.RadioGroup
|
||||||
|
|
||||||
|
const ContextMenuSubTrigger = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.SubTrigger>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubTrigger> & {
|
||||||
|
inset?: boolean
|
||||||
|
}
|
||||||
|
>(({ className, inset, children, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.SubTrigger
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[state=open]:text-accent-foreground",
|
||||||
|
inset && "pl-8",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
<ChevronRightIcon className="ml-auto h-4 w-4" />
|
||||||
|
</ContextMenuPrimitive.SubTrigger>
|
||||||
|
))
|
||||||
|
ContextMenuSubTrigger.displayName = ContextMenuPrimitive.SubTrigger.displayName
|
||||||
|
|
||||||
|
const ContextMenuSubContent = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.SubContent>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.SubContent>
|
||||||
|
>(({ className, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.SubContent
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-lg data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
ContextMenuSubContent.displayName = ContextMenuPrimitive.SubContent.displayName
|
||||||
|
|
||||||
|
const ContextMenuContent = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.Content>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Content>
|
||||||
|
>(({ className, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.Portal>
|
||||||
|
<ContextMenuPrimitive.Content
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"z-50 min-w-[8rem] overflow-hidden rounded-md border bg-popover p-1 text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
</ContextMenuPrimitive.Portal>
|
||||||
|
))
|
||||||
|
ContextMenuContent.displayName = ContextMenuPrimitive.Content.displayName
|
||||||
|
|
||||||
|
const ContextMenuItem = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.Item>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Item> & {
|
||||||
|
inset?: boolean
|
||||||
|
}
|
||||||
|
>(({ className, inset, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.Item
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"relative flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||||
|
inset && "pl-8",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
ContextMenuItem.displayName = ContextMenuPrimitive.Item.displayName
|
||||||
|
|
||||||
|
const ContextMenuCheckboxItem = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.CheckboxItem>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.CheckboxItem>
|
||||||
|
>(({ className, children, checked, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.CheckboxItem
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
checked={checked}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
|
||||||
|
<ContextMenuPrimitive.ItemIndicator>
|
||||||
|
<CheckIcon className="h-4 w-4" />
|
||||||
|
</ContextMenuPrimitive.ItemIndicator>
|
||||||
|
</span>
|
||||||
|
{children}
|
||||||
|
</ContextMenuPrimitive.CheckboxItem>
|
||||||
|
))
|
||||||
|
ContextMenuCheckboxItem.displayName =
|
||||||
|
ContextMenuPrimitive.CheckboxItem.displayName
|
||||||
|
|
||||||
|
const ContextMenuRadioItem = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.RadioItem>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.RadioItem>
|
||||||
|
>(({ className, children, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.RadioItem
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"relative flex cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
|
||||||
|
<ContextMenuPrimitive.ItemIndicator>
|
||||||
|
<DotFilledIcon className="h-4 w-4 fill-current" />
|
||||||
|
</ContextMenuPrimitive.ItemIndicator>
|
||||||
|
</span>
|
||||||
|
{children}
|
||||||
|
</ContextMenuPrimitive.RadioItem>
|
||||||
|
))
|
||||||
|
ContextMenuRadioItem.displayName = ContextMenuPrimitive.RadioItem.displayName
|
||||||
|
|
||||||
|
const ContextMenuLabel = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.Label>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Label> & {
|
||||||
|
inset?: boolean
|
||||||
|
}
|
||||||
|
>(({ className, inset, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.Label
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"px-2 py-1.5 text-sm font-semibold text-foreground",
|
||||||
|
inset && "pl-8",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
ContextMenuLabel.displayName = ContextMenuPrimitive.Label.displayName
|
||||||
|
|
||||||
|
const ContextMenuSeparator = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ContextMenuPrimitive.Separator>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof ContextMenuPrimitive.Separator>
|
||||||
|
>(({ className, ...props }, ref) => (
|
||||||
|
<ContextMenuPrimitive.Separator
|
||||||
|
ref={ref}
|
||||||
|
className={cn("-mx-1 my-1 h-px bg-border", className)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
ContextMenuSeparator.displayName = ContextMenuPrimitive.Separator.displayName
|
||||||
|
|
||||||
|
const ContextMenuShortcut = ({
|
||||||
|
className,
|
||||||
|
...props
|
||||||
|
}: React.HTMLAttributes<HTMLSpanElement>) => {
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
className={cn(
|
||||||
|
"ml-auto text-xs tracking-widest text-muted-foreground",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ContextMenuShortcut.displayName = "ContextMenuShortcut"
|
||||||
|
|
||||||
|
export {
|
||||||
|
ContextMenu,
|
||||||
|
ContextMenuTrigger,
|
||||||
|
ContextMenuContent,
|
||||||
|
ContextMenuItem,
|
||||||
|
ContextMenuCheckboxItem,
|
||||||
|
ContextMenuRadioItem,
|
||||||
|
ContextMenuLabel,
|
||||||
|
ContextMenuSeparator,
|
||||||
|
ContextMenuShortcut,
|
||||||
|
ContextMenuGroup,
|
||||||
|
ContextMenuPortal,
|
||||||
|
ContextMenuSub,
|
||||||
|
ContextMenuSubContent,
|
||||||
|
ContextMenuSubTrigger,
|
||||||
|
ContextMenuRadioGroup,
|
||||||
|
}
|
@ -17,6 +17,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
|||||||
const handleOnBlur = () => {
|
const handleOnBlur = () => {
|
||||||
updateAppState({ disableShortCuts: false })
|
updateAppState({ disableShortCuts: false })
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<input
|
<input
|
||||||
type={type}
|
type={type}
|
||||||
|
@ -201,3 +201,28 @@ export async function getSamplers(): Promise<string[]> {
|
|||||||
const res = await api.post("/samplers")
|
const res = await api.post("/samplers")
|
||||||
return res.data
|
return res.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function postAdjustMask(
|
||||||
|
mask: File | Blob,
|
||||||
|
operate: "expand" | "shrink",
|
||||||
|
kernel_size: number
|
||||||
|
) {
|
||||||
|
const maskBase64 = await convertToBase64(mask)
|
||||||
|
const res = await fetch(`${API_ENDPOINT}/adjust_mask`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
mask: maskBase64,
|
||||||
|
operate: operate,
|
||||||
|
kernel_size: kernel_size,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
if (res.ok) {
|
||||||
|
const blob = await res.blob()
|
||||||
|
return blob
|
||||||
|
}
|
||||||
|
const errMsg = await res.json()
|
||||||
|
throw new Error(errMsg)
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ import { immer } from "zustand/middleware/immer"
|
|||||||
import { castDraft } from "immer"
|
import { castDraft } from "immer"
|
||||||
import { createWithEqualityFn } from "zustand/traditional"
|
import { createWithEqualityFn } from "zustand/traditional"
|
||||||
import {
|
import {
|
||||||
|
AdjustMaskOperate,
|
||||||
CV2Flag,
|
CV2Flag,
|
||||||
ExtenderDirection,
|
ExtenderDirection,
|
||||||
FreeuConfig,
|
FreeuConfig,
|
||||||
@ -27,13 +28,14 @@ import {
|
|||||||
PAINT_BY_EXAMPLE,
|
PAINT_BY_EXAMPLE,
|
||||||
} from "./const"
|
} from "./const"
|
||||||
import {
|
import {
|
||||||
|
blobToImage,
|
||||||
canvasToImage,
|
canvasToImage,
|
||||||
dataURItoBlob,
|
dataURItoBlob,
|
||||||
generateMask,
|
generateMask,
|
||||||
loadImage,
|
loadImage,
|
||||||
srcToFile,
|
srcToFile,
|
||||||
} from "./utils"
|
} from "./utils"
|
||||||
import inpaint, { getGenInfo, runPlugin } from "./api"
|
import inpaint, { getGenInfo, postAdjustMask, runPlugin } from "./api"
|
||||||
import { toast } from "@/components/ui/use-toast"
|
import { toast } from "@/components/ui/use-toast"
|
||||||
|
|
||||||
type FileManagerState = {
|
type FileManagerState = {
|
||||||
@ -102,13 +104,14 @@ export type Settings = {
|
|||||||
|
|
||||||
// PowerPaint
|
// PowerPaint
|
||||||
powerpaintTask: PowerPaintTask
|
powerpaintTask: PowerPaintTask
|
||||||
|
|
||||||
|
// AdjustMask
|
||||||
|
adjustMaskKernelSize: number
|
||||||
}
|
}
|
||||||
|
|
||||||
type InteractiveSegState = {
|
type InteractiveSegState = {
|
||||||
isInteractiveSeg: boolean
|
isInteractiveSeg: boolean
|
||||||
interactiveSegMask: HTMLImageElement | null
|
|
||||||
tmpInteractiveSegMask: HTMLImageElement | null
|
tmpInteractiveSegMask: HTMLImageElement | null
|
||||||
prevInteractiveSegMask: HTMLImageElement | null
|
|
||||||
clicks: number[][]
|
clicks: number[][]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,8 +122,12 @@ type EditorState = {
|
|||||||
lineGroups: LineGroup[]
|
lineGroups: LineGroup[]
|
||||||
lastLineGroup: LineGroup
|
lastLineGroup: LineGroup
|
||||||
curLineGroup: LineGroup
|
curLineGroup: LineGroup
|
||||||
// 只用来显示
|
|
||||||
|
// mask from interactive-seg or other segmentation models
|
||||||
extraMasks: HTMLImageElement[]
|
extraMasks: HTMLImageElement[]
|
||||||
|
prevExtraMasks: HTMLImageElement[]
|
||||||
|
|
||||||
|
temporaryMasks: HTMLImageElement[]
|
||||||
// redo 相关
|
// redo 相关
|
||||||
redoRenders: HTMLImageElement[]
|
redoRenders: HTMLImageElement[]
|
||||||
redoCurLines: Line[]
|
redoCurLines: Line[]
|
||||||
@ -135,6 +142,7 @@ type AppState = {
|
|||||||
imageWidth: number
|
imageWidth: number
|
||||||
isInpainting: boolean
|
isInpainting: boolean
|
||||||
isPluginRunning: boolean
|
isPluginRunning: boolean
|
||||||
|
isAdjustingMask: boolean
|
||||||
windowSize: Size
|
windowSize: Size
|
||||||
editorState: EditorState
|
editorState: EditorState
|
||||||
disableShortCuts: boolean
|
disableShortCuts: boolean
|
||||||
@ -209,6 +217,9 @@ type AppAction = {
|
|||||||
redo: () => void
|
redo: () => void
|
||||||
undoDisabled: () => boolean
|
undoDisabled: () => boolean
|
||||||
redoDisabled: () => boolean
|
redoDisabled: () => boolean
|
||||||
|
|
||||||
|
adjustMask: (operate: AdjustMaskOperate) => Promise<void>
|
||||||
|
clearMask: () => void
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultValues: AppState = {
|
const defaultValues: AppState = {
|
||||||
@ -219,6 +230,7 @@ const defaultValues: AppState = {
|
|||||||
imageWidth: 0,
|
imageWidth: 0,
|
||||||
isInpainting: false,
|
isInpainting: false,
|
||||||
isPluginRunning: false,
|
isPluginRunning: false,
|
||||||
|
isAdjustingMask: false,
|
||||||
disableShortCuts: false,
|
disableShortCuts: false,
|
||||||
|
|
||||||
windowSize: {
|
windowSize: {
|
||||||
@ -230,6 +242,8 @@ const defaultValues: AppState = {
|
|||||||
brushSizeScale: 1,
|
brushSizeScale: 1,
|
||||||
renders: [],
|
renders: [],
|
||||||
extraMasks: [],
|
extraMasks: [],
|
||||||
|
prevExtraMasks: [],
|
||||||
|
temporaryMasks: [],
|
||||||
lineGroups: [],
|
lineGroups: [],
|
||||||
lastLineGroup: [],
|
lastLineGroup: [],
|
||||||
curLineGroup: [],
|
curLineGroup: [],
|
||||||
@ -240,9 +254,7 @@ const defaultValues: AppState = {
|
|||||||
|
|
||||||
interactiveSegState: {
|
interactiveSegState: {
|
||||||
isInteractiveSeg: false,
|
isInteractiveSeg: false,
|
||||||
interactiveSegMask: null,
|
|
||||||
tmpInteractiveSegMask: null,
|
tmpInteractiveSegMask: null,
|
||||||
prevInteractiveSegMask: null,
|
|
||||||
clicks: [],
|
clicks: [],
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -323,6 +335,7 @@ const defaultValues: AppState = {
|
|||||||
enableFreeu: false,
|
enableFreeu: false,
|
||||||
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
||||||
powerpaintTask: PowerPaintTask.text_guided,
|
powerpaintTask: PowerPaintTask.text_guided,
|
||||||
|
adjustMaskKernelSize: 12,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -335,10 +348,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
if (get().settings.showExtender) {
|
if (get().settings.showExtender) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const { lastLineGroup, curLineGroup } = get().editorState
|
const { lastLineGroup, curLineGroup, prevExtraMasks, extraMasks } =
|
||||||
const { prevInteractiveSegMask, interactiveSegMask } =
|
get().editorState
|
||||||
get().interactiveSegState
|
if (curLineGroup.length !== 0 || extraMasks.length !== 0) {
|
||||||
if (curLineGroup.length !== 0 || interactiveSegMask !== null) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const { imageWidth, imageHeight } = get()
|
const { imageWidth, imageHeight } = get()
|
||||||
@ -347,13 +359,13 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
imageWidth,
|
imageWidth,
|
||||||
imageHeight,
|
imageHeight,
|
||||||
[lastLineGroup],
|
[lastLineGroup],
|
||||||
prevInteractiveSegMask ? [prevInteractiveSegMask] : [],
|
prevExtraMasks,
|
||||||
BRUSH_COLOR
|
BRUSH_COLOR
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
const maskImage = await canvasToImage(maskCanvas)
|
const maskImage = await canvasToImage(maskCanvas)
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.editorState.extraMasks.push(castDraft(maskImage))
|
state.editorState.temporaryMasks.push(castDraft(maskImage))
|
||||||
})
|
})
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e)
|
console.error(e)
|
||||||
@ -362,7 +374,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
},
|
},
|
||||||
hidePrevMask: () => {
|
hidePrevMask: () => {
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.editorState.extraMasks = []
|
state.editorState.temporaryMasks = []
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -408,33 +420,36 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const { lastLineGroup, curLineGroup, lineGroups, renders } =
|
const {
|
||||||
get().editorState
|
lastLineGroup,
|
||||||
|
curLineGroup,
|
||||||
const { interactiveSegMask, prevInteractiveSegMask } =
|
lineGroups,
|
||||||
get().interactiveSegState
|
renders,
|
||||||
|
prevExtraMasks,
|
||||||
|
extraMasks,
|
||||||
|
} = get().editorState
|
||||||
|
|
||||||
const useLastLineGroup =
|
const useLastLineGroup =
|
||||||
curLineGroup.length === 0 &&
|
curLineGroup.length === 0 &&
|
||||||
interactiveSegMask === null &&
|
extraMasks.length === 0 &&
|
||||||
!settings.showExtender
|
!settings.showExtender
|
||||||
|
|
||||||
// useLastLineGroup 的影响
|
// useLastLineGroup 的影响
|
||||||
// 1. 使用上一次的 mask
|
// 1. 使用上一次的 mask
|
||||||
// 2. 结果替换当前 render
|
// 2. 结果替换当前 render
|
||||||
let maskImage = null
|
let maskImages: HTMLImageElement[] = []
|
||||||
let maskLineGroup: LineGroup = []
|
let maskLineGroup: LineGroup = []
|
||||||
if (useLastLineGroup === true) {
|
if (useLastLineGroup === true) {
|
||||||
maskLineGroup = lastLineGroup
|
maskLineGroup = lastLineGroup
|
||||||
maskImage = prevInteractiveSegMask
|
maskImages = prevExtraMasks
|
||||||
} else {
|
} else {
|
||||||
maskLineGroup = curLineGroup
|
maskLineGroup = curLineGroup
|
||||||
maskImage = interactiveSegMask
|
maskImages = extraMasks
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
maskLineGroup.length === 0 &&
|
maskLineGroup.length === 0 &&
|
||||||
maskImage === null &&
|
maskImages === null &&
|
||||||
!settings.showExtender
|
!settings.showExtender
|
||||||
) {
|
) {
|
||||||
toast({
|
toast({
|
||||||
@ -474,7 +489,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
imageWidth,
|
imageWidth,
|
||||||
imageHeight,
|
imageHeight,
|
||||||
[maskLineGroup],
|
[maskLineGroup],
|
||||||
maskImage ? [maskImage] : []
|
maskImages
|
||||||
)
|
)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -500,6 +515,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
lineGroups: newLineGroups,
|
lineGroups: newLineGroups,
|
||||||
lastLineGroup: maskLineGroup,
|
lastLineGroup: maskLineGroup,
|
||||||
curLineGroup: [],
|
curLineGroup: [],
|
||||||
|
extraMasks: [],
|
||||||
|
prevExtraMasks: maskImages,
|
||||||
})
|
})
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
toast({
|
toast({
|
||||||
@ -512,15 +529,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
set((state) => {
|
set((state) => {
|
||||||
state.isInpainting = false
|
state.isInpainting = false
|
||||||
})
|
})
|
||||||
|
|
||||||
const newInteractiveSegState = {
|
|
||||||
...defaultValues.interactiveSegState,
|
|
||||||
prevInteractiveSegMask: maskImage,
|
|
||||||
}
|
|
||||||
|
|
||||||
set((state) => {
|
|
||||||
state.interactiveSegState = castDraft(newInteractiveSegState)
|
|
||||||
})
|
|
||||||
},
|
},
|
||||||
|
|
||||||
runRenderablePlugin: async (
|
runRenderablePlugin: async (
|
||||||
@ -557,8 +565,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
} else {
|
} else {
|
||||||
const newMask = new Image()
|
const newMask = new Image()
|
||||||
await loadImage(newMask, blob)
|
await loadImage(newMask, blob)
|
||||||
get().updateInteractiveSegState({
|
set((state) => {
|
||||||
interactiveSegMask: newMask,
|
state.editorState.extraMasks.push(castDraft(newMask))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
const end = new Date()
|
const end = new Date()
|
||||||
@ -618,7 +626,9 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
},
|
},
|
||||||
|
|
||||||
getIsProcessing: (): boolean => {
|
getIsProcessing: (): boolean => {
|
||||||
return get().isInpainting || get().isPluginRunning
|
return (
|
||||||
|
get().isInpainting || get().isPluginRunning || get().isAdjustingMask
|
||||||
|
)
|
||||||
},
|
},
|
||||||
|
|
||||||
isSD: (): boolean => {
|
isSD: (): boolean => {
|
||||||
@ -809,14 +819,14 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
|
|
||||||
handleInteractiveSegAccept: () => {
|
handleInteractiveSegAccept: () => {
|
||||||
set((state) => {
|
set((state) => {
|
||||||
return {
|
if (state.interactiveSegState.tmpInteractiveSegMask) {
|
||||||
...state,
|
state.editorState.extraMasks.push(
|
||||||
interactiveSegState: {
|
castDraft(state.interactiveSegState.tmpInteractiveSegMask)
|
||||||
...defaultValues.interactiveSegState,
|
)
|
||||||
interactiveSegMask:
|
|
||||||
state.interactiveSegState.tmpInteractiveSegMask,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
state.interactiveSegState = castDraft({
|
||||||
|
...defaultValues.interactiveSegState,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -986,10 +996,54 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
set((state) => {
|
set((state) => {
|
||||||
state.settings.seed = newValue
|
state.settings.seed = newValue
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
adjustMask: async (operate: AdjustMaskOperate) => {
|
||||||
|
const { imageWidth, imageHeight } = get()
|
||||||
|
const { curLineGroup, extraMasks } = get().editorState
|
||||||
|
const { adjustMaskKernelSize } = get().settings
|
||||||
|
if (curLineGroup.length === 0 && extraMasks.length === 0) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
set((state) => {
|
||||||
|
state.isAdjustingMask = true
|
||||||
|
})
|
||||||
|
|
||||||
|
const maskCanvas = generateMask(
|
||||||
|
imageWidth,
|
||||||
|
imageHeight,
|
||||||
|
[curLineGroup],
|
||||||
|
extraMasks,
|
||||||
|
BRUSH_COLOR
|
||||||
|
)
|
||||||
|
const maskBlob = dataURItoBlob(maskCanvas.toDataURL())
|
||||||
|
const newMaskBlob = await postAdjustMask(
|
||||||
|
maskBlob,
|
||||||
|
operate,
|
||||||
|
adjustMaskKernelSize
|
||||||
|
)
|
||||||
|
const newMask = await blobToImage(newMaskBlob)
|
||||||
|
|
||||||
|
// TODO: currently ignore stroke undo/redo
|
||||||
|
set((state) => {
|
||||||
|
state.editorState.extraMasks = [castDraft(newMask)]
|
||||||
|
state.editorState.curLineGroup = []
|
||||||
|
})
|
||||||
|
|
||||||
|
set((state) => {
|
||||||
|
state.isAdjustingMask = false
|
||||||
|
})
|
||||||
|
},
|
||||||
|
clearMask: () => {
|
||||||
|
set((state) => {
|
||||||
|
state.editorState.extraMasks = []
|
||||||
|
state.editorState.curLineGroup = []
|
||||||
|
})
|
||||||
|
},
|
||||||
})),
|
})),
|
||||||
{
|
{
|
||||||
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
||||||
version: 0,
|
version: 1,
|
||||||
partialize: (state) =>
|
partialize: (state) =>
|
||||||
Object.fromEntries(
|
Object.fromEntries(
|
||||||
Object.entries(state).filter(([key]) =>
|
Object.entries(state).filter(([key]) =>
|
||||||
|
@ -125,3 +125,5 @@ export enum PowerPaintTask {
|
|||||||
object_remove = "object-remove",
|
object_remove = "object-remove",
|
||||||
outpainting = "outpainting",
|
outpainting = "outpainting",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type AdjustMaskOperate = "expand" | "shrink"
|
||||||
|
@ -53,6 +53,13 @@ export function loadImage(image: HTMLImageElement, src: string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function blobToImage(blob: Blob) {
|
||||||
|
const dataURL = URL.createObjectURL(blob)
|
||||||
|
const newImage = new Image()
|
||||||
|
await loadImage(newImage, dataURL)
|
||||||
|
return newImage
|
||||||
|
}
|
||||||
|
|
||||||
export function canvasToImage(
|
export function canvasToImage(
|
||||||
canvas: HTMLCanvasElement
|
canvas: HTMLCanvasElement
|
||||||
): Promise<HTMLImageElement> {
|
): Promise<HTMLImageElement> {
|
||||||
|
Loading…
Reference in New Issue
Block a user