switch to FastAPI

This commit is contained in:
Qing 2024-01-01 16:05:34 +08:00
parent c4abda3942
commit 79a41454f6
8 changed files with 54 additions and 256 deletions

View File

@ -175,6 +175,7 @@ class Api:
def api_inpaint(self, req: InpaintRequest): def api_inpaint(self, req: InpaintRequest):
image, alpha_channel, infos = decode_base64_to_image(req.image) image, alpha_channel, infos = decode_base64_to_image(req.image)
mask, _, _ = decode_base64_to_image(req.mask, gray=True) mask, _, _ = decode_base64_to_image(req.mask, gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]: if image.shape[:2] != mask.shape[:2]:
raise HTTPException( raise HTTPException(

View File

@ -7,13 +7,7 @@ from PIL import Image, ImageOps, PngImagePlugin
from fastapi import FastAPI, UploadFile, HTTPException from fastapi import FastAPI, UploadFile, HTTPException
from starlette.responses import FileResponse from starlette.responses import FileResponse
from ..schema import ( from ..schema import MediasResponse, MediaTab
MediasResponse,
MediasRequest,
MediaFileRequest,
MediaTab,
MediaThumbnailFileRequest,
)
LARGE_ENOUGH_NUMBER = 100 LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
@ -34,9 +28,9 @@ class FileManager:
# fmt: off # fmt: off
self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse]) self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None) self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None) self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
# fmt: on # fmt: on
def api_save_image(self, file: UploadFile): def api_save_image(self, file: UploadFile):
@ -45,18 +39,21 @@ class FileManager:
with open(self.output_dir / filename, "wb") as fw: with open(self.output_dir / filename, "wb") as fw:
fw.write(origin_image_bytes) fw.write(origin_image_bytes)
def api_medias(self, req: MediasRequest) -> List[MediasResponse]: def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
img_dir = self._get_dir(req.tab) img_dir = self._get_dir(tab)
return self._media_names(img_dir) return self._media_names(img_dir)
def api_media_file(self, req: MediaFileRequest) -> FileResponse: def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
file_path = self._get_file(req.tab, req.filename) file_path = self._get_file(tab, filename)
return FileResponse(file_path) return FileResponse(file_path, media_type="image/png")
def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse: # tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
img_dir = self._get_dir(req.tab) def api_media_thumbnail_file(
self, tab: MediaTab, filename: str, width: int, height: int
) -> FileResponse:
img_dir = self._get_dir(tab)
thumb_filename, (width, height) = self.get_thumbnail( thumb_filename, (width, height) = self.get_thumbnail(
img_dir, req.filename, width=req.width, height=req.height img_dir, filename, width=width, height=height
) )
thumbnail_filepath = self.thumbnail_directory / thumb_filename thumbnail_filepath = self.thumbnail_directory / thumb_filename
return FileResponse( return FileResponse(
@ -65,6 +62,7 @@ class FileManager:
"X-Width": str(width), "X-Width": str(width),
"X-Height": str(height), "X-Height": str(height),
}, },
media_type="image/jpeg",
) )
def _get_dir(self, tab: MediaTab) -> Path: def _get_dir(self, tab: MediaTab) -> Path:

View File

@ -236,10 +236,6 @@ class RunPluginRequest(BaseModel):
MediaTab = Literal["input", "output"] MediaTab = Literal["input", "output"]
class MediasRequest(BaseModel):
tab: MediaTab
class MediasResponse(BaseModel): class MediasResponse(BaseModel):
name: str name: str
height: int height: int
@ -248,18 +244,6 @@ class MediasResponse(BaseModel):
mtime: float mtime: float
class MediaFileRequest(BaseModel):
tab: MediaTab
filename: str
class MediaThumbnailFileRequest(BaseModel):
tab: MediaTab
filename: str
width: int = 0
height: int = 0
class GenInfoResponse(BaseModel): class GenInfoResponse(BaseModel):
prompt: str = "" prompt: str = ""
negative_prompt: str = "" negative_prompt: str = ""

View File

