This commit is contained in:
Qing 2023-12-30 23:36:44 +08:00
parent 85c3397b97
commit c4abda3942
35 changed files with 969 additions and 854 deletions

323
lama_cleaner/api.py Normal file
View File

@ -0,0 +1,323 @@
import os
import threading
import time
import traceback
from pathlib import Path
from typing import Optional, Dict, List
import cv2
import torch
import numpy as np
from loguru import logger
from PIL import Image
import uvicorn
from fastapi import APIRouter, FastAPI, Request, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse, Response
from fastapi.staticfiles import StaticFiles
from lama_cleaner.helper import (
load_img,
decode_base64_to_image,
pil_to_bytes,
numpy_to_bytes,
concat_alpha_channel,
)
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
from lama_cleaner.schema import (
GenInfoResponse,
ApiConfig,
ServerConfigResponse,
SwitchModelRequest,
InpaintRequest,
RunPluginRequest,
)
from lama_cleaner.file_manager import FileManager
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
WEB_APP_DIR = CURRENT_DIR / "web_app"
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
traceback.print_exc()
return JSONResponse(
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_origins": ["*"],
"allow_credentials": True,
}
app.add_middleware(CORSMiddleware, **cors_options)
class Api:
def __init__(self, app: FastAPI, config: ApiConfig):
self.app = app
self.config = config
self.router = APIRouter()
self.queue_lock = threading.Lock()
api_middleware(self.app)
self.file_manager = self._build_file_manager()
self.plugins = self._build_plugins()
self.model_manager = self._build_model_manager()
# fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
def api_models(self) -> List[ModelInfo]:
return self.model_manager.scan_models()
def api_current_model(self) -> ModelInfo:
return self.model_manager.current_model
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
if req.name == self.model_manager.name:
return self.model_manager.current_model
self.model_manager.switch(req.name)
return self.model_manager.current_model
def api_server_config(self) -> ServerConfigResponse:
return ServerConfigResponse(
plugins=list(self.plugins.keys()),
enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet,
controlnetMethod=self.model_manager.controlnet_method,
disableModelSwitch=self.config.disable_model_switch,
isDesktop=self.config.gui,
)
def api_input_image(self) -> FileResponse:
if self.config.input and self.config.input.is_file():
return FileResponse(self.config.input)
raise HTTPException(status_code=404, detail="Input image not found")
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
_, _, info = load_img(file.file.read(), return_info=True)
parts = info.get("parameters", "").split("Negative prompt: ")
prompt = parts[0].strip()
negative_prompt = ""
if len(parts) > 1:
negative_prompt = parts[1].split("\n")[0].strip()
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
def api_inpaint(self, req: InpaintRequest):
image, alpha_channel, infos = decode_base64_to_image(req.image)
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
raise HTTPException(
400,
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
)
if req.paint_by_example_example_image:
paint_by_example_image, _, _ = decode_base64_to_image(
req.paint_by_example_example_image
)
start = time.time()
rgb_np_img = self.model_manager(image, mask, req)
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
torch_gc()
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
ext = "png"
res_img_bytes = pil_to_bytes(
Image.fromarray(rgb_res),
ext=ext,
quality=self.config.quality,
infos=infos,
)
return Response(
content=res_img_bytes,
media_type=f"image/{ext}",
headers={"X-Seed": str(req.sd_seed)},
)
def api_run_plugin(self, req: RunPluginRequest):
if req.name not in self.plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
image, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_res = self.plugins[req.name].run(image, req)
torch_gc()
if req.name == InteractiveSeg.name:
return Response(
content=numpy_to_bytes(bgr_res, "png"),
media_type="image/png",
)
ext = "png"
if req.name in [RemoveBG.name, AnimeSeg.name]:
rgb_res = bgr_res
else:
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
rgb_res = concat_alpha_channel(rgb_res, alpha_channel)
return Response(
content=pil_to_bytes(
Image.fromarray(rgb_res),
ext=ext,
quality=self.config.quality,
infos=infos,
),
media_type=f"image/{ext}",
)
def launch(self):
self.app.include_router(self.router)
uvicorn.run(
self.app,
host=self.config.host,
port=self.config.port,
timeout_keep_alive=60,
)
def _build_file_manager(self) -> Optional[FileManager]:
if self.config.input and self.config.input.is_dir():
logger.info(
f"Input is directory, initialize file manager {self.config.input}"
)
return FileManager(
app=self.app,
input_dir=self.config.input,
output_dir=self.config.output_dir,
)
return None
def _build_plugins(self) -> Dict:
return build_plugins(
self.config.enable_interactive_seg,
self.config.interactive_seg_model,
self.config.interactive_seg_device,
self.config.enable_remove_bg,
self.config.enable_anime_seg,
self.config.enable_realesrgan,
self.config.realesrgan_device,
self.config.realesrgan_model,
self.config.enable_gfpgan,
self.config.gfpgan_device,
self.config.enable_restoreformer,
self.config.restoreformer_device,
self.config.no_half,
)
def _build_model_manager(self):
return ModelManager(
name=self.config.model,
device=torch.device(self.config.device),
no_half=self.config.no_half,
disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder,
cpu_offload=self.config.cpu_offload,
)
if __name__ == "__main__":
from lama_cleaner.schema import InteractiveSegModel, RealESRGANModel
app = FastAPI()
api = Api(
app,
ApiConfig(
host="127.0.0.1",
port=8080,
model="lama",
no_half=False,
cpu_offload=False,
disable_nsfw_checker=False,
cpu_textencoder=False,
device="cpu",
gui=False,
disable_model_switch=False,
input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images",
output_dir="/Users/cwq/code/github/lama-cleaner/tmp",
quality=100,
enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device="cpu",
enable_remove_bg=False,
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device="cpu",
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
enable_gfpgan=False,
gfpgan_device="cpu",
enable_restoreformer=False,
restoreformer_device="cpu",
),
)
api.launch()

View File

