2024-01-03 02:03:04 +01:00
|
|
|
import asyncio
|
2023-12-30 16:36:44 +01:00
|
|
|
import os
|
|
|
|
import threading
|
|
|
|
import time
|
|
|
|
import traceback
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Optional, Dict, List
|
|
|
|
|
|
|
|
import cv2
|
2024-01-03 02:03:04 +01:00
|
|
|
import numpy as np
|
2024-01-02 10:13:11 +01:00
|
|
|
import socketio
|
2023-12-30 16:36:44 +01:00
|
|
|
import torch
|
|
|
|
import uvicorn
|
2024-01-03 02:03:04 +01:00
|
|
|
from PIL import Image
|
2023-12-30 16:36:44 +01:00
|
|
|
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
|
|
from fastapi.exceptions import HTTPException
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.responses import JSONResponse, FileResponse, Response
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
2024-01-03 02:03:04 +01:00
|
|
|
from loguru import logger
|
|
|
|
from socketio import AsyncServer
|
2023-12-30 16:36:44 +01:00
|
|
|
|
2024-01-05 08:19:23 +01:00
|
|
|
from iopaint.file_manager import FileManager
|
|
|
|
from iopaint.helper import (
|
2023-12-30 16:36:44 +01:00
|
|
|
load_img,
|
|
|
|
decode_base64_to_image,
|
|
|
|
pil_to_bytes,
|
|
|
|
numpy_to_bytes,
|
|
|
|
concat_alpha_channel,
|
2024-01-02 15:32:40 +01:00
|
|
|
gen_frontend_mask,
|
2024-01-05 07:57:30 +01:00
|
|
|
adjust_mask,
|
2023-12-30 16:36:44 +01:00
|
|
|
)
|
2024-01-05 08:19:23 +01:00
|
|
|
from iopaint.model.utils import torch_gc
|
|
|
|
from iopaint.model_info import ModelInfo
|
|
|
|
from iopaint.model_manager import ModelManager
|
|
|
|
from iopaint.plugins import build_plugins
|
|
|
|
from iopaint.plugins.base_plugin import BasePlugin
|
|
|
|
from iopaint.schema import (
|
2023-12-30 16:36:44 +01:00
|
|
|
GenInfoResponse,
|
|
|
|
ApiConfig,
|
|
|
|
ServerConfigResponse,
|
|
|
|
SwitchModelRequest,
|
|
|
|
InpaintRequest,
|
|
|
|
RunPluginRequest,
|
2024-01-02 07:34:36 +01:00
|
|
|
SDSampler,
|
2024-01-02 15:32:40 +01:00
|
|
|
PluginInfo,
|
2024-01-05 07:57:30 +01:00
|
|
|
AdjustMaskRequest,
|
2023-12-30 16:36:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
|
|
|
WEB_APP_DIR = CURRENT_DIR / "web_app"
|
|
|
|
|
|
|
|
|
|
|
|
def api_middleware(app: FastAPI):
|
|
|
|
rich_available = False
|
|
|
|
try:
|
|
|
|
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
|
|
|
import anyio # importing just so it can be placed on silent list
|
|
|
|
import starlette # importing just so it can be placed on silent list
|
|
|
|
from rich.console import Console
|
|
|
|
|
|
|
|
console = Console()
|
|
|
|
rich_available = True
|
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def handle_exception(request: Request, e: Exception):
|
|
|
|
err = {
|
|
|
|
"error": type(e).__name__,
|
|
|
|
"detail": vars(e).get("detail", ""),
|
|
|
|
"body": vars(e).get("body", ""),
|
|
|
|
"errors": str(e),
|
|
|
|
}
|
|
|
|
if not isinstance(
|
|
|
|
e, HTTPException
|
|
|
|
): # do not print backtrace on known httpexceptions
|
|
|
|
message = f"API error: {request.method}: {request.url} {err}"
|
|
|
|
if rich_available:
|
|
|
|
print(message)
|
|
|
|
console.print_exception(
|
|
|
|
show_locals=True,
|
|
|
|
max_frames=2,
|
|
|
|
extra_lines=1,
|
|
|
|
suppress=[anyio, starlette],
|
|
|
|
word_wrap=False,
|
|
|
|
width=min([console.width, 200]),
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
traceback.print_exc()
|
|
|
|
return JSONResponse(
|
|
|
|
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
|
|
|
|
)
|
|
|
|
|
|
|
|
@app.middleware("http")
|
|
|
|
async def exception_handling(request: Request, call_next):
|
|
|
|
try:
|
|
|
|
return await call_next(request)
|
|
|
|
except Exception as e:
|
|
|
|
return handle_exception(request, e)
|
|
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
|
|
async def fastapi_exception_handler(request: Request, e: Exception):
|
|
|
|
return handle_exception(request, e)
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
|
|
async def http_exception_handler(request: Request, e: HTTPException):
|
|
|
|
return handle_exception(request, e)
|
|
|
|
|
|
|
|
cors_options = {
|
|
|
|
"allow_methods": ["*"],
|
|
|
|
"allow_headers": ["*"],
|
|
|
|
"allow_origins": ["*"],
|
|
|
|
"allow_credentials": True,
|
|
|
|
}
|
|
|
|
app.add_middleware(CORSMiddleware, **cors_options)
|
|
|
|
|
|
|
|
|
2024-01-02 10:13:11 +01:00
|
|
|
global_sio: AsyncServer = None
|
|
|
|
|
|
|
|
|
2024-01-03 02:03:04 +01:00
|
|
|
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
|
2024-01-02 10:13:11 +01:00
|
|
|
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
|
|
|
|
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
|
|
|
|
|
|
|
|
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
|
|
|
|
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
|
|
|
|
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
2023-12-30 16:36:44 +01:00
|
|
|
class Api:
|
|
|
|
def __init__(self, app: FastAPI, config: ApiConfig):
|
|
|
|
self.app = app
|
|
|
|
self.config = config
|
|
|
|
self.router = APIRouter()
|
|
|
|
self.queue_lock = threading.Lock()
|
|
|
|
api_middleware(self.app)
|
|
|
|
|
|
|
|
self.file_manager = self._build_file_manager()
|
|
|
|
self.plugins = self._build_plugins()
|
|
|
|
self.model_manager = self._build_model_manager()
|
|
|
|
|
|
|
|
# fmt: off
|
|
|
|
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
|
|
|
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
|
|
|
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
|
|
|
|
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
|
|
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
|
|
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
|
|
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
2024-01-02 15:32:40 +01:00
|
|
|
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
|
|
|
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
2024-01-02 07:34:36 +01:00
|
|
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
2024-01-05 07:57:30 +01:00
|
|
|
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
2023-12-30 16:36:44 +01:00
|
|
|
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
|
|
|
# fmt: on
|
|
|
|
|
2024-01-02 10:13:11 +01:00
|
|
|
global global_sio
|
|
|
|
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
|
|
|
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
|
|
|
self.app.mount("/ws", self.combined_asgi_app)
|
|
|
|
global_sio = self.sio
|
|
|
|
|
2023-12-30 16:36:44 +01:00
|
|
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
|
|
|
return self.app.add_api_route(path, endpoint, **kwargs)
|
|
|
|
|
|
|
|
def api_models(self) -> List[ModelInfo]:
|
|
|
|
return self.model_manager.scan_models()
|
|
|
|
|
|
|
|
def api_current_model(self) -> ModelInfo:
|
|
|
|
return self.model_manager.current_model
|
|
|
|
|
|
|
|
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
|
|
|
|
if req.name == self.model_manager.name:
|
|
|
|
return self.model_manager.current_model
|
|
|
|
self.model_manager.switch(req.name)
|
|
|
|
return self.model_manager.current_model
|
|
|
|
|
|
|
|
def api_server_config(self) -> ServerConfigResponse:
|
|
|
|
return ServerConfigResponse(
|
2024-01-02 15:32:40 +01:00
|
|
|
plugins=[
|
|
|
|
PluginInfo(
|
|
|
|
name=it.name,
|
|
|
|
support_gen_image=it.support_gen_image,
|
|
|
|
support_gen_mask=it.support_gen_mask,
|
|
|
|
)
|
|
|
|
for it in self.plugins.values()
|
|
|
|
],
|
2023-12-30 16:36:44 +01:00
|
|
|
enableFileManager=self.file_manager is not None,
|
|
|
|
enableAutoSaving=self.config.output_dir is not None,
|
|
|
|
enableControlnet=self.model_manager.enable_controlnet,
|
|
|
|
controlnetMethod=self.model_manager.controlnet_method,
|
|
|
|
disableModelSwitch=self.config.disable_model_switch,
|
|
|
|
isDesktop=self.config.gui,
|
2024-01-02 07:34:36 +01:00
|
|
|
samplers=self.api_samplers(),
|
2023-12-30 16:36:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
def api_input_image(self) -> FileResponse:
|
|
|
|
if self.config.input and self.config.input.is_file():
|
|
|
|
return FileResponse(self.config.input)
|
|
|
|
raise HTTPException(status_code=404, detail="Input image not found")
|
|
|
|
|
|
|
|
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
|
|
|
|
_, _, info = load_img(file.file.read(), return_info=True)
|
|
|
|
parts = info.get("parameters", "").split("Negative prompt: ")
|
|
|
|
prompt = parts[0].strip()
|
|
|
|
negative_prompt = ""
|
|
|
|
if len(parts) > 1:
|
|
|
|
negative_prompt = parts[1].split("\n")[0].strip()
|
|
|
|
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
|
|
|
|
|
|
|
def api_inpaint(self, req: InpaintRequest):
|
|
|
|
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
|
|
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
2024-01-01 09:05:34 +01:00
|
|
|
|
2023-12-30 16:36:44 +01:00
|
|
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
|
|
|
if image.shape[:2] != mask.shape[:2]:
|
|
|
|
raise HTTPException(
|
|
|
|
400,
|
|
|
|
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
|
|
|
)
|
|
|
|
|
|
|
|
if req.paint_by_example_example_image:
|
|
|
|
paint_by_example_image, _, _ = decode_base64_to_image(
|
|
|
|
req.paint_by_example_example_image
|
|
|
|
)
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
rgb_np_img = self.model_manager(image, mask, req)
|
|
|
|
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
|
|
|
torch_gc()
|
|
|
|
|
|
|
|
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
|
|
|
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
|
|
|
|
|
|
|
ext = "png"
|
|
|
|
res_img_bytes = pil_to_bytes(
|
|
|
|
Image.fromarray(rgb_res),
|
|
|
|
ext=ext,
|
|
|
|
quality=self.config.quality,
|
|
|
|
infos=infos,
|
|
|
|
)
|
2024-01-02 10:13:11 +01:00
|
|
|
|
|
|
|
asyncio.run(self.sio.emit("diffusion_finish"))
|
|
|
|
|
2023-12-30 16:36:44 +01:00
|
|
|
return Response(
|
|
|
|
content=res_img_bytes,
|
|
|
|
media_type=f"image/{ext}",
|
|
|
|
headers={"X-Seed": str(req.sd_seed)},
|
|
|
|
)
|
|
|
|
|
2024-01-02 15:32:40 +01:00
|
|
|
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
2024-01-02 04:07:35 +01:00
|
|
|
ext = "png"
|
2023-12-30 16:36:44 +01:00
|
|
|
if req.name not in self.plugins:
|
2024-01-02 15:32:40 +01:00
|
|
|
raise HTTPException(status_code=422, detail="Plugin not found")
|
|
|
|
if not self.plugins[req.name].support_gen_image:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=422, detail="Plugin does not support output image"
|
|
|
|
)
|
2024-01-02 04:07:35 +01:00
|
|
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
2024-01-02 15:32:40 +01:00
|
|
|
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
2023-12-30 16:36:44 +01:00
|
|
|
torch_gc()
|
2024-01-02 15:32:40 +01:00
|
|
|
|
|
|
|
if bgr_or_rgba_np_img.shape[2] == 4:
|
|
|
|
rgba_np_img = bgr_or_rgba_np_img
|
2023-12-30 16:36:44 +01:00
|
|
|
else:
|
2024-01-02 15:32:40 +01:00
|
|
|
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
2024-01-02 04:07:35 +01:00
|
|
|
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
2023-12-30 16:36:44 +01:00
|
|
|
|
|
|
|
return Response(
|
|
|
|
content=pil_to_bytes(
|
2024-01-02 04:07:35 +01:00
|
|
|
Image.fromarray(rgba_np_img),
|
2023-12-30 16:36:44 +01:00
|
|
|
ext=ext,
|
|
|
|
quality=self.config.quality,
|
|
|
|
infos=infos,
|
|
|
|
),
|
|
|
|
media_type=f"image/{ext}",
|
|
|
|
)
|
|
|
|
|
2024-01-02 15:32:40 +01:00
|
|
|
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
|
|
|
if req.name not in self.plugins:
|
|
|
|
raise HTTPException(status_code=422, detail="Plugin not found")
|
|
|
|
if not self.plugins[req.name].support_gen_mask:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=422, detail="Plugin does not support output image"
|
|
|
|
)
|
|
|
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
|
|
|
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
|
|
|
torch_gc()
|
|
|
|
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
|
|
|
return Response(
|
|
|
|
content=numpy_to_bytes(res_mask, "png"),
|
|
|
|
media_type="image/png",
|
|
|
|
)
|
|
|
|
|
2024-01-02 07:34:36 +01:00
|
|
|
def api_samplers(self) -> List[str]:
|
|
|
|
return [member.value for member in SDSampler.__members__.values()]
|
|
|
|
|
2024-01-05 07:57:30 +01:00
|
|
|
def api_adjust_mask(self, req: AdjustMaskRequest):
|
|
|
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
|
|
|
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
|
|
|
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
|
|
|
|
2023-12-30 16:36:44 +01:00
|
|
|
def launch(self):
|
|
|
|
self.app.include_router(self.router)
|
|
|
|
uvicorn.run(
|
2024-01-02 10:13:11 +01:00
|
|
|
self.combined_asgi_app,
|
2023-12-30 16:36:44 +01:00
|
|
|
host=self.config.host,
|
|
|
|
port=self.config.port,
|
2024-01-02 10:13:11 +01:00
|
|
|
timeout_keep_alive=999999999,
|
2023-12-30 16:36:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
def _build_file_manager(self) -> Optional[FileManager]:
|
|
|
|
if self.config.input and self.config.input.is_dir():
|
|
|
|
logger.info(
|
|
|
|
f"Input is directory, initialize file manager {self.config.input}"
|
|
|
|
)
|
|
|
|
|
|
|
|
return FileManager(
|
|
|
|
app=self.app,
|
|
|
|
input_dir=self.config.input,
|
|
|
|
output_dir=self.config.output_dir,
|
|
|
|
)
|
|
|
|
return None
|
|
|
|
|
2024-01-02 15:32:40 +01:00
|
|
|
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
2023-12-30 16:36:44 +01:00
|
|
|
return build_plugins(
|
|
|
|
self.config.enable_interactive_seg,
|
|
|
|
self.config.interactive_seg_model,
|
|
|
|
self.config.interactive_seg_device,
|
|
|
|
self.config.enable_remove_bg,
|
|
|
|
self.config.enable_anime_seg,
|
|
|
|
self.config.enable_realesrgan,
|
|
|
|
self.config.realesrgan_device,
|
|
|
|
self.config.realesrgan_model,
|
|
|
|
self.config.enable_gfpgan,
|
|
|
|
self.config.gfpgan_device,
|
|
|
|
self.config.enable_restoreformer,
|
|
|
|
self.config.restoreformer_device,
|
|
|
|
self.config.no_half,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _build_model_manager(self):
|
|
|
|
return ModelManager(
|
|
|
|
name=self.config.model,
|
|
|
|
device=torch.device(self.config.device),
|
|
|
|
no_half=self.config.no_half,
|
|
|
|
disable_nsfw=self.config.disable_nsfw_checker,
|
|
|
|
sd_cpu_textencoder=self.config.cpu_textencoder,
|
|
|
|
cpu_offload=self.config.cpu_offload,
|
2024-01-02 10:13:11 +01:00
|
|
|
callback=diffuser_callback,
|
2023-12-30 16:36:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-01-05 08:19:23 +01:00
|
|
|
from iopaint.schema import InteractiveSegModel, RealESRGANModel
|
2023-12-30 16:36:44 +01:00
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
api = Api(
|
|
|
|
app,
|
|
|
|
ApiConfig(
|
|
|
|
host="127.0.0.1",
|
|
|
|
port=8080,
|
|
|
|
model="lama",
|
|
|
|
no_half=False,
|
|
|
|
cpu_offload=False,
|
|
|
|
disable_nsfw_checker=False,
|
|
|
|
cpu_textencoder=False,
|
|
|
|
device="cpu",
|
|
|
|
gui=False,
|
|
|
|
disable_model_switch=False,
|
|
|
|
input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images",
|
|
|
|
output_dir="/Users/cwq/code/github/lama-cleaner/tmp",
|
|
|
|
quality=100,
|
|
|
|
enable_interactive_seg=False,
|
|
|
|
interactive_seg_model=InteractiveSegModel.vit_b,
|
|
|
|
interactive_seg_device="cpu",
|
|
|
|
enable_remove_bg=False,
|
|
|
|
enable_anime_seg=False,
|
|
|
|
enable_realesrgan=False,
|
|
|
|
realesrgan_device="cpu",
|
|
|
|
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
|
|
|
enable_gfpgan=False,
|
|
|
|
gfpgan_device="cpu",
|
|
|
|
enable_restoreformer=False,
|
|
|
|
restoreformer_device="cpu",
|
|
|
|
),
|
|
|
|
)
|
|
|
|
api.launch()
|