add mask tab

This commit is contained in:
Qing 2024-08-12 12:13:37 +08:00
parent 60b1411d6b
commit 820ce5e4d0
11 changed files with 121 additions and 58 deletions

View File

@ -19,7 +19,6 @@ try:
except: except:
pass pass
import uvicorn import uvicorn
from PIL import Image from PIL import Image
from fastapi import APIRouter, FastAPI, Request, UploadFile from fastapi import APIRouter, FastAPI, Request, UploadFile
@ -127,7 +126,7 @@ def api_middleware(app: FastAPI):
"allow_headers": ["*"], "allow_headers": ["*"],
"allow_origins": ["*"], "allow_origins": ["*"],
"allow_credentials": True, "allow_credentials": True,
"expose_headers": ["X-Seed"] "expose_headers": ["X-Seed"],
} }
app.add_middleware(CORSMiddleware, **cors_options) app.add_middleware(CORSMiddleware, **cors_options)
@ -159,7 +158,8 @@ class Api:
# fmt: off # fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
response_model=ServerConfigResponse)
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
@ -361,6 +361,7 @@ class Api:
return FileManager( return FileManager(
app=self.app, app=self.app,
input_dir=self.config.input, input_dir=self.config.input,
mask_dir=self.config.mask_dir,
output_dir=self.config.output_dir, output_dir=self.config.output_dir,
) )
return None return None

View File

