From f27fc51e34e0c6d496ca42098d062cbe34d652d2 Mon Sep 17 00:00:00 2001 From: Qing Date: Sun, 17 Dec 2023 22:15:48 +0800 Subject: [PATCH] update --- web_app/src/components/Editor.tsx | 16 +- web_app/src/components/Expender.tsx | 393 +++++++++++++++++++++++++ web_app/src/components/PromptInput.tsx | 39 ++- web_app/src/components/Settings.tsx | 18 +- web_app/src/components/SidePanel.tsx | 30 +- web_app/src/lib/api.ts | 6 +- web_app/src/lib/states.ts | 101 ++++++- web_app/src/lib/types.ts | 2 + web_app/src/lib/utils.ts | 23 +- 9 files changed, 591 insertions(+), 37 deletions(-) create mode 100644 web_app/src/components/Expender.tsx diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index c98d089..3c26d04 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -28,6 +28,7 @@ import { useStore } from "@/lib/states" import Cropper from "./Cropper" import { InteractiveSegPoints } from "./InteractiveSeg" import useHotKey from "@/hooks/useHotkey" +import Extender from "./Expender" const TOOLBAR_HEIGHT = 200 const MIN_BRUSH_SIZE = 10 @@ -170,11 +171,7 @@ export default function Editor(props: EditorProps) { imageHeight ) } - // if (dreamButtonHoverSegMask) { - // context.drawImage(dreamButtonHoverSegMask, 0, 0, imageWidth, imageHeight) - // } drawLines(context, curLineGroup) - // drawLines(context, dreamButtonHoverLineGroup) }, [ renders, extraMasks, @@ -788,7 +785,16 @@ export default function Editor(props: EditorProps) { minHeight={Math.min(256, imageHeight)} minWidth={Math.min(256, imageWidth)} scale={getCurScale()} - show={settings.showCroper} + show={settings.showCropper} + /> + + {interactiveSegState.isInteractiveSeg ? ( diff --git a/web_app/src/components/Expender.tsx b/web_app/src/components/Expender.tsx new file mode 100644 index 0000000..ea58c93 --- /dev/null +++ b/web_app/src/components/Expender.tsx @@ -0,0 +1,393 @@ +import { useStore } from "@/lib/states" +import { cn } from "@/lib/utils" +import React, { useEffect, useState } from "react" +import { twMerge } from "tailwind-merge" + +const DOC_MOVE_OPTS = { capture: true, passive: false } + +const DRAG_HANDLE_BORDER = 2 + +interface EVData { + initX: number + initY: number + initHeight: number + initWidth: number + startResizeX: number + startResizeY: number + ord: string // top/right/bottom/left +} + +interface Props { + maxHeight: number + maxWidth: number + scale: number + minHeight: number + minWidth: number + show: boolean +} + +const clamp = ( + newPos: number, + newLength: number, + oldPos: number, + oldLength: number, + minLength: number, + maxLength: number +) => { + return [newPos, newLength] + if (newPos !== oldPos && newLength === oldLength) { + if (newPos < 0) { + return [0, oldLength] + } + if (newPos + newLength > maxLength) { + return [maxLength - oldLength, oldLength] + } + } else { + if (newLength < minLength) { + if (newPos === oldPos) { + return [newPos, minLength] + } + return [newPos + newLength - minLength, minLength] + } + if (newPos < 0) { + return [0, newPos + newLength] + } + if (newPos + newLength > maxLength) { + return [newPos, maxLength - newPos] + } + } + + return [newPos, newLength] +} + +const Extender = (props: Props) => { + const { minHeight, minWidth, maxHeight, maxWidth, scale, show } = props + + const [ + imageWidth, + imageHeight, + isInpainting, + { x, y, width, height }, + setX, + setY, + setWidth, + setHeight, + ] = useStore((state) => [ + state.imageWidth, + state.imageHeight, + state.isInpainting, + state.extenderState, + state.setExtenderX, + state.setExtenderY, + state.setExtenderWidth, + state.setExtenderHeight, + ]) + + const [isResizing, setIsResizing] = useState(false) + const [isMoving, setIsMoving] = useState(false) + + useEffect(() => { + setX(Math.round((maxWidth - 512) / 2)) + setY(Math.round((maxHeight - 512) / 2)) + }, [maxHeight, maxWidth, imageWidth, imageHeight]) + + const [evData, setEVData] = useState({ + initX: 0, + initY: 0, + initHeight: 0, + initWidth: 0, + startResizeX: 0, + startResizeY: 0, + ord: "top", + }) + + const onDragFocus = () => { + // console.log("focus") + } + + const clampLeftRight = (newX: number, newWidth: number) => { + return clamp(newX, newWidth, x, width, minWidth, maxWidth) + } + + const clampTopBottom = (newY: number, newHeight: number) => { + return clamp(newY, newHeight, y, height, minHeight, maxHeight) + } + + const onPointerMove = (e: PointerEvent) => { + if (isInpainting) { + return + } + const curX = e.clientX + const curY = e.clientY + + const offsetY = Math.round((curY - evData.startResizeY) / scale) + const offsetX = Math.round((curX - evData.startResizeX) / scale) + + const moveTop = () => { + const newHeight = evData.initHeight - offsetY + const newY = evData.initY + offsetY + const [clampedY, clampedHeight] = clampTopBottom(newY, newHeight) + setHeight(clampedHeight) + setY(clampedY) + } + + const moveBottom = () => { + const newHeight = evData.initHeight + offsetY + const [clampedY, clampedHeight] = clampTopBottom(evData.initY, newHeight) + setHeight(clampedHeight) + setY(clampedY) + } + + const moveLeft = () => { + const newWidth = evData.initWidth - offsetX + const newX = evData.initX + offsetX + const [clampedX, clampedWidth] = clampLeftRight(newX, newWidth) + setWidth(clampedWidth) + setX(clampedX) + } + + const moveRight = () => { + const newWidth = evData.initWidth + offsetX + const [clampedX, clampedWidth] = clampLeftRight(evData.initX, newWidth) + setWidth(clampedWidth) + setX(clampedX) + } + + if (isResizing) { + switch (evData.ord) { + case "topleft": { + moveTop() + moveLeft() + break + } + case "topright": { + moveTop() + moveRight() + break + } + case "bottomleft": { + moveBottom() + moveLeft() + break + } + case "bottomright": { + moveBottom() + moveRight() + break + } + case "top": { + moveTop() + break + } + case "right": { + moveRight() + break + } + case "bottom": { + moveBottom() + break + } + case "left": { + moveLeft() + break + } + + default: + break + } + } + + if (isMoving) { + const newX = evData.initX + offsetX + const newY = evData.initY + offsetY + const [clampedX, clampedWidth] = clampLeftRight(newX, evData.initWidth) + const [clampedY, clampedHeight] = clampTopBottom(newY, evData.initHeight) + setWidth(clampedWidth) + setHeight(clampedHeight) + setX(clampedX) + setY(clampedY) + } + } + + const onPointerDone = (e: PointerEvent) => { + if (isResizing) { + setIsResizing(false) + } + + if (isMoving) { + setIsMoving(false) + } + } + + useEffect(() => { + if (isResizing || isMoving) { + document.addEventListener("pointermove", onPointerMove, DOC_MOVE_OPTS) + document.addEventListener("pointerup", onPointerDone, DOC_MOVE_OPTS) + document.addEventListener("pointercancel", onPointerDone, DOC_MOVE_OPTS) + return () => { + document.removeEventListener( + "pointermove", + onPointerMove, + DOC_MOVE_OPTS + ) + document.removeEventListener("pointerup", onPointerDone, DOC_MOVE_OPTS) + document.removeEventListener( + "pointercancel", + onPointerDone, + DOC_MOVE_OPTS + ) + } + } + }, [isResizing, isMoving, width, height, evData]) + + const onCropPointerDown = (e: React.PointerEvent) => { + const { ord } = (e.target as HTMLElement).dataset + if (ord) { + setIsResizing(true) + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord, + }) + } + } + + const createDragHandle = (cursor: string, side1: string, side2: string) => { + const sideLength = 12 + const halfSideLength = sideLength / 2 + const draghandleCls = `w-[${sideLength}px] h-[${sideLength}px] z-[4] absolute content-[''] block border-2 border-primary borde pointer-events-auto hover:bg-primary` + + let xTrans = "0" + let yTrans = "0" + + let side2Key = side2 + let side2Val = `${-halfSideLength}px` + if (side2 === "") { + side2Val = "50%" + if (side1 === "left" || side1 === "right") { + side2Key = "top" + yTrans = "-50%" + } else { + side2Key = "left" + xTrans = "-50%" + } + } + + return ( +
+ ) + } + + const createCropSelection = () => { + return ( +
+
+
+
+
+ {createDragHandle("cursor-nw-resize", "top", "left")} + {createDragHandle("cursor-ne-resize", "top", "right")} + {createDragHandle("cursor-sw-resize", "bottom", "left")} + {createDragHandle("cursor-se-resize", "bottom", "right")} + {createDragHandle("cursor-ns-resize", "top", "")} + {createDragHandle("cursor-ns-resize", "bottom", "")} + {createDragHandle("cursor-ew-resize", "left", "")} + {createDragHandle("cursor-ew-resize", "right", "")} +
+ ) + } + + const onInfoBarPointerDown = (e: React.PointerEvent) => { + setIsMoving(true) + setEVData({ + initX: x, + initY: y, + initHeight: height, + initWidth: width, + startResizeX: e.clientX, + startResizeY: e.clientY, + ord: "", + }) + } + + const createInfoBar = () => { + return ( +
+ {/* TODO: 移动的时候会显示 brush */} + {width} x {height} +
+ ) + } + + const createBorder = () => { + return ( +
+ ) + } + + if (show === false) { + return null + } + + return ( +
+
+ {createBorder()} + {createInfoBar()} + {createCropSelection()} +
+
+ ) +} + +export default Extender diff --git a/web_app/src/components/PromptInput.tsx b/web_app/src/components/PromptInput.tsx index 6bb539a..09e9571 100644 --- a/web_app/src/components/PromptInput.tsx +++ b/web_app/src/components/PromptInput.tsx @@ -5,14 +5,21 @@ import { useStore } from "@/lib/states" import { useClickAway } from "react-use" const PromptInput = () => { - const [isProcessing, prompt, updateSettings, runInpainting] = useStore( - (state) => [ - state.getIsProcessing(), - state.settings.prompt, - state.updateSettings, - state.runInpainting, - ] - ) + const [ + isProcessing, + prompt, + updateSettings, + runInpainting, + showPrevMask, + hidePrevMask, + ] = useStore((state) => [ + state.getIsProcessing(), + state.settings.prompt, + state.updateSettings, + state.runInpainting, + state.showPrevMask, + state.hidePrevMask, + ]) const ref = useRef(null) useClickAway(ref, () => { @@ -41,13 +48,13 @@ const PromptInput = () => { } } - // const onMouseEnter = () => { - // emitter.emit(DREAM_BUTTON_MOUSE_ENTER) - // } + const onMouseEnter = () => { + showPrevMask() + } - // const onMouseLeave = () => { - // emitter.emit(DREAM_BUTTON_MOUSE_LEAVE) - // } + const onMouseLeave = () => { + hidePrevMask() + } return (
@@ -63,8 +70,8 @@ const PromptInput = () => { size="sm" onClick={handleRepaintClick} disabled={prompt.length === 0 || isProcessing} - // onMouseEnter={onMouseEnter} - // onMouseLeave={onMouseLeave} + onMouseEnter={onMouseEnter} + onMouseLeave={onMouseLeave} > Dream diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index e3836c3..fb1d928 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -74,12 +74,14 @@ export function SettingsDialog() { updateSettings, fileManagerState, updateFileManagerState, + setAppModel, ] = useStore((state) => [ state.updateAppState, state.settings, state.updateSettings, state.fileManagerState, state.updateFileManagerState, + state.setModel, ]) const { toast } = useToast() const [model, setModel] = useState(settings.model) @@ -123,7 +125,7 @@ export function SettingsDialog() { toast({ title: `Switch to ${model.name} success`, }) - updateSettings({ model: model }) + setAppModel(model) } else { throw new Error("Server error") } @@ -142,10 +144,16 @@ export function SettingsDialog() { } } - useHotKey("s", () => { - toggleOpen() - onSubmit(form.getValues()) - }) + useHotKey( + "s", + () => { + toggleOpen() + if (open) { + onSubmit(form.getValues()) + } + }, + [open, form, model] + ) function onOpenChange(value: boolean) { toggleOpen() diff --git a/web_app/src/components/SidePanel.tsx b/web_app/src/components/SidePanel.tsx index 04d3200..7d48989 100644 --- a/web_app/src/components/SidePanel.tsx +++ b/web_app/src/components/SidePanel.tsx @@ -397,6 +397,27 @@ const SidePanel = () => { ) } + const renderExpender = () => { + return ( + <> + + + { + updateSettings({ showExpender: value }) + if (value) { + updateSettings({ showCropper: false }) + } + }} + /> + + + + ) + } + return ( { { - updateSettings({ showCroper: value }) + updateSettings({ showCropper: value }) + if (value) { + updateSettings({ showExpender: false }) + } }} /> + {renderExpender()} +
diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 4389788..56e324d 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -18,7 +18,6 @@ export default async function inpaint( mask: File | Blob, paintByExampleImage: File | null = null ) { - // 1080, 2000, Original const fd = new FormData() fd.append("image", imageFile) fd.append("mask", mask) @@ -37,7 +36,7 @@ export default async function inpaint( fd.append("croperY", croperRect.y.toString()) fd.append("croperHeight", croperRect.height.toString()) fd.append("croperWidth", croperRect.width.toString()) - fd.append("useCroper", settings.showCroper ? "true" : "false") + fd.append("useCroper", settings.showCropper ? "true" : "false") fd.append("sdMaskBlur", settings.sdMaskBlur.toString()) fd.append("sdStrength", settings.sdStrength.toString()) @@ -52,6 +51,9 @@ export default async function inpaint( fd.append("sdMatchHistograms", settings.sdMatchHistograms ? "true" : "false") fd.append("sdScale", (settings.sdScale / 100).toString()) + fd.append("enableFreeu", settings.enableFreeu.toString()) + fd.append("freeuConfig", JSON.stringify(settings.freeuConfig)) + fd.append("enableLCMLora", settings.enableLCMLora.toString()) fd.append("cv2Radius", settings.cv2Radius.toString()) fd.append("cv2Flag", settings.cv2Flag.toString()) diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index 25de24a..a1c09c1 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -18,12 +18,19 @@ import { SortOrder, } from "./types" import { + BRUSH_COLOR, DEFAULT_BRUSH_SIZE, DEFAULT_NEGATIVE_PROMPT, MODEL_TYPE_INPAINT, PAINT_BY_EXAMPLE, } from "./const" -import { dataURItoBlob, generateMask, loadImage, srcToFile } from "./utils" +import { + canvasToImage, + dataURItoBlob, + generateMask, + loadImage, + srcToFile, +} from "./utils" import inpaint, { runPlugin } from "./api" import { toast } from "@/components/ui/use-toast" @@ -48,7 +55,8 @@ export type Settings = { enableDownloadMask: boolean enableManualInpainting: boolean enableUploadMask: boolean - showCroper: boolean + showCropper: boolean + showExpender: boolean // For LDM ldmSteps: number @@ -134,6 +142,7 @@ type AppState = { interactiveSegState: InteractiveSegState fileManagerState: FileManagerState cropperState: CropperState + extenderState: CropperState serverConfig: ServerConfig settings: Settings @@ -155,9 +164,15 @@ type AppAction = { setCropperWidth: (newValue: number) => void setCropperHeight: (newValue: number) => void + setExtenderX: (newValue: number) => void + setExtenderY: (newValue: number) => void + setExtenderWidth: (newValue: number) => void + setExtenderHeight: (newValue: number) => void + setServerConfig: (newValue: ServerConfig) => void setSeed: (newValue: number) => void updateSettings: (newSettings: Partial) => void + setModel: (newModel: ModelInfo) => void updateFileManagerState: (newState: Partial) => void updateInteractiveSegState: (newState: Partial) => void resetInteractiveSegState: () => void @@ -166,6 +181,8 @@ type AppAction = { showSidePanel: () => boolean runInpainting: () => Promise + showPrevMask: () => Promise + hidePrevMask: () => void runRenderablePlugin: ( pluginName: string, params?: PluginParams @@ -226,6 +243,13 @@ const defaultValues: AppState = { width: 512, height: 512, }, + extenderState: { + x: 0, + y: 0, + width: 512, + height: 512, + }, + fileManagerState: { sortBy: SortBy.CTIME, sortOrder: SortOrder.DESCENDING, @@ -248,6 +272,7 @@ const defaultValues: AppState = { model_type: "inpaint", support_controlnet: false, support_strength: false, + support_outpainting: false, controlnets: [], support_freeu: false, support_lcm_lora: false, @@ -255,7 +280,8 @@ const defaultValues: AppState = { need_prompt: false, }, enableControlnet: false, - showCroper: false, + showCropper: false, + showExpender: false, enableDownloadMask: false, enableManualInpainting: false, enableUploadMask: false, @@ -289,6 +315,38 @@ export const useStore = createWithEqualityFn()( immer((set, get) => ({ ...defaultValues, + showPrevMask: async () => { + const { lastLineGroup, curLineGroup } = get().editorState + const { prevInteractiveSegMask, interactiveSegMask } = + get().interactiveSegState + if (curLineGroup.length !== 0 || interactiveSegMask !== null) { + return + } + const { imageWidth, imageHeight } = get() + + const maskCanvas = generateMask( + imageWidth, + imageHeight, + [lastLineGroup], + prevInteractiveSegMask ? [prevInteractiveSegMask] : [], + BRUSH_COLOR + ) + try { + const maskImage = await canvasToImage(maskCanvas) + set((state) => { + state.editorState.extraMasks.push(castDraft(maskImage)) + }) + } catch (e) { + console.error(e) + return + } + }, + hidePrevMask: () => { + set((state) => { + state.editorState.extraMasks = [] + }) + }, + getCurrentTargetFile: async (): Promise => { const file = get().file! // 一定是在 file 加载了以后才可能调用这个函数 const renders = get().editorState.renders @@ -415,7 +473,7 @@ export const useStore = createWithEqualityFn()( get().updateEditorState({ renders: newRenders, lineGroups: newLineGroups, - lastLineGroup: curLineGroup, + lastLineGroup: maskLineGroup, curLineGroup: [], }) } catch (e: any) { @@ -432,7 +490,7 @@ export const useStore = createWithEqualityFn()( const newInteractiveSegState = { ...defaultValues.interactiveSegState, - prevInteractiveSegMask: useLastLineGroup ? null : maskImage, + prevInteractiveSegMask: maskImage, } set((state) => { @@ -675,6 +733,19 @@ export const useStore = createWithEqualityFn()( }) }, + setModel: (newModel: ModelInfo) => { + set((state) => { + state.settings.model = newModel + + if ( + newModel.support_controlnet && + !newModel.controlnets.includes(state.settings.controlnetMethod) + ) { + state.settings.controlnetMethod = newModel.controlnets[0] + } + }) + }, + updateFileManagerState: (newState: Partial) => { set((state) => { state.fileManagerState = { @@ -773,6 +844,26 @@ export const useStore = createWithEqualityFn()( state.cropperState.height = newValue }), + setExtenderX: (newValue: number) => + set((state) => { + state.extenderState.x = newValue + }), + + setExtenderY: (newValue: number) => + set((state) => { + state.extenderState.y = newValue + }), + + setExtenderWidth: (newValue: number) => + set((state) => { + state.extenderState.width = newValue + }), + + setExtenderHeight: (newValue: number) => + set((state) => { + state.extenderState.height = newValue + }), + setSeed: (newValue: number) => set((state) => { state.settings.seed = newValue diff --git a/web_app/src/lib/types.ts b/web_app/src/lib/types.ts index 120c95f..ad3370d 100644 --- a/web_app/src/lib/types.ts +++ b/web_app/src/lib/types.ts @@ -9,6 +9,7 @@ export interface ModelInfo { | "diffusers_sdxl_inpaint" | "diffusers_other" support_strength: boolean + support_outpainting: boolean support_controlnet: boolean controlnets: string[] support_freeu: boolean @@ -66,6 +67,7 @@ export enum SDSampler { kEulerA = "k_euler_a", dpmPlusPlus = "dpm++", uni_pc = "uni_pc", + lcm = "lcm", } export interface FreeuConfig { diff --git a/web_app/src/lib/utils.ts b/web_app/src/lib/utils.ts index 53c5d4f..ffc52df 100644 --- a/web_app/src/lib/utils.ts +++ b/web_app/src/lib/utils.ts @@ -53,6 +53,24 @@ export function loadImage(image: HTMLImageElement, src: string) { }) } +export function canvasToImage( + canvas: HTMLCanvasElement +): Promise { + return new Promise((resolve, reject) => { + const image = new Image() + + image.addEventListener("load", () => { + resolve(image) + }) + + image.addEventListener("error", (error) => { + reject(error) + }) + + image.src = canvas.toDataURL() + }) +} + export function srcToFile(src: string, fileName: string, mimeType: string) { return fetch(src) .then(function (res) { @@ -164,7 +182,8 @@ export const generateMask = ( imageWidth: number, imageHeight: number, lineGroups: LineGroup[], - maskImages: HTMLImageElement[] = [] + maskImages: HTMLImageElement[] = [], + lineGroupsColor: string = "white" ): HTMLCanvasElement => { const maskCanvas = document.createElement("canvas") maskCanvas.width = imageWidth @@ -179,7 +198,7 @@ export const generateMask = ( }) lineGroups.forEach((lineGroup) => { - drawLines(ctx, lineGroup, "white") + drawLines(ctx, lineGroup, lineGroupsColor) }) return maskCanvas