add mask tab

This commit is contained in:
Qing 2024-08-12 12:13:37 +08:00
parent 60b1411d6b
commit 820ce5e4d0
11 changed files with 121 additions and 58 deletions

View File

@ -19,7 +19,6 @@ try:
except:
pass
import uvicorn
from PIL import Image
from fastapi import APIRouter, FastAPI, Request, UploadFile
@ -127,7 +126,7 @@ def api_middleware(app: FastAPI):
"allow_headers": ["*"],
"allow_origins": ["*"],
"allow_credentials": True,
"expose_headers": ["X-Seed"]
"expose_headers": ["X-Seed"],
}
app.add_middleware(CORSMiddleware, **cors_options)
@ -159,7 +158,8 @@ class Api:
# fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
response_model=ServerConfigResponse)
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
@ -361,6 +361,7 @@ class Api:
return FileManager(
app=self.app,
input_dir=self.config.input,
mask_dir=self.config.mask_dir,
output_dir=self.config.output_dir,
)
return None

View File

@ -1,7 +1,7 @@
import webbrowser
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Dict, Optional
from typing import Optional
import typer
from fastapi import FastAPI
@ -120,6 +120,9 @@ def start(
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu),
input: Optional[Path] = Option(None, help=INPUT_HELP),
mask_dir: Optional[Path] = Option(
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
output_dir: Optional[Path] = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
),
@ -145,8 +148,11 @@ def start(
if input and not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit(-1)
if mask_dir and not mask_dir.exists():
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
exit(-1)
if input and input.is_dir() and not output_dir:
logger.error(f"invalid --output-dir: must be set when --input is a directory")
logger.error("invalid --output-dir: --output-dir must be set when --input is a directory")
exit(-1)
if output_dir:
output_dir = output_dir.expanduser().absolute()
@ -154,6 +160,8 @@ def start(
if not output_dir.exists():
logger.info(f"Create output directory {output_dir}")
output_dir.mkdir(parents=True)
if mask_dir:
mask_dir = mask_dir.expanduser().absolute()
model_dir = model_dir.expanduser().absolute()
@ -192,6 +200,7 @@ def start(
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
device=device,
input=input,
mask_dir=mask_dir,
output_dir=output_dir,
quality=quality,
enable_interactive_seg=enable_interactive_seg,

View File

@ -63,7 +63,7 @@ SD_CONTROLNET_CHOICES: List[str] = [
SD_BRUSHNET_CHOICES: List[str] = [
"Sanster/brushnet_random_mask",
"Sanster/brushnet_segmentation_mask"
"Sanster/brushnet_segmentation_mask",
]
SD2_CONTROLNET_CHOICES = [
@ -99,6 +99,10 @@ OUTPUT_DIR_HELP = """
Result images will be saved to output directory automatically.
"""
MASK_DIR_HELP = """
You can view masks in FileManager
"""
INPUT_HELP = """
If input is image, it will be loaded by default.
If input is directory, you can browse and select image in file manager.

View File

@ -4,7 +4,7 @@ from pathlib import Path
from typing import List
from PIL import Image, ImageOps, PngImagePlugin
from fastapi import FastAPI, UploadFile, HTTPException
from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse
from ..schema import MediasResponse, MediaTab
@ -16,9 +16,10 @@ from .utils import aspect_to_string, generate_filename, glob_img
class FileManager:
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
self.app = app
self.input_dir: Path = input_dir
self.mask_dir: Path = mask_dir
self.output_dir: Path = output_dir
self.image_dir_filenames = []
@ -63,6 +64,8 @@ class FileManager:
return self.input_dir
elif tab == "output":
return self.output_dir
elif tab == "mask":
return self.mask_dir
else:
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")

View File

@ -244,6 +244,7 @@ class ApiConfig(BaseModel):
cpu_textencoder: bool
device: Device
input: Optional[Path]
mask_dir: Optional[Path]
output_dir: Optional[Path]
quality: int
enable_interactive_seg: bool
@ -436,7 +437,7 @@ class RunPluginRequest(BaseModel):
scale: float = Field(2.0, description="Scale for upscaling")
MediaTab = Literal["input", "output"]
MediaTab = Literal["input", "output", "mask"]
class MediasResponse(BaseModel):

View File

@ -3,10 +3,11 @@ import os
from pathlib import Path
import mimetypes
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('text/css', '.css')
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
from iopaint.schema import (
Device,
@ -78,40 +79,43 @@ def load_config(p: Path) -> WebConfig:
def save_config(
host,
port,
model,
model_dir,
no_half,
low_mem,
cpu_offload,
disable_nsfw_checker,
local_files_only,
cpu_textencoder,
device,
input,
output_dir,
quality,
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
remove_bg_model,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
inbrowser,
host,
port,
model,
model_dir,
no_half,
low_mem,
cpu_offload,
disable_nsfw_checker,
local_files_only,
cpu_textencoder,
device,
input,
mask_dir,
output_dir,
quality,
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
remove_bg_model,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
inbrowser,
):
config = WebConfig(**locals())
if str(config.input) == ".":
config.input = None
if str(config.output_dir) == ".":
config.output_dir = None
if str(config.mask_dir) == ".":
config.mask_dir = None
config.model = config.model.strip()
print(config.model_dump_json(indent=4))
if config.input and not os.path.exists(config.input):
@ -166,7 +170,7 @@ def main(config_file: Path):
model = gr.Textbox(
init_config.model,
label="Current Model. Model will be automatically downloaded. "
"You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
"You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
)
device = gr.Radio(
@ -207,6 +211,10 @@ def main(config_file: Path):
init_config.output_dir,
label=f"Output directory. {OUTPUT_DIR_HELP}",
)
mask_dir = gr.Textbox(
init_config.mask_dir,
label=f"Mask directory. {MASK_DIR_HELP}",
)
with gr.Tab("Plugins"):
with gr.Row():
@ -288,6 +296,7 @@ def main(config_file: Path):
cpu_textencoder,
device,
input,
mask_dir,
output_dir,
quality,
enable_interactive_seg,

View File

@ -50,6 +50,7 @@ const SORT_BY_MODIFIED_TIME = "Modified time"
const IMAGE_TAB = "input"
const OUTPUT_TAB = "output"
export const MASK_TAB = "mask"
const SortByMap = {
[SortBy.NAME]: SORT_BY_NAME,
@ -264,6 +265,7 @@ export default function FileManager(props: Props) {
<TabsList aria-label="Manage your account">
<TabsTrigger value={IMAGE_TAB}>Image Directory</TabsTrigger>
<TabsTrigger value={OUTPUT_TAB}>Output Directory</TabsTrigger>
<TabsTrigger value={MASK_TAB}>Mask Directory</TabsTrigger>
</TabsList>
</Tabs>

View File

@ -7,8 +7,8 @@ import { useImage } from "@/hooks/useImage"
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"
import PromptInput from "./PromptInput"
import { RotateCw, Image, Upload } from "lucide-react"
import FileManager from "./FileManager"
import { getMediaFile } from "@/lib/api"
import FileManager, { MASK_TAB } from "./FileManager"
import { getMediaBlob, getMediaFile } from "@/lib/api"
import { useStore } from "@/lib/states"
import SettingsDialog from "./Settings"
import { cn, fileToImage } from "@/lib/utils"
@ -31,6 +31,7 @@ const Header = () => {
hidePrevMask,
imageHeight,
imageWidth,
handleFileManagerMaskSelect,
] = useStore((state) => [
state.file,
state.customMask,
@ -46,6 +47,7 @@ const Header = () => {
state.hidePrevMask,
state.imageHeight,
state.imageWidth,
state.handleFileManagerMaskSelect,
])
const { toast } = useToast()
@ -64,25 +66,29 @@ const Header = () => {
hidePrevMask()
}
const handleOnPhotoClick = async (tab: string, filename: string) => {
try {
if (tab === MASK_TAB) {
const maskBlob = await getMediaBlob(tab, filename)
handleFileManagerMaskSelect(maskBlob)
} else {
const newFile = await getMediaFile(tab, filename)
setFile(newFile)
}
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
return
}
}
return (
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
<div className="flex items-center gap-1">
{serverConfig.enableFileManager ? (
<FileManager
photoWidth={512}
onPhotoClick={async (tab: string, filename: string) => {
try {
const newFile = await getMediaFile(tab, filename)
setFile(newFile)
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
return
}
}}
/>
<FileManager photoWidth={512} onPhotoClick={handleOnPhotoClick} />
) : (
<></>
)}

View File

@ -135,7 +135,7 @@ export async function runPlugin(
body: JSON.stringify({
name,
image: imageBase64,
scale:upscale,
scale: upscale,
clicks,
}),
})
@ -167,6 +167,23 @@ export async function getMediaFile(tab: string, filename: string) {
throw new Error(errMsg.errors)
}
export async function getMediaBlob(tab: string, filename: string) {
const res = await fetch(
`${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent(
filename
)}`,
{
method: "GET",
}
)
if (res.ok) {
const blob = await res.blob()
return blob
}
const errMsg = await res.json()
throw new Error(errMsg.errors)
}
export async function getMedias(tab: string): Promise<Filename[]> {
const res = await api.get(`medias`, { params: { tab } })
return res.data

View File

@ -207,6 +207,7 @@ type AppAction = {
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
resetInteractiveSegState: () => void
handleInteractiveSegAccept: () => void
handleFileManagerMaskSelect: (blob: Blob) => Promise<void>
showPromptInput: () => boolean
runInpainting: () => Promise<void>
@ -903,6 +904,16 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
})
},
handleFileManagerMaskSelect: async (blob: Blob) => {
const newMask = new Image()
await loadImage(newMask, URL.createObjectURL(blob))
set((state) => {
state.editorState.extraMasks.push(castDraft(newMask))
})
get().runInpainting()
},
setIsInpainting: (newValue: boolean) =>
set((state) => {
state.isInpainting = newValue