wip
This commit is contained in:
parent
85c3397b97
commit
c4abda3942
323
lama_cleaner/api.py
Normal file
323
lama_cleaner/api.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse, FileResponse, Response
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from lama_cleaner.helper import (
|
||||||
|
load_img,
|
||||||
|
decode_base64_to_image,
|
||||||
|
pil_to_bytes,
|
||||||
|
numpy_to_bytes,
|
||||||
|
concat_alpha_channel,
|
||||||
|
)
|
||||||
|
from lama_cleaner.model.utils import torch_gc
|
||||||
|
from lama_cleaner.model_info import ModelInfo
|
||||||
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
|
||||||
|
from lama_cleaner.schema import (
|
||||||
|
GenInfoResponse,
|
||||||
|
ApiConfig,
|
||||||
|
ServerConfigResponse,
|
||||||
|
SwitchModelRequest,
|
||||||
|
InpaintRequest,
|
||||||
|
RunPluginRequest,
|
||||||
|
)
|
||||||
|
from lama_cleaner.file_manager import FileManager
|
||||||
|
|
||||||
|
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||||
|
WEB_APP_DIR = CURRENT_DIR / "web_app"
|
||||||
|
|
||||||
|
|
||||||
|
def api_middleware(app: FastAPI):
|
||||||
|
rich_available = False
|
||||||
|
try:
|
||||||
|
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
||||||
|
import anyio # importing just so it can be placed on silent list
|
||||||
|
import starlette # importing just so it can be placed on silent list
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
rich_available = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def handle_exception(request: Request, e: Exception):
|
||||||
|
err = {
|
||||||
|
"error": type(e).__name__,
|
||||||
|
"detail": vars(e).get("detail", ""),
|
||||||
|
"body": vars(e).get("body", ""),
|
||||||
|
"errors": str(e),
|
||||||
|
}
|
||||||
|
if not isinstance(
|
||||||
|
e, HTTPException
|
||||||
|
): # do not print backtrace on known httpexceptions
|
||||||
|
message = f"API error: {request.method}: {request.url} {err}"
|
||||||
|
if rich_available:
|
||||||
|
print(message)
|
||||||
|
console.print_exception(
|
||||||
|
show_locals=True,
|
||||||
|
max_frames=2,
|
||||||
|
extra_lines=1,
|
||||||
|
suppress=[anyio, starlette],
|
||||||
|
word_wrap=False,
|
||||||
|
width=min([console.width, 200]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
traceback.print_exc()
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def exception_handling(request: Request, call_next):
|
||||||
|
try:
|
||||||
|
return await call_next(request)
|
||||||
|
except Exception as e:
|
||||||
|
return handle_exception(request, e)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def fastapi_exception_handler(request: Request, e: Exception):
|
||||||
|
return handle_exception(request, e)
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request: Request, e: HTTPException):
|
||||||
|
return handle_exception(request, e)
|
||||||
|
|
||||||
|
cors_options = {
|
||||||
|
"allow_methods": ["*"],
|
||||||
|
"allow_headers": ["*"],
|
||||||
|
"allow_origins": ["*"],
|
||||||
|
"allow_credentials": True,
|
||||||
|
}
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_options)
|
||||||
|
|
||||||
|
|
||||||
|
class Api:
|
||||||
|
def __init__(self, app: FastAPI, config: ApiConfig):
|
||||||
|
self.app = app
|
||||||
|
self.config = config
|
||||||
|
self.router = APIRouter()
|
||||||
|
self.queue_lock = threading.Lock()
|
||||||
|
api_middleware(self.app)
|
||||||
|
|
||||||
|
self.file_manager = self._build_file_manager()
|
||||||
|
self.plugins = self._build_plugins()
|
||||||
|
self.model_manager = self._build_model_manager()
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||||
|
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
||||||
|
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
|
||||||
|
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||||
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||||
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||||
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||||
|
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
|
||||||
|
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||||
|
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||||
|
|
||||||
|
def api_models(self) -> List[ModelInfo]:
|
||||||
|
return self.model_manager.scan_models()
|
||||||
|
|
||||||
|
def api_current_model(self) -> ModelInfo:
|
||||||
|
return self.model_manager.current_model
|
||||||
|
|
||||||
|
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
|
||||||
|
if req.name == self.model_manager.name:
|
||||||
|
return self.model_manager.current_model
|
||||||
|
self.model_manager.switch(req.name)
|
||||||
|
return self.model_manager.current_model
|
||||||
|
|
||||||
|
def api_server_config(self) -> ServerConfigResponse:
|
||||||
|
return ServerConfigResponse(
|
||||||
|
plugins=list(self.plugins.keys()),
|
||||||
|
enableFileManager=self.file_manager is not None,
|
||||||
|
enableAutoSaving=self.config.output_dir is not None,
|
||||||
|
enableControlnet=self.model_manager.enable_controlnet,
|
||||||
|
controlnetMethod=self.model_manager.controlnet_method,
|
||||||
|
disableModelSwitch=self.config.disable_model_switch,
|
||||||
|
isDesktop=self.config.gui,
|
||||||
|
)
|
||||||
|
|
||||||
|
def api_input_image(self) -> FileResponse:
|
||||||
|
if self.config.input and self.config.input.is_file():
|
||||||
|
return FileResponse(self.config.input)
|
||||||
|
raise HTTPException(status_code=404, detail="Input image not found")
|
||||||
|
|
||||||
|
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
|
||||||
|
_, _, info = load_img(file.file.read(), return_info=True)
|
||||||
|
parts = info.get("parameters", "").split("Negative prompt: ")
|
||||||
|
prompt = parts[0].strip()
|
||||||
|
negative_prompt = ""
|
||||||
|
if len(parts) > 1:
|
||||||
|
negative_prompt = parts[1].split("\n")[0].strip()
|
||||||
|
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
||||||
|
|
||||||
|
def api_inpaint(self, req: InpaintRequest):
|
||||||
|
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||||
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||||
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||||
|
if image.shape[:2] != mask.shape[:2]:
|
||||||
|
raise HTTPException(
|
||||||
|
400,
|
||||||
|
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if req.paint_by_example_example_image:
|
||||||
|
paint_by_example_image, _, _ = decode_base64_to_image(
|
||||||
|
req.paint_by_example_example_image
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
rgb_np_img = self.model_manager(image, mask, req)
|
||||||
|
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||||
|
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
||||||
|
|
||||||
|
ext = "png"
|
||||||
|
res_img_bytes = pil_to_bytes(
|
||||||
|
Image.fromarray(rgb_res),
|
||||||
|
ext=ext,
|
||||||
|
quality=self.config.quality,
|
||||||
|
infos=infos,
|
||||||
|
)
|
||||||
|
return Response(
|
||||||
|
content=res_img_bytes,
|
||||||
|
media_type=f"image/{ext}",
|
||||||
|
headers={"X-Seed": str(req.sd_seed)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def api_run_plugin(self, req: RunPluginRequest):
|
||||||
|
if req.name not in self.plugins:
|
||||||
|
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||||
|
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||||
|
bgr_res = self.plugins[req.name].run(image, req)
|
||||||
|
torch_gc()
|
||||||
|
if req.name == InteractiveSeg.name:
|
||||||
|
return Response(
|
||||||
|
content=numpy_to_bytes(bgr_res, "png"),
|
||||||
|
media_type="image/png",
|
||||||
|
)
|
||||||
|
ext = "png"
|
||||||
|
if req.name in [RemoveBG.name, AnimeSeg.name]:
|
||||||
|
rgb_res = bgr_res
|
||||||
|
else:
|
||||||
|
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
|
||||||
|
rgb_res = concat_alpha_channel(rgb_res, alpha_channel)
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=pil_to_bytes(
|
||||||
|
Image.fromarray(rgb_res),
|
||||||
|
ext=ext,
|
||||||
|
quality=self.config.quality,
|
||||||
|
infos=infos,
|
||||||
|
),
|
||||||
|
media_type=f"image/{ext}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def launch(self):
|
||||||
|
self.app.include_router(self.router)
|
||||||
|
uvicorn.run(
|
||||||
|
self.app,
|
||||||
|
host=self.config.host,
|
||||||
|
port=self.config.port,
|
||||||
|
timeout_keep_alive=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_file_manager(self) -> Optional[FileManager]:
|
||||||
|
if self.config.input and self.config.input.is_dir():
|
||||||
|
logger.info(
|
||||||
|
f"Input is directory, initialize file manager {self.config.input}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileManager(
|
||||||
|
app=self.app,
|
||||||
|
input_dir=self.config.input,
|
||||||
|
output_dir=self.config.output_dir,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_plugins(self) -> Dict:
|
||||||
|
return build_plugins(
|
||||||
|
self.config.enable_interactive_seg,
|
||||||
|
self.config.interactive_seg_model,
|
||||||
|
self.config.interactive_seg_device,
|
||||||
|
self.config.enable_remove_bg,
|
||||||
|
self.config.enable_anime_seg,
|
||||||
|
self.config.enable_realesrgan,
|
||||||
|
self.config.realesrgan_device,
|
||||||
|
self.config.realesrgan_model,
|
||||||
|
self.config.enable_gfpgan,
|
||||||
|
self.config.gfpgan_device,
|
||||||
|
self.config.enable_restoreformer,
|
||||||
|
self.config.restoreformer_device,
|
||||||
|
self.config.no_half,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_model_manager(self):
|
||||||
|
return ModelManager(
|
||||||
|
name=self.config.model,
|
||||||
|
device=torch.device(self.config.device),
|
||||||
|
no_half=self.config.no_half,
|
||||||
|
disable_nsfw=self.config.disable_nsfw_checker,
|
||||||
|
sd_cpu_textencoder=self.config.cpu_textencoder,
|
||||||
|
cpu_offload=self.config.cpu_offload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from lama_cleaner.schema import InteractiveSegModel, RealESRGANModel
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
api = Api(
|
||||||
|
app,
|
||||||
|
ApiConfig(
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=8080,
|
||||||
|
model="lama",
|
||||||
|
no_half=False,
|
||||||
|
cpu_offload=False,
|
||||||
|
disable_nsfw_checker=False,
|
||||||
|
cpu_textencoder=False,
|
||||||
|
device="cpu",
|
||||||
|
gui=False,
|
||||||
|
disable_model_switch=False,
|
||||||
|
input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images",
|
||||||
|
output_dir="/Users/cwq/code/github/lama-cleaner/tmp",
|
||||||
|
quality=100,
|
||||||
|
enable_interactive_seg=False,
|
||||||
|
interactive_seg_model=InteractiveSegModel.vit_b,
|
||||||
|
interactive_seg_device="cpu",
|
||||||
|
enable_remove_bg=False,
|
||||||
|
enable_anime_seg=False,
|
||||||
|
enable_realesrgan=False,
|
||||||
|
realesrgan_device="cpu",
|
||||||
|
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
||||||
|
enable_gfpgan=False,
|
||||||
|
gfpgan_device="cpu",
|
||||||
|
enable_restoreformer=False,
|
||||||
|
restoreformer_device="cpu",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
api.launch()
|
@ -10,7 +10,7 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
from lama_cleaner.schema import InpaintRequest, HDStrategy, SDSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
@ -36,7 +36,7 @@ def run_model(model, size):
|
|||||||
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
||||||
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
||||||
|
|
||||||
config = Config(
|
config = InpaintRequest(
|
||||||
ldm_steps=2,
|
ldm_steps=2,
|
||||||
hd_strategy=HDStrategy.ORIGINAL,
|
hd_strategy=HDStrategy.ORIGINAL,
|
||||||
hd_strategy_crop_margin=128,
|
hd_strategy_crop_margin=128,
|
||||||
@ -44,7 +44,7 @@ def run_model(model, size):
|
|||||||
hd_strategy_resize_limit=128,
|
hd_strategy_resize_limit=128,
|
||||||
prompt="a fox is sitting on a bench",
|
prompt="a fox is sitting on a bench",
|
||||||
sd_steps=5,
|
sd_steps=5,
|
||||||
sd_sampler=SDSampler.ddim
|
sd_sampler=SDSampler.ddim,
|
||||||
)
|
)
|
||||||
model(image, mask, config)
|
model(image, mask, config)
|
||||||
|
|
||||||
@ -75,7 +75,9 @@ def benchmark(model, times: int, empty_cache: bool):
|
|||||||
# cpu_metrics.append(process.cpu_percent())
|
# cpu_metrics.append(process.cpu_percent())
|
||||||
time_metrics.append((time.time() - start) * 1000)
|
time_metrics.append((time.time() - start) * 1000)
|
||||||
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
||||||
gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
|
gpu_memory_metrics.append(
|
||||||
|
nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
|
||||||
|
)
|
||||||
|
|
||||||
print(f"size: {size}".center(80, "-"))
|
print(f"size: {size}".center(80, "-"))
|
||||||
# print(f"cpu: {format(cpu_metrics)}")
|
# print(f"cpu: {format(cpu_metrics)}")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from fastapi import FastAPI
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typer import Option
|
from typer import Option
|
||||||
|
|
||||||
@ -38,6 +39,17 @@ def list_model(
|
|||||||
print(it.name)
|
print(it.name)
|
||||||
|
|
||||||
|
|
||||||
|
@typer_app.command(help="Processing image with lama cleaner")
|
||||||
|
def run(
|
||||||
|
input: Path = Option(..., help="Image file or folder containing images"),
|
||||||
|
output_dir: Path = Option(..., help="Output directory"),
|
||||||
|
config_path: Path = Option(..., help="Config file path"),
|
||||||
|
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||||
|
):
|
||||||
|
setup_model_dir(model_dir)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@typer_app.command(help="Start lama cleaner server")
|
@typer_app.command(help="Start lama cleaner server")
|
||||||
def start(
|
def start(
|
||||||
host: str = Option("127.0.0.1"),
|
host: str = Option("127.0.0.1"),
|
||||||
@ -80,6 +92,16 @@ def start(
|
|||||||
):
|
):
|
||||||
dump_environment_info()
|
dump_environment_info()
|
||||||
device = check_device(device)
|
device = check_device(device)
|
||||||
|
if input and not input.exists():
|
||||||
|
logger.error(f"invalid --input: {input} not exists")
|
||||||
|
exit()
|
||||||
|
if output_dir:
|
||||||
|
output_dir = output_dir.expanduser().absolute()
|
||||||
|
logger.info(f"Image will be saved to {output_dir}")
|
||||||
|
if not output_dir.exists():
|
||||||
|
logger.info(f"Create output directory {output_dir}")
|
||||||
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
model_dir = model_dir.expanduser().absolute()
|
model_dir = model_dir.expanduser().absolute()
|
||||||
setup_model_dir(model_dir)
|
setup_model_dir(model_dir)
|
||||||
|
|
||||||
@ -92,32 +114,38 @@ def start(
|
|||||||
logger.info(f"{model} not found in {model_dir}, try to downloading")
|
logger.info(f"{model} not found in {model_dir}, try to downloading")
|
||||||
cli_download_model(model, model_dir)
|
cli_download_model(model, model_dir)
|
||||||
|
|
||||||
from lama_cleaner.server import start
|
from lama_cleaner.api import Api
|
||||||
|
from lama_cleaner.schema import ApiConfig
|
||||||
|
|
||||||
start(
|
app = FastAPI()
|
||||||
host=host,
|
api = Api(
|
||||||
port=port,
|
app,
|
||||||
model=model,
|
ApiConfig(
|
||||||
no_half=no_half,
|
host=host,
|
||||||
cpu_offload=cpu_offload,
|
port=port,
|
||||||
disable_nsfw_checker=disable_nsfw_checker,
|
model=model,
|
||||||
cpu_textencoder=cpu_textencoder,
|
no_half=no_half,
|
||||||
device=device,
|
cpu_offload=cpu_offload,
|
||||||
gui=gui,
|
disable_nsfw_checker=disable_nsfw_checker,
|
||||||
disable_model_switch=disable_model_switch,
|
cpu_textencoder=cpu_textencoder,
|
||||||
input=input,
|
device=device,
|
||||||
output_dir=output_dir,
|
gui=gui,
|
||||||
quality=quality,
|
disable_model_switch=disable_model_switch,
|
||||||
enable_interactive_seg=enable_interactive_seg,
|
input=input,
|
||||||
interactive_seg_model=interactive_seg_model,
|
output_dir=output_dir,
|
||||||
interactive_seg_device=interactive_seg_device,
|
quality=quality,
|
||||||
enable_remove_bg=enable_remove_bg,
|
enable_interactive_seg=enable_interactive_seg,
|
||||||
enable_anime_seg=enable_anime_seg,
|
interactive_seg_model=interactive_seg_model,
|
||||||
enable_realesrgan=enable_realesrgan,
|
interactive_seg_device=interactive_seg_device,
|
||||||
realesrgan_device=realesrgan_device,
|
enable_remove_bg=enable_remove_bg,
|
||||||
realesrgan_model=realesrgan_model,
|
enable_anime_seg=enable_anime_seg,
|
||||||
enable_gfpgan=enable_gfpgan,
|
enable_realesrgan=enable_realesrgan,
|
||||||
gfpgan_device=gfpgan_device,
|
realesrgan_device=realesrgan_device,
|
||||||
enable_restoreformer=enable_restoreformer,
|
realesrgan_model=realesrgan_model,
|
||||||
restoreformer_device=restoreformer_device,
|
enable_gfpgan=enable_gfpgan,
|
||||||
|
gfpgan_device=gfpgan_device,
|
||||||
|
enable_restoreformer=enable_restoreformer,
|
||||||
|
restoreformer_device=restoreformer_device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
api.launch()
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import time
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
from typing import List
|
||||||
|
|
||||||
# from watchdog.events import FileSystemEventHandler
|
|
||||||
# from watchdog.observers import Observer
|
|
||||||
|
|
||||||
from PIL import Image, ImageOps, PngImagePlugin
|
from PIL import Image, ImageOps, PngImagePlugin
|
||||||
from loguru import logger
|
from fastapi import FastAPI, UploadFile, HTTPException
|
||||||
|
from starlette.responses import FileResponse
|
||||||
|
|
||||||
|
from ..schema import (
|
||||||
|
MediasResponse,
|
||||||
|
MediasRequest,
|
||||||
|
MediaFileRequest,
|
||||||
|
MediaTab,
|
||||||
|
MediaThumbnailFileRequest,
|
||||||
|
)
|
||||||
|
|
||||||
LARGE_ENOUGH_NUMBER = 100
|
LARGE_ENOUGH_NUMBER = 100
|
||||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||||
@ -21,132 +22,87 @@ from .utils import aspect_to_string, generate_filename, glob_img
|
|||||||
|
|
||||||
|
|
||||||
class FileManager:
|
class FileManager:
|
||||||
def __init__(self, app=None):
|
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
|
||||||
self.app = app
|
self.app = app
|
||||||
self._default_root_directory = "media"
|
self.input_dir: Path = input_dir
|
||||||
self._default_thumbnail_directory = "media"
|
self.output_dir: Path = output_dir
|
||||||
self._default_root_url = "/"
|
|
||||||
self._default_thumbnail_root_url = "/"
|
|
||||||
self._default_format = "JPEG"
|
|
||||||
self.output_dir: Path = None
|
|
||||||
|
|
||||||
if app is not None:
|
|
||||||
self.init_app(app)
|
|
||||||
|
|
||||||
self.image_dir_filenames = []
|
self.image_dir_filenames = []
|
||||||
self.output_dir_filenames = []
|
self.output_dir_filenames = []
|
||||||
|
if not self.thumbnail_directory.exists():
|
||||||
|
self.thumbnail_directory.mkdir(parents=True)
|
||||||
|
|
||||||
self.image_dir_observer = None
|
# fmt: off
|
||||||
self.output_dir_observer = None
|
self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
||||||
|
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse])
|
||||||
|
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None)
|
||||||
|
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
self.modified_time = {
|
def api_save_image(self, file: UploadFile):
|
||||||
"image": datetime.utcnow(),
|
filename = file.filename
|
||||||
"output": datetime.utcnow(),
|
origin_image_bytes = file.file.read()
|
||||||
}
|
with open(self.output_dir / filename, "wb") as fw:
|
||||||
|
fw.write(origin_image_bytes)
|
||||||
|
|
||||||
# def start(self):
|
def api_medias(self, req: MediasRequest) -> List[MediasResponse]:
|
||||||
# self.image_dir_filenames = self._media_names(self.root_directory)
|
img_dir = self._get_dir(req.tab)
|
||||||
# self.output_dir_filenames = self._media_names(self.output_dir)
|
return self._media_names(img_dir)
|
||||||
#
|
|
||||||
# logger.info(f"Start watching image directory: {self.root_directory}")
|
|
||||||
# self.image_dir_observer = Observer()
|
|
||||||
# self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
|
||||||
# self.image_dir_observer.start()
|
|
||||||
#
|
|
||||||
# logger.info(f"Start watching output directory: {self.output_dir}")
|
|
||||||
# self.output_dir_observer = Observer()
|
|
||||||
# self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
|
||||||
# self.output_dir_observer.start()
|
|
||||||
|
|
||||||
def on_modified(self, event):
|
def api_media_file(self, req: MediaFileRequest) -> FileResponse:
|
||||||
if not os.path.isdir(event.src_path):
|
file_path = self._get_file(req.tab, req.filename)
|
||||||
return
|
return FileResponse(file_path)
|
||||||
if event.src_path == str(self.root_directory):
|
|
||||||
logger.info(f"Image directory {event.src_path} modified")
|
|
||||||
self.image_dir_filenames = self._media_names(self.root_directory)
|
|
||||||
self.modified_time["image"] = datetime.utcnow()
|
|
||||||
elif event.src_path == str(self.output_dir):
|
|
||||||
logger.info(f"Output directory {event.src_path} modified")
|
|
||||||
self.output_dir_filenames = self._media_names(self.output_dir)
|
|
||||||
self.modified_time["output"] = datetime.utcnow()
|
|
||||||
|
|
||||||
def init_app(self, app):
|
def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse:
|
||||||
if self.app is None:
|
img_dir = self._get_dir(req.tab)
|
||||||
self.app = app
|
thumb_filename, (width, height) = self.get_thumbnail(
|
||||||
app.thumbnail_instance = self
|
img_dir, req.filename, width=req.width, height=req.height
|
||||||
|
|
||||||
if not hasattr(app, "extensions"):
|
|
||||||
app.extensions = {}
|
|
||||||
|
|
||||||
if "thumbnail" in app.extensions:
|
|
||||||
raise RuntimeError("Flask-thumbnail extension already initialized")
|
|
||||||
|
|
||||||
app.extensions["thumbnail"] = self
|
|
||||||
|
|
||||||
app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory)
|
|
||||||
app.config.setdefault(
|
|
||||||
"THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory
|
|
||||||
)
|
)
|
||||||
app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url)
|
thumbnail_filepath = self.thumbnail_directory / thumb_filename
|
||||||
app.config.setdefault(
|
return FileResponse(
|
||||||
"THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url
|
thumbnail_filepath,
|
||||||
|
headers={
|
||||||
|
"X-Width": str(width),
|
||||||
|
"X-Height": str(height),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
|
|
||||||
|
|
||||||
@property
|
def _get_dir(self, tab: MediaTab) -> Path:
|
||||||
def root_directory(self):
|
if tab == "input":
|
||||||
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
|
return self.input_dir
|
||||||
|
elif tab == "output":
|
||||||
if os.path.isabs(path):
|
return self.output_dir
|
||||||
return path
|
|
||||||
else:
|
else:
|
||||||
return os.path.join(self.app.root_path, path)
|
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
|
||||||
|
|
||||||
|
def _get_file(self, tab: MediaTab, filename: str) -> Path:
|
||||||
|
file_path = self._get_dir(tab) / filename
|
||||||
|
if not file_path.exists():
|
||||||
|
raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
|
||||||
|
return file_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def thumbnail_directory(self):
|
def thumbnail_directory(self) -> Path:
|
||||||
path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"]
|
return self.output_dir / "thumbnails"
|
||||||
|
|
||||||
if os.path.isabs(path):
|
|
||||||
return path
|
|
||||||
else:
|
|
||||||
return os.path.join(self.app.root_path, path)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def root_url(self):
|
|
||||||
return self.app.config["THUMBNAIL_MEDIA_URL"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def media_names(self):
|
|
||||||
# return self.image_dir_filenames
|
|
||||||
return self._media_names(self.root_directory)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_media_names(self):
|
|
||||||
return self._media_names(self.output_dir)
|
|
||||||
# return self.output_dir_filenames
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _media_names(directory: Path):
|
def _media_names(directory: Path) -> List[MediasResponse]:
|
||||||
names = sorted([it.name for it in glob_img(directory)])
|
names = sorted([it.name for it in glob_img(directory)])
|
||||||
res = []
|
res = []
|
||||||
for name in names:
|
for name in names:
|
||||||
path = os.path.join(directory, name)
|
path = os.path.join(directory, name)
|
||||||
img = Image.open(path)
|
img = Image.open(path)
|
||||||
res.append(
|
res.append(
|
||||||
{
|
MediasResponse(
|
||||||
"name": name,
|
name=name,
|
||||||
"height": img.height,
|
height=img.height,
|
||||||
"width": img.width,
|
width=img.width,
|
||||||
"ctime": os.path.getctime(path),
|
ctime=os.path.getctime(path),
|
||||||
"mtime": os.path.getmtime(path),
|
mtime=os.path.getmtime(path),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@property
|
|
||||||
def thumbnail_url(self):
|
|
||||||
return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"]
|
|
||||||
|
|
||||||
def get_thumbnail(
|
def get_thumbnail(
|
||||||
self, directory: Path, original_filename: str, width, height, **options
|
self, directory: Path, original_filename: str, width, height, **options
|
||||||
):
|
):
|
||||||
@ -161,7 +117,10 @@ class FileManager:
|
|||||||
image = Image.open(BytesIO(storage.read(original_filepath)))
|
image = Image.open(BytesIO(storage.read(original_filepath)))
|
||||||
|
|
||||||
# keep ratio resize
|
# keep ratio resize
|
||||||
if width is not None:
|
if not width and not height:
|
||||||
|
width = 256
|
||||||
|
|
||||||
|
if width != 0:
|
||||||
height = int(image.height * width / image.width)
|
height = int(image.height * width / image.width)
|
||||||
else:
|
else:
|
||||||
width = int(image.width * height / image.height)
|
width = int(image.width * height / image.height)
|
||||||
@ -180,18 +139,15 @@ class FileManager:
|
|||||||
thumbnail_filepath = os.path.join(
|
thumbnail_filepath = os.path.join(
|
||||||
self.thumbnail_directory, original_path, thumbnail_filename
|
self.thumbnail_directory, original_path, thumbnail_filename
|
||||||
)
|
)
|
||||||
thumbnail_url = os.path.join(
|
|
||||||
self.thumbnail_url, original_path, thumbnail_filename
|
|
||||||
)
|
|
||||||
|
|
||||||
if storage.exists(thumbnail_filepath):
|
if storage.exists(thumbnail_filepath):
|
||||||
return thumbnail_url, (width, height)
|
return thumbnail_filepath, (width, height)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image.load()
|
image.load()
|
||||||
except (IOError, OSError):
|
except (IOError, OSError):
|
||||||
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
|
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
|
||||||
return thumbnail_url, (width, height)
|
return thumbnail_filepath, (width, height)
|
||||||
|
|
||||||
# get original image format
|
# get original image format
|
||||||
options["format"] = options.get("format", image.format)
|
options["format"] = options.get("format", image.format)
|
||||||
@ -203,7 +159,7 @@ class FileManager:
|
|||||||
raw_data = self.get_raw_data(image, **options)
|
raw_data = self.get_raw_data(image, **options)
|
||||||
storage.save(thumbnail_filepath, raw_data)
|
storage.save(thumbnail_filepath, raw_data)
|
||||||
|
|
||||||
return thumbnail_url, (width, height)
|
return thumbnail_filepath, (width, height)
|
||||||
|
|
||||||
def get_raw_data(self, image, **options):
|
def get_raw_data(self, image, **options):
|
||||||
data = {
|
data = {
|
||||||
@ -246,7 +202,7 @@ class FileManager:
|
|||||||
if image.format:
|
if image.format:
|
||||||
return image.format
|
return image.format
|
||||||
|
|
||||||
return self.app.config["THUMBNAIL_DEFAULT_FORMAT"]
|
return "JPEG"
|
||||||
|
|
||||||
def _create_thumbnail(self, image, size, crop="fit", background=None):
|
def _create_thumbnail(self, image, size, crop="fit", background=None):
|
||||||
try:
|
try:
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import base64
|
||||||
|
import imghdr
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict, Tuple
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import cv2
|
import cv2
|
||||||
@ -138,8 +140,8 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
|||||||
def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
|
def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
|
||||||
with io.BytesIO() as output:
|
with io.BytesIO() as output:
|
||||||
kwargs = {k: v for k, v in infos.items() if v is not None}
|
kwargs = {k: v for k, v in infos.items() if v is not None}
|
||||||
if ext == 'jpg':
|
if ext == "jpg":
|
||||||
ext = 'jpeg'
|
ext = "jpeg"
|
||||||
if "png" == ext.lower() and "parameters" in kwargs:
|
if "png" == ext.lower() and "parameters" in kwargs:
|
||||||
pnginfo_data = PngImagePlugin.PngInfo()
|
pnginfo_data = PngImagePlugin.PngInfo()
|
||||||
pnginfo_data.add_text("parameters", kwargs["parameters"])
|
pnginfo_data.add_text("parameters", kwargs["parameters"])
|
||||||
@ -290,3 +292,61 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
|||||||
|
|
||||||
def is_mac():
|
def is_mac():
|
||||||
return sys.platform == "darwin"
|
return sys.platform == "darwin"
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_ext(img_bytes):
|
||||||
|
w = imghdr.what("", img_bytes)
|
||||||
|
if w is None:
|
||||||
|
w = "jpeg"
|
||||||
|
return w
|
||||||
|
|
||||||
|
|
||||||
|
def decode_base64_to_image(
|
||||||
|
encoding: str, gray=False
|
||||||
|
) -> Tuple[np.array, Optional[np.array], Dict]:
|
||||||
|
if encoding.startswith("data:image/"):
|
||||||
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
|
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
|
||||||
|
|
||||||
|
alpha_channel = None
|
||||||
|
infos = image.info
|
||||||
|
try:
|
||||||
|
image = ImageOps.exif_transpose(image)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if gray:
|
||||||
|
image = image.convert("L")
|
||||||
|
np_img = np.array(image)
|
||||||
|
else:
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
np_img = np.array(image)
|
||||||
|
alpha_channel = np_img[:, :, -1]
|
||||||
|
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
||||||
|
else:
|
||||||
|
image = image.convert("RGB")
|
||||||
|
np_img = np.array(image)
|
||||||
|
|
||||||
|
return np_img, alpha_channel, infos
|
||||||
|
|
||||||
|
|
||||||
|
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
|
||||||
|
img_bytes = pil_to_bytes(
|
||||||
|
image,
|
||||||
|
"png",
|
||||||
|
quality=quality,
|
||||||
|
infos=infos,
|
||||||
|
)
|
||||||
|
return base64.b64encode(img_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
|
||||||
|
if alpha_channel is not None:
|
||||||
|
if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
|
||||||
|
alpha_channel = cv2.resize(
|
||||||
|
alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
|
||||||
|
)
|
||||||
|
rgb_np_img = np.concatenate(
|
||||||
|
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||||
|
)
|
||||||
|
return rgb_np_img
|
||||||
|
@ -14,7 +14,7 @@ from lama_cleaner.helper import (
|
|||||||
)
|
)
|
||||||
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
||||||
from lama_cleaner.model.utils import get_scheduler
|
from lama_cleaner.model.utils import get_scheduler
|
||||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
from lama_cleaner.schema import InpaintRequest, HDStrategy, SDSampler
|
||||||
|
|
||||||
|
|
||||||
class InpaintModel:
|
class InpaintModel:
|
||||||
@ -44,7 +44,7 @@ class InpaintModel:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input images and output images have same size
|
"""Input images and output images have same size
|
||||||
images: [H, W, C] RGB
|
images: [H, W, C] RGB
|
||||||
masks: [H, W, 1] 255 为 masks 区域
|
masks: [H, W, 1] 255 为 masks 区域
|
||||||
@ -56,7 +56,7 @@ class InpaintModel:
|
|||||||
def download():
|
def download():
|
||||||
...
|
...
|
||||||
|
|
||||||
def _pad_forward(self, image, mask, config: Config):
|
def _pad_forward(self, image, mask, config: InpaintRequest):
|
||||||
origin_height, origin_width = image.shape[:2]
|
origin_height, origin_width = image.shape[:2]
|
||||||
pad_image = pad_img_to_modulo(
|
pad_image = pad_img_to_modulo(
|
||||||
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
||||||
@ -74,7 +74,7 @@ class InpaintModel:
|
|||||||
|
|
||||||
result, image, mask = self.forward_post_process(result, image, mask, config)
|
result, image, mask = self.forward_post_process(result, image, mask, config)
|
||||||
|
|
||||||
if config.sd_prevent_unmasked_area:
|
if config.sd_keep_unmasked_area:
|
||||||
mask = mask[:, :, np.newaxis]
|
mask = mask[:, :, np.newaxis]
|
||||||
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
||||||
return result
|
return result
|
||||||
@ -86,7 +86,7 @@ class InpaintModel:
|
|||||||
return result, image, mask
|
return result, image, mask
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
images: [H, W, C] RGB, not normalized
|
images: [H, W, C] RGB, not normalized
|
||||||
masks: [H, W]
|
masks: [H, W]
|
||||||
@ -141,7 +141,7 @@ class InpaintModel:
|
|||||||
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
def _crop_box(self, image, mask, box, config: Config):
|
def _crop_box(self, image, mask, box, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -233,7 +233,7 @@ class InpaintModel:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _apply_cropper(self, image, mask, config: Config):
|
def _apply_cropper(self, image, mask, config: InpaintRequest):
|
||||||
img_h, img_w = image.shape[:2]
|
img_h, img_w = image.shape[:2]
|
||||||
l, t, w, h = (
|
l, t, w, h = (
|
||||||
config.croper_x,
|
config.croper_x,
|
||||||
@ -253,7 +253,7 @@ class InpaintModel:
|
|||||||
crop_mask = mask[t:b, l:r]
|
crop_mask = mask[t:b, l:r]
|
||||||
return crop_img, crop_mask, (l, t, r, b)
|
return crop_img, crop_mask, (l, t, r, b)
|
||||||
|
|
||||||
def _run_box(self, image, mask, box, config: Config):
|
def _run_box(self, image, mask, box, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -276,7 +276,7 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
super().__init__(device, **kwargs)
|
super().__init__(device, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
images: [H, W, C] RGB, not normalized
|
images: [H, W, C] RGB, not normalized
|
||||||
masks: [H, W]
|
masks: [H, W]
|
||||||
@ -295,7 +295,7 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
def _do_outpainting(self, image, config: Config):
|
def _do_outpainting(self, image, config: InpaintRequest):
|
||||||
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
||||||
# 从 image 中 crop 出 outpainting 区域
|
# 从 image 中 crop 出 outpainting 区域
|
||||||
image_h, image_w = image.shape[:2]
|
image_h, image_w = image.shape[:2]
|
||||||
@ -368,7 +368,7 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
] = expanded_cropped_result_image
|
] = expanded_cropped_result_image
|
||||||
return outpainting_image
|
return outpainting_image
|
||||||
|
|
||||||
def _scaled_pad_forward(self, image, mask, config: Config):
|
def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
|
||||||
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||||
origin_size = image.shape[:2]
|
origin_size = image.shape[:2]
|
||||||
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||||
@ -396,7 +396,7 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
# ]
|
# ]
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
def set_scheduler(self, config: Config):
|
def set_scheduler(self, config: InpaintRequest):
|
||||||
scheduler_config = self.model.scheduler.config
|
scheduler_config = self.model.scheduler.config
|
||||||
sd_sampler = config.sd_sampler
|
sd_sampler = config.sd_sampler
|
||||||
if config.sd_lcm_lora:
|
if config.sd_lcm_lora:
|
||||||
|
@ -14,7 +14,7 @@ from lama_cleaner.model.helper.controlnet_preprocess import (
|
|||||||
)
|
)
|
||||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions
|
from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions
|
||||||
from lama_cleaner.schema import Config, ModelType
|
from lama_cleaner.schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
|
|
||||||
class ControlNet(DiffusionInpaintModel):
|
class ControlNet(DiffusionInpaintModel):
|
||||||
@ -130,7 +130,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
raise NotImplementedError(f"{self.controlnet_method} not implemented")
|
raise NotImplementedError(f"{self.controlnet_method} not implemented")
|
||||||
return control_image
|
return control_image
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.fft as fft
|
import torch.fft as fft
|
||||||
|
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
from lama_cleaner.helper import (
|
from lama_cleaner.helper import (
|
||||||
load_model,
|
load_model,
|
||||||
@ -1665,7 +1665,7 @@ class FcF(InpaintModel):
|
|||||||
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
|
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
images: [H, W, C] RGB, not normalized
|
images: [H, W, C] RGB, not normalized
|
||||||
masks: [H, W]
|
masks: [H, W]
|
||||||
@ -1705,7 +1705,7 @@ class FcF(InpaintModel):
|
|||||||
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input images and output images have same size
|
"""Input images and output images have same size
|
||||||
images: [H, W, C] RGB
|
images: [H, W, C] RGB
|
||||||
masks: [H, W] mask area == 255
|
masks: [H, W] mask area == 255
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
|
|
||||||
class InstructPix2Pix(DiffusionInpaintModel):
|
class InstructPix2Pix(DiffusionInpaintModel):
|
||||||
@ -40,7 +40,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
|||||||
else:
|
else:
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.utils import get_scheduler
|
from lama_cleaner.model.utils import get_scheduler
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
|
|
||||||
class Kandinsky(DiffusionInpaintModel):
|
class Kandinsky(DiffusionInpaintModel):
|
||||||
@ -29,7 +29,7 @@ class Kandinsky(DiffusionInpaintModel):
|
|||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
@ -11,7 +11,7 @@ from lama_cleaner.helper import (
|
|||||||
download_model,
|
download_model,
|
||||||
)
|
)
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
LAMA_MODEL_URL = os.environ.get(
|
LAMA_MODEL_URL = os.environ.get(
|
||||||
"LAMA_MODEL_URL",
|
"LAMA_MODEL_URL",
|
||||||
@ -36,7 +36,7 @@ class LaMa(InpaintModel):
|
|||||||
def is_downloaded() -> bool:
|
def is_downloaded() -> bool:
|
||||||
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W]
|
mask: [H, W]
|
||||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
|||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
||||||
from lama_cleaner.model.plms_sampler import PLMSSampler
|
from lama_cleaner.model.plms_sampler import PLMSSampler
|
||||||
from lama_cleaner.schema import Config, LDMSampler
|
from lama_cleaner.schema import InpaintRequest, LDMSampler
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -277,7 +277,7 @@ class LDM(InpaintModel):
|
|||||||
return all([os.path.exists(it) for it in model_paths])
|
return all([os.path.exists(it) for it in model_paths])
|
||||||
|
|
||||||
@torch.cuda.amp.autocast()
|
@torch.cuda.amp.autocast()
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1]
|
mask: [H, W, 1]
|
||||||
|
@ -9,7 +9,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
|
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
|
|
||||||
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
|
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
|
||||||
@ -56,7 +56,7 @@ class Manga(InpaintModel):
|
|||||||
]
|
]
|
||||||
return all([os.path.exists(it) for it in model_paths])
|
return all([os.path.exists(it) for it in model_paths])
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1]
|
mask: [H, W, 1]
|
||||||
|
@ -28,7 +28,7 @@ from lama_cleaner.model.utils import (
|
|||||||
normalize_2nd_moment,
|
normalize_2nd_moment,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
|
|
||||||
class ModulatedConv2d(nn.Module):
|
class ModulatedConv2d(nn.Module):
|
||||||
@ -1912,7 +1912,7 @@ class MAT(InpaintModel):
|
|||||||
def is_downloaded() -> bool:
|
def is_downloaded() -> bool:
|
||||||
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
|
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input images and output images have same size
|
"""Input images and output images have same size
|
||||||
images: [H, W, C] RGB
|
images: [H, W, C] RGB
|
||||||
masks: [H, W] mask area == 255
|
masks: [H, W] mask area == 255
|
||||||
|
@ -3,7 +3,6 @@ import os
|
|||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lama_cleaner.const import Config
|
|
||||||
from lama_cleaner.helper import (
|
from lama_cleaner.helper import (
|
||||||
load_jit_model,
|
load_jit_model,
|
||||||
download_model,
|
download_model,
|
||||||
@ -13,6 +12,7 @@ from lama_cleaner.helper import (
|
|||||||
norm_img,
|
norm_img,
|
||||||
)
|
)
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
MIGAN_MODEL_URL = os.environ.get(
|
MIGAN_MODEL_URL = os.environ.get(
|
||||||
"MIGAN_MODEL_URL",
|
"MIGAN_MODEL_URL",
|
||||||
@ -40,7 +40,7 @@ class MIGAN(InpaintModel):
|
|||||||
return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
|
return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: InpaintRequest):
|
||||||
"""
|
"""
|
||||||
images: [H, W, C] RGB, not normalized
|
images: [H, W, C] RGB, not normalized
|
||||||
masks: [H, W]
|
masks: [H, W]
|
||||||
@ -80,7 +80,7 @@ class MIGAN(InpaintModel):
|
|||||||
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input images and output images have same size
|
"""Input images and output images have same size
|
||||||
images: [H, W, C] RGB
|
images: [H, W, C] RGB
|
||||||
masks: [H, W] mask area == 255
|
masks: [H, W] mask area == 255
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import cv2
|
import cv2
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
|
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class OpenCV2(InpaintModel):
|
|||||||
def is_downloaded() -> bool:
|
def is_downloaded() -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1]
|
mask: [H, W, 1]
|
||||||
|
@ -4,8 +4,9 @@ import cv2
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.helper import decode_base64_to_image
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
|
|
||||||
class PaintByExample(DiffusionInpaintModel):
|
class PaintByExample(DiffusionInpaintModel):
|
||||||
@ -38,16 +39,21 @@ class PaintByExample(DiffusionInpaintModel):
|
|||||||
else:
|
else:
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
|
if config.paint_by_example_example_image is None:
|
||||||
|
raise ValueError("paint_by_example_example_image is required")
|
||||||
|
example_image, _, _ = decode_base64_to_image(
|
||||||
|
config.paint_by_example_example_image
|
||||||
|
)
|
||||||
output = self.model(
|
output = self.model(
|
||||||
image=PIL.Image.fromarray(image),
|
image=PIL.Image.fromarray(image),
|
||||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||||
example_image=config.paint_by_example_example_image,
|
example_image=PIL.Image.fromarray(example_image),
|
||||||
num_inference_steps=config.sd_steps,
|
num_inference_steps=config.sd_steps,
|
||||||
guidance_scale=config.sd_guidance_scale,
|
guidance_scale=config.sd_guidance_scale,
|
||||||
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
|
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
|
||||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
|||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
from .powerpaint_tokenizer import add_task_to_prompt
|
from .powerpaint_tokenizer import add_task_to_prompt
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class PowerPaint(DiffusionInpaintModel):
|
|||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||||
from lama_cleaner.schema import Config, ModelType
|
from lama_cleaner.schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
|
|
||||||
class SD(DiffusionInpaintModel):
|
class SD(DiffusionInpaintModel):
|
||||||
@ -64,7 +64,7 @@ class SD(DiffusionInpaintModel):
|
|||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
@ -9,7 +9,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||||
from lama_cleaner.schema import Config, ModelType
|
from lama_cleaner.schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
|
|
||||||
class SDXL(DiffusionInpaintModel):
|
class SDXL(DiffusionInpaintModel):
|
||||||
@ -60,7 +60,7 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
|
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
@ -343,7 +343,7 @@ class ZITS(InpaintModel):
|
|||||||
items["line"] = line_pred.detach()
|
items["line"] = line_pred.detach()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input images and output images have same size
|
"""Input images and output images have same size
|
||||||
images: [H, W, C] RGB
|
images: [H, W, C] RGB
|
||||||
masks: [H, W]
|
masks: [H, W]
|
||||||
|
@ -8,7 +8,7 @@ from lama_cleaner.helper import switch_mps_device
|
|||||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||||
from lama_cleaner.model.utils import torch_gc
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.model_info import ModelInfo, ModelType
|
from lama_cleaner.model_info import ModelInfo, ModelType
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
@ -31,13 +31,15 @@ class ModelManager:
|
|||||||
self.model = self.init_model(name, device, **kwargs)
|
self.model = self.init_model(name, device, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_model(self) -> Dict:
|
def current_model(self) -> ModelInfo:
|
||||||
return self.available_models[self.name].model_dump()
|
return self.available_models[self.name]
|
||||||
|
|
||||||
def init_model(self, name: str, device, **kwargs):
|
def init_model(self, name: str, device, **kwargs):
|
||||||
logger.info(f"Loading model: {name}")
|
logger.info(f"Loading model: {name}")
|
||||||
if name not in self.available_models:
|
if name not in self.available_models:
|
||||||
raise NotImplementedError(f"Unsupported model: {name}. Available models: {self.available_models.keys()}")
|
raise NotImplementedError(
|
||||||
|
f"Unsupported model: {name}. Available models: {self.available_models.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
model_info = self.available_models[name]
|
model_info = self.available_models[name]
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@ -66,7 +68,17 @@ class ModelManager:
|
|||||||
|
|
||||||
raise NotImplementedError(f"Unsupported model: {name}")
|
raise NotImplementedError(f"Unsupported model: {name}")
|
||||||
|
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: InpaintRequest):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
config:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
self.switch_controlnet_method(config)
|
self.switch_controlnet_method(config)
|
||||||
self.enable_disable_freeu(config)
|
self.enable_disable_freeu(config)
|
||||||
self.enable_disable_lcm_lora(config)
|
self.enable_disable_lcm_lora(config)
|
||||||
@ -135,7 +147,7 @@ class ModelManager:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||||
|
|
||||||
def enable_disable_freeu(self, config: Config):
|
def enable_disable_freeu(self, config: InpaintRequest):
|
||||||
if str(self.model.device) == "mps":
|
if str(self.model.device) == "mps":
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -151,7 +163,7 @@ class ModelManager:
|
|||||||
else:
|
else:
|
||||||
self.model.model.disable_freeu()
|
self.model.model.disable_freeu()
|
||||||
|
|
||||||
def enable_disable_lcm_lora(self, config: Config):
|
def enable_disable_lcm_lora(self, config: InpaintRequest):
|
||||||
if self.available_models[self.name].support_lcm_lora:
|
if self.available_models[self.name].support_lcm_lora:
|
||||||
if config.sd_lcm_lora:
|
if config.sd_lcm_lora:
|
||||||
if not self.model.model.get_list_adapters():
|
if not self.model.model.get_list_adapters():
|
||||||
|
@ -10,7 +10,6 @@ from ..const import InteractiveSegModel, Device, RealESRGANModel
|
|||||||
|
|
||||||
|
|
||||||
def build_plugins(
|
def build_plugins(
|
||||||
global_config,
|
|
||||||
enable_interactive_seg: bool,
|
enable_interactive_seg: bool,
|
||||||
interactive_seg_model: InteractiveSegModel,
|
interactive_seg_model: InteractiveSegModel,
|
||||||
interactive_seg_device: Device,
|
interactive_seg_device: Device,
|
||||||
@ -25,25 +24,26 @@ def build_plugins(
|
|||||||
restoreformer_device: Device,
|
restoreformer_device: Device,
|
||||||
no_half: bool,
|
no_half: bool,
|
||||||
):
|
):
|
||||||
|
plugins = {}
|
||||||
if enable_interactive_seg:
|
if enable_interactive_seg:
|
||||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||||
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
|
plugins[InteractiveSeg.name] = InteractiveSeg(
|
||||||
interactive_seg_model, interactive_seg_device
|
interactive_seg_model, interactive_seg_device
|
||||||
)
|
)
|
||||||
|
|
||||||
if enable_remove_bg:
|
if enable_remove_bg:
|
||||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||||
global_config.plugins[RemoveBG.name] = RemoveBG()
|
plugins[RemoveBG.name] = RemoveBG()
|
||||||
|
|
||||||
if enable_anime_seg:
|
if enable_anime_seg:
|
||||||
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
||||||
global_config.plugins[AnimeSeg.name] = AnimeSeg()
|
plugins[AnimeSeg.name] = AnimeSeg()
|
||||||
|
|
||||||
if enable_realesrgan:
|
if enable_realesrgan:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
|
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
|
||||||
)
|
)
|
||||||
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
||||||
realesrgan_model,
|
realesrgan_model,
|
||||||
realesrgan_device,
|
realesrgan_device,
|
||||||
no_half=no_half,
|
no_half=no_half,
|
||||||
@ -57,14 +57,15 @@ def build_plugins(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
|
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
|
||||||
)
|
)
|
||||||
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||||
gfpgan_device,
|
gfpgan_device,
|
||||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
upscaler=plugins.get(RealESRGANUpscaler.name, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
if enable_restoreformer:
|
if enable_restoreformer:
|
||||||
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
|
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
|
||||||
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
||||||
restoreformer_device,
|
restoreformer_device,
|
||||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
upscaler=plugins.get(RealESRGANUpscaler.name, None),
|
||||||
)
|
)
|
||||||
|
return plugins
|
||||||
|
@ -1,8 +1,17 @@
|
|||||||
|
import random
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import Optional, Literal, List
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, validator, field_validator
|
||||||
|
|
||||||
|
from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel
|
||||||
|
|
||||||
|
|
||||||
|
class CV2Flag(str, Enum):
|
||||||
|
INPAINT_NS = "INPAINT_NS"
|
||||||
|
INPAINT_TELEA = "INPAINT_TELEA"
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
@ -56,93 +65,215 @@ class PowerPaintTask(str, Enum):
|
|||||||
outpainting = "outpainting"
|
outpainting = "outpainting"
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class ApiConfig(BaseModel):
|
||||||
class Config:
|
host: str
|
||||||
arbitrary_types_allowed = True
|
port: int
|
||||||
|
model: str
|
||||||
|
no_half: bool
|
||||||
|
cpu_offload: bool
|
||||||
|
disable_nsfw_checker: bool
|
||||||
|
cpu_textencoder: bool
|
||||||
|
device: Device
|
||||||
|
gui: bool
|
||||||
|
disable_model_switch: bool
|
||||||
|
input: Path
|
||||||
|
output_dir: Path
|
||||||
|
quality: int
|
||||||
|
enable_interactive_seg: bool
|
||||||
|
interactive_seg_model: InteractiveSegModel
|
||||||
|
interactive_seg_device: Device
|
||||||
|
enable_remove_bg: bool
|
||||||
|
enable_anime_seg: bool
|
||||||
|
enable_realesrgan: bool
|
||||||
|
realesrgan_device: Device
|
||||||
|
realesrgan_model: RealESRGANModel
|
||||||
|
enable_gfpgan: bool
|
||||||
|
gfpgan_device: Device
|
||||||
|
enable_restoreformer: bool
|
||||||
|
restoreformer_device: Device
|
||||||
|
|
||||||
# Configs for ldm model
|
|
||||||
ldm_steps: int = 20
|
|
||||||
ldm_sampler: str = LDMSampler.plms
|
|
||||||
|
|
||||||
# Configs for zits model
|
class InpaintRequest(BaseModel):
|
||||||
zits_wireframe: bool = True
|
image: Optional[str] = Field(..., description="base64 encoded image")
|
||||||
|
mask: Optional[str] = Field(..., description="base64 encoded mask")
|
||||||
|
|
||||||
# Configs for High Resolution Strategy(different way to preprocess image)
|
ldm_steps: int = Field(20, description="Steps for ldm model.")
|
||||||
hd_strategy: str = HDStrategy.CROP # See HDStrategy Enum
|
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
|
||||||
hd_strategy_crop_margin: int = 128
|
zits_wireframe: bool = Field(True, description="Enable wireframe for zits model.")
|
||||||
# If the longer side of the image is larger than this value, use crop strategy
|
|
||||||
hd_strategy_crop_trigger_size: int = 800
|
|
||||||
hd_strategy_resize_limit: int = 1280
|
|
||||||
|
|
||||||
# Configs for Stable Diffusion 1.5
|
hd_strategy: str = Field(
|
||||||
prompt: str = ""
|
HDStrategy.CROP,
|
||||||
negative_prompt: str = ""
|
description="Different way to preprocess image, only used by erase models(e.g. lama/mat)",
|
||||||
# Crop image to this size before doing sd inpainting
|
)
|
||||||
# The value is always on the original image scale
|
hd_strategy_crop_trigger_size: int = Field(
|
||||||
use_croper: bool = False
|
800,
|
||||||
croper_x: int = None
|
description="Crop trigger size for hd_strategy=CROP, if the longer side of the image is larger than this value, use crop strategy",
|
||||||
croper_y: int = None
|
)
|
||||||
croper_height: int = None
|
hd_strategy_crop_margin: int = Field(
|
||||||
croper_width: int = None
|
128, description="Crop margin for hd_strategy=CROP"
|
||||||
use_extender: bool = False
|
)
|
||||||
extender_x: int = None
|
hd_strategy_resize_limit: int = Field(
|
||||||
extender_y: int = None
|
1280, description="Resize limit for hd_strategy=RESIZE"
|
||||||
extender_height: int = None
|
)
|
||||||
extender_width: int = None
|
|
||||||
|
|
||||||
# Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
|
prompt: str = Field("", description="Prompt for diffusion models.")
|
||||||
# Used by sd models and paint_by_example model
|
negative_prompt: str = Field(
|
||||||
sd_scale: float = 1.0
|
"", description="Negative prompt for diffusion models."
|
||||||
# Blur the edge of mask area. The higher the number the smoother blend with the original image
|
)
|
||||||
sd_mask_blur: int = 0
|
use_croper: bool = Field(
|
||||||
# Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
False, description="Crop image before doing diffusion inpainting"
|
||||||
# starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
)
|
||||||
# on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
croper_x: int = Field(0, description="Crop x for croper")
|
||||||
# process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
croper_y: int = Field(0, description="Crop y for croper")
|
||||||
# essentially ignores `image`.
|
croper_height: int = Field(512, description="Crop height for croper")
|
||||||
sd_strength: float = 1.0
|
croper_width: int = Field(512, description="Crop width for croper")
|
||||||
# The number of denoising steps. More denoising steps usually lead to a
|
|
||||||
# higher quality image at the expense of slower inference.
|
|
||||||
sd_steps: int = 50
|
|
||||||
# Higher guidance scale encourages to generate images that are closely linked
|
|
||||||
# to the text prompt, usually at the expense of lower image quality.
|
|
||||||
sd_guidance_scale: float = 7.5
|
|
||||||
sd_sampler: str = SDSampler.uni_pc
|
|
||||||
# -1 mean random seed
|
|
||||||
sd_seed: int = 42
|
|
||||||
sd_match_histograms: bool = False
|
|
||||||
|
|
||||||
# out-painting
|
use_extender: bool = Field(
|
||||||
sd_outpainting_softness: float = 20.0
|
False, description="Extend image before doing sd outpainting"
|
||||||
sd_outpainting_space: float = 20.0
|
)
|
||||||
|
extender_x: int = Field(0, description="Extend x for extender")
|
||||||
|
extender_y: int = Field(0, description="Extend y for extender")
|
||||||
|
extender_height: int = Field(640, description="Extend height for extender")
|
||||||
|
extender_width: int = Field(640, description="Extend width for extender")
|
||||||
|
|
||||||
# freeu
|
sd_mask_blur: int = Field(
|
||||||
sd_freeu: bool = False
|
33,
|
||||||
|
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
|
||||||
|
)
|
||||||
|
sd_strength: float = Field(
|
||||||
|
1.0,
|
||||||
|
description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image",
|
||||||
|
)
|
||||||
|
sd_steps: int = Field(
|
||||||
|
50,
|
||||||
|
description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.",
|
||||||
|
)
|
||||||
|
sd_guidance_scale: float = Field(
|
||||||
|
7.5,
|
||||||
|
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.",
|
||||||
|
)
|
||||||
|
sd_sampler: str = Field(
|
||||||
|
SDSampler.uni_pc, description="Sampler for diffusion model."
|
||||||
|
)
|
||||||
|
sd_seed: int = Field(
|
||||||
|
42,
|
||||||
|
description="Seed for diffusion model. -1 mean random seed",
|
||||||
|
validate_default=True,
|
||||||
|
)
|
||||||
|
sd_match_histograms: bool = Field(
|
||||||
|
False,
|
||||||
|
description="Match histograms between inpainting area and original image.",
|
||||||
|
)
|
||||||
|
|
||||||
|
sd_outpainting_softness: float = Field(20.0)
|
||||||
|
sd_outpainting_space: float = Field(20.0)
|
||||||
|
|
||||||
|
sd_freeu: bool = Field(
|
||||||
|
False,
|
||||||
|
description="Enable freeu mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu",
|
||||||
|
)
|
||||||
sd_freeu_config: FREEUConfig = FREEUConfig()
|
sd_freeu_config: FREEUConfig = FREEUConfig()
|
||||||
|
|
||||||
# lcm-lora
|
sd_lcm_lora: bool = Field(
|
||||||
sd_lcm_lora: bool = False
|
False,
|
||||||
|
description="Enable lcm-lora mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm#texttoimage",
|
||||||
|
)
|
||||||
|
|
||||||
# preserving the unmasked area at the expense of some more unnatural transitions between the masked and unmasked areas.
|
sd_keep_unmasked_area: bool = Field(
|
||||||
sd_prevent_unmasked_area: bool = True
|
True, description="Keep unmasked area unchanged"
|
||||||
|
)
|
||||||
|
|
||||||
# Configs for opencv inpainting
|
cv2_flag: CV2Flag = Field(
|
||||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
CV2Flag.INPAINT_NS,
|
||||||
cv2_flag: str = "INPAINT_NS"
|
description="Flag for opencv inpainting: https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07",
|
||||||
cv2_radius: int = 4
|
)
|
||||||
|
cv2_radius: int = Field(
|
||||||
|
4,
|
||||||
|
description="Radius of a circular neighborhood of each point inpainted that is considered by the algorithm",
|
||||||
|
)
|
||||||
|
|
||||||
# Paint by Example
|
# Paint by Example
|
||||||
paint_by_example_example_image: Optional[Image] = None
|
paint_by_example_example_image: Optional[str] = Field(
|
||||||
|
None, description="Base64 encoded example image for paint by example model"
|
||||||
|
)
|
||||||
|
|
||||||
# InstructPix2Pix
|
# InstructPix2Pix
|
||||||
p2p_image_guidance_scale: float = 1.5
|
p2p_image_guidance_scale: float = Field(1.5, description="Image guidance scale")
|
||||||
|
|
||||||
# ControlNet
|
# ControlNet
|
||||||
enable_controlnet: bool = False
|
enable_controlnet: bool = Field(False, description="Enable controlnet")
|
||||||
controlnet_conditioning_scale: float = 0.4
|
controlnet_conditioning_scale: float = Field(0.4, description="Conditioning scale")
|
||||||
controlnet_method: str = "lllyasviel/control_v11p_sd15_canny"
|
controlnet_method: str = Field(
|
||||||
|
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
||||||
|
)
|
||||||
|
|
||||||
# PowerPaint
|
# PowerPaint
|
||||||
powerpaint_task: PowerPaintTask = PowerPaintTask.text_guided
|
powerpaint_task: PowerPaintTask = Field(
|
||||||
# control the fitting degree of the generated objects to the mask shape.
|
PowerPaintTask.text_guided, description="PowerPaint task"
|
||||||
fitting_degree: float = 1.0
|
)
|
||||||
|
fitting_degree: float = Field(
|
||||||
|
1.0,
|
||||||
|
description="Control the fitting degree of the generated objects to the mask shape.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("sd_seed")
|
||||||
|
@classmethod
|
||||||
|
def sd_seed_validator(cls, v: int) -> int:
|
||||||
|
if v == -1:
|
||||||
|
return random.randint(1, 99999999)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class RunPluginRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
image: Optional[str] = Field(..., description="base64 encoded image")
|
||||||
|
clicks: List[List[int]] = Field(
|
||||||
|
[], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]"
|
||||||
|
)
|
||||||
|
scale: float = Field(2.0, description="Scale for upscaling")
|
||||||
|
|
||||||
|
|
||||||
|
MediaTab = Literal["input", "output"]
|
||||||
|
|
||||||
|
|
||||||
|
class MediasRequest(BaseModel):
|
||||||
|
tab: MediaTab
|
||||||
|
|
||||||
|
|
||||||
|
class MediasResponse(BaseModel):
|
||||||
|
name: str
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
ctime: float
|
||||||
|
mtime: float
|
||||||
|
|
||||||
|
|
||||||
|
class MediaFileRequest(BaseModel):
|
||||||
|
tab: MediaTab
|
||||||
|
filename: str
|
||||||
|
|
||||||
|
|
||||||
|
class MediaThumbnailFileRequest(BaseModel):
|
||||||
|
tab: MediaTab
|
||||||
|
filename: str
|
||||||
|
width: int = 0
|
||||||
|
height: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class GenInfoResponse(BaseModel):
|
||||||
|
prompt: str = ""
|
||||||
|
negative_prompt: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ServerConfigResponse(BaseModel):
|
||||||
|
plugins: List[str]
|
||||||
|
enableFileManager: bool
|
||||||
|
enableAutoSaving: bool
|
||||||
|
enableControlnet: bool
|
||||||
|
controlnetMethod: Optional[str]
|
||||||
|
disableModelSwitch: bool
|
||||||
|
isDesktop: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchModelRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
@ -1,16 +1,28 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
|
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||||
|
cv2.setNumThreads(NUM_THREADS)
|
||||||
|
|
||||||
|
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
|
||||||
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
||||||
|
|
||||||
|
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||||
|
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||||
|
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||||
|
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||||
|
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
import imghdr
|
|
||||||
import io
|
import io
|
||||||
import logging
|
|
||||||
import multiprocessing
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -21,6 +33,11 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from lama_cleaner.const import *
|
from lama_cleaner.const import *
|
||||||
from lama_cleaner.file_manager import FileManager
|
from lama_cleaner.file_manager import FileManager
|
||||||
from lama_cleaner.model.utils import torch_gc
|
from lama_cleaner.model.utils import torch_gc
|
||||||
@ -31,8 +48,15 @@ from lama_cleaner.plugins import (
|
|||||||
AnimeSeg,
|
AnimeSeg,
|
||||||
build_plugins,
|
build_plugins,
|
||||||
)
|
)
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
from lama_cleaner.helper import (
|
||||||
|
load_img,
|
||||||
|
numpy_to_bytes,
|
||||||
|
resize_max_size,
|
||||||
|
pil_to_bytes,
|
||||||
|
is_mac,
|
||||||
|
get_image_ext, concat_alpha_channel,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
@ -42,454 +66,23 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from flask import (
|
|
||||||
Flask,
|
app = FastAPI()
|
||||||
request,
|
app.add_middleware(
|
||||||
send_file,
|
CORSMiddleware,
|
||||||
cli,
|
allow_origins=["*"],
|
||||||
make_response,
|
allow_credentials=True,
|
||||||
send_from_directory,
|
allow_methods=["*"],
|
||||||
jsonify,
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
from flask_socketio import SocketIO
|
|
||||||
|
|
||||||
# Disable ability for Flask to display warning about using a development server in a production environment.
|
|
||||||
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
|
||||||
cli.show_server_banner = lambda *_: None
|
|
||||||
from flask_cors import CORS
|
|
||||||
|
|
||||||
from lama_cleaner.helper import (
|
|
||||||
load_img,
|
|
||||||
numpy_to_bytes,
|
|
||||||
resize_max_size,
|
|
||||||
pil_to_bytes,
|
|
||||||
is_mac,
|
|
||||||
)
|
|
||||||
|
|
||||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
|
||||||
|
|
||||||
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
|
||||||
|
|
||||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
|
||||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
|
||||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
|
||||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
|
||||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
|
||||||
if os.environ.get("CACHE_DIR"):
|
|
||||||
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
|
||||||
|
|
||||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
|
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
|
||||||
|
|
||||||
|
|
||||||
class NoFlaskwebgui(logging.Filter):
|
|
||||||
def filter(self, record):
|
|
||||||
msg = record.getMessage()
|
|
||||||
if "Running on http:" in msg:
|
|
||||||
print(msg[msg.index("Running on http:") :])
|
|
||||||
|
|
||||||
return (
|
|
||||||
"flaskwebgui-keep-server-alive" not in msg
|
|
||||||
and "socket.io" not in msg
|
|
||||||
and "This is a development server." not in msg
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
|
||||||
|
|
||||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
|
||||||
app.config["JSON_AS_ASCII"] = False
|
|
||||||
CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
|
|
||||||
|
|
||||||
sio_logger = logging.getLogger("sio-logger")
|
|
||||||
sio_logger.setLevel(logging.ERROR)
|
|
||||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GlobalConfig:
|
|
||||||
model_manager: ModelManager = None
|
|
||||||
file_manager: FileManager = None
|
|
||||||
output_dir: Path = None
|
|
||||||
input_image_path: Path = None
|
|
||||||
disable_model_switch: bool = False
|
|
||||||
is_desktop: bool = False
|
|
||||||
image_quality: int = 95
|
|
||||||
plugins = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def enable_auto_saving(self) -> bool:
|
|
||||||
return self.output_dir is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def enable_file_manager(self) -> bool:
|
|
||||||
return self.file_manager is not None
|
|
||||||
|
|
||||||
|
|
||||||
global_config = GlobalConfig()
|
global_config = GlobalConfig()
|
||||||
|
|
||||||
|
|
||||||
def get_image_ext(img_bytes):
|
|
||||||
w = imghdr.what("", img_bytes)
|
|
||||||
if w is None:
|
|
||||||
w = "jpeg"
|
|
||||||
return w
|
|
||||||
|
|
||||||
|
|
||||||
def diffuser_callback(i, t, latents):
|
def diffuser_callback(i, t, latents):
|
||||||
socketio.emit("diffusion_progress", {"step": i})
|
socketio.emit("diffusion_progress", {"step": i})
|
||||||
|
|
||||||
|
|
||||||
@app.route("/save_image", methods=["POST"])
|
|
||||||
def save_image():
|
|
||||||
if global_config.output_dir is None:
|
|
||||||
return "--output-dir is None", 500
|
|
||||||
|
|
||||||
input = request.files
|
|
||||||
filename = request.form["filename"]
|
|
||||||
origin_image_bytes = input["image"].read() # RGB
|
|
||||||
# ext = get_image_ext(origin_image_bytes)
|
|
||||||
ext = "png"
|
|
||||||
image, alpha_channel, infos = load_img(origin_image_bytes, return_info=True)
|
|
||||||
save_path = (global_config.output_dir / filename).with_suffix(f".{ext}")
|
|
||||||
|
|
||||||
if alpha_channel is not None:
|
|
||||||
if alpha_channel.shape[:2] != image.shape[:2]:
|
|
||||||
alpha_channel = cv2.resize(
|
|
||||||
alpha_channel, dsize=(image.shape[1], image.shape[0])
|
|
||||||
)
|
|
||||||
image = np.concatenate((image, alpha_channel[:, :, np.newaxis]), axis=-1)
|
|
||||||
|
|
||||||
pil_image = Image.fromarray(image).convert("RGBA")
|
|
||||||
|
|
||||||
img_bytes = pil_to_bytes(
|
|
||||||
pil_image,
|
|
||||||
ext,
|
|
||||||
quality=global_config.image_quality,
|
|
||||||
infos=infos,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with open(save_path, "wb") as fw:
|
|
||||||
fw.write(img_bytes)
|
|
||||||
except:
|
|
||||||
return f"Save image failed: {traceback.format_exc()}", 500
|
|
||||||
|
|
||||||
return "ok", 200
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/medias/<tab>")
|
|
||||||
def medias(tab):
|
|
||||||
if tab == "image":
|
|
||||||
response = make_response(jsonify(global_config.file_manager.media_names), 200)
|
|
||||||
else:
|
|
||||||
response = make_response(
|
|
||||||
jsonify(global_config.file_manager.output_media_names), 200
|
|
||||||
)
|
|
||||||
# response.last_modified = thumb.modified_time[tab]
|
|
||||||
# response.cache_control.no_cache = True
|
|
||||||
# response.cache_control.max_age = 0
|
|
||||||
# response.make_conditional(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/media/<tab>/<filename>")
|
|
||||||
def media_file(tab, filename):
|
|
||||||
if tab == "image":
|
|
||||||
return send_from_directory(global_config.file_manager.root_directory, filename)
|
|
||||||
return send_from_directory(global_config.file_manager.output_dir, filename)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/media_thumbnail/<tab>/<filename>")
|
|
||||||
def media_thumbnail_file(tab, filename):
|
|
||||||
args = request.args
|
|
||||||
width = args.get("width")
|
|
||||||
height = args.get("height")
|
|
||||||
if width is None and height is None:
|
|
||||||
width = 256
|
|
||||||
if width:
|
|
||||||
width = int(float(width))
|
|
||||||
if height:
|
|
||||||
height = int(float(height))
|
|
||||||
|
|
||||||
directory = global_config.file_manager.root_directory
|
|
||||||
if tab == "output":
|
|
||||||
directory = global_config.file_manager.output_dir
|
|
||||||
thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
|
|
||||||
directory, filename, width, height
|
|
||||||
)
|
|
||||||
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
|
|
||||||
|
|
||||||
response = make_response(send_file(thumb_filepath))
|
|
||||||
response.headers["X-Width"] = str(width)
|
|
||||||
response.headers["X-Height"] = str(height)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/inpaint", methods=["POST"])
|
|
||||||
def process():
|
|
||||||
input = request.files
|
|
||||||
# RGB
|
|
||||||
origin_image_bytes = input["image"].read()
|
|
||||||
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_info=True)
|
|
||||||
|
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
|
||||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
|
||||||
|
|
||||||
if image.shape[:2] != mask.shape[:2]:
|
|
||||||
return (
|
|
||||||
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
|
|
||||||
400,
|
|
||||||
)
|
|
||||||
|
|
||||||
original_shape = image.shape
|
|
||||||
interpolation = cv2.INTER_CUBIC
|
|
||||||
|
|
||||||
form = request.form
|
|
||||||
size_limit = max(image.shape)
|
|
||||||
|
|
||||||
if "paintByExampleImage" in input:
|
|
||||||
paint_by_example_example_image, _ = load_img(
|
|
||||||
input["paintByExampleImage"].read()
|
|
||||||
)
|
|
||||||
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
|
|
||||||
else:
|
|
||||||
paint_by_example_example_image = None
|
|
||||||
|
|
||||||
config = Config(
|
|
||||||
ldm_steps=form["ldmSteps"],
|
|
||||||
ldm_sampler=form["ldmSampler"],
|
|
||||||
hd_strategy=form["hdStrategy"],
|
|
||||||
zits_wireframe=form["zitsWireframe"],
|
|
||||||
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
|
|
||||||
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
|
|
||||||
hd_strategy_resize_limit=form["hdStrategyResizeLimit"],
|
|
||||||
prompt=form["prompt"],
|
|
||||||
negative_prompt=form["negativePrompt"],
|
|
||||||
use_croper=form["useCroper"],
|
|
||||||
croper_x=form["croperX"],
|
|
||||||
croper_y=form["croperY"],
|
|
||||||
croper_height=form["croperHeight"],
|
|
||||||
croper_width=form["croperWidth"],
|
|
||||||
use_extender=form["useExtender"],
|
|
||||||
extender_x=form["extenderX"],
|
|
||||||
extender_y=form["extenderY"],
|
|
||||||
extender_height=form["extenderHeight"],
|
|
||||||
extender_width=form["extenderWidth"],
|
|
||||||
sd_scale=form["sdScale"],
|
|
||||||
sd_mask_blur=form["sdMaskBlur"],
|
|
||||||
sd_strength=form["sdStrength"],
|
|
||||||
sd_steps=form["sdSteps"],
|
|
||||||
sd_guidance_scale=form["sdGuidanceScale"],
|
|
||||||
sd_sampler=form["sdSampler"],
|
|
||||||
sd_seed=form["sdSeed"],
|
|
||||||
sd_freeu=form["enableFreeu"],
|
|
||||||
sd_freeu_config=json.loads(form["freeuConfig"]),
|
|
||||||
sd_lcm_lora=form["enableLCMLora"],
|
|
||||||
sd_match_histograms=form["sdMatchHistograms"],
|
|
||||||
cv2_flag=form["cv2Flag"],
|
|
||||||
cv2_radius=form["cv2Radius"],
|
|
||||||
paint_by_example_example_image=paint_by_example_example_image,
|
|
||||||
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
|
||||||
enable_controlnet=form["enable_controlnet"],
|
|
||||||
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
|
||||||
controlnet_method=form["controlnet_method"],
|
|
||||||
powerpaint_task=form["powerpaintTask"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.sd_seed == -1:
|
|
||||||
config.sd_seed = random.randint(1, 99999999)
|
|
||||||
|
|
||||||
logger.info(f"Origin image shape: {original_shape}")
|
|
||||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
|
||||||
|
|
||||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
try:
|
|
||||||
res_np_img = global_config.model_manager(image, mask, config)
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "CUDA out of memory. " in str(e):
|
|
||||||
# NOTE: the string may change?
|
|
||||||
return "CUDA out of memory", 500
|
|
||||||
elif "Invalid buffer size" in str(e) and is_mac():
|
|
||||||
return "Out of memory", 500
|
|
||||||
else:
|
|
||||||
logger.exception(e)
|
|
||||||
return f"{str(e)}", 500
|
|
||||||
finally:
|
|
||||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
|
||||||
torch_gc()
|
|
||||||
|
|
||||||
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
|
||||||
if alpha_channel is not None:
|
|
||||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
|
||||||
alpha_channel = cv2.resize(
|
|
||||||
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
|
|
||||||
)
|
|
||||||
res_np_img = np.concatenate(
|
|
||||||
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
ext = get_image_ext(origin_image_bytes)
|
|
||||||
|
|
||||||
bytes_io = io.BytesIO(
|
|
||||||
pil_to_bytes(
|
|
||||||
Image.fromarray(res_np_img),
|
|
||||||
ext,
|
|
||||||
quality=global_config.image_quality,
|
|
||||||
infos=exif_infos,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = make_response(
|
|
||||||
send_file(
|
|
||||||
# io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
|
||||||
bytes_io,
|
|
||||||
mimetype=f"image/{ext}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
response.headers["X-Seed"] = str(config.sd_seed)
|
|
||||||
|
|
||||||
socketio.emit("diffusion_finish")
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/run_plugin", methods=["POST"])
|
|
||||||
def run_plugin():
|
|
||||||
form = request.form
|
|
||||||
files = request.files
|
|
||||||
name = form["name"]
|
|
||||||
if name not in global_config.plugins:
|
|
||||||
return "Plugin not found", 500
|
|
||||||
|
|
||||||
origin_image_bytes = files["image"].read() # RGB
|
|
||||||
rgb_np_img, alpha_channel, infos = load_img(
|
|
||||||
origin_image_bytes, return_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
try:
|
|
||||||
form = dict(form)
|
|
||||||
if name == InteractiveSeg.name:
|
|
||||||
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
|
|
||||||
form["img_md5"] = img_md5
|
|
||||||
bgr_res = global_config.plugins[name](rgb_np_img, files, form)
|
|
||||||
except RuntimeError as e:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if "CUDA out of memory. " in str(e):
|
|
||||||
# NOTE: the string may change?
|
|
||||||
return "CUDA out of memory", 500
|
|
||||||
else:
|
|
||||||
logger.exception(e)
|
|
||||||
return "Internal Server Error", 500
|
|
||||||
|
|
||||||
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
|
|
||||||
torch_gc()
|
|
||||||
|
|
||||||
if name == InteractiveSeg.name:
|
|
||||||
return make_response(
|
|
||||||
send_file(
|
|
||||||
io.BytesIO(numpy_to_bytes(bgr_res, "png")),
|
|
||||||
mimetype="image/png",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if name in [RemoveBG.name, AnimeSeg.name]:
|
|
||||||
rgb_res = bgr_res
|
|
||||||
ext = "png"
|
|
||||||
else:
|
|
||||||
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
|
|
||||||
ext = get_image_ext(origin_image_bytes)
|
|
||||||
if alpha_channel is not None:
|
|
||||||
if alpha_channel.shape[:2] != rgb_res.shape[:2]:
|
|
||||||
alpha_channel = cv2.resize(
|
|
||||||
alpha_channel, dsize=(rgb_res.shape[1], rgb_res.shape[0])
|
|
||||||
)
|
|
||||||
rgb_res = np.concatenate(
|
|
||||||
(rgb_res, alpha_channel[:, :, np.newaxis]), axis=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
response = make_response(
|
|
||||||
send_file(
|
|
||||||
io.BytesIO(
|
|
||||||
pil_to_bytes(
|
|
||||||
Image.fromarray(rgb_res),
|
|
||||||
ext,
|
|
||||||
quality=global_config.image_quality,
|
|
||||||
infos=infos,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
mimetype=f"image/{ext}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/server_config", methods=["GET"])
|
|
||||||
def get_server_config():
|
|
||||||
return {
|
|
||||||
"plugins": list(global_config.plugins.keys()),
|
|
||||||
"enableFileManager": global_config.enable_file_manager,
|
|
||||||
"enableAutoSaving": global_config.enable_auto_saving,
|
|
||||||
"enableControlnet": global_config.model_manager.enable_controlnet,
|
|
||||||
"controlnetMethod": global_config.model_manager.controlnet_method,
|
|
||||||
"disableModelSwitch": global_config.disable_model_switch,
|
|
||||||
"isDesktop": global_config.is_desktop,
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/models", methods=["GET"])
|
|
||||||
def get_models():
|
|
||||||
return [it.model_dump() for it in global_config.model_manager.scan_models()]
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/model")
|
|
||||||
def current_model():
|
|
||||||
return (
|
|
||||||
global_config.model_manager.current_model,
|
|
||||||
200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/model", methods=["POST"])
|
|
||||||
def switch_model():
|
|
||||||
if global_config.disable_model_switch:
|
|
||||||
return "Switch model is disabled", 400
|
|
||||||
|
|
||||||
new_name = request.form.get("name")
|
|
||||||
if new_name == global_config.model_manager.name:
|
|
||||||
return "Same model", 200
|
|
||||||
|
|
||||||
try:
|
|
||||||
global_config.model_manager.switch(new_name)
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
error_message = f"{type(e).__name__} - {str(e)}"
|
|
||||||
logger.error(error_message)
|
|
||||||
return f"Switch model failed: {error_message}", 500
|
|
||||||
return f"ok, switch to {new_name}", 200
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/")
|
|
||||||
def index():
|
|
||||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/inputimage")
|
|
||||||
def get_cli_input_image():
|
|
||||||
if global_config.input_image_path:
|
|
||||||
with open(global_config.input_image_path, "rb") as f:
|
|
||||||
image_in_bytes = f.read()
|
|
||||||
return send_file(
|
|
||||||
global_config.input_image_path,
|
|
||||||
as_attachment=True,
|
|
||||||
download_name=Path(global_config.input_image_path).name,
|
|
||||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return "No Input Image"
|
|
||||||
|
|
||||||
|
|
||||||
def start(
|
def start(
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import InpaintRequest
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ def test_controlnet_switch_onoff(caplog):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.switch_controlnet_method(
|
model.switch_controlnet_method(
|
||||||
Config(
|
InpaintRequest(
|
||||||
name=name,
|
name=name,
|
||||||
enable_controlnet=False,
|
enable_controlnet=False,
|
||||||
)
|
)
|
||||||
@ -63,7 +63,7 @@ def test_switch_controlnet_method(caplog):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.switch_controlnet_method(
|
model.switch_controlnet_method(
|
||||||
Config(
|
InpaintRequest(
|
||||||
name=name,
|
name=name,
|
||||||
enable_controlnet=True,
|
enable_controlnet=True,
|
||||||
controlnet_method=new_method,
|
controlnet_method=new_method,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -18,9 +19,9 @@ def extra_info(img_p: Path):
|
|||||||
ext = img_p.suffix.strip(".")
|
ext = img_p.suffix.strip(".")
|
||||||
img_bytes = img_p.read_bytes()
|
img_bytes = img_p.read_bytes()
|
||||||
np_img, _, infos = load_img(img_bytes, False, True)
|
np_img, _, infos = load_img(img_bytes, False, True)
|
||||||
pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos)
|
res_pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos)
|
||||||
res_img = Image.open(io.BytesIO(pil_bytes))
|
res_img = Image.open(io.BytesIO(res_pil_bytes))
|
||||||
return infos, res_img.info
|
return infos, res_img.info, res_pil_bytes
|
||||||
|
|
||||||
|
|
||||||
def assert_keys(keys: List[str], infos, res_infos):
|
def assert_keys(keys: List[str], infos, res_infos):
|
||||||
@ -30,23 +31,29 @@ def assert_keys(keys: List[str], infos, res_infos):
|
|||||||
assert infos[k] == res_infos[k]
|
assert infos[k] == res_infos[k]
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(file_path, keys):
|
||||||
|
infos, res_infos, res_pil_bytes = extra_info(file_path)
|
||||||
|
assert_keys(keys, infos, res_infos)
|
||||||
|
with tempfile.NamedTemporaryFile("wb", suffix=file_path.suffix) as temp_file:
|
||||||
|
temp_file.write(res_pil_bytes)
|
||||||
|
temp_file.flush()
|
||||||
|
infos, res_infos, res_pil_bytes = extra_info(Path(temp_file.name))
|
||||||
|
assert_keys(keys, infos, res_infos)
|
||||||
|
|
||||||
|
|
||||||
def test_png_icc_profile_png():
|
def test_png_icc_profile_png():
|
||||||
infos, res_infos = extra_info(current_dir / "icc_profile_test.png")
|
run_test(current_dir / "icc_profile_test.png", ["icc_profile", "exif"])
|
||||||
assert_keys(["icc_profile", "exif"], infos, res_infos)
|
|
||||||
|
|
||||||
|
|
||||||
def test_png_icc_profile_jpeg():
|
def test_png_icc_profile_jpeg():
|
||||||
infos, res_infos = extra_info(current_dir / "icc_profile_test.jpg")
|
run_test(current_dir / "icc_profile_test.jpg", ["icc_profile", "exif"])
|
||||||
assert_keys(["icc_profile", "exif"], infos, res_infos)
|
|
||||||
|
|
||||||
|
|
||||||
def test_jpeg():
|
def test_jpeg():
|
||||||
jpg_img_p = current_dir / "bunny.jpeg"
|
jpg_img_p = current_dir / "bunny.jpeg"
|
||||||
infos, res_infos = extra_info(jpg_img_p)
|
run_test(jpg_img_p, ["dpi", "exif"])
|
||||||
assert_keys(["dpi", "exif"], infos, res_infos)
|
|
||||||
|
|
||||||
|
|
||||||
def test_png_parameter():
|
def test_png_parameter():
|
||||||
jpg_img_p = current_dir / "png_parameter_test.png"
|
jpg_img_p = current_dir / "png_parameter_test.png"
|
||||||
infos, res_infos = extra_info(jpg_img_p)
|
run_test(jpg_img_p, ["parameters"])
|
||||||
assert_keys(["parameters"], infos, res_infos)
|
|
||||||
|
@ -3,7 +3,7 @@ import cv2
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lama_cleaner.schema import LDMSampler, HDStrategy, Config, SDSampler
|
from lama_cleaner.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / "result"
|
save_dir = current_dir / "result"
|
||||||
@ -72,4 +72,4 @@ def get_config(**kwargs):
|
|||||||
hd_strategy_resize_limit=200,
|
hd_strategy_resize_limit=200,
|
||||||
)
|
)
|
||||||
data.update(**kwargs)
|
data.update(**kwargs)
|
||||||
return Config(**data)
|
return InpaintRequest(**data)
|
||||||
|
@ -43,7 +43,7 @@ def save_config(
|
|||||||
restoreformer_device,
|
restoreformer_device,
|
||||||
enable_gif,
|
enable_gif,
|
||||||
):
|
):
|
||||||
config = Config(**locals())
|
config = InpaintRequest(**locals())
|
||||||
print(config)
|
print(config)
|
||||||
if config.input and not os.path.exists(config.input):
|
if config.input and not os.path.exists(config.input):
|
||||||
return "[Error] Input file or directory does not exist"
|
return "[Error] Input file or directory does not exist"
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
pushd ./lama_cleaner/app
|
pushd ./web_app
|
||||||
yarn run build
|
npm run build
|
||||||
popd
|
popd
|
||||||
|
|
||||||
rm -r -f dist
|
rm -r -f dist
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
torch>=2.0.0
|
torch>=2.0.0
|
||||||
typer
|
typer
|
||||||
opencv-python
|
opencv-python
|
||||||
flask==2.2.3
|
fastapi==0.108.0
|
||||||
flask-socketio
|
python-multipart
|
||||||
simple-websocket
|
simple-websocket
|
||||||
flask_cors
|
flask_cors
|
||||||
flaskwebgui==0.3.5
|
flaskwebgui==0.3.5
|
||||||
|
@ -9,21 +9,28 @@ export default function useInputImage() {
|
|||||||
headers.append("pragma", "no-cache")
|
headers.append("pragma", "no-cache")
|
||||||
headers.append("cache-control", "no-cache")
|
headers.append("cache-control", "no-cache")
|
||||||
|
|
||||||
fetch(`${API_ENDPOINT}/inputimage`, { headers }).then(async (res) => {
|
fetch(`${API_ENDPOINT}/inputimage`, { headers })
|
||||||
const filename = res.headers
|
.then(async (res) => {
|
||||||
.get("content-disposition")
|
if (!res.ok) {
|
||||||
?.split("filename=")[1]
|
throw new Error("No input image found")
|
||||||
.split(";")[0]
|
}
|
||||||
|
const filename = res.headers
|
||||||
|
.get("content-disposition")
|
||||||
|
?.split("filename=")[1]
|
||||||
|
.split(";")[0]
|
||||||
|
|
||||||
const data = await res.blob()
|
const data = await res.blob()
|
||||||
if (data && data.type.startsWith("image")) {
|
if (data && data.type.startsWith("image")) {
|
||||||
const userInput = new File(
|
const userInput = new File(
|
||||||
[data],
|
[data],
|
||||||
filename !== undefined ? filename : "inputImage"
|
filename !== undefined ? filename : "inputImage"
|
||||||
)
|
)
|
||||||
setInputImage(userInput)
|
setInputImage(userInput)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err)
|
||||||
|
})
|
||||||
}, [setInputImage])
|
}, [setInputImage])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
|
import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
|
||||||
import { Settings } from "@/lib/states"
|
import { Settings } from "@/lib/states"
|
||||||
import { srcToFile } from "@/lib/utils"
|
import { convertToBase64, srcToFile } from "@/lib/utils"
|
||||||
import axios from "axios"
|
import axios, { AxiosError } from "axios"
|
||||||
|
|
||||||
export const API_ENDPOINT = import.meta.env.VITE_BACKEND
|
export const API_ENDPOINT = import.meta.env.VITE_BACKEND
|
||||||
? import.meta.env.VITE_BACKEND
|
? import.meta.env.VITE_BACKEND
|
||||||
: ""
|
: "/api/v1"
|
||||||
|
|
||||||
const api = axios.create({
|
const api = axios.create({
|
||||||
baseURL: API_ENDPOINT,
|
baseURL: API_ENDPOINT,
|
||||||
@ -19,96 +19,75 @@ export default async function inpaint(
|
|||||||
mask: File | Blob,
|
mask: File | Blob,
|
||||||
paintByExampleImage: File | null = null
|
paintByExampleImage: File | null = null
|
||||||
) {
|
) {
|
||||||
const fd = new FormData()
|
const imageBase64 = await convertToBase64(imageFile)
|
||||||
fd.append("image", imageFile)
|
const maskBase64 = await convertToBase64(mask)
|
||||||
fd.append("mask", mask)
|
const exampleImageBase64 = paintByExampleImage
|
||||||
fd.append("ldmSteps", settings.ldmSteps.toString())
|
? await convertToBase64(paintByExampleImage)
|
||||||
fd.append("ldmSampler", settings.ldmSampler.toString())
|
: null
|
||||||
fd.append("zitsWireframe", settings.zitsWireframe.toString())
|
|
||||||
fd.append("hdStrategy", "Crop")
|
|
||||||
fd.append("hdStrategyCropMargin", "128")
|
|
||||||
fd.append("hdStrategyCropTrigerSize", "640")
|
|
||||||
fd.append("hdStrategyResizeLimit", "2048")
|
|
||||||
|
|
||||||
fd.append("prompt", settings.prompt)
|
|
||||||
fd.append("negativePrompt", settings.negativePrompt)
|
|
||||||
|
|
||||||
fd.append("useCroper", settings.showCropper ? "true" : "false")
|
|
||||||
fd.append("croperX", croperRect.x.toString())
|
|
||||||
fd.append("croperY", croperRect.y.toString())
|
|
||||||
fd.append("croperHeight", croperRect.height.toString())
|
|
||||||
fd.append("croperWidth", croperRect.width.toString())
|
|
||||||
|
|
||||||
fd.append("useExtender", settings.showExtender ? "true" : "false")
|
|
||||||
fd.append("extenderX", extenderState.x.toString())
|
|
||||||
fd.append("extenderY", extenderState.y.toString())
|
|
||||||
fd.append("extenderHeight", extenderState.height.toString())
|
|
||||||
fd.append("extenderWidth", extenderState.width.toString())
|
|
||||||
|
|
||||||
fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
|
|
||||||
fd.append("sdStrength", settings.sdStrength.toString())
|
|
||||||
fd.append("sdSteps", settings.sdSteps.toString())
|
|
||||||
fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString())
|
|
||||||
fd.append("sdSampler", settings.sdSampler.toString())
|
|
||||||
|
|
||||||
if (settings.seedFixed) {
|
|
||||||
fd.append("sdSeed", settings.seed.toString())
|
|
||||||
} else {
|
|
||||||
fd.append("sdSeed", "-1")
|
|
||||||
}
|
|
||||||
|
|
||||||
fd.append("sdMatchHistograms", settings.sdMatchHistograms ? "true" : "false")
|
|
||||||
fd.append("sdScale", (settings.sdScale / 100).toString())
|
|
||||||
fd.append("enableFreeu", settings.enableFreeu.toString())
|
|
||||||
fd.append("freeuConfig", JSON.stringify(settings.freeuConfig))
|
|
||||||
fd.append("enableLCMLora", settings.enableLCMLora.toString())
|
|
||||||
|
|
||||||
fd.append("cv2Radius", settings.cv2Radius.toString())
|
|
||||||
fd.append("cv2Flag", settings.cv2Flag.toString())
|
|
||||||
|
|
||||||
// TODO: resize image's shortest_edge to 224 before pass to backend, save network time?
|
|
||||||
// https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor
|
|
||||||
if (paintByExampleImage) {
|
|
||||||
fd.append("paintByExampleImage", paintByExampleImage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InstructPix2Pix
|
|
||||||
fd.append("p2pImageGuidanceScale", settings.p2pImageGuidanceScale.toString())
|
|
||||||
|
|
||||||
// ControlNet
|
|
||||||
fd.append("enable_controlnet", settings.enableControlnet.toString())
|
|
||||||
fd.append(
|
|
||||||
"controlnet_conditioning_scale",
|
|
||||||
settings.controlnetConditioningScale.toString()
|
|
||||||
)
|
|
||||||
fd.append("controlnet_method", settings.controlnetMethod?.toString())
|
|
||||||
|
|
||||||
// PowerPaint
|
|
||||||
if (settings.showExtender) {
|
|
||||||
fd.append("powerpaintTask", PowerPaintTask.outpainting)
|
|
||||||
} else {
|
|
||||||
fd.append("powerpaintTask", settings.powerpaintTask)
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: fd,
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
image: imageBase64,
|
||||||
|
mask: maskBase64,
|
||||||
|
ldm_steps: settings.ldmSteps,
|
||||||
|
ldm_sampler: settings.ldmSampler,
|
||||||
|
zits_wireframe: settings.zitsWireframe,
|
||||||
|
cv2_flag: settings.cv2Flag,
|
||||||
|
cv2_radius: settings.cv2Radius,
|
||||||
|
hd_strategy: "Crop",
|
||||||
|
hd_strategy_crop_triger_size: 640,
|
||||||
|
hd_strategy_crop_margin: 128,
|
||||||
|
hd_trategy_resize_imit: 2048,
|
||||||
|
prompt: settings.prompt,
|
||||||
|
negative_prompt: settings.negativePrompt,
|
||||||
|
use_croper: settings.showCropper,
|
||||||
|
croper_x: croperRect.x,
|
||||||
|
croper_y: croperRect.y,
|
||||||
|
croper_height: croperRect.height,
|
||||||
|
croper_width: croperRect.width,
|
||||||
|
use_extender: settings.showExtender,
|
||||||
|
extender_x: extenderState.x,
|
||||||
|
extender_y: extenderState.y,
|
||||||
|
extender_height: extenderState.height,
|
||||||
|
extender_width: extenderState.width,
|
||||||
|
sd_mask_blur: settings.sdMaskBlur,
|
||||||
|
sd_strength: settings.sdStrength,
|
||||||
|
sd_steps: settings.sdSteps,
|
||||||
|
sd_guidance_scale: settings.sdGuidanceScale,
|
||||||
|
sd_sampler: settings.sdSampler,
|
||||||
|
sd_seed: settings.seedFixed ? settings.seed : -1,
|
||||||
|
sd_match_histograms: settings.sdMatchHistograms,
|
||||||
|
sd_freeu: settings.enableFreeu,
|
||||||
|
sd_freeu_config: settings.freeuConfig,
|
||||||
|
sd_lcm_lora: settings.enableLCMLora,
|
||||||
|
paint_by_example_example_image: exampleImageBase64,
|
||||||
|
p2p_image_guidance_scale: settings.p2pImageGuidanceScale,
|
||||||
|
enable_controlnet: settings.enableControlnet,
|
||||||
|
controlnet_conditioning_scale: settings.controlnetConditioningScale,
|
||||||
|
controlnet_method: settings.controlnetMethod
|
||||||
|
? settings.controlnetMethod
|
||||||
|
: "",
|
||||||
|
powerpaint_task: settings.showExtender
|
||||||
|
? PowerPaintTask.outpainting
|
||||||
|
: settings.powerpaintTask,
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
if (res.ok) {
|
const blob = await res.blob()
|
||||||
const blob = await res.blob()
|
return {
|
||||||
const newSeed = res.headers.get("X-seed")
|
blob: URL.createObjectURL(blob),
|
||||||
return { blob: URL.createObjectURL(blob), seed: newSeed }
|
seed: res.headers.get("X-Seed"),
|
||||||
}
|
}
|
||||||
const errMsg = await res.text()
|
} catch (error: any) {
|
||||||
throw new Error(errMsg)
|
throw new Error(`Something went wrong: ${JSON.stringify(error.message)}`)
|
||||||
} catch (error) {
|
|
||||||
throw new Error(`Something went wrong: ${error}`)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getServerConfig() {
|
export function getServerConfig() {
|
||||||
return fetch(`${API_ENDPOINT}/server_config`, {
|
return fetch(`${API_ENDPOINT}/server-config`, {
|
||||||
method: "GET",
|
method: "GET",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -491,10 +491,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
paintByExampleFile
|
paintByExampleFile
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!res) {
|
|
||||||
throw new Error("Something went wrong on server side.")
|
|
||||||
}
|
|
||||||
|
|
||||||
const { blob, seed } = res
|
const { blob, seed } = res
|
||||||
if (seed) {
|
if (seed) {
|
||||||
get().setSeed(parseInt(seed, 10))
|
get().setSeed(parseInt(seed, 10))
|
||||||
|
@ -223,3 +223,17 @@ export const generateMask = (
|
|||||||
|
|
||||||
return maskCanvas
|
return maskCanvas
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const convertToBase64 = (fileOrBlob: File | Blob): Promise<string> => {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const reader = new FileReader()
|
||||||
|
reader.onload = (event) => {
|
||||||
|
const base64String = event.target?.result as string
|
||||||
|
resolve(base64String)
|
||||||
|
}
|
||||||
|
reader.onerror = (error) => {
|
||||||
|
reject(error)
|
||||||
|
}
|
||||||
|
reader.readAsDataURL(fileOrBlob)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user