add mask tab
This commit is contained in:
parent
60b1411d6b
commit
820ce5e4d0
@ -19,7 +19,6 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
||||
@ -127,7 +126,7 @@ def api_middleware(app: FastAPI):
|
||||
"allow_headers": ["*"],
|
||||
"allow_origins": ["*"],
|
||||
"allow_credentials": True,
|
||||
"expose_headers": ["X-Seed"]
|
||||
"expose_headers": ["X-Seed"],
|
||||
}
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
||||
@ -159,7 +158,8 @@ class Api:
|
||||
|
||||
# fmt: off
|
||||
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
|
||||
response_model=ServerConfigResponse)
|
||||
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||
@ -361,6 +361,7 @@ class Api:
|
||||
return FileManager(
|
||||
app=self.app,
|
||||
input_dir=self.config.input,
|
||||
mask_dir=self.config.mask_dir,
|
||||
output_dir=self.config.output_dir,
|
||||
)
|
||||
return None
|
||||
|
@ -1,7 +1,7 @@
|
||||
import webbrowser
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from fastapi import FastAPI
|
||||
@ -120,6 +120,9 @@ def start(
|
||||
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
||||
device: Device = Option(Device.cpu),
|
||||
input: Optional[Path] = Option(None, help=INPUT_HELP),
|
||||
mask_dir: Optional[Path] = Option(
|
||||
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
output_dir: Optional[Path] = Option(
|
||||
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
@ -145,8 +148,11 @@ def start(
|
||||
if input and not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
exit(-1)
|
||||
if mask_dir and not mask_dir.exists():
|
||||
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
|
||||
exit(-1)
|
||||
if input and input.is_dir() and not output_dir:
|
||||
logger.error(f"invalid --output-dir: must be set when --input is a directory")
|
||||
logger.error("invalid --output-dir: --output-dir must be set when --input is a directory")
|
||||
exit(-1)
|
||||
if output_dir:
|
||||
output_dir = output_dir.expanduser().absolute()
|
||||
@ -154,6 +160,8 @@ def start(
|
||||
if not output_dir.exists():
|
||||
logger.info(f"Create output directory {output_dir}")
|
||||
output_dir.mkdir(parents=True)
|
||||
if mask_dir:
|
||||
mask_dir = mask_dir.expanduser().absolute()
|
||||
|
||||
model_dir = model_dir.expanduser().absolute()
|
||||
|
||||
@ -192,6 +200,7 @@ def start(
|
||||
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
||||
device=device,
|
||||
input=input,
|
||||
mask_dir=mask_dir,
|
||||
output_dir=output_dir,
|
||||
quality=quality,
|
||||
enable_interactive_seg=enable_interactive_seg,
|
||||
|
@ -63,7 +63,7 @@ SD_CONTROLNET_CHOICES: List[str] = [
|
||||
|
||||
SD_BRUSHNET_CHOICES: List[str] = [
|
||||
"Sanster/brushnet_random_mask",
|
||||
"Sanster/brushnet_segmentation_mask"
|
||||
"Sanster/brushnet_segmentation_mask",
|
||||
]
|
||||
|
||||
SD2_CONTROLNET_CHOICES = [
|
||||
@ -99,6 +99,10 @@ OUTPUT_DIR_HELP = """
|
||||
Result images will be saved to output directory automatically.
|
||||
"""
|
||||
|
||||
MASK_DIR_HELP = """
|
||||
You can view masks in FileManager
|
||||
"""
|
||||
|
||||
INPUT_HELP = """
|
||||
If input is image, it will be loaded by default.
|
||||
If input is directory, you can browse and select image in file manager.
|
||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
from fastapi import FastAPI, UploadFile, HTTPException
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
from ..schema import MediasResponse, MediaTab
|
||||
@ -16,9 +16,10 @@ from .utils import aspect_to_string, generate_filename, glob_img
|
||||
|
||||
|
||||
class FileManager:
|
||||
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
|
||||
def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
|
||||
self.app = app
|
||||
self.input_dir: Path = input_dir
|
||||
self.mask_dir: Path = mask_dir
|
||||
self.output_dir: Path = output_dir
|
||||
|
||||
self.image_dir_filenames = []
|
||||
@ -63,6 +64,8 @@ class FileManager:
|
||||
return self.input_dir
|
||||
elif tab == "output":
|
||||
return self.output_dir
|
||||
elif tab == "mask":
|
||||
return self.mask_dir
|
||||
else:
|
||||
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
|
||||
|
||||
|
@ -244,6 +244,7 @@ class ApiConfig(BaseModel):
|
||||
cpu_textencoder: bool
|
||||
device: Device
|
||||
input: Optional[Path]
|
||||
mask_dir: Optional[Path]
|
||||
output_dir: Optional[Path]
|
||||
quality: int
|
||||
enable_interactive_seg: bool
|
||||
@ -436,7 +437,7 @@ class RunPluginRequest(BaseModel):
|
||||
scale: float = Field(2.0, description="Scale for upscaling")
|
||||
|
||||
|
||||
MediaTab = Literal["input", "output"]
|
||||
MediaTab = Literal["input", "output", "mask"]
|
||||
|
||||
|
||||
class MediasResponse(BaseModel):
|
||||
|
@ -3,10 +3,11 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import mimetypes
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
mimetypes.add_type('text/css', '.css')
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
from iopaint.schema import (
|
||||
Device,
|
||||
@ -78,40 +79,43 @@ def load_config(p: Path) -> WebConfig:
|
||||
|
||||
|
||||
def save_config(
|
||||
host,
|
||||
port,
|
||||
model,
|
||||
model_dir,
|
||||
no_half,
|
||||
low_mem,
|
||||
cpu_offload,
|
||||
disable_nsfw_checker,
|
||||
local_files_only,
|
||||
cpu_textencoder,
|
||||
device,
|
||||
input,
|
||||
output_dir,
|
||||
quality,
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
remove_bg_model,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
inbrowser,
|
||||
host,
|
||||
port,
|
||||
model,
|
||||
model_dir,
|
||||
no_half,
|
||||
low_mem,
|
||||
cpu_offload,
|
||||
disable_nsfw_checker,
|
||||
local_files_only,
|
||||
cpu_textencoder,
|
||||
device,
|
||||
input,
|
||||
mask_dir,
|
||||
output_dir,
|
||||
quality,
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
remove_bg_model,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
inbrowser,
|
||||
):
|
||||
config = WebConfig(**locals())
|
||||
if str(config.input) == ".":
|
||||
config.input = None
|
||||
if str(config.output_dir) == ".":
|
||||
config.output_dir = None
|
||||
if str(config.mask_dir) == ".":
|
||||
config.mask_dir = None
|
||||
config.model = config.model.strip()
|
||||
print(config.model_dump_json(indent=4))
|
||||
if config.input and not os.path.exists(config.input):
|
||||
@ -166,7 +170,7 @@ def main(config_file: Path):
|
||||
model = gr.Textbox(
|
||||
init_config.model,
|
||||
label="Current Model. Model will be automatically downloaded. "
|
||||
"You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
|
||||
"You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
|
||||
)
|
||||
|
||||
device = gr.Radio(
|
||||
@ -207,6 +211,10 @@ def main(config_file: Path):
|
||||
init_config.output_dir,
|
||||
label=f"Output directory. {OUTPUT_DIR_HELP}",
|
||||
)
|
||||
mask_dir = gr.Textbox(
|
||||
init_config.mask_dir,
|
||||
label=f"Mask directory. {MASK_DIR_HELP}",
|
||||
)
|
||||
|
||||
with gr.Tab("Plugins"):
|
||||
with gr.Row():
|
||||
@ -288,6 +296,7 @@ def main(config_file: Path):
|
||||
cpu_textencoder,
|
||||
device,
|
||||
input,
|
||||
mask_dir,
|
||||
output_dir,
|
||||
quality,
|
||||
enable_interactive_seg,
|
||||
|
@ -50,6 +50,7 @@ const SORT_BY_MODIFIED_TIME = "Modified time"
|
||||
|
||||
const IMAGE_TAB = "input"
|
||||
const OUTPUT_TAB = "output"
|
||||
export const MASK_TAB = "mask"
|
||||
|
||||
const SortByMap = {
|
||||
[SortBy.NAME]: SORT_BY_NAME,
|
||||
@ -264,6 +265,7 @@ export default function FileManager(props: Props) {
|
||||
<TabsList aria-label="Manage your account">
|
||||
<TabsTrigger value={IMAGE_TAB}>Image Directory</TabsTrigger>
|
||||
<TabsTrigger value={OUTPUT_TAB}>Output Directory</TabsTrigger>
|
||||
<TabsTrigger value={MASK_TAB}>Mask Directory</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs>
|
||||
|
||||
|
@ -7,8 +7,8 @@ import { useImage } from "@/hooks/useImage"
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"
|
||||
import PromptInput from "./PromptInput"
|
||||
import { RotateCw, Image, Upload } from "lucide-react"
|
||||
import FileManager from "./FileManager"
|
||||
import { getMediaFile } from "@/lib/api"
|
||||
import FileManager, { MASK_TAB } from "./FileManager"
|
||||
import { getMediaBlob, getMediaFile } from "@/lib/api"
|
||||
import { useStore } from "@/lib/states"
|
||||
import SettingsDialog from "./Settings"
|
||||
import { cn, fileToImage } from "@/lib/utils"
|
||||
@ -31,6 +31,7 @@ const Header = () => {
|
||||
hidePrevMask,
|
||||
imageHeight,
|
||||
imageWidth,
|
||||
handleFileManagerMaskSelect,
|
||||
] = useStore((state) => [
|
||||
state.file,
|
||||
state.customMask,
|
||||
@ -46,6 +47,7 @@ const Header = () => {
|
||||
state.hidePrevMask,
|
||||
state.imageHeight,
|
||||
state.imageWidth,
|
||||
state.handleFileManagerMaskSelect,
|
||||
])
|
||||
|
||||
const { toast } = useToast()
|
||||
@ -64,25 +66,29 @@ const Header = () => {
|
||||
hidePrevMask()
|
||||
}
|
||||
|
||||
const handleOnPhotoClick = async (tab: string, filename: string) => {
|
||||
try {
|
||||
if (tab === MASK_TAB) {
|
||||
const maskBlob = await getMediaBlob(tab, filename)
|
||||
handleFileManagerMaskSelect(maskBlob)
|
||||
} else {
|
||||
const newFile = await getMediaFile(tab, filename)
|
||||
setFile(newFile)
|
||||
}
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
|
||||
<div className="flex items-center gap-1">
|
||||
{serverConfig.enableFileManager ? (
|
||||
<FileManager
|
||||
photoWidth={512}
|
||||
onPhotoClick={async (tab: string, filename: string) => {
|
||||
try {
|
||||
const newFile = await getMediaFile(tab, filename)
|
||||
setFile(newFile)
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<FileManager photoWidth={512} onPhotoClick={handleOnPhotoClick} />
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
|
@ -135,7 +135,7 @@ export async function runPlugin(
|
||||
body: JSON.stringify({
|
||||
name,
|
||||
image: imageBase64,
|
||||
scale:upscale,
|
||||
scale: upscale,
|
||||
clicks,
|
||||
}),
|
||||
})
|
||||
@ -167,6 +167,23 @@ export async function getMediaFile(tab: string, filename: string) {
|
||||
throw new Error(errMsg.errors)
|
||||
}
|
||||
|
||||
export async function getMediaBlob(tab: string, filename: string) {
|
||||
const res = await fetch(
|
||||
`${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent(
|
||||
filename
|
||||
)}`,
|
||||
{
|
||||
method: "GET",
|
||||
}
|
||||
)
|
||||
if (res.ok) {
|
||||
const blob = await res.blob()
|
||||
return blob
|
||||
}
|
||||
const errMsg = await res.json()
|
||||
throw new Error(errMsg.errors)
|
||||
}
|
||||
|
||||
export async function getMedias(tab: string): Promise<Filename[]> {
|
||||
const res = await api.get(`medias`, { params: { tab } })
|
||||
return res.data
|
||||
|
@ -207,6 +207,7 @@ type AppAction = {
|
||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||
resetInteractiveSegState: () => void
|
||||
handleInteractiveSegAccept: () => void
|
||||
handleFileManagerMaskSelect: (blob: Blob) => Promise<void>
|
||||
showPromptInput: () => boolean
|
||||
|
||||
runInpainting: () => Promise<void>
|
||||
@ -903,6 +904,16 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
})
|
||||
},
|
||||
|
||||
handleFileManagerMaskSelect: async (blob: Blob) => {
|
||||
const newMask = new Image()
|
||||
|
||||
await loadImage(newMask, URL.createObjectURL(blob))
|
||||
set((state) => {
|
||||
state.editorState.extraMasks.push(castDraft(newMask))
|
||||
})
|
||||
get().runInpainting()
|
||||
},
|
||||
|
||||
setIsInpainting: (newValue: boolean) =>
|
||||
set((state) => {
|
||||
state.isInpainting = newValue
|
||||
|
Loading…
Reference in New Issue
Block a user