This commit is contained in:
Qing 2023-12-11 22:28:07 +08:00
parent fecf4beef0
commit 354a1280a4
13 changed files with 531 additions and 747 deletions

View File

@ -36,16 +36,15 @@ class ModelManager:
return ControlNet(device, **{**kwargs, "model_info": model_info}) return ControlNet(device, **{**kwargs, "model_info": model_info})
else: else:
if model_info.model_type in [ if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
]: ]:
raise NotImplementedError(
f"When using non inpaint Stable Diffusion model, you must enable controlnet"
)
if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT:
return SD(device, model_id_or_path=model_info.path, **kwargs) return SD(device, model_id_or_path=model_info.path, **kwargs)
if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT: if model_info.model_type in [
ModelType.DIFFUSERS_SDXL_INPAINT,
ModelType.DIFFUSERS_SDXL,
]:
return SDXL(device, model_id_or_path=model_info.path, **kwargs) return SDXL(device, model_id_or_path=model_info.path, **kwargs)
raise NotImplementedError(f"Unsupported model: {name}") raise NotImplementedError(f"Unsupported model: {name}")
@ -88,7 +87,7 @@ class ModelManager:
if self.kwargs["sd_controlnet_method"] == control_method: if self.kwargs["sd_controlnet_method"] == control_method:
return return
if not self.available_models[self.name].support_controlnet(): if not self.available_models[self.name].support_controlnet:
return return
del self.model del self.model
@ -105,7 +104,7 @@ class ModelManager:
if str(self.model.device) == "mps": if str(self.model.device) == "mps":
return return
if self.available_models[self.name].support_freeu(): if self.available_models[self.name].support_freeu:
if config.sd_freeu: if config.sd_freeu:
freeu_config = config.sd_freeu_config freeu_config = config.sd_freeu_config
self.model.model.enable_freeu( self.model.model.enable_freeu(
@ -118,7 +117,7 @@ class ModelManager:
self.model.model.disable_freeu() self.model.model.disable_freeu()
def enable_disable_lcm_lora(self, config: Config): def enable_disable_lcm_lora(self, config: Config):
if self.available_models[self.name].support_lcm_lora(): if self.available_models[self.name].support_lcm_lora:
if config.sd_lcm_lora: if config.sd_lcm_lora:
if not self.model.model.pipe.get_list_adapters(): if not self.model.model.pipe.get_list_adapters():
self.model.model.load_lora_weights(self.model.lcm_lora_id) self.model.model.load_lora_weights(self.model.lcm_lora_id)

View File

@ -1,8 +1,14 @@
from typing import Optional from typing import Optional, List
from enum import Enum from enum import Enum
from PIL.Image import Image from PIL.Image import Image
from pydantic import BaseModel from pydantic import BaseModel, computed_field
from lama_cleaner.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
)
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline" DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
@ -31,6 +37,36 @@ class ModelInfo(BaseModel):
model_type: ModelType model_type: ModelType
is_single_file_diffusers: bool = False is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
"timbrooks/instruct-pix2pix",
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if self.name in ["stabilityai/stable-diffusion-2-inpainting"]:
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_lcm_lora(self) -> bool: def support_lcm_lora(self) -> bool:
return self.model_type in [ return self.model_type in [
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
@ -39,6 +75,8 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL_INPAINT, ModelType.DIFFUSERS_SDXL_INPAINT,
] ]
@computed_field
@property
def support_controlnet(self) -> bool: def support_controlnet(self) -> bool:
return self.model_type in [ return self.model_type in [
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
@ -47,6 +85,8 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL_INPAINT, ModelType.DIFFUSERS_SDXL_INPAINT,
] ]
@computed_field
@property
def support_freeu(self) -> bool: def support_freeu(self) -> bool:
return ( return (
self.model_type self.model_type
@ -56,7 +96,7 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT, ModelType.DIFFUSERS_SDXL_INPAINT,
] ]
or "instruct-pix2pix" in self.name or "timbrooks/instruct-pix2pix" in self.name
) )

View File