@ -10,7 +10,7 @@ import psutil
import torch
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, SDSampler
from lama_cleaner.schema import InpaintRequest, HDStrategy, SDSampler
try:
torch._C._jit_override_can_fuse_on_cpu(False)
@ -36,7 +36,7 @@ def run_model(model, size):
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
mask = np.random.randint(0, 255, size).astype(np.uint8)
config = Config(
config = InpaintRequest(
ldm_steps=2,
hd_strategy=HDStrategy.ORIGINAL,
hd_strategy_crop_margin=128,
@ -44,7 +44,7 @@ def run_model(model, size):
hd_strategy_resize_limit=128,
prompt="a fox is sitting on a bench",
sd_steps=5,
sd_sampler=SDSampler.ddim
sd_sampler=SDSampler.ddim,
)
model(image, mask, config)
@ -75,7 +75,9 @@ def benchmark(model, times: int, empty_cache: bool):
# cpu_metrics.append(process.cpu_percent())
time_metrics.append((time.time() - start) * 1000)
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
gpu_memory_metrics.append(
nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
)
print(f"size: {size}".center(80, "-"))
# print(f"cpu: {format(cpu_metrics)}")

View File

@ -1,6 +1,7 @@
from pathlib import Path
import typer
from fastapi import FastAPI
from loguru import logger
from typer import Option
@ -38,6 +39,17 @@ def list_model(
print(it.name)
@typer_app.command(help="Processing image with lama cleaner")
def run(
input: Path = Option(..., help="Image file or folder containing images"),
output_dir: Path = Option(..., help="Output directory"),
config_path: Path = Option(..., help="Config file path"),
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
setup_model_dir(model_dir)
pass
@typer_app.command(help="Start lama cleaner server")
def start(
host: str = Option("127.0.0.1"),
@ -80,6 +92,16 @@ def start(
):
dump_environment_info()
device = check_device(device)
if input and not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will be saved to {output_dir}")
if not output_dir.exists():
logger.info(f"Create output directory {output_dir}")
output_dir.mkdir(parents=True)
model_dir = model_dir.expanduser().absolute()
setup_model_dir(model_dir)
@ -92,9 +114,13 @@ def start(
logger.info(f"{model} not found in {model_dir}, try to downloading")
cli_download_model(model, model_dir)
from lama_cleaner.server import start
from lama_cleaner.api import Api
from lama_cleaner.schema import ApiConfig
start(
app = FastAPI()
api = Api(
app,
ApiConfig(
host=host,
port=port,
model=model,
@ -120,4 +146,6 @@ def start(
gfpgan_device=gfpgan_device,
enable_restoreformer=enable_restoreformer,
restoreformer_device=restoreformer_device,
),
)
api.launch()

View File

@ -1,18 +1,19 @@
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
import os
from datetime import datetime
import cv2
import time
from io import BytesIO
from pathlib import Path
import numpy as np
# from watchdog.events import FileSystemEventHandler
# from watchdog.observers import Observer
from typing import List
from PIL import Image, ImageOps, PngImagePlugin
from loguru import logger
from fastapi import FastAPI, UploadFile, HTTPException
from starlette.responses import FileResponse
from ..schema import (
MediasResponse,
MediasRequest,
MediaFileRequest,
MediaTab,
MediaThumbnailFileRequest,
)
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
@ -21,132 +22,87 @@ from .utils import aspect_to_string, generate_filename, glob_img
class FileManager:
def __init__(self, app=None):
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
self.app = app
self._default_root_directory = "media"
self._default_thumbnail_directory = "media"
self._default_root_url = "/"
self._default_thumbnail_root_url = "/"
self._default_format = "JPEG"
self.output_dir: Path = None
if app is not None:
self.init_app(app)
self.input_dir: Path = input_dir
self.output_dir: Path = output_dir
self.image_dir_filenames = []
self.output_dir_filenames = []
if not self.thumbnail_directory.exists():
self.thumbnail_directory.mkdir(parents=True)
self.image_dir_observer = None
self.output_dir_observer = None
# fmt: off
self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse])
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None)
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None)
# fmt: on
self.modified_time = {
"image": datetime.utcnow(),
"output": datetime.utcnow(),
}
def api_save_image(self, file: UploadFile):
filename = file.filename
origin_image_bytes = file.file.read()
with open(self.output_dir / filename, "wb") as fw:
fw.write(origin_image_bytes)
# def start(self):
# self.image_dir_filenames = self._media_names(self.root_directory)
# self.output_dir_filenames = self._media_names(self.output_dir)
#
# logger.info(f"Start watching image directory: {self.root_directory}")
# self.image_dir_observer = Observer()
# self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
# self.image_dir_observer.start()
#
# logger.info(f"Start watching output directory: {self.output_dir}")
# self.output_dir_observer = Observer()
# self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
# self.output_dir_observer.start()
def api_medias(self, req: MediasRequest) -> List[MediasResponse]:
img_dir = self._get_dir(req.tab)
return self._media_names(img_dir)
def on_modified(self, event):
if not os.path.isdir(event.src_path):
return
if event.src_path == str(self.root_directory):
logger.info(f"Image directory {event.src_path} modified")
self.image_dir_filenames = self._media_names(self.root_directory)
self.modified_time["image"] = datetime.utcnow()
elif event.src_path == str(self.output_dir):
logger.info(f"Output directory {event.src_path} modified")
self.output_dir_filenames = self._media_names(self.output_dir)
self.modified_time["output"] = datetime.utcnow()
def api_media_file(self, req: MediaFileRequest) -> FileResponse:
file_path = self._get_file(req.tab, req.filename)
return FileResponse(file_path)
def init_app(self, app):
if self.app is None:
self.app = app
app.thumbnail_instance = self
if not hasattr(app, "extensions"):
app.extensions = {}
if "thumbnail" in app.extensions:
raise RuntimeError("Flask-thumbnail extension already initialized")
app.extensions["thumbnail"] = self
app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory)
app.config.setdefault(
"THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory
def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse:
img_dir = self._get_dir(req.tab)
thumb_filename, (width, height) = self.get_thumbnail(
img_dir, req.filename, width=req.width, height=req.height
)
app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url)
app.config.setdefault(
"THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url
thumbnail_filepath = self.thumbnail_directory / thumb_filename
return FileResponse(
thumbnail_filepath,
headers={
"X-Width": str(width),
"X-Height": str(height),
},
)
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
@property
def root_directory(self):
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
if os.path.isabs(path):
return path
def _get_dir(self, tab: MediaTab) -> Path:
if tab == "input":
return self.input_dir
elif tab == "output":
return self.output_dir
else:
return os.path.join(self.app.root_path, path)
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
def _get_file(self, tab: MediaTab, filename: str) -> Path:
file_path = self._get_dir(tab) / filename
if not file_path.exists():
raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
return file_path
@property
def thumbnail_directory(self):
path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"]
if os.path.isabs(path):
return path
else:
return os.path.join(self.app.root_path, path)
@property
def root_url(self):
return self.app.config["THUMBNAIL_MEDIA_URL"]
@property
def media_names(self):
# return self.image_dir_filenames
return self._media_names(self.root_directory)
@property
def output_media_names(self):
return self._media_names(self.output_dir)
# return self.output_dir_filenames
def thumbnail_directory(self) -> Path:
return self.output_dir / "thumbnails"
@staticmethod
def _media_names(directory: Path):
def _media_names(directory: Path) -> List[MediasResponse]:
names = sorted([it.name for it in glob_img(directory)])
res = []
for name in names:
path = os.path.join(directory, name)
img = Image.open(path)
res.append(
{
"name": name,
"height": img.height,
"width": img.width,
"ctime": os.path.getctime(path),
"mtime": os.path.getmtime(path),
}
MediasResponse(
name=name,
height=img.height,
width=img.width,
ctime=os.path.getctime(path),
mtime=os.path.getmtime(path),
)
)
return res
@property
def thumbnail_url(self):
return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"]
def get_thumbnail(
self, directory: Path, original_filename: str, width, height, **options
):
@ -161,7 +117,10 @@ class FileManager:
image = Image.open(BytesIO(storage.read(original_filepath)))
# keep ratio resize
if width is not None:
if not width and not height:
width = 256
if width != 0:
height = int(image.height * width / image.width)
else:
width = int(image.width * height / image.height)
@ -180,18 +139,15 @@ class FileManager:
thumbnail_filepath = os.path.join(
self.thumbnail_directory, original_path, thumbnail_filename
)
thumbnail_url = os.path.join(
self.thumbnail_url, original_path, thumbnail_filename
)
if storage.exists(thumbnail_filepath):
return thumbnail_url, (width, height)
return thumbnail_filepath, (width, height)
try:
image.load()
except (IOError, OSError):
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
return thumbnail_url, (width, height)
return thumbnail_filepath, (width, height)
# get original image format
options["format"] = options.get("format", image.format)
@ -203,7 +159,7 @@ class FileManager:
raw_data = self.get_raw_data(image, **options)
storage.save(thumbnail_filepath, raw_data)
return thumbnail_url, (width, height)
return thumbnail_filepath, (width, height)
def get_raw_data(self, image, **options):
data = {
@ -246,7 +202,7 @@ class FileManager:
if image.format:
return image.format
return self.app.config["THUMBNAIL_DEFAULT_FORMAT"]
return "JPEG"
def _create_thumbnail(self, image, size, crop="fit", background=None):
try:

View File

@ -1,7 +1,9 @@
import base64
import imghdr
import io
import os
import sys
from typing import List, Optional
from typing import List, Optional, Dict, Tuple
from urllib.parse import urlparse
import cv2
@ -138,8 +140,8 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
with io.BytesIO() as output:
kwargs = {k: v for k, v in infos.items() if v is not None}
if ext == 'jpg':
ext = 'jpeg'
if ext == "jpg":
ext = "jpeg"
if "png" == ext.lower() and "parameters" in kwargs:
pnginfo_data = PngImagePlugin.PngInfo()
pnginfo_data.add_text("parameters", kwargs["parameters"])
@ -290,3 +292,61 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
def is_mac():
return sys.platform == "darwin"
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def decode_base64_to_image(
encoding: str, gray=False
) -> Tuple[np.array, Optional[np.array], Dict]:
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
alpha_channel = None
infos = image.info
try:
image = ImageOps.exif_transpose(image)
except:
pass
if gray:
image = image.convert("L")
np_img = np.array(image)
else:
if image.mode == "RGBA":
np_img = np.array(image)
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else:
image = image.convert("RGB")
np_img = np.array(image)
return np_img, alpha_channel, infos
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
img_bytes = pil_to_bytes(
image,
"png",
quality=quality,
infos=infos,
)
return base64.b64encode(img_bytes)
def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
if alpha_channel is not None:
if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
)
rgb_np_img = np.concatenate(
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
return rgb_np_img

View File

@ -14,7 +14,7 @@ from lama_cleaner.helper import (
)
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, HDStrategy, SDSampler
from lama_cleaner.schema import InpaintRequest, HDStrategy, SDSampler
class InpaintModel:
@ -44,7 +44,7 @@ class InpaintModel:
return False
@abc.abstractmethod
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W, 1] 255 masks 区域
@ -56,7 +56,7 @@ class InpaintModel:
def download():
...
def _pad_forward(self, image, mask, config: Config):
def _pad_forward(self, image, mask, config: InpaintRequest):
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
@ -74,7 +74,7 @@ class InpaintModel:
result, image, mask = self.forward_post_process(result, image, mask, config)
if config.sd_prevent_unmasked_area:
if config.sd_keep_unmasked_area:
mask = mask[:, :, np.newaxis]
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
return result
@ -86,7 +86,7 @@ class InpaintModel:
return result, image, mask
@torch.no_grad()
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
@ -141,7 +141,7 @@ class InpaintModel:
return inpaint_result
def _crop_box(self, image, mask, box, config: Config):
def _crop_box(self, image, mask, box, config: InpaintRequest):
"""
Args:
@ -233,7 +233,7 @@ class InpaintModel:
return result
def _apply_cropper(self, image, mask, config: Config):
def _apply_cropper(self, image, mask, config: InpaintRequest):
img_h, img_w = image.shape[:2]
l, t, w, h = (
config.croper_x,
@ -253,7 +253,7 @@ class InpaintModel:
crop_mask = mask[t:b, l:r]
return crop_img, crop_mask, (l, t, r, b)
def _run_box(self, image, mask, box, config: Config):
def _run_box(self, image, mask, box, config: InpaintRequest):
"""
Args:
@ -276,7 +276,7 @@ class DiffusionInpaintModel(InpaintModel):
super().__init__(device, **kwargs)
@torch.no_grad()
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
@ -295,7 +295,7 @@ class DiffusionInpaintModel(InpaintModel):
return inpaint_result
def _do_outpainting(self, image, config: Config):
def _do_outpainting(self, image, config: InpaintRequest):
# cropper 和 image 在同一个坐标系下croper_x/y 可能为负数
# 从 image 中 crop 出 outpainting 区域
image_h, image_w = image.shape[:2]
@ -368,7 +368,7 @@ class DiffusionInpaintModel(InpaintModel):
] = expanded_cropped_result_image
return outpainting_image
def _scaled_pad_forward(self, image, mask, config: Config):
def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
@ -396,7 +396,7 @@ class DiffusionInpaintModel(InpaintModel):
# ]
return inpaint_result
def set_scheduler(self, config: Config):
def set_scheduler(self, config: InpaintRequest):
scheduler_config = self.model.scheduler.config
sd_sampler = config.sd_sampler
if config.sd_lcm_lora:

View File

@ -14,7 +14,7 @@ from lama_cleaner.model.helper.controlnet_preprocess import (
)
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions
from lama_cleaner.schema import Config, ModelType
from lama_cleaner.schema import InpaintRequest, ModelType
class ControlNet(DiffusionInpaintModel):
@ -130,7 +130,7 @@ class ControlNet(DiffusionInpaintModel):
raise NotImplementedError(f"{self.controlnet_method} not implemented")
return control_image
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint

View File

@ -6,7 +6,7 @@ import torch
import numpy as np
import torch.fft as fft
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
from lama_cleaner.helper import (
load_model,
@ -1665,7 +1665,7 @@ class FcF(InpaintModel):
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
@torch.no_grad()
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
@ -1705,7 +1705,7 @@ class FcF(InpaintModel):
return inpaint_result
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W] mask area == 255

View File

@ -4,7 +4,7 @@ import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
class InstructPix2Pix(DiffusionInpaintModel):
@ -40,7 +40,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
else:
self.model = self.model.to(device)
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint

View File

@ -5,7 +5,7 @@ import torch
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
class Kandinsky(DiffusionInpaintModel):
@ -29,7 +29,7 @@ class Kandinsky(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint

View File

@ -11,7 +11,7 @@ from lama_cleaner.helper import (
download_model,
)
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
@ -36,7 +36,7 @@ class LaMa(InpaintModel):
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]

View File

@ -7,7 +7,7 @@ from loguru import logger
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.model.ddim_sampler import DDIMSampler
from lama_cleaner.model.plms_sampler import PLMSSampler
from lama_cleaner.schema import Config, LDMSampler
from lama_cleaner.schema import InpaintRequest, LDMSampler
torch.manual_seed(42)
import torch.nn as nn
@ -277,7 +277,7 @@ class LDM(InpaintModel):
return all([os.path.exists(it) for it in model_paths])
@torch.cuda.amp.autocast()
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""
image: [H, W, C] RGB
mask: [H, W, 1]

View File

@ -9,7 +9,7 @@ from loguru import logger
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
@ -56,7 +56,7 @@ class Manga(InpaintModel):
]
return all([os.path.exists(it) for it in model_paths])
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""
image: [H, W, C] RGB
mask: [H, W, 1]

View File

@ -28,7 +28,7 @@ from lama_cleaner.model.utils import (
normalize_2nd_moment,
set_seed,
)
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
class ModulatedConv2d(nn.Module):
@ -1912,7 +1912,7 @@ class MAT(InpaintModel):
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W] mask area == 255

