70af4845af
new file: inpaint/__main__.py new file: inpaint/api.py new file: inpaint/batch_processing.py new file: inpaint/benchmark.py new file: inpaint/cli.py new file: inpaint/const.py new file: inpaint/download.py new file: inpaint/file_manager/__init__.py new file: inpaint/file_manager/file_manager.py new file: inpaint/file_manager/storage_backends.py new file: inpaint/file_manager/utils.py new file: inpaint/helper.py new file: inpaint/installer.py new file: inpaint/model/__init__.py new file: inpaint/model/anytext/__init__.py new file: inpaint/model/anytext/anytext_model.py new file: inpaint/model/anytext/anytext_pipeline.py new file: inpaint/model/anytext/anytext_sd15.yaml new file: inpaint/model/anytext/cldm/__init__.py new file: inpaint/model/anytext/cldm/cldm.py new file: inpaint/model/anytext/cldm/ddim_hacked.py new file: inpaint/model/anytext/cldm/embedding_manager.py new file: inpaint/model/anytext/cldm/hack.py new file: inpaint/model/anytext/cldm/model.py new file: inpaint/model/anytext/cldm/recognizer.py new file: inpaint/model/anytext/ldm/__init__.py new file: inpaint/model/anytext/ldm/models/__init__.py new file: inpaint/model/anytext/ldm/models/autoencoder.py new file: inpaint/model/anytext/ldm/models/diffusion/__init__.py new file: inpaint/model/anytext/ldm/models/diffusion/ddim.py new file: inpaint/model/anytext/ldm/models/diffusion/ddpm.py new file: inpaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py new file: inpaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py new file: inpaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py new file: inpaint/model/anytext/ldm/models/diffusion/plms.py new file: inpaint/model/anytext/ldm/models/diffusion/sampling_util.py new file: inpaint/model/anytext/ldm/modules/__init__.py new file: inpaint/model/anytext/ldm/modules/attention.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/__init__.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/model.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/util.py new file: inpaint/model/anytext/ldm/modules/distributions/__init__.py new file: inpaint/model/anytext/ldm/modules/distributions/distributions.py new file: inpaint/model/anytext/ldm/modules/ema.py new file: inpaint/model/anytext/ldm/modules/encoders/__init__.py new file: inpaint/model/anytext/ldm/modules/encoders/modules.py new file: inpaint/model/anytext/ldm/util.py new file: inpaint/model/anytext/main.py new file: inpaint/model/anytext/ocr_recog/RNN.py new file: inpaint/model/anytext/ocr_recog/RecCTCHead.py new file: inpaint/model/anytext/ocr_recog/RecModel.py new file: inpaint/model/anytext/ocr_recog/RecMv1_enhance.py new file: inpaint/model/anytext/ocr_recog/RecSVTR.py new file: inpaint/model/anytext/ocr_recog/__init__.py new file: inpaint/model/anytext/ocr_recog/common.py new file: inpaint/model/anytext/ocr_recog/en_dict.txt new file: inpaint/model/anytext/ocr_recog/ppocr_keys_v1.txt new file: inpaint/model/anytext/utils.py new file: inpaint/model/base.py new file: inpaint/model/brushnet/__init__.py new file: inpaint/model/brushnet/brushnet.py new file: inpaint/model/brushnet/brushnet_unet_forward.py new file: inpaint/model/brushnet/brushnet_wrapper.py new file: inpaint/model/brushnet/pipeline_brushnet.py new file: inpaint/model/brushnet/unet_2d_blocks.py new file: inpaint/model/controlnet.py new file: inpaint/model/ddim_sampler.py new file: inpaint/model/fcf.py new file: inpaint/model/helper/__init__.py new file: inpaint/model/helper/controlnet_preprocess.py new file: inpaint/model/helper/cpu_text_encoder.py new file: inpaint/model/helper/g_diffuser_bot.py new file: inpaint/model/instruct_pix2pix.py new file: inpaint/model/kandinsky.py new file: inpaint/model/lama.py new file: inpaint/model/ldm.py new file: inpaint/model/manga.py new file: inpaint/model/mat.py new file: inpaint/model/mi_gan.py new file: inpaint/model/opencv2.py new file: inpaint/model/original_sd_configs/__init__.py new file: inpaint/model/original_sd_configs/sd_xl_base.yaml new file: inpaint/model/original_sd_configs/sd_xl_refiner.yaml new file: inpaint/model/original_sd_configs/v1-inference.yaml new file: inpaint/model/original_sd_configs/v2-inference-v.yaml new file: inpaint/model/paint_by_example.py new file: inpaint/model/plms_sampler.py new file: inpaint/model/power_paint/__init__.py new file: inpaint/model/power_paint/pipeline_powerpaint.py new file: inpaint/model/power_paint/power_paint.py new file: inpaint/model/power_paint/power_paint_v2.py new file: inpaint/model/power_paint/powerpaint_tokenizer.py
399 lines
15 KiB
Python
399 lines
15 KiB
Python
import asyncio
|
|
import os
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, List
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import socketio
|
|
import torch
|
|
|
|
try:
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
torch._C._jit_set_nvfuser_enabled(False)
|
|
except:
|
|
pass
|
|
|
|
import uvicorn
|
|
from PIL import Image
|
|
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
|
from fastapi.encoders import jsonable_encoder
|
|
from fastapi.exceptions import HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, FileResponse, Response
|
|
from fastapi.staticfiles import StaticFiles
|
|
from loguru import logger
|
|
from socketio import AsyncServer
|
|
|
|
from inpaint.file_manager import FileManager
|
|
from inpaint.helper import (
|
|
load_img,
|
|
decode_base64_to_image,
|
|
pil_to_bytes,
|
|
numpy_to_bytes,
|
|
concat_alpha_channel,
|
|
gen_frontend_mask,
|
|
adjust_mask,
|
|
)
|
|
from inpaint.model.utils import torch_gc
|
|
from inpaint.model_manager import ModelManager
|
|
from inpaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
|
|
from inpaint.plugins.base_plugin import BasePlugin
|
|
from inpaint.plugins.remove_bg import RemoveBG
|
|
from inpaint.schema import (
|
|
GenInfoResponse,
|
|
ApiConfig,
|
|
ServerConfigResponse,
|
|
SwitchModelRequest,
|
|
InpaintRequest,
|
|
RunPluginRequest,
|
|
SDSampler,
|
|
PluginInfo,
|
|
AdjustMaskRequest,
|
|
RemoveBGModel,
|
|
SwitchPluginModelRequest,
|
|
ModelInfo,
|
|
InteractiveSegModel,
|
|
RealESRGANModel,
|
|
)
|
|
|
|
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
|
WEB_APP_DIR = CURRENT_DIR / "web_app"
|
|
|
|
|
|
def api_middleware(app: FastAPI):
|
|
rich_available = False
|
|
try:
|
|
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
|
import anyio # importing just so it can be placed on silent list
|
|
import starlette # importing just so it can be placed on silent list
|
|
from rich.console import Console
|
|
|
|
console = Console()
|
|
rich_available = True
|
|
except Exception:
|
|
pass
|
|
|
|
def handle_exception(request: Request, e: Exception):
|
|
err = {
|
|
"error": type(e).__name__,
|
|
"detail": vars(e).get("detail", ""),
|
|
"body": vars(e).get("body", ""),
|
|
"errors": str(e),
|
|
}
|
|
if not isinstance(
|
|
e, HTTPException
|
|
): # do not print backtrace on known httpexceptions
|
|
message = f"API error: {request.method}: {request.url} {err}"
|
|
if rich_available:
|
|
print(message)
|
|
console.print_exception(
|
|
show_locals=True,
|
|
max_frames=2,
|
|
extra_lines=1,
|
|
suppress=[anyio, starlette],
|
|
word_wrap=False,
|
|
width=min([console.width, 200]),
|
|
)
|
|
else:
|
|
traceback.print_exc()
|
|
return JSONResponse(
|
|
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
|
|
)
|
|
|
|
@app.middleware("http")
|
|
async def exception_handling(request: Request, call_next):
|
|
try:
|
|
return await call_next(request)
|
|
except Exception as e:
|
|
return handle_exception(request, e)
|
|
|
|
@app.exception_handler(Exception)
|
|
async def fastapi_exception_handler(request: Request, e: Exception):
|
|
return handle_exception(request, e)
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, e: HTTPException):
|
|
return handle_exception(request, e)
|
|
|
|
cors_options = {
|
|
"allow_methods": ["*"],
|
|
"allow_headers": ["*"],
|
|
"allow_origins": ["*"],
|
|
"allow_credentials": True,
|
|
"expose_headers": ["X-Seed"],
|
|
}
|
|
app.add_middleware(CORSMiddleware, **cors_options)
|
|
|
|
|
|
global_sio: AsyncServer = None
|
|
|
|
|
|
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
|
|
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
|
|
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
|
|
|
|
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
|
|
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
|
|
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
|
|
return {}
|
|
|
|
|
|
class Api:
|
|
def __init__(self, app: FastAPI, config: ApiConfig):
|
|
self.app = app
|
|
self.config = config
|
|
self.router = APIRouter()
|
|
self.queue_lock = threading.Lock()
|
|
api_middleware(self.app)
|
|
|
|
self.file_manager = self._build_file_manager()
|
|
self.plugins = self._build_plugins()
|
|
self.model_manager = self._build_model_manager()
|
|
|
|
# fmt: off
|
|
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
|
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
|
|
response_model=ServerConfigResponse)
|
|
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
|
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
|
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
|
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
|
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
|
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
|
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
|
# fmt: on
|
|
|
|
global global_sio
|
|
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
|
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
|
self.app.mount("/ws", self.combined_asgi_app)
|
|
global_sio = self.sio
|
|
|
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
|
return self.app.add_api_route(path, endpoint, **kwargs)
|
|
|
|
def api_save_image(self, file: UploadFile):
|
|
filename = file.filename
|
|
origin_image_bytes = file.file.read()
|
|
with open(self.config.output_dir / filename, "wb") as fw:
|
|
fw.write(origin_image_bytes)
|
|
|
|
def api_current_model(self) -> ModelInfo:
|
|
return self.model_manager.current_model
|
|
|
|
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
|
|
if req.name == self.model_manager.name:
|
|
return self.model_manager.current_model
|
|
self.model_manager.switch(req.name)
|
|
return self.model_manager.current_model
|
|
|
|
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
|
|
if req.plugin_name in self.plugins:
|
|
self.plugins[req.plugin_name].switch_model(req.model_name)
|
|
if req.plugin_name == RemoveBG.name:
|
|
self.config.remove_bg_model = req.model_name
|
|
if req.plugin_name == RealESRGANUpscaler.name:
|
|
self.config.realesrgan_model = req.model_name
|
|
if req.plugin_name == InteractiveSeg.name:
|
|
self.config.interactive_seg_model = req.model_name
|
|
torch_gc()
|
|
|
|
def api_server_config(self) -> ServerConfigResponse:
|
|
plugins = []
|
|
for it in self.plugins.values():
|
|
plugins.append(
|
|
PluginInfo(
|
|
name=it.name,
|
|
support_gen_image=it.support_gen_image,
|
|
support_gen_mask=it.support_gen_mask,
|
|
)
|
|
)
|
|
|
|
return ServerConfigResponse(
|
|
plugins=plugins,
|
|
modelInfos=self.model_manager.scan_models(),
|
|
removeBGModel=self.config.remove_bg_model,
|
|
removeBGModels=RemoveBGModel.values(),
|
|
realesrganModel=self.config.realesrgan_model,
|
|
realesrganModels=RealESRGANModel.values(),
|
|
interactiveSegModel=self.config.interactive_seg_model,
|
|
interactiveSegModels=InteractiveSegModel.values(),
|
|
enableFileManager=self.file_manager is not None,
|
|
enableAutoSaving=self.config.output_dir is not None,
|
|
enableControlnet=self.model_manager.enable_controlnet,
|
|
controlnetMethod=self.model_manager.controlnet_method,
|
|
disableModelSwitch=False,
|
|
isDesktop=False,
|
|
samplers=self.api_samplers(),
|
|
)
|
|
|
|
def api_input_image(self) -> FileResponse:
|
|
if self.config.input and self.config.input.is_file():
|
|
return FileResponse(self.config.input)
|
|
raise HTTPException(status_code=404, detail="Input image not found")
|
|
|
|
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
|
|
_, _, info = load_img(file.file.read(), return_info=True)
|
|
parts = info.get("parameters", "").split("Negative prompt: ")
|
|
prompt = parts[0].strip()
|
|
negative_prompt = ""
|
|
if len(parts) > 1:
|
|
negative_prompt = parts[1].split("\n")[0].strip()
|
|
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
|
|
|
def api_inpaint(self, req: InpaintRequest):
|
|
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
|
|
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
|
if image.shape[:2] != mask.shape[:2]:
|
|
raise HTTPException(
|
|
400,
|
|
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
|
)
|
|
|
|
if req.paint_by_example_example_image:
|
|
paint_by_example_image, _, _ = decode_base64_to_image(
|
|
req.paint_by_example_example_image
|
|
)
|
|
|
|
start = time.time()
|
|
rgb_np_img = self.model_manager(image, mask, req)
|
|
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
|
torch_gc()
|
|
|
|
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
|
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
|
|
|
ext = "png"
|
|
res_img_bytes = pil_to_bytes(
|
|
Image.fromarray(rgb_res),
|
|
ext=ext,
|
|
quality=self.config.quality,
|
|
infos=infos,
|
|
)
|
|
|
|
asyncio.run(self.sio.emit("diffusion_finish"))
|
|
|
|
return Response(
|
|
content=res_img_bytes,
|
|
media_type=f"image/{ext}",
|
|
headers={"X-Seed": str(req.sd_seed)},
|
|
)
|
|
|
|
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
|
ext = "png"
|
|
if req.name not in self.plugins:
|
|
raise HTTPException(status_code=422, detail="Plugin not found")
|
|
if not self.plugins[req.name].support_gen_image:
|
|
raise HTTPException(
|
|
status_code=422, detail="Plugin does not support output image"
|
|
)
|
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
|
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
|
torch_gc()
|
|
|
|
if bgr_or_rgba_np_img.shape[2] == 4:
|
|
rgba_np_img = bgr_or_rgba_np_img
|
|
else:
|
|
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
|
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
|
|
|
return Response(
|
|
content=pil_to_bytes(
|
|
Image.fromarray(rgba_np_img),
|
|
ext=ext,
|
|
quality=self.config.quality,
|
|
infos=infos,
|
|
),
|
|
media_type=f"image/{ext}",
|
|
)
|
|
|
|
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
|
if req.name not in self.plugins:
|
|
raise HTTPException(status_code=422, detail="Plugin not found")
|
|
if not self.plugins[req.name].support_gen_mask:
|
|
raise HTTPException(
|
|
status_code=422, detail="Plugin does not support output image"
|
|
)
|
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
|
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
|
torch_gc()
|
|
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
|
return Response(
|
|
content=numpy_to_bytes(res_mask, "png"),
|
|
media_type="image/png",
|
|
)
|
|
|
|
def api_samplers(self) -> List[str]:
|
|
return [member.value for member in SDSampler.__members__.values()]
|
|
|
|
def api_adjust_mask(self, req: AdjustMaskRequest):
|
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
|
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
|
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
|
|
|
def launch(self):
|
|
self.app.include_router(self.router)
|
|
uvicorn.run(
|
|
self.combined_asgi_app,
|
|
host=self.config.host,
|
|
port=self.config.port,
|
|
timeout_keep_alive=999999999,
|
|
)
|
|
|
|
def _build_file_manager(self) -> Optional[FileManager]:
|
|
if self.config.input and self.config.input.is_dir():
|
|
logger.info(
|
|
f"Input is directory, initialize file manager {self.config.input}"
|
|
)
|
|
|
|
return FileManager(
|
|
app=self.app,
|
|
input_dir=self.config.input,
|
|
mask_dir=self.config.mask_dir,
|
|
output_dir=self.config.output_dir,
|
|
)
|
|
return None
|
|
|
|
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
|
return build_plugins(
|
|
self.config.enable_interactive_seg,
|
|
self.config.interactive_seg_model,
|
|
self.config.interactive_seg_device,
|
|
self.config.enable_remove_bg,
|
|
self.config.remove_bg_model,
|
|
self.config.enable_anime_seg,
|
|
self.config.enable_realesrgan,
|
|
self.config.realesrgan_device,
|
|
self.config.realesrgan_model,
|
|
self.config.enable_gfpgan,
|
|
self.config.gfpgan_device,
|
|
self.config.enable_restoreformer,
|
|
self.config.restoreformer_device,
|
|
self.config.no_half,
|
|
)
|
|
|
|
def _build_model_manager(self):
|
|
return ModelManager(
|
|
name=self.config.model,
|
|
device=torch.device(self.config.device),
|
|
no_half=self.config.no_half,
|
|
low_mem=self.config.low_mem,
|
|
disable_nsfw=self.config.disable_nsfw_checker,
|
|
sd_cpu_textencoder=self.config.cpu_textencoder,
|
|
local_files_only=self.config.local_files_only,
|
|
cpu_offload=self.config.cpu_offload,
|
|
callback=diffuser_callback,
|
|
)
|