From 820ce5e4d0e3700ab0b692d47bb717c09ff58689 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 12 Aug 2024 12:13:37 +0800 Subject: [PATCH] add mask tab --- iopaint/api.py | 7 +- iopaint/cli.py | 13 +++- iopaint/const.py | 6 +- iopaint/file_manager/file_manager.py | 7 +- iopaint/schema.py | 3 +- ...verture-creations-5sI6fQgYIuo_all_mask.png | 0 iopaint/web_config.py | 71 +++++++++++-------- web_app/src/components/FileManager.tsx | 2 + web_app/src/components/Header.tsx | 40 ++++++----- web_app/src/lib/api.ts | 19 ++++- web_app/src/lib/states.ts | 11 +++ 11 files changed, 121 insertions(+), 58 deletions(-) delete mode 100644 iopaint/tests/overture-creations-5sI6fQgYIuo_all_mask.png diff --git a/iopaint/api.py b/iopaint/api.py index 90d6f36..6725479 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -19,7 +19,6 @@ try: except: pass - import uvicorn from PIL import Image from fastapi import APIRouter, FastAPI, Request, UploadFile @@ -127,7 +126,7 @@ def api_middleware(app: FastAPI): "allow_headers": ["*"], "allow_origins": ["*"], "allow_credentials": True, - "expose_headers": ["X-Seed"] + "expose_headers": ["X-Seed"], } app.add_middleware(CORSMiddleware, **cors_options) @@ -159,7 +158,8 @@ class Api: # fmt: off self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) - self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) + self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], + response_model=ServerConfigResponse) self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) @@ -361,6 +361,7 @@ class Api: return FileManager( app=self.app, input_dir=self.config.input, + mask_dir=self.config.mask_dir, output_dir=self.config.output_dir, ) return None diff --git a/iopaint/cli.py b/iopaint/cli.py index 951fbb4..9ba686b 100644 --- a/iopaint/cli.py +++ b/iopaint/cli.py @@ -1,7 +1,7 @@ import webbrowser from contextlib import asynccontextmanager from pathlib import Path -from typing import Dict, Optional +from typing import Optional import typer from fastapi import FastAPI @@ -120,6 +120,9 @@ def start( local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP), device: Device = Option(Device.cpu), input: Optional[Path] = Option(None, help=INPUT_HELP), + mask_dir: Optional[Path] = Option( + None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False + ), output_dir: Optional[Path] = Option( None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False ), @@ -145,8 +148,11 @@ def start( if input and not input.exists(): logger.error(f"invalid --input: {input} not exists") exit(-1) + if mask_dir and not mask_dir.exists(): + logger.error(f"invalid --mask-dir: {mask_dir} not exists") + exit(-1) if input and input.is_dir() and not output_dir: - logger.error(f"invalid --output-dir: must be set when --input is a directory") + logger.error("invalid --output-dir: --output-dir must be set when --input is a directory") exit(-1) if output_dir: output_dir = output_dir.expanduser().absolute() @@ -154,6 +160,8 @@ def start( if not output_dir.exists(): logger.info(f"Create output directory {output_dir}") output_dir.mkdir(parents=True) + if mask_dir: + mask_dir = mask_dir.expanduser().absolute() model_dir = model_dir.expanduser().absolute() @@ -192,6 +200,7 @@ def start( cpu_textencoder=cpu_textencoder if device == Device.cuda else False, device=device, input=input, + mask_dir=mask_dir, output_dir=output_dir, quality=quality, enable_interactive_seg=enable_interactive_seg, diff --git a/iopaint/const.py b/iopaint/const.py index 9a0386b..b18254b 100644 --- a/iopaint/const.py +++ b/iopaint/const.py @@ -63,7 +63,7 @@ SD_CONTROLNET_CHOICES: List[str] = [ SD_BRUSHNET_CHOICES: List[str] = [ "Sanster/brushnet_random_mask", - "Sanster/brushnet_segmentation_mask" + "Sanster/brushnet_segmentation_mask", ] SD2_CONTROLNET_CHOICES = [ @@ -99,6 +99,10 @@ OUTPUT_DIR_HELP = """ Result images will be saved to output directory automatically. """ +MASK_DIR_HELP = """ +You can view masks in FileManager +""" + INPUT_HELP = """ If input is image, it will be loaded by default. If input is directory, you can browse and select image in file manager. diff --git a/iopaint/file_manager/file_manager.py b/iopaint/file_manager/file_manager.py index 413162c..c24f54f 100644 --- a/iopaint/file_manager/file_manager.py +++ b/iopaint/file_manager/file_manager.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List from PIL import Image, ImageOps, PngImagePlugin -from fastapi import FastAPI, UploadFile, HTTPException +from fastapi import FastAPI, HTTPException from starlette.responses import FileResponse from ..schema import MediasResponse, MediaTab @@ -16,9 +16,10 @@ from .utils import aspect_to_string, generate_filename, glob_img class FileManager: - def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path): + def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path): self.app = app self.input_dir: Path = input_dir + self.mask_dir: Path = mask_dir self.output_dir: Path = output_dir self.image_dir_filenames = [] @@ -63,6 +64,8 @@ class FileManager: return self.input_dir elif tab == "output": return self.output_dir + elif tab == "mask": + return self.mask_dir else: raise HTTPException(status_code=422, detail=f"tab not found: {tab}") diff --git a/iopaint/schema.py b/iopaint/schema.py index 318c21a..3150ab4 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -244,6 +244,7 @@ class ApiConfig(BaseModel): cpu_textencoder: bool device: Device input: Optional[Path] + mask_dir: Optional[Path] output_dir: Optional[Path] quality: int enable_interactive_seg: bool @@ -436,7 +437,7 @@ class RunPluginRequest(BaseModel): scale: float = Field(2.0, description="Scale for upscaling") -MediaTab = Literal["input", "output"] +MediaTab = Literal["input", "output", "mask"] class MediasResponse(BaseModel): diff --git a/iopaint/tests/overture-creations-5sI6fQgYIuo_all_mask.png b/iopaint/tests/overture-creations-5sI6fQgYIuo_all_mask.png deleted file mode 100644 index e69de29..0000000 diff --git a/iopaint/web_config.py b/iopaint/web_config.py index 934d3a9..d582536 100644 --- a/iopaint/web_config.py +++ b/iopaint/web_config.py @@ -3,10 +3,11 @@ import os from pathlib import Path import mimetypes + # fix for windows mimetypes registry entries being borked # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 -mimetypes.add_type('application/javascript', '.js') -mimetypes.add_type('text/css', '.css') +mimetypes.add_type("application/javascript", ".js") +mimetypes.add_type("text/css", ".css") from iopaint.schema import ( Device, @@ -78,40 +79,43 @@ def load_config(p: Path) -> WebConfig: def save_config( - host, - port, - model, - model_dir, - no_half, - low_mem, - cpu_offload, - disable_nsfw_checker, - local_files_only, - cpu_textencoder, - device, - input, - output_dir, - quality, - enable_interactive_seg, - interactive_seg_model, - interactive_seg_device, - enable_remove_bg, - remove_bg_model, - enable_anime_seg, - enable_realesrgan, - realesrgan_device, - realesrgan_model, - enable_gfpgan, - gfpgan_device, - enable_restoreformer, - restoreformer_device, - inbrowser, + host, + port, + model, + model_dir, + no_half, + low_mem, + cpu_offload, + disable_nsfw_checker, + local_files_only, + cpu_textencoder, + device, + input, + mask_dir, + output_dir, + quality, + enable_interactive_seg, + interactive_seg_model, + interactive_seg_device, + enable_remove_bg, + remove_bg_model, + enable_anime_seg, + enable_realesrgan, + realesrgan_device, + realesrgan_model, + enable_gfpgan, + gfpgan_device, + enable_restoreformer, + restoreformer_device, + inbrowser, ): config = WebConfig(**locals()) if str(config.input) == ".": config.input = None if str(config.output_dir) == ".": config.output_dir = None + if str(config.mask_dir) == ".": + config.mask_dir = None config.model = config.model.strip() print(config.model_dump_json(indent=4)) if config.input and not os.path.exists(config.input): @@ -166,7 +170,7 @@ def main(config_file: Path): model = gr.Textbox( init_config.model, label="Current Model. Model will be automatically downloaded. " - "You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.", + "You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.", ) device = gr.Radio( @@ -207,6 +211,10 @@ def main(config_file: Path): init_config.output_dir, label=f"Output directory. {OUTPUT_DIR_HELP}", ) + mask_dir = gr.Textbox( + init_config.mask_dir, + label=f"Mask directory. {MASK_DIR_HELP}", + ) with gr.Tab("Plugins"): with gr.Row(): @@ -288,6 +296,7 @@ def main(config_file: Path): cpu_textencoder, device, input, + mask_dir, output_dir, quality, enable_interactive_seg, diff --git a/web_app/src/components/FileManager.tsx b/web_app/src/components/FileManager.tsx index 84e04ed..6c2a841 100644 --- a/web_app/src/components/FileManager.tsx +++ b/web_app/src/components/FileManager.tsx @@ -50,6 +50,7 @@ const SORT_BY_MODIFIED_TIME = "Modified time" const IMAGE_TAB = "input" const OUTPUT_TAB = "output" +export const MASK_TAB = "mask" const SortByMap = { [SortBy.NAME]: SORT_BY_NAME, @@ -264,6 +265,7 @@ export default function FileManager(props: Props) { Image Directory Output Directory + Mask Directory diff --git a/web_app/src/components/Header.tsx b/web_app/src/components/Header.tsx index 05acce6..e9d594e 100644 --- a/web_app/src/components/Header.tsx +++ b/web_app/src/components/Header.tsx @@ -7,8 +7,8 @@ import { useImage } from "@/hooks/useImage" import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover" import PromptInput from "./PromptInput" import { RotateCw, Image, Upload } from "lucide-react" -import FileManager from "./FileManager" -import { getMediaFile } from "@/lib/api" +import FileManager, { MASK_TAB } from "./FileManager" +import { getMediaBlob, getMediaFile } from "@/lib/api" import { useStore } from "@/lib/states" import SettingsDialog from "./Settings" import { cn, fileToImage } from "@/lib/utils" @@ -31,6 +31,7 @@ const Header = () => { hidePrevMask, imageHeight, imageWidth, + handleFileManagerMaskSelect, ] = useStore((state) => [ state.file, state.customMask, @@ -46,6 +47,7 @@ const Header = () => { state.hidePrevMask, state.imageHeight, state.imageWidth, + state.handleFileManagerMaskSelect, ]) const { toast } = useToast() @@ -64,25 +66,29 @@ const Header = () => { hidePrevMask() } + const handleOnPhotoClick = async (tab: string, filename: string) => { + try { + if (tab === MASK_TAB) { + const maskBlob = await getMediaBlob(tab, filename) + handleFileManagerMaskSelect(maskBlob) + } else { + const newFile = await getMediaFile(tab, filename) + setFile(newFile) + } + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + return + } + } + return (
{serverConfig.enableFileManager ? ( - { - try { - const newFile = await getMediaFile(tab, filename) - setFile(newFile) - } catch (e: any) { - toast({ - variant: "destructive", - description: e.message ? e.message : e.toString(), - }) - return - } - }} - /> + ) : ( <> )} diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index f77e38e..a8f0b7f 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -135,7 +135,7 @@ export async function runPlugin( body: JSON.stringify({ name, image: imageBase64, - scale:upscale, + scale: upscale, clicks, }), }) @@ -167,6 +167,23 @@ export async function getMediaFile(tab: string, filename: string) { throw new Error(errMsg.errors) } +export async function getMediaBlob(tab: string, filename: string) { + const res = await fetch( + `${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent( + filename + )}`, + { + method: "GET", + } + ) + if (res.ok) { + const blob = await res.blob() + return blob + } + const errMsg = await res.json() + throw new Error(errMsg.errors) +} + export async function getMedias(tab: string): Promise { const res = await api.get(`medias`, { params: { tab } }) return res.data diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index a9cde43..2d91142 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -207,6 +207,7 @@ type AppAction = { updateInteractiveSegState: (newState: Partial) => void resetInteractiveSegState: () => void handleInteractiveSegAccept: () => void + handleFileManagerMaskSelect: (blob: Blob) => Promise showPromptInput: () => boolean runInpainting: () => Promise @@ -903,6 +904,16 @@ export const useStore = createWithEqualityFn()( }) }, + handleFileManagerMaskSelect: async (blob: Blob) => { + const newMask = new Image() + + await loadImage(newMask, URL.createObjectURL(blob)) + set((state) => { + state.editorState.extraMasks.push(castDraft(newMask)) + }) + get().runInpainting() + }, + setIsInpainting: (newValue: boolean) => set((state) => { state.isInpainting = newValue