View File

@ -3,7 +3,6 @@ import os
import cv2
import torch
from lama_cleaner.const import Config
from lama_cleaner.helper import (
load_jit_model,
download_model,
@ -13,6 +12,7 @@ from lama_cleaner.helper import (
norm_img,
)
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import InpaintRequest
MIGAN_MODEL_URL = os.environ.get(
"MIGAN_MODEL_URL",
@ -40,7 +40,7 @@ class MIGAN(InpaintModel):
return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
@torch.no_grad()
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
@ -80,7 +80,7 @@ class MIGAN(InpaintModel):
return inpaint_result
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W] mask area == 255

View File

@ -1,6 +1,6 @@
import cv2
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
@ -14,7 +14,7 @@ class OpenCV2(InpaintModel):
def is_downloaded() -> bool:
return True
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1]

View File

@ -4,8 +4,9 @@ import cv2
import torch
from loguru import logger
from lama_cleaner.helper import decode_base64_to_image
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
class PaintByExample(DiffusionInpaintModel):
@ -38,16 +39,21 @@ class PaintByExample(DiffusionInpaintModel):
else:
self.model = self.model.to(device)
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
if config.paint_by_example_example_image is None:
raise ValueError("paint_by_example_example_image is required")
example_image, _, _ = decode_base64_to_image(
config.paint_by_example_example_image
)
output = self.model(
image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=config.paint_by_example_example_image,
example_image=PIL.Image.fromarray(example_image),
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",

View File

@ -7,7 +7,7 @@ from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
from .powerpaint_tokenizer import add_task_to_prompt
@ -58,7 +58,7 @@ class PowerPaint(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint

View File

@ -6,7 +6,7 @@ from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.schema import Config, ModelType
from lama_cleaner.schema import InpaintRequest, ModelType
class SD(DiffusionInpaintModel):
@ -64,7 +64,7 @@ class SD(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint

View File

@ -9,7 +9,7 @@ from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.schema import Config, ModelType
from lama_cleaner.schema import InpaintRequest, ModelType
class SDXL(DiffusionInpaintModel):
@ -60,7 +60,7 @@ class SDXL(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint

View File

@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
import numpy as np
from lama_cleaner.model.base import InpaintModel
@ -343,7 +343,7 @@ class ZITS(InpaintModel):
items["line"] = line_pred.detach()
@torch.no_grad()
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W]

View File

@ -8,7 +8,7 @@ from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model import models, ControlNet, SD, SDXL
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo, ModelType
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
class ModelManager:
@ -31,13 +31,15 @@ class ModelManager:
self.model = self.init_model(name, device, **kwargs)
@property
def current_model(self) -> Dict:
return self.available_models[self.name].model_dump()
def current_model(self) -> ModelInfo:
return self.available_models[self.name]
def init_model(self, name: str, device, **kwargs):
logger.info(f"Loading model: {name}")
if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}. Available models: {self.available_models.keys()}")
raise NotImplementedError(
f"Unsupported model: {name}. Available models: {self.available_models.keys()}"
)
model_info = self.available_models[name]
kwargs = {
@ -66,7 +68,17 @@ class ModelManager:
raise NotImplementedError(f"Unsupported model: {name}")
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
config:
Returns:
"""
self.switch_controlnet_method(config)
self.enable_disable_freeu(config)
self.enable_disable_lcm_lora(config)
@ -135,7 +147,7 @@ class ModelManager:
else:
logger.info(f"Enable controlnet: {config.controlnet_method}")
def enable_disable_freeu(self, config: Config):
def enable_disable_freeu(self, config: InpaintRequest):
if str(self.model.device) == "mps":
return
@ -151,7 +163,7 @@ class ModelManager:
else:
self.model.model.disable_freeu()
def enable_disable_lcm_lora(self, config: Config):
def enable_disable_lcm_lora(self, config: InpaintRequest):
if self.available_models[self.name].support_lcm_lora:
if config.sd_lcm_lora:
if not self.model.model.get_list_adapters():

View File

@ -10,7 +10,6 @@ from ..const import InteractiveSegModel, Device, RealESRGANModel
def build_plugins(
global_config,
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
@ -25,25 +24,26 @@ def build_plugins(
restoreformer_device: Device,
no_half: bool,
):
plugins = {}
if enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
plugins[InteractiveSeg.name] = InteractiveSeg(
interactive_seg_model, interactive_seg_device
)
if enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
global_config.plugins[RemoveBG.name] = RemoveBG()
plugins[RemoveBG.name] = RemoveBG()
if enable_anime_seg:
logger.info(f"Initialize {AnimeSeg.name} plugin")
global_config.plugins[AnimeSeg.name] = AnimeSeg()
plugins[AnimeSeg.name] = AnimeSeg()
if enable_realesrgan:
logger.info(
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
)
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
realesrgan_model,
realesrgan_device,
no_half=no_half,
@ -57,14 +57,15 @@ def build_plugins(
logger.info(
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
)
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
plugins[GFPGANPlugin.name] = GFPGANPlugin(
gfpgan_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
upscaler=plugins.get(RealESRGANUpscaler.name, None),
)
if enable_restoreformer:
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
restoreformer_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
upscaler=plugins.get(RealESRGANUpscaler.name, None),
)
return plugins

View File

@ -1,8 +1,17 @@
import random
from enum import Enum
from typing import Optional
from pathlib import Path
from typing import Optional, Literal, List
from PIL.Image import Image
from pydantic import BaseModel
from pydantic import BaseModel, Field, validator, field_validator
from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel
class CV2Flag(str, Enum):
INPAINT_NS = "INPAINT_NS"
INPAINT_TELEA = "INPAINT_TELEA"
class ModelType(str, Enum):
@ -56,93 +65,215 @@ class PowerPaintTask(str, Enum):
outpainting = "outpainting"
class Config(BaseModel):
class Config:
arbitrary_types_allowed = True
class ApiConfig(BaseModel):
host: str
port: int
model: str
no_half: bool
cpu_offload: bool
disable_nsfw_checker: bool
cpu_textencoder: bool
device: Device
gui: bool
disable_model_switch: bool
input: Path
output_dir: Path
quality: int
enable_interactive_seg: bool
interactive_seg_model: InteractiveSegModel
interactive_seg_device: Device
enable_remove_bg: bool
enable_anime_seg: bool
enable_realesrgan: bool
realesrgan_device: Device
realesrgan_model: RealESRGANModel
enable_gfpgan: bool
gfpgan_device: Device
enable_restoreformer: bool
restoreformer_device: Device
# Configs for ldm model
ldm_steps: int = 20
ldm_sampler: str = LDMSampler.plms
# Configs for zits model
zits_wireframe: bool = True
class InpaintRequest(BaseModel):
image: Optional[str] = Field(..., description="base64 encoded image")
mask: Optional[str] = Field(..., description="base64 encoded mask")
# Configs for High Resolution Strategy(different way to preprocess image)
hd_strategy: str = HDStrategy.CROP # See HDStrategy Enum
hd_strategy_crop_margin: int = 128
# If the longer side of the image is larger than this value, use crop strategy
hd_strategy_crop_trigger_size: int = 800
hd_strategy_resize_limit: int = 1280
ldm_steps: int = Field(20, description="Steps for ldm model.")
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
zits_wireframe: bool = Field(True, description="Enable wireframe for zits model.")
# Configs for Stable Diffusion 1.5
prompt: str = ""
negative_prompt: str = ""
# Crop image to this size before doing sd inpainting
# The value is always on the original image scale
use_croper: bool = False
croper_x: int = None
croper_y: int = None
croper_height: int = None
croper_width: int = None
use_extender: bool = False
extender_x: int = None
extender_y: int = None
extender_height: int = None
extender_width: int = None
hd_strategy: str = Field(
HDStrategy.CROP,
description="Different way to preprocess image, only used by erase models(e.g. lama/mat)",
)
hd_strategy_crop_trigger_size: int = Field(
800,
description="Crop trigger size for hd_strategy=CROP, if the longer side of the image is larger than this value, use crop strategy",
)
hd_strategy_crop_margin: int = Field(
128, description="Crop margin for hd_strategy=CROP"
)
hd_strategy_resize_limit: int = Field(
1280, description="Resize limit for hd_strategy=RESIZE"
)
# Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
# Used by sd models and paint_by_example model
sd_scale: float = 1.0
# Blur the edge of mask area. The higher the number the smoother blend with the original image
sd_mask_blur: int = 0
# Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
# starting point and more noise is added the higher the `strength`. The number of denoising steps depends
# on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
# process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
# essentially ignores `image`.
sd_strength: float = 1.0
# The number of denoising steps. More denoising steps usually lead to a
# higher quality image at the expense of slower inference.
sd_steps: int = 50
# Higher guidance scale encourages to generate images that are closely linked
# to the text prompt, usually at the expense of lower image quality.
sd_guidance_scale: float = 7.5
sd_sampler: str = SDSampler.uni_pc
# -1 mean random seed
sd_seed: int = 42
sd_match_histograms: bool = False
prompt: str = Field("", description="Prompt for diffusion models.")
negative_prompt: str = Field(
"", description="Negative prompt for diffusion models."
)
use_croper: bool = Field(
False, description="Crop image before doing diffusion inpainting"
)
croper_x: int = Field(0, description="Crop x for croper")
croper_y: int = Field(0, description="Crop y for croper")
croper_height: int = Field(512, description="Crop height for croper")
croper_width: int = Field(512, description="Crop width for croper")
# out-painting
sd_outpainting_softness: float = 20.0
sd_outpainting_space: float = 20.0
use_extender: bool = Field(
False, description="Extend image before doing sd outpainting"
)
extender_x: int = Field(0, description="Extend x for extender")
extender_y: int = Field(0, description="Extend y for extender")
extender_height: int = Field(640, description="Extend height for extender")
extender_width: int = Field(640, description="Extend width for extender")
# freeu
sd_freeu: bool = False
sd_mask_blur: int = Field(
33,
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
)
sd_strength: float = Field(
1.0,
description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image",
)
sd_steps: int = Field(
50,
description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.",
)
sd_guidance_scale: float = Field(
7.5,
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.",
)
sd_sampler: str = Field(
SDSampler.uni_pc, description="Sampler for diffusion model."
)
sd_seed: int = Field(
42,
description="Seed for diffusion model. -1 mean random seed",
validate_default=True,
)
sd_match_histograms: bool = Field(
False,
description="Match histograms between inpainting area and original image.",
)
sd_outpainting_softness: float = Field(20.0)
sd_outpainting_space: float = Field(20.0)
sd_freeu: bool = Field(
False,
description="Enable freeu mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu",
)
sd_freeu_config: FREEUConfig = FREEUConfig()
# lcm-lora
sd_lcm_lora: bool = False
sd_lcm_lora: bool = Field(
False,
description="Enable lcm-lora mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm#texttoimage",
)
# preserving the unmasked area at the expense of some more unnatural transitions between the masked and unmasked areas.
sd_prevent_unmasked_area: bool = True
sd_keep_unmasked_area: bool = Field(
True, description="Keep unmasked area unchanged"
)
# Configs for opencv inpainting
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
cv2_flag: str = "INPAINT_NS"
cv2_radius: int = 4
cv2_flag: CV2Flag = Field(
CV2Flag.INPAINT_NS,
description="Flag for opencv inpainting: https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07",
)
cv2_radius: int = Field(
4,
description="Radius of a circular neighborhood of each point inpainted that is considered by the algorithm",
)
# Paint by Example
paint_by_example_example_image: Optional[Image] = None
paint_by_example_example_image: Optional[str] = Field(
None, description="Base64 encoded example image for paint by example model"
)
# InstructPix2Pix
p2p_image_guidance_scale: float = 1.5
p2p_image_guidance_scale: float = Field(1.5, description="Image guidance scale")
# ControlNet
enable_controlnet: bool = False
controlnet_conditioning_scale: float = 0.4
controlnet_method: str = "lllyasviel/control_v11p_sd15_canny"
enable_controlnet: bool = Field(False, description="Enable controlnet")
controlnet_conditioning_scale: float = Field(0.4, description="Conditioning scale")
controlnet_method: str = Field(
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
)
# PowerPaint
powerpaint_task: PowerPaintTask = PowerPaintTask.text_guided
# control the fitting degree of the generated objects to the mask shape.
fitting_degree: float = 1.0
powerpaint_task: PowerPaintTask = Field(
PowerPaintTask.text_guided, description="PowerPaint task"
)
fitting_degree: float = Field(
1.0,
description="Control the fitting degree of the generated objects to the mask shape.",
)
@field_validator("sd_seed")
@classmethod
def sd_seed_validator(cls, v: int) -> int:
if v == -1:
return random.randint(1, 99999999)
return v
class RunPluginRequest(BaseModel):
name: str
image: Optional[str] = Field(..., description="base64 encoded image")
clicks: List[List[int]] = Field(
[], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]"
)
scale: float = Field(2.0, description="Scale for upscaling")
MediaTab = Literal["input", "output"]
class MediasRequest(BaseModel):
tab: MediaTab
class MediasResponse(BaseModel):
name: str
height: int
width: int
ctime: float
mtime: float
class MediaFileRequest(BaseModel):
tab: MediaTab
filename: str
class MediaThumbnailFileRequest(BaseModel):
tab: MediaTab
filename: str
width: int = 0
height: int = 0
class GenInfoResponse(BaseModel):
prompt: str = ""
negative_prompt: str = ""
class ServerConfigResponse(BaseModel):
plugins: List[str]
enableFileManager: bool
enableAutoSaving: bool
enableControlnet: bool
controlnetMethod: Optional[str]
disableModelSwitch: bool
isDesktop: bool
class SwitchModelRequest(BaseModel):
name: str

View File

@ -1,16 +1,28 @@
#!/usr/bin/env python3
import multiprocessing
import os
import cv2
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
NUM_THREADS = str(multiprocessing.cpu_count())
cv2.setNumThreads(NUM_THREADS)
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
import hashlib
import traceback
from dataclasses import dataclass
import imghdr
import io
import logging
import multiprocessing
import random
import time
from pathlib import Path
@ -21,6 +33,11 @@ import torch
from PIL import Image
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from lama_cleaner.const import *
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
@ -31,8 +48,15 @@ from lama_cleaner.plugins import (
AnimeSeg,
build_plugins,
)
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
pil_to_bytes,
is_mac,
get_image_ext, concat_alpha_channel,
)
try:
torch._C._jit_override_can_fuse_on_cpu(False)
@ -42,454 +66,23 @@ try:
except:
pass
from flask import (
Flask,
request,
send_file,
cli,
make_response,
send_from_directory,
jsonify,
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from flask_socketio import SocketIO
# Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
cli.show_server_banner = lambda *_: None
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
pil_to_bytes,
is_mac,
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
class NoFlaskwebgui(logging.Filter):
def filter(self, record):
msg = record.getMessage()
if "Running on http:" in msg:
print(msg[msg.index("Running on http:") :])
return (
"flaskwebgui-keep-server-alive" not in msg
and "socket.io" not in msg
and "This is a development server." not in msg
)
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
sio_logger = logging.getLogger("sio-logger")
sio_logger.setLevel(logging.ERROR)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
@dataclass
class GlobalConfig:
model_manager: ModelManager = None
file_manager: FileManager = None
output_dir: Path = None
input_image_path: Path = None
disable_model_switch: bool = False
is_desktop: bool = False
image_quality: int = 95
plugins = {}
@property
def enable_auto_saving(self) -> bool:
return self.output_dir is not None
@property
def enable_file_manager(self) -> bool:
return self.file_manager is not None
global_config = GlobalConfig()
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def diffuser_callback(i, t, latents):
socketio.emit("diffusion_progress", {"step": i})
@app.route("/save_image", methods=["POST"])
def save_image():
if global_config.output_dir is None:
return "--output-dir is None", 500
input = request.files
filename = request.form["filename"]
origin_image_bytes = input["image"].read() # RGB
# ext = get_image_ext(origin_image_bytes)
ext = "png"
image, alpha_channel, infos = load_img(origin_image_bytes, return_info=True)
save_path = (global_config.output_dir / filename).with_suffix(f".{ext}")
if alpha_channel is not None:
if alpha_channel.shape[:2] != image.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(image.shape[1], image.shape[0])
)
image = np.concatenate((image, alpha_channel[:, :, np.newaxis]), axis=-1)
pil_image = Image.fromarray(image).convert("RGBA")
img_bytes = pil_to_bytes(
pil_image,
ext,
quality=global_config.image_quality,
infos=infos,
)
try:
with open(save_path, "wb") as fw:
fw.write(img_bytes)
except:
return f"Save image failed: {traceback.format_exc()}", 500
return "ok", 200
@app.route("/medias/<tab>")
def medias(tab):
if tab == "image":
response = make_response(jsonify(global_config.file_manager.media_names), 200)
else:
response = make_response(
jsonify(global_config.file_manager.output_media_names), 200
)
# response.last_modified = thumb.modified_time[tab]
# response.cache_control.no_cache = True
# response.cache_control.max_age = 0
# response.make_conditional(request)
return response
@app.route("/media/<tab>/<filename>")
def media_file(tab, filename):
if tab == "image":
return send_from_directory(global_config.file_manager.root_directory, filename)
return send_from_directory(global_config.file_manager.output_dir, filename)
@app.route("/media_thumbnail/<tab>/<filename>")
def media_thumbnail_file(tab, filename):
args = request.args
width = args.get("width")
height = args.get("height")
if width is None and height is None:
width = 256
if width:
width = int(float(width))
if height:
height = int(float(height))
directory = global_config.file_manager.root_directory
if tab == "output":
directory = global_config.file_manager.output_dir
thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
directory, filename, width, height
)
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
response = make_response(send_file(thumb_filepath))
response.headers["X-Width"] = str(width)
response.headers["X-Height"] = str(height)
return response
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_info=True)
mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
return (
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
400,
)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
form = request.form
size_limit = max(image.shape)
if "paintByExampleImage" in input:
paint_by_example_example_image, _ = load_img(
input["paintByExampleImage"].read()
)
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else:
paint_by_example_example_image = None
config = Config(
ldm_steps=form["ldmSteps"],
ldm_sampler=form["ldmSampler"],
hd_strategy=form["hdStrategy"],
zits_wireframe=form["zitsWireframe"],
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
hd_strategy_resize_limit=form["hdStrategyResizeLimit"],
prompt=form["prompt"],
negative_prompt=form["negativePrompt"],
use_croper=form["useCroper"],
croper_x=form["croperX"],
croper_y=form["croperY"],
croper_height=form["croperHeight"],
croper_width=form["croperWidth"],
use_extender=form["useExtender"],
extender_x=form["extenderX"],
extender_y=form["extenderY"],
extender_height=form["extenderHeight"],
extender_width=form["extenderWidth"],
sd_scale=form["sdScale"],
sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"],
sd_steps=form["sdSteps"],
sd_guidance_scale=form["sdGuidanceScale"],
sd_sampler=form["sdSampler"],
sd_seed=form["sdSeed"],
sd_freeu=form["enableFreeu"],
sd_freeu_config=json.loads(form["freeuConfig"]),
sd_lcm_lora=form["enableLCMLora"],
sd_match_histograms=form["sdMatchHistograms"],
cv2_flag=form["cv2Flag"],
cv2_radius=form["cv2Radius"],
paint_by_example_example_image=paint_by_example_example_image,
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
enable_controlnet=form["enable_controlnet"],
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
controlnet_method=form["controlnet_method"],
powerpaint_task=form["powerpaintTask"],
)
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 99999999)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time()
try:
res_np_img = global_config.model_manager(image, mask, config)
except RuntimeError as e:
if "CUDA out of memory. " in str(e):
# NOTE: the string may change?
return "CUDA out of memory", 500
elif "Invalid buffer size" in str(e) and is_mac():
return "Out of memory", 500
else:
logger.exception(e)
return f"{str(e)}", 500
finally:
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch_gc()
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
bytes_io = io.BytesIO(
pil_to_bytes(
Image.fromarray(res_np_img),
ext,
quality=global_config.image_quality,
infos=exif_infos,
)
)
response = make_response(
send_file(
# io.BytesIO(numpy_to_bytes(res_np_img, ext)),
bytes_io,
mimetype=f"image/{ext}",
)
)
response.headers["X-Seed"] = str(config.sd_seed)
socketio.emit("diffusion_finish")
return response
@app.route("/run_plugin", methods=["POST"])
def run_plugin():
form = request.form
files = request.files
name = form["name"]
if name not in global_config.plugins:
return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB
rgb_np_img, alpha_channel, infos = load_img(
origin_image_bytes, return_info=True
)
start = time.time()
try:
form = dict(form)
if name == InteractiveSeg.name:
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
form["img_md5"] = img_md5
bgr_res = global_config.plugins[name](rgb_np_img, files, form)
except RuntimeError as e:
torch.cuda.empty_cache()
if "CUDA out of memory. " in str(e):
# NOTE: the string may change?
return "CUDA out of memory", 500
else:
logger.exception(e)
return "Internal Server Error", 500
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
torch_gc()
if name == InteractiveSeg.name:
return make_response(
send_file(
io.BytesIO(numpy_to_bytes(bgr_res, "png")),
mimetype="image/png",
)
)
if name in [RemoveBG.name, AnimeSeg.name]:
rgb_res = bgr_res
ext = "png"
else:
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
ext = get_image_ext(origin_image_bytes)
if alpha_channel is not None:
if alpha_channel.shape[:2] != rgb_res.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(rgb_res.shape[1], rgb_res.shape[0])
)
rgb_res = np.concatenate(
(rgb_res, alpha_channel[:, :, np.newaxis]), axis=-1
)
response = make_response(
send_file(
io.BytesIO(
pil_to_bytes(
Image.fromarray(rgb_res),
ext,
quality=global_config.image_quality,
infos=infos,
)
),
mimetype=f"image/{ext}",
)
)
return response
@app.route("/server_config", methods=["GET"])
def get_server_config():
return {
"plugins": list(global_config.plugins.keys()),
"enableFileManager": global_config.enable_file_manager,
"enableAutoSaving": global_config.enable_auto_saving,
"enableControlnet": global_config.model_manager.enable_controlnet,
"controlnetMethod": global_config.model_manager.controlnet_method,
"disableModelSwitch": global_config.disable_model_switch,
"isDesktop": global_config.is_desktop,
}, 200
@app.route("/models", methods=["GET"])
def get_models():
return [it.model_dump() for it in global_config.model_manager.scan_models()]
@app.route("/model")
def current_model():
return (
global_config.model_manager.current_model,
200,
)
@app.route("/model", methods=["POST"])
def switch_model():
if global_config.disable_model_switch:
return "Switch model is disabled", 400
new_name = request.form.get("name")
if new_name == global_config.model_manager.name:
return "Same model", 200
try:
global_config.model_manager.switch(new_name)
except Exception as e:
traceback.print_exc()
error_message = f"{type(e).__name__} - {str(e)}"
logger.error(error_message)
return f"Switch model failed: {error_message}", 500
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@app.route("/inputimage")
def get_cli_input_image():
if global_config.input_image_path:
with open(global_config.input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
global_config.input_image_path,
as_attachment=True,
download_name=Path(global_config.input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def start(
host: str,
port: int,

View File

@ -1,6 +1,6 @@
import os
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -38,7 +38,7 @@ def test_controlnet_switch_onoff(caplog):
)
model.switch_controlnet_method(
Config(
InpaintRequest(
name=name,
enable_controlnet=False,
)
@ -63,7 +63,7 @@ def test_switch_controlnet_method(caplog):
)
model.switch_controlnet_method(
Config(
InpaintRequest(
name=name,
enable_controlnet=True,
controlnet_method=new_method,

View File

@ -1,4 +1,5 @@
import io
import tempfile
from pathlib import Path
from typing import List
@ -18,9 +19,9 @@ def extra_info(img_p: Path):
ext = img_p.suffix.strip(".")
img_bytes = img_p.read_bytes()
np_img, _, infos = load_img(img_bytes, False, True)
pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos)
res_img = Image.open(io.BytesIO(pil_bytes))
return infos, res_img.info
res_pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos)
res_img = Image.open(io.BytesIO(res_pil_bytes))
return infos, res_img.info, res_pil_bytes
def assert_keys(keys: List[str], infos, res_infos):
@ -30,23 +31,29 @@ def assert_keys(keys: List[str], infos, res_infos):
assert infos[k] == res_infos[k]
def run_test(file_path, keys):
infos, res_infos, res_pil_bytes = extra_info(file_path)
assert_keys(keys, infos, res_infos)
with tempfile.NamedTemporaryFile("wb", suffix=file_path.suffix) as temp_file:
temp_file.write(res_pil_bytes)
temp_file.flush()
infos, res_infos, res_pil_bytes = extra_info(Path(temp_file.name))
assert_keys(keys, infos, res_infos)
def test_png_icc_profile_png():
infos, res_infos = extra_info(current_dir / "icc_profile_test.png")
assert_keys(["icc_profile", "exif"], infos, res_infos)
run_test(current_dir / "icc_profile_test.png", ["icc_profile", "exif"])
def test_png_icc_profile_jpeg():
infos, res_infos = extra_info(current_dir / "icc_profile_test.jpg")
assert_keys(["icc_profile", "exif"], infos, res_infos)
run_test(current_dir / "icc_profile_test.jpg", ["icc_profile", "exif"])
def test_jpeg():
jpg_img_p = current_dir / "bunny.jpeg"
infos, res_infos = extra_info(jpg_img_p)
assert_keys(["dpi", "exif"], infos, res_infos)
run_test(jpg_img_p, ["dpi", "exif"])
def test_png_parameter():
jpg_img_p = current_dir / "png_parameter_test.png"
infos, res_infos = extra_info(jpg_img_p)
assert_keys(["parameters"], infos, res_infos)
run_test(jpg_img_p, ["parameters"])

View File

@ -3,7 +3,7 @@ import cv2
import pytest
import torch
from lama_cleaner.schema import LDMSampler, HDStrategy, Config, SDSampler
from lama_cleaner.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
@ -72,4 +72,4 @@ def get_config(**kwargs):
hd_strategy_resize_limit=200,
)
data.update(**kwargs)
return Config(**data)
return InpaintRequest(**data)

View File

@ -43,7 +43,7 @@ def save_config(
restoreformer_device,
enable_gif,
):
config = Config(**locals())
config = InpaintRequest(**locals())
print(config)
if config.input and not os.path.exists(config.input):
return "[Error] Input file or directory does not exist"

View File

@ -1,8 +1,8 @@
#!/usr/bin/env bash
set -e
pushd ./lama_cleaner/app
yarn run build
pushd ./web_app
npm run build
popd
rm -r -f dist

View File

@ -1,8 +1,8 @@
torch>=2.0.0
typer
opencv-python
flask==2.2.3
flask-socketio
fastapi==0.108.0
python-multipart
simple-websocket
flask_cors
flaskwebgui==0.3.5

View File

@ -9,7 +9,11 @@ export default function useInputImage() {
headers.append("pragma", "no-cache")
headers.append("cache-control", "no-cache")
fetch(`${API_ENDPOINT}/inputimage`, { headers }).then(async (res) => {
fetch(`${API_ENDPOINT}/inputimage`, { headers })
.then(async (res) => {
if (!res.ok) {
throw new Error("No input image found")
}
const filename = res.headers
.get("content-disposition")
?.split("filename=")[1]
@ -24,6 +28,9 @@ export default function useInputImage() {
setInputImage(userInput)
}
})
.catch((err) => {
console.log(err)
})
}, [setInputImage])
useEffect(() => {

View File

@ -1,11 +1,11 @@
import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
import { Settings } from "@/lib/states"
import { srcToFile } from "@/lib/utils"
import axios from "axios"
import { convertToBase64, srcToFile } from "@/lib/utils"
import axios, { AxiosError } from "axios"
export const API_ENDPOINT = import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND
: ""
: "/api/v1"
const api = axios.create({
baseURL: API_ENDPOINT,
@ -19,96 +19,75 @@ export default async function inpaint(
mask: File | Blob,
paintByExampleImage: File | null = null
) {
const fd = new FormData()
fd.append("image", imageFile)
fd.append("mask", mask)
fd.append("ldmSteps", settings.ldmSteps.toString())
fd.append("ldmSampler", settings.ldmSampler.toString())
fd.append("zitsWireframe", settings.zitsWireframe.toString())
fd.append("hdStrategy", "Crop")
fd.append("hdStrategyCropMargin", "128")
fd.append("hdStrategyCropTrigerSize", "640")
fd.append("hdStrategyResizeLimit", "2048")
fd.append("prompt", settings.prompt)
fd.append("negativePrompt", settings.negativePrompt)
fd.append("useCroper", settings.showCropper ? "true" : "false")
fd.append("croperX", croperRect.x.toString())
fd.append("croperY", croperRect.y.toString())
fd.append("croperHeight", croperRect.height.toString())
fd.append("croperWidth", croperRect.width.toString())
fd.append("useExtender", settings.showExtender ? "true" : "false")
fd.append("extenderX", extenderState.x.toString())
fd.append("extenderY", extenderState.y.toString())
fd.append("extenderHeight", extenderState.height.toString())
fd.append("extenderWidth", extenderState.width.toString())
fd.append("sdMaskBlur", settings.sdMaskBlur.toString())
fd.append("sdStrength", settings.sdStrength.toString())
fd.append("sdSteps", settings.sdSteps.toString())
fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString())
fd.append("sdSampler", settings.sdSampler.toString())
if (settings.seedFixed) {
fd.append("sdSeed", settings.seed.toString())
} else {
fd.append("sdSeed", "-1")
}
fd.append("sdMatchHistograms", settings.sdMatchHistograms ? "true" : "false")
fd.append("sdScale", (settings.sdScale / 100).toString())
fd.append("enableFreeu", settings.enableFreeu.toString())
fd.append("freeuConfig", JSON.stringify(settings.freeuConfig))
fd.append("enableLCMLora", settings.enableLCMLora.toString())
fd.append("cv2Radius", settings.cv2Radius.toString())
fd.append("cv2Flag", settings.cv2Flag.toString())
// TODO: resize image's shortest_edge to 224 before pass to backend, save network time?
// https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor
if (paintByExampleImage) {
fd.append("paintByExampleImage", paintByExampleImage)
}
// InstructPix2Pix
fd.append("p2pImageGuidanceScale", settings.p2pImageGuidanceScale.toString())
// ControlNet
fd.append("enable_controlnet", settings.enableControlnet.toString())
fd.append(
"controlnet_conditioning_scale",
settings.controlnetConditioningScale.toString()
)
fd.append("controlnet_method", settings.controlnetMethod?.toString())
// PowerPaint
if (settings.showExtender) {
fd.append("powerpaintTask", PowerPaintTask.outpainting)
} else {
fd.append("powerpaintTask", settings.powerpaintTask)
}
const imageBase64 = await convertToBase64(imageFile)
const maskBase64 = await convertToBase64(mask)
const exampleImageBase64 = paintByExampleImage
? await convertToBase64(paintByExampleImage)
: null
try {
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: "POST",
body: fd,
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
image: imageBase64,
mask: maskBase64,
ldm_steps: settings.ldmSteps,
ldm_sampler: settings.ldmSampler,
zits_wireframe: settings.zitsWireframe,
cv2_flag: settings.cv2Flag,
cv2_radius: settings.cv2Radius,
hd_strategy: "Crop",
hd_strategy_crop_triger_size: 640,
hd_strategy_crop_margin: 128,
hd_trategy_resize_imit: 2048,
prompt: settings.prompt,
negative_prompt: settings.negativePrompt,
use_croper: settings.showCropper,
croper_x: croperRect.x,
croper_y: croperRect.y,
croper_height: croperRect.height,
croper_width: croperRect.width,
use_extender: settings.showExtender,
extender_x: extenderState.x,
extender_y: extenderState.y,
extender_height: extenderState.height,
extender_width: extenderState.width,
sd_mask_blur: settings.sdMaskBlur,
sd_strength: settings.sdStrength,
sd_steps: settings.sdSteps,
sd_guidance_scale: settings.sdGuidanceScale,
sd_sampler: settings.sdSampler,
sd_seed: settings.seedFixed ? settings.seed : -1,
sd_match_histograms: settings.sdMatchHistograms,
sd_freeu: settings.enableFreeu,
sd_freeu_config: settings.freeuConfig,
sd_lcm_lora: settings.enableLCMLora,
paint_by_example_example_image: exampleImageBase64,
p2p_image_guidance_scale: settings.p2pImageGuidanceScale,
enable_controlnet: settings.enableControlnet,
controlnet_conditioning_scale: settings.controlnetConditioningScale,
controlnet_method: settings.controlnetMethod
? settings.controlnetMethod
: "",
powerpaint_task: settings.showExtender
? PowerPaintTask.outpainting
: settings.powerpaintTask,
}),
})
if (res.ok) {
const blob = await res.blob()
const newSeed = res.headers.get("X-seed")
return { blob: URL.createObjectURL(blob), seed: newSeed }
return {
blob: URL.createObjectURL(blob),
seed: res.headers.get("X-Seed"),
}
const errMsg = await res.text()
throw new Error(errMsg)
} catch (error) {
throw new Error(`Something went wrong: ${error}`)
} catch (error: any) {
throw new Error(`Something went wrong: ${JSON.stringify(error.message)}`)
}
}
export function getServerConfig() {
return fetch(`${API_ENDPOINT}/server_config`, {
return fetch(`${API_ENDPOINT}/server-config`, {
method: "GET",
})
}

View File

@ -491,10 +491,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
paintByExampleFile
)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob, seed } = res
if (seed) {
get().setSeed(parseInt(seed, 10))

View File

@ -223,3 +223,17 @@ export const generateMask = (
return maskCanvas
}
export const convertToBase64 = (fileOrBlob: File | Blob): Promise<string> => {
return new Promise((resolve, reject) => {
const reader = new FileReader()
reader.onload = (event) => {
const base64String = event.target?.result as string
resolve(base64String)
}
reader.onerror = (error) => {
reject(error)
}
reader.readAsDataURL(fileOrBlob)
})
}