This commit is contained in:
Qing 2023-12-11 22:28:07 +08:00
parent fecf4beef0
commit 354a1280a4
13 changed files with 531 additions and 747 deletions

View File

@ -36,16 +36,15 @@ class ModelManager:
return ControlNet(device, **{**kwargs, "model_info": model_info})
else:
if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
]:
raise NotImplementedError(
f"When using non inpaint Stable Diffusion model, you must enable controlnet"
)
if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT:
return SD(device, model_id_or_path=model_info.path, **kwargs)
if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT:
if model_info.model_type in [
ModelType.DIFFUSERS_SDXL_INPAINT,
ModelType.DIFFUSERS_SDXL,
]:
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
raise NotImplementedError(f"Unsupported model: {name}")
@ -88,7 +87,7 @@ class ModelManager:
if self.kwargs["sd_controlnet_method"] == control_method:
return
if not self.available_models[self.name].support_controlnet():
if not self.available_models[self.name].support_controlnet:
return
del self.model
@ -105,7 +104,7 @@ class ModelManager:
if str(self.model.device) == "mps":
return
if self.available_models[self.name].support_freeu():
if self.available_models[self.name].support_freeu:
if config.sd_freeu:
freeu_config = config.sd_freeu_config
self.model.model.enable_freeu(
@ -118,7 +117,7 @@ class ModelManager:
self.model.model.disable_freeu()
def enable_disable_lcm_lora(self, config: Config):
if self.available_models[self.name].support_lcm_lora():
if self.available_models[self.name].support_lcm_lora:
if config.sd_lcm_lora:
if not self.model.model.pipe.get_list_adapters():
self.model.model.load_lora_weights(self.model.lcm_lora_id)

View File

@ -1,8 +1,14 @@
from typing import Optional
from typing import Optional, List
from enum import Enum
from PIL.Image import Image
from pydantic import BaseModel
from pydantic import BaseModel, computed_field
from lama_cleaner.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
)
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
@ -31,6 +37,36 @@ class ModelInfo(BaseModel):
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
"timbrooks/instruct-pix2pix",
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if self.name in ["stabilityai/stable-diffusion-2-inpainting"]:
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
@ -39,6 +75,8 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
@ -47,6 +85,8 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_freeu(self) -> bool:
return (
self.model_type
@ -56,7 +96,7 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
or "instruct-pix2pix" in self.name
or "timbrooks/instruct-pix2pix" in self.name
)

View File