@ -419,14 +419,8 @@ def run_plugin():
@app.route("/server_config", methods=["GET"]) @app.route("/server_config", methods=["GET"])
def get_server_config(): def get_server_config():
controlnet = {
"SD": SD_CONTROLNET_CHOICES,
"SD2": SD2_CONTROLNET_CHOICES,
"SDXL": SDXL_CONTROLNET_CHOICES,
}
return { return {
"plugins": list(plugins.keys()), "plugins": list(plugins.keys()),
"availableControlNet": controlnet,
"enableFileManager": enable_file_manager, "enableFileManager": enable_file_manager,
"enableAutoSaving": enable_auto_saving, "enableAutoSaving": enable_auto_saving,
}, 200 }, 200
@ -434,20 +428,12 @@ def get_server_config():
@app.route("/models", methods=["GET"]) @app.route("/models", methods=["GET"])
def get_models(): def get_models():
return [ return [it.model_dump() for it in model.scan_models()]
{
**it.dict(),
"support_lcm_lora": it.support_lcm_lora(),
"support_controlnet": it.support_controlnet(),
"support_freeu": it.support_freeu(),
}
for it in model.scan_models()
]
@app.route("/model") @app.route("/model")
def current_model(): def current_model():
return model.available_models[model.name].dict(), 200 return model.available_models[model.name].model_dump(), 200
@app.route("/is_desktop") @app.route("/is_desktop")
@ -600,8 +586,20 @@ def main(args):
else: else:
input_image_path = args.input input_image_path = args.input
# 为了兼容性
model_name_map = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
"sd2": "stabilityai/stable-diffusion-2-inpainting",
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
}
model = ModelManager( model = ModelManager(
name=args.model, name=model_name_map.get(args.model, args.model),
sd_controlnet=args.sd_controlnet, sd_controlnet=args.sd_controlnet,
sd_controlnet_method=args.sd_controlnet_method, sd_controlnet_method=args.sd_controlnet_method,
device=device, device=device,

View File

@ -6,14 +6,15 @@ import {
TransformComponent, TransformComponent,
TransformWrapper, TransformWrapper,
} from "react-zoom-pan-pinch" } from "react-zoom-pan-pinch"
import { useKeyPressEvent, useWindowSize } from "react-use" import { useKeyPressEvent } from "react-use"
import inpaint, { downloadToOutput, runPlugin } from "@/lib/api" import { downloadToOutput, runPlugin } from "@/lib/api"
import { IconButton } from "@/components/ui/button" import { IconButton } from "@/components/ui/button"
import { import {
askWritePermission, askWritePermission,
copyCanvasImage, copyCanvasImage,
downloadImage, downloadImage,
drawLines, drawLines,
generateMask,
isMidClick, isMidClick,
isRightClick, isRightClick,
loadImage, loadImage,
@ -21,20 +22,13 @@ import {
srcToFile, srcToFile,
} from "@/lib/utils" } from "@/lib/utils"
import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react" import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react"
import emitter, {
EVENT_PROMPT,
EVENT_CUSTOM_MASK,
EVENT_PAINT_BY_EXAMPLE,
RERUN_LAST_MASK,
DREAM_BUTTON_MOUSE_ENTER,
DREAM_BUTTON_MOUSE_LEAVE,
} from "@/lib/event"
import { useImage } from "@/hooks/useImage" import { useImage } from "@/hooks/useImage"
import { Slider } from "./ui/slider" import { Slider } from "./ui/slider"
import { Line, LineGroup, PluginName } from "@/lib/types" import { PluginName } from "@/lib/types"
import { useHotkeys } from "react-hotkeys-hook" import { useHotkeys } from "react-hotkeys-hook"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import Cropper from "./Cropper" import Cropper from "./Cropper"
import { InteractiveSegPoints } from "./InteractiveSeg"
const TOOLBAR_HEIGHT = 200 const TOOLBAR_HEIGHT = 200
const MIN_BRUSH_SIZE = 10 const MIN_BRUSH_SIZE = 10
@ -50,6 +44,8 @@ export default function Editor(props: EditorProps) {
const { toast } = useToast() const { toast } = useToast()
const [ const [
idForUpdateView,
windowSize,
isInpainting, isInpainting,
imageWidth, imageWidth,
imageHeight, imageHeight,
@ -75,7 +71,13 @@ export default function Editor(props: EditorProps) {
redo, redo,
undoDisabled, undoDisabled,
redoDisabled, redoDisabled,
isProcessing,
updateAppState,
runMannually,
runInpainting,
] = useStore((state) => [ ] = useStore((state) => [
state.idForUpdateView,
state.windowSize,
state.isInpainting, state.isInpainting,
state.imageWidth, state.imageWidth,
state.imageHeight, state.imageHeight,
@ -101,65 +103,32 @@ export default function Editor(props: EditorProps) {
state.redo, state.redo,
state.undoDisabled(), state.undoDisabled(),
state.redoDisabled(), state.redoDisabled(),
state.getIsProcessing(),
state.updateAppState,
state.runMannually(),
state.runInpainting,
]) ])
const baseBrushSize = useStore((state) => state.editorState.baseBrushSize) const baseBrushSize = useStore((state) => state.editorState.baseBrushSize)
const brushSize = useStore((state) => state.getBrushSize()) const brushSize = useStore((state) => state.getBrushSize())
const renders = useStore((state) => state.editorState.renders) const renders = useStore((state) => state.editorState.renders)
const extraMasks = useStore((state) => state.editorState.extraMasks)
const lineGroups = useStore((state) => state.editorState.lineGroups) const lineGroups = useStore((state) => state.editorState.lineGroups)
const lastLineGroup = useStore((state) => state.editorState.lastLineGroup) const lastLineGroup = useStore((state) => state.editorState.lastLineGroup)
const curLineGroup = useStore((state) => state.editorState.curLineGroup) const curLineGroup = useStore((state) => state.editorState.curLineGroup)
const redoLineGroups = useStore((state) => state.editorState.redoLineGroups) const redoLineGroups = useStore((state) => state.editorState.redoLineGroups)
// 纯 local state // Local State
const [showOriginal, setShowOriginal] = useState(false) const [showOriginal, setShowOriginal] = useState(false)
//
const isProcessing = isInpainting
const isDiffusionModels = false
const isPix2Pix = false
const [showInteractiveSegModal, setShowInteractiveSegModal] = useState(false)
const [interactiveSegMask, setInteractiveSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
// only used while interactive segmentation is on
const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
const [prevInteractiveSegMask, setPrevInteractiveSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
// 仅用于在 dream button hover 时显示提示
const [dreamButtonHoverSegMask, setDreamButtonHoverSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
const [dreamButtonHoverLineGroup, setDreamButtonHoverLineGroup] =
useState<LineGroup>([])
const [original, isOriginalLoaded] = useImage(file) const [original, isOriginalLoaded] = useImage(file)
const [context, setContext] = useState<CanvasRenderingContext2D>() const [context, setContext] = useState<CanvasRenderingContext2D>()
const [maskCanvas] = useState<HTMLCanvasElement>(() => {
return document.createElement("canvas")
})
const [{ x, y }, setCoords] = useState({ x: -1, y: -1 }) const [{ x, y }, setCoords] = useState({ x: -1, y: -1 })
const [showBrush, setShowBrush] = useState(false) const [showBrush, setShowBrush] = useState(false)
const [showRefBrush, setShowRefBrush] = useState(false) const [showRefBrush, setShowRefBrush] = useState(false)
const [isPanning, setIsPanning] = useState<boolean>(false) const [isPanning, setIsPanning] = useState<boolean>(false)
const [isChangingBrushSizeByMouse, setIsChangingBrushSizeByMouse] =
useState<boolean>(false)
const [changeBrushSizeByMouseInit, setChangeBrushSizeByMouseInit] = useState({
x: -1,
y: -1,
brushSize: 20,
})
const [scale, setScale] = useState<number>(1) const [scale, setScale] = useState<number>(1)
const [panned, setPanned] = useState<boolean>(false) const [panned, setPanned] = useState<boolean>(false)
const [minScale, setMinScale] = useState<number>(1.0) const [minScale, setMinScale] = useState<number>(1.0)
const windowSize = useWindowSize()
const windowCenterX = windowSize.width / 2 const windowCenterX = windowSize.width / 2
const windowCenterY = windowSize.height / 2 const windowCenterY = windowSize.height / 2
const viewportRef = useRef<ReactZoomPanPinchContentRef | null>(null) const viewportRef = useRef<ReactZoomPanPinchContentRef | null>(null)
@ -170,402 +139,77 @@ export default function Editor(props: EditorProps) {
const [sliderPos, setSliderPos] = useState<number>(0) const [sliderPos, setSliderPos] = useState<number>(0)
const draw = useCallback(
(render: HTMLImageElement, lineGroup: LineGroup) => {
if (!context) {
return
}
console.log(
`[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
)
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
context.drawImage(render, 0, 0, imageWidth, imageHeight)
if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) {
context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight)
}
if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) {
context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight)
}
if (dreamButtonHoverSegMask) {
context.drawImage(
dreamButtonHoverSegMask,
0,
0,
imageWidth,
imageHeight
)
}
drawLines(context, lineGroup)
drawLines(context, dreamButtonHoverLineGroup)
},
[
context,
interactiveSegState,
tmpInteractiveSegMask,
dreamButtonHoverSegMask,
interactiveSegMask,
imageHeight,
imageWidth,
dreamButtonHoverLineGroup,
]
)
const drawLinesOnMask = useCallback(
(_lineGroups: LineGroup[], maskImage?: HTMLImageElement | null) => {
if (!context?.canvas.width || !context?.canvas.height) {
throw new Error("canvas has invalid size")
}
maskCanvas.width = context?.canvas.width
maskCanvas.height = context?.canvas.height
const ctx = maskCanvas.getContext("2d")
if (!ctx) {
throw new Error("could not retrieve mask canvas")
}
if (maskImage !== undefined && maskImage !== null) {
// TODO: check whether draw yellow mask works on backend
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
}
_lineGroups.forEach((lineGroup) => {
drawLines(ctx, lineGroup, "white")
})
if (
(maskImage === undefined || maskImage === null) &&
_lineGroups.length === 1 &&
_lineGroups[0].length === 0 &&
isPix2Pix
) {
// For InstructPix2Pix without mask
drawLines(
ctx,
[
{
size: 9999999999,
pts: [
{ x: 0, y: 0 },
{ x: imageWidth, y: 0 },
{ x: imageWidth, y: imageHeight },
{ x: 0, y: imageHeight },
],
},
],
"white"
)
}
},
[context, maskCanvas, imageWidth, imageHeight]
)
const hadDrawSomething = useCallback(() => { const hadDrawSomething = useCallback(() => {
return curLineGroup.length !== 0 return curLineGroup.length !== 0
}, [curLineGroup]) }, [curLineGroup])
// const drawOnCurrentRender = useCallback(
// (lineGroup: LineGroup) => {
// console.log("[drawOnCurrentRender] draw on current render")
// if (renders.length === 0) {
// draw(original, lineGroup)
// } else {
// draw(renders[renders.length - 1], lineGroup)
// }
// },
// [original, renders, draw]
// )
useEffect(() => { useEffect(() => {
if (!context) { if (
!context ||
!isOriginalLoaded ||
imageWidth === 0 ||
imageHeight === 0
) {
return return
} }
const render = renders.length === 0 ? original : renders[renders.length - 1] const render = renders.length === 0 ? original : renders[renders.length - 1]
console.log( console.log(
`[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}` `[draw] renders.length: ${renders.length} render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
) )
context.canvas.width = imageWidth
context.canvas.height = imageHeight
context.clearRect(0, 0, context.canvas.width, context.canvas.height) context.clearRect(0, 0, context.canvas.width, context.canvas.height)
context.drawImage(render, 0, 0, imageWidth, imageHeight) context.drawImage(render, 0, 0, imageWidth, imageHeight)
// if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) {
// context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight) extraMasks.forEach((maskImage) => {
// } context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
// if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) { })
// context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight)
// } if (
interactiveSegState.isInteractiveSeg &&
interactiveSegState.tmpInteractiveSegMask
) {
context.drawImage(
interactiveSegState.tmpInteractiveSegMask,
0,
0,
imageWidth,
imageHeight
)
}
if (
!interactiveSegState.isInteractiveSeg &&
interactiveSegState.interactiveSegMask
) {
context.drawImage(
interactiveSegState.interactiveSegMask,
0,
0,
imageWidth,
imageHeight
)
}
// if (dreamButtonHoverSegMask) { // if (dreamButtonHoverSegMask) {
// context.drawImage(dreamButtonHoverSegMask, 0, 0, imageWidth, imageHeight) // context.drawImage(dreamButtonHoverSegMask, 0, 0, imageWidth, imageHeight)
// } // }
drawLines(context, curLineGroup) drawLines(context, curLineGroup)
// drawLines(context, dreamButtonHoverLineGroup) // drawLines(context, dreamButtonHoverLineGroup)
}, [renders, file, original, context, curLineGroup, imageHeight, imageWidth])
const runInpainting = useCallback(
async (
useLastLineGroup?: boolean,
customMask?: File,
maskImage?: HTMLImageElement | null,
paintByExampleImage?: File
) => {
// customMask: mask uploaded by user
// maskImage: mask from interactive segmentation
if (file === undefined) {
return
}
const useCustomMask = customMask !== undefined && customMask !== null
const useMaskImage = maskImage !== undefined && maskImage !== null
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
console.log("runInpainting")
console.log({
useCustomMask,
useMaskImage,
})
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
if (lastLineGroup.length === 0) {
return
}
maskLineGroup = lastLineGroup
} else if (!useCustomMask) {
if (!hadDrawSomething() && !useMaskImage) {
return
}
// setLastLineGroup(curLineGroup)
maskLineGroup = curLineGroup
}
const newLineGroups = [...lineGroups, maskLineGroup]
cleanCurLineGroup()
setIsDraging(false)
setIsInpainting(true)
drawLinesOnMask([maskLineGroup], maskImage)
let targetFile = file
console.log(
`randers.length ${renders.length} useLastLineGroup: ${useLastLineGroup}`
)
if (useLastLineGroup === true) {
// renders.length == 1 还是用原来的
if (renders.length > 1) {
const lastRender = renders[renders.length - 2]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
} else if (renders.length > 0) {
console.info("gradually inpainting on last result")
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
try {
console.log("before run inpaint")
const res = await inpaint(
targetFile,
settings,
cropperRect,
useCustomMask ? undefined : maskCanvas.toDataURL(),
useCustomMask ? customMask : undefined,
paintByExampleImage
)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob, seed } = res
if (seed) {
setSeed(parseInt(seed, 10))
}
const newRender = new Image()
await loadImage(newRender, blob)
if (useLastLineGroup === true) {
const prevRenders = renders.slice(0, -1)
const newRenders = [...prevRenders, newRender]
// setRenders(newRenders)
updateEditorState({ renders: newRenders })
} else {
const newRenders = [...renders, newRender]
updateEditorState({ renders: newRenders })
}
draw(newRender, [])
// Only append new LineGroup after inpainting success
// setLineGroups(newLineGroups)
updateEditorState({ lineGroups: newLineGroups })
// clear redo stack
resetRedoState()
} catch (e: any) {
toast({
variant: "destructive",
title: "Uh oh! Something went wrong.",
description: e.message ? e.message : e.toString(),
})
// drawOnCurrentRender([])
}
setIsInpainting(false)
setPrevInteractiveSegMask(maskImage)
setTmpInteractiveSegMask(null)
setInteractiveSegMask(null)
},
[
renders,
lineGroups,
curLineGroup,
maskCanvas,
settings,
cropperRect,
// drawOnCurrentRender,
hadDrawSomething,
drawLinesOnMask,
]
)
useEffect(() => {
emitter.on(EVENT_PROMPT, () => {
if (hadDrawSomething() || interactiveSegMask) {
runInpainting(false, undefined, interactiveSegMask)
} else if (lastLineGroup.length !== 0) {
// 使用上一次手绘的 mask 生成
runInpainting(true, undefined, prevInteractiveSegMask)
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask)
} else if (isPix2Pix) {
runInpainting(false, undefined, null)
} else {
toast({
variant: "destructive",
description: "Please draw mask on picture.",
})
}
emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
})
return () => {
emitter.off(EVENT_PROMPT)
}
}, [ }, [
hadDrawSomething, renders,
runInpainting, extraMasks,
interactiveSegMask, original,
prevInteractiveSegMask, isOriginalLoaded,
interactiveSegState,
context,
curLineGroup,
imageHeight,
imageWidth,
]) ])
useEffect(() => {
emitter.on(DREAM_BUTTON_MOUSE_ENTER, () => {
// 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask
if (!hadDrawSomething() && !interactiveSegMask) {
if (prevInteractiveSegMask) {
setDreamButtonHoverSegMask(prevInteractiveSegMask)
}
let lineGroup2Show: LineGroup = []
if (redoLineGroups.length !== 0) {
lineGroup2Show = redoLineGroups[redoLineGroups.length - 1]
} else if (lineGroups.length !== 0) {
lineGroup2Show = lineGroups[lineGroups.length - 1]
}
console.log(
`[DREAM_BUTTON_MOUSE_ENTER], prevInteractiveSegMask: ${prevInteractiveSegMask} lineGroup2Show: ${lineGroup2Show.length}`
)
if (lineGroup2Show) {
setDreamButtonHoverLineGroup(lineGroup2Show)
}
}
})
return () => {
emitter.off(DREAM_BUTTON_MOUSE_ENTER)
}
}, [
hadDrawSomething,
interactiveSegMask,
prevInteractiveSegMask,
// drawOnCurrentRender,
lineGroups,
redoLineGroups,
])
useEffect(() => {
emitter.on(DREAM_BUTTON_MOUSE_LEAVE, () => {
// 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask
if (!hadDrawSomething() && !interactiveSegMask) {
setDreamButtonHoverSegMask(null)
setDreamButtonHoverLineGroup([])
// drawOnCurrentRender([])
}
})
return () => {
emitter.off(DREAM_BUTTON_MOUSE_LEAVE)
}
}, [hadDrawSomething, interactiveSegMask])
useEffect(() => {
emitter.on(EVENT_CUSTOM_MASK, (data: any) => {
// TODO: not work with paint by example
runInpainting(false, data.mask)
})
return () => {
emitter.off(EVENT_CUSTOM_MASK)
}
}, [runInpainting])
useEffect(() => {
emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => {
if (hadDrawSomething() || interactiveSegMask) {
runInpainting(false, undefined, interactiveSegMask, data.image)
} else if (lastLineGroup.length !== 0) {
// 使用上一次手绘的 mask 生成
runInpainting(true, undefined, prevInteractiveSegMask, data.image)
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask, data.image)
} else {
toast({
variant: "destructive",
description: "Please draw mask on picture.",
})
}
})
return () => {
emitter.off(EVENT_PAINT_BY_EXAMPLE)
}
}, [runInpainting])
useEffect(() => {
emitter.on(RERUN_LAST_MASK, () => {
if (lastLineGroup.length !== 0) {
// 使用上一次手绘的 mask 生成
runInpainting(true, undefined, prevInteractiveSegMask)
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask)
} else {
toast({
variant: "destructive",
description: "No mask to reuse",
})
}
})
return () => {
emitter.off(RERUN_LAST_MASK)
}
}, [runInpainting])
const getCurrentRender = useCallback(async () => { const getCurrentRender = useCallback(async () => {
let targetFile = file let targetFile = file
if (renders.length > 0) { if (renders.length > 0) {
@ -575,126 +219,6 @@ export default function Editor(props: EditorProps) {
return targetFile return targetFile
}, [file, renders]) }, [file, renders])
useEffect(() => {
emitter.on(PluginName.InteractiveSeg, () => {
// setIsInteractiveSeg(true)
if (interactiveSegMask !== null) {
setShowInteractiveSegModal(true)
}
})
return () => {
emitter.off(PluginName.InteractiveSeg)
}
})
const runRenderablePlugin = useCallback(
async (name: string, data?: any) => {
if (isProcessing) {
return
}
try {
// TODO 要不要加 undoCurrentLine
const start = new Date()
setIsPluginRunning(true)
const targetFile = await getCurrentRender()
const res = await runPlugin(name, targetFile, data?.upscale)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob)
setImageSize(newRender.height, newRender.width)
const newRenders = [...renders, newRender]
// setRenders(newRenders)
updateEditorState({ renders: newRenders })
const newLineGroups = [...lineGroups, []]
updateEditorState({ lineGroups: newLineGroups })
const end = new Date()
const time = end.getTime() - start.getTime()
toast({
description: `Run ${name} successfully in ${time / 1000}s`,
})
const rW = windowSize.width / newRender.width
const rH = (windowSize.height - TOOLBAR_HEIGHT) / newRender.height
let s = 1.0
if (rW < 1 || rH < 1) {
s = Math.min(rW, rH)
}
setMinScale(s)
setScale(s)
viewportRef.current?.centerView(s, 1)
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
} finally {
setIsPluginRunning(false)
}
},
[
renders,
// setRenders,
getCurrentRender,
setIsPluginRunning,
isProcessing,
setImageSize,
lineGroups,
viewportRef,
windowSize,
// setLineGroups,
]
)
useEffect(() => {
emitter.on(PluginName.RemoveBG, () => {
runRenderablePlugin(PluginName.RemoveBG)
})
return () => {
emitter.off(PluginName.RemoveBG)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.AnimeSeg, () => {
runRenderablePlugin(PluginName.AnimeSeg)
})
return () => {
emitter.off(PluginName.AnimeSeg)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.GFPGAN, () => {
runRenderablePlugin(PluginName.GFPGAN)
})
return () => {
emitter.off(PluginName.GFPGAN)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.RestoreFormer, () => {
runRenderablePlugin(PluginName.RestoreFormer)
})
return () => {
emitter.off(PluginName.RestoreFormer)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.RealESRGAN, (data: any) => {
runRenderablePlugin(PluginName.RealESRGAN, data)
})
return () => {
emitter.off(PluginName.RealESRGAN)
}
}, [runRenderablePlugin])
const hadRunInpainting = () => { const hadRunInpainting = () => {
return renders.length !== 0 return renders.length !== 0
} }
@ -736,14 +260,17 @@ export default function Editor(props: EditorProps) {
setScale(s) setScale(s)
console.log( console.log(
`[on file load] image size: ${width}x${height}, canvas size: ${context?.canvas.width}x${context?.canvas.height} scale: ${s}, initialCentered: ${initialCentered}` `[on file load] image size: ${width}x${height}, scale: ${s}, initialCentered: ${initialCentered}`
) )
if (context?.canvas) { if (context?.canvas) {
context.canvas.width = width console.log("[on file load] set canvas size")
context.canvas.height = height if (width != context.canvas.width) {
console.log("[on file load] set canvas size && drawOnCurrentRender") context.canvas.width = width
// drawOnCurrentRender([]) }
if (height != context.canvas.height) {
context.canvas.height = height
}
} }
if (!initialCentered) { if (!initialCentered) {
@ -753,13 +280,11 @@ export default function Editor(props: EditorProps) {
setInitialCentered(true) setInitialCentered(true)
} }
}, [ }, [
context?.canvas,
viewportRef, viewportRef,
original, original,
isOriginalLoaded, isOriginalLoaded,
windowSize, windowSize,
initialCentered, initialCentered,
// drawOnCurrentRender,
getCurrentWidthHeight, getCurrentWidthHeight,
]) ])
@ -807,17 +332,6 @@ export default function Editor(props: EditorProps) {
} }
}, [windowSize, resetZoom]) }, [windowSize, resetZoom])
useEffect(() => {
window.addEventListener("blur", () => {
setIsChangingBrushSizeByMouse(false)
})
return () => {
window.removeEventListener("blur", () => {
setIsChangingBrushSizeByMouse(false)
})
}
}, [])
const handleEscPressed = () => { const handleEscPressed = () => {
if (isProcessing) { if (isProcessing) {
return return
@ -845,15 +359,6 @@ export default function Editor(props: EditorProps) {
} }
const onMouseDrag = (ev: SyntheticEvent) => { const onMouseDrag = (ev: SyntheticEvent) => {
// if (isChangingBrushSizeByMouse) {
// const initX = changeBrushSizeByMouseInit.x
// // move right: increase brush size
// const newSize = changeBrushSizeByMouseInit.brushSize + (x - initX)
// if (newSize <= MAX_BRUSH_SIZE && newSize >= MIN_BRUSH_SIZE) {
// setBaseBrushSize(newSize)
// }
// return
// }
if (interactiveSegState.isInteractiveSeg) { if (interactiveSegState.isInteractiveSeg) {
return return
} }
@ -868,26 +373,16 @@ export default function Editor(props: EditorProps) {
} }
handleCanvasMouseMove(mouseXY(ev)) handleCanvasMouseMove(mouseXY(ev))
// const lineGroup = [...curLineGroup]
// lineGroup[lineGroup.length - 1].pts.push(mouseXY(ev))
// setCurLineGroup(lineGroup)
// drawOnCurrentRender(lineGroup)
} }
const runInteractiveSeg = async (newClicks: number[][]) => { const runInteractiveSeg = async (newClicks: number[][]) => {
if (!file) { updateAppState({ isPluginRunning: true })
return
}
// setIsInteractiveSegRunning(true)
const targetFile = await getCurrentRender() const targetFile = await getCurrentRender()
const prevMask = null
try { try {
const res = await runPlugin( const res = await runPlugin(
PluginName.InteractiveSeg, PluginName.InteractiveSeg,
targetFile, targetFile,
undefined, undefined,
prevMask,
newClicks newClicks
) )
if (!res) { if (!res) {
@ -896,7 +391,7 @@ export default function Editor(props: EditorProps) {
const { blob } = res const { blob } = res
const img = new Image() const img = new Image()
img.onload = () => { img.onload = () => {
setTmpInteractiveSegMask(img) updateInteractiveSegState({ tmpInteractiveSegMask: img })
} }
img.src = blob img.src = blob
} catch (e: any) { } catch (e: any) {
@ -905,7 +400,7 @@ export default function Editor(props: EditorProps) {
description: e.message ? e.message : e.toString(), description: e.message ? e.message : e.toString(),
}) })
} }
// setIsInteractiveSegRunning(false) updateAppState({ isPluginRunning: false })
} }
const onPointerUp = (ev: SyntheticEvent) => { const onPointerUp = (ev: SyntheticEvent) => {
@ -933,7 +428,7 @@ export default function Editor(props: EditorProps) {
return return
} }
if (enableManualInpainting) { if (runMannually) {
setIsDraging(false) setIsDraging(false)
} else { } else {
runInpainting() runInpainting()
@ -959,15 +454,13 @@ export default function Editor(props: EditorProps) {
const onCanvasMouseUp = (ev: SyntheticEvent) => { const onCanvasMouseUp = (ev: SyntheticEvent) => {
if (interactiveSegState.isInteractiveSeg) { if (interactiveSegState.isInteractiveSeg) {
const xy = mouseXY(ev) const xy = mouseXY(ev)
const isX = xy.x
const isY = xy.y
const newClicks: number[][] = [...interactiveSegState.clicks] const newClicks: number[][] = [...interactiveSegState.clicks]
if (isRightClick(ev)) { if (isRightClick(ev)) {
newClicks.push([isX, isY, 0, newClicks.length]) newClicks.push([xy.x, xy.y, 0, newClicks.length])
} else { } else {
newClicks.push([isX, isY, 1, newClicks.length]) newClicks.push([xy.x, xy.y, 1, newClicks.length])
} }
// runInteractiveSeg(newClicks) runInteractiveSeg(newClicks)
updateInteractiveSegState({ clicks: newClicks }) updateInteractiveSegState({ clicks: newClicks })
} }
} }
@ -979,13 +472,10 @@ export default function Editor(props: EditorProps) {
if (interactiveSegState.isInteractiveSeg) { if (interactiveSegState.isInteractiveSeg) {
return return
} }
if (isChangingBrushSizeByMouse) {
return
}
if (isPanning) { if (isPanning) {
return return
} }
if (!original.src) { if (!isOriginalLoaded) {
return return
} }
const canvas = context?.canvas const canvas = context?.canvas
@ -1002,26 +492,17 @@ export default function Editor(props: EditorProps) {
return return
} }
if ( // if (
isDiffusionModels && // isDiffusionModels &&
settings.showCroper && // settings.showCroper &&
isOutsideCroper(mouseXY(ev)) // isOutsideCroper(mouseXY(ev))
) { // ) {
return // // TODO: 去掉这个逻辑,在 cropper 层截断 click 点击?
} // return
// }
setIsDraging(true) setIsDraging(true)
// let lineGroup: LineGroup = []
// if (enableManualInpainting) {
// lineGroup = [...curLineGroup]
// }
// lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] })
// setCurLineGroup(lineGroup)
handleCanvasMouseDown(mouseXY(ev)) handleCanvasMouseDown(mouseXY(ev))
// drawOnCurrentRender(lineGroup)
} }
const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => { const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => {
@ -1092,7 +573,7 @@ export default function Editor(props: EditorProps) {
let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1") let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1")
maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg") maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg")
drawLinesOnMask(lineGroups) const maskCanvas = generateMask(imageWidth, imageHeight, lineGroups)
// Create a link // Create a link
const aDownloadLink = document.createElement("a") const aDownloadLink = document.createElement("a")
// Add the name of the file to the link // Add the name of the file to the link
@ -1147,11 +628,11 @@ export default function Editor(props: EditorProps) {
useHotkeys( useHotkeys(
"shift+r", "shift+r",
() => { () => {
if (enableManualInpainting && hadDrawSomething()) { if (runMannually && hadDrawSomething()) {
runInpainting() runInpainting()
} }
}, },
[enableManualInpainting, runInpainting, hadDrawSomething] [runMannually, runInpainting, hadDrawSomething]
) )
useHotkeys( useHotkeys(
@ -1192,13 +673,11 @@ export default function Editor(props: EditorProps) {
(ev) => { (ev) => {
ev?.preventDefault() ev?.preventDefault()
ev?.stopPropagation() ev?.stopPropagation()
setIsChangingBrushSizeByMouse(true) // TODO: mouse scroll increase/decrease brush size
setChangeBrushSizeByMouseInit({ x, y, brushSize })
}, },
(ev) => { (ev) => {
ev?.preventDefault() ev?.preventDefault()
ev?.stopPropagation() ev?.stopPropagation()
setIsChangingBrushSizeByMouse(false)
} }
) )
@ -1359,20 +838,24 @@ export default function Editor(props: EditorProps) {
show={settings.showCroper} show={settings.showCroper}
/> />
{/* {interactiveSegState.isInteractiveSeg ? <InteractiveSeg /> : <></>} */} {interactiveSegState.isInteractiveSeg ? (
<InteractiveSegPoints />
) : (
<></>
)}
</TransformComponent> </TransformComponent>
</TransformWrapper> </TransformWrapper>
) )
} }
const onInteractiveAccept = () => { // const onInteractiveAccept = () => {
setInteractiveSegMask(tmpInteractiveSegMask) // setInteractiveSegMask(tmpInteractiveSegMask)
setTmpInteractiveSegMask(null) // setTmpInteractiveSegMask(null)
if (!enableManualInpainting && tmpInteractiveSegMask) { // if (!enableManualInpainting && tmpInteractiveSegMask) {
runInpainting(false, undefined, tmpInteractiveSegMask) // runInpainting(false, undefined, tmpInteractiveSegMask)
} // }
} // }
return ( return (
<div <div
@ -1388,16 +871,11 @@ export default function Editor(props: EditorProps) {
!isPanning && !isPanning &&
(interactiveSegState.isInteractiveSeg (interactiveSegState.isInteractiveSeg
? renderInteractiveSegCursor() ? renderInteractiveSegCursor()
: renderBrush( : renderBrush(getBrushStyle(x, y)))}
getBrushStyle(
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.x : x,
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.y : y
)
))}
{showRefBrush && renderBrush(getBrushStyle(windowCenterX, windowCenterY))} {showRefBrush && renderBrush(getBrushStyle(windowCenterX, windowCenterY))}
<div className="fixed flex bottom-5 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md bg-background/50"> <div className="fixed flex bottom-5 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md bg-background/70">
<Slider <Slider
className="w-48" className="w-48"
defaultValue={[50]} defaultValue={[50]}
@ -1464,16 +942,17 @@ export default function Editor(props: EditorProps) {
<Download /> <Download />
</IconButton> </IconButton>
{settings.enableManualInpainting ? ( {settings.enableManualInpainting &&
settings.model.model_type === "inpaint" ? (
<IconButton <IconButton
tooltip="Run Inpainting" tooltip="Run Inpainting"
disabled={ disabled={
isProcessing || isProcessing ||
(!hadDrawSomething() && interactiveSegMask === null) (!hadDrawSomething() &&
interactiveSegState.interactiveSegMask === null)
} }
onClick={() => { onClick={() => {
// ensured by disabled runInpainting()
runInpainting(false, undefined, interactiveSegMask)
}} }}
> >
<Eraser /> <Eraser />

View File

@ -69,7 +69,7 @@ const Header = () => {
) )
return ( return (
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 backdrop-filter backdrop-blur-md border-b"> <header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
<div className="flex items-center gap-1"> <div className="flex items-center gap-1">
{enableFileManager ? ( {enableFileManager ? (
<FileManager <FileManager

View File

@ -39,19 +39,20 @@ const InteractiveSegReplaceModal = (props: InteractiveSegReplaceModal) => {
} }
const InteractiveSegConfirmActions = () => { const InteractiveSegConfirmActions = () => {
const [interactiveSegState, resetInteractiveSegState] = useStore((state) => [ const [
interactiveSegState,
resetInteractiveSegState,
handleInteractiveSegAccept,
] = useStore((state) => [
state.interactiveSegState, state.interactiveSegState,
state.resetInteractiveSegState, state.resetInteractiveSegState,
state.handleInteractiveSegAccept,
]) ])
if (!interactiveSegState.isInteractiveSeg) { if (!interactiveSegState.isInteractiveSeg) {
return null return null
} }
const onAcceptClick = () => {
resetInteractiveSegState()
}
return ( return (
<div className="z-10 absolute top-[68px] rounded-xl border-solid border p-[8px] left-1/2 translate-x-[-50%] flex justify-center items-center gap-[8px] bg-background"> <div className="z-10 absolute top-[68px] rounded-xl border-solid border p-[8px] left-1/2 translate-x-[-50%] flex justify-center items-center gap-[8px] bg-background">
<Button <Button
@ -66,7 +67,7 @@ const InteractiveSegConfirmActions = () => {
<Button <Button
size="sm" size="sm"
onClick={() => { onClick={() => {
onAcceptClick() handleInteractiveSegAccept()
}} }}
> >
Accept Accept
@ -84,11 +85,11 @@ interface ItemProps {
const Item = (props: ItemProps) => { const Item = (props: ItemProps) => {
const { x, y, positive } = props const { x, y, positive } = props
const name = positive const name = positive
? "bg-[rgba(21,_215,_121,_0.936)] outline-[6px_solid_rgba(98,_255,_179,_0.31)]" ? "bg-[rgba(21,_215,_121,_0.936)] outline-[rgba(98,255,179,0.31)]"
: "bg-[rgba(237,_49,_55,_0.942)] outline-[6px_solid_rgba(255,_89,_95,_0.31)]" : "bg-[rgba(237,_49,_55,_0.942)] outline-[rgba(255,89,95,0.31)]"
return ( return (
<div <div
className={`absolute h-[8px] w-[8px] rounded-[50%] ${name}`} className={`absolute h-[10px] w-[10px] rounded-[50%] ${name} outline-8 outline`}
style={{ style={{
left: x, left: x,
top: y, top: y,

View File

@ -15,9 +15,7 @@ import {
Slice, Slice,
Smile, Smile,
} from "lucide-react" } from "lucide-react"
import { MixIcon } from "@radix-ui/react-icons"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import { InteractiveSeg } from "./InteractiveSeg"
export enum PluginName { export enum PluginName {
RemoveBG = "RemoveBG", RemoveBG = "RemoveBG",
@ -56,44 +54,49 @@ const pluginMap = {
} }
const Plugins = () => { const Plugins = () => {
const [plugins, updateInteractiveSegState] = useStore((state) => [ const [file, plugins, updateInteractiveSegState, runRenderablePlugin] =
state.serverConfig.plugins, useStore((state) => [
state.updateInteractiveSegState, state.file,
]) state.serverConfig.plugins,
state.updateInteractiveSegState,
state.runRenderablePlugin,
])
const disabled = !file
if (plugins.length === 0) { if (plugins.length === 0) {
return null return null
} }
const onPluginClick = (pluginName: string) => { const onPluginClick = (pluginName: string) => {
// if (!disabled) {
// emitter.emit(pluginName)
// }
if (pluginName === PluginName.InteractiveSeg) { if (pluginName === PluginName.InteractiveSeg) {
updateInteractiveSegState({ isInteractiveSeg: true }) updateInteractiveSegState({ isInteractiveSeg: true })
} else {
runRenderablePlugin(pluginName)
} }
} }
const onRealESRGANClick = (upscale: number) => {
// if (!disabled) {
// emitter.emit(PluginName.RealESRGAN, { upscale })
// }
}
const renderRealESRGANPlugin = () => { const renderRealESRGANPlugin = () => {
return ( return (
<DropdownMenuSub key="RealESRGAN"> <DropdownMenuSub key="RealESRGAN">
<DropdownMenuSubTrigger> <DropdownMenuSubTrigger disabled={disabled}>
<div className="flex gap-2 items-center"> <div className="flex gap-2 items-center">
<Fullscreen /> <Fullscreen />
RealESRGAN RealESRGAN
</div> </div>
</DropdownMenuSubTrigger> </DropdownMenuSubTrigger>
<DropdownMenuSubContent> <DropdownMenuSubContent>
<DropdownMenuItem onClick={() => onRealESRGANClick(2)}> <DropdownMenuItem
onClick={() =>
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 2 })
}
>
upscale 2x upscale 2x
</DropdownMenuItem> </DropdownMenuItem>
<DropdownMenuItem onClick={() => onRealESRGANClick(4)}> <DropdownMenuItem
onClick={() =>
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 4 })
}
>
upscale 4x upscale 4x
</DropdownMenuItem> </DropdownMenuItem>
</DropdownMenuSubContent> </DropdownMenuSubContent>
@ -108,7 +111,11 @@ const Plugins = () => {
return renderRealESRGANPlugin() return renderRealESRGANPlugin()
} }
return ( return (
<DropdownMenuItem key={plugin} onClick={() => onPluginClick(plugin)}> <DropdownMenuItem
key={plugin}
onClick={() => onPluginClick(plugin)}
disabled={disabled}
>
<div className="flex gap-2 items-center"> <div className="flex gap-2 items-center">
<IconClass className="p-1" /> <IconClass className="p-1" />
{showName} {showName}
@ -121,7 +128,7 @@ const Plugins = () => {
return ( return (
<DropdownMenu modal={false}> <DropdownMenu modal={false}>
<DropdownMenuTrigger <DropdownMenuTrigger
className="border rounded-lg z-10 bg-background" className="border rounded-lg z-10 bg-background outline-none"
tabIndex={-1} tabIndex={-1}
> >
<Button variant="ghost" size="icon" asChild className="p-1.5"> <Button variant="ghost" size="icon" asChild className="p-1.5">

View File

@ -1,19 +1,17 @@
import React, { FormEvent } from "react" import React, { FormEvent } from "react"
import emitter, {
DREAM_BUTTON_MOUSE_ENTER,
DREAM_BUTTON_MOUSE_LEAVE,
EVENT_PROMPT,
} from "@/lib/event"
import { Button } from "./ui/button" import { Button } from "./ui/button"
import { Input } from "./ui/input" import { Input } from "./ui/input"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
const PromptInput = () => { const PromptInput = () => {
const [isInpainting, prompt, updateSettings] = useStore((state) => [ const [isProcessing, prompt, updateSettings, runInpainting] = useStore(
state.isInpainting, (state) => [
state.settings.prompt, state.getIsProcessing(),
state.updateSettings, state.settings.prompt,
]) state.updateSettings,
state.runInpainting,
]
)
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => { const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
evt.preventDefault() evt.preventDefault()
@ -22,25 +20,25 @@ const PromptInput = () => {
updateSettings({ prompt: target.value }) updateSettings({ prompt: target.value })
} }
const handleRepaintClick = () => { const handleRepaintClick = async () => {
if (prompt.length !== 0 && isInpainting) { if (prompt.length !== 0 && !isProcessing) {
emitter.emit(EVENT_PROMPT) await runInpainting()
} }
} }
const onKeyUp = (e: React.KeyboardEvent) => { const onKeyUp = (e: React.KeyboardEvent) => {
if (e.key === "Enter" && !isInpainting) { if (e.key === "Enter" && !isProcessing) {
handleRepaintClick() handleRepaintClick()
} }
} }
const onMouseEnter = () => { // const onMouseEnter = () => {
emitter.emit(DREAM_BUTTON_MOUSE_ENTER) // emitter.emit(DREAM_BUTTON_MOUSE_ENTER)
} // }
const onMouseLeave = () => { // const onMouseLeave = () => {
emitter.emit(DREAM_BUTTON_MOUSE_LEAVE) // emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
} // }
return ( return (
<div className="flex gap-4 items-center"> <div className="flex gap-4 items-center">
@ -54,9 +52,9 @@ const PromptInput = () => {
<Button <Button
size="sm" size="sm"
onClick={handleRepaintClick} onClick={handleRepaintClick}
disabled={prompt.length === 0 || isInpainting} disabled={prompt.length === 0 || isProcessing}
onMouseEnter={onMouseEnter} // onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave} // onMouseLeave={onMouseLeave}
> >
Dream Dream
</Button> </Button>

View File

@ -29,7 +29,7 @@ const DropdownMenuSubTrigger = React.forwardRef<
<DropdownMenuPrimitive.SubTrigger <DropdownMenuPrimitive.SubTrigger
ref={ref} ref={ref}
className={cn( className={cn(
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent data-[state=open]:bg-accent", "flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent data-[state=open]:bg-accent data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
inset && "pl-8", inset && "pl-8",
className className
)} )}
@ -72,6 +72,7 @@ const DropdownMenuContent = React.forwardRef<
className className
)} )}
{...props} {...props}
onCloseAutoFocus={(e) => e.preventDefault()}
/> />
</DropdownMenuPrimitive.Portal> </DropdownMenuPrimitive.Portal>
)) ))

View File

@ -15,18 +15,13 @@ export default async function inpaint(
imageFile: File, imageFile: File,
settings: Settings, settings: Settings,
croperRect: Rect, croperRect: Rect,
maskBase64?: string, mask: File | Blob,
customMask?: File, paintByExampleImage: File | null = null
paintByExampleImage?: File
) { ) {
// 1080, 2000, Original // 1080, 2000, Original
const fd = new FormData() const fd = new FormData()
fd.append("image", imageFile) fd.append("image", imageFile)
if (maskBase64 !== undefined) { fd.append("mask", mask)
fd.append("mask", dataURItoBlob(maskBase64))
} else if (customMask !== undefined) {
fd.append("mask", customMask)
}
fd.append("ldmSteps", settings.ldmSteps.toString()) fd.append("ldmSteps", settings.ldmSteps.toString())
fd.append("ldmSampler", settings.ldmSampler.toString()) fd.append("ldmSampler", settings.ldmSampler.toString())
@ -42,8 +37,7 @@ export default async function inpaint(
fd.append("croperY", croperRect.y.toString()) fd.append("croperY", croperRect.y.toString())
fd.append("croperHeight", croperRect.height.toString()) fd.append("croperHeight", croperRect.height.toString())
fd.append("croperWidth", croperRect.width.toString()) fd.append("croperWidth", croperRect.width.toString())
// fd.append("useCroper", settings.showCroper ? "true" : "false") fd.append("useCroper", settings.showCroper ? "true" : "false")
fd.append("useCroper", "false")
fd.append("sdMaskBlur", settings.sdMaskBlur.toString()) fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
fd.append("sdStrength", settings.sdStrength.toString()) fd.append("sdStrength", settings.sdStrength.toString())
@ -147,7 +141,6 @@ export async function runPlugin(
name: string, name: string,
imageFile: File, imageFile: File,
upscale?: number, upscale?: number,
maskFile?: File | null,
clicks?: number[][] clicks?: number[][]
) { ) {
const fd = new FormData() const fd = new FormData()
@ -159,9 +152,6 @@ export async function runPlugin(
if (clicks) { if (clicks) {
fd.append("clicks", JSON.stringify(clicks)) fd.append("clicks", JSON.stringify(clicks))
} }
if (maskFile) {
fd.append("mask", maskFile)
}
try { try {
const res = await fetch(`${API_ENDPOINT}/run_plugin`, { const res = await fetch(`${API_ENDPOINT}/run_plugin`, {

View File

@ -1,7 +1,8 @@
import { create } from "zustand"
import { persist } from "zustand/middleware" import { persist } from "zustand/middleware"
import { shallow } from "zustand/shallow" import { shallow } from "zustand/shallow"
import { immer } from "zustand/middleware/immer" import { immer } from "zustand/middleware/immer"
import { castDraft } from "immer"
import { nanoid } from "nanoid"
import { createWithEqualityFn } from "zustand/traditional" import { createWithEqualityFn } from "zustand/traditional"
import { import {
CV2Flag, CV2Flag,
@ -10,6 +11,7 @@ import {
Line, Line,
LineGroup, LineGroup,
ModelInfo, ModelInfo,
PluginParams,
Point, Point,
SDSampler, SDSampler,
Size, Size,
@ -17,6 +19,9 @@ import {
SortOrder, SortOrder,
} from "./types" } from "./types"
import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const" import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const"
import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils"
import inpaint, { runPlugin } from "./api"
import { toast, useToast } from "@/components/ui/use-toast"
type FileManagerState = { type FileManagerState = {
sortBy: SortBy sortBy: SortBy
@ -95,7 +100,9 @@ type ServerConfig = {
type InteractiveSegState = { type InteractiveSegState = {
isInteractiveSeg: boolean isInteractiveSeg: boolean
isInteractiveSegRunning: boolean interactiveSegMask: HTMLImageElement | null
tmpInteractiveSegMask: HTMLImageElement | null
prevInteractiveSegMask: HTMLImageElement | null
clicks: number[][] clicks: number[][]
} }
@ -103,9 +110,11 @@ type EditorState = {
baseBrushSize: number baseBrushSize: number
brushSizeScale: number brushSizeScale: number
renders: HTMLImageElement[] renders: HTMLImageElement[]
paintByExampleImage: File | null
lineGroups: LineGroup[] lineGroups: LineGroup[]
lastLineGroup: LineGroup lastLineGroup: LineGroup
curLineGroup: LineGroup curLineGroup: LineGroup
extraMasks: HTMLImageElement[]
// redo 相关 // redo 相关
redoRenders: HTMLImageElement[] redoRenders: HTMLImageElement[]
redoCurLines: Line[] redoCurLines: Line[]
@ -113,6 +122,8 @@ type EditorState = {
} }
type AppState = { type AppState = {
idForUpdateView: string
file: File | null file: File | null
customMask: File | null customMask: File | null
imageHeight: number imageHeight: number
@ -136,6 +147,7 @@ type AppAction = {
setCustomFile: (file: File) => void setCustomFile: (file: File) => void
setIsInpainting: (newValue: boolean) => void setIsInpainting: (newValue: boolean) => void
setIsPluginRunning: (newValue: boolean) => void setIsPluginRunning: (newValue: boolean) => void
getIsProcessing: () => boolean
setBaseBrushSize: (newValue: number) => void setBaseBrushSize: (newValue: number) => void
getBrushSize: () => number getBrushSize: () => number
setImageSize: (width: number, height: number) => void setImageSize: (width: number, height: number) => void
@ -151,10 +163,18 @@ type AppAction = {
updateFileManagerState: (newState: Partial<FileManagerState>) => void updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void resetInteractiveSegState: () => void
handleInteractiveSegAccept: () => void
showPromptInput: () => boolean showPromptInput: () => boolean
showSidePanel: () => boolean showSidePanel: () => boolean
runInpainting: () => Promise<void>
runRenderablePlugin: (
pluginName: string,
params?: PluginParams
) => Promise<void>
// EditorState // EditorState
getCurrentTargetFile: () => Promise<File>
updateEditorState: (newState: Partial<EditorState>) => void updateEditorState: (newState: Partial<EditorState>) => void
runMannually: () => boolean runMannually: () => boolean
handleCanvasMouseDown: (point: Point) => void handleCanvasMouseDown: (point: Point) => void
@ -168,12 +188,15 @@ type AppAction = {
} }
const defaultValues: AppState = { const defaultValues: AppState = {
idForUpdateView: nanoid(),
file: null, file: null,
customMask: null, customMask: null,
imageHeight: 0, imageHeight: 0,
imageWidth: 0, imageWidth: 0,
isInpainting: false, isInpainting: false,
isPluginRunning: false, isPluginRunning: false,
windowSize: { windowSize: {
height: 600, height: 600,
width: 800, width: 800,
@ -182,6 +205,8 @@ const defaultValues: AppState = {
baseBrushSize: DEFAULT_BRUSH_SIZE, baseBrushSize: DEFAULT_BRUSH_SIZE,
brushSizeScale: 1, brushSizeScale: 1,
renders: [], renders: [],
paintByExampleImage: null,
extraMasks: [],
lineGroups: [], lineGroups: [],
lastLineGroup: [], lastLineGroup: [],
curLineGroup: [], curLineGroup: [],
@ -192,7 +217,9 @@ const defaultValues: AppState = {
interactiveSegState: { interactiveSegState: {
isInteractiveSeg: false, isInteractiveSeg: false,
isInteractiveSegRunning: false, interactiveSegMask: null,
tmpInteractiveSegMask: null,
prevInteractiveSegMask: null,
clicks: [], clicks: [],
}, },
@ -267,16 +294,208 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
immer((set, get) => ({ immer((set, get) => ({
...defaultValues, ...defaultValues,
getCurrentTargetFile: async (): Promise<File> => {
const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数
const renders = get().editorState.renders
let targetFile = file
if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
return targetFile
},
runInpainting: async () => {
const { file, imageWidth, imageHeight, settings, cropperState } = get()
if (file === null) {
return
}
const {
lastLineGroup,
curLineGroup,
lineGroups,
renders,
paintByExampleImage,
} = get().editorState
const { interactiveSegMask, prevInteractiveSegMask } =
get().interactiveSegState
const useLastLineGroup =
curLineGroup.length === 0 && interactiveSegMask === null
const maskImage = useLastLineGroup
? prevInteractiveSegMask
: interactiveSegMask
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
if (lastLineGroup.length === 0 && maskImage === null) {
toast({
variant: "destructive",
description: "Please draw mask on picture",
})
return
}
maskLineGroup = lastLineGroup
} else {
if (curLineGroup.length === 0 && maskImage === null) {
toast({
variant: "destructive",
description: "Please draw mask on picture",
})
return
}
maskLineGroup = curLineGroup
}
const newLineGroups = [...lineGroups, maskLineGroup]
set((state) => {
state.isInpainting = true
})
let targetFile = file
if (useLastLineGroup === true) {
// renders.length == 1 还是用原来的
if (renders.length > 1) {
const lastRender = renders[renders.length - 2]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
} else if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
const maskCanvas = generateMask(
imageWidth,
imageHeight,
[maskLineGroup],
maskImage ? [maskImage] : []
)
try {
const res = await inpaint(
targetFile,
settings,
cropperState,
dataURItoBlob(maskCanvas.toDataURL()),
paintByExampleImage
)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob, seed } = res
if (seed) {
set((state) => (state.settings.seed = parseInt(seed, 10)))
}
const newRender = new Image()
await loadImage(newRender, blob)
if (useLastLineGroup === true) {
const prevRenders = renders.slice(0, -1)
const newRenders = [...prevRenders, newRender]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
lastLineGroup: curLineGroup,
curLineGroup: [],
})
} else {
const newRenders = [...renders, newRender]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
lastLineGroup: curLineGroup,
curLineGroup: [],
})
}
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
}
get().resetRedoState()
set((state) => {
state.isInpainting = false
})
const newInteractiveSegState = {
...defaultValues.interactiveSegState,
prevInteractiveSegMask: useLastLineGroup ? null : maskImage,
}
set((state) => {
state.interactiveSegState = castDraft(newInteractiveSegState)
})
},
runRenderablePlugin: async (
pluginName: string,
params: PluginParams = { upscale: 1 }
) => {
const { renders, lineGroups } = get().editorState
set((state) => {
state.isInpainting = true
})
try {
const start = new Date()
const targetFile = await get().getCurrentTargetFile()
const res = await runPlugin(pluginName, targetFile, params.upscale)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob)
get().setImageSize(newRender.height, newRender.width)
const newRenders = [...renders, newRender]
const newLineGroups = [...lineGroups, []]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
})
const end = new Date()
const time = end.getTime() - start.getTime()
toast({
description: `Run ${pluginName} successfully in ${time / 1000}s`,
})
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
}
set((state) => {
state.isInpainting = false
})
},
// Edirot State // // Edirot State //
updateEditorState: (newState: Partial<EditorState>) => { updateEditorState: (newState: Partial<EditorState>) => {
set((state) => { set((state) => {
return { state.editorState = castDraft({ ...state.editorState, ...newState })
...state,
editorState: {
...state.editorState,
...newState,
},
}
}) })
}, },
@ -313,6 +532,10 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
) )
}, },
getIsProcessing: (): boolean => {
return get().isInpainting || get().isPluginRunning
},
// undo/redo // undo/redo
undoDisabled: (): boolean => { undoDisabled: (): boolean => {
@ -468,15 +691,30 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => { updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => {
set((state) => { set((state) => {
state.interactiveSegState = { return {
...state.interactiveSegState, ...state,
...newState, interactiveSegState: {
...state.interactiveSegState,
...newState,
},
} }
}) })
}, },
resetInteractiveSegState: () => { resetInteractiveSegState: () => {
get().updateInteractiveSegState(defaultValues.interactiveSegState)
},
handleInteractiveSegAccept: () => {
set((state) => { set((state) => {
state.interactiveSegState = defaultValues.interactiveSegState return {
...state,
interactiveSegState: {
...defaultValues.interactiveSegState,
interactiveSegMask:
state.interactiveSegState.tmpInteractiveSegMask,
},
}
}) })
}, },
@ -492,8 +730,12 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
setFile: (file: File) => setFile: (file: File) =>
set((state) => { set((state) => {
// TODO: 清空各种状态
state.file = file state.file = file
state.interactiveSegState = castDraft(
defaultValues.interactiveSegState
)
state.editorState = castDraft(defaultValues.editorState)
state.cropperState = defaultValues.cropperState
}), }),
setCustomFile: (file: File) => setCustomFile: (file: File) =>

View File

@ -25,6 +25,10 @@ export enum PluginName {
InteractiveSeg = "InteractiveSeg", InteractiveSeg = "InteractiveSeg",
} }
export interface PluginParams {
upscale: number
}
export enum SortBy { export enum SortBy {
NAME = "name", NAME = "name",
CTIME = "ctime", CTIME = "ctime",

View File

@ -159,3 +159,28 @@ export function drawLines(
ctx.stroke() ctx.stroke()
}) })
} }
export const generateMask = (
imageWidth: number,
imageHeight: number,
lineGroups: LineGroup[],
maskImages: HTMLImageElement[] = []
): HTMLCanvasElement => {
const maskCanvas = document.createElement("canvas")
maskCanvas.width = imageWidth
maskCanvas.height = imageHeight
const ctx = maskCanvas.getContext("2d")
if (!ctx) {
throw new Error("could not retrieve mask canvas")
}
maskImages.forEach((maskImage) => {
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
})
lineGroups.forEach((lineGroup) => {
drawLines(ctx, lineGroup, "white")
})
return maskCanvas
}