From 354a1280a41bcb58bfc42b860e84d9cf926295f3 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 11 Dec 2023 22:28:07 +0800 Subject: [PATCH] wip --- lama_cleaner/model_manager.py | 17 +- lama_cleaner/schema.py | 46 +- lama_cleaner/server.py | 32 +- web_app/src/components/Editor.tsx | 749 +++----------------- web_app/src/components/Header.tsx | 2 +- web_app/src/components/InteractiveSeg.tsx | 19 +- web_app/src/components/Plugins.tsx | 47 +- web_app/src/components/PromptInput.tsx | 44 +- web_app/src/components/ui/dropdown-menu.tsx | 3 +- web_app/src/lib/api.ts | 18 +- web_app/src/lib/states.ts | 272 ++++++- web_app/src/lib/types.ts | 4 + web_app/src/lib/utils.ts | 25 + 13 files changed, 531 insertions(+), 747 deletions(-) diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 90e0f12..6149ecf 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -36,16 +36,15 @@ class ModelManager: return ControlNet(device, **{**kwargs, "model_info": model_info}) else: if model_info.model_type in [ + ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SD, - ModelType.DIFFUSERS_SDXL, ]: - raise NotImplementedError( - f"When using non inpaint Stable Diffusion model, you must enable controlnet" - ) - if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT: return SD(device, model_id_or_path=model_info.path, **kwargs) - if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT: + if model_info.model_type in [ + ModelType.DIFFUSERS_SDXL_INPAINT, + ModelType.DIFFUSERS_SDXL, + ]: return SDXL(device, model_id_or_path=model_info.path, **kwargs) raise NotImplementedError(f"Unsupported model: {name}") @@ -88,7 +87,7 @@ class ModelManager: if self.kwargs["sd_controlnet_method"] == control_method: return - if not self.available_models[self.name].support_controlnet(): + if not self.available_models[self.name].support_controlnet: return del self.model @@ -105,7 +104,7 @@ class ModelManager: if str(self.model.device) == "mps": return - if self.available_models[self.name].support_freeu(): + if self.available_models[self.name].support_freeu: if config.sd_freeu: freeu_config = config.sd_freeu_config self.model.model.enable_freeu( @@ -118,7 +117,7 @@ class ModelManager: self.model.model.disable_freeu() def enable_disable_lcm_lora(self, config: Config): - if self.available_models[self.name].support_lcm_lora(): + if self.available_models[self.name].support_lcm_lora: if config.sd_lcm_lora: if not self.model.model.pipe.get_list_adapters(): self.model.model.load_lora_weights(self.model.lcm_lora_id) diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index f7b253e..cd46e78 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -1,8 +1,14 @@ -from typing import Optional +from typing import Optional, List from enum import Enum from PIL.Image import Image -from pydantic import BaseModel +from pydantic import BaseModel, computed_field + +from lama_cleaner.const import ( + SDXL_CONTROLNET_CHOICES, + SD2_CONTROLNET_CHOICES, + SD_CONTROLNET_CHOICES, +) DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline" @@ -31,6 +37,36 @@ class ModelInfo(BaseModel): model_type: ModelType is_single_file_diffusers: bool = False + @computed_field + @property + def need_prompt(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SDXL_INPAINT, + ] or self.name in [ + "timbrooks/instruct-pix2pix", + "kandinsky-community/kandinsky-2-2-decoder-inpaint", + ] + + @computed_field + @property + def controlnets(self) -> List[str]: + if self.model_type in [ + ModelType.DIFFUSERS_SDXL, + ModelType.DIFFUSERS_SDXL_INPAINT, + ]: + return SDXL_CONTROLNET_CHOICES + if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]: + if self.name in ["stabilityai/stable-diffusion-2-inpainting"]: + return SD2_CONTROLNET_CHOICES + else: + return SD_CONTROLNET_CHOICES + return [] + + @computed_field + @property def support_lcm_lora(self) -> bool: return self.model_type in [ ModelType.DIFFUSERS_SD, @@ -39,6 +75,8 @@ class ModelInfo(BaseModel): ModelType.DIFFUSERS_SDXL_INPAINT, ] + @computed_field + @property def support_controlnet(self) -> bool: return self.model_type in [ ModelType.DIFFUSERS_SD, @@ -47,6 +85,8 @@ class ModelInfo(BaseModel): ModelType.DIFFUSERS_SDXL_INPAINT, ] + @computed_field + @property def support_freeu(self) -> bool: return ( self.model_type @@ -56,7 +96,7 @@ class ModelInfo(BaseModel): ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SDXL_INPAINT, ] - or "instruct-pix2pix" in self.name + or "timbrooks/instruct-pix2pix" in self.name ) diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index c302aeb..2922ee3 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -419,14 +419,8 @@ def run_plugin(): @app.route("/server_config", methods=["GET"]) def get_server_config(): - controlnet = { - "SD": SD_CONTROLNET_CHOICES, - "SD2": SD2_CONTROLNET_CHOICES, - "SDXL": SDXL_CONTROLNET_CHOICES, - } return { "plugins": list(plugins.keys()), - "availableControlNet": controlnet, "enableFileManager": enable_file_manager, "enableAutoSaving": enable_auto_saving, }, 200 @@ -434,20 +428,12 @@ def get_server_config(): @app.route("/models", methods=["GET"]) def get_models(): - return [ - { - **it.dict(), - "support_lcm_lora": it.support_lcm_lora(), - "support_controlnet": it.support_controlnet(), - "support_freeu": it.support_freeu(), - } - for it in model.scan_models() - ] + return [it.model_dump() for it in model.scan_models()] @app.route("/model") def current_model(): - return model.available_models[model.name].dict(), 200 + return model.available_models[model.name].model_dump(), 200 @app.route("/is_desktop") @@ -600,8 +586,20 @@ def main(args): else: input_image_path = args.input + # 为了兼容性 + model_name_map = { + "sd1.5": "runwayml/stable-diffusion-inpainting", + "anything4": "Sanster/anything-4.0-inpainting", + "realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting", + "sd2": "stabilityai/stable-diffusion-2-inpainting", + "sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + "kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint", + "paint_by_example": "Fantasy-Studio/Paint-by-Example", + "instruct_pix2pix": "timbrooks/instruct-pix2pix", + } + model = ModelManager( - name=args.model, + name=model_name_map.get(args.model, args.model), sd_controlnet=args.sd_controlnet, sd_controlnet_method=args.sd_controlnet_method, device=device, diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index 1a892a0..54ddfa7 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -6,14 +6,15 @@ import { TransformComponent, TransformWrapper, } from "react-zoom-pan-pinch" -import { useKeyPressEvent, useWindowSize } from "react-use" -import inpaint, { downloadToOutput, runPlugin } from "@/lib/api" +import { useKeyPressEvent } from "react-use" +import { downloadToOutput, runPlugin } from "@/lib/api" import { IconButton } from "@/components/ui/button" import { askWritePermission, copyCanvasImage, downloadImage, drawLines, + generateMask, isMidClick, isRightClick, loadImage, @@ -21,20 +22,13 @@ import { srcToFile, } from "@/lib/utils" import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react" -import emitter, { - EVENT_PROMPT, - EVENT_CUSTOM_MASK, - EVENT_PAINT_BY_EXAMPLE, - RERUN_LAST_MASK, - DREAM_BUTTON_MOUSE_ENTER, - DREAM_BUTTON_MOUSE_LEAVE, -} from "@/lib/event" import { useImage } from "@/hooks/useImage" import { Slider } from "./ui/slider" -import { Line, LineGroup, PluginName } from "@/lib/types" +import { PluginName } from "@/lib/types" import { useHotkeys } from "react-hotkeys-hook" import { useStore } from "@/lib/states" import Cropper from "./Cropper" +import { InteractiveSegPoints } from "./InteractiveSeg" const TOOLBAR_HEIGHT = 200 const MIN_BRUSH_SIZE = 10 @@ -50,6 +44,8 @@ export default function Editor(props: EditorProps) { const { toast } = useToast() const [ + idForUpdateView, + windowSize, isInpainting, imageWidth, imageHeight, @@ -75,7 +71,13 @@ export default function Editor(props: EditorProps) { redo, undoDisabled, redoDisabled, + isProcessing, + updateAppState, + runMannually, + runInpainting, ] = useStore((state) => [ + state.idForUpdateView, + state.windowSize, state.isInpainting, state.imageWidth, state.imageHeight, @@ -101,65 +103,32 @@ export default function Editor(props: EditorProps) { state.redo, state.undoDisabled(), state.redoDisabled(), + state.getIsProcessing(), + state.updateAppState, + state.runMannually(), + state.runInpainting, ]) const baseBrushSize = useStore((state) => state.editorState.baseBrushSize) const brushSize = useStore((state) => state.getBrushSize()) const renders = useStore((state) => state.editorState.renders) + const extraMasks = useStore((state) => state.editorState.extraMasks) const lineGroups = useStore((state) => state.editorState.lineGroups) - const lastLineGroup = useStore((state) => state.editorState.lastLineGroup) const curLineGroup = useStore((state) => state.editorState.curLineGroup) const redoLineGroups = useStore((state) => state.editorState.redoLineGroups) - // 纯 local state + // Local State const [showOriginal, setShowOriginal] = useState(false) - - // - const isProcessing = isInpainting - const isDiffusionModels = false - const isPix2Pix = false - - const [showInteractiveSegModal, setShowInteractiveSegModal] = useState(false) - const [interactiveSegMask, setInteractiveSegMask] = useState< - HTMLImageElement | null | undefined - >(null) - - // only used while interactive segmentation is on - const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = useState< - HTMLImageElement | null | undefined - >(null) - const [prevInteractiveSegMask, setPrevInteractiveSegMask] = useState< - HTMLImageElement | null | undefined - >(null) - - // 仅用于在 dream button hover 时显示提示 - const [dreamButtonHoverSegMask, setDreamButtonHoverSegMask] = useState< - HTMLImageElement | null | undefined - >(null) - const [dreamButtonHoverLineGroup, setDreamButtonHoverLineGroup] = - useState([]) - const [original, isOriginalLoaded] = useImage(file) const [context, setContext] = useState() - const [maskCanvas] = useState(() => { - return document.createElement("canvas") - }) const [{ x, y }, setCoords] = useState({ x: -1, y: -1 }) const [showBrush, setShowBrush] = useState(false) const [showRefBrush, setShowRefBrush] = useState(false) const [isPanning, setIsPanning] = useState(false) - const [isChangingBrushSizeByMouse, setIsChangingBrushSizeByMouse] = - useState(false) - const [changeBrushSizeByMouseInit, setChangeBrushSizeByMouseInit] = useState({ - x: -1, - y: -1, - brushSize: 20, - }) const [scale, setScale] = useState(1) const [panned, setPanned] = useState(false) const [minScale, setMinScale] = useState(1.0) - const windowSize = useWindowSize() const windowCenterX = windowSize.width / 2 const windowCenterY = windowSize.height / 2 const viewportRef = useRef(null) @@ -170,402 +139,77 @@ export default function Editor(props: EditorProps) { const [sliderPos, setSliderPos] = useState(0) - const draw = useCallback( - (render: HTMLImageElement, lineGroup: LineGroup) => { - if (!context) { - return - } - console.log( - `[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}` - ) - - context.clearRect(0, 0, context.canvas.width, context.canvas.height) - context.drawImage(render, 0, 0, imageWidth, imageHeight) - if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) { - context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight) - } - if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) { - context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight) - } - if (dreamButtonHoverSegMask) { - context.drawImage( - dreamButtonHoverSegMask, - 0, - 0, - imageWidth, - imageHeight - ) - } - drawLines(context, lineGroup) - drawLines(context, dreamButtonHoverLineGroup) - }, - [ - context, - interactiveSegState, - tmpInteractiveSegMask, - dreamButtonHoverSegMask, - interactiveSegMask, - imageHeight, - imageWidth, - dreamButtonHoverLineGroup, - ] - ) - - const drawLinesOnMask = useCallback( - (_lineGroups: LineGroup[], maskImage?: HTMLImageElement | null) => { - if (!context?.canvas.width || !context?.canvas.height) { - throw new Error("canvas has invalid size") - } - maskCanvas.width = context?.canvas.width - maskCanvas.height = context?.canvas.height - const ctx = maskCanvas.getContext("2d") - if (!ctx) { - throw new Error("could not retrieve mask canvas") - } - - if (maskImage !== undefined && maskImage !== null) { - // TODO: check whether draw yellow mask works on backend - ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight) - } - - _lineGroups.forEach((lineGroup) => { - drawLines(ctx, lineGroup, "white") - }) - - if ( - (maskImage === undefined || maskImage === null) && - _lineGroups.length === 1 && - _lineGroups[0].length === 0 && - isPix2Pix - ) { - // For InstructPix2Pix without mask - drawLines( - ctx, - [ - { - size: 9999999999, - pts: [ - { x: 0, y: 0 }, - { x: imageWidth, y: 0 }, - { x: imageWidth, y: imageHeight }, - { x: 0, y: imageHeight }, - ], - }, - ], - "white" - ) - } - }, - [context, maskCanvas, imageWidth, imageHeight] - ) - const hadDrawSomething = useCallback(() => { return curLineGroup.length !== 0 }, [curLineGroup]) - // const drawOnCurrentRender = useCallback( - // (lineGroup: LineGroup) => { - // console.log("[drawOnCurrentRender] draw on current render") - // if (renders.length === 0) { - // draw(original, lineGroup) - // } else { - // draw(renders[renders.length - 1], lineGroup) - // } - // }, - // [original, renders, draw] - // ) - useEffect(() => { - if (!context) { + if ( + !context || + !isOriginalLoaded || + imageWidth === 0 || + imageHeight === 0 + ) { return } const render = renders.length === 0 ? original : renders[renders.length - 1] console.log( - `[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}` + `[draw] renders.length: ${renders.length} render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}` ) + context.canvas.width = imageWidth + context.canvas.height = imageHeight + context.clearRect(0, 0, context.canvas.width, context.canvas.height) context.drawImage(render, 0, 0, imageWidth, imageHeight) - // if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) { - // context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight) - // } - // if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) { - // context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight) - // } + + extraMasks.forEach((maskImage) => { + context.drawImage(maskImage, 0, 0, imageWidth, imageHeight) + }) + + if ( + interactiveSegState.isInteractiveSeg && + interactiveSegState.tmpInteractiveSegMask + ) { + context.drawImage( + interactiveSegState.tmpInteractiveSegMask, + 0, + 0, + imageWidth, + imageHeight + ) + } + if ( + !interactiveSegState.isInteractiveSeg && + interactiveSegState.interactiveSegMask + ) { + context.drawImage( + interactiveSegState.interactiveSegMask, + 0, + 0, + imageWidth, + imageHeight + ) + } // if (dreamButtonHoverSegMask) { // context.drawImage(dreamButtonHoverSegMask, 0, 0, imageWidth, imageHeight) // } drawLines(context, curLineGroup) // drawLines(context, dreamButtonHoverLineGroup) - }, [renders, file, original, context, curLineGroup, imageHeight, imageWidth]) - - const runInpainting = useCallback( - async ( - useLastLineGroup?: boolean, - customMask?: File, - maskImage?: HTMLImageElement | null, - paintByExampleImage?: File - ) => { - // customMask: mask uploaded by user - // maskImage: mask from interactive segmentation - if (file === undefined) { - return - } - const useCustomMask = customMask !== undefined && customMask !== null - const useMaskImage = maskImage !== undefined && maskImage !== null - // useLastLineGroup 的影响 - // 1. 使用上一次的 mask - // 2. 结果替换当前 render - console.log("runInpainting") - console.log({ - useCustomMask, - useMaskImage, - }) - - let maskLineGroup: LineGroup = [] - if (useLastLineGroup === true) { - if (lastLineGroup.length === 0) { - return - } - maskLineGroup = lastLineGroup - } else if (!useCustomMask) { - if (!hadDrawSomething() && !useMaskImage) { - return - } - - // setLastLineGroup(curLineGroup) - maskLineGroup = curLineGroup - } - - const newLineGroups = [...lineGroups, maskLineGroup] - - cleanCurLineGroup() - setIsDraging(false) - setIsInpainting(true) - drawLinesOnMask([maskLineGroup], maskImage) - - let targetFile = file - console.log( - `randers.length ${renders.length} useLastLineGroup: ${useLastLineGroup}` - ) - if (useLastLineGroup === true) { - // renders.length == 1 还是用原来的 - if (renders.length > 1) { - const lastRender = renders[renders.length - 2] - targetFile = await srcToFile( - lastRender.currentSrc, - file.name, - file.type - ) - } - } else if (renders.length > 0) { - console.info("gradually inpainting on last result") - - const lastRender = renders[renders.length - 1] - targetFile = await srcToFile( - lastRender.currentSrc, - file.name, - file.type - ) - } - - try { - console.log("before run inpaint") - const res = await inpaint( - targetFile, - settings, - cropperRect, - useCustomMask ? undefined : maskCanvas.toDataURL(), - useCustomMask ? customMask : undefined, - paintByExampleImage - ) - if (!res) { - throw new Error("Something went wrong on server side.") - } - const { blob, seed } = res - if (seed) { - setSeed(parseInt(seed, 10)) - } - const newRender = new Image() - await loadImage(newRender, blob) - - if (useLastLineGroup === true) { - const prevRenders = renders.slice(0, -1) - const newRenders = [...prevRenders, newRender] - // setRenders(newRenders) - updateEditorState({ renders: newRenders }) - } else { - const newRenders = [...renders, newRender] - updateEditorState({ renders: newRenders }) - } - - draw(newRender, []) - // Only append new LineGroup after inpainting success - // setLineGroups(newLineGroups) - updateEditorState({ lineGroups: newLineGroups }) - - // clear redo stack - resetRedoState() - } catch (e: any) { - toast({ - variant: "destructive", - title: "Uh oh! Something went wrong.", - description: e.message ? e.message : e.toString(), - }) - // drawOnCurrentRender([]) - } - setIsInpainting(false) - setPrevInteractiveSegMask(maskImage) - setTmpInteractiveSegMask(null) - setInteractiveSegMask(null) - }, - [ - renders, - lineGroups, - curLineGroup, - maskCanvas, - settings, - cropperRect, - // drawOnCurrentRender, - hadDrawSomething, - drawLinesOnMask, - ] - ) - - useEffect(() => { - emitter.on(EVENT_PROMPT, () => { - if (hadDrawSomething() || interactiveSegMask) { - runInpainting(false, undefined, interactiveSegMask) - } else if (lastLineGroup.length !== 0) { - // 使用上一次手绘的 mask 生成 - runInpainting(true, undefined, prevInteractiveSegMask) - } else if (prevInteractiveSegMask) { - // 使用上一次 IS 的 mask 生成 - runInpainting(false, undefined, prevInteractiveSegMask) - } else if (isPix2Pix) { - runInpainting(false, undefined, null) - } else { - toast({ - variant: "destructive", - description: "Please draw mask on picture.", - }) - } - emitter.emit(DREAM_BUTTON_MOUSE_LEAVE) - }) - - return () => { - emitter.off(EVENT_PROMPT) - } }, [ - hadDrawSomething, - runInpainting, - interactiveSegMask, - prevInteractiveSegMask, + renders, + extraMasks, + original, + isOriginalLoaded, + interactiveSegState, + context, + curLineGroup, + imageHeight, + imageWidth, ]) - useEffect(() => { - emitter.on(DREAM_BUTTON_MOUSE_ENTER, () => { - // 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask - if (!hadDrawSomething() && !interactiveSegMask) { - if (prevInteractiveSegMask) { - setDreamButtonHoverSegMask(prevInteractiveSegMask) - } - let lineGroup2Show: LineGroup = [] - if (redoLineGroups.length !== 0) { - lineGroup2Show = redoLineGroups[redoLineGroups.length - 1] - } else if (lineGroups.length !== 0) { - lineGroup2Show = lineGroups[lineGroups.length - 1] - } - console.log( - `[DREAM_BUTTON_MOUSE_ENTER], prevInteractiveSegMask: ${prevInteractiveSegMask} lineGroup2Show: ${lineGroup2Show.length}` - ) - if (lineGroup2Show) { - setDreamButtonHoverLineGroup(lineGroup2Show) - } - } - }) - return () => { - emitter.off(DREAM_BUTTON_MOUSE_ENTER) - } - }, [ - hadDrawSomething, - interactiveSegMask, - prevInteractiveSegMask, - // drawOnCurrentRender, - lineGroups, - redoLineGroups, - ]) - - useEffect(() => { - emitter.on(DREAM_BUTTON_MOUSE_LEAVE, () => { - // 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask - if (!hadDrawSomething() && !interactiveSegMask) { - setDreamButtonHoverSegMask(null) - setDreamButtonHoverLineGroup([]) - // drawOnCurrentRender([]) - } - }) - return () => { - emitter.off(DREAM_BUTTON_MOUSE_LEAVE) - } - }, [hadDrawSomething, interactiveSegMask]) - - useEffect(() => { - emitter.on(EVENT_CUSTOM_MASK, (data: any) => { - // TODO: not work with paint by example - runInpainting(false, data.mask) - }) - - return () => { - emitter.off(EVENT_CUSTOM_MASK) - } - }, [runInpainting]) - - useEffect(() => { - emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => { - if (hadDrawSomething() || interactiveSegMask) { - runInpainting(false, undefined, interactiveSegMask, data.image) - } else if (lastLineGroup.length !== 0) { - // 使用上一次手绘的 mask 生成 - runInpainting(true, undefined, prevInteractiveSegMask, data.image) - } else if (prevInteractiveSegMask) { - // 使用上一次 IS 的 mask 生成 - runInpainting(false, undefined, prevInteractiveSegMask, data.image) - } else { - toast({ - variant: "destructive", - description: "Please draw mask on picture.", - }) - } - }) - - return () => { - emitter.off(EVENT_PAINT_BY_EXAMPLE) - } - }, [runInpainting]) - - useEffect(() => { - emitter.on(RERUN_LAST_MASK, () => { - if (lastLineGroup.length !== 0) { - // 使用上一次手绘的 mask 生成 - runInpainting(true, undefined, prevInteractiveSegMask) - } else if (prevInteractiveSegMask) { - // 使用上一次 IS 的 mask 生成 - runInpainting(false, undefined, prevInteractiveSegMask) - } else { - toast({ - variant: "destructive", - description: "No mask to reuse", - }) - } - }) - return () => { - emitter.off(RERUN_LAST_MASK) - } - }, [runInpainting]) - const getCurrentRender = useCallback(async () => { let targetFile = file if (renders.length > 0) { @@ -575,126 +219,6 @@ export default function Editor(props: EditorProps) { return targetFile }, [file, renders]) - useEffect(() => { - emitter.on(PluginName.InteractiveSeg, () => { - // setIsInteractiveSeg(true) - if (interactiveSegMask !== null) { - setShowInteractiveSegModal(true) - } - }) - return () => { - emitter.off(PluginName.InteractiveSeg) - } - }) - - const runRenderablePlugin = useCallback( - async (name: string, data?: any) => { - if (isProcessing) { - return - } - try { - // TODO 要不要加 undoCurrentLine?? - const start = new Date() - setIsPluginRunning(true) - const targetFile = await getCurrentRender() - const res = await runPlugin(name, targetFile, data?.upscale) - if (!res) { - throw new Error("Something went wrong on server side.") - } - const { blob } = res - const newRender = new Image() - await loadImage(newRender, blob) - setImageSize(newRender.height, newRender.width) - const newRenders = [...renders, newRender] - // setRenders(newRenders) - updateEditorState({ renders: newRenders }) - const newLineGroups = [...lineGroups, []] - updateEditorState({ lineGroups: newLineGroups }) - - const end = new Date() - const time = end.getTime() - start.getTime() - - toast({ - description: `Run ${name} successfully in ${time / 1000}s`, - }) - - const rW = windowSize.width / newRender.width - const rH = (windowSize.height - TOOLBAR_HEIGHT) / newRender.height - let s = 1.0 - if (rW < 1 || rH < 1) { - s = Math.min(rW, rH) - } - setMinScale(s) - setScale(s) - viewportRef.current?.centerView(s, 1) - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - } finally { - setIsPluginRunning(false) - } - }, - [ - renders, - // setRenders, - getCurrentRender, - setIsPluginRunning, - isProcessing, - setImageSize, - lineGroups, - viewportRef, - windowSize, - // setLineGroups, - ] - ) - - useEffect(() => { - emitter.on(PluginName.RemoveBG, () => { - runRenderablePlugin(PluginName.RemoveBG) - }) - return () => { - emitter.off(PluginName.RemoveBG) - } - }, [runRenderablePlugin]) - - useEffect(() => { - emitter.on(PluginName.AnimeSeg, () => { - runRenderablePlugin(PluginName.AnimeSeg) - }) - return () => { - emitter.off(PluginName.AnimeSeg) - } - }, [runRenderablePlugin]) - - useEffect(() => { - emitter.on(PluginName.GFPGAN, () => { - runRenderablePlugin(PluginName.GFPGAN) - }) - return () => { - emitter.off(PluginName.GFPGAN) - } - }, [runRenderablePlugin]) - - useEffect(() => { - emitter.on(PluginName.RestoreFormer, () => { - runRenderablePlugin(PluginName.RestoreFormer) - }) - return () => { - emitter.off(PluginName.RestoreFormer) - } - }, [runRenderablePlugin]) - - useEffect(() => { - emitter.on(PluginName.RealESRGAN, (data: any) => { - runRenderablePlugin(PluginName.RealESRGAN, data) - }) - return () => { - emitter.off(PluginName.RealESRGAN) - } - }, [runRenderablePlugin]) - const hadRunInpainting = () => { return renders.length !== 0 } @@ -736,14 +260,17 @@ export default function Editor(props: EditorProps) { setScale(s) console.log( - `[on file load] image size: ${width}x${height}, canvas size: ${context?.canvas.width}x${context?.canvas.height} scale: ${s}, initialCentered: ${initialCentered}` + `[on file load] image size: ${width}x${height}, scale: ${s}, initialCentered: ${initialCentered}` ) if (context?.canvas) { - context.canvas.width = width - context.canvas.height = height - console.log("[on file load] set canvas size && drawOnCurrentRender") - // drawOnCurrentRender([]) + console.log("[on file load] set canvas size") + if (width != context.canvas.width) { + context.canvas.width = width + } + if (height != context.canvas.height) { + context.canvas.height = height + } } if (!initialCentered) { @@ -753,13 +280,11 @@ export default function Editor(props: EditorProps) { setInitialCentered(true) } }, [ - context?.canvas, viewportRef, original, isOriginalLoaded, windowSize, initialCentered, - // drawOnCurrentRender, getCurrentWidthHeight, ]) @@ -807,17 +332,6 @@ export default function Editor(props: EditorProps) { } }, [windowSize, resetZoom]) - useEffect(() => { - window.addEventListener("blur", () => { - setIsChangingBrushSizeByMouse(false) - }) - return () => { - window.removeEventListener("blur", () => { - setIsChangingBrushSizeByMouse(false) - }) - } - }, []) - const handleEscPressed = () => { if (isProcessing) { return @@ -845,15 +359,6 @@ export default function Editor(props: EditorProps) { } const onMouseDrag = (ev: SyntheticEvent) => { - // if (isChangingBrushSizeByMouse) { - // const initX = changeBrushSizeByMouseInit.x - // // move right: increase brush size - // const newSize = changeBrushSizeByMouseInit.brushSize + (x - initX) - // if (newSize <= MAX_BRUSH_SIZE && newSize >= MIN_BRUSH_SIZE) { - // setBaseBrushSize(newSize) - // } - // return - // } if (interactiveSegState.isInteractiveSeg) { return } @@ -868,26 +373,16 @@ export default function Editor(props: EditorProps) { } handleCanvasMouseMove(mouseXY(ev)) - // const lineGroup = [...curLineGroup] - // lineGroup[lineGroup.length - 1].pts.push(mouseXY(ev)) - // setCurLineGroup(lineGroup) - // drawOnCurrentRender(lineGroup) } const runInteractiveSeg = async (newClicks: number[][]) => { - if (!file) { - return - } - - // setIsInteractiveSegRunning(true) + updateAppState({ isPluginRunning: true }) const targetFile = await getCurrentRender() - const prevMask = null try { const res = await runPlugin( PluginName.InteractiveSeg, targetFile, undefined, - prevMask, newClicks ) if (!res) { @@ -896,7 +391,7 @@ export default function Editor(props: EditorProps) { const { blob } = res const img = new Image() img.onload = () => { - setTmpInteractiveSegMask(img) + updateInteractiveSegState({ tmpInteractiveSegMask: img }) } img.src = blob } catch (e: any) { @@ -905,7 +400,7 @@ export default function Editor(props: EditorProps) { description: e.message ? e.message : e.toString(), }) } - // setIsInteractiveSegRunning(false) + updateAppState({ isPluginRunning: false }) } const onPointerUp = (ev: SyntheticEvent) => { @@ -933,7 +428,7 @@ export default function Editor(props: EditorProps) { return } - if (enableManualInpainting) { + if (runMannually) { setIsDraging(false) } else { runInpainting() @@ -959,15 +454,13 @@ export default function Editor(props: EditorProps) { const onCanvasMouseUp = (ev: SyntheticEvent) => { if (interactiveSegState.isInteractiveSeg) { const xy = mouseXY(ev) - const isX = xy.x - const isY = xy.y const newClicks: number[][] = [...interactiveSegState.clicks] if (isRightClick(ev)) { - newClicks.push([isX, isY, 0, newClicks.length]) + newClicks.push([xy.x, xy.y, 0, newClicks.length]) } else { - newClicks.push([isX, isY, 1, newClicks.length]) + newClicks.push([xy.x, xy.y, 1, newClicks.length]) } - // runInteractiveSeg(newClicks) + runInteractiveSeg(newClicks) updateInteractiveSegState({ clicks: newClicks }) } } @@ -979,13 +472,10 @@ export default function Editor(props: EditorProps) { if (interactiveSegState.isInteractiveSeg) { return } - if (isChangingBrushSizeByMouse) { - return - } if (isPanning) { return } - if (!original.src) { + if (!isOriginalLoaded) { return } const canvas = context?.canvas @@ -1002,26 +492,17 @@ export default function Editor(props: EditorProps) { return } - if ( - isDiffusionModels && - settings.showCroper && - isOutsideCroper(mouseXY(ev)) - ) { - return - } + // if ( + // isDiffusionModels && + // settings.showCroper && + // isOutsideCroper(mouseXY(ev)) + // ) { + // // TODO: 去掉这个逻辑,在 cropper 层截断 click 点击? + // return + // } setIsDraging(true) - - // let lineGroup: LineGroup = [] - // if (enableManualInpainting) { - // lineGroup = [...curLineGroup] - // } - // lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] }) - // setCurLineGroup(lineGroup) - handleCanvasMouseDown(mouseXY(ev)) - - // drawOnCurrentRender(lineGroup) } const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { @@ -1092,7 +573,7 @@ export default function Editor(props: EditorProps) { let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1") maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg") - drawLinesOnMask(lineGroups) + const maskCanvas = generateMask(imageWidth, imageHeight, lineGroups) // Create a link const aDownloadLink = document.createElement("a") // Add the name of the file to the link @@ -1147,11 +628,11 @@ export default function Editor(props: EditorProps) { useHotkeys( "shift+r", () => { - if (enableManualInpainting && hadDrawSomething()) { + if (runMannually && hadDrawSomething()) { runInpainting() } }, - [enableManualInpainting, runInpainting, hadDrawSomething] + [runMannually, runInpainting, hadDrawSomething] ) useHotkeys( @@ -1192,13 +673,11 @@ export default function Editor(props: EditorProps) { (ev) => { ev?.preventDefault() ev?.stopPropagation() - setIsChangingBrushSizeByMouse(true) - setChangeBrushSizeByMouseInit({ x, y, brushSize }) + // TODO: mouse scroll increase/decrease brush size }, (ev) => { ev?.preventDefault() ev?.stopPropagation() - setIsChangingBrushSizeByMouse(false) } ) @@ -1359,20 +838,24 @@ export default function Editor(props: EditorProps) { show={settings.showCroper} /> - {/* {interactiveSegState.isInteractiveSeg ? : <>} */} + {interactiveSegState.isInteractiveSeg ? ( + + ) : ( + <> + )} ) } - const onInteractiveAccept = () => { - setInteractiveSegMask(tmpInteractiveSegMask) - setTmpInteractiveSegMask(null) + // const onInteractiveAccept = () => { + // setInteractiveSegMask(tmpInteractiveSegMask) + // setTmpInteractiveSegMask(null) - if (!enableManualInpainting && tmpInteractiveSegMask) { - runInpainting(false, undefined, tmpInteractiveSegMask) - } - } + // if (!enableManualInpainting && tmpInteractiveSegMask) { + // runInpainting(false, undefined, tmpInteractiveSegMask) + // } + // } return (
+
- {settings.enableManualInpainting ? ( + {settings.enableManualInpainting && + settings.model.model_type === "inpaint" ? ( { - // ensured by disabled - runInpainting(false, undefined, interactiveSegMask) + runInpainting() }} > diff --git a/web_app/src/components/Header.tsx b/web_app/src/components/Header.tsx index 1941465..8322f25 100644 --- a/web_app/src/components/Header.tsx +++ b/web_app/src/components/Header.tsx @@ -69,7 +69,7 @@ const Header = () => { ) return ( -
+
{enableFileManager ? ( { } const InteractiveSegConfirmActions = () => { - const [interactiveSegState, resetInteractiveSegState] = useStore((state) => [ + const [ + interactiveSegState, + resetInteractiveSegState, + handleInteractiveSegAccept, + ] = useStore((state) => [ state.interactiveSegState, state.resetInteractiveSegState, + state.handleInteractiveSegAccept, ]) if (!interactiveSegState.isInteractiveSeg) { return null } - const onAcceptClick = () => { - resetInteractiveSegState() - } - return (
diff --git a/web_app/src/components/ui/dropdown-menu.tsx b/web_app/src/components/ui/dropdown-menu.tsx index 0e4dccf..1938c75 100644 --- a/web_app/src/components/ui/dropdown-menu.tsx +++ b/web_app/src/components/ui/dropdown-menu.tsx @@ -29,7 +29,7 @@ const DropdownMenuSubTrigger = React.forwardRef< e.preventDefault()} /> )) diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 58d8eef..6ab2af2 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -15,18 +15,13 @@ export default async function inpaint( imageFile: File, settings: Settings, croperRect: Rect, - maskBase64?: string, - customMask?: File, - paintByExampleImage?: File + mask: File | Blob, + paintByExampleImage: File | null = null ) { // 1080, 2000, Original const fd = new FormData() fd.append("image", imageFile) - if (maskBase64 !== undefined) { - fd.append("mask", dataURItoBlob(maskBase64)) - } else if (customMask !== undefined) { - fd.append("mask", customMask) - } + fd.append("mask", mask) fd.append("ldmSteps", settings.ldmSteps.toString()) fd.append("ldmSampler", settings.ldmSampler.toString()) @@ -42,8 +37,7 @@ export default async function inpaint( fd.append("croperY", croperRect.y.toString()) fd.append("croperHeight", croperRect.height.toString()) fd.append("croperWidth", croperRect.width.toString()) - // fd.append("useCroper", settings.showCroper ? "true" : "false") - fd.append("useCroper", "false") + fd.append("useCroper", settings.showCroper ? "true" : "false") fd.append("sdMaskBlur", settings.sdMaskBlur.toString()) fd.append("sdStrength", settings.sdStrength.toString()) @@ -147,7 +141,6 @@ export async function runPlugin( name: string, imageFile: File, upscale?: number, - maskFile?: File | null, clicks?: number[][] ) { const fd = new FormData() @@ -159,9 +152,6 @@ export async function runPlugin( if (clicks) { fd.append("clicks", JSON.stringify(clicks)) } - if (maskFile) { - fd.append("mask", maskFile) - } try { const res = await fetch(`${API_ENDPOINT}/run_plugin`, { diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index 797ff52..aa0e425 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -1,7 +1,8 @@ -import { create } from "zustand" import { persist } from "zustand/middleware" import { shallow } from "zustand/shallow" import { immer } from "zustand/middleware/immer" +import { castDraft } from "immer" +import { nanoid } from "nanoid" import { createWithEqualityFn } from "zustand/traditional" import { CV2Flag, @@ -10,6 +11,7 @@ import { Line, LineGroup, ModelInfo, + PluginParams, Point, SDSampler, Size, @@ -17,6 +19,9 @@ import { SortOrder, } from "./types" import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const" +import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils" +import inpaint, { runPlugin } from "./api" +import { toast, useToast } from "@/components/ui/use-toast" type FileManagerState = { sortBy: SortBy @@ -95,7 +100,9 @@ type ServerConfig = { type InteractiveSegState = { isInteractiveSeg: boolean - isInteractiveSegRunning: boolean + interactiveSegMask: HTMLImageElement | null + tmpInteractiveSegMask: HTMLImageElement | null + prevInteractiveSegMask: HTMLImageElement | null clicks: number[][] } @@ -103,9 +110,11 @@ type EditorState = { baseBrushSize: number brushSizeScale: number renders: HTMLImageElement[] + paintByExampleImage: File | null lineGroups: LineGroup[] lastLineGroup: LineGroup curLineGroup: LineGroup + extraMasks: HTMLImageElement[] // redo 相关 redoRenders: HTMLImageElement[] redoCurLines: Line[] @@ -113,6 +122,8 @@ type EditorState = { } type AppState = { + idForUpdateView: string + file: File | null customMask: File | null imageHeight: number @@ -136,6 +147,7 @@ type AppAction = { setCustomFile: (file: File) => void setIsInpainting: (newValue: boolean) => void setIsPluginRunning: (newValue: boolean) => void + getIsProcessing: () => boolean setBaseBrushSize: (newValue: number) => void getBrushSize: () => number setImageSize: (width: number, height: number) => void @@ -151,10 +163,18 @@ type AppAction = { updateFileManagerState: (newState: Partial) => void updateInteractiveSegState: (newState: Partial) => void resetInteractiveSegState: () => void + handleInteractiveSegAccept: () => void showPromptInput: () => boolean showSidePanel: () => boolean + runInpainting: () => Promise + runRenderablePlugin: ( + pluginName: string, + params?: PluginParams + ) => Promise + // EditorState + getCurrentTargetFile: () => Promise updateEditorState: (newState: Partial) => void runMannually: () => boolean handleCanvasMouseDown: (point: Point) => void @@ -168,12 +188,15 @@ type AppAction = { } const defaultValues: AppState = { + idForUpdateView: nanoid(), + file: null, customMask: null, imageHeight: 0, imageWidth: 0, isInpainting: false, isPluginRunning: false, + windowSize: { height: 600, width: 800, @@ -182,6 +205,8 @@ const defaultValues: AppState = { baseBrushSize: DEFAULT_BRUSH_SIZE, brushSizeScale: 1, renders: [], + paintByExampleImage: null, + extraMasks: [], lineGroups: [], lastLineGroup: [], curLineGroup: [], @@ -192,7 +217,9 @@ const defaultValues: AppState = { interactiveSegState: { isInteractiveSeg: false, - isInteractiveSegRunning: false, + interactiveSegMask: null, + tmpInteractiveSegMask: null, + prevInteractiveSegMask: null, clicks: [], }, @@ -267,16 +294,208 @@ export const useStore = createWithEqualityFn()( immer((set, get) => ({ ...defaultValues, + getCurrentTargetFile: async (): Promise => { + const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数 + const renders = get().editorState.renders + + let targetFile = file + if (renders.length > 0) { + const lastRender = renders[renders.length - 1] + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type + ) + } + return targetFile + }, + + runInpainting: async () => { + const { file, imageWidth, imageHeight, settings, cropperState } = get() + + if (file === null) { + return + } + const { + lastLineGroup, + curLineGroup, + lineGroups, + renders, + paintByExampleImage, + } = get().editorState + + const { interactiveSegMask, prevInteractiveSegMask } = + get().interactiveSegState + + const useLastLineGroup = + curLineGroup.length === 0 && interactiveSegMask === null + + const maskImage = useLastLineGroup + ? prevInteractiveSegMask + : interactiveSegMask + + // useLastLineGroup 的影响 + // 1. 使用上一次的 mask + // 2. 结果替换当前 render + let maskLineGroup: LineGroup = [] + if (useLastLineGroup === true) { + if (lastLineGroup.length === 0 && maskImage === null) { + toast({ + variant: "destructive", + description: "Please draw mask on picture", + }) + return + } + maskLineGroup = lastLineGroup + } else { + if (curLineGroup.length === 0 && maskImage === null) { + toast({ + variant: "destructive", + description: "Please draw mask on picture", + }) + return + } + maskLineGroup = curLineGroup + } + + const newLineGroups = [...lineGroups, maskLineGroup] + + set((state) => { + state.isInpainting = true + }) + + let targetFile = file + if (useLastLineGroup === true) { + // renders.length == 1 还是用原来的 + if (renders.length > 1) { + const lastRender = renders[renders.length - 2] + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type + ) + } + } else if (renders.length > 0) { + const lastRender = renders[renders.length - 1] + targetFile = await srcToFile( + lastRender.currentSrc, + file.name, + file.type + ) + } + + const maskCanvas = generateMask( + imageWidth, + imageHeight, + [maskLineGroup], + maskImage ? [maskImage] : [] + ) + + try { + const res = await inpaint( + targetFile, + settings, + cropperState, + dataURItoBlob(maskCanvas.toDataURL()), + paintByExampleImage + ) + + if (!res) { + throw new Error("Something went wrong on server side.") + } + + const { blob, seed } = res + if (seed) { + set((state) => (state.settings.seed = parseInt(seed, 10))) + } + const newRender = new Image() + await loadImage(newRender, blob) + if (useLastLineGroup === true) { + const prevRenders = renders.slice(0, -1) + const newRenders = [...prevRenders, newRender] + get().updateEditorState({ + renders: newRenders, + lineGroups: newLineGroups, + lastLineGroup: curLineGroup, + curLineGroup: [], + }) + } else { + const newRenders = [...renders, newRender] + get().updateEditorState({ + renders: newRenders, + lineGroups: newLineGroups, + lastLineGroup: curLineGroup, + curLineGroup: [], + }) + } + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + } + + get().resetRedoState() + set((state) => { + state.isInpainting = false + }) + + const newInteractiveSegState = { + ...defaultValues.interactiveSegState, + prevInteractiveSegMask: useLastLineGroup ? null : maskImage, + } + + set((state) => { + state.interactiveSegState = castDraft(newInteractiveSegState) + }) + }, + + runRenderablePlugin: async ( + pluginName: string, + params: PluginParams = { upscale: 1 } + ) => { + const { renders, lineGroups } = get().editorState + set((state) => { + state.isInpainting = true + }) + + try { + const start = new Date() + const targetFile = await get().getCurrentTargetFile() + const res = await runPlugin(pluginName, targetFile, params.upscale) + if (!res) { + throw new Error("Something went wrong on server side.") + } + const { blob } = res + const newRender = new Image() + await loadImage(newRender, blob) + get().setImageSize(newRender.height, newRender.width) + const newRenders = [...renders, newRender] + const newLineGroups = [...lineGroups, []] + get().updateEditorState({ + renders: newRenders, + lineGroups: newLineGroups, + }) + const end = new Date() + const time = end.getTime() - start.getTime() + toast({ + description: `Run ${pluginName} successfully in ${time / 1000}s`, + }) + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + } + set((state) => { + state.isInpainting = false + }) + }, + // Edirot State // updateEditorState: (newState: Partial) => { set((state) => { - return { - ...state, - editorState: { - ...state.editorState, - ...newState, - }, - } + state.editorState = castDraft({ ...state.editorState, ...newState }) }) }, @@ -313,6 +532,10 @@ export const useStore = createWithEqualityFn()( ) }, + getIsProcessing: (): boolean => { + return get().isInpainting || get().isPluginRunning + }, + // undo/redo undoDisabled: (): boolean => { @@ -468,15 +691,30 @@ export const useStore = createWithEqualityFn()( updateInteractiveSegState: (newState: Partial) => { set((state) => { - state.interactiveSegState = { - ...state.interactiveSegState, - ...newState, + return { + ...state, + interactiveSegState: { + ...state.interactiveSegState, + ...newState, + }, } }) }, + resetInteractiveSegState: () => { + get().updateInteractiveSegState(defaultValues.interactiveSegState) + }, + + handleInteractiveSegAccept: () => { set((state) => { - state.interactiveSegState = defaultValues.interactiveSegState + return { + ...state, + interactiveSegState: { + ...defaultValues.interactiveSegState, + interactiveSegMask: + state.interactiveSegState.tmpInteractiveSegMask, + }, + } }) }, @@ -492,8 +730,12 @@ export const useStore = createWithEqualityFn()( setFile: (file: File) => set((state) => { - // TODO: 清空各种状态 state.file = file + state.interactiveSegState = castDraft( + defaultValues.interactiveSegState + ) + state.editorState = castDraft(defaultValues.editorState) + state.cropperState = defaultValues.cropperState }), setCustomFile: (file: File) => diff --git a/web_app/src/lib/types.ts b/web_app/src/lib/types.ts index 9ce1c9f..ed3d797 100644 --- a/web_app/src/lib/types.ts +++ b/web_app/src/lib/types.ts @@ -25,6 +25,10 @@ export enum PluginName { InteractiveSeg = "InteractiveSeg", } +export interface PluginParams { + upscale: number +} + export enum SortBy { NAME = "name", CTIME = "ctime", diff --git a/web_app/src/lib/utils.ts b/web_app/src/lib/utils.ts index b1e9917..53c5d4f 100644 --- a/web_app/src/lib/utils.ts +++ b/web_app/src/lib/utils.ts @@ -159,3 +159,28 @@ export function drawLines( ctx.stroke() }) } + +export const generateMask = ( + imageWidth: number, + imageHeight: number, + lineGroups: LineGroup[], + maskImages: HTMLImageElement[] = [] +): HTMLCanvasElement => { + const maskCanvas = document.createElement("canvas") + maskCanvas.width = imageWidth + maskCanvas.height = imageHeight + const ctx = maskCanvas.getContext("2d") + if (!ctx) { + throw new Error("could not retrieve mask canvas") + } + + maskImages.forEach((maskImage) => { + ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight) + }) + + lineGroups.forEach((lineGroup) => { + drawLines(ctx, lineGroup, "white") + }) + + return maskCanvas +}