@ -419,14 +419,8 @@ def run_plugin():
@app.route("/server_config", methods=["GET"])
def get_server_config():
controlnet = {
"SD": SD_CONTROLNET_CHOICES,
"SD2": SD2_CONTROLNET_CHOICES,
"SDXL": SDXL_CONTROLNET_CHOICES,
}
return {
"plugins": list(plugins.keys()),
"availableControlNet": controlnet,
"enableFileManager": enable_file_manager,
"enableAutoSaving": enable_auto_saving,
}, 200
@ -434,20 +428,12 @@ def get_server_config():
@app.route("/models", methods=["GET"])
def get_models():
return [
{
**it.dict(),
"support_lcm_lora": it.support_lcm_lora(),
"support_controlnet": it.support_controlnet(),
"support_freeu": it.support_freeu(),
}
for it in model.scan_models()
]
return [it.model_dump() for it in model.scan_models()]
@app.route("/model")
def current_model():
return model.available_models[model.name].dict(), 200
return model.available_models[model.name].model_dump(), 200
@app.route("/is_desktop")
@ -600,8 +586,20 @@ def main(args):
else:
input_image_path = args.input
# 为了兼容性
model_name_map = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
"sd2": "stabilityai/stable-diffusion-2-inpainting",
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
}
model = ModelManager(
name=args.model,
name=model_name_map.get(args.model, args.model),
sd_controlnet=args.sd_controlnet,
sd_controlnet_method=args.sd_controlnet_method,
device=device,

View File

@ -6,14 +6,15 @@ import {
TransformComponent,
TransformWrapper,
} from "react-zoom-pan-pinch"
import { useKeyPressEvent, useWindowSize } from "react-use"
import inpaint, { downloadToOutput, runPlugin } from "@/lib/api"
import { useKeyPressEvent } from "react-use"
import { downloadToOutput, runPlugin } from "@/lib/api"
import { IconButton } from "@/components/ui/button"
import {
askWritePermission,
copyCanvasImage,
downloadImage,
drawLines,
generateMask,
isMidClick,
isRightClick,
loadImage,
@ -21,20 +22,13 @@ import {
srcToFile,
} from "@/lib/utils"
import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react"
import emitter, {
EVENT_PROMPT,
EVENT_CUSTOM_MASK,
EVENT_PAINT_BY_EXAMPLE,
RERUN_LAST_MASK,
DREAM_BUTTON_MOUSE_ENTER,
DREAM_BUTTON_MOUSE_LEAVE,
} from "@/lib/event"
import { useImage } from "@/hooks/useImage"
import { Slider } from "./ui/slider"
import { Line, LineGroup, PluginName } from "@/lib/types"
import { PluginName } from "@/lib/types"
import { useHotkeys } from "react-hotkeys-hook"
import { useStore } from "@/lib/states"
import Cropper from "./Cropper"
import { InteractiveSegPoints } from "./InteractiveSeg"
const TOOLBAR_HEIGHT = 200
const MIN_BRUSH_SIZE = 10
@ -50,6 +44,8 @@ export default function Editor(props: EditorProps) {
const { toast } = useToast()
const [
idForUpdateView,
windowSize,
isInpainting,
imageWidth,
imageHeight,
@ -75,7 +71,13 @@ export default function Editor(props: EditorProps) {
redo,
undoDisabled,
redoDisabled,
isProcessing,
updateAppState,
runMannually,
runInpainting,
] = useStore((state) => [
state.idForUpdateView,
state.windowSize,
state.isInpainting,
state.imageWidth,
state.imageHeight,
@ -101,65 +103,32 @@ export default function Editor(props: EditorProps) {
state.redo,
state.undoDisabled(),
state.redoDisabled(),
state.getIsProcessing(),
state.updateAppState,
state.runMannually(),
state.runInpainting,
])
const baseBrushSize = useStore((state) => state.editorState.baseBrushSize)
const brushSize = useStore((state) => state.getBrushSize())
const renders = useStore((state) => state.editorState.renders)
const extraMasks = useStore((state) => state.editorState.extraMasks)
const lineGroups = useStore((state) => state.editorState.lineGroups)
const lastLineGroup = useStore((state) => state.editorState.lastLineGroup)
const curLineGroup = useStore((state) => state.editorState.curLineGroup)
const redoLineGroups = useStore((state) => state.editorState.redoLineGroups)
// 纯 local state
// Local State
const [showOriginal, setShowOriginal] = useState(false)
//
const isProcessing = isInpainting
const isDiffusionModels = false
const isPix2Pix = false
const [showInteractiveSegModal, setShowInteractiveSegModal] = useState(false)
const [interactiveSegMask, setInteractiveSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
// only used while interactive segmentation is on
const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
const [prevInteractiveSegMask, setPrevInteractiveSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
// 仅用于在 dream button hover 时显示提示
const [dreamButtonHoverSegMask, setDreamButtonHoverSegMask] = useState<
HTMLImageElement | null | undefined
>(null)
const [dreamButtonHoverLineGroup, setDreamButtonHoverLineGroup] =
useState<LineGroup>([])
const [original, isOriginalLoaded] = useImage(file)
const [context, setContext] = useState<CanvasRenderingContext2D>()
const [maskCanvas] = useState<HTMLCanvasElement>(() => {
return document.createElement("canvas")
})
const [{ x, y }, setCoords] = useState({ x: -1, y: -1 })
const [showBrush, setShowBrush] = useState(false)
const [showRefBrush, setShowRefBrush] = useState(false)
const [isPanning, setIsPanning] = useState<boolean>(false)
const [isChangingBrushSizeByMouse, setIsChangingBrushSizeByMouse] =
useState<boolean>(false)
const [changeBrushSizeByMouseInit, setChangeBrushSizeByMouseInit] = useState({
x: -1,
y: -1,
brushSize: 20,
})
const [scale, setScale] = useState<number>(1)
const [panned, setPanned] = useState<boolean>(false)
const [minScale, setMinScale] = useState<number>(1.0)
const windowSize = useWindowSize()
const windowCenterX = windowSize.width / 2
const windowCenterY = windowSize.height / 2
const viewportRef = useRef<ReactZoomPanPinchContentRef | null>(null)
@ -170,402 +139,77 @@ export default function Editor(props: EditorProps) {
const [sliderPos, setSliderPos] = useState<number>(0)
const draw = useCallback(
(render: HTMLImageElement, lineGroup: LineGroup) => {
if (!context) {
return
}
console.log(
`[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
)
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
context.drawImage(render, 0, 0, imageWidth, imageHeight)
if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) {
context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight)
}
if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) {
context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight)
}
if (dreamButtonHoverSegMask) {
context.drawImage(
dreamButtonHoverSegMask,
0,
0,
imageWidth,
imageHeight
)
}
drawLines(context, lineGroup)
drawLines(context, dreamButtonHoverLineGroup)
},
[
context,
interactiveSegState,
tmpInteractiveSegMask,
dreamButtonHoverSegMask,
interactiveSegMask,
imageHeight,
imageWidth,
dreamButtonHoverLineGroup,
]
)
const drawLinesOnMask = useCallback(
(_lineGroups: LineGroup[], maskImage?: HTMLImageElement | null) => {
if (!context?.canvas.width || !context?.canvas.height) {
throw new Error("canvas has invalid size")
}
maskCanvas.width = context?.canvas.width
maskCanvas.height = context?.canvas.height
const ctx = maskCanvas.getContext("2d")
if (!ctx) {
throw new Error("could not retrieve mask canvas")
}
if (maskImage !== undefined && maskImage !== null) {
// TODO: check whether draw yellow mask works on backend
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
}
_lineGroups.forEach((lineGroup) => {
drawLines(ctx, lineGroup, "white")
})
if (
(maskImage === undefined || maskImage === null) &&
_lineGroups.length === 1 &&
_lineGroups[0].length === 0 &&
isPix2Pix
) {
// For InstructPix2Pix without mask
drawLines(
ctx,
[
{
size: 9999999999,
pts: [
{ x: 0, y: 0 },
{ x: imageWidth, y: 0 },
{ x: imageWidth, y: imageHeight },
{ x: 0, y: imageHeight },
],
},
],
"white"
)
}
},
[context, maskCanvas, imageWidth, imageHeight]
)
const hadDrawSomething = useCallback(() => {
return curLineGroup.length !== 0
}, [curLineGroup])
// const drawOnCurrentRender = useCallback(
// (lineGroup: LineGroup) => {
// console.log("[drawOnCurrentRender] draw on current render")
// if (renders.length === 0) {
// draw(original, lineGroup)
// } else {
// draw(renders[renders.length - 1], lineGroup)
// }
// },
// [original, renders, draw]
// )
useEffect(() => {
if (!context) {
if (
!context ||
!isOriginalLoaded ||
imageWidth === 0 ||
imageHeight === 0
) {
return
}
const render = renders.length === 0 ? original : renders[renders.length - 1]
console.log(
`[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
`[draw] renders.length: ${renders.length} render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
)
context.canvas.width = imageWidth
context.canvas.height = imageHeight
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
context.drawImage(render, 0, 0, imageWidth, imageHeight)
// if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) {
// context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight)
// }
// if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) {
// context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight)
// }
extraMasks.forEach((maskImage) => {
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
})
if (
interactiveSegState.isInteractiveSeg &&
interactiveSegState.tmpInteractiveSegMask
) {
context.drawImage(
interactiveSegState.tmpInteractiveSegMask,
0,
0,
imageWidth,
imageHeight
)
}
if (
!interactiveSegState.isInteractiveSeg &&
interactiveSegState.interactiveSegMask
) {
context.drawImage(
interactiveSegState.interactiveSegMask,
0,
0,
imageWidth,
imageHeight
)
}
// if (dreamButtonHoverSegMask) {
// context.drawImage(dreamButtonHoverSegMask, 0, 0, imageWidth, imageHeight)
// }
drawLines(context, curLineGroup)
// drawLines(context, dreamButtonHoverLineGroup)
}, [renders, file, original, context, curLineGroup, imageHeight, imageWidth])
const runInpainting = useCallback(
async (
useLastLineGroup?: boolean,
customMask?: File,
maskImage?: HTMLImageElement | null,
paintByExampleImage?: File
) => {
// customMask: mask uploaded by user
// maskImage: mask from interactive segmentation
if (file === undefined) {
return
}
const useCustomMask = customMask !== undefined && customMask !== null
const useMaskImage = maskImage !== undefined && maskImage !== null
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
console.log("runInpainting")
console.log({
useCustomMask,
useMaskImage,
})
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
if (lastLineGroup.length === 0) {
return
}
maskLineGroup = lastLineGroup
} else if (!useCustomMask) {
if (!hadDrawSomething() && !useMaskImage) {
return
}
// setLastLineGroup(curLineGroup)
maskLineGroup = curLineGroup
}
const newLineGroups = [...lineGroups, maskLineGroup]
cleanCurLineGroup()
setIsDraging(false)
setIsInpainting(true)
drawLinesOnMask([maskLineGroup], maskImage)
let targetFile = file
console.log(
`randers.length ${renders.length} useLastLineGroup: ${useLastLineGroup}`
)
if (useLastLineGroup === true) {
// renders.length == 1 还是用原来的
if (renders.length > 1) {
const lastRender = renders[renders.length - 2]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
} else if (renders.length > 0) {
console.info("gradually inpainting on last result")
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
try {
console.log("before run inpaint")
const res = await inpaint(
targetFile,
settings,
cropperRect,
useCustomMask ? undefined : maskCanvas.toDataURL(),
useCustomMask ? customMask : undefined,
paintByExampleImage
)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob, seed } = res
if (seed) {
setSeed(parseInt(seed, 10))
}
const newRender = new Image()
await loadImage(newRender, blob)
if (useLastLineGroup === true) {
const prevRenders = renders.slice(0, -1)
const newRenders = [...prevRenders, newRender]
// setRenders(newRenders)
updateEditorState({ renders: newRenders })
} else {
const newRenders = [...renders, newRender]
updateEditorState({ renders: newRenders })
}
draw(newRender, [])
// Only append new LineGroup after inpainting success
// setLineGroups(newLineGroups)
updateEditorState({ lineGroups: newLineGroups })
// clear redo stack
resetRedoState()
} catch (e: any) {
toast({
variant: "destructive",
title: "Uh oh! Something went wrong.",
description: e.message ? e.message : e.toString(),
})
// drawOnCurrentRender([])
}
setIsInpainting(false)
setPrevInteractiveSegMask(maskImage)
setTmpInteractiveSegMask(null)
setInteractiveSegMask(null)
},
[
}, [
renders,
lineGroups,
extraMasks,
original,
isOriginalLoaded,
interactiveSegState,
context,
curLineGroup,
maskCanvas,
settings,
cropperRect,
// drawOnCurrentRender,
hadDrawSomething,
drawLinesOnMask,
]
)
useEffect(() => {
emitter.on(EVENT_PROMPT, () => {
if (hadDrawSomething() || interactiveSegMask) {
runInpainting(false, undefined, interactiveSegMask)
} else if (lastLineGroup.length !== 0) {
// 使用上一次手绘的 mask 生成
runInpainting(true, undefined, prevInteractiveSegMask)
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask)
} else if (isPix2Pix) {
runInpainting(false, undefined, null)
} else {
toast({
variant: "destructive",
description: "Please draw mask on picture.",
})
}
emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
})
return () => {
emitter.off(EVENT_PROMPT)
}
}, [
hadDrawSomething,
runInpainting,
interactiveSegMask,
prevInteractiveSegMask,
imageHeight,
imageWidth,
])
useEffect(() => {
emitter.on(DREAM_BUTTON_MOUSE_ENTER, () => {
// 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask
if (!hadDrawSomething() && !interactiveSegMask) {
if (prevInteractiveSegMask) {
setDreamButtonHoverSegMask(prevInteractiveSegMask)
}
let lineGroup2Show: LineGroup = []
if (redoLineGroups.length !== 0) {
lineGroup2Show = redoLineGroups[redoLineGroups.length - 1]
} else if (lineGroups.length !== 0) {
lineGroup2Show = lineGroups[lineGroups.length - 1]
}
console.log(
`[DREAM_BUTTON_MOUSE_ENTER], prevInteractiveSegMask: ${prevInteractiveSegMask} lineGroup2Show: ${lineGroup2Show.length}`
)
if (lineGroup2Show) {
setDreamButtonHoverLineGroup(lineGroup2Show)
}
}
})
return () => {
emitter.off(DREAM_BUTTON_MOUSE_ENTER)
}
}, [
hadDrawSomething,
interactiveSegMask,
prevInteractiveSegMask,
// drawOnCurrentRender,
lineGroups,
redoLineGroups,
])
useEffect(() => {
emitter.on(DREAM_BUTTON_MOUSE_LEAVE, () => {
// 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask
if (!hadDrawSomething() && !interactiveSegMask) {
setDreamButtonHoverSegMask(null)
setDreamButtonHoverLineGroup([])
// drawOnCurrentRender([])
}
})
return () => {
emitter.off(DREAM_BUTTON_MOUSE_LEAVE)
}
}, [hadDrawSomething, interactiveSegMask])
useEffect(() => {
emitter.on(EVENT_CUSTOM_MASK, (data: any) => {
// TODO: not work with paint by example
runInpainting(false, data.mask)
})
return () => {
emitter.off(EVENT_CUSTOM_MASK)
}
}, [runInpainting])
useEffect(() => {
emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => {
if (hadDrawSomething() || interactiveSegMask) {
runInpainting(false, undefined, interactiveSegMask, data.image)
} else if (lastLineGroup.length !== 0) {
// 使用上一次手绘的 mask 生成
runInpainting(true, undefined, prevInteractiveSegMask, data.image)
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask, data.image)
} else {
toast({
variant: "destructive",
description: "Please draw mask on picture.",
})
}
})
return () => {
emitter.off(EVENT_PAINT_BY_EXAMPLE)
}
}, [runInpainting])
useEffect(() => {
emitter.on(RERUN_LAST_MASK, () => {
if (lastLineGroup.length !== 0) {
// 使用上一次手绘的 mask 生成
runInpainting(true, undefined, prevInteractiveSegMask)
} else if (prevInteractiveSegMask) {
// 使用上一次 IS 的 mask 生成
runInpainting(false, undefined, prevInteractiveSegMask)
} else {
toast({
variant: "destructive",
description: "No mask to reuse",
})
}
})
return () => {
emitter.off(RERUN_LAST_MASK)
}
}, [runInpainting])
const getCurrentRender = useCallback(async () => {
let targetFile = file
if (renders.length > 0) {
@ -575,126 +219,6 @@ export default function Editor(props: EditorProps) {
return targetFile
}, [file, renders])
useEffect(() => {
emitter.on(PluginName.InteractiveSeg, () => {
// setIsInteractiveSeg(true)
if (interactiveSegMask !== null) {
setShowInteractiveSegModal(true)
}
})
return () => {
emitter.off(PluginName.InteractiveSeg)
}
})
const runRenderablePlugin = useCallback(
async (name: string, data?: any) => {
if (isProcessing) {
return
}
try {
// TODO 要不要加 undoCurrentLine
const start = new Date()
setIsPluginRunning(true)
const targetFile = await getCurrentRender()
const res = await runPlugin(name, targetFile, data?.upscale)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob)
setImageSize(newRender.height, newRender.width)
const newRenders = [...renders, newRender]
// setRenders(newRenders)
updateEditorState({ renders: newRenders })
const newLineGroups = [...lineGroups, []]
updateEditorState({ lineGroups: newLineGroups })
const end = new Date()
const time = end.getTime() - start.getTime()
toast({
description: `Run ${name} successfully in ${time / 1000}s`,
})
const rW = windowSize.width / newRender.width
const rH = (windowSize.height - TOOLBAR_HEIGHT) / newRender.height
let s = 1.0
if (rW < 1 || rH < 1) {
s = Math.min(rW, rH)
}
setMinScale(s)
setScale(s)
viewportRef.current?.centerView(s, 1)
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
} finally {
setIsPluginRunning(false)
}
},
[
renders,
// setRenders,
getCurrentRender,
setIsPluginRunning,
isProcessing,
setImageSize,
lineGroups,
viewportRef,
windowSize,
// setLineGroups,
]
)
useEffect(() => {
emitter.on(PluginName.RemoveBG, () => {
runRenderablePlugin(PluginName.RemoveBG)
})
return () => {
emitter.off(PluginName.RemoveBG)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.AnimeSeg, () => {
runRenderablePlugin(PluginName.AnimeSeg)
})
return () => {
emitter.off(PluginName.AnimeSeg)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.GFPGAN, () => {
runRenderablePlugin(PluginName.GFPGAN)
})
return () => {
emitter.off(PluginName.GFPGAN)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.RestoreFormer, () => {
runRenderablePlugin(PluginName.RestoreFormer)
})
return () => {
emitter.off(PluginName.RestoreFormer)
}
}, [runRenderablePlugin])
useEffect(() => {
emitter.on(PluginName.RealESRGAN, (data: any) => {
runRenderablePlugin(PluginName.RealESRGAN, data)
})
return () => {
emitter.off(PluginName.RealESRGAN)
}
}, [runRenderablePlugin])
const hadRunInpainting = () => {
return renders.length !== 0
}
@ -736,14 +260,17 @@ export default function Editor(props: EditorProps) {
setScale(s)
console.log(
`[on file load] image size: ${width}x${height}, canvas size: ${context?.canvas.width}x${context?.canvas.height} scale: ${s}, initialCentered: ${initialCentered}`
`[on file load] image size: ${width}x${height}, scale: ${s}, initialCentered: ${initialCentered}`
)
if (context?.canvas) {
console.log("[on file load] set canvas size")
if (width != context.canvas.width) {
context.canvas.width = width
}
if (height != context.canvas.height) {
context.canvas.height = height
console.log("[on file load] set canvas size && drawOnCurrentRender")
// drawOnCurrentRender([])
}
}
if (!initialCentered) {
@ -753,13 +280,11 @@ export default function Editor(props: EditorProps) {
setInitialCentered(true)
}
}, [
context?.canvas,
viewportRef,
original,
isOriginalLoaded,
windowSize,
initialCentered,
// drawOnCurrentRender,
getCurrentWidthHeight,
])
@ -807,17 +332,6 @@ export default function Editor(props: EditorProps) {
}
}, [windowSize, resetZoom])
useEffect(() => {
window.addEventListener("blur", () => {
setIsChangingBrushSizeByMouse(false)
})
return () => {
window.removeEventListener("blur", () => {
setIsChangingBrushSizeByMouse(false)
})
}
}, [])
const handleEscPressed = () => {
if (isProcessing) {
return
@ -845,15 +359,6 @@ export default function Editor(props: EditorProps) {
}
const onMouseDrag = (ev: SyntheticEvent) => {
// if (isChangingBrushSizeByMouse) {
// const initX = changeBrushSizeByMouseInit.x
// // move right: increase brush size
// const newSize = changeBrushSizeByMouseInit.brushSize + (x - initX)
// if (newSize <= MAX_BRUSH_SIZE && newSize >= MIN_BRUSH_SIZE) {
// setBaseBrushSize(newSize)
// }
// return
// }
if (interactiveSegState.isInteractiveSeg) {
return
}
@ -868,26 +373,16 @@ export default function Editor(props: EditorProps) {
}
handleCanvasMouseMove(mouseXY(ev))
// const lineGroup = [...curLineGroup]
// lineGroup[lineGroup.length - 1].pts.push(mouseXY(ev))
// setCurLineGroup(lineGroup)
// drawOnCurrentRender(lineGroup)
}
const runInteractiveSeg = async (newClicks: number[][]) => {
if (!file) {
return
}
// setIsInteractiveSegRunning(true)
updateAppState({ isPluginRunning: true })
const targetFile = await getCurrentRender()
const prevMask = null
try {
const res = await runPlugin(
PluginName.InteractiveSeg,
targetFile,
undefined,
prevMask,
newClicks
)
if (!res) {
@ -896,7 +391,7 @@ export default function Editor(props: EditorProps) {
const { blob } = res
const img = new Image()
img.onload = () => {
setTmpInteractiveSegMask(img)
updateInteractiveSegState({ tmpInteractiveSegMask: img })
}
img.src = blob
} catch (e: any) {
@ -905,7 +400,7 @@ export default function Editor(props: EditorProps) {
description: e.message ? e.message : e.toString(),
})
}
// setIsInteractiveSegRunning(false)
updateAppState({ isPluginRunning: false })
}
const onPointerUp = (ev: SyntheticEvent) => {
@ -933,7 +428,7 @@ export default function Editor(props: EditorProps) {
return
}
if (enableManualInpainting) {
if (runMannually) {
setIsDraging(false)
} else {
runInpainting()
@ -959,15 +454,13 @@ export default function Editor(props: EditorProps) {
const onCanvasMouseUp = (ev: SyntheticEvent) => {
if (interactiveSegState.isInteractiveSeg) {
const xy = mouseXY(ev)
const isX = xy.x
const isY = xy.y
const newClicks: number[][] = [...interactiveSegState.clicks]
if (isRightClick(ev)) {
newClicks.push([isX, isY, 0, newClicks.length])
newClicks.push([xy.x, xy.y, 0, newClicks.length])
} else {
newClicks.push([isX, isY, 1, newClicks.length])
newClicks.push([xy.x, xy.y, 1, newClicks.length])
}
// runInteractiveSeg(newClicks)
runInteractiveSeg(newClicks)
updateInteractiveSegState({ clicks: newClicks })
}
}
@ -979,13 +472,10 @@ export default function Editor(props: EditorProps) {
if (interactiveSegState.isInteractiveSeg) {
return
}
if (isChangingBrushSizeByMouse) {
return
}
if (isPanning) {
return
}
if (!original.src) {
if (!isOriginalLoaded) {
return
}
const canvas = context?.canvas
@ -1002,26 +492,17 @@ export default function Editor(props: EditorProps) {
return
}
if (
isDiffusionModels &&
settings.showCroper &&
isOutsideCroper(mouseXY(ev))
) {
return
}
// if (
// isDiffusionModels &&
// settings.showCroper &&
// isOutsideCroper(mouseXY(ev))
// ) {
// // TODO: 去掉这个逻辑,在 cropper 层截断 click 点击?
// return
// }
setIsDraging(true)
// let lineGroup: LineGroup = []
// if (enableManualInpainting) {
// lineGroup = [...curLineGroup]
// }
// lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] })
// setCurLineGroup(lineGroup)
handleCanvasMouseDown(mouseXY(ev))
// drawOnCurrentRender(lineGroup)
}
const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => {
@ -1092,7 +573,7 @@ export default function Editor(props: EditorProps) {
let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1")
maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg")
drawLinesOnMask(lineGroups)
const maskCanvas = generateMask(imageWidth, imageHeight, lineGroups)
// Create a link
const aDownloadLink = document.createElement("a")
// Add the name of the file to the link
@ -1147,11 +628,11 @@ export default function Editor(props: EditorProps) {
useHotkeys(
"shift+r",
() => {
if (enableManualInpainting && hadDrawSomething()) {
if (runMannually && hadDrawSomething()) {
runInpainting()
}
},
[enableManualInpainting, runInpainting, hadDrawSomething]
[runMannually, runInpainting, hadDrawSomething]
)
useHotkeys(
@ -1192,13 +673,11 @@ export default function Editor(props: EditorProps) {
(ev) => {
ev?.preventDefault()
ev?.stopPropagation()
setIsChangingBrushSizeByMouse(true)
setChangeBrushSizeByMouseInit({ x, y, brushSize })
// TODO: mouse scroll increase/decrease brush size
},
(ev) => {
ev?.preventDefault()
ev?.stopPropagation()
setIsChangingBrushSizeByMouse(false)
}
)
@ -1359,20 +838,24 @@ export default function Editor(props: EditorProps) {
show={settings.showCroper}
/>
{/* {interactiveSegState.isInteractiveSeg ? <InteractiveSeg /> : <></>} */}
{interactiveSegState.isInteractiveSeg ? (
<InteractiveSegPoints />
) : (
<></>
)}
</TransformComponent>
</TransformWrapper>
)
}
const onInteractiveAccept = () => {
setInteractiveSegMask(tmpInteractiveSegMask)
setTmpInteractiveSegMask(null)
// const onInteractiveAccept = () => {
// setInteractiveSegMask(tmpInteractiveSegMask)
// setTmpInteractiveSegMask(null)
if (!enableManualInpainting && tmpInteractiveSegMask) {
runInpainting(false, undefined, tmpInteractiveSegMask)
}
}
// if (!enableManualInpainting && tmpInteractiveSegMask) {
// runInpainting(false, undefined, tmpInteractiveSegMask)
// }
// }
return (
<div
@ -1388,16 +871,11 @@ export default function Editor(props: EditorProps) {
!isPanning &&
(interactiveSegState.isInteractiveSeg
? renderInteractiveSegCursor()
: renderBrush(
getBrushStyle(
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.x : x,
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.y : y
)
))}
: renderBrush(getBrushStyle(x, y)))}
{showRefBrush && renderBrush(getBrushStyle(windowCenterX, windowCenterY))}
<div className="fixed flex bottom-5 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md bg-background/50">
<div className="fixed flex bottom-5 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md bg-background/70">
<Slider
className="w-48"
defaultValue={[50]}
@ -1464,16 +942,17 @@ export default function Editor(props: EditorProps) {
<Download />
</IconButton>
{settings.enableManualInpainting ? (
{settings.enableManualInpainting &&
settings.model.model_type === "inpaint" ? (
<IconButton
tooltip="Run Inpainting"
disabled={
isProcessing ||
(!hadDrawSomething() && interactiveSegMask === null)
(!hadDrawSomething() &&
interactiveSegState.interactiveSegMask === null)
}
onClick={() => {
// ensured by disabled
runInpainting(false, undefined, interactiveSegMask)
runInpainting()
}}
>
<Eraser />

View File

@ -69,7 +69,7 @@ const Header = () => {
)
return (
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 backdrop-filter backdrop-blur-md border-b">
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
<div className="flex items-center gap-1">
{enableFileManager ? (
<FileManager

View File

@ -39,19 +39,20 @@ const InteractiveSegReplaceModal = (props: InteractiveSegReplaceModal) => {
}
const InteractiveSegConfirmActions = () => {
const [interactiveSegState, resetInteractiveSegState] = useStore((state) => [
const [
interactiveSegState,
resetInteractiveSegState,
handleInteractiveSegAccept,
] = useStore((state) => [
state.interactiveSegState,
state.resetInteractiveSegState,
state.handleInteractiveSegAccept,
])
if (!interactiveSegState.isInteractiveSeg) {
return null
}
const onAcceptClick = () => {
resetInteractiveSegState()
}
return (
<div className="z-10 absolute top-[68px] rounded-xl border-solid border p-[8px] left-1/2 translate-x-[-50%] flex justify-center items-center gap-[8px] bg-background">
<Button
@ -66,7 +67,7 @@ const InteractiveSegConfirmActions = () => {
<Button
size="sm"
onClick={() => {
onAcceptClick()
handleInteractiveSegAccept()
}}
>
Accept
@ -84,11 +85,11 @@ interface ItemProps {
const Item = (props: ItemProps) => {
const { x, y, positive } = props
const name = positive
? "bg-[rgba(21,_215,_121,_0.936)] outline-[6px_solid_rgba(98,_255,_179,_0.31)]"
: "bg-[rgba(237,_49,_55,_0.942)] outline-[6px_solid_rgba(255,_89,_95,_0.31)]"
? "bg-[rgba(21,_215,_121,_0.936)] outline-[rgba(98,255,179,0.31)]"
: "bg-[rgba(237,_49,_55,_0.942)] outline-[rgba(255,89,95,0.31)]"
return (
<div
className={`absolute h-[8px] w-[8px] rounded-[50%] ${name}`}
className={`absolute h-[10px] w-[10px] rounded-[50%] ${name} outline-8 outline`}
style={{
left: x,
top: y,

View File

@ -15,9 +15,7 @@ import {
Slice,
Smile,
} from "lucide-react"
import { MixIcon } from "@radix-ui/react-icons"
import { useStore } from "@/lib/states"
import { InteractiveSeg } from "./InteractiveSeg"
export enum PluginName {
RemoveBG = "RemoveBG",
@ -56,44 +54,49 @@ const pluginMap = {
}
const Plugins = () => {
const [plugins, updateInteractiveSegState] = useStore((state) => [
const [file, plugins, updateInteractiveSegState, runRenderablePlugin] =
useStore((state) => [
state.file,
state.serverConfig.plugins,
state.updateInteractiveSegState,
state.runRenderablePlugin,
])
const disabled = !file
if (plugins.length === 0) {
return null
}
const onPluginClick = (pluginName: string) => {
// if (!disabled) {
// emitter.emit(pluginName)
// }
if (pluginName === PluginName.InteractiveSeg) {
updateInteractiveSegState({ isInteractiveSeg: true })
} else {
runRenderablePlugin(pluginName)
}
}
const onRealESRGANClick = (upscale: number) => {
// if (!disabled) {
// emitter.emit(PluginName.RealESRGAN, { upscale })
// }
}
const renderRealESRGANPlugin = () => {
return (
<DropdownMenuSub key="RealESRGAN">
<DropdownMenuSubTrigger>
<DropdownMenuSubTrigger disabled={disabled}>
<div className="flex gap-2 items-center">
<Fullscreen />
RealESRGAN
</div>
</DropdownMenuSubTrigger>
<DropdownMenuSubContent>
<DropdownMenuItem onClick={() => onRealESRGANClick(2)}>
<DropdownMenuItem
onClick={() =>
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 2 })
}
>
upscale 2x
</DropdownMenuItem>
<DropdownMenuItem onClick={() => onRealESRGANClick(4)}>
<DropdownMenuItem
onClick={() =>
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 4 })
}
>
upscale 4x
</DropdownMenuItem>
</DropdownMenuSubContent>
@ -108,7 +111,11 @@ const Plugins = () => {
return renderRealESRGANPlugin()
}
return (
<DropdownMenuItem key={plugin} onClick={() => onPluginClick(plugin)}>
<DropdownMenuItem
key={plugin}
onClick={() => onPluginClick(plugin)}
disabled={disabled}
>
<div className="flex gap-2 items-center">
<IconClass className="p-1" />
{showName}
@ -121,7 +128,7 @@ const Plugins = () => {
return (
<DropdownMenu modal={false}>
<DropdownMenuTrigger
className="border rounded-lg z-10 bg-background"
className="border rounded-lg z-10 bg-background outline-none"
tabIndex={-1}
>
<Button variant="ghost" size="icon" asChild className="p-1.5">

View File

@ -1,19 +1,17 @@
import React, { FormEvent } from "react"
import emitter, {
DREAM_BUTTON_MOUSE_ENTER,
DREAM_BUTTON_MOUSE_LEAVE,
EVENT_PROMPT,
} from "@/lib/event"
import { Button } from "./ui/button"
import { Input } from "./ui/input"
import { useStore } from "@/lib/states"
const PromptInput = () => {
const [isInpainting, prompt, updateSettings] = useStore((state) => [
state.isInpainting,
const [isProcessing, prompt, updateSettings, runInpainting] = useStore(
(state) => [
state.getIsProcessing(),
state.settings.prompt,
state.updateSettings,
])
state.runInpainting,
]
)
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
evt.preventDefault()
@ -22,25 +20,25 @@ const PromptInput = () => {
updateSettings({ prompt: target.value })
}
const handleRepaintClick = () => {
if (prompt.length !== 0 && isInpainting) {
emitter.emit(EVENT_PROMPT)
const handleRepaintClick = async () => {
if (prompt.length !== 0 && !isProcessing) {
await runInpainting()
}
}
const onKeyUp = (e: React.KeyboardEvent) => {
if (e.key === "Enter" && !isInpainting) {
if (e.key === "Enter" && !isProcessing) {
handleRepaintClick()
}
}
const onMouseEnter = () => {
emitter.emit(DREAM_BUTTON_MOUSE_ENTER)
}
// const onMouseEnter = () => {
// emitter.emit(DREAM_BUTTON_MOUSE_ENTER)
// }
const onMouseLeave = () => {
emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
}
// const onMouseLeave = () => {
// emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
// }
return (
<div className="flex gap-4 items-center">
@ -54,9 +52,9 @@ const PromptInput = () => {
<Button
size="sm"
onClick={handleRepaintClick}
disabled={prompt.length === 0 || isInpainting}
onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave}
disabled={prompt.length === 0 || isProcessing}
// onMouseEnter={onMouseEnter}
// onMouseLeave={onMouseLeave}
>
Dream
</Button>

View File

@ -29,7 +29,7 @@ const DropdownMenuSubTrigger = React.forwardRef<
<DropdownMenuPrimitive.SubTrigger
ref={ref}
className={cn(
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent data-[state=open]:bg-accent",
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent data-[state=open]:bg-accent data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
inset && "pl-8",
className
)}
@ -72,6 +72,7 @@ const DropdownMenuContent = React.forwardRef<
className
)}
{...props}
onCloseAutoFocus={(e) => e.preventDefault()}
/>
</DropdownMenuPrimitive.Portal>
))

View File

@ -15,18 +15,13 @@ export default async function inpaint(
imageFile: File,
settings: Settings,
croperRect: Rect,
maskBase64?: string,
customMask?: File,
paintByExampleImage?: File
mask: File | Blob,
paintByExampleImage: File | null = null
) {
// 1080, 2000, Original
const fd = new FormData()
fd.append("image", imageFile)
if (maskBase64 !== undefined) {
fd.append("mask", dataURItoBlob(maskBase64))
} else if (customMask !== undefined) {
fd.append("mask", customMask)
}
fd.append("mask", mask)
fd.append("ldmSteps", settings.ldmSteps.toString())
fd.append("ldmSampler", settings.ldmSampler.toString())
@ -42,8 +37,7 @@ export default async function inpaint(
fd.append("croperY", croperRect.y.toString())
fd.append("croperHeight", croperRect.height.toString())
fd.append("croperWidth", croperRect.width.toString())
// fd.append("useCroper", settings.showCroper ? "true" : "false")
fd.append("useCroper", "false")
fd.append("useCroper", settings.showCroper ? "true" : "false")
fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
fd.append("sdStrength", settings.sdStrength.toString())
@ -147,7 +141,6 @@ export async function runPlugin(
name: string,
imageFile: File,
upscale?: number,
maskFile?: File | null,
clicks?: number[][]
) {
const fd = new FormData()
@ -159,9 +152,6 @@ export async function runPlugin(
if (clicks) {
fd.append("clicks", JSON.stringify(clicks))
}
if (maskFile) {
fd.append("mask", maskFile)
}
try {
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {

View File

@ -1,7 +1,8 @@
import { create } from "zustand"
import { persist } from "zustand/middleware"
import { shallow } from "zustand/shallow"
import { immer } from "zustand/middleware/immer"
import { castDraft } from "immer"
import { nanoid } from "nanoid"
import { createWithEqualityFn } from "zustand/traditional"
import {
CV2Flag,
@ -10,6 +11,7 @@ import {
Line,
LineGroup,
ModelInfo,
PluginParams,
Point,
SDSampler,
Size,
@ -17,6 +19,9 @@ import {
SortOrder,
} from "./types"
import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const"
import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils"
import inpaint, { runPlugin } from "./api"
import { toast, useToast } from "@/components/ui/use-toast"
type FileManagerState = {
sortBy: SortBy
@ -95,7 +100,9 @@ type ServerConfig = {
type InteractiveSegState = {
isInteractiveSeg: boolean
isInteractiveSegRunning: boolean
interactiveSegMask: HTMLImageElement | null
tmpInteractiveSegMask: HTMLImageElement | null
prevInteractiveSegMask: HTMLImageElement | null
clicks: number[][]
}
@ -103,9 +110,11 @@ type EditorState = {
baseBrushSize: number
brushSizeScale: number
renders: HTMLImageElement[]
paintByExampleImage: File | null
lineGroups: LineGroup[]
lastLineGroup: LineGroup
curLineGroup: LineGroup
extraMasks: HTMLImageElement[]
// redo 相关
redoRenders: HTMLImageElement[]
redoCurLines: Line[]
@ -113,6 +122,8 @@ type EditorState = {
}
type AppState = {
idForUpdateView: string
file: File | null
customMask: File | null
imageHeight: number
@ -136,6 +147,7 @@ type AppAction = {
setCustomFile: (file: File) => void
setIsInpainting: (newValue: boolean) => void
setIsPluginRunning: (newValue: boolean) => void
getIsProcessing: () => boolean
setBaseBrushSize: (newValue: number) => void
getBrushSize: () => number
setImageSize: (width: number, height: number) => void
@ -151,10 +163,18 @@ type AppAction = {
updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void
handleInteractiveSegAccept: () => void
showPromptInput: () => boolean
showSidePanel: () => boolean
runInpainting: () => Promise<void>
runRenderablePlugin: (
pluginName: string,
params?: PluginParams
) => Promise<void>
// EditorState
getCurrentTargetFile: () => Promise<File>
updateEditorState: (newState: Partial<EditorState>) => void
runMannually: () => boolean
handleCanvasMouseDown: (point: Point) => void
@ -168,12 +188,15 @@ type AppAction = {
}
const defaultValues: AppState = {
idForUpdateView: nanoid(),
file: null,
customMask: null,
imageHeight: 0,
imageWidth: 0,
isInpainting: false,
isPluginRunning: false,
windowSize: {
height: 600,
width: 800,
@ -182,6 +205,8 @@ const defaultValues: AppState = {
baseBrushSize: DEFAULT_BRUSH_SIZE,
brushSizeScale: 1,
renders: [],
paintByExampleImage: null,
extraMasks: [],
lineGroups: [],
lastLineGroup: [],
curLineGroup: [],
@ -192,7 +217,9 @@ const defaultValues: AppState = {
interactiveSegState: {
isInteractiveSeg: false,
isInteractiveSegRunning: false,
interactiveSegMask: null,
tmpInteractiveSegMask: null,
prevInteractiveSegMask: null,
clicks: [],
},
@ -267,16 +294,208 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
immer((set, get) => ({
...defaultValues,
getCurrentTargetFile: async (): Promise<File> => {
const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数
const renders = get().editorState.renders
let targetFile = file
if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
return targetFile
},
runInpainting: async () => {
const { file, imageWidth, imageHeight, settings, cropperState } = get()
if (file === null) {
return
}
const {
lastLineGroup,
curLineGroup,
lineGroups,
renders,
paintByExampleImage,
} = get().editorState
const { interactiveSegMask, prevInteractiveSegMask } =
get().interactiveSegState
const useLastLineGroup =
curLineGroup.length === 0 && interactiveSegMask === null
const maskImage = useLastLineGroup
? prevInteractiveSegMask
: interactiveSegMask
// useLastLineGroup 的影响
// 1. 使用上一次的 mask
// 2. 结果替换当前 render
let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) {
if (lastLineGroup.length === 0 && maskImage === null) {
toast({
variant: "destructive",
description: "Please draw mask on picture",
})
return
}
maskLineGroup = lastLineGroup
} else {
if (curLineGroup.length === 0 && maskImage === null) {
toast({
variant: "destructive",
description: "Please draw mask on picture",
})
return
}
maskLineGroup = curLineGroup
}
const newLineGroups = [...lineGroups, maskLineGroup]
set((state) => {
state.isInpainting = true
})
let targetFile = file
if (useLastLineGroup === true) {
// renders.length == 1 还是用原来的
if (renders.length > 1) {
const lastRender = renders[renders.length - 2]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
} else if (renders.length > 0) {
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
const maskCanvas = generateMask(
imageWidth,
imageHeight,
[maskLineGroup],
maskImage ? [maskImage] : []
)
try {
const res = await inpaint(
targetFile,
settings,
cropperState,
dataURItoBlob(maskCanvas.toDataURL()),
paintByExampleImage
)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob, seed } = res
if (seed) {
set((state) => (state.settings.seed = parseInt(seed, 10)))
}
const newRender = new Image()
await loadImage(newRender, blob)
if (useLastLineGroup === true) {
const prevRenders = renders.slice(0, -1)
const newRenders = [...prevRenders, newRender]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
lastLineGroup: curLineGroup,
curLineGroup: [],
})
} else {
const newRenders = [...renders, newRender]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
lastLineGroup: curLineGroup,
curLineGroup: [],
})
}
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
}
get().resetRedoState()
set((state) => {
state.isInpainting = false
})
const newInteractiveSegState = {
...defaultValues.interactiveSegState,
prevInteractiveSegMask: useLastLineGroup ? null : maskImage,
}
set((state) => {
state.interactiveSegState = castDraft(newInteractiveSegState)
})
},
runRenderablePlugin: async (
pluginName: string,
params: PluginParams = { upscale: 1 }
) => {
const { renders, lineGroups } = get().editorState
set((state) => {
state.isInpainting = true
})
try {
const start = new Date()
const targetFile = await get().getCurrentTargetFile()
const res = await runPlugin(pluginName, targetFile, params.upscale)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob)
get().setImageSize(newRender.height, newRender.width)
const newRenders = [...renders, newRender]
const newLineGroups = [...lineGroups, []]
get().updateEditorState({
renders: newRenders,
lineGroups: newLineGroups,
})
const end = new Date()
const time = end.getTime() - start.getTime()
toast({
description: `Run ${pluginName} successfully in ${time / 1000}s`,
})
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
}
set((state) => {
state.isInpainting = false
})
},
// Edirot State //
updateEditorState: (newState: Partial<EditorState>) => {
set((state) => {
return {
...state,
editorState: {
...state.editorState,
...newState,
},
}
state.editorState = castDraft({ ...state.editorState, ...newState })
})
},
@ -313,6 +532,10 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
)
},
getIsProcessing: (): boolean => {
return get().isInpainting || get().isPluginRunning
},
// undo/redo
undoDisabled: (): boolean => {
@ -468,15 +691,30 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => {
set((state) => {
state.interactiveSegState = {
return {
...state,
interactiveSegState: {
...state.interactiveSegState,
...newState,
},
}
})
},
resetInteractiveSegState: () => {
get().updateInteractiveSegState(defaultValues.interactiveSegState)
},
handleInteractiveSegAccept: () => {
set((state) => {
state.interactiveSegState = defaultValues.interactiveSegState
return {
...state,
interactiveSegState: {
...defaultValues.interactiveSegState,
interactiveSegMask:
state.interactiveSegState.tmpInteractiveSegMask,
},
}
})
},
@ -492,8 +730,12 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
setFile: (file: File) =>
set((state) => {
// TODO: 清空各种状态
state.file = file
state.interactiveSegState = castDraft(
defaultValues.interactiveSegState
)
state.editorState = castDraft(defaultValues.editorState)
state.cropperState = defaultValues.cropperState
}),
setCustomFile: (file: File) =>

View File

@ -25,6 +25,10 @@ export enum PluginName {
InteractiveSeg = "InteractiveSeg",
}
export interface PluginParams {
upscale: number
}
export enum SortBy {
NAME = "name",
CTIME = "ctime",

View File

@ -159,3 +159,28 @@ export function drawLines(
ctx.stroke()
})
}
export const generateMask = (
imageWidth: number,
imageHeight: number,
lineGroups: LineGroup[],
maskImages: HTMLImageElement[] = []
): HTMLCanvasElement => {
const maskCanvas = document.createElement("canvas")
maskCanvas.width = imageWidth
maskCanvas.height = imageHeight
const ctx = maskCanvas.getContext("2d")
if (!ctx) {
throw new Error("could not retrieve mask canvas")
}
maskImages.forEach((maskImage) => {
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
})
lineGroups.forEach((lineGroup) => {
drawLines(ctx, lineGroup, "white")
})
return maskCanvas
}