From e889e527ab545707f4c3c3136f43d46739f22cd1 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 5 Jan 2024 14:57:30 +0800 Subject: [PATCH] add adjust mask feature --- lama_cleaner/api.py | 10 + lama_cleaner/helper.py | 7 +- lama_cleaner/schema.py | 13 +- lama_cleaner/tests/test_adjust_mask.py | 15 ++ lama_cleaner/tests/utils.py | 2 - web_app/package-lock.json | 29 +++ web_app/package.json | 1 + web_app/src/App.tsx | 9 +- web_app/src/components/DiffusionProgress.tsx | 4 +- web_app/src/components/Editor.tsx | 17 +- web_app/src/components/PromptInput.tsx | 1 - .../components/SidePanel/DiffusionOptions.tsx | 96 ++++++++- web_app/src/components/ui/context-menu.tsx | 202 ++++++++++++++++++ web_app/src/components/ui/input.tsx | 1 + web_app/src/lib/api.ts | 25 +++ web_app/src/lib/states.ts | 142 ++++++++---- web_app/src/lib/types.ts | 2 + web_app/src/lib/utils.ts | 7 + 18 files changed, 507 insertions(+), 76 deletions(-) create mode 100644 lama_cleaner/tests/test_adjust_mask.py create mode 100644 web_app/src/components/ui/context-menu.tsx diff --git a/lama_cleaner/api.py b/lama_cleaner/api.py index 434926c..53a33d2 100644 --- a/lama_cleaner/api.py +++ b/lama_cleaner/api.py @@ -29,6 +29,7 @@ from lama_cleaner.helper import ( numpy_to_bytes, concat_alpha_channel, gen_frontend_mask, + adjust_mask, ) from lama_cleaner.model.utils import torch_gc from lama_cleaner.model_info import ModelInfo @@ -44,6 +45,7 @@ from lama_cleaner.schema import ( RunPluginRequest, SDSampler, PluginInfo, + AdjustMaskRequest, ) CURRENT_DIR = Path(__file__).parent.absolute().resolve() @@ -150,6 +152,7 @@ class Api: self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) + self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") # fmt: on @@ -294,6 +297,13 @@ class Api: def api_samplers(self) -> List[str]: return [member.value for member in SDSampler.__members__.values()] + def api_adjust_mask(self, req: AdjustMaskRequest): + mask, _, _ = decode_base64_to_image(req.mask, gray=True) + cv2.imwrite("tmp_adjust_mask_input.png", mask) + mask = adjust_mask(mask, req.kernel_size, req.operate) + cv2.imwrite("tmp_adjust_mask.png", mask) + return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") + def launch(self): self.app.include_router(self.router) uvicorn.run( diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 551c457..177df02 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -358,16 +358,19 @@ def adjust_mask(mask: np.ndarray, kernel_size: int, operate): mask[mask >= 127] = 255 mask[mask < 127] = 0 # fronted brush color "ffcc00bb" + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1) + ) if operate == "expand": mask = cv2.dilate( mask, - np.ones((kernel_size, kernel_size), np.uint8), + kernel, iterations=1, ) else: mask = cv2.erode( mask, - np.ones((kernel_size, kernel_size), np.uint8), + kernel, iterations=1, ) res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 626e897..65e11f4 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -110,8 +110,8 @@ class ApiConfig(BaseModel): class InpaintRequest(BaseModel): - image: Optional[str] = Field(..., description="base64 encoded image") - mask: Optional[str] = Field(..., description="base64 encoded mask") + image: Optional[str] = Field(None, description="base64 encoded image") + mask: Optional[str] = Field(None, description="base64 encoded mask") ldm_steps: int = Field(20, description="Steps for ldm model.") ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.") @@ -289,3 +289,12 @@ class ServerConfigResponse(BaseModel): class SwitchModelRequest(BaseModel): name: str + + +AdjustMaskOperate = Literal["expand", "shrink"] + + +class AdjustMaskRequest(BaseModel): + mask: str = Field(..., description="base64 encoded mask. 255 means area to do inpaint") + operate: AdjustMaskOperate = Field(..., description="expand or shrink") + kernel_size: int = Field(5, description="Kernel size for expanding mask") diff --git a/lama_cleaner/tests/test_adjust_mask.py b/lama_cleaner/tests/test_adjust_mask.py new file mode 100644 index 0000000..33de61e --- /dev/null +++ b/lama_cleaner/tests/test_adjust_mask.py @@ -0,0 +1,15 @@ +import cv2 +from lama_cleaner.helper import adjust_mask +from lama_cleaner.tests.utils import current_dir, save_dir + +mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png" + + +def test_adjust_mask(): + mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) + res_mask = adjust_mask(mask, 0, "expand") + cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask) + res_mask = adjust_mask(mask, 40, "expand") + cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask) + res_mask = adjust_mask(mask, 20, "shrink") + cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask) diff --git a/lama_cleaner/tests/utils.py b/lama_cleaner/tests/utils.py index 6786080..342e179 100644 --- a/lama_cleaner/tests/utils.py +++ b/lama_cleaner/tests/utils.py @@ -31,8 +31,6 @@ def assert_equal( mask_p=current_dir / "mask.png", ): img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) - config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0] - config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0] print(f"Input image shape: {img.shape}") res = model(img, mask, config) ok = cv2.imwrite( diff --git a/web_app/package-lock.json b/web_app/package-lock.json index 408241e..a6e54bf 100644 --- a/web_app/package-lock.json +++ b/web_app/package-lock.json @@ -12,6 +12,7 @@ "@hookform/resolvers": "^3.3.2", "@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-alert-dialog": "^1.0.5", + "@radix-ui/react-context-menu": "^2.1.5", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-icons": "^1.3.0", @@ -1537,6 +1538,34 @@ } } }, + "node_modules/@radix-ui/react-context-menu": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.1.5.tgz", + "integrity": "sha512-R5XaDj06Xul1KGb+WP8qiOh7tKJNz2durpLBXAGZjSVtctcRFCuEvy2gtMwRJGePwQQE5nV77gs4FwRi8T+r2g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-menu": "2.0.6", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-controllable-state": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-dialog": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz", diff --git a/web_app/package.json b/web_app/package.json index 81f7926..ba5cb30 100644 --- a/web_app/package.json +++ b/web_app/package.json @@ -14,6 +14,7 @@ "@hookform/resolvers": "^3.3.2", "@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-alert-dialog": "^1.0.5", + "@radix-ui/react-context-menu": "^2.1.5", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-icons": "^1.3.0", diff --git a/web_app/src/App.tsx b/web_app/src/App.tsx index e00ec21..6650840 100644 --- a/web_app/src/App.tsx +++ b/web_app/src/App.tsx @@ -1,5 +1,4 @@ -import { useCallback, useEffect, useMemo, useRef } from "react" -import { nanoid } from "nanoid" +import { useCallback, useEffect, useRef } from "react" import useInputImage from "@/hooks/useInputImage" import { keepGUIAlive } from "@/lib/utils" @@ -52,10 +51,6 @@ function Home() { fetchServerConfig() }, []) - const workspaceId = useMemo(() => { - return nanoid() - }, [file]) - const dragCounter = useRef(0) const handleDrag = useCallback((event: any) => { @@ -155,7 +150,7 @@ function Home() {
- + {!file ? ( { diff --git a/web_app/src/components/DiffusionProgress.tsx b/web_app/src/components/DiffusionProgress.tsx index 3585146..8c56ad2 100644 --- a/web_app/src/components/DiffusionProgress.tsx +++ b/web_app/src/components/DiffusionProgress.tsx @@ -52,9 +52,9 @@ const DiffusionProgress = () => { return (
diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index 038d2b2..a96d418 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -95,6 +95,7 @@ export default function Editor(props: EditorProps) { const brushSize = useStore((state) => state.getBrushSize()) const renders = useStore((state) => state.editorState.renders) const extraMasks = useStore((state) => state.editorState.extraMasks) + const temporaryMasks = useStore((state) => state.editorState.temporaryMasks) const lineGroups = useStore((state) => state.editorState.lineGroups) const curLineGroup = useStore((state) => state.editorState.curLineGroup) @@ -166,6 +167,9 @@ export default function Editor(props: EditorProps) { context.canvas.width = imageWidth context.canvas.height = imageHeight context.clearRect(0, 0, context.canvas.width, context.canvas.height) + temporaryMasks.forEach((maskImage) => { + context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) + }) extraMasks.forEach((maskImage) => { context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) }) @@ -182,20 +186,9 @@ export default function Editor(props: EditorProps) { imageHeight ) } - if ( - !interactiveSegState.isInteractiveSeg && - interactiveSegState.interactiveSegMask - ) { - context.drawImage( - interactiveSegState.interactiveSegMask, - 0, - 0, - imageWidth, - imageHeight - ) - } drawLines(context, curLineGroup) }, [ + temporaryMasks, extraMasks, isOriginalLoaded, interactiveSegState, diff --git a/web_app/src/components/PromptInput.tsx b/web_app/src/components/PromptInput.tsx index 5b65851..da63524 100644 --- a/web_app/src/components/PromptInput.tsx +++ b/web_app/src/components/PromptInput.tsx @@ -21,7 +21,6 @@ const PromptInput = () => { state.hidePrevMask, ]) const ref = useRef(null) - useClickAway(ref, () => { if (ref?.current) { const input = ref.current as HTMLInputElement diff --git a/web_app/src/components/SidePanel/DiffusionOptions.tsx b/web_app/src/components/SidePanel/DiffusionOptions.tsx index 3b8f608..55e418e 100644 --- a/web_app/src/components/SidePanel/DiffusionOptions.tsx +++ b/web_app/src/components/SidePanel/DiffusionOptions.tsx @@ -1,4 +1,4 @@ -import { FormEvent } from "react" +import { FormEvent, useRef } from "react" import { useStore } from "@/lib/states" import { Switch } from "../ui/switch" import { NumberInput } from "../ui/input" @@ -18,7 +18,8 @@ import { Slider } from "../ui/slider" import { useImage } from "@/hooks/useImage" import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const" import { RowContainer, LabelTitle } from "./LabelTitle" -import { Upload } from "lucide-react" +import { Minus, Plus, Upload } from "lucide-react" +import { useClickAway } from "react-use" const ExtenderButton = ({ text, @@ -31,7 +32,7 @@ const ExtenderButton = ({ return (