wip
This commit is contained in:
parent
fecf4beef0
commit
354a1280a4
@ -36,16 +36,15 @@ class ModelManager:
|
||||
return ControlNet(device, **{**kwargs, "model_info": model_info})
|
||||
else:
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
f"When using non inpaint Stable Diffusion model, you must enable controlnet"
|
||||
)
|
||||
if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT:
|
||||
return SD(device, model_id_or_path=model_info.path, **kwargs)
|
||||
|
||||
if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT:
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
|
||||
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
@ -88,7 +87,7 @@ class ModelManager:
|
||||
if self.kwargs["sd_controlnet_method"] == control_method:
|
||||
return
|
||||
|
||||
if not self.available_models[self.name].support_controlnet():
|
||||
if not self.available_models[self.name].support_controlnet:
|
||||
return
|
||||
|
||||
del self.model
|
||||
@ -105,7 +104,7 @@ class ModelManager:
|
||||
if str(self.model.device) == "mps":
|
||||
return
|
||||
|
||||
if self.available_models[self.name].support_freeu():
|
||||
if self.available_models[self.name].support_freeu:
|
||||
if config.sd_freeu:
|
||||
freeu_config = config.sd_freeu_config
|
||||
self.model.model.enable_freeu(
|
||||
@ -118,7 +117,7 @@ class ModelManager:
|
||||
self.model.model.disable_freeu()
|
||||
|
||||
def enable_disable_lcm_lora(self, config: Config):
|
||||
if self.available_models[self.name].support_lcm_lora():
|
||||
if self.available_models[self.name].support_lcm_lora:
|
||||
if config.sd_lcm_lora:
|
||||
if not self.model.model.pipe.get_list_adapters():
|
||||
self.model.model.load_lora_weights(self.model.lcm_lora_id)
|
||||
|
@ -1,8 +1,14 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, computed_field
|
||||
|
||||
from lama_cleaner.const import (
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
)
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||
@ -31,6 +37,36 @@ class ModelInfo(BaseModel):
|
||||
model_type: ModelType
|
||||
is_single_file_diffusers: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def need_prompt(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [
|
||||
"timbrooks/instruct-pix2pix",
|
||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def controlnets(self) -> List[str]:
|
||||
if self.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
return SDXL_CONTROLNET_CHOICES
|
||||
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
||||
if self.name in ["stabilityai/stable-diffusion-2-inpainting"]:
|
||||
return SD2_CONTROLNET_CHOICES
|
||||
else:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_lcm_lora(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
@ -39,6 +75,8 @@ class ModelInfo(BaseModel):
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_controlnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
@ -47,6 +85,8 @@ class ModelInfo(BaseModel):
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_freeu(self) -> bool:
|
||||
return (
|
||||
self.model_type
|
||||
@ -56,7 +96,7 @@ class ModelInfo(BaseModel):
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
or "instruct-pix2pix" in self.name
|
||||
or "timbrooks/instruct-pix2pix" in self.name
|
||||
)
|
||||
|
||||
|
||||
|
@ -419,14 +419,8 @@ def run_plugin():
|
||||
|
||||
@app.route("/server_config", methods=["GET"])
|
||||
def get_server_config():
|
||||
controlnet = {
|
||||
"SD": SD_CONTROLNET_CHOICES,
|
||||
"SD2": SD2_CONTROLNET_CHOICES,
|
||||
"SDXL": SDXL_CONTROLNET_CHOICES,
|
||||
}
|
||||
return {
|
||||
"plugins": list(plugins.keys()),
|
||||
"availableControlNet": controlnet,
|
||||
"enableFileManager": enable_file_manager,
|
||||
"enableAutoSaving": enable_auto_saving,
|
||||
}, 200
|
||||
@ -434,20 +428,12 @@ def get_server_config():
|
||||
|
||||
@app.route("/models", methods=["GET"])
|
||||
def get_models():
|
||||
return [
|
||||
{
|
||||
**it.dict(),
|
||||
"support_lcm_lora": it.support_lcm_lora(),
|
||||
"support_controlnet": it.support_controlnet(),
|
||||
"support_freeu": it.support_freeu(),
|
||||
}
|
||||
for it in model.scan_models()
|
||||
]
|
||||
return [it.model_dump() for it in model.scan_models()]
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.available_models[model.name].dict(), 200
|
||||
return model.available_models[model.name].model_dump(), 200
|
||||
|
||||
|
||||
@app.route("/is_desktop")
|
||||
@ -600,8 +586,20 @@ def main(args):
|
||||
else:
|
||||
input_image_path = args.input
|
||||
|
||||
# 为了兼容性
|
||||
model_name_map = {
|
||||
"sd1.5": "runwayml/stable-diffusion-inpainting",
|
||||
"anything4": "Sanster/anything-4.0-inpainting",
|
||||
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
|
||||
"sd2": "stabilityai/stable-diffusion-2-inpainting",
|
||||
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
||||
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
|
||||
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
|
||||
}
|
||||
|
||||
model = ModelManager(
|
||||
name=args.model,
|
||||
name=model_name_map.get(args.model, args.model),
|
||||
sd_controlnet=args.sd_controlnet,
|
||||
sd_controlnet_method=args.sd_controlnet_method,
|
||||
device=device,
|
||||
|
@ -6,14 +6,15 @@ import {
|
||||
TransformComponent,
|
||||
TransformWrapper,
|
||||
} from "react-zoom-pan-pinch"
|
||||
import { useKeyPressEvent, useWindowSize } from "react-use"
|
||||
import inpaint, { downloadToOutput, runPlugin } from "@/lib/api"
|
||||
import { useKeyPressEvent } from "react-use"
|
||||
import { downloadToOutput, runPlugin } from "@/lib/api"
|
||||
import { IconButton } from "@/components/ui/button"
|
||||
import {
|
||||
askWritePermission,
|
||||
copyCanvasImage,
|
||||
downloadImage,
|
||||
drawLines,
|
||||
generateMask,
|
||||
isMidClick,
|
||||
isRightClick,
|
||||
loadImage,
|
||||
@ -21,20 +22,13 @@ import {
|
||||
srcToFile,
|
||||
} from "@/lib/utils"
|
||||
import { Eraser, Eye, Redo, Undo, Expand, Download } from "lucide-react"
|
||||
import emitter, {
|
||||
EVENT_PROMPT,
|
||||
EVENT_CUSTOM_MASK,
|
||||
EVENT_PAINT_BY_EXAMPLE,
|
||||
RERUN_LAST_MASK,
|
||||
DREAM_BUTTON_MOUSE_ENTER,
|
||||
DREAM_BUTTON_MOUSE_LEAVE,
|
||||
} from "@/lib/event"
|
||||
import { useImage } from "@/hooks/useImage"
|
||||
import { Slider } from "./ui/slider"
|
||||
import { Line, LineGroup, PluginName } from "@/lib/types"
|
||||
import { PluginName } from "@/lib/types"
|
||||
import { useHotkeys } from "react-hotkeys-hook"
|
||||
import { useStore } from "@/lib/states"
|
||||
import Cropper from "./Cropper"
|
||||
import { InteractiveSegPoints } from "./InteractiveSeg"
|
||||
|
||||
const TOOLBAR_HEIGHT = 200
|
||||
const MIN_BRUSH_SIZE = 10
|
||||
@ -50,6 +44,8 @@ export default function Editor(props: EditorProps) {
|
||||
const { toast } = useToast()
|
||||
|
||||
const [
|
||||
idForUpdateView,
|
||||
windowSize,
|
||||
isInpainting,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
@ -75,7 +71,13 @@ export default function Editor(props: EditorProps) {
|
||||
redo,
|
||||
undoDisabled,
|
||||
redoDisabled,
|
||||
isProcessing,
|
||||
updateAppState,
|
||||
runMannually,
|
||||
runInpainting,
|
||||
] = useStore((state) => [
|
||||
state.idForUpdateView,
|
||||
state.windowSize,
|
||||
state.isInpainting,
|
||||
state.imageWidth,
|
||||
state.imageHeight,
|
||||
@ -101,65 +103,32 @@ export default function Editor(props: EditorProps) {
|
||||
state.redo,
|
||||
state.undoDisabled(),
|
||||
state.redoDisabled(),
|
||||
state.getIsProcessing(),
|
||||
state.updateAppState,
|
||||
state.runMannually(),
|
||||
state.runInpainting,
|
||||
])
|
||||
const baseBrushSize = useStore((state) => state.editorState.baseBrushSize)
|
||||
const brushSize = useStore((state) => state.getBrushSize())
|
||||
const renders = useStore((state) => state.editorState.renders)
|
||||
const extraMasks = useStore((state) => state.editorState.extraMasks)
|
||||
const lineGroups = useStore((state) => state.editorState.lineGroups)
|
||||
|
||||
const lastLineGroup = useStore((state) => state.editorState.lastLineGroup)
|
||||
const curLineGroup = useStore((state) => state.editorState.curLineGroup)
|
||||
const redoLineGroups = useStore((state) => state.editorState.redoLineGroups)
|
||||
|
||||
// 纯 local state
|
||||
// Local State
|
||||
const [showOriginal, setShowOriginal] = useState(false)
|
||||
|
||||
//
|
||||
const isProcessing = isInpainting
|
||||
const isDiffusionModels = false
|
||||
const isPix2Pix = false
|
||||
|
||||
const [showInteractiveSegModal, setShowInteractiveSegModal] = useState(false)
|
||||
const [interactiveSegMask, setInteractiveSegMask] = useState<
|
||||
HTMLImageElement | null | undefined
|
||||
>(null)
|
||||
|
||||
// only used while interactive segmentation is on
|
||||
const [tmpInteractiveSegMask, setTmpInteractiveSegMask] = useState<
|
||||
HTMLImageElement | null | undefined
|
||||
>(null)
|
||||
const [prevInteractiveSegMask, setPrevInteractiveSegMask] = useState<
|
||||
HTMLImageElement | null | undefined
|
||||
>(null)
|
||||
|
||||
// 仅用于在 dream button hover 时显示提示
|
||||
const [dreamButtonHoverSegMask, setDreamButtonHoverSegMask] = useState<
|
||||
HTMLImageElement | null | undefined
|
||||
>(null)
|
||||
const [dreamButtonHoverLineGroup, setDreamButtonHoverLineGroup] =
|
||||
useState<LineGroup>([])
|
||||
|
||||
const [original, isOriginalLoaded] = useImage(file)
|
||||
const [context, setContext] = useState<CanvasRenderingContext2D>()
|
||||
const [maskCanvas] = useState<HTMLCanvasElement>(() => {
|
||||
return document.createElement("canvas")
|
||||
})
|
||||
const [{ x, y }, setCoords] = useState({ x: -1, y: -1 })
|
||||
const [showBrush, setShowBrush] = useState(false)
|
||||
const [showRefBrush, setShowRefBrush] = useState(false)
|
||||
const [isPanning, setIsPanning] = useState<boolean>(false)
|
||||
const [isChangingBrushSizeByMouse, setIsChangingBrushSizeByMouse] =
|
||||
useState<boolean>(false)
|
||||
const [changeBrushSizeByMouseInit, setChangeBrushSizeByMouseInit] = useState({
|
||||
x: -1,
|
||||
y: -1,
|
||||
brushSize: 20,
|
||||
})
|
||||
|
||||
const [scale, setScale] = useState<number>(1)
|
||||
const [panned, setPanned] = useState<boolean>(false)
|
||||
const [minScale, setMinScale] = useState<number>(1.0)
|
||||
const windowSize = useWindowSize()
|
||||
const windowCenterX = windowSize.width / 2
|
||||
const windowCenterY = windowSize.height / 2
|
||||
const viewportRef = useRef<ReactZoomPanPinchContentRef | null>(null)
|
||||
@ -170,402 +139,77 @@ export default function Editor(props: EditorProps) {
|
||||
|
||||
const [sliderPos, setSliderPos] = useState<number>(0)
|
||||
|
||||
const draw = useCallback(
|
||||
(render: HTMLImageElement, lineGroup: LineGroup) => {
|
||||
if (!context) {
|
||||
return
|
||||
}
|
||||
console.log(
|
||||
`[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
|
||||
)
|
||||
|
||||
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
|
||||
context.drawImage(render, 0, 0, imageWidth, imageHeight)
|
||||
if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) {
|
||||
context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight)
|
||||
}
|
||||
if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) {
|
||||
context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight)
|
||||
}
|
||||
if (dreamButtonHoverSegMask) {
|
||||
context.drawImage(
|
||||
dreamButtonHoverSegMask,
|
||||
0,
|
||||
0,
|
||||
imageWidth,
|
||||
imageHeight
|
||||
)
|
||||
}
|
||||
drawLines(context, lineGroup)
|
||||
drawLines(context, dreamButtonHoverLineGroup)
|
||||
},
|
||||
[
|
||||
context,
|
||||
interactiveSegState,
|
||||
tmpInteractiveSegMask,
|
||||
dreamButtonHoverSegMask,
|
||||
interactiveSegMask,
|
||||
imageHeight,
|
||||
imageWidth,
|
||||
dreamButtonHoverLineGroup,
|
||||
]
|
||||
)
|
||||
|
||||
const drawLinesOnMask = useCallback(
|
||||
(_lineGroups: LineGroup[], maskImage?: HTMLImageElement | null) => {
|
||||
if (!context?.canvas.width || !context?.canvas.height) {
|
||||
throw new Error("canvas has invalid size")
|
||||
}
|
||||
maskCanvas.width = context?.canvas.width
|
||||
maskCanvas.height = context?.canvas.height
|
||||
const ctx = maskCanvas.getContext("2d")
|
||||
if (!ctx) {
|
||||
throw new Error("could not retrieve mask canvas")
|
||||
}
|
||||
|
||||
if (maskImage !== undefined && maskImage !== null) {
|
||||
// TODO: check whether draw yellow mask works on backend
|
||||
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||
}
|
||||
|
||||
_lineGroups.forEach((lineGroup) => {
|
||||
drawLines(ctx, lineGroup, "white")
|
||||
})
|
||||
|
||||
if (
|
||||
(maskImage === undefined || maskImage === null) &&
|
||||
_lineGroups.length === 1 &&
|
||||
_lineGroups[0].length === 0 &&
|
||||
isPix2Pix
|
||||
) {
|
||||
// For InstructPix2Pix without mask
|
||||
drawLines(
|
||||
ctx,
|
||||
[
|
||||
{
|
||||
size: 9999999999,
|
||||
pts: [
|
||||
{ x: 0, y: 0 },
|
||||
{ x: imageWidth, y: 0 },
|
||||
{ x: imageWidth, y: imageHeight },
|
||||
{ x: 0, y: imageHeight },
|
||||
],
|
||||
},
|
||||
],
|
||||
"white"
|
||||
)
|
||||
}
|
||||
},
|
||||
[context, maskCanvas, imageWidth, imageHeight]
|
||||
)
|
||||
|
||||
const hadDrawSomething = useCallback(() => {
|
||||
return curLineGroup.length !== 0
|
||||
}, [curLineGroup])
|
||||
|
||||
// const drawOnCurrentRender = useCallback(
|
||||
// (lineGroup: LineGroup) => {
|
||||
// console.log("[drawOnCurrentRender] draw on current render")
|
||||
// if (renders.length === 0) {
|
||||
// draw(original, lineGroup)
|
||||
// } else {
|
||||
// draw(renders[renders.length - 1], lineGroup)
|
||||
// }
|
||||
// },
|
||||
// [original, renders, draw]
|
||||
// )
|
||||
|
||||
useEffect(() => {
|
||||
if (!context) {
|
||||
if (
|
||||
!context ||
|
||||
!isOriginalLoaded ||
|
||||
imageWidth === 0 ||
|
||||
imageHeight === 0
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
const render = renders.length === 0 ? original : renders[renders.length - 1]
|
||||
|
||||
console.log(
|
||||
`[draw] render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
|
||||
`[draw] renders.length: ${renders.length} render size: ${render.width}x${render.height} image size: ${imageWidth}x${imageHeight} canvas size: ${context.canvas.width}x${context.canvas.height}`
|
||||
)
|
||||
|
||||
context.canvas.width = imageWidth
|
||||
context.canvas.height = imageHeight
|
||||
|
||||
context.clearRect(0, 0, context.canvas.width, context.canvas.height)
|
||||
context.drawImage(render, 0, 0, imageWidth, imageHeight)
|
||||
// if (interactiveSegState.isInteractiveSeg && tmpInteractiveSegMask) {
|
||||
// context.drawImage(tmpInteractiveSegMask, 0, 0, imageWidth, imageHeight)
|
||||
// }
|
||||
// if (!interactiveSegState.isInteractiveSeg && interactiveSegMask) {
|
||||
// context.drawImage(interactiveSegMask, 0, 0, imageWidth, imageHeight)
|
||||
// }
|
||||
|
||||
extraMasks.forEach((maskImage) => {
|
||||
context.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||
})
|
||||
|
||||
if (
|
||||
interactiveSegState.isInteractiveSeg &&
|
||||
interactiveSegState.tmpInteractiveSegMask
|
||||
) {
|
||||
context.drawImage(
|
||||
interactiveSegState.tmpInteractiveSegMask,
|
||||
0,
|
||||
0,
|
||||
imageWidth,
|
||||
imageHeight
|
||||
)
|
||||
}
|
||||
if (
|
||||
!interactiveSegState.isInteractiveSeg &&
|
||||
interactiveSegState.interactiveSegMask
|
||||
) {
|
||||
context.drawImage(
|
||||
interactiveSegState.interactiveSegMask,
|
||||
0,
|
||||
0,
|
||||
imageWidth,
|
||||
imageHeight
|
||||
)
|
||||
}
|
||||
// if (dreamButtonHoverSegMask) {
|
||||
// context.drawImage(dreamButtonHoverSegMask, 0, 0, imageWidth, imageHeight)
|
||||
// }
|
||||
drawLines(context, curLineGroup)
|
||||
// drawLines(context, dreamButtonHoverLineGroup)
|
||||
}, [renders, file, original, context, curLineGroup, imageHeight, imageWidth])
|
||||
|
||||
const runInpainting = useCallback(
|
||||
async (
|
||||
useLastLineGroup?: boolean,
|
||||
customMask?: File,
|
||||
maskImage?: HTMLImageElement | null,
|
||||
paintByExampleImage?: File
|
||||
) => {
|
||||
// customMask: mask uploaded by user
|
||||
// maskImage: mask from interactive segmentation
|
||||
if (file === undefined) {
|
||||
return
|
||||
}
|
||||
const useCustomMask = customMask !== undefined && customMask !== null
|
||||
const useMaskImage = maskImage !== undefined && maskImage !== null
|
||||
// useLastLineGroup 的影响
|
||||
// 1. 使用上一次的 mask
|
||||
// 2. 结果替换当前 render
|
||||
console.log("runInpainting")
|
||||
console.log({
|
||||
useCustomMask,
|
||||
useMaskImage,
|
||||
})
|
||||
|
||||
let maskLineGroup: LineGroup = []
|
||||
if (useLastLineGroup === true) {
|
||||
if (lastLineGroup.length === 0) {
|
||||
return
|
||||
}
|
||||
maskLineGroup = lastLineGroup
|
||||
} else if (!useCustomMask) {
|
||||
if (!hadDrawSomething() && !useMaskImage) {
|
||||
return
|
||||
}
|
||||
|
||||
// setLastLineGroup(curLineGroup)
|
||||
maskLineGroup = curLineGroup
|
||||
}
|
||||
|
||||
const newLineGroups = [...lineGroups, maskLineGroup]
|
||||
|
||||
cleanCurLineGroup()
|
||||
setIsDraging(false)
|
||||
setIsInpainting(true)
|
||||
drawLinesOnMask([maskLineGroup], maskImage)
|
||||
|
||||
let targetFile = file
|
||||
console.log(
|
||||
`randers.length ${renders.length} useLastLineGroup: ${useLastLineGroup}`
|
||||
)
|
||||
if (useLastLineGroup === true) {
|
||||
// renders.length == 1 还是用原来的
|
||||
if (renders.length > 1) {
|
||||
const lastRender = renders[renders.length - 2]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
} else if (renders.length > 0) {
|
||||
console.info("gradually inpainting on last result")
|
||||
|
||||
const lastRender = renders[renders.length - 1]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
|
||||
try {
|
||||
console.log("before run inpaint")
|
||||
const res = await inpaint(
|
||||
targetFile,
|
||||
settings,
|
||||
cropperRect,
|
||||
useCustomMask ? undefined : maskCanvas.toDataURL(),
|
||||
useCustomMask ? customMask : undefined,
|
||||
paintByExampleImage
|
||||
)
|
||||
if (!res) {
|
||||
throw new Error("Something went wrong on server side.")
|
||||
}
|
||||
const { blob, seed } = res
|
||||
if (seed) {
|
||||
setSeed(parseInt(seed, 10))
|
||||
}
|
||||
const newRender = new Image()
|
||||
await loadImage(newRender, blob)
|
||||
|
||||
if (useLastLineGroup === true) {
|
||||
const prevRenders = renders.slice(0, -1)
|
||||
const newRenders = [...prevRenders, newRender]
|
||||
// setRenders(newRenders)
|
||||
updateEditorState({ renders: newRenders })
|
||||
} else {
|
||||
const newRenders = [...renders, newRender]
|
||||
updateEditorState({ renders: newRenders })
|
||||
}
|
||||
|
||||
draw(newRender, [])
|
||||
// Only append new LineGroup after inpainting success
|
||||
// setLineGroups(newLineGroups)
|
||||
updateEditorState({ lineGroups: newLineGroups })
|
||||
|
||||
// clear redo stack
|
||||
resetRedoState()
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: "Uh oh! Something went wrong.",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
// drawOnCurrentRender([])
|
||||
}
|
||||
setIsInpainting(false)
|
||||
setPrevInteractiveSegMask(maskImage)
|
||||
setTmpInteractiveSegMask(null)
|
||||
setInteractiveSegMask(null)
|
||||
},
|
||||
[
|
||||
renders,
|
||||
lineGroups,
|
||||
curLineGroup,
|
||||
maskCanvas,
|
||||
settings,
|
||||
cropperRect,
|
||||
// drawOnCurrentRender,
|
||||
hadDrawSomething,
|
||||
drawLinesOnMask,
|
||||
]
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(EVENT_PROMPT, () => {
|
||||
if (hadDrawSomething() || interactiveSegMask) {
|
||||
runInpainting(false, undefined, interactiveSegMask)
|
||||
} else if (lastLineGroup.length !== 0) {
|
||||
// 使用上一次手绘的 mask 生成
|
||||
runInpainting(true, undefined, prevInteractiveSegMask)
|
||||
} else if (prevInteractiveSegMask) {
|
||||
// 使用上一次 IS 的 mask 生成
|
||||
runInpainting(false, undefined, prevInteractiveSegMask)
|
||||
} else if (isPix2Pix) {
|
||||
runInpainting(false, undefined, null)
|
||||
} else {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: "Please draw mask on picture.",
|
||||
})
|
||||
}
|
||||
emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
|
||||
})
|
||||
|
||||
return () => {
|
||||
emitter.off(EVENT_PROMPT)
|
||||
}
|
||||
}, [
|
||||
hadDrawSomething,
|
||||
runInpainting,
|
||||
interactiveSegMask,
|
||||
prevInteractiveSegMask,
|
||||
renders,
|
||||
extraMasks,
|
||||
original,
|
||||
isOriginalLoaded,
|
||||
interactiveSegState,
|
||||
context,
|
||||
curLineGroup,
|
||||
imageHeight,
|
||||
imageWidth,
|
||||
])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(DREAM_BUTTON_MOUSE_ENTER, () => {
|
||||
// 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask
|
||||
if (!hadDrawSomething() && !interactiveSegMask) {
|
||||
if (prevInteractiveSegMask) {
|
||||
setDreamButtonHoverSegMask(prevInteractiveSegMask)
|
||||
}
|
||||
let lineGroup2Show: LineGroup = []
|
||||
if (redoLineGroups.length !== 0) {
|
||||
lineGroup2Show = redoLineGroups[redoLineGroups.length - 1]
|
||||
} else if (lineGroups.length !== 0) {
|
||||
lineGroup2Show = lineGroups[lineGroups.length - 1]
|
||||
}
|
||||
console.log(
|
||||
`[DREAM_BUTTON_MOUSE_ENTER], prevInteractiveSegMask: ${prevInteractiveSegMask} lineGroup2Show: ${lineGroup2Show.length}`
|
||||
)
|
||||
if (lineGroup2Show) {
|
||||
setDreamButtonHoverLineGroup(lineGroup2Show)
|
||||
}
|
||||
}
|
||||
})
|
||||
return () => {
|
||||
emitter.off(DREAM_BUTTON_MOUSE_ENTER)
|
||||
}
|
||||
}, [
|
||||
hadDrawSomething,
|
||||
interactiveSegMask,
|
||||
prevInteractiveSegMask,
|
||||
// drawOnCurrentRender,
|
||||
lineGroups,
|
||||
redoLineGroups,
|
||||
])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(DREAM_BUTTON_MOUSE_LEAVE, () => {
|
||||
// 当前 canvas 上没有手绘 mask 或者 interactiveSegMask 时,显示上一次的 mask
|
||||
if (!hadDrawSomething() && !interactiveSegMask) {
|
||||
setDreamButtonHoverSegMask(null)
|
||||
setDreamButtonHoverLineGroup([])
|
||||
// drawOnCurrentRender([])
|
||||
}
|
||||
})
|
||||
return () => {
|
||||
emitter.off(DREAM_BUTTON_MOUSE_LEAVE)
|
||||
}
|
||||
}, [hadDrawSomething, interactiveSegMask])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(EVENT_CUSTOM_MASK, (data: any) => {
|
||||
// TODO: not work with paint by example
|
||||
runInpainting(false, data.mask)
|
||||
})
|
||||
|
||||
return () => {
|
||||
emitter.off(EVENT_CUSTOM_MASK)
|
||||
}
|
||||
}, [runInpainting])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(EVENT_PAINT_BY_EXAMPLE, (data: any) => {
|
||||
if (hadDrawSomething() || interactiveSegMask) {
|
||||
runInpainting(false, undefined, interactiveSegMask, data.image)
|
||||
} else if (lastLineGroup.length !== 0) {
|
||||
// 使用上一次手绘的 mask 生成
|
||||
runInpainting(true, undefined, prevInteractiveSegMask, data.image)
|
||||
} else if (prevInteractiveSegMask) {
|
||||
// 使用上一次 IS 的 mask 生成
|
||||
runInpainting(false, undefined, prevInteractiveSegMask, data.image)
|
||||
} else {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: "Please draw mask on picture.",
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
emitter.off(EVENT_PAINT_BY_EXAMPLE)
|
||||
}
|
||||
}, [runInpainting])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(RERUN_LAST_MASK, () => {
|
||||
if (lastLineGroup.length !== 0) {
|
||||
// 使用上一次手绘的 mask 生成
|
||||
runInpainting(true, undefined, prevInteractiveSegMask)
|
||||
} else if (prevInteractiveSegMask) {
|
||||
// 使用上一次 IS 的 mask 生成
|
||||
runInpainting(false, undefined, prevInteractiveSegMask)
|
||||
} else {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: "No mask to reuse",
|
||||
})
|
||||
}
|
||||
})
|
||||
return () => {
|
||||
emitter.off(RERUN_LAST_MASK)
|
||||
}
|
||||
}, [runInpainting])
|
||||
|
||||
const getCurrentRender = useCallback(async () => {
|
||||
let targetFile = file
|
||||
if (renders.length > 0) {
|
||||
@ -575,126 +219,6 @@ export default function Editor(props: EditorProps) {
|
||||
return targetFile
|
||||
}, [file, renders])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.InteractiveSeg, () => {
|
||||
// setIsInteractiveSeg(true)
|
||||
if (interactiveSegMask !== null) {
|
||||
setShowInteractiveSegModal(true)
|
||||
}
|
||||
})
|
||||
return () => {
|
||||
emitter.off(PluginName.InteractiveSeg)
|
||||
}
|
||||
})
|
||||
|
||||
const runRenderablePlugin = useCallback(
|
||||
async (name: string, data?: any) => {
|
||||
if (isProcessing) {
|
||||
return
|
||||
}
|
||||
try {
|
||||
// TODO 要不要加 undoCurrentLine??
|
||||
const start = new Date()
|
||||
setIsPluginRunning(true)
|
||||
const targetFile = await getCurrentRender()
|
||||
const res = await runPlugin(name, targetFile, data?.upscale)
|
||||
if (!res) {
|
||||
throw new Error("Something went wrong on server side.")
|
||||
}
|
||||
const { blob } = res
|
||||
const newRender = new Image()
|
||||
await loadImage(newRender, blob)
|
||||
setImageSize(newRender.height, newRender.width)
|
||||
const newRenders = [...renders, newRender]
|
||||
// setRenders(newRenders)
|
||||
updateEditorState({ renders: newRenders })
|
||||
const newLineGroups = [...lineGroups, []]
|
||||
updateEditorState({ lineGroups: newLineGroups })
|
||||
|
||||
const end = new Date()
|
||||
const time = end.getTime() - start.getTime()
|
||||
|
||||
toast({
|
||||
description: `Run ${name} successfully in ${time / 1000}s`,
|
||||
})
|
||||
|
||||
const rW = windowSize.width / newRender.width
|
||||
const rH = (windowSize.height - TOOLBAR_HEIGHT) / newRender.height
|
||||
let s = 1.0
|
||||
if (rW < 1 || rH < 1) {
|
||||
s = Math.min(rW, rH)
|
||||
}
|
||||
setMinScale(s)
|
||||
setScale(s)
|
||||
viewportRef.current?.centerView(s, 1)
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
} finally {
|
||||
setIsPluginRunning(false)
|
||||
}
|
||||
},
|
||||
[
|
||||
renders,
|
||||
// setRenders,
|
||||
getCurrentRender,
|
||||
setIsPluginRunning,
|
||||
isProcessing,
|
||||
setImageSize,
|
||||
lineGroups,
|
||||
viewportRef,
|
||||
windowSize,
|
||||
// setLineGroups,
|
||||
]
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.RemoveBG, () => {
|
||||
runRenderablePlugin(PluginName.RemoveBG)
|
||||
})
|
||||
return () => {
|
||||
emitter.off(PluginName.RemoveBG)
|
||||
}
|
||||
}, [runRenderablePlugin])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.AnimeSeg, () => {
|
||||
runRenderablePlugin(PluginName.AnimeSeg)
|
||||
})
|
||||
return () => {
|
||||
emitter.off(PluginName.AnimeSeg)
|
||||
}
|
||||
}, [runRenderablePlugin])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.GFPGAN, () => {
|
||||
runRenderablePlugin(PluginName.GFPGAN)
|
||||
})
|
||||
return () => {
|
||||
emitter.off(PluginName.GFPGAN)
|
||||
}
|
||||
}, [runRenderablePlugin])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.RestoreFormer, () => {
|
||||
runRenderablePlugin(PluginName.RestoreFormer)
|
||||
})
|
||||
return () => {
|
||||
emitter.off(PluginName.RestoreFormer)
|
||||
}
|
||||
}, [runRenderablePlugin])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(PluginName.RealESRGAN, (data: any) => {
|
||||
runRenderablePlugin(PluginName.RealESRGAN, data)
|
||||
})
|
||||
return () => {
|
||||
emitter.off(PluginName.RealESRGAN)
|
||||
}
|
||||
}, [runRenderablePlugin])
|
||||
|
||||
const hadRunInpainting = () => {
|
||||
return renders.length !== 0
|
||||
}
|
||||
@ -736,14 +260,17 @@ export default function Editor(props: EditorProps) {
|
||||
setScale(s)
|
||||
|
||||
console.log(
|
||||
`[on file load] image size: ${width}x${height}, canvas size: ${context?.canvas.width}x${context?.canvas.height} scale: ${s}, initialCentered: ${initialCentered}`
|
||||
`[on file load] image size: ${width}x${height}, scale: ${s}, initialCentered: ${initialCentered}`
|
||||
)
|
||||
|
||||
if (context?.canvas) {
|
||||
context.canvas.width = width
|
||||
context.canvas.height = height
|
||||
console.log("[on file load] set canvas size && drawOnCurrentRender")
|
||||
// drawOnCurrentRender([])
|
||||
console.log("[on file load] set canvas size")
|
||||
if (width != context.canvas.width) {
|
||||
context.canvas.width = width
|
||||
}
|
||||
if (height != context.canvas.height) {
|
||||
context.canvas.height = height
|
||||
}
|
||||
}
|
||||
|
||||
if (!initialCentered) {
|
||||
@ -753,13 +280,11 @@ export default function Editor(props: EditorProps) {
|
||||
setInitialCentered(true)
|
||||
}
|
||||
}, [
|
||||
context?.canvas,
|
||||
viewportRef,
|
||||
original,
|
||||
isOriginalLoaded,
|
||||
windowSize,
|
||||
initialCentered,
|
||||
// drawOnCurrentRender,
|
||||
getCurrentWidthHeight,
|
||||
])
|
||||
|
||||
@ -807,17 +332,6 @@ export default function Editor(props: EditorProps) {
|
||||
}
|
||||
}, [windowSize, resetZoom])
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener("blur", () => {
|
||||
setIsChangingBrushSizeByMouse(false)
|
||||
})
|
||||
return () => {
|
||||
window.removeEventListener("blur", () => {
|
||||
setIsChangingBrushSizeByMouse(false)
|
||||
})
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleEscPressed = () => {
|
||||
if (isProcessing) {
|
||||
return
|
||||
@ -845,15 +359,6 @@ export default function Editor(props: EditorProps) {
|
||||
}
|
||||
|
||||
const onMouseDrag = (ev: SyntheticEvent) => {
|
||||
// if (isChangingBrushSizeByMouse) {
|
||||
// const initX = changeBrushSizeByMouseInit.x
|
||||
// // move right: increase brush size
|
||||
// const newSize = changeBrushSizeByMouseInit.brushSize + (x - initX)
|
||||
// if (newSize <= MAX_BRUSH_SIZE && newSize >= MIN_BRUSH_SIZE) {
|
||||
// setBaseBrushSize(newSize)
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
if (interactiveSegState.isInteractiveSeg) {
|
||||
return
|
||||
}
|
||||
@ -868,26 +373,16 @@ export default function Editor(props: EditorProps) {
|
||||
}
|
||||
|
||||
handleCanvasMouseMove(mouseXY(ev))
|
||||
// const lineGroup = [...curLineGroup]
|
||||
// lineGroup[lineGroup.length - 1].pts.push(mouseXY(ev))
|
||||
// setCurLineGroup(lineGroup)
|
||||
// drawOnCurrentRender(lineGroup)
|
||||
}
|
||||
|
||||
const runInteractiveSeg = async (newClicks: number[][]) => {
|
||||
if (!file) {
|
||||
return
|
||||
}
|
||||
|
||||
// setIsInteractiveSegRunning(true)
|
||||
updateAppState({ isPluginRunning: true })
|
||||
const targetFile = await getCurrentRender()
|
||||
const prevMask = null
|
||||
try {
|
||||
const res = await runPlugin(
|
||||
PluginName.InteractiveSeg,
|
||||
targetFile,
|
||||
undefined,
|
||||
prevMask,
|
||||
newClicks
|
||||
)
|
||||
if (!res) {
|
||||
@ -896,7 +391,7 @@ export default function Editor(props: EditorProps) {
|
||||
const { blob } = res
|
||||
const img = new Image()
|
||||
img.onload = () => {
|
||||
setTmpInteractiveSegMask(img)
|
||||
updateInteractiveSegState({ tmpInteractiveSegMask: img })
|
||||
}
|
||||
img.src = blob
|
||||
} catch (e: any) {
|
||||
@ -905,7 +400,7 @@ export default function Editor(props: EditorProps) {
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
}
|
||||
// setIsInteractiveSegRunning(false)
|
||||
updateAppState({ isPluginRunning: false })
|
||||
}
|
||||
|
||||
const onPointerUp = (ev: SyntheticEvent) => {
|
||||
@ -933,7 +428,7 @@ export default function Editor(props: EditorProps) {
|
||||
return
|
||||
}
|
||||
|
||||
if (enableManualInpainting) {
|
||||
if (runMannually) {
|
||||
setIsDraging(false)
|
||||
} else {
|
||||
runInpainting()
|
||||
@ -959,15 +454,13 @@ export default function Editor(props: EditorProps) {
|
||||
const onCanvasMouseUp = (ev: SyntheticEvent) => {
|
||||
if (interactiveSegState.isInteractiveSeg) {
|
||||
const xy = mouseXY(ev)
|
||||
const isX = xy.x
|
||||
const isY = xy.y
|
||||
const newClicks: number[][] = [...interactiveSegState.clicks]
|
||||
if (isRightClick(ev)) {
|
||||
newClicks.push([isX, isY, 0, newClicks.length])
|
||||
newClicks.push([xy.x, xy.y, 0, newClicks.length])
|
||||
} else {
|
||||
newClicks.push([isX, isY, 1, newClicks.length])
|
||||
newClicks.push([xy.x, xy.y, 1, newClicks.length])
|
||||
}
|
||||
// runInteractiveSeg(newClicks)
|
||||
runInteractiveSeg(newClicks)
|
||||
updateInteractiveSegState({ clicks: newClicks })
|
||||
}
|
||||
}
|
||||
@ -979,13 +472,10 @@ export default function Editor(props: EditorProps) {
|
||||
if (interactiveSegState.isInteractiveSeg) {
|
||||
return
|
||||
}
|
||||
if (isChangingBrushSizeByMouse) {
|
||||
return
|
||||
}
|
||||
if (isPanning) {
|
||||
return
|
||||
}
|
||||
if (!original.src) {
|
||||
if (!isOriginalLoaded) {
|
||||
return
|
||||
}
|
||||
const canvas = context?.canvas
|
||||
@ -1002,26 +492,17 @@ export default function Editor(props: EditorProps) {
|
||||
return
|
||||
}
|
||||
|
||||
if (
|
||||
isDiffusionModels &&
|
||||
settings.showCroper &&
|
||||
isOutsideCroper(mouseXY(ev))
|
||||
) {
|
||||
return
|
||||
}
|
||||
// if (
|
||||
// isDiffusionModels &&
|
||||
// settings.showCroper &&
|
||||
// isOutsideCroper(mouseXY(ev))
|
||||
// ) {
|
||||
// // TODO: 去掉这个逻辑,在 cropper 层截断 click 点击?
|
||||
// return
|
||||
// }
|
||||
|
||||
setIsDraging(true)
|
||||
|
||||
// let lineGroup: LineGroup = []
|
||||
// if (enableManualInpainting) {
|
||||
// lineGroup = [...curLineGroup]
|
||||
// }
|
||||
// lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] })
|
||||
// setCurLineGroup(lineGroup)
|
||||
|
||||
handleCanvasMouseDown(mouseXY(ev))
|
||||
|
||||
// drawOnCurrentRender(lineGroup)
|
||||
}
|
||||
|
||||
const handleUndo = (keyboardEvent: KeyboardEvent | SyntheticEvent) => {
|
||||
@ -1092,7 +573,7 @@ export default function Editor(props: EditorProps) {
|
||||
let maskFileName = file.name.replace(/(\.[\w\d_-]+)$/i, "_mask$1")
|
||||
maskFileName = maskFileName.replace(/\.[^/.]+$/, ".jpg")
|
||||
|
||||
drawLinesOnMask(lineGroups)
|
||||
const maskCanvas = generateMask(imageWidth, imageHeight, lineGroups)
|
||||
// Create a link
|
||||
const aDownloadLink = document.createElement("a")
|
||||
// Add the name of the file to the link
|
||||
@ -1147,11 +628,11 @@ export default function Editor(props: EditorProps) {
|
||||
useHotkeys(
|
||||
"shift+r",
|
||||
() => {
|
||||
if (enableManualInpainting && hadDrawSomething()) {
|
||||
if (runMannually && hadDrawSomething()) {
|
||||
runInpainting()
|
||||
}
|
||||
},
|
||||
[enableManualInpainting, runInpainting, hadDrawSomething]
|
||||
[runMannually, runInpainting, hadDrawSomething]
|
||||
)
|
||||
|
||||
useHotkeys(
|
||||
@ -1192,13 +673,11 @@ export default function Editor(props: EditorProps) {
|
||||
(ev) => {
|
||||
ev?.preventDefault()
|
||||
ev?.stopPropagation()
|
||||
setIsChangingBrushSizeByMouse(true)
|
||||
setChangeBrushSizeByMouseInit({ x, y, brushSize })
|
||||
// TODO: mouse scroll increase/decrease brush size
|
||||
},
|
||||
(ev) => {
|
||||
ev?.preventDefault()
|
||||
ev?.stopPropagation()
|
||||
setIsChangingBrushSizeByMouse(false)
|
||||
}
|
||||
)
|
||||
|
||||
@ -1359,20 +838,24 @@ export default function Editor(props: EditorProps) {
|
||||
show={settings.showCroper}
|
||||
/>
|
||||
|
||||
{/* {interactiveSegState.isInteractiveSeg ? <InteractiveSeg /> : <></>} */}
|
||||
{interactiveSegState.isInteractiveSeg ? (
|
||||
<InteractiveSegPoints />
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</TransformComponent>
|
||||
</TransformWrapper>
|
||||
)
|
||||
}
|
||||
|
||||
const onInteractiveAccept = () => {
|
||||
setInteractiveSegMask(tmpInteractiveSegMask)
|
||||
setTmpInteractiveSegMask(null)
|
||||
// const onInteractiveAccept = () => {
|
||||
// setInteractiveSegMask(tmpInteractiveSegMask)
|
||||
// setTmpInteractiveSegMask(null)
|
||||
|
||||
if (!enableManualInpainting && tmpInteractiveSegMask) {
|
||||
runInpainting(false, undefined, tmpInteractiveSegMask)
|
||||
}
|
||||
}
|
||||
// if (!enableManualInpainting && tmpInteractiveSegMask) {
|
||||
// runInpainting(false, undefined, tmpInteractiveSegMask)
|
||||
// }
|
||||
// }
|
||||
|
||||
return (
|
||||
<div
|
||||
@ -1388,16 +871,11 @@ export default function Editor(props: EditorProps) {
|
||||
!isPanning &&
|
||||
(interactiveSegState.isInteractiveSeg
|
||||
? renderInteractiveSegCursor()
|
||||
: renderBrush(
|
||||
getBrushStyle(
|
||||
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.x : x,
|
||||
isChangingBrushSizeByMouse ? changeBrushSizeByMouseInit.y : y
|
||||
)
|
||||
))}
|
||||
: renderBrush(getBrushStyle(x, y)))}
|
||||
|
||||
{showRefBrush && renderBrush(getBrushStyle(windowCenterX, windowCenterY))}
|
||||
|
||||
<div className="fixed flex bottom-5 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md bg-background/50">
|
||||
<div className="fixed flex bottom-5 border px-4 py-2 rounded-[3rem] gap-8 items-center justify-center backdrop-filter backdrop-blur-md bg-background/70">
|
||||
<Slider
|
||||
className="w-48"
|
||||
defaultValue={[50]}
|
||||
@ -1464,16 +942,17 @@ export default function Editor(props: EditorProps) {
|
||||
<Download />
|
||||
</IconButton>
|
||||
|
||||
{settings.enableManualInpainting ? (
|
||||
{settings.enableManualInpainting &&
|
||||
settings.model.model_type === "inpaint" ? (
|
||||
<IconButton
|
||||
tooltip="Run Inpainting"
|
||||
disabled={
|
||||
isProcessing ||
|
||||
(!hadDrawSomething() && interactiveSegMask === null)
|
||||
(!hadDrawSomething() &&
|
||||
interactiveSegState.interactiveSegMask === null)
|
||||
}
|
||||
onClick={() => {
|
||||
// ensured by disabled
|
||||
runInpainting(false, undefined, interactiveSegMask)
|
||||
runInpainting()
|
||||
}}
|
||||
>
|
||||
<Eraser />
|
||||
|
@ -69,7 +69,7 @@ const Header = () => {
|
||||
)
|
||||
|
||||
return (
|
||||
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 backdrop-filter backdrop-blur-md border-b">
|
||||
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
|
||||
<div className="flex items-center gap-1">
|
||||
{enableFileManager ? (
|
||||
<FileManager
|
||||
|
@ -39,19 +39,20 @@ const InteractiveSegReplaceModal = (props: InteractiveSegReplaceModal) => {
|
||||
}
|
||||
|
||||
const InteractiveSegConfirmActions = () => {
|
||||
const [interactiveSegState, resetInteractiveSegState] = useStore((state) => [
|
||||
const [
|
||||
interactiveSegState,
|
||||
resetInteractiveSegState,
|
||||
handleInteractiveSegAccept,
|
||||
] = useStore((state) => [
|
||||
state.interactiveSegState,
|
||||
state.resetInteractiveSegState,
|
||||
state.handleInteractiveSegAccept,
|
||||
])
|
||||
|
||||
if (!interactiveSegState.isInteractiveSeg) {
|
||||
return null
|
||||
}
|
||||
|
||||
const onAcceptClick = () => {
|
||||
resetInteractiveSegState()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="z-10 absolute top-[68px] rounded-xl border-solid border p-[8px] left-1/2 translate-x-[-50%] flex justify-center items-center gap-[8px] bg-background">
|
||||
<Button
|
||||
@ -66,7 +67,7 @@ const InteractiveSegConfirmActions = () => {
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
onAcceptClick()
|
||||
handleInteractiveSegAccept()
|
||||
}}
|
||||
>
|
||||
Accept
|
||||
@ -84,11 +85,11 @@ interface ItemProps {
|
||||
const Item = (props: ItemProps) => {
|
||||
const { x, y, positive } = props
|
||||
const name = positive
|
||||
? "bg-[rgba(21,_215,_121,_0.936)] outline-[6px_solid_rgba(98,_255,_179,_0.31)]"
|
||||
: "bg-[rgba(237,_49,_55,_0.942)] outline-[6px_solid_rgba(255,_89,_95,_0.31)]"
|
||||
? "bg-[rgba(21,_215,_121,_0.936)] outline-[rgba(98,255,179,0.31)]"
|
||||
: "bg-[rgba(237,_49,_55,_0.942)] outline-[rgba(255,89,95,0.31)]"
|
||||
return (
|
||||
<div
|
||||
className={`absolute h-[8px] w-[8px] rounded-[50%] ${name}`}
|
||||
className={`absolute h-[10px] w-[10px] rounded-[50%] ${name} outline-8 outline`}
|
||||
style={{
|
||||
left: x,
|
||||
top: y,
|
||||
|
@ -15,9 +15,7 @@ import {
|
||||
Slice,
|
||||
Smile,
|
||||
} from "lucide-react"
|
||||
import { MixIcon } from "@radix-ui/react-icons"
|
||||
import { useStore } from "@/lib/states"
|
||||
import { InteractiveSeg } from "./InteractiveSeg"
|
||||
|
||||
export enum PluginName {
|
||||
RemoveBG = "RemoveBG",
|
||||
@ -56,44 +54,49 @@ const pluginMap = {
|
||||
}
|
||||
|
||||
const Plugins = () => {
|
||||
const [plugins, updateInteractiveSegState] = useStore((state) => [
|
||||
state.serverConfig.plugins,
|
||||
state.updateInteractiveSegState,
|
||||
])
|
||||
const [file, plugins, updateInteractiveSegState, runRenderablePlugin] =
|
||||
useStore((state) => [
|
||||
state.file,
|
||||
state.serverConfig.plugins,
|
||||
state.updateInteractiveSegState,
|
||||
state.runRenderablePlugin,
|
||||
])
|
||||
const disabled = !file
|
||||
|
||||
if (plugins.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const onPluginClick = (pluginName: string) => {
|
||||
// if (!disabled) {
|
||||
// emitter.emit(pluginName)
|
||||
// }
|
||||
if (pluginName === PluginName.InteractiveSeg) {
|
||||
updateInteractiveSegState({ isInteractiveSeg: true })
|
||||
} else {
|
||||
runRenderablePlugin(pluginName)
|
||||
}
|
||||
}
|
||||
|
||||
const onRealESRGANClick = (upscale: number) => {
|
||||
// if (!disabled) {
|
||||
// emitter.emit(PluginName.RealESRGAN, { upscale })
|
||||
// }
|
||||
}
|
||||
|
||||
const renderRealESRGANPlugin = () => {
|
||||
return (
|
||||
<DropdownMenuSub key="RealESRGAN">
|
||||
<DropdownMenuSubTrigger>
|
||||
<DropdownMenuSubTrigger disabled={disabled}>
|
||||
<div className="flex gap-2 items-center">
|
||||
<Fullscreen />
|
||||
RealESRGAN
|
||||
</div>
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuSubContent>
|
||||
<DropdownMenuItem onClick={() => onRealESRGANClick(2)}>
|
||||
<DropdownMenuItem
|
||||
onClick={() =>
|
||||
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 2 })
|
||||
}
|
||||
>
|
||||
upscale 2x
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={() => onRealESRGANClick(4)}>
|
||||
<DropdownMenuItem
|
||||
onClick={() =>
|
||||
runRenderablePlugin(PluginName.RealESRGAN, { upscale: 4 })
|
||||
}
|
||||
>
|
||||
upscale 4x
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuSubContent>
|
||||
@ -108,7 +111,11 @@ const Plugins = () => {
|
||||
return renderRealESRGANPlugin()
|
||||
}
|
||||
return (
|
||||
<DropdownMenuItem key={plugin} onClick={() => onPluginClick(plugin)}>
|
||||
<DropdownMenuItem
|
||||
key={plugin}
|
||||
onClick={() => onPluginClick(plugin)}
|
||||
disabled={disabled}
|
||||
>
|
||||
<div className="flex gap-2 items-center">
|
||||
<IconClass className="p-1" />
|
||||
{showName}
|
||||
@ -121,7 +128,7 @@ const Plugins = () => {
|
||||
return (
|
||||
<DropdownMenu modal={false}>
|
||||
<DropdownMenuTrigger
|
||||
className="border rounded-lg z-10 bg-background"
|
||||
className="border rounded-lg z-10 bg-background outline-none"
|
||||
tabIndex={-1}
|
||||
>
|
||||
<Button variant="ghost" size="icon" asChild className="p-1.5">
|
||||
|
@ -1,19 +1,17 @@
|
||||
import React, { FormEvent } from "react"
|
||||
import emitter, {
|
||||
DREAM_BUTTON_MOUSE_ENTER,
|
||||
DREAM_BUTTON_MOUSE_LEAVE,
|
||||
EVENT_PROMPT,
|
||||
} from "@/lib/event"
|
||||
import { Button } from "./ui/button"
|
||||
import { Input } from "./ui/input"
|
||||
import { useStore } from "@/lib/states"
|
||||
|
||||
const PromptInput = () => {
|
||||
const [isInpainting, prompt, updateSettings] = useStore((state) => [
|
||||
state.isInpainting,
|
||||
state.settings.prompt,
|
||||
state.updateSettings,
|
||||
])
|
||||
const [isProcessing, prompt, updateSettings, runInpainting] = useStore(
|
||||
(state) => [
|
||||
state.getIsProcessing(),
|
||||
state.settings.prompt,
|
||||
state.updateSettings,
|
||||
state.runInpainting,
|
||||
]
|
||||
)
|
||||
|
||||
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
|
||||
evt.preventDefault()
|
||||
@ -22,25 +20,25 @@ const PromptInput = () => {
|
||||
updateSettings({ prompt: target.value })
|
||||
}
|
||||
|
||||
const handleRepaintClick = () => {
|
||||
if (prompt.length !== 0 && isInpainting) {
|
||||
emitter.emit(EVENT_PROMPT)
|
||||
const handleRepaintClick = async () => {
|
||||
if (prompt.length !== 0 && !isProcessing) {
|
||||
await runInpainting()
|
||||
}
|
||||
}
|
||||
|
||||
const onKeyUp = (e: React.KeyboardEvent) => {
|
||||
if (e.key === "Enter" && !isInpainting) {
|
||||
if (e.key === "Enter" && !isProcessing) {
|
||||
handleRepaintClick()
|
||||
}
|
||||
}
|
||||
|
||||
const onMouseEnter = () => {
|
||||
emitter.emit(DREAM_BUTTON_MOUSE_ENTER)
|
||||
}
|
||||
// const onMouseEnter = () => {
|
||||
// emitter.emit(DREAM_BUTTON_MOUSE_ENTER)
|
||||
// }
|
||||
|
||||
const onMouseLeave = () => {
|
||||
emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
|
||||
}
|
||||
// const onMouseLeave = () => {
|
||||
// emitter.emit(DREAM_BUTTON_MOUSE_LEAVE)
|
||||
// }
|
||||
|
||||
return (
|
||||
<div className="flex gap-4 items-center">
|
||||
@ -54,9 +52,9 @@ const PromptInput = () => {
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleRepaintClick}
|
||||
disabled={prompt.length === 0 || isInpainting}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseLeave={onMouseLeave}
|
||||
disabled={prompt.length === 0 || isProcessing}
|
||||
// onMouseEnter={onMouseEnter}
|
||||
// onMouseLeave={onMouseLeave}
|
||||
>
|
||||
Dream
|
||||
</Button>
|
||||
|
@ -29,7 +29,7 @@ const DropdownMenuSubTrigger = React.forwardRef<
|
||||
<DropdownMenuPrimitive.SubTrigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent data-[state=open]:bg-accent",
|
||||
"flex cursor-default select-none items-center rounded-sm px-2 py-1.5 text-sm outline-none focus:bg-accent data-[state=open]:bg-accent data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
|
||||
inset && "pl-8",
|
||||
className
|
||||
)}
|
||||
@ -72,6 +72,7 @@ const DropdownMenuContent = React.forwardRef<
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
onCloseAutoFocus={(e) => e.preventDefault()}
|
||||
/>
|
||||
</DropdownMenuPrimitive.Portal>
|
||||
))
|
||||
|
@ -15,18 +15,13 @@ export default async function inpaint(
|
||||
imageFile: File,
|
||||
settings: Settings,
|
||||
croperRect: Rect,
|
||||
maskBase64?: string,
|
||||
customMask?: File,
|
||||
paintByExampleImage?: File
|
||||
mask: File | Blob,
|
||||
paintByExampleImage: File | null = null
|
||||
) {
|
||||
// 1080, 2000, Original
|
||||
const fd = new FormData()
|
||||
fd.append("image", imageFile)
|
||||
if (maskBase64 !== undefined) {
|
||||
fd.append("mask", dataURItoBlob(maskBase64))
|
||||
} else if (customMask !== undefined) {
|
||||
fd.append("mask", customMask)
|
||||
}
|
||||
fd.append("mask", mask)
|
||||
|
||||
fd.append("ldmSteps", settings.ldmSteps.toString())
|
||||
fd.append("ldmSampler", settings.ldmSampler.toString())
|
||||
@ -42,8 +37,7 @@ export default async function inpaint(
|
||||
fd.append("croperY", croperRect.y.toString())
|
||||
fd.append("croperHeight", croperRect.height.toString())
|
||||
fd.append("croperWidth", croperRect.width.toString())
|
||||
// fd.append("useCroper", settings.showCroper ? "true" : "false")
|
||||
fd.append("useCroper", "false")
|
||||
fd.append("useCroper", settings.showCroper ? "true" : "false")
|
||||
|
||||
fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
|
||||
fd.append("sdStrength", settings.sdStrength.toString())
|
||||
@ -147,7 +141,6 @@ export async function runPlugin(
|
||||
name: string,
|
||||
imageFile: File,
|
||||
upscale?: number,
|
||||
maskFile?: File | null,
|
||||
clicks?: number[][]
|
||||
) {
|
||||
const fd = new FormData()
|
||||
@ -159,9 +152,6 @@ export async function runPlugin(
|
||||
if (clicks) {
|
||||
fd.append("clicks", JSON.stringify(clicks))
|
||||
}
|
||||
if (maskFile) {
|
||||
fd.append("mask", maskFile)
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { create } from "zustand"
|
||||
import { persist } from "zustand/middleware"
|
||||
import { shallow } from "zustand/shallow"
|
||||
import { immer } from "zustand/middleware/immer"
|
||||
import { castDraft } from "immer"
|
||||
import { nanoid } from "nanoid"
|
||||
import { createWithEqualityFn } from "zustand/traditional"
|
||||
import {
|
||||
CV2Flag,
|
||||
@ -10,6 +11,7 @@ import {
|
||||
Line,
|
||||
LineGroup,
|
||||
ModelInfo,
|
||||
PluginParams,
|
||||
Point,
|
||||
SDSampler,
|
||||
Size,
|
||||
@ -17,6 +19,9 @@ import {
|
||||
SortOrder,
|
||||
} from "./types"
|
||||
import { DEFAULT_BRUSH_SIZE, MODEL_TYPE_INPAINT } from "./const"
|
||||
import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils"
|
||||
import inpaint, { runPlugin } from "./api"
|
||||
import { toast, useToast } from "@/components/ui/use-toast"
|
||||
|
||||
type FileManagerState = {
|
||||
sortBy: SortBy
|
||||
@ -95,7 +100,9 @@ type ServerConfig = {
|
||||
|
||||
type InteractiveSegState = {
|
||||
isInteractiveSeg: boolean
|
||||
isInteractiveSegRunning: boolean
|
||||
interactiveSegMask: HTMLImageElement | null
|
||||
tmpInteractiveSegMask: HTMLImageElement | null
|
||||
prevInteractiveSegMask: HTMLImageElement | null
|
||||
clicks: number[][]
|
||||
}
|
||||
|
||||
@ -103,9 +110,11 @@ type EditorState = {
|
||||
baseBrushSize: number
|
||||
brushSizeScale: number
|
||||
renders: HTMLImageElement[]
|
||||
paintByExampleImage: File | null
|
||||
lineGroups: LineGroup[]
|
||||
lastLineGroup: LineGroup
|
||||
curLineGroup: LineGroup
|
||||
extraMasks: HTMLImageElement[]
|
||||
// redo 相关
|
||||
redoRenders: HTMLImageElement[]
|
||||
redoCurLines: Line[]
|
||||
@ -113,6 +122,8 @@ type EditorState = {
|
||||
}
|
||||
|
||||
type AppState = {
|
||||
idForUpdateView: string
|
||||
|
||||
file: File | null
|
||||
customMask: File | null
|
||||
imageHeight: number
|
||||
@ -136,6 +147,7 @@ type AppAction = {
|
||||
setCustomFile: (file: File) => void
|
||||
setIsInpainting: (newValue: boolean) => void
|
||||
setIsPluginRunning: (newValue: boolean) => void
|
||||
getIsProcessing: () => boolean
|
||||
setBaseBrushSize: (newValue: number) => void
|
||||
getBrushSize: () => number
|
||||
setImageSize: (width: number, height: number) => void
|
||||
@ -151,10 +163,18 @@ type AppAction = {
|
||||
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||
resetInteractiveSegState: () => void
|
||||
handleInteractiveSegAccept: () => void
|
||||
showPromptInput: () => boolean
|
||||
showSidePanel: () => boolean
|
||||
|
||||
runInpainting: () => Promise<void>
|
||||
runRenderablePlugin: (
|
||||
pluginName: string,
|
||||
params?: PluginParams
|
||||
) => Promise<void>
|
||||
|
||||
// EditorState
|
||||
getCurrentTargetFile: () => Promise<File>
|
||||
updateEditorState: (newState: Partial<EditorState>) => void
|
||||
runMannually: () => boolean
|
||||
handleCanvasMouseDown: (point: Point) => void
|
||||
@ -168,12 +188,15 @@ type AppAction = {
|
||||
}
|
||||
|
||||
const defaultValues: AppState = {
|
||||
idForUpdateView: nanoid(),
|
||||
|
||||
file: null,
|
||||
customMask: null,
|
||||
imageHeight: 0,
|
||||
imageWidth: 0,
|
||||
isInpainting: false,
|
||||
isPluginRunning: false,
|
||||
|
||||
windowSize: {
|
||||
height: 600,
|
||||
width: 800,
|
||||
@ -182,6 +205,8 @@ const defaultValues: AppState = {
|
||||
baseBrushSize: DEFAULT_BRUSH_SIZE,
|
||||
brushSizeScale: 1,
|
||||
renders: [],
|
||||
paintByExampleImage: null,
|
||||
extraMasks: [],
|
||||
lineGroups: [],
|
||||
lastLineGroup: [],
|
||||
curLineGroup: [],
|
||||
@ -192,7 +217,9 @@ const defaultValues: AppState = {
|
||||
|
||||
interactiveSegState: {
|
||||
isInteractiveSeg: false,
|
||||
isInteractiveSegRunning: false,
|
||||
interactiveSegMask: null,
|
||||
tmpInteractiveSegMask: null,
|
||||
prevInteractiveSegMask: null,
|
||||
clicks: [],
|
||||
},
|
||||
|
||||
@ -267,16 +294,208 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
immer((set, get) => ({
|
||||
...defaultValues,
|
||||
|
||||
getCurrentTargetFile: async (): Promise<File> => {
|
||||
const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数
|
||||
const renders = get().editorState.renders
|
||||
|
||||
let targetFile = file
|
||||
if (renders.length > 0) {
|
||||
const lastRender = renders[renders.length - 1]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
return targetFile
|
||||
},
|
||||
|
||||
runInpainting: async () => {
|
||||
const { file, imageWidth, imageHeight, settings, cropperState } = get()
|
||||
|
||||
if (file === null) {
|
||||
return
|
||||
}
|
||||
const {
|
||||
lastLineGroup,
|
||||
curLineGroup,
|
||||
lineGroups,
|
||||
renders,
|
||||
paintByExampleImage,
|
||||
} = get().editorState
|
||||
|
||||
const { interactiveSegMask, prevInteractiveSegMask } =
|
||||
get().interactiveSegState
|
||||
|
||||
const useLastLineGroup =
|
||||
curLineGroup.length === 0 && interactiveSegMask === null
|
||||
|
||||
const maskImage = useLastLineGroup
|
||||
? prevInteractiveSegMask
|
||||
: interactiveSegMask
|
||||
|
||||
// useLastLineGroup 的影响
|
||||
// 1. 使用上一次的 mask
|
||||
// 2. 结果替换当前 render
|
||||
let maskLineGroup: LineGroup = []
|
||||
if (useLastLineGroup === true) {
|
||||
if (lastLineGroup.length === 0 && maskImage === null) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: "Please draw mask on picture",
|
||||
})
|
||||
return
|
||||
}
|
||||
maskLineGroup = lastLineGroup
|
||||
} else {
|
||||
if (curLineGroup.length === 0 && maskImage === null) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: "Please draw mask on picture",
|
||||
})
|
||||
return
|
||||
}
|
||||
maskLineGroup = curLineGroup
|
||||
}
|
||||
|
||||
const newLineGroups = [...lineGroups, maskLineGroup]
|
||||
|
||||
set((state) => {
|
||||
state.isInpainting = true
|
||||
})
|
||||
|
||||
let targetFile = file
|
||||
if (useLastLineGroup === true) {
|
||||
// renders.length == 1 还是用原来的
|
||||
if (renders.length > 1) {
|
||||
const lastRender = renders[renders.length - 2]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
} else if (renders.length > 0) {
|
||||
const lastRender = renders[renders.length - 1]
|
||||
targetFile = await srcToFile(
|
||||
lastRender.currentSrc,
|
||||
file.name,
|
||||
file.type
|
||||
)
|
||||
}
|
||||
|
||||
const maskCanvas = generateMask(
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
[maskLineGroup],
|
||||
maskImage ? [maskImage] : []
|
||||
)
|
||||
|
||||
try {
|
||||
const res = await inpaint(
|
||||
targetFile,
|
||||
settings,
|
||||
cropperState,
|
||||
dataURItoBlob(maskCanvas.toDataURL()),
|
||||
paintByExampleImage
|
||||
)
|
||||
|
||||
if (!res) {
|
||||
throw new Error("Something went wrong on server side.")
|
||||
}
|
||||
|
||||
const { blob, seed } = res
|
||||
if (seed) {
|
||||
set((state) => (state.settings.seed = parseInt(seed, 10)))
|
||||
}
|
||||
const newRender = new Image()
|
||||
await loadImage(newRender, blob)
|
||||
if (useLastLineGroup === true) {
|
||||
const prevRenders = renders.slice(0, -1)
|
||||
const newRenders = [...prevRenders, newRender]
|
||||
get().updateEditorState({
|
||||
renders: newRenders,
|
||||
lineGroups: newLineGroups,
|
||||
lastLineGroup: curLineGroup,
|
||||
curLineGroup: [],
|
||||
})
|
||||
} else {
|
||||
const newRenders = [...renders, newRender]
|
||||
get().updateEditorState({
|
||||
renders: newRenders,
|
||||
lineGroups: newLineGroups,
|
||||
lastLineGroup: curLineGroup,
|
||||
curLineGroup: [],
|
||||
})
|
||||
}
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
}
|
||||
|
||||
get().resetRedoState()
|
||||
set((state) => {
|
||||
state.isInpainting = false
|
||||
})
|
||||
|
||||
const newInteractiveSegState = {
|
||||
...defaultValues.interactiveSegState,
|
||||
prevInteractiveSegMask: useLastLineGroup ? null : maskImage,
|
||||
}
|
||||
|
||||
set((state) => {
|
||||
state.interactiveSegState = castDraft(newInteractiveSegState)
|
||||
})
|
||||
},
|
||||
|
||||
runRenderablePlugin: async (
|
||||
pluginName: string,
|
||||
params: PluginParams = { upscale: 1 }
|
||||
) => {
|
||||
const { renders, lineGroups } = get().editorState
|
||||
set((state) => {
|
||||
state.isInpainting = true
|
||||
})
|
||||
|
||||
try {
|
||||
const start = new Date()
|
||||
const targetFile = await get().getCurrentTargetFile()
|
||||
const res = await runPlugin(pluginName, targetFile, params.upscale)
|
||||
if (!res) {
|
||||
throw new Error("Something went wrong on server side.")
|
||||
}
|
||||
const { blob } = res
|
||||
const newRender = new Image()
|
||||
await loadImage(newRender, blob)
|
||||
get().setImageSize(newRender.height, newRender.width)
|
||||
const newRenders = [...renders, newRender]
|
||||
const newLineGroups = [...lineGroups, []]
|
||||
get().updateEditorState({
|
||||
renders: newRenders,
|
||||
lineGroups: newLineGroups,
|
||||
})
|
||||
const end = new Date()
|
||||
const time = end.getTime() - start.getTime()
|
||||
toast({
|
||||
description: `Run ${pluginName} successfully in ${time / 1000}s`,
|
||||
})
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
}
|
||||
set((state) => {
|
||||
state.isInpainting = false
|
||||
})
|
||||
},
|
||||
|
||||
// Edirot State //
|
||||
updateEditorState: (newState: Partial<EditorState>) => {
|
||||
set((state) => {
|
||||
return {
|
||||
...state,
|
||||
editorState: {
|
||||
...state.editorState,
|
||||
...newState,
|
||||
},
|
||||
}
|
||||
state.editorState = castDraft({ ...state.editorState, ...newState })
|
||||
})
|
||||
},
|
||||
|
||||
@ -313,6 +532,10 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
)
|
||||
},
|
||||
|
||||
getIsProcessing: (): boolean => {
|
||||
return get().isInpainting || get().isPluginRunning
|
||||
},
|
||||
|
||||
// undo/redo
|
||||
|
||||
undoDisabled: (): boolean => {
|
||||
@ -468,15 +691,30 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
|
||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => {
|
||||
set((state) => {
|
||||
state.interactiveSegState = {
|
||||
...state.interactiveSegState,
|
||||
...newState,
|
||||
return {
|
||||
...state,
|
||||
interactiveSegState: {
|
||||
...state.interactiveSegState,
|
||||
...newState,
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
|
||||
resetInteractiveSegState: () => {
|
||||
get().updateInteractiveSegState(defaultValues.interactiveSegState)
|
||||
},
|
||||
|
||||
handleInteractiveSegAccept: () => {
|
||||
set((state) => {
|
||||
state.interactiveSegState = defaultValues.interactiveSegState
|
||||
return {
|
||||
...state,
|
||||
interactiveSegState: {
|
||||
...defaultValues.interactiveSegState,
|
||||
interactiveSegMask:
|
||||
state.interactiveSegState.tmpInteractiveSegMask,
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
|
||||
@ -492,8 +730,12 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
|
||||
setFile: (file: File) =>
|
||||
set((state) => {
|
||||
// TODO: 清空各种状态
|
||||
state.file = file
|
||||
state.interactiveSegState = castDraft(
|
||||
defaultValues.interactiveSegState
|
||||
)
|
||||
state.editorState = castDraft(defaultValues.editorState)
|
||||
state.cropperState = defaultValues.cropperState
|
||||
}),
|
||||
|
||||
setCustomFile: (file: File) =>
|
||||
|
@ -25,6 +25,10 @@ export enum PluginName {
|
||||
InteractiveSeg = "InteractiveSeg",
|
||||
}
|
||||
|
||||
export interface PluginParams {
|
||||
upscale: number
|
||||
}
|
||||
|
||||
export enum SortBy {
|
||||
NAME = "name",
|
||||
CTIME = "ctime",
|
||||
|
@ -159,3 +159,28 @@ export function drawLines(
|
||||
ctx.stroke()
|
||||
})
|
||||
}
|
||||
|
||||
export const generateMask = (
|
||||
imageWidth: number,
|
||||
imageHeight: number,
|
||||
lineGroups: LineGroup[],
|
||||
maskImages: HTMLImageElement[] = []
|
||||
): HTMLCanvasElement => {
|
||||
const maskCanvas = document.createElement("canvas")
|
||||
maskCanvas.width = imageWidth
|
||||
maskCanvas.height = imageHeight
|
||||
const ctx = maskCanvas.getContext("2d")
|
||||
if (!ctx) {
|
||||
throw new Error("could not retrieve mask canvas")
|
||||
}
|
||||
|
||||
maskImages.forEach((maskImage) => {
|
||||
ctx.drawImage(maskImage, 0, 0, imageWidth, imageHeight)
|
||||
})
|
||||
|
||||
lineGroups.forEach((lineGroup) => {
|
||||
drawLines(ctx, lineGroup, "white")
|
||||
})
|
||||
|
||||
return maskCanvas
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user