diff --git a/lama_cleaner/api.py b/lama_cleaner/api.py index d5a1ae9..5afb0b2 100644 --- a/lama_cleaner/api.py +++ b/lama_cleaner/api.py @@ -175,6 +175,7 @@ class Api: def api_inpaint(self, req: InpaintRequest): image, alpha_channel, infos = decode_base64_to_image(req.image) mask, _, _ = decode_base64_to_image(req.mask, gray=True) + mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] if image.shape[:2] != mask.shape[:2]: raise HTTPException( diff --git a/lama_cleaner/file_manager/file_manager.py b/lama_cleaner/file_manager/file_manager.py index 6a38cfa..cb33278 100644 --- a/lama_cleaner/file_manager/file_manager.py +++ b/lama_cleaner/file_manager/file_manager.py @@ -7,13 +7,7 @@ from PIL import Image, ImageOps, PngImagePlugin from fastapi import FastAPI, UploadFile, HTTPException from starlette.responses import FileResponse -from ..schema import ( - MediasResponse, - MediasRequest, - MediaFileRequest, - MediaTab, - MediaThumbnailFileRequest, -) +from ..schema import MediasResponse, MediaTab LARGE_ENOUGH_NUMBER = 100 PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) @@ -34,9 +28,9 @@ class FileManager: # fmt: off self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) - self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse]) - self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None) - self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None) + self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse]) + self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"]) + self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"]) # fmt: on def api_save_image(self, file: UploadFile): @@ -45,18 +39,21 @@ class FileManager: with open(self.output_dir / filename, "wb") as fw: fw.write(origin_image_bytes) - def api_medias(self, req: MediasRequest) -> List[MediasResponse]: - img_dir = self._get_dir(req.tab) + def api_medias(self, tab: MediaTab) -> List[MediasResponse]: + img_dir = self._get_dir(tab) return self._media_names(img_dir) - def api_media_file(self, req: MediaFileRequest) -> FileResponse: - file_path = self._get_file(req.tab, req.filename) - return FileResponse(file_path) + def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse: + file_path = self._get_file(tab, filename) + return FileResponse(file_path, media_type="image/png") - def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse: - img_dir = self._get_dir(req.tab) + # tab=${tab}?filename=${filename.name}?width=${width}&height=${height} + def api_media_thumbnail_file( + self, tab: MediaTab, filename: str, width: int, height: int + ) -> FileResponse: + img_dir = self._get_dir(tab) thumb_filename, (width, height) = self.get_thumbnail( - img_dir, req.filename, width=req.width, height=req.height + img_dir, filename, width=width, height=height ) thumbnail_filepath = self.thumbnail_directory / thumb_filename return FileResponse( @@ -65,6 +62,7 @@ class FileManager: "X-Width": str(width), "X-Height": str(height), }, + media_type="image/jpeg", ) def _get_dir(self, tab: MediaTab) -> Path: diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index acdee85..0cedefd 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -236,10 +236,6 @@ class RunPluginRequest(BaseModel): MediaTab = Literal["input", "output"] -class MediasRequest(BaseModel): - tab: MediaTab - - class MediasResponse(BaseModel): name: str height: int @@ -248,18 +244,6 @@ class MediasResponse(BaseModel): mtime: float -class MediaFileRequest(BaseModel): - tab: MediaTab - filename: str - - -class MediaThumbnailFileRequest(BaseModel): - tab: MediaTab - filename: str - width: int = 0 - height: int = 0 - - class GenInfoResponse(BaseModel): prompt: str = "" negative_prompt: str = "" diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py deleted file mode 100644 index 2c1d6ed..0000000 --- a/lama_cleaner/server.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python3 -import multiprocessing -import os - -import cv2 - -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - -NUM_THREADS = str(multiprocessing.cpu_count()) -cv2.setNumThreads(NUM_THREADS) - -# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 -os.environ["KMP_DUPLICATE_LIB_OK"] = "True" - -os.environ["OMP_NUM_THREADS"] = NUM_THREADS -os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS -os.environ["MKL_NUM_THREADS"] = NUM_THREADS -os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS -os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS - -import hashlib -import traceback -from dataclasses import dataclass - -import io -import random -import time -from pathlib import Path - -import cv2 -import numpy as np -import torch -from PIL import Image -from loguru import logger - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles -from fastapi.responses import FileResponse - -from lama_cleaner.const import * -from lama_cleaner.file_manager import FileManager -from lama_cleaner.model.utils import torch_gc -from lama_cleaner.model_manager import ModelManager -from lama_cleaner.plugins import ( - InteractiveSeg, - RemoveBG, - AnimeSeg, - build_plugins, -) -from lama_cleaner.schema import InpaintRequest -from lama_cleaner.helper import ( - load_img, - numpy_to_bytes, - resize_max_size, - pil_to_bytes, - is_mac, - get_image_ext, concat_alpha_channel, -) - -try: - torch._C._jit_override_can_fuse_on_cpu(False) - torch._C._jit_override_can_fuse_on_gpu(False) - torch._C._jit_set_texpr_fuser_enabled(False) - torch._C._jit_set_nvfuser_enabled(False) -except: - pass - - -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") - -global_config = GlobalConfig() - -def diffuser_callback(i, t, latents): - socketio.emit("diffusion_progress", {"step": i}) - - -def start( - host: str, - port: int, - model: str, - no_half: bool, - cpu_offload: bool, - disable_nsfw_checker, - cpu_textencoder: bool, - device: Device, - gui: bool, - disable_model_switch: bool, - input: Path, - output_dir: Path, - quality: int, - enable_interactive_seg: bool, - interactive_seg_model: InteractiveSegModel, - interactive_seg_device: Device, - enable_remove_bg: bool, - enable_anime_seg: bool, - enable_realesrgan: bool, - realesrgan_device: Device, - realesrgan_model: RealESRGANModel, - enable_gfpgan: bool, - gfpgan_device: Device, - enable_restoreformer: bool, - restoreformer_device: Device, -): - if input: - if not input.exists(): - logger.error(f"invalid --input: {input} not exists") - exit() - if input.is_dir(): - logger.info(f"Initialize file manager") - file_manager = FileManager(app) - app.config["THUMBNAIL_MEDIA_ROOT"] = input - app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( - output_dir, "lama_cleaner_thumbnails" - ) - file_manager.output_dir = output_dir - global_config.file_manager = file_manager - else: - global_config.input_image_path = input - - global_config.image_quality = quality - global_config.disable_model_switch = disable_model_switch - global_config.is_desktop = gui - build_plugins( - global_config, - enable_interactive_seg, - interactive_seg_model, - interactive_seg_device, - enable_remove_bg, - enable_anime_seg, - enable_realesrgan, - realesrgan_device, - realesrgan_model, - enable_gfpgan, - gfpgan_device, - enable_restoreformer, - restoreformer_device, - no_half, - ) - if output_dir: - output_dir = output_dir.expanduser().absolute() - logger.info(f"Image will auto save to output dir: {output_dir}") - if not output_dir.exists(): - logger.info(f"Create output dir: {output_dir}") - output_dir.mkdir(parents=True) - global_config.output_dir = output_dir - - global_config.model_manager = ModelManager( - name=model, - device=torch.device(device), - no_half=no_half, - disable_nsfw=disable_nsfw_checker, - sd_cpu_textencoder=cpu_textencoder, - cpu_offload=cpu_offload, - callback=diffuser_callback, - ) - - if gui: - from flaskwebgui import FlaskUI - - ui = FlaskUI( - app, - socketio=socketio, - width=1200, - height=800, - host=host, - port=port, - close_server_on_exit=True, - idle_interval=60, - ) - ui.run() - else: - socketio.run( - app, - host=host, - port=port, - allow_unsafe_werkzeug=True, - ) diff --git a/web_app/src/components/FileManager.tsx b/web_app/src/components/FileManager.tsx index 54a634c..68f7fb7 100644 --- a/web_app/src/components/FileManager.tsx +++ b/web_app/src/components/FileManager.tsx @@ -48,7 +48,7 @@ const SORT_BY_NAME = "Name" const SORT_BY_CREATED_TIME = "Created time" const SORT_BY_MODIFIED_TIME = "Modified time" -const IMAGE_TAB = "image" +const IMAGE_TAB = "input" const OUTPUT_TAB = "output" const SortByMap = { @@ -158,7 +158,9 @@ export default function FileManager(props: Props) { const newPhotos = filteredFilenames.map((filename: Filename) => { const width = photoWidth const height = filename.height * (width / filename.width) - const src = `${API_ENDPOINT}/media_thumbnail/${tab}/${filename.name}?width=${width}&height=${height}` + const src = `${API_ENDPOINT}/media_thumbnail_file?tab=${tab}&filename=${encodeURIComponent( + filename.name + )}&width=${Math.ceil(width)}&height=${Math.ceil(height)}` return { src, height, width, name: filename.name } }) setPhotos(newPhotos) diff --git a/web_app/src/components/Header.tsx b/web_app/src/components/Header.tsx index 1b4ae5a..05acce6 100644 --- a/web_app/src/components/Header.tsx +++ b/web_app/src/components/Header.tsx @@ -71,8 +71,16 @@ const Header = () => { { - const newFile = await getMediaFile(tab, filename) - setFile(newFile) + try { + const newFile = await getMediaFile(tab, filename) + setFile(newFile) + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + return + } }} /> ) : ( diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index 8422da0..a450c4f 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -115,19 +115,15 @@ export function SettingsDialog() { updateAppState({ disableShortCuts: true }) switchModel(model.name) .then((res) => { - if (res.ok) { - toast({ - title: `Switch to ${model.name} success`, - }) - setAppModel(model) - } else { - throw new Error("Server error") - } + toast({ + title: `Switch to ${model.name} success`, + }) + setAppModel(model) }) - .catch(() => { + .catch((error: any) => { toast({ variant: "destructive", - title: `Switch to ${model.name} failed`, + title: `Switch to ${model.name} failed: ${error}`, }) setModel(settings.model) }) @@ -168,17 +164,21 @@ export function SettingsDialog() { .filter((info) => model_types.includes(info.model_type)) .map((info: ModelInfo) => { return ( -
onModelSelect(info)}> +
onModelSelect(info)} + className="px-2" + >
{info.name}
- +
) }) diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 09e41e0..4dc5b4e 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -1,7 +1,7 @@ import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types" import { Settings } from "@/lib/states" import { convertToBase64, srcToFile } from "@/lib/utils" -import axios, { AxiosError } from "axios" +import axios from "axios" export const API_ENDPOINT = import.meta.env.VITE_BACKEND ? import.meta.env.VITE_BACKEND @@ -93,12 +93,7 @@ export function getServerConfig() { } export function switchModel(name: string) { - const fd = new FormData() - fd.append("name", name) - return fetch(`${API_ENDPOINT}/model`, { - method: "POST", - body: fd, - }) + return axios.post(`${API_ENDPOINT}/model`, { name }) } export function currentModel() { @@ -151,14 +146,18 @@ export async function runPlugin( export async function getMediaFile(tab: string, filename: string) { const res = await fetch( - `${API_ENDPOINT}/media/${tab}/${encodeURIComponent(filename)}`, + `${API_ENDPOINT}/media_file?tab=${tab}&filename=${encodeURIComponent( + filename + )}`, { method: "GET", } ) if (res.ok) { const blob = await res.blob() - const file = new File([blob], filename) + const file = new File([blob], filename, { + type: res.headers.get("Content-Type") ?? "image/png", + }) return file } const errMsg = await res.text() @@ -166,15 +165,8 @@ export async function getMediaFile(tab: string, filename: string) { } export async function getMedias(tab: string): Promise { - const res = await fetch(`${API_ENDPOINT}/medias/${tab}`, { - method: "GET", - }) - if (res.ok) { - const filenames = await res.json() - return filenames - } - const errMsg = await res.text() - throw new Error(errMsg) + const res = await axios.get(`${API_ENDPOINT}/medias`, { params: { tab } }) + return res.data } export async function downloadToOutput( @@ -192,7 +184,6 @@ export async function downloadToOutput( method: "POST", body: fd, }) - console.log(res.ok) if (!res.ok) { const errMsg = await res.text() throw new Error(errMsg)