@ -1,186 +0,0 @@
#!/usr/bin/env python3
import multiprocessing
import os
import cv2
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
NUM_THREADS = str(multiprocessing.cpu_count())
cv2.setNumThreads(NUM_THREADS)
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
import hashlib
import traceback
from dataclasses import dataclass
import io
import random
import time
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from lama_cleaner.const import *
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import (
InteractiveSeg,
RemoveBG,
AnimeSeg,
build_plugins,
)
from lama_cleaner.schema import InpaintRequest
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
pil_to_bytes,
is_mac,
get_image_ext, concat_alpha_channel,
)
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
global_config = GlobalConfig()
def diffuser_callback(i, t, latents):
socketio.emit("diffusion_progress", {"step": i})
def start(
host: str,
port: int,
model: str,
no_half: bool,
cpu_offload: bool,
disable_nsfw_checker,
cpu_textencoder: bool,
device: Device,
gui: bool,
disable_model_switch: bool,
input: Path,
output_dir: Path,
quality: int,
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: RealESRGANModel,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
):
if input:
if not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if input.is_dir():
logger.info(f"Initialize file manager")
file_manager = FileManager(app)
app.config["THUMBNAIL_MEDIA_ROOT"] = input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
file_manager.output_dir = output_dir
global_config.file_manager = file_manager
else:
global_config.input_image_path = input
global_config.image_quality = quality
global_config.disable_model_switch = disable_model_switch
global_config.is_desktop = gui
build_plugins(
global_config,
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
no_half,
)
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will auto save to output dir: {output_dir}")
if not output_dir.exists():
logger.info(f"Create output dir: {output_dir}")
output_dir.mkdir(parents=True)
global_config.output_dir = output_dir
global_config.model_manager = ModelManager(
name=model,
device=torch.device(device),
no_half=no_half,
disable_nsfw=disable_nsfw_checker,
sd_cpu_textencoder=cpu_textencoder,
cpu_offload=cpu_offload,
callback=diffuser_callback,
)
if gui:
from flaskwebgui import FlaskUI
ui = FlaskUI(
app,
socketio=socketio,
width=1200,
height=800,
host=host,
port=port,
close_server_on_exit=True,
idle_interval=60,
)
ui.run()
else:
socketio.run(
app,
host=host,
port=port,
allow_unsafe_werkzeug=True,
)

View File

@ -48,7 +48,7 @@ const SORT_BY_NAME = "Name"
const SORT_BY_CREATED_TIME = "Created time" const SORT_BY_CREATED_TIME = "Created time"
const SORT_BY_MODIFIED_TIME = "Modified time" const SORT_BY_MODIFIED_TIME = "Modified time"
const IMAGE_TAB = "image" const IMAGE_TAB = "input"
const OUTPUT_TAB = "output" const OUTPUT_TAB = "output"
const SortByMap = { const SortByMap = {
@ -158,7 +158,9 @@ export default function FileManager(props: Props) {
const newPhotos = filteredFilenames.map((filename: Filename) => { const newPhotos = filteredFilenames.map((filename: Filename) => {
const width = photoWidth const width = photoWidth
const height = filename.height * (width / filename.width) const height = filename.height * (width / filename.width)
const src = `${API_ENDPOINT}/media_thumbnail/${tab}/${filename.name}?width=${width}&height=${height}` const src = `${API_ENDPOINT}/media_thumbnail_file?tab=${tab}&filename=${encodeURIComponent(
filename.name
)}&width=${Math.ceil(width)}&height=${Math.ceil(height)}`
return { src, height, width, name: filename.name } return { src, height, width, name: filename.name }
}) })
setPhotos(newPhotos) setPhotos(newPhotos)

View File

@ -71,8 +71,16 @@ const Header = () => {
<FileManager <FileManager
photoWidth={512} photoWidth={512}
onPhotoClick={async (tab: string, filename: string) => { onPhotoClick={async (tab: string, filename: string) => {
try {
const newFile = await getMediaFile(tab, filename) const newFile = await getMediaFile(tab, filename)
setFile(newFile) setFile(newFile)
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
return
}
}} }}
/> />
) : ( ) : (

View File

@ -115,19 +115,15 @@ export function SettingsDialog() {
updateAppState({ disableShortCuts: true }) updateAppState({ disableShortCuts: true })
switchModel(model.name) switchModel(model.name)
.then((res) => { .then((res) => {
if (res.ok) {
toast({ toast({
title: `Switch to ${model.name} success`, title: `Switch to ${model.name} success`,
}) })
setAppModel(model) setAppModel(model)
} else {
throw new Error("Server error")
}
}) })
.catch(() => { .catch((error: any) => {
toast({ toast({
variant: "destructive", variant: "destructive",
title: `Switch to ${model.name} failed`, title: `Switch to ${model.name} failed: ${error}`,
}) })
setModel(settings.model) setModel(settings.model)
}) })
@ -168,17 +164,21 @@ export function SettingsDialog() {
.filter((info) => model_types.includes(info.model_type)) .filter((info) => model_types.includes(info.model_type))
.map((info: ModelInfo) => { .map((info: ModelInfo) => {
return ( return (
<div key={info.name} onClick={() => onModelSelect(info)}> <div
key={info.name}
onClick={() => onModelSelect(info)}
className="px-2"
>
<div <div
className={cn([ className={cn([
info.name === model.name ? "bg-muted" : "hover:bg-muted", info.name === model.name ? "bg-muted" : "hover:bg-muted",
"rounded-md px-2 py-1 my-1", "rounded-md px-2 py-2",
"cursor-default", "cursor-default",
])} ])}
> >
<div className="text-base">{info.name}</div> <div className="text-base">{info.name}</div>
</div> </div>
<Separator /> <Separator className="my-1" />
</div> </div>
) )
}) })

View File

@ -1,7 +1,7 @@
import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types" import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
import { Settings } from "@/lib/states" import { Settings } from "@/lib/states"
import { convertToBase64, srcToFile } from "@/lib/utils" import { convertToBase64, srcToFile } from "@/lib/utils"
import axios, { AxiosError } from "axios" import axios from "axios"
export const API_ENDPOINT = import.meta.env.VITE_BACKEND export const API_ENDPOINT = import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND ? import.meta.env.VITE_BACKEND
@ -93,12 +93,7 @@ export function getServerConfig() {
} }
export function switchModel(name: string) { export function switchModel(name: string) {
const fd = new FormData() return axios.post(`${API_ENDPOINT}/model`, { name })
fd.append("name", name)
return fetch(`${API_ENDPOINT}/model`, {
method: "POST",
body: fd,
})
} }
export function currentModel() { export function currentModel() {
@ -151,14 +146,18 @@ export async function runPlugin(
export async function getMediaFile(tab: string, filename: string) { export async function getMediaFile(tab: string, filename: string) {
const res = await fetch( const res = await fetch(
`${API_ENDPOINT}/media/${tab}/${encodeURIComponent(filename)}`, `${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent(
filename
)}`,
{ {
method: "GET", method: "GET",
} }
) )
if (res.ok) { if (res.ok) {
const blob = await res.blob() const blob = await res.blob()
const file = new File([blob], filename) const file = new File([blob], filename, {
type: res.headers.get("Content-Type") ?? "image/png",
})
return file return file
} }
const errMsg = await res.text() const errMsg = await res.text()
@ -166,15 +165,8 @@ export async function getMediaFile(tab: string, filename: string) {
} }
export async function getMedias(tab: string): Promise<Filename[]> { export async function getMedias(tab: string): Promise<Filename[]> {
const res = await fetch(`${API_ENDPOINT}/medias/${tab}`, { const res = await axios.get(`${API_ENDPOINT}/medias`, { params: { tab } })
method: "GET", return res.data
})
if (res.ok) {
const filenames = await res.json()
return filenames
}
const errMsg = await res.text()
throw new Error(errMsg)
} }
export async function downloadToOutput( export async function downloadToOutput(
@ -192,7 +184,6 @@ export async function downloadToOutput(
method: "POST", method: "POST",
body: fd, body: fd,
}) })
console.log(res.ok)
if (!res.ok) { if (!res.ok) {
const errMsg = await res.text() const errMsg = await res.text()
throw new Error(errMsg) throw new Error(errMsg)