switch to FastAPI
This commit is contained in:
parent
c4abda3942
commit
79a41454f6
@ -175,6 +175,7 @@ class Api:
|
||||
def api_inpaint(self, req: InpaintRequest):
|
||||
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
|
||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||
if image.shape[:2] != mask.shape[:2]:
|
||||
raise HTTPException(
|
||||
|
@ -7,13 +7,7 @@ from PIL import Image, ImageOps, PngImagePlugin
|
||||
from fastapi import FastAPI, UploadFile, HTTPException
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
from ..schema import (
|
||||
MediasResponse,
|
||||
MediasRequest,
|
||||
MediaFileRequest,
|
||||
MediaTab,
|
||||
MediaThumbnailFileRequest,
|
||||
)
|
||||
from ..schema import MediasResponse, MediaTab
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
@ -34,9 +28,9 @@ class FileManager:
|
||||
|
||||
# fmt: off
|
||||
self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
||||
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse])
|
||||
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None)
|
||||
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None)
|
||||
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
|
||||
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
|
||||
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
|
||||
# fmt: on
|
||||
|
||||
def api_save_image(self, file: UploadFile):
|
||||
@ -45,18 +39,21 @@ class FileManager:
|
||||
with open(self.output_dir / filename, "wb") as fw:
|
||||
fw.write(origin_image_bytes)
|
||||
|
||||
def api_medias(self, req: MediasRequest) -> List[MediasResponse]:
|
||||
img_dir = self._get_dir(req.tab)
|
||||
def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
|
||||
img_dir = self._get_dir(tab)
|
||||
return self._media_names(img_dir)
|
||||
|
||||
def api_media_file(self, req: MediaFileRequest) -> FileResponse:
|
||||
file_path = self._get_file(req.tab, req.filename)
|
||||
return FileResponse(file_path)
|
||||
def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
|
||||
file_path = self._get_file(tab, filename)
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse:
|
||||
img_dir = self._get_dir(req.tab)
|
||||
# tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
|
||||
def api_media_thumbnail_file(
|
||||
self, tab: MediaTab, filename: str, width: int, height: int
|
||||
) -> FileResponse:
|
||||
img_dir = self._get_dir(tab)
|
||||
thumb_filename, (width, height) = self.get_thumbnail(
|
||||
img_dir, req.filename, width=req.width, height=req.height
|
||||
img_dir, filename, width=width, height=height
|
||||
)
|
||||
thumbnail_filepath = self.thumbnail_directory / thumb_filename
|
||||
return FileResponse(
|
||||
@ -65,6 +62,7 @@ class FileManager:
|
||||
"X-Width": str(width),
|
||||
"X-Height": str(height),
|
||||
},
|
||||
media_type="image/jpeg",
|
||||
)
|
||||
|
||||
def _get_dir(self, tab: MediaTab) -> Path:
|
||||
|
@ -236,10 +236,6 @@ class RunPluginRequest(BaseModel):
|
||||
MediaTab = Literal["input", "output"]
|
||||
|
||||
|
||||
class MediasRequest(BaseModel):
|
||||
tab: MediaTab
|
||||
|
||||
|
||||
class MediasResponse(BaseModel):
|
||||
name: str
|
||||
height: int
|
||||
@ -248,18 +244,6 @@ class MediasResponse(BaseModel):
|
||||
mtime: float
|
||||
|
||||
|
||||
class MediaFileRequest(BaseModel):
|
||||
tab: MediaTab
|
||||
filename: str
|
||||
|
||||
|
||||
class MediaThumbnailFileRequest(BaseModel):
|
||||
tab: MediaTab
|
||||
filename: str
|
||||
width: int = 0
|
||||
height: int = 0
|
||||
|
||||
|
||||
class GenInfoResponse(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
|
@ -1,186 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import cv2
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
cv2.setNumThreads(NUM_THREADS)
|
||||
|
||||
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||
|
||||
import hashlib
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
|
||||
import io
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from lama_cleaner.const import *
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.plugins import (
|
||||
InteractiveSeg,
|
||||
RemoveBG,
|
||||
AnimeSeg,
|
||||
build_plugins,
|
||||
)
|
||||
from lama_cleaner.schema import InpaintRequest
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
pil_to_bytes,
|
||||
is_mac,
|
||||
get_image_ext, concat_alpha_channel,
|
||||
)
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
def diffuser_callback(i, t, latents):
|
||||
socketio.emit("diffusion_progress", {"step": i})
|
||||
|
||||
|
||||
def start(
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
no_half: bool,
|
||||
cpu_offload: bool,
|
||||
disable_nsfw_checker,
|
||||
cpu_textencoder: bool,
|
||||
device: Device,
|
||||
gui: bool,
|
||||
disable_model_switch: bool,
|
||||
input: Path,
|
||||
output_dir: Path,
|
||||
quality: int,
|
||||
enable_interactive_seg: bool,
|
||||
interactive_seg_model: InteractiveSegModel,
|
||||
interactive_seg_device: Device,
|
||||
enable_remove_bg: bool,
|
||||
enable_anime_seg: bool,
|
||||
enable_realesrgan: bool,
|
||||
realesrgan_device: Device,
|
||||
realesrgan_model: RealESRGANModel,
|
||||
enable_gfpgan: bool,
|
||||
gfpgan_device: Device,
|
||||
enable_restoreformer: bool,
|
||||
restoreformer_device: Device,
|
||||
):
|
||||
if input:
|
||||
if not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
exit()
|
||||
if input.is_dir():
|
||||
logger.info(f"Initialize file manager")
|
||||
file_manager = FileManager(app)
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
file_manager.output_dir = output_dir
|
||||
global_config.file_manager = file_manager
|
||||
else:
|
||||
global_config.input_image_path = input
|
||||
|
||||
global_config.image_quality = quality
|
||||
global_config.disable_model_switch = disable_model_switch
|
||||
global_config.is_desktop = gui
|
||||
build_plugins(
|
||||
global_config,
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
no_half,
|
||||
)
|
||||
if output_dir:
|
||||
output_dir = output_dir.expanduser().absolute()
|
||||
logger.info(f"Image will auto save to output dir: {output_dir}")
|
||||
if not output_dir.exists():
|
||||
logger.info(f"Create output dir: {output_dir}")
|
||||
output_dir.mkdir(parents=True)
|
||||
global_config.output_dir = output_dir
|
||||
|
||||
global_config.model_manager = ModelManager(
|
||||
name=model,
|
||||
device=torch.device(device),
|
||||
no_half=no_half,
|
||||
disable_nsfw=disable_nsfw_checker,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
cpu_offload=cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
if gui:
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
ui = FlaskUI(
|
||||
app,
|
||||
socketio=socketio,
|
||||
width=1200,
|
||||
height=800,
|
||||
host=host,
|
||||
port=port,
|
||||
close_server_on_exit=True,
|
||||
idle_interval=60,
|
||||
)
|
||||
ui.run()
|
||||
else:
|
||||
socketio.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
allow_unsafe_werkzeug=True,
|
||||
)
|
@ -48,7 +48,7 @@ const SORT_BY_NAME = "Name"
|
||||
const SORT_BY_CREATED_TIME = "Created time"
|
||||
const SORT_BY_MODIFIED_TIME = "Modified time"
|
||||
|
||||
const IMAGE_TAB = "image"
|
||||
const IMAGE_TAB = "input"
|
||||
const OUTPUT_TAB = "output"
|
||||
|
||||
const SortByMap = {
|
||||
@ -158,7 +158,9 @@ export default function FileManager(props: Props) {
|
||||
const newPhotos = filteredFilenames.map((filename: Filename) => {
|
||||
const width = photoWidth
|
||||
const height = filename.height * (width / filename.width)
|
||||
const src = `${API_ENDPOINT}/media_thumbnail/${tab}/${filename.name}?width=${width}&height=${height}`
|
||||
const src = `${API_ENDPOINT}/media_thumbnail_file?tab=${tab}&filename=${encodeURIComponent(
|
||||
filename.name
|
||||
)}&width=${Math.ceil(width)}&height=${Math.ceil(height)}`
|
||||
return { src, height, width, name: filename.name }
|
||||
})
|
||||
setPhotos(newPhotos)
|
||||
|
@ -71,8 +71,16 @@ const Header = () => {
|
||||
<FileManager
|
||||
photoWidth={512}
|
||||
onPhotoClick={async (tab: string, filename: string) => {
|
||||
try {
|
||||
const newFile = await getMediaFile(tab, filename)
|
||||
setFile(newFile)
|
||||
} catch (e: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
description: e.message ? e.message : e.toString(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
|
@ -115,19 +115,15 @@ export function SettingsDialog() {
|
||||
updateAppState({ disableShortCuts: true })
|
||||
switchModel(model.name)
|
||||
.then((res) => {
|
||||
if (res.ok) {
|
||||
toast({
|
||||
title: `Switch to ${model.name} success`,
|
||||
})
|
||||
setAppModel(model)
|
||||
} else {
|
||||
throw new Error("Server error")
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
.catch((error: any) => {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: `Switch to ${model.name} failed`,
|
||||
title: `Switch to ${model.name} failed: ${error}`,
|
||||
})
|
||||
setModel(settings.model)
|
||||
})
|
||||
@ -168,17 +164,21 @@ export function SettingsDialog() {
|
||||
.filter((info) => model_types.includes(info.model_type))
|
||||
.map((info: ModelInfo) => {
|
||||
return (
|
||||
<div key={info.name} onClick={() => onModelSelect(info)}>
|
||||
<div
|
||||
key={info.name}
|
||||
onClick={() => onModelSelect(info)}
|
||||
className="px-2"
|
||||
>
|
||||
<div
|
||||
className={cn([
|
||||
info.name === model.name ? "bg-muted" : "hover:bg-muted",
|
||||
"rounded-md px-2 py-1 my-1",
|
||||
"rounded-md px-2 py-2",
|
||||
"cursor-default",
|
||||
])}
|
||||
>
|
||||
<div className="text-base">{info.name}</div>
|
||||
</div>
|
||||
<Separator />
|
||||
<Separator className="my-1" />
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
|
||||
import { Settings } from "@/lib/states"
|
||||
import { convertToBase64, srcToFile } from "@/lib/utils"
|
||||
import axios, { AxiosError } from "axios"
|
||||
import axios from "axios"
|
||||
|
||||
export const API_ENDPOINT = import.meta.env.VITE_BACKEND
|
||||
? import.meta.env.VITE_BACKEND
|
||||
@ -93,12 +93,7 @@ export function getServerConfig() {
|
||||
}
|
||||
|
||||
export function switchModel(name: string) {
|
||||
const fd = new FormData()
|
||||
fd.append("name", name)
|
||||
return fetch(`${API_ENDPOINT}/model`, {
|
||||
method: "POST",
|
||||
body: fd,
|
||||
})
|
||||
return axios.post(`${API_ENDPOINT}/model`, { name })
|
||||
}
|
||||
|
||||
export function currentModel() {
|
||||
@ -151,14 +146,18 @@ export async function runPlugin(
|
||||
|
||||
export async function getMediaFile(tab: string, filename: string) {
|
||||
const res = await fetch(
|
||||
`${API_ENDPOINT}/media/${tab}/${encodeURIComponent(filename)}`,
|
||||
`${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent(
|
||||
filename
|
||||
)}`,
|
||||
{
|
||||
method: "GET",
|
||||
}
|
||||
)
|
||||
if (res.ok) {
|
||||
const blob = await res.blob()
|
||||
const file = new File([blob], filename)
|
||||
const file = new File([blob], filename, {
|
||||
type: res.headers.get("Content-Type") ?? "image/png",
|
||||
})
|
||||
return file
|
||||
}
|
||||
const errMsg = await res.text()
|
||||
@ -166,15 +165,8 @@ export async function getMediaFile(tab: string, filename: string) {
|
||||
}
|
||||
|
||||
export async function getMedias(tab: string): Promise<Filename[]> {
|
||||
const res = await fetch(`${API_ENDPOINT}/medias/${tab}`, {
|
||||
method: "GET",
|
||||
})
|
||||
if (res.ok) {
|
||||
const filenames = await res.json()
|
||||
return filenames
|
||||
}
|
||||
const errMsg = await res.text()
|
||||
throw new Error(errMsg)
|
||||
const res = await axios.get(`${API_ENDPOINT}/medias`, { params: { tab } })
|
||||
return res.data
|
||||
}
|
||||
|
||||
export async function downloadToOutput(
|
||||
@ -192,7 +184,6 @@ export async function downloadToOutput(
|
||||
method: "POST",
|
||||
body: fd,
|
||||
})
|
||||
console.log(res.ok)
|
||||
if (!res.ok) {
|
||||
const errMsg = await res.text()
|
||||
throw new Error(errMsg)
|
||||
|
Loading…
Reference in New Issue
Block a user