@ -1,7 +1,7 @@
import webbrowser import webbrowser
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Optional
import typer import typer
from fastapi import FastAPI from fastapi import FastAPI
@ -120,6 +120,9 @@ def start(
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP), local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu), device: Device = Option(Device.cpu),
input: Optional[Path] = Option(None, help=INPUT_HELP), input: Optional[Path] = Option(None, help=INPUT_HELP),
mask_dir: Optional[Path] = Option(
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
output_dir: Optional[Path] = Option( output_dir: Optional[Path] = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
), ),
@ -145,8 +148,11 @@ def start(
if input and not input.exists(): if input and not input.exists():
logger.error(f"invalid --input: {input} not exists") logger.error(f"invalid --input: {input} not exists")
exit(-1) exit(-1)
if mask_dir and not mask_dir.exists():
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
exit(-1)
if input and input.is_dir() and not output_dir: if input and input.is_dir() and not output_dir:
logger.error(f"invalid --output-dir: must be set when --input is a directory") logger.error("invalid --output-dir: --output-dir must be set when --input is a directory")
exit(-1) exit(-1)
if output_dir: if output_dir:
output_dir = output_dir.expanduser().absolute() output_dir = output_dir.expanduser().absolute()
@ -154,6 +160,8 @@ def start(
if not output_dir.exists(): if not output_dir.exists():
logger.info(f"Create output directory {output_dir}") logger.info(f"Create output directory {output_dir}")
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
if mask_dir:
mask_dir = mask_dir.expanduser().absolute()
model_dir = model_dir.expanduser().absolute() model_dir = model_dir.expanduser().absolute()
@ -192,6 +200,7 @@ def start(
cpu_textencoder=cpu_textencoder if device == Device.cuda else False, cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
device=device, device=device,
input=input, input=input,
mask_dir=mask_dir,
output_dir=output_dir, output_dir=output_dir,
quality=quality, quality=quality,
enable_interactive_seg=enable_interactive_seg, enable_interactive_seg=enable_interactive_seg,

View File

@ -63,7 +63,7 @@ SD_CONTROLNET_CHOICES: List[str] = [
SD_BRUSHNET_CHOICES: List[str] = [ SD_BRUSHNET_CHOICES: List[str] = [
"Sanster/brushnet_random_mask", "Sanster/brushnet_random_mask",
"Sanster/brushnet_segmentation_mask" "Sanster/brushnet_segmentation_mask",
] ]
SD2_CONTROLNET_CHOICES = [ SD2_CONTROLNET_CHOICES = [
@ -99,6 +99,10 @@ OUTPUT_DIR_HELP = """
Result images will be saved to output directory automatically. Result images will be saved to output directory automatically.
""" """
MASK_DIR_HELP = """
You can view masks in FileManager
"""
INPUT_HELP = """ INPUT_HELP = """
If input is image, it will be loaded by default. If input is image, it will be loaded by default.
If input is directory, you can browse and select image in file manager. If input is directory, you can browse and select image in file manager.

View File

@ -4,7 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from PIL import Image, ImageOps, PngImagePlugin from PIL import Image, ImageOps, PngImagePlugin
from fastapi import FastAPI, UploadFile, HTTPException from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse from starlette.responses import FileResponse
from ..schema import MediasResponse, MediaTab from ..schema import MediasResponse, MediaTab
@ -16,9 +16,10 @@ from .utils import aspect_to_string, generate_filename, glob_img
class FileManager: class FileManager:
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path): def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
self.app = app self.app = app
self.input_dir: Path = input_dir self.input_dir: Path = input_dir
self.mask_dir: Path = mask_dir
self.output_dir: Path = output_dir self.output_dir: Path = output_dir
self.image_dir_filenames = [] self.image_dir_filenames = []
@ -63,6 +64,8 @@ class FileManager:
return self.input_dir return self.input_dir
elif tab == "output": elif tab == "output":
return self.output_dir return self.output_dir
elif tab == "mask":
return self.mask_dir
else: else:
raise HTTPException(status_code=422, detail=f"tab not found: {tab}") raise HTTPException(status_code=422, detail=f"tab not found: {tab}")

View File

@ -244,6 +244,7 @@ class ApiConfig(BaseModel):
cpu_textencoder: bool cpu_textencoder: bool
device: Device device: Device
input: Optional[Path] input: Optional[Path]
mask_dir: Optional[Path]
output_dir: Optional[Path] output_dir: Optional[Path]
quality: int quality: int
enable_interactive_seg: bool enable_interactive_seg: bool
@ -436,7 +437,7 @@ class RunPluginRequest(BaseModel):
scale: float = Field(2.0, description="Scale for upscaling") scale: float = Field(2.0, description="Scale for upscaling")
MediaTab = Literal["input", "output"] MediaTab = Literal["input", "output", "mask"]
class MediasResponse(BaseModel): class MediasResponse(BaseModel):

View File

@ -3,10 +3,11 @@ import os
from pathlib import Path from pathlib import Path
import mimetypes import mimetypes
# fix for windows mimetypes registry entries being borked # fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type('application/javascript', '.js') mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type('text/css', '.css') mimetypes.add_type("text/css", ".css")
from iopaint.schema import ( from iopaint.schema import (
Device, Device,
@ -90,6 +91,7 @@ def save_config(
cpu_textencoder, cpu_textencoder,
device, device,
input, input,
mask_dir,
output_dir, output_dir,
quality, quality,
enable_interactive_seg, enable_interactive_seg,
@ -112,6 +114,8 @@ def save_config(
config.input = None config.input = None
if str(config.output_dir) == ".": if str(config.output_dir) == ".":
config.output_dir = None config.output_dir = None
if str(config.mask_dir) == ".":
config.mask_dir = None
config.model = config.model.strip() config.model = config.model.strip()
print(config.model_dump_json(indent=4)) print(config.model_dump_json(indent=4))
if config.input and not os.path.exists(config.input): if config.input and not os.path.exists(config.input):
@ -207,6 +211,10 @@ def main(config_file: Path):
init_config.output_dir, init_config.output_dir,
label=f"Output directory. {OUTPUT_DIR_HELP}", label=f"Output directory. {OUTPUT_DIR_HELP}",
) )
mask_dir = gr.Textbox(
init_config.mask_dir,
label=f"Mask directory. {MASK_DIR_HELP}",
)
with gr.Tab("Plugins"): with gr.Tab("Plugins"):
with gr.Row(): with gr.Row():
@ -288,6 +296,7 @@ def main(config_file: Path):
cpu_textencoder, cpu_textencoder,
device, device,
input, input,
mask_dir,
output_dir, output_dir,
quality, quality,
enable_interactive_seg, enable_interactive_seg,

View File

@ -50,6 +50,7 @@ const SORT_BY_MODIFIED_TIME = "Modified time"
const IMAGE_TAB = "input" const IMAGE_TAB = "input"
const OUTPUT_TAB = "output" const OUTPUT_TAB = "output"
export const MASK_TAB = "mask"
const SortByMap = { const SortByMap = {
[SortBy.NAME]: SORT_BY_NAME, [SortBy.NAME]: SORT_BY_NAME,
@ -264,6 +265,7 @@ export default function FileManager(props: Props) {
<TabsList aria-label="Manage your account"> <TabsList aria-label="Manage your account">
<TabsTrigger value={IMAGE_TAB}>Image Directory</TabsTrigger> <TabsTrigger value={IMAGE_TAB}>Image Directory</TabsTrigger>
<TabsTrigger value={OUTPUT_TAB}>Output Directory</TabsTrigger> <TabsTrigger value={OUTPUT_TAB}>Output Directory</TabsTrigger>
<TabsTrigger value={MASK_TAB}>Mask Directory</TabsTrigger>
</TabsList> </TabsList>
</Tabs> </Tabs>

View File

@ -7,8 +7,8 @@ import { useImage } from "@/hooks/useImage"
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover" import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"
import PromptInput from "./PromptInput" import PromptInput from "./PromptInput"
import { RotateCw, Image, Upload } from "lucide-react" import { RotateCw, Image, Upload } from "lucide-react"
import FileManager from "./FileManager" import FileManager, { MASK_TAB } from "./FileManager"
import { getMediaFile } from "@/lib/api" import { getMediaBlob, getMediaFile } from "@/lib/api"
import { useStore } from "@/lib/states" import { useStore } from "@/lib/states"
import SettingsDialog from "./Settings" import SettingsDialog from "./Settings"
import { cn, fileToImage } from "@/lib/utils" import { cn, fileToImage } from "@/lib/utils"
@ -31,6 +31,7 @@ const Header = () => {
hidePrevMask, hidePrevMask,
imageHeight, imageHeight,
imageWidth, imageWidth,
handleFileManagerMaskSelect,
] = useStore((state) => [ ] = useStore((state) => [
state.file, state.file,
state.customMask, state.customMask,
@ -46,6 +47,7 @@ const Header = () => {
state.hidePrevMask, state.hidePrevMask,
state.imageHeight, state.imageHeight,
state.imageWidth, state.imageWidth,
state.handleFileManagerMaskSelect,
]) ])
const { toast } = useToast() const { toast } = useToast()
@ -64,16 +66,15 @@ const Header = () => {
hidePrevMask() hidePrevMask()
} }
return ( const handleOnPhotoClick = async (tab: string, filename: string) => {
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
<div className="flex items-center gap-1">
{serverConfig.enableFileManager ? (
<FileManager
photoWidth={512}
onPhotoClick={async (tab: string, filename: string) => {
try { try {
if (tab === MASK_TAB) {
const maskBlob = await getMediaBlob(tab, filename)
handleFileManagerMaskSelect(maskBlob)
} else {
const newFile = await getMediaFile(tab, filename) const newFile = await getMediaFile(tab, filename)
setFile(newFile) setFile(newFile)
}
} catch (e: any) { } catch (e: any) {
toast({ toast({
variant: "destructive", variant: "destructive",
@ -81,8 +82,13 @@ const Header = () => {
}) })
return return
} }
}} }
/>
return (
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
<div className="flex items-center gap-1">
{serverConfig.enableFileManager ? (
<FileManager photoWidth={512} onPhotoClick={handleOnPhotoClick} />
) : ( ) : (
<></> <></>
)} )}

View File

@ -167,6 +167,23 @@ export async function getMediaFile(tab: string, filename: string) {
throw new Error(errMsg.errors) throw new Error(errMsg.errors)
} }
export async function getMediaBlob(tab: string, filename: string) {
const res = await fetch(
`${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent(
filename
)}`,
{
method: "GET",
}
)
if (res.ok) {
const blob = await res.blob()
return blob
}
const errMsg = await res.json()
throw new Error(errMsg.errors)
}
export async function getMedias(tab: string): Promise<Filename[]> { export async function getMedias(tab: string): Promise<Filename[]> {
const res = await api.get(`medias`, { params: { tab } }) const res = await api.get(`medias`, { params: { tab } })
return res.data return res.data

View File

@ -207,6 +207,7 @@ type AppAction = {
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void resetInteractiveSegState: () => void
handleInteractiveSegAccept: () => void handleInteractiveSegAccept: () => void
handleFileManagerMaskSelect: (blob: Blob) => Promise<void>
showPromptInput: () => boolean showPromptInput: () => boolean
runInpainting: () => Promise<void> runInpainting: () => Promise<void>
@ -903,6 +904,16 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
}) })
}, },
handleFileManagerMaskSelect: async (blob: Blob) => {
const newMask = new Image()
await loadImage(newMask, URL.createObjectURL(blob))
set((state) => {
state.editorState.extraMasks.push(castDraft(newMask))
})
get().runInpainting()
},
setIsInpainting: (newValue: boolean) => setIsInpainting: (newValue: boolean) =>
set((state) => { set((state) => {
state.isInpainting = newValue state.isInpainting = newValue