IOPaint/iopaint/api.py

397 lines
15 KiB
Python
Raw Normal View History

2024-01-03 02:03:04 +01:00
import asyncio
2023-12-30 16:36:44 +01:00
import os
import threading
import time
import traceback
from pathlib import Path
from typing import Optional, Dict, List
import cv2
2024-01-03 02:03:04 +01:00
import numpy as np
2024-01-02 10:13:11 +01:00
import socketio
2023-12-30 16:36:44 +01:00
import torch
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
2023-12-30 16:36:44 +01:00
import uvicorn
2024-01-03 02:03:04 +01:00
from PIL import Image
2023-12-30 16:36:44 +01:00
from fastapi import APIRouter, FastAPI, Request, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse, Response
from fastapi.staticfiles import StaticFiles
2024-01-03 02:03:04 +01:00
from loguru import logger
from socketio import AsyncServer
2023-12-30 16:36:44 +01:00
2024-01-05 08:19:23 +01:00
from iopaint.file_manager import FileManager
from iopaint.helper import (
2023-12-30 16:36:44 +01:00
load_img,
decode_base64_to_image,
pil_to_bytes,
numpy_to_bytes,
concat_alpha_channel,
gen_frontend_mask,
2024-01-05 07:57:30 +01:00
adjust_mask,
2023-12-30 16:36:44 +01:00
)
2024-01-05 08:19:23 +01:00
from iopaint.model.utils import torch_gc
from iopaint.model_manager import ModelManager
2024-02-10 05:34:56 +01:00
from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
2024-01-05 08:19:23 +01:00
from iopaint.plugins.base_plugin import BasePlugin
2024-02-08 09:49:54 +01:00
from iopaint.plugins.remove_bg import RemoveBG
2024-01-05 08:19:23 +01:00
from iopaint.schema import (
2023-12-30 16:36:44 +01:00
GenInfoResponse,
ApiConfig,
ServerConfigResponse,
SwitchModelRequest,
InpaintRequest,
RunPluginRequest,
2024-01-02 07:34:36 +01:00
SDSampler,
PluginInfo,
2024-01-05 07:57:30 +01:00
AdjustMaskRequest,
2024-02-08 09:49:54 +01:00
RemoveBGModel,
SwitchPluginModelRequest,
ModelInfo,
2024-02-10 05:34:56 +01:00
InteractiveSegModel,
2024-02-08 10:16:57 +01:00
RealESRGANModel,
2023-12-30 16:36:44 +01:00
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
WEB_APP_DIR = CURRENT_DIR / "web_app"
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
traceback.print_exc()
return JSONResponse(
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_origins": ["*"],
"allow_credentials": True,
}
app.add_middleware(CORSMiddleware, **cors_options)
2024-01-02 10:13:11 +01:00
global_sio: AsyncServer = None
2024-01-03 02:03:04 +01:00
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
2024-01-02 10:13:11 +01:00
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
return {}
2023-12-30 16:36:44 +01:00
class Api:
def __init__(self, app: FastAPI, config: ApiConfig):
self.app = app
self.config = config
self.router = APIRouter()
self.queue_lock = threading.Lock()
api_middleware(self.app)
self.file_manager = self._build_file_manager()
self.plugins = self._build_plugins()
self.model_manager = self._build_model_manager()
# fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
2024-02-08 09:49:54 +01:00
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
2024-01-02 07:34:36 +01:00
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
2024-01-05 07:57:30 +01:00
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
2024-02-09 08:17:28 +01:00
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
2023-12-30 16:36:44 +01:00
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
2024-01-02 10:13:11 +01:00
global global_sio
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
self.app.mount("/ws", self.combined_asgi_app)
global_sio = self.sio
2023-12-30 16:36:44 +01:00
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
2024-02-09 08:17:28 +01:00
def api_save_image(self, file: UploadFile):
filename = file.filename
origin_image_bytes = file.file.read()
with open(self.config.output_dir / filename, "wb") as fw:
fw.write(origin_image_bytes)
2023-12-30 16:36:44 +01:00
def api_current_model(self) -> ModelInfo:
return self.model_manager.current_model
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
if req.name == self.model_manager.name:
return self.model_manager.current_model
self.model_manager.switch(req.name)
return self.model_manager.current_model
2024-02-08 09:49:54 +01:00
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
if req.plugin_name in self.plugins:
self.plugins[req.plugin_name].switch_model(req.model_name)
if req.plugin_name == RemoveBG.name:
self.config.remove_bg_model = req.model_name
2024-02-08 10:16:57 +01:00
if req.plugin_name == RealESRGANUpscaler.name:
self.config.realesrgan_model = req.model_name
2024-02-10 05:34:56 +01:00
if req.plugin_name == InteractiveSeg.name:
self.config.interactive_seg_model = req.model_name
torch_gc()
2024-02-08 09:49:54 +01:00
2023-12-30 16:36:44 +01:00
def api_server_config(self) -> ServerConfigResponse:
2024-02-08 09:49:54 +01:00
plugins = []
for it in self.plugins.values():
plugins.append(
PluginInfo(
name=it.name,
support_gen_image=it.support_gen_image,
support_gen_mask=it.support_gen_mask,
)
2024-02-08 09:49:54 +01:00
)
return ServerConfigResponse(
plugins=plugins,
modelInfos=self.model_manager.scan_models(),
removeBGModel=self.config.remove_bg_model,
removeBGModels=RemoveBGModel.values(),
2024-02-08 10:16:57 +01:00
realesrganModel=self.config.realesrgan_model,
realesrganModels=RealESRGANModel.values(),
2024-02-10 05:34:56 +01:00
interactiveSegModel=self.config.interactive_seg_model,
interactiveSegModels=InteractiveSegModel.values(),
2023-12-30 16:36:44 +01:00
enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet,
controlnetMethod=self.model_manager.controlnet_method,
2024-01-31 14:51:34 +01:00
disableModelSwitch=False,
isDesktop=False,
2024-01-02 07:34:36 +01:00
samplers=self.api_samplers(),
2023-12-30 16:36:44 +01:00
)
def api_input_image(self) -> FileResponse:
if self.config.input and self.config.input.is_file():
return FileResponse(self.config.input)
raise HTTPException(status_code=404, detail="Input image not found")
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
_, _, info = load_img(file.file.read(), return_info=True)
parts = info.get("parameters", "").split("Negative prompt: ")
prompt = parts[0].strip()
negative_prompt = ""
if len(parts) > 1:
negative_prompt = parts[1].split("\n")[0].strip()
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
def api_inpaint(self, req: InpaintRequest):
image, alpha_channel, infos = decode_base64_to_image(req.image)
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
2024-01-01 09:05:34 +01:00
2023-12-30 16:36:44 +01:00
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
raise HTTPException(
400,
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
)
if req.paint_by_example_example_image:
paint_by_example_image, _, _ = decode_base64_to_image(
req.paint_by_example_example_image
)
start = time.time()
rgb_np_img = self.model_manager(image, mask, req)
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
torch_gc()
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
ext = "png"
res_img_bytes = pil_to_bytes(
Image.fromarray(rgb_res),
ext=ext,
quality=self.config.quality,
infos=infos,
)
2024-01-02 10:13:11 +01:00
asyncio.run(self.sio.emit("diffusion_finish"))
2023-12-30 16:36:44 +01:00
return Response(
content=res_img_bytes,
media_type=f"image/{ext}",
headers={"X-Seed": str(req.sd_seed)},
)
def api_run_plugin_gen_image(self, req: RunPluginRequest):
2024-01-02 04:07:35 +01:00
ext = "png"
2023-12-30 16:36:44 +01:00
if req.name not in self.plugins:
raise HTTPException(status_code=422, detail="Plugin not found")
if not self.plugins[req.name].support_gen_image:
raise HTTPException(
status_code=422, detail="Plugin does not support output image"
)
2024-01-02 04:07:35 +01:00
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
2023-12-30 16:36:44 +01:00
torch_gc()
if bgr_or_rgba_np_img.shape[2] == 4:
rgba_np_img = bgr_or_rgba_np_img
2023-12-30 16:36:44 +01:00
else:
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
2024-01-02 04:07:35 +01:00
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
2023-12-30 16:36:44 +01:00
return Response(
content=pil_to_bytes(
2024-01-02 04:07:35 +01:00
Image.fromarray(rgba_np_img),
2023-12-30 16:36:44 +01:00
ext=ext,
quality=self.config.quality,
infos=infos,
),
media_type=f"image/{ext}",
)
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
if req.name not in self.plugins:
raise HTTPException(status_code=422, detail="Plugin not found")
if not self.plugins[req.name].support_gen_mask:
raise HTTPException(
status_code=422, detail="Plugin does not support output image"
)
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
torch_gc()
res_mask = gen_frontend_mask(bgr_or_gray_mask)
return Response(
content=numpy_to_bytes(res_mask, "png"),
media_type="image/png",
)
2024-01-02 07:34:36 +01:00
def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()]
2024-01-05 07:57:30 +01:00
def api_adjust_mask(self, req: AdjustMaskRequest):
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
mask = adjust_mask(mask, req.kernel_size, req.operate)
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
2023-12-30 16:36:44 +01:00
def launch(self):
self.app.include_router(self.router)
uvicorn.run(
2024-01-02 10:13:11 +01:00
self.combined_asgi_app,
2023-12-30 16:36:44 +01:00
host=self.config.host,
port=self.config.port,
2024-01-02 10:13:11 +01:00
timeout_keep_alive=999999999,
2023-12-30 16:36:44 +01:00
)
def _build_file_manager(self) -> Optional[FileManager]:
if self.config.input and self.config.input.is_dir():
logger.info(
f"Input is directory, initialize file manager {self.config.input}"
)
return FileManager(
app=self.app,
input_dir=self.config.input,
output_dir=self.config.output_dir,
)
return None
def _build_plugins(self) -> Dict[str, BasePlugin]:
2023-12-30 16:36:44 +01:00
return build_plugins(
self.config.enable_interactive_seg,
self.config.interactive_seg_model,
self.config.interactive_seg_device,
self.config.enable_remove_bg,
2024-02-08 09:49:54 +01:00
self.config.remove_bg_model,
2023-12-30 16:36:44 +01:00
self.config.enable_anime_seg,
self.config.enable_realesrgan,
self.config.realesrgan_device,
self.config.realesrgan_model,
self.config.enable_gfpgan,
self.config.gfpgan_device,
self.config.enable_restoreformer,
self.config.restoreformer_device,
self.config.no_half,
)
def _build_model_manager(self):
return ModelManager(
name=self.config.model,
device=torch.device(self.config.device),
no_half=self.config.no_half,
2024-01-08 16:54:20 +01:00
low_mem=self.config.low_mem,
2023-12-30 16:36:44 +01:00
disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder,
local_files_only=self.config.local_files_only,
2023-12-30 16:36:44 +01:00
cpu_offload=self.config.cpu_offload,
2024-01-02 10:13:11 +01:00
callback=diffuser_callback,
2023-12-30 16:36:44 +01:00
)