add mask tab
This commit is contained in:
parent
60b1411d6b
commit
820ce5e4d0
@ -19,7 +19,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
||||||
@ -127,7 +126,7 @@ def api_middleware(app: FastAPI):
|
|||||||
"allow_headers": ["*"],
|
"allow_headers": ["*"],
|
||||||
"allow_origins": ["*"],
|
"allow_origins": ["*"],
|
||||||
"allow_credentials": True,
|
"allow_credentials": True,
|
||||||
"expose_headers": ["X-Seed"]
|
"expose_headers": ["X-Seed"],
|
||||||
}
|
}
|
||||||
app.add_middleware(CORSMiddleware, **cors_options)
|
app.add_middleware(CORSMiddleware, **cors_options)
|
||||||
|
|
||||||
@ -159,7 +158,8 @@ class Api:
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
|
||||||
|
response_model=ServerConfigResponse)
|
||||||
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||||
@ -361,6 +361,7 @@ class Api:
|
|||||||
return FileManager(
|
return FileManager(
|
||||||
app=self.app,
|
app=self.app,
|
||||||
input_dir=self.config.input,
|
input_dir=self.config.input,
|
||||||
|
mask_dir=self.config.mask_dir,
|
||||||
output_dir=self.config.output_dir,
|
output_dir=self.config.output_dir,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import webbrowser
|
import webbrowser
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@ -120,6 +120,9 @@ def start(
|
|||||||
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
||||||
device: Device = Option(Device.cpu),
|
device: Device = Option(Device.cpu),
|
||||||
input: Optional[Path] = Option(None, help=INPUT_HELP),
|
input: Optional[Path] = Option(None, help=INPUT_HELP),
|
||||||
|
mask_dir: Optional[Path] = Option(
|
||||||
|
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
||||||
|
),
|
||||||
output_dir: Optional[Path] = Option(
|
output_dir: Optional[Path] = Option(
|
||||||
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
||||||
),
|
),
|
||||||
@ -145,8 +148,11 @@ def start(
|
|||||||
if input and not input.exists():
|
if input and not input.exists():
|
||||||
logger.error(f"invalid --input: {input} not exists")
|
logger.error(f"invalid --input: {input} not exists")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
if mask_dir and not mask_dir.exists():
|
||||||
|
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
|
||||||
|
exit(-1)
|
||||||
if input and input.is_dir() and not output_dir:
|
if input and input.is_dir() and not output_dir:
|
||||||
logger.error(f"invalid --output-dir: must be set when --input is a directory")
|
logger.error("invalid --output-dir: --output-dir must be set when --input is a directory")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
if output_dir:
|
if output_dir:
|
||||||
output_dir = output_dir.expanduser().absolute()
|
output_dir = output_dir.expanduser().absolute()
|
||||||
@ -154,6 +160,8 @@ def start(
|
|||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
logger.info(f"Create output directory {output_dir}")
|
logger.info(f"Create output directory {output_dir}")
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
|
if mask_dir:
|
||||||
|
mask_dir = mask_dir.expanduser().absolute()
|
||||||
|
|
||||||
model_dir = model_dir.expanduser().absolute()
|
model_dir = model_dir.expanduser().absolute()
|
||||||
|
|
||||||
@ -192,6 +200,7 @@ def start(
|
|||||||
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
||||||
device=device,
|
device=device,
|
||||||
input=input,
|
input=input,
|
||||||
|
mask_dir=mask_dir,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
enable_interactive_seg=enable_interactive_seg,
|
enable_interactive_seg=enable_interactive_seg,
|
||||||
|
@ -63,7 +63,7 @@ SD_CONTROLNET_CHOICES: List[str] = [
|
|||||||
|
|
||||||
SD_BRUSHNET_CHOICES: List[str] = [
|
SD_BRUSHNET_CHOICES: List[str] = [
|
||||||
"Sanster/brushnet_random_mask",
|
"Sanster/brushnet_random_mask",
|
||||||
"Sanster/brushnet_segmentation_mask"
|
"Sanster/brushnet_segmentation_mask",
|
||||||
]
|
]
|
||||||
|
|
||||||
SD2_CONTROLNET_CHOICES = [
|
SD2_CONTROLNET_CHOICES = [
|
||||||
@ -99,6 +99,10 @@ OUTPUT_DIR_HELP = """
|
|||||||
Result images will be saved to output directory automatically.
|
Result images will be saved to output directory automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
MASK_DIR_HELP = """
|
||||||
|
You can view masks in FileManager
|
||||||
|
"""
|
||||||
|
|
||||||
INPUT_HELP = """
|
INPUT_HELP = """
|
||||||
If input is image, it will be loaded by default.
|
If input is image, it will be loaded by default.
|
||||||
If input is directory, you can browse and select image in file manager.
|
If input is directory, you can browse and select image in file manager.
|
||||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from PIL import Image, ImageOps, PngImagePlugin
|
from PIL import Image, ImageOps, PngImagePlugin
|
||||||
from fastapi import FastAPI, UploadFile, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from starlette.responses import FileResponse
|
from starlette.responses import FileResponse
|
||||||
|
|
||||||
from ..schema import MediasResponse, MediaTab
|
from ..schema import MediasResponse, MediaTab
|
||||||
@ -16,9 +16,10 @@ from .utils import aspect_to_string, generate_filename, glob_img
|
|||||||
|
|
||||||
|
|
||||||
class FileManager:
|
class FileManager:
|
||||||
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
|
def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.input_dir: Path = input_dir
|
self.input_dir: Path = input_dir
|
||||||
|
self.mask_dir: Path = mask_dir
|
||||||
self.output_dir: Path = output_dir
|
self.output_dir: Path = output_dir
|
||||||
|
|
||||||
self.image_dir_filenames = []
|
self.image_dir_filenames = []
|
||||||
@ -63,6 +64,8 @@ class FileManager:
|
|||||||
return self.input_dir
|
return self.input_dir
|
||||||
elif tab == "output":
|
elif tab == "output":
|
||||||
return self.output_dir
|
return self.output_dir
|
||||||
|
elif tab == "mask":
|
||||||
|
return self.mask_dir
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
|
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
|
||||||
|
|
||||||
|
@ -244,6 +244,7 @@ class ApiConfig(BaseModel):
|
|||||||
cpu_textencoder: bool
|
cpu_textencoder: bool
|
||||||
device: Device
|
device: Device
|
||||||
input: Optional[Path]
|
input: Optional[Path]
|
||||||
|
mask_dir: Optional[Path]
|
||||||
output_dir: Optional[Path]
|
output_dir: Optional[Path]
|
||||||
quality: int
|
quality: int
|
||||||
enable_interactive_seg: bool
|
enable_interactive_seg: bool
|
||||||
@ -436,7 +437,7 @@ class RunPluginRequest(BaseModel):
|
|||||||
scale: float = Field(2.0, description="Scale for upscaling")
|
scale: float = Field(2.0, description="Scale for upscaling")
|
||||||
|
|
||||||
|
|
||||||
MediaTab = Literal["input", "output"]
|
MediaTab = Literal["input", "output", "mask"]
|
||||||
|
|
||||||
|
|
||||||
class MediasResponse(BaseModel):
|
class MediasResponse(BaseModel):
|
||||||
|
@ -3,10 +3,11 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type('text/css', '.css')
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
from iopaint.schema import (
|
from iopaint.schema import (
|
||||||
Device,
|
Device,
|
||||||
@ -90,6 +91,7 @@ def save_config(
|
|||||||
cpu_textencoder,
|
cpu_textencoder,
|
||||||
device,
|
device,
|
||||||
input,
|
input,
|
||||||
|
mask_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
quality,
|
quality,
|
||||||
enable_interactive_seg,
|
enable_interactive_seg,
|
||||||
@ -112,6 +114,8 @@ def save_config(
|
|||||||
config.input = None
|
config.input = None
|
||||||
if str(config.output_dir) == ".":
|
if str(config.output_dir) == ".":
|
||||||
config.output_dir = None
|
config.output_dir = None
|
||||||
|
if str(config.mask_dir) == ".":
|
||||||
|
config.mask_dir = None
|
||||||
config.model = config.model.strip()
|
config.model = config.model.strip()
|
||||||
print(config.model_dump_json(indent=4))
|
print(config.model_dump_json(indent=4))
|
||||||
if config.input and not os.path.exists(config.input):
|
if config.input and not os.path.exists(config.input):
|
||||||
@ -207,6 +211,10 @@ def main(config_file: Path):
|
|||||||
init_config.output_dir,
|
init_config.output_dir,
|
||||||
label=f"Output directory. {OUTPUT_DIR_HELP}",
|
label=f"Output directory. {OUTPUT_DIR_HELP}",
|
||||||
)
|
)
|
||||||
|
mask_dir = gr.Textbox(
|
||||||
|
init_config.mask_dir,
|
||||||
|
label=f"Mask directory. {MASK_DIR_HELP}",
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Tab("Plugins"):
|
with gr.Tab("Plugins"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -288,6 +296,7 @@ def main(config_file: Path):
|
|||||||
cpu_textencoder,
|
cpu_textencoder,
|
||||||
device,
|
device,
|
||||||
input,
|
input,
|
||||||
|
mask_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
quality,
|
quality,
|
||||||
enable_interactive_seg,
|
enable_interactive_seg,
|
||||||
|
@ -50,6 +50,7 @@ const SORT_BY_MODIFIED_TIME = "Modified time"
|
|||||||
|
|
||||||
const IMAGE_TAB = "input"
|
const IMAGE_TAB = "input"
|
||||||
const OUTPUT_TAB = "output"
|
const OUTPUT_TAB = "output"
|
||||||
|
export const MASK_TAB = "mask"
|
||||||
|
|
||||||
const SortByMap = {
|
const SortByMap = {
|
||||||
[SortBy.NAME]: SORT_BY_NAME,
|
[SortBy.NAME]: SORT_BY_NAME,
|
||||||
@ -264,6 +265,7 @@ export default function FileManager(props: Props) {
|
|||||||
<TabsList aria-label="Manage your account">
|
<TabsList aria-label="Manage your account">
|
||||||
<TabsTrigger value={IMAGE_TAB}>Image Directory</TabsTrigger>
|
<TabsTrigger value={IMAGE_TAB}>Image Directory</TabsTrigger>
|
||||||
<TabsTrigger value={OUTPUT_TAB}>Output Directory</TabsTrigger>
|
<TabsTrigger value={OUTPUT_TAB}>Output Directory</TabsTrigger>
|
||||||
|
<TabsTrigger value={MASK_TAB}>Mask Directory</TabsTrigger>
|
||||||
</TabsList>
|
</TabsList>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ import { useImage } from "@/hooks/useImage"
|
|||||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"
|
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"
|
||||||
import PromptInput from "./PromptInput"
|
import PromptInput from "./PromptInput"
|
||||||
import { RotateCw, Image, Upload } from "lucide-react"
|
import { RotateCw, Image, Upload } from "lucide-react"
|
||||||
import FileManager from "./FileManager"
|
import FileManager, { MASK_TAB } from "./FileManager"
|
||||||
import { getMediaFile } from "@/lib/api"
|
import { getMediaBlob, getMediaFile } from "@/lib/api"
|
||||||
import { useStore } from "@/lib/states"
|
import { useStore } from "@/lib/states"
|
||||||
import SettingsDialog from "./Settings"
|
import SettingsDialog from "./Settings"
|
||||||
import { cn, fileToImage } from "@/lib/utils"
|
import { cn, fileToImage } from "@/lib/utils"
|
||||||
@ -31,6 +31,7 @@ const Header = () => {
|
|||||||
hidePrevMask,
|
hidePrevMask,
|
||||||
imageHeight,
|
imageHeight,
|
||||||
imageWidth,
|
imageWidth,
|
||||||
|
handleFileManagerMaskSelect,
|
||||||
] = useStore((state) => [
|
] = useStore((state) => [
|
||||||
state.file,
|
state.file,
|
||||||
state.customMask,
|
state.customMask,
|
||||||
@ -46,6 +47,7 @@ const Header = () => {
|
|||||||
state.hidePrevMask,
|
state.hidePrevMask,
|
||||||
state.imageHeight,
|
state.imageHeight,
|
||||||
state.imageWidth,
|
state.imageWidth,
|
||||||
|
state.handleFileManagerMaskSelect,
|
||||||
])
|
])
|
||||||
|
|
||||||
const { toast } = useToast()
|
const { toast } = useToast()
|
||||||
@ -64,16 +66,15 @@ const Header = () => {
|
|||||||
hidePrevMask()
|
hidePrevMask()
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
const handleOnPhotoClick = async (tab: string, filename: string) => {
|
||||||
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
|
|
||||||
<div className="flex items-center gap-1">
|
|
||||||
{serverConfig.enableFileManager ? (
|
|
||||||
<FileManager
|
|
||||||
photoWidth={512}
|
|
||||||
onPhotoClick={async (tab: string, filename: string) => {
|
|
||||||
try {
|
try {
|
||||||
|
if (tab === MASK_TAB) {
|
||||||
|
const maskBlob = await getMediaBlob(tab, filename)
|
||||||
|
handleFileManagerMaskSelect(maskBlob)
|
||||||
|
} else {
|
||||||
const newFile = await getMediaFile(tab, filename)
|
const newFile = await getMediaFile(tab, filename)
|
||||||
setFile(newFile)
|
setFile(newFile)
|
||||||
|
}
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
toast({
|
toast({
|
||||||
variant: "destructive",
|
variant: "destructive",
|
||||||
@ -81,8 +82,13 @@ const Header = () => {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}}
|
}
|
||||||
/>
|
|
||||||
|
return (
|
||||||
|
<header className="h-[60px] px-6 py-4 absolute top-[0] flex justify-between items-center w-full z-20 border-b backdrop-filter backdrop-blur-md bg-background/70">
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
{serverConfig.enableFileManager ? (
|
||||||
|
<FileManager photoWidth={512} onPhotoClick={handleOnPhotoClick} />
|
||||||
) : (
|
) : (
|
||||||
<></>
|
<></>
|
||||||
)}
|
)}
|
||||||
|
@ -167,6 +167,23 @@ export async function getMediaFile(tab: string, filename: string) {
|
|||||||
throw new Error(errMsg.errors)
|
throw new Error(errMsg.errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getMediaBlob(tab: string, filename: string) {
|
||||||
|
const res = await fetch(
|
||||||
|
`${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent(
|
||||||
|
filename
|
||||||
|
)}`,
|
||||||
|
{
|
||||||
|
method: "GET",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if (res.ok) {
|
||||||
|
const blob = await res.blob()
|
||||||
|
return blob
|
||||||
|
}
|
||||||
|
const errMsg = await res.json()
|
||||||
|
throw new Error(errMsg.errors)
|
||||||
|
}
|
||||||
|
|
||||||
export async function getMedias(tab: string): Promise<Filename[]> {
|
export async function getMedias(tab: string): Promise<Filename[]> {
|
||||||
const res = await api.get(`medias`, { params: { tab } })
|
const res = await api.get(`medias`, { params: { tab } })
|
||||||
return res.data
|
return res.data
|
||||||
|
@ -207,6 +207,7 @@ type AppAction = {
|
|||||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||||
resetInteractiveSegState: () => void
|
resetInteractiveSegState: () => void
|
||||||
handleInteractiveSegAccept: () => void
|
handleInteractiveSegAccept: () => void
|
||||||
|
handleFileManagerMaskSelect: (blob: Blob) => Promise<void>
|
||||||
showPromptInput: () => boolean
|
showPromptInput: () => boolean
|
||||||
|
|
||||||
runInpainting: () => Promise<void>
|
runInpainting: () => Promise<void>
|
||||||
@ -903,6 +904,16 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
|
||||||
|
handleFileManagerMaskSelect: async (blob: Blob) => {
|
||||||
|
const newMask = new Image()
|
||||||
|
|
||||||
|
await loadImage(newMask, URL.createObjectURL(blob))
|
||||||
|
set((state) => {
|
||||||
|
state.editorState.extraMasks.push(castDraft(newMask))
|
||||||
|
})
|
||||||
|
get().runInpainting()
|
||||||
|
},
|
||||||
|
|
||||||
setIsInpainting: (newValue: boolean) =>
|
setIsInpainting: (newValue: boolean) =>
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.isInpainting = newValue
|
state.isInpainting = newValue
|
||||||
|
Loading…
Reference in New Issue
Block a user