diff --git a/lama_cleaner/api.py b/lama_cleaner/api.py new file mode 100644 index 0000000..d5a1ae9 --- /dev/null +++ b/lama_cleaner/api.py @@ -0,0 +1,323 @@ +import os +import threading +import time +import traceback +from pathlib import Path +from typing import Optional, Dict, List + +import cv2 +import torch +import numpy as np +from loguru import logger +from PIL import Image + +import uvicorn +from fastapi import APIRouter, FastAPI, Request, UploadFile +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, FileResponse, Response +from fastapi.staticfiles import StaticFiles + +from lama_cleaner.helper import ( + load_img, + decode_base64_to_image, + pil_to_bytes, + numpy_to_bytes, + concat_alpha_channel, +) +from lama_cleaner.model.utils import torch_gc +from lama_cleaner.model_info import ModelInfo +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg +from lama_cleaner.schema import ( + GenInfoResponse, + ApiConfig, + ServerConfigResponse, + SwitchModelRequest, + InpaintRequest, + RunPluginRequest, +) +from lama_cleaner.file_manager import FileManager + +CURRENT_DIR = Path(__file__).parent.absolute().resolve() +WEB_APP_DIR = CURRENT_DIR / "web_app" + + +def api_middleware(app: FastAPI): + rich_available = False + try: + if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + + console = Console() + rich_available = True + except Exception: + pass + + def handle_exception(request: Request, e: Exception): + err = { + "error": type(e).__name__, + "detail": vars(e).get("detail", ""), + "body": vars(e).get("body", ""), + "errors": str(e), + } + if not isinstance( + e, HTTPException + ): # do not print backtrace on known httpexceptions + message = f"API error: {request.method}: {request.url} {err}" + if rich_available: + print(message) + console.print_exception( + show_locals=True, + max_frames=2, + extra_lines=1, + suppress=[anyio, starlette], + word_wrap=False, + width=min([console.width, 200]), + ) + else: + traceback.print_exc() + return JSONResponse( + status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err) + ) + + @app.middleware("http") + async def exception_handling(request: Request, call_next): + try: + return await call_next(request) + except Exception as e: + return handle_exception(request, e) + + @app.exception_handler(Exception) + async def fastapi_exception_handler(request: Request, e: Exception): + return handle_exception(request, e) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, e: HTTPException): + return handle_exception(request, e) + + cors_options = { + "allow_methods": ["*"], + "allow_headers": ["*"], + "allow_origins": ["*"], + "allow_credentials": True, + } + app.add_middleware(CORSMiddleware, **cors_options) + + +class Api: + def __init__(self, app: FastAPI, config: ApiConfig): + self.app = app + self.config = config + self.router = APIRouter() + self.queue_lock = threading.Lock() + api_middleware(self.app) + + self.file_manager = self._build_file_manager() + self.plugins = self._build_plugins() + self.model_manager = self._build_model_manager() + + # fmt: off + self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) + self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) + self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo]) + self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) + self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) + self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) + self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) + self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"]) + self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") + # fmt: on + + def add_api_route(self, path: str, endpoint, **kwargs): + return self.app.add_api_route(path, endpoint, **kwargs) + + def api_models(self) -> List[ModelInfo]: + return self.model_manager.scan_models() + + def api_current_model(self) -> ModelInfo: + return self.model_manager.current_model + + def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo: + if req.name == self.model_manager.name: + return self.model_manager.current_model + self.model_manager.switch(req.name) + return self.model_manager.current_model + + def api_server_config(self) -> ServerConfigResponse: + return ServerConfigResponse( + plugins=list(self.plugins.keys()), + enableFileManager=self.file_manager is not None, + enableAutoSaving=self.config.output_dir is not None, + enableControlnet=self.model_manager.enable_controlnet, + controlnetMethod=self.model_manager.controlnet_method, + disableModelSwitch=self.config.disable_model_switch, + isDesktop=self.config.gui, + ) + + def api_input_image(self) -> FileResponse: + if self.config.input and self.config.input.is_file(): + return FileResponse(self.config.input) + raise HTTPException(status_code=404, detail="Input image not found") + + def api_geninfo(self, file: UploadFile) -> GenInfoResponse: + _, _, info = load_img(file.file.read(), return_info=True) + parts = info.get("parameters", "").split("Negative prompt: ") + prompt = parts[0].strip() + negative_prompt = "" + if len(parts) > 1: + negative_prompt = parts[1].split("\n")[0].strip() + return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) + + def api_inpaint(self, req: InpaintRequest): + image, alpha_channel, infos = decode_base64_to_image(req.image) + mask, _, _ = decode_base64_to_image(req.mask, gray=True) + mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] + if image.shape[:2] != mask.shape[:2]: + raise HTTPException( + 400, + detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", + ) + + if req.paint_by_example_example_image: + paint_by_example_image, _, _ = decode_base64_to_image( + req.paint_by_example_example_image + ) + + start = time.time() + rgb_np_img = self.model_manager(image, mask, req) + logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") + torch_gc() + + rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) + rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) + + ext = "png" + res_img_bytes = pil_to_bytes( + Image.fromarray(rgb_res), + ext=ext, + quality=self.config.quality, + infos=infos, + ) + return Response( + content=res_img_bytes, + media_type=f"image/{ext}", + headers={"X-Seed": str(req.sd_seed)}, + ) + + def api_run_plugin(self, req: RunPluginRequest): + if req.name not in self.plugins: + raise HTTPException(status_code=404, detail="Plugin not found") + image, alpha_channel, infos = decode_base64_to_image(req.image) + bgr_res = self.plugins[req.name].run(image, req) + torch_gc() + if req.name == InteractiveSeg.name: + return Response( + content=numpy_to_bytes(bgr_res, "png"), + media_type="image/png", + ) + ext = "png" + if req.name in [RemoveBG.name, AnimeSeg.name]: + rgb_res = bgr_res + else: + rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB) + rgb_res = concat_alpha_channel(rgb_res, alpha_channel) + + return Response( + content=pil_to_bytes( + Image.fromarray(rgb_res), + ext=ext, + quality=self.config.quality, + infos=infos, + ), + media_type=f"image/{ext}", + ) + + def launch(self): + self.app.include_router(self.router) + uvicorn.run( + self.app, + host=self.config.host, + port=self.config.port, + timeout_keep_alive=60, + ) + + def _build_file_manager(self) -> Optional[FileManager]: + if self.config.input and self.config.input.is_dir(): + logger.info( + f"Input is directory, initialize file manager {self.config.input}" + ) + + return FileManager( + app=self.app, + input_dir=self.config.input, + output_dir=self.config.output_dir, + ) + return None + + def _build_plugins(self) -> Dict: + return build_plugins( + self.config.enable_interactive_seg, + self.config.interactive_seg_model, + self.config.interactive_seg_device, + self.config.enable_remove_bg, + self.config.enable_anime_seg, + self.config.enable_realesrgan, + self.config.realesrgan_device, + self.config.realesrgan_model, + self.config.enable_gfpgan, + self.config.gfpgan_device, + self.config.enable_restoreformer, + self.config.restoreformer_device, + self.config.no_half, + ) + + def _build_model_manager(self): + return ModelManager( + name=self.config.model, + device=torch.device(self.config.device), + no_half=self.config.no_half, + disable_nsfw=self.config.disable_nsfw_checker, + sd_cpu_textencoder=self.config.cpu_textencoder, + cpu_offload=self.config.cpu_offload, + ) + + +if __name__ == "__main__": + from lama_cleaner.schema import InteractiveSegModel, RealESRGANModel + + app = FastAPI() + api = Api( + app, + ApiConfig( + host="127.0.0.1", + port=8080, + model="lama", + no_half=False, + cpu_offload=False, + disable_nsfw_checker=False, + cpu_textencoder=False, + device="cpu", + gui=False, + disable_model_switch=False, + input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images", + output_dir="/Users/cwq/code/github/lama-cleaner/tmp", + quality=100, + enable_interactive_seg=False, + interactive_seg_model=InteractiveSegModel.vit_b, + interactive_seg_device="cpu", + enable_remove_bg=False, + enable_anime_seg=False, + enable_realesrgan=False, + realesrgan_device="cpu", + realesrgan_model=RealESRGANModel.realesr_general_x4v3, + enable_gfpgan=False, + gfpgan_device="cpu", + enable_restoreformer=False, + restoreformer_device="cpu", + ), + ) + api.launch() diff --git a/lama_cleaner/benchmark.py b/lama_cleaner/benchmark.py index feb29e9..d8bcbbb 100644 --- a/lama_cleaner/benchmark.py +++ b/lama_cleaner/benchmark.py @@ -10,7 +10,7 @@ import psutil import torch from lama_cleaner.model_manager import ModelManager -from lama_cleaner.schema import Config, HDStrategy, SDSampler +from lama_cleaner.schema import InpaintRequest, HDStrategy, SDSampler try: torch._C._jit_override_can_fuse_on_cpu(False) @@ -36,7 +36,7 @@ def run_model(model, size): image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8) mask = np.random.randint(0, 255, size).astype(np.uint8) - config = Config( + config = InpaintRequest( ldm_steps=2, hd_strategy=HDStrategy.ORIGINAL, hd_strategy_crop_margin=128, @@ -44,7 +44,7 @@ def run_model(model, size): hd_strategy_resize_limit=128, prompt="a fox is sitting on a bench", sd_steps=5, - sd_sampler=SDSampler.ddim + sd_sampler=SDSampler.ddim, ) model(image, mask, config) @@ -75,7 +75,9 @@ def benchmark(model, times: int, empty_cache: bool): # cpu_metrics.append(process.cpu_percent()) time_metrics.append((time.time() - start) * 1000) memory_metrics.append(process.memory_info().rss / 1024 / 1024) - gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024) + gpu_memory_metrics.append( + nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024 + ) print(f"size: {size}".center(80, "-")) # print(f"cpu: {format(cpu_metrics)}") diff --git a/lama_cleaner/cli.py b/lama_cleaner/cli.py index c4021bb..fbdd167 100644 --- a/lama_cleaner/cli.py +++ b/lama_cleaner/cli.py @@ -1,6 +1,7 @@ from pathlib import Path import typer +from fastapi import FastAPI from loguru import logger from typer import Option @@ -38,6 +39,17 @@ def list_model( print(it.name) +@typer_app.command(help="Processing image with lama cleaner") +def run( + input: Path = Option(..., help="Image file or folder containing images"), + output_dir: Path = Option(..., help="Output directory"), + config_path: Path = Option(..., help="Config file path"), + model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False), +): + setup_model_dir(model_dir) + pass + + @typer_app.command(help="Start lama cleaner server") def start( host: str = Option("127.0.0.1"), @@ -80,6 +92,16 @@ def start( ): dump_environment_info() device = check_device(device) + if input and not input.exists(): + logger.error(f"invalid --input: {input} not exists") + exit() + if output_dir: + output_dir = output_dir.expanduser().absolute() + logger.info(f"Image will be saved to {output_dir}") + if not output_dir.exists(): + logger.info(f"Create output directory {output_dir}") + output_dir.mkdir(parents=True) + model_dir = model_dir.expanduser().absolute() setup_model_dir(model_dir) @@ -92,32 +114,38 @@ def start( logger.info(f"{model} not found in {model_dir}, try to downloading") cli_download_model(model, model_dir) - from lama_cleaner.server import start + from lama_cleaner.api import Api + from lama_cleaner.schema import ApiConfig - start( - host=host, - port=port, - model=model, - no_half=no_half, - cpu_offload=cpu_offload, - disable_nsfw_checker=disable_nsfw_checker, - cpu_textencoder=cpu_textencoder, - device=device, - gui=gui, - disable_model_switch=disable_model_switch, - input=input, - output_dir=output_dir, - quality=quality, - enable_interactive_seg=enable_interactive_seg, - interactive_seg_model=interactive_seg_model, - interactive_seg_device=interactive_seg_device, - enable_remove_bg=enable_remove_bg, - enable_anime_seg=enable_anime_seg, - enable_realesrgan=enable_realesrgan, - realesrgan_device=realesrgan_device, - realesrgan_model=realesrgan_model, - enable_gfpgan=enable_gfpgan, - gfpgan_device=gfpgan_device, - enable_restoreformer=enable_restoreformer, - restoreformer_device=restoreformer_device, + app = FastAPI() + api = Api( + app, + ApiConfig( + host=host, + port=port, + model=model, + no_half=no_half, + cpu_offload=cpu_offload, + disable_nsfw_checker=disable_nsfw_checker, + cpu_textencoder=cpu_textencoder, + device=device, + gui=gui, + disable_model_switch=disable_model_switch, + input=input, + output_dir=output_dir, + quality=quality, + enable_interactive_seg=enable_interactive_seg, + interactive_seg_model=interactive_seg_model, + interactive_seg_device=interactive_seg_device, + enable_remove_bg=enable_remove_bg, + enable_anime_seg=enable_anime_seg, + enable_realesrgan=enable_realesrgan, + realesrgan_device=realesrgan_device, + realesrgan_model=realesrgan_model, + enable_gfpgan=enable_gfpgan, + gfpgan_device=gfpgan_device, + enable_restoreformer=enable_restoreformer, + restoreformer_device=restoreformer_device, + ), ) + api.launch() diff --git a/lama_cleaner/file_manager/file_manager.py b/lama_cleaner/file_manager/file_manager.py index 0b513c1..6a38cfa 100644 --- a/lama_cleaner/file_manager/file_manager.py +++ b/lama_cleaner/file_manager/file_manager.py @@ -1,18 +1,19 @@ -# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py import os -from datetime import datetime - -import cv2 -import time from io import BytesIO from pathlib import Path -import numpy as np - -# from watchdog.events import FileSystemEventHandler -# from watchdog.observers import Observer +from typing import List from PIL import Image, ImageOps, PngImagePlugin -from loguru import logger +from fastapi import FastAPI, UploadFile, HTTPException +from starlette.responses import FileResponse + +from ..schema import ( + MediasResponse, + MediasRequest, + MediaFileRequest, + MediaTab, + MediaThumbnailFileRequest, +) LARGE_ENOUGH_NUMBER = 100 PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) @@ -21,132 +22,87 @@ from .utils import aspect_to_string, generate_filename, glob_img class FileManager: - def __init__(self, app=None): + def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path): self.app = app - self._default_root_directory = "media" - self._default_thumbnail_directory = "media" - self._default_root_url = "/" - self._default_thumbnail_root_url = "/" - self._default_format = "JPEG" - self.output_dir: Path = None - - if app is not None: - self.init_app(app) + self.input_dir: Path = input_dir + self.output_dir: Path = output_dir self.image_dir_filenames = [] self.output_dir_filenames = [] + if not self.thumbnail_directory.exists(): + self.thumbnail_directory.mkdir(parents=True) - self.image_dir_observer = None - self.output_dir_observer = None + # fmt: off + self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) + self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse]) + self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None) + self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None) + # fmt: on - self.modified_time = { - "image": datetime.utcnow(), - "output": datetime.utcnow(), - } + def api_save_image(self, file: UploadFile): + filename = file.filename + origin_image_bytes = file.file.read() + with open(self.output_dir / filename, "wb") as fw: + fw.write(origin_image_bytes) - # def start(self): - # self.image_dir_filenames = self._media_names(self.root_directory) - # self.output_dir_filenames = self._media_names(self.output_dir) - # - # logger.info(f"Start watching image directory: {self.root_directory}") - # self.image_dir_observer = Observer() - # self.image_dir_observer.schedule(self, self.root_directory, recursive=False) - # self.image_dir_observer.start() - # - # logger.info(f"Start watching output directory: {self.output_dir}") - # self.output_dir_observer = Observer() - # self.output_dir_observer.schedule(self, self.output_dir, recursive=False) - # self.output_dir_observer.start() + def api_medias(self, req: MediasRequest) -> List[MediasResponse]: + img_dir = self._get_dir(req.tab) + return self._media_names(img_dir) - def on_modified(self, event): - if not os.path.isdir(event.src_path): - return - if event.src_path == str(self.root_directory): - logger.info(f"Image directory {event.src_path} modified") - self.image_dir_filenames = self._media_names(self.root_directory) - self.modified_time["image"] = datetime.utcnow() - elif event.src_path == str(self.output_dir): - logger.info(f"Output directory {event.src_path} modified") - self.output_dir_filenames = self._media_names(self.output_dir) - self.modified_time["output"] = datetime.utcnow() + def api_media_file(self, req: MediaFileRequest) -> FileResponse: + file_path = self._get_file(req.tab, req.filename) + return FileResponse(file_path) - def init_app(self, app): - if self.app is None: - self.app = app - app.thumbnail_instance = self - - if not hasattr(app, "extensions"): - app.extensions = {} - - if "thumbnail" in app.extensions: - raise RuntimeError("Flask-thumbnail extension already initialized") - - app.extensions["thumbnail"] = self - - app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory) - app.config.setdefault( - "THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory + def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse: + img_dir = self._get_dir(req.tab) + thumb_filename, (width, height) = self.get_thumbnail( + img_dir, req.filename, width=req.width, height=req.height ) - app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url) - app.config.setdefault( - "THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url + thumbnail_filepath = self.thumbnail_directory / thumb_filename + return FileResponse( + thumbnail_filepath, + headers={ + "X-Width": str(width), + "X-Height": str(height), + }, ) - app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format) - @property - def root_directory(self): - path = self.app.config["THUMBNAIL_MEDIA_ROOT"] - - if os.path.isabs(path): - return path + def _get_dir(self, tab: MediaTab) -> Path: + if tab == "input": + return self.input_dir + elif tab == "output": + return self.output_dir else: - return os.path.join(self.app.root_path, path) + raise HTTPException(status_code=422, detail=f"tab not found: {tab}") + + def _get_file(self, tab: MediaTab, filename: str) -> Path: + file_path = self._get_dir(tab) / filename + if not file_path.exists(): + raise HTTPException(status_code=422, detail=f"file not found: {file_path}") + return file_path @property - def thumbnail_directory(self): - path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] - - if os.path.isabs(path): - return path - else: - return os.path.join(self.app.root_path, path) - - @property - def root_url(self): - return self.app.config["THUMBNAIL_MEDIA_URL"] - - @property - def media_names(self): - # return self.image_dir_filenames - return self._media_names(self.root_directory) - - @property - def output_media_names(self): - return self._media_names(self.output_dir) - # return self.output_dir_filenames + def thumbnail_directory(self) -> Path: + return self.output_dir / "thumbnails" @staticmethod - def _media_names(directory: Path): + def _media_names(directory: Path) -> List[MediasResponse]: names = sorted([it.name for it in glob_img(directory)]) res = [] for name in names: path = os.path.join(directory, name) img = Image.open(path) res.append( - { - "name": name, - "height": img.height, - "width": img.width, - "ctime": os.path.getctime(path), - "mtime": os.path.getmtime(path), - } + MediasResponse( + name=name, + height=img.height, + width=img.width, + ctime=os.path.getctime(path), + mtime=os.path.getmtime(path), + ) ) return res - @property - def thumbnail_url(self): - return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"] - def get_thumbnail( self, directory: Path, original_filename: str, width, height, **options ): @@ -161,7 +117,10 @@ class FileManager: image = Image.open(BytesIO(storage.read(original_filepath))) # keep ratio resize - if width is not None: + if not width and not height: + width = 256 + + if width != 0: height = int(image.height * width / image.width) else: width = int(image.width * height / image.height) @@ -180,18 +139,15 @@ class FileManager: thumbnail_filepath = os.path.join( self.thumbnail_directory, original_path, thumbnail_filename ) - thumbnail_url = os.path.join( - self.thumbnail_url, original_path, thumbnail_filename - ) if storage.exists(thumbnail_filepath): - return thumbnail_url, (width, height) + return thumbnail_filepath, (width, height) try: image.load() except (IOError, OSError): self.app.logger.warning("Thumbnail not load image: %s", original_filepath) - return thumbnail_url, (width, height) + return thumbnail_filepath, (width, height) # get original image format options["format"] = options.get("format", image.format) @@ -203,7 +159,7 @@ class FileManager: raw_data = self.get_raw_data(image, **options) storage.save(thumbnail_filepath, raw_data) - return thumbnail_url, (width, height) + return thumbnail_filepath, (width, height) def get_raw_data(self, image, **options): data = { @@ -246,7 +202,7 @@ class FileManager: if image.format: return image.format - return self.app.config["THUMBNAIL_DEFAULT_FORMAT"] + return "JPEG" def _create_thumbnail(self, image, size, crop="fit", background=None): try: diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index e8a5f68..715ae51 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -1,7 +1,9 @@ +import base64 +import imghdr import io import os import sys -from typing import List, Optional +from typing import List, Optional, Dict, Tuple from urllib.parse import urlparse import cv2 @@ -138,8 +140,8 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes: with io.BytesIO() as output: kwargs = {k: v for k, v in infos.items() if v is not None} - if ext == 'jpg': - ext = 'jpeg' + if ext == "jpg": + ext = "jpeg" if "png" == ext.lower() and "parameters" in kwargs: pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data.add_text("parameters", kwargs["parameters"]) @@ -290,3 +292,61 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: def is_mac(): return sys.platform == "darwin" + + +def get_image_ext(img_bytes): + w = imghdr.what("", img_bytes) + if w is None: + w = "jpeg" + return w + + +def decode_base64_to_image( + encoding: str, gray=False +) -> Tuple[np.array, Optional[np.array], Dict]: + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + image = Image.open(io.BytesIO(base64.b64decode(encoding))) + + alpha_channel = None + infos = image.info + try: + image = ImageOps.exif_transpose(image) + except: + pass + + if gray: + image = image.convert("L") + np_img = np.array(image) + else: + if image.mode == "RGBA": + np_img = np.array(image) + alpha_channel = np_img[:, :, -1] + np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) + else: + image = image.convert("RGB") + np_img = np.array(image) + + return np_img, alpha_channel, infos + + +def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes: + img_bytes = pil_to_bytes( + image, + "png", + quality=quality, + infos=infos, + ) + return base64.b64encode(img_bytes) + + +def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray: + if alpha_channel is not None: + if alpha_channel.shape[:2] != rgb_np_img.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0]) + ) + rgb_np_img = np.concatenate( + (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 + ) + return rgb_np_img diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index dd65d55..82c60c6 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -14,7 +14,7 @@ from lama_cleaner.helper import ( ) from lama_cleaner.model.helper.g_diffuser_bot import expand_image from lama_cleaner.model.utils import get_scheduler -from lama_cleaner.schema import Config, HDStrategy, SDSampler +from lama_cleaner.schema import InpaintRequest, HDStrategy, SDSampler class InpaintModel: @@ -44,7 +44,7 @@ class InpaintModel: return False @abc.abstractmethod - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input images and output images have same size images: [H, W, C] RGB masks: [H, W, 1] 255 为 masks 区域 @@ -56,7 +56,7 @@ class InpaintModel: def download(): ... - def _pad_forward(self, image, mask, config: Config): + def _pad_forward(self, image, mask, config: InpaintRequest): origin_height, origin_width = image.shape[:2] pad_image = pad_img_to_modulo( image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size @@ -74,7 +74,7 @@ class InpaintModel: result, image, mask = self.forward_post_process(result, image, mask, config) - if config.sd_prevent_unmasked_area: + if config.sd_keep_unmasked_area: mask = mask[:, :, np.newaxis] result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) return result @@ -86,7 +86,7 @@ class InpaintModel: return result, image, mask @torch.no_grad() - def __call__(self, image, mask, config: Config): + def __call__(self, image, mask, config: InpaintRequest): """ images: [H, W, C] RGB, not normalized masks: [H, W] @@ -141,7 +141,7 @@ class InpaintModel: return inpaint_result - def _crop_box(self, image, mask, box, config: Config): + def _crop_box(self, image, mask, box, config: InpaintRequest): """ Args: @@ -233,7 +233,7 @@ class InpaintModel: return result - def _apply_cropper(self, image, mask, config: Config): + def _apply_cropper(self, image, mask, config: InpaintRequest): img_h, img_w = image.shape[:2] l, t, w, h = ( config.croper_x, @@ -253,7 +253,7 @@ class InpaintModel: crop_mask = mask[t:b, l:r] return crop_img, crop_mask, (l, t, r, b) - def _run_box(self, image, mask, box, config: Config): + def _run_box(self, image, mask, box, config: InpaintRequest): """ Args: @@ -276,7 +276,7 @@ class DiffusionInpaintModel(InpaintModel): super().__init__(device, **kwargs) @torch.no_grad() - def __call__(self, image, mask, config: Config): + def __call__(self, image, mask, config: InpaintRequest): """ images: [H, W, C] RGB, not normalized masks: [H, W] @@ -295,7 +295,7 @@ class DiffusionInpaintModel(InpaintModel): return inpaint_result - def _do_outpainting(self, image, config: Config): + def _do_outpainting(self, image, config: InpaintRequest): # cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数 # 从 image 中 crop 出 outpainting 区域 image_h, image_w = image.shape[:2] @@ -368,7 +368,7 @@ class DiffusionInpaintModel(InpaintModel): ] = expanded_cropped_result_image return outpainting_image - def _scaled_pad_forward(self, image, mask, config: Config): + def _scaled_pad_forward(self, image, mask, config: InpaintRequest): longer_side_length = int(config.sd_scale * max(image.shape[:2])) origin_size = image.shape[:2] downsize_image = resize_max_size(image, size_limit=longer_side_length) @@ -396,7 +396,7 @@ class DiffusionInpaintModel(InpaintModel): # ] return inpaint_result - def set_scheduler(self, config: Config): + def set_scheduler(self, config: InpaintRequest): scheduler_config = self.model.scheduler.config sd_sampler = config.sd_sampler if config.sd_lcm_lora: diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index b7288f8..cbe1e5f 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -14,7 +14,7 @@ from lama_cleaner.model.helper.controlnet_preprocess import ( ) from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions -from lama_cleaner.schema import Config, ModelType +from lama_cleaner.schema import InpaintRequest, ModelType class ControlNet(DiffusionInpaintModel): @@ -130,7 +130,7 @@ class ControlNet(DiffusionInpaintModel): raise NotImplementedError(f"{self.controlnet_method} not implemented") return control_image - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] 255 means area to repaint diff --git a/lama_cleaner/model/fcf.py b/lama_cleaner/model/fcf.py index 9cfc2be..10e9113 100644 --- a/lama_cleaner/model/fcf.py +++ b/lama_cleaner/model/fcf.py @@ -6,7 +6,7 @@ import torch import numpy as np import torch.fft as fft -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest from lama_cleaner.helper import ( load_model, @@ -1665,7 +1665,7 @@ class FcF(InpaintModel): return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL)) @torch.no_grad() - def __call__(self, image, mask, config: Config): + def __call__(self, image, mask, config: InpaintRequest): """ images: [H, W, C] RGB, not normalized masks: [H, W] @@ -1705,7 +1705,7 @@ class FcF(InpaintModel): return inpaint_result - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input images and output images have same size images: [H, W, C] RGB masks: [H, W] mask area == 255 diff --git a/lama_cleaner/model/instruct_pix2pix.py b/lama_cleaner/model/instruct_pix2pix.py index 18990b1..bf03ff7 100644 --- a/lama_cleaner/model/instruct_pix2pix.py +++ b/lama_cleaner/model/instruct_pix2pix.py @@ -4,7 +4,7 @@ import torch from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest class InstructPix2Pix(DiffusionInpaintModel): @@ -40,7 +40,7 @@ class InstructPix2Pix(DiffusionInpaintModel): else: self.model = self.model.to(device) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] 255 means area to repaint diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index 0645af7..965aa72 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -5,7 +5,7 @@ import torch from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.utils import get_scheduler -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest class Kandinsky(DiffusionInpaintModel): @@ -29,7 +29,7 @@ class Kandinsky(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] 255 means area to repaint diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index f1dd239..0a2e6f5 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -11,7 +11,7 @@ from lama_cleaner.helper import ( download_model, ) from lama_cleaner.model.base import InpaintModel -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest LAMA_MODEL_URL = os.environ.get( "LAMA_MODEL_URL", @@ -36,7 +36,7 @@ class LaMa(InpaintModel): def is_downloaded() -> bool: return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W] diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index 4066ad3..a338e8e 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -7,7 +7,7 @@ from loguru import logger from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.ddim_sampler import DDIMSampler from lama_cleaner.model.plms_sampler import PLMSSampler -from lama_cleaner.schema import Config, LDMSampler +from lama_cleaner.schema import InpaintRequest, LDMSampler torch.manual_seed(42) import torch.nn as nn @@ -277,7 +277,7 @@ class LDM(InpaintModel): return all([os.path.exists(it) for it in model_paths]) @torch.cuda.amp.autocast() - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """ image: [H, W, C] RGB mask: [H, W, 1] diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py index 76e4c86..b675e6a 100644 --- a/lama_cleaner/model/manga.py +++ b/lama_cleaner/model/manga.py @@ -9,7 +9,7 @@ from loguru import logger from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model from lama_cleaner.model.base import InpaintModel -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest MANGA_INPAINTOR_MODEL_URL = os.environ.get( @@ -56,7 +56,7 @@ class Manga(InpaintModel): ] return all([os.path.exists(it) for it in model_paths]) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """ image: [H, W, C] RGB mask: [H, W, 1] diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py index 49ad54e..cc01284 100644 --- a/lama_cleaner/model/mat.py +++ b/lama_cleaner/model/mat.py @@ -28,7 +28,7 @@ from lama_cleaner.model.utils import ( normalize_2nd_moment, set_seed, ) -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest class ModulatedConv2d(nn.Module): @@ -1912,7 +1912,7 @@ class MAT(InpaintModel): def is_downloaded() -> bool: return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL)) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input images and output images have same size images: [H, W, C] RGB masks: [H, W] mask area == 255 diff --git a/lama_cleaner/model/mi_gan.py b/lama_cleaner/model/mi_gan.py index d8ec0fa..bd85ee2 100644 --- a/lama_cleaner/model/mi_gan.py +++ b/lama_cleaner/model/mi_gan.py @@ -3,7 +3,6 @@ import os import cv2 import torch -from lama_cleaner.const import Config from lama_cleaner.helper import ( load_jit_model, download_model, @@ -13,6 +12,7 @@ from lama_cleaner.helper import ( norm_img, ) from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import InpaintRequest MIGAN_MODEL_URL = os.environ.get( "MIGAN_MODEL_URL", @@ -40,7 +40,7 @@ class MIGAN(InpaintModel): return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL)) @torch.no_grad() - def __call__(self, image, mask, config: Config): + def __call__(self, image, mask, config: InpaintRequest): """ images: [H, W, C] RGB, not normalized masks: [H, W] @@ -80,7 +80,7 @@ class MIGAN(InpaintModel): return inpaint_result - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input images and output images have same size images: [H, W, C] RGB masks: [H, W] mask area == 255 diff --git a/lama_cleaner/model/opencv2.py b/lama_cleaner/model/opencv2.py index cfbde9e..b13a13d 100644 --- a/lama_cleaner/model/opencv2.py +++ b/lama_cleaner/model/opencv2.py @@ -1,6 +1,6 @@ import cv2 from lama_cleaner.model.base import InpaintModel -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA} @@ -14,7 +14,7 @@ class OpenCV2(InpaintModel): def is_downloaded() -> bool: return True - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index 07d3842..a8cbdbb 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -4,8 +4,9 @@ import cv2 import torch from loguru import logger +from lama_cleaner.helper import decode_base64_to_image from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest class PaintByExample(DiffusionInpaintModel): @@ -38,16 +39,21 @@ class PaintByExample(DiffusionInpaintModel): else: self.model = self.model.to(device) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] 255 means area to repaint return: BGR IMAGE """ + if config.paint_by_example_example_image is None: + raise ValueError("paint_by_example_example_image is required") + example_image, _, _ = decode_base64_to_image( + config.paint_by_example_example_image + ) output = self.model( image=PIL.Image.fromarray(image), mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), - example_image=config.paint_by_example_example_image, + example_image=PIL.Image.fromarray(example_image), num_inference_steps=config.sd_steps, guidance_scale=config.sd_guidance_scale, negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature", diff --git a/lama_cleaner/model/power_paint/power_paint.py b/lama_cleaner/model/power_paint/power_paint.py index 014f403..d10ad95 100644 --- a/lama_cleaner/model/power_paint/power_paint.py +++ b/lama_cleaner/model/power_paint/power_paint.py @@ -7,7 +7,7 @@ from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.model.utils import handle_from_pretrained_exceptions -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest from .powerpaint_tokenizer import add_task_to_prompt @@ -58,7 +58,7 @@ class PowerPaint(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] 255 means area to repaint diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 1c7fcdd..6f77abf 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -6,7 +6,7 @@ from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper from lama_cleaner.model.utils import handle_from_pretrained_exceptions -from lama_cleaner.schema import Config, ModelType +from lama_cleaner.schema import InpaintRequest, ModelType class SD(DiffusionInpaintModel): @@ -64,7 +64,7 @@ class SD(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] 255 means area to repaint diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index b25451f..f1bb60c 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -9,7 +9,7 @@ from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.utils import handle_from_pretrained_exceptions -from lama_cleaner.schema import Config, ModelType +from lama_cleaner.schema import InpaintRequest, ModelType class SDXL(DiffusionInpaintModel): @@ -60,7 +60,7 @@ class SDXL(DiffusionInpaintModel): self.callback = kwargs.pop("callback", None) - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input image and output image have same size image: [H, W, C] RGB mask: [H, W, 1] 255 means area to repaint diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index 22fae23..8858a88 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, download_model -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest import numpy as np from lama_cleaner.model.base import InpaintModel @@ -343,7 +343,7 @@ class ZITS(InpaintModel): items["line"] = line_pred.detach() @torch.no_grad() - def forward(self, image, mask, config: Config): + def forward(self, image, mask, config: InpaintRequest): """Input images and output images have same size images: [H, W, C] RGB masks: [H, W] diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index d790573..f31bdaa 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -8,7 +8,7 @@ from lama_cleaner.helper import switch_mps_device from lama_cleaner.model import models, ControlNet, SD, SDXL from lama_cleaner.model.utils import torch_gc from lama_cleaner.model_info import ModelInfo, ModelType -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest class ModelManager: @@ -31,13 +31,15 @@ class ModelManager: self.model = self.init_model(name, device, **kwargs) @property - def current_model(self) -> Dict: - return self.available_models[self.name].model_dump() + def current_model(self) -> ModelInfo: + return self.available_models[self.name] def init_model(self, name: str, device, **kwargs): logger.info(f"Loading model: {name}") if name not in self.available_models: - raise NotImplementedError(f"Unsupported model: {name}. Available models: {self.available_models.keys()}") + raise NotImplementedError( + f"Unsupported model: {name}. Available models: {self.available_models.keys()}" + ) model_info = self.available_models[name] kwargs = { @@ -66,7 +68,17 @@ class ModelManager: raise NotImplementedError(f"Unsupported model: {name}") - def __call__(self, image, mask, config: Config): + def __call__(self, image, mask, config: InpaintRequest): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + config: + + Returns: + + """ self.switch_controlnet_method(config) self.enable_disable_freeu(config) self.enable_disable_lcm_lora(config) @@ -135,7 +147,7 @@ class ModelManager: else: logger.info(f"Enable controlnet: {config.controlnet_method}") - def enable_disable_freeu(self, config: Config): + def enable_disable_freeu(self, config: InpaintRequest): if str(self.model.device) == "mps": return @@ -151,7 +163,7 @@ class ModelManager: else: self.model.model.disable_freeu() - def enable_disable_lcm_lora(self, config: Config): + def enable_disable_lcm_lora(self, config: InpaintRequest): if self.available_models[self.name].support_lcm_lora: if config.sd_lcm_lora: if not self.model.model.get_list_adapters(): diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py index addd683..69ee2b5 100644 --- a/lama_cleaner/plugins/__init__.py +++ b/lama_cleaner/plugins/__init__.py @@ -10,7 +10,6 @@ from ..const import InteractiveSegModel, Device, RealESRGANModel def build_plugins( - global_config, enable_interactive_seg: bool, interactive_seg_model: InteractiveSegModel, interactive_seg_device: Device, @@ -25,25 +24,26 @@ def build_plugins( restoreformer_device: Device, no_half: bool, ): + plugins = {} if enable_interactive_seg: logger.info(f"Initialize {InteractiveSeg.name} plugin") - global_config.plugins[InteractiveSeg.name] = InteractiveSeg( + plugins[InteractiveSeg.name] = InteractiveSeg( interactive_seg_model, interactive_seg_device ) if enable_remove_bg: logger.info(f"Initialize {RemoveBG.name} plugin") - global_config.plugins[RemoveBG.name] = RemoveBG() + plugins[RemoveBG.name] = RemoveBG() if enable_anime_seg: logger.info(f"Initialize {AnimeSeg.name} plugin") - global_config.plugins[AnimeSeg.name] = AnimeSeg() + plugins[AnimeSeg.name] = AnimeSeg() if enable_realesrgan: logger.info( f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}" ) - global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( + plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( realesrgan_model, realesrgan_device, no_half=no_half, @@ -57,14 +57,15 @@ def build_plugins( logger.info( f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" ) - global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin( + plugins[GFPGANPlugin.name] = GFPGANPlugin( gfpgan_device, - upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), + upscaler=plugins.get(RealESRGANUpscaler.name, None), ) if enable_restoreformer: logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") - global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( + plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( restoreformer_device, - upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None), + upscaler=plugins.get(RealESRGANUpscaler.name, None), ) + return plugins diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 3478cb6..acdee85 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -1,8 +1,17 @@ +import random from enum import Enum -from typing import Optional +from pathlib import Path +from typing import Optional, Literal, List from PIL.Image import Image -from pydantic import BaseModel +from pydantic import BaseModel, Field, validator, field_validator + +from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel + + +class CV2Flag(str, Enum): + INPAINT_NS = "INPAINT_NS" + INPAINT_TELEA = "INPAINT_TELEA" class ModelType(str, Enum): @@ -56,93 +65,215 @@ class PowerPaintTask(str, Enum): outpainting = "outpainting" -class Config(BaseModel): - class Config: - arbitrary_types_allowed = True +class ApiConfig(BaseModel): + host: str + port: int + model: str + no_half: bool + cpu_offload: bool + disable_nsfw_checker: bool + cpu_textencoder: bool + device: Device + gui: bool + disable_model_switch: bool + input: Path + output_dir: Path + quality: int + enable_interactive_seg: bool + interactive_seg_model: InteractiveSegModel + interactive_seg_device: Device + enable_remove_bg: bool + enable_anime_seg: bool + enable_realesrgan: bool + realesrgan_device: Device + realesrgan_model: RealESRGANModel + enable_gfpgan: bool + gfpgan_device: Device + enable_restoreformer: bool + restoreformer_device: Device - # Configs for ldm model - ldm_steps: int = 20 - ldm_sampler: str = LDMSampler.plms - # Configs for zits model - zits_wireframe: bool = True +class InpaintRequest(BaseModel): + image: Optional[str] = Field(..., description="base64 encoded image") + mask: Optional[str] = Field(..., description="base64 encoded mask") - # Configs for High Resolution Strategy(different way to preprocess image) - hd_strategy: str = HDStrategy.CROP # See HDStrategy Enum - hd_strategy_crop_margin: int = 128 - # If the longer side of the image is larger than this value, use crop strategy - hd_strategy_crop_trigger_size: int = 800 - hd_strategy_resize_limit: int = 1280 + ldm_steps: int = Field(20, description="Steps for ldm model.") + ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.") + zits_wireframe: bool = Field(True, description="Enable wireframe for zits model.") - # Configs for Stable Diffusion 1.5 - prompt: str = "" - negative_prompt: str = "" - # Crop image to this size before doing sd inpainting - # The value is always on the original image scale - use_croper: bool = False - croper_x: int = None - croper_y: int = None - croper_height: int = None - croper_width: int = None - use_extender: bool = False - extender_x: int = None - extender_y: int = None - extender_height: int = None - extender_width: int = None + hd_strategy: str = Field( + HDStrategy.CROP, + description="Different way to preprocess image, only used by erase models(e.g. lama/mat)", + ) + hd_strategy_crop_trigger_size: int = Field( + 800, + description="Crop trigger size for hd_strategy=CROP, if the longer side of the image is larger than this value, use crop strategy", + ) + hd_strategy_crop_margin: int = Field( + 128, description="Crop margin for hd_strategy=CROP" + ) + hd_strategy_resize_limit: int = Field( + 1280, description="Resize limit for hd_strategy=RESIZE" + ) - # Resize the image before doing sd inpainting, the area outside the mask will not lose quality. - # Used by sd models and paint_by_example model - sd_scale: float = 1.0 - # Blur the edge of mask area. The higher the number the smoother blend with the original image - sd_mask_blur: int = 0 - # Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a - # starting point and more noise is added the higher the `strength`. The number of denoising steps depends - # on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising - # process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 - # essentially ignores `image`. - sd_strength: float = 1.0 - # The number of denoising steps. More denoising steps usually lead to a - # higher quality image at the expense of slower inference. - sd_steps: int = 50 - # Higher guidance scale encourages to generate images that are closely linked - # to the text prompt, usually at the expense of lower image quality. - sd_guidance_scale: float = 7.5 - sd_sampler: str = SDSampler.uni_pc - # -1 mean random seed - sd_seed: int = 42 - sd_match_histograms: bool = False + prompt: str = Field("", description="Prompt for diffusion models.") + negative_prompt: str = Field( + "", description="Negative prompt for diffusion models." + ) + use_croper: bool = Field( + False, description="Crop image before doing diffusion inpainting" + ) + croper_x: int = Field(0, description="Crop x for croper") + croper_y: int = Field(0, description="Crop y for croper") + croper_height: int = Field(512, description="Crop height for croper") + croper_width: int = Field(512, description="Crop width for croper") - # out-painting - sd_outpainting_softness: float = 20.0 - sd_outpainting_space: float = 20.0 + use_extender: bool = Field( + False, description="Extend image before doing sd outpainting" + ) + extender_x: int = Field(0, description="Extend x for extender") + extender_y: int = Field(0, description="Extend y for extender") + extender_height: int = Field(640, description="Extend height for extender") + extender_width: int = Field(640, description="Extend width for extender") - # freeu - sd_freeu: bool = False + sd_mask_blur: int = Field( + 33, + description="Blur the edge of mask area. The higher the number the smoother blend with the original image", + ) + sd_strength: float = Field( + 1.0, + description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image", + ) + sd_steps: int = Field( + 50, + description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.", + ) + sd_guidance_scale: float = Field( + 7.5, + help="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.", + ) + sd_sampler: str = Field( + SDSampler.uni_pc, description="Sampler for diffusion model." + ) + sd_seed: int = Field( + 42, + description="Seed for diffusion model. -1 mean random seed", + validate_default=True, + ) + sd_match_histograms: bool = Field( + False, + description="Match histograms between inpainting area and original image.", + ) + + sd_outpainting_softness: float = Field(20.0) + sd_outpainting_space: float = Field(20.0) + + sd_freeu: bool = Field( + False, + description="Enable freeu mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu", + ) sd_freeu_config: FREEUConfig = FREEUConfig() - # lcm-lora - sd_lcm_lora: bool = False + sd_lcm_lora: bool = Field( + False, + description="Enable lcm-lora mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm#texttoimage", + ) - # preserving the unmasked area at the expense of some more unnatural transitions between the masked and unmasked areas. - sd_prevent_unmasked_area: bool = True + sd_keep_unmasked_area: bool = Field( + True, description="Keep unmasked area unchanged" + ) - # Configs for opencv inpainting - # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 - cv2_flag: str = "INPAINT_NS" - cv2_radius: int = 4 + cv2_flag: CV2Flag = Field( + CV2Flag.INPAINT_NS, + description="Flag for opencv inpainting: https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07", + ) + cv2_radius: int = Field( + 4, + description="Radius of a circular neighborhood of each point inpainted that is considered by the algorithm", + ) # Paint by Example - paint_by_example_example_image: Optional[Image] = None + paint_by_example_example_image: Optional[str] = Field( + None, description="Base64 encoded example image for paint by example model" + ) # InstructPix2Pix - p2p_image_guidance_scale: float = 1.5 + p2p_image_guidance_scale: float = Field(1.5, description="Image guidance scale") # ControlNet - enable_controlnet: bool = False - controlnet_conditioning_scale: float = 0.4 - controlnet_method: str = "lllyasviel/control_v11p_sd15_canny" + enable_controlnet: bool = Field(False, description="Enable controlnet") + controlnet_conditioning_scale: float = Field(0.4, description="Conditioning scale") + controlnet_method: str = Field( + "lllyasviel/control_v11p_sd15_canny", description="Controlnet method" + ) # PowerPaint - powerpaint_task: PowerPaintTask = PowerPaintTask.text_guided - # control the fitting degree of the generated objects to the mask shape. - fitting_degree: float = 1.0 + powerpaint_task: PowerPaintTask = Field( + PowerPaintTask.text_guided, description="PowerPaint task" + ) + fitting_degree: float = Field( + 1.0, + description="Control the fitting degree of the generated objects to the mask shape.", + ) + + @field_validator("sd_seed") + @classmethod + def sd_seed_validator(cls, v: int) -> int: + if v == -1: + return random.randint(1, 99999999) + return v + + +class RunPluginRequest(BaseModel): + name: str + image: Optional[str] = Field(..., description="base64 encoded image") + clicks: List[List[int]] = Field( + [], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]" + ) + scale: float = Field(2.0, description="Scale for upscaling") + + +MediaTab = Literal["input", "output"] + + +class MediasRequest(BaseModel): + tab: MediaTab + + +class MediasResponse(BaseModel): + name: str + height: int + width: int + ctime: float + mtime: float + + +class MediaFileRequest(BaseModel): + tab: MediaTab + filename: str + + +class MediaThumbnailFileRequest(BaseModel): + tab: MediaTab + filename: str + width: int = 0 + height: int = 0 + + +class GenInfoResponse(BaseModel): + prompt: str = "" + negative_prompt: str = "" + + +class ServerConfigResponse(BaseModel): + plugins: List[str] + enableFileManager: bool + enableAutoSaving: bool + enableControlnet: bool + controlnetMethod: Optional[str] + disableModelSwitch: bool + isDesktop: bool + + +class SwitchModelRequest(BaseModel): + name: str diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 9451349..2c1d6ed 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,16 +1,28 @@ #!/usr/bin/env python3 +import multiprocessing import os +import cv2 + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +NUM_THREADS = str(multiprocessing.cpu_count()) +cv2.setNumThreads(NUM_THREADS) + +# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + +os.environ["OMP_NUM_THREADS"] = NUM_THREADS +os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS +os.environ["MKL_NUM_THREADS"] = NUM_THREADS +os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS +os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS + import hashlib import traceback from dataclasses import dataclass - -import imghdr import io -import logging -import multiprocessing import random import time from pathlib import Path @@ -21,6 +33,11 @@ import torch from PIL import Image from loguru import logger +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse + from lama_cleaner.const import * from lama_cleaner.file_manager import FileManager from lama_cleaner.model.utils import torch_gc @@ -31,8 +48,15 @@ from lama_cleaner.plugins import ( AnimeSeg, build_plugins, ) -from lama_cleaner.schema import Config - +from lama_cleaner.schema import InpaintRequest +from lama_cleaner.helper import ( + load_img, + numpy_to_bytes, + resize_max_size, + pil_to_bytes, + is_mac, + get_image_ext, concat_alpha_channel, +) try: torch._C._jit_override_can_fuse_on_cpu(False) @@ -42,454 +66,23 @@ try: except: pass -from flask import ( - Flask, - request, - send_file, - cli, - make_response, - send_from_directory, - jsonify, + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) -from flask_socketio import SocketIO - -# Disable ability for Flask to display warning about using a development server in a production environment. -# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 -cli.show_server_banner = lambda *_: None -from flask_cors import CORS - -from lama_cleaner.helper import ( - load_img, - numpy_to_bytes, - resize_max_size, - pil_to_bytes, - is_mac, -) - -NUM_THREADS = str(multiprocessing.cpu_count()) - -# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 -os.environ["KMP_DUPLICATE_LIB_OK"] = "True" - -os.environ["OMP_NUM_THREADS"] = NUM_THREADS -os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS -os.environ["MKL_NUM_THREADS"] = NUM_THREADS -os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS -os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS -if os.environ.get("CACHE_DIR"): - os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] - BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") - -class NoFlaskwebgui(logging.Filter): - def filter(self, record): - msg = record.getMessage() - if "Running on http:" in msg: - print(msg[msg.index("Running on http:") :]) - - return ( - "flaskwebgui-keep-server-alive" not in msg - and "socket.io" not in msg - and "This is a development server." not in msg - ) - - -logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) - -app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) -app.config["JSON_AS_ASCII"] = False -CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"]) - -sio_logger = logging.getLogger("sio-logger") -sio_logger.setLevel(logging.ERROR) -socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading") - - -@dataclass -class GlobalConfig: - model_manager: ModelManager = None - file_manager: FileManager = None - output_dir: Path = None - input_image_path: Path = None - disable_model_switch: bool = False - is_desktop: bool = False - image_quality: int = 95 - plugins = {} - - @property - def enable_auto_saving(self) -> bool: - return self.output_dir is not None - - @property - def enable_file_manager(self) -> bool: - return self.file_manager is not None - - global_config = GlobalConfig() - -def get_image_ext(img_bytes): - w = imghdr.what("", img_bytes) - if w is None: - w = "jpeg" - return w - - def diffuser_callback(i, t, latents): socketio.emit("diffusion_progress", {"step": i}) -@app.route("/save_image", methods=["POST"]) -def save_image(): - if global_config.output_dir is None: - return "--output-dir is None", 500 - - input = request.files - filename = request.form["filename"] - origin_image_bytes = input["image"].read() # RGB - # ext = get_image_ext(origin_image_bytes) - ext = "png" - image, alpha_channel, infos = load_img(origin_image_bytes, return_info=True) - save_path = (global_config.output_dir / filename).with_suffix(f".{ext}") - - if alpha_channel is not None: - if alpha_channel.shape[:2] != image.shape[:2]: - alpha_channel = cv2.resize( - alpha_channel, dsize=(image.shape[1], image.shape[0]) - ) - image = np.concatenate((image, alpha_channel[:, :, np.newaxis]), axis=-1) - - pil_image = Image.fromarray(image).convert("RGBA") - - img_bytes = pil_to_bytes( - pil_image, - ext, - quality=global_config.image_quality, - infos=infos, - ) - try: - with open(save_path, "wb") as fw: - fw.write(img_bytes) - except: - return f"Save image failed: {traceback.format_exc()}", 500 - - return "ok", 200 - - -@app.route("/medias/") -def medias(tab): - if tab == "image": - response = make_response(jsonify(global_config.file_manager.media_names), 200) - else: - response = make_response( - jsonify(global_config.file_manager.output_media_names), 200 - ) - # response.last_modified = thumb.modified_time[tab] - # response.cache_control.no_cache = True - # response.cache_control.max_age = 0 - # response.make_conditional(request) - return response - - -@app.route("/media//") -def media_file(tab, filename): - if tab == "image": - return send_from_directory(global_config.file_manager.root_directory, filename) - return send_from_directory(global_config.file_manager.output_dir, filename) - - -@app.route("/media_thumbnail//") -def media_thumbnail_file(tab, filename): - args = request.args - width = args.get("width") - height = args.get("height") - if width is None and height is None: - width = 256 - if width: - width = int(float(width)) - if height: - height = int(float(height)) - - directory = global_config.file_manager.root_directory - if tab == "output": - directory = global_config.file_manager.output_dir - thumb_filename, (width, height) = global_config.file_manager.get_thumbnail( - directory, filename, width, height - ) - thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" - - response = make_response(send_file(thumb_filepath)) - response.headers["X-Width"] = str(width) - response.headers["X-Height"] = str(height) - return response - - -@app.route("/inpaint", methods=["POST"]) -def process(): - input = request.files - # RGB - origin_image_bytes = input["image"].read() - image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_info=True) - - mask, _ = load_img(input["mask"].read(), gray=True) - mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] - - if image.shape[:2] != mask.shape[:2]: - return ( - f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", - 400, - ) - - original_shape = image.shape - interpolation = cv2.INTER_CUBIC - - form = request.form - size_limit = max(image.shape) - - if "paintByExampleImage" in input: - paint_by_example_example_image, _ = load_img( - input["paintByExampleImage"].read() - ) - paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) - else: - paint_by_example_example_image = None - - config = Config( - ldm_steps=form["ldmSteps"], - ldm_sampler=form["ldmSampler"], - hd_strategy=form["hdStrategy"], - zits_wireframe=form["zitsWireframe"], - hd_strategy_crop_margin=form["hdStrategyCropMargin"], - hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], - hd_strategy_resize_limit=form["hdStrategyResizeLimit"], - prompt=form["prompt"], - negative_prompt=form["negativePrompt"], - use_croper=form["useCroper"], - croper_x=form["croperX"], - croper_y=form["croperY"], - croper_height=form["croperHeight"], - croper_width=form["croperWidth"], - use_extender=form["useExtender"], - extender_x=form["extenderX"], - extender_y=form["extenderY"], - extender_height=form["extenderHeight"], - extender_width=form["extenderWidth"], - sd_scale=form["sdScale"], - sd_mask_blur=form["sdMaskBlur"], - sd_strength=form["sdStrength"], - sd_steps=form["sdSteps"], - sd_guidance_scale=form["sdGuidanceScale"], - sd_sampler=form["sdSampler"], - sd_seed=form["sdSeed"], - sd_freeu=form["enableFreeu"], - sd_freeu_config=json.loads(form["freeuConfig"]), - sd_lcm_lora=form["enableLCMLora"], - sd_match_histograms=form["sdMatchHistograms"], - cv2_flag=form["cv2Flag"], - cv2_radius=form["cv2Radius"], - paint_by_example_example_image=paint_by_example_example_image, - p2p_image_guidance_scale=form["p2pImageGuidanceScale"], - enable_controlnet=form["enable_controlnet"], - controlnet_conditioning_scale=form["controlnet_conditioning_scale"], - controlnet_method=form["controlnet_method"], - powerpaint_task=form["powerpaintTask"], - ) - - if config.sd_seed == -1: - config.sd_seed = random.randint(1, 99999999) - - logger.info(f"Origin image shape: {original_shape}") - image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) - - mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) - - start = time.time() - try: - res_np_img = global_config.model_manager(image, mask, config) - except RuntimeError as e: - if "CUDA out of memory. " in str(e): - # NOTE: the string may change? - return "CUDA out of memory", 500 - elif "Invalid buffer size" in str(e) and is_mac(): - return "Out of memory", 500 - else: - logger.exception(e) - return f"{str(e)}", 500 - finally: - logger.info(f"process time: {(time.time() - start) * 1000}ms") - torch_gc() - - res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) - if alpha_channel is not None: - if alpha_channel.shape[:2] != res_np_img.shape[:2]: - alpha_channel = cv2.resize( - alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) - ) - res_np_img = np.concatenate( - (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 - ) - - ext = get_image_ext(origin_image_bytes) - - bytes_io = io.BytesIO( - pil_to_bytes( - Image.fromarray(res_np_img), - ext, - quality=global_config.image_quality, - infos=exif_infos, - ) - ) - - response = make_response( - send_file( - # io.BytesIO(numpy_to_bytes(res_np_img, ext)), - bytes_io, - mimetype=f"image/{ext}", - ) - ) - response.headers["X-Seed"] = str(config.sd_seed) - - socketio.emit("diffusion_finish") - return response - - -@app.route("/run_plugin", methods=["POST"]) -def run_plugin(): - form = request.form - files = request.files - name = form["name"] - if name not in global_config.plugins: - return "Plugin not found", 500 - - origin_image_bytes = files["image"].read() # RGB - rgb_np_img, alpha_channel, infos = load_img( - origin_image_bytes, return_info=True - ) - - start = time.time() - try: - form = dict(form) - if name == InteractiveSeg.name: - img_md5 = hashlib.md5(origin_image_bytes).hexdigest() - form["img_md5"] = img_md5 - bgr_res = global_config.plugins[name](rgb_np_img, files, form) - except RuntimeError as e: - torch.cuda.empty_cache() - if "CUDA out of memory. " in str(e): - # NOTE: the string may change? - return "CUDA out of memory", 500 - else: - logger.exception(e) - return "Internal Server Error", 500 - - logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") - torch_gc() - - if name == InteractiveSeg.name: - return make_response( - send_file( - io.BytesIO(numpy_to_bytes(bgr_res, "png")), - mimetype="image/png", - ) - ) - - if name in [RemoveBG.name, AnimeSeg.name]: - rgb_res = bgr_res - ext = "png" - else: - rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB) - ext = get_image_ext(origin_image_bytes) - if alpha_channel is not None: - if alpha_channel.shape[:2] != rgb_res.shape[:2]: - alpha_channel = cv2.resize( - alpha_channel, dsize=(rgb_res.shape[1], rgb_res.shape[0]) - ) - rgb_res = np.concatenate( - (rgb_res, alpha_channel[:, :, np.newaxis]), axis=-1 - ) - - response = make_response( - send_file( - io.BytesIO( - pil_to_bytes( - Image.fromarray(rgb_res), - ext, - quality=global_config.image_quality, - infos=infos, - ) - ), - mimetype=f"image/{ext}", - ) - ) - return response - - -@app.route("/server_config", methods=["GET"]) -def get_server_config(): - return { - "plugins": list(global_config.plugins.keys()), - "enableFileManager": global_config.enable_file_manager, - "enableAutoSaving": global_config.enable_auto_saving, - "enableControlnet": global_config.model_manager.enable_controlnet, - "controlnetMethod": global_config.model_manager.controlnet_method, - "disableModelSwitch": global_config.disable_model_switch, - "isDesktop": global_config.is_desktop, - }, 200 - - -@app.route("/models", methods=["GET"]) -def get_models(): - return [it.model_dump() for it in global_config.model_manager.scan_models()] - - -@app.route("/model") -def current_model(): - return ( - global_config.model_manager.current_model, - 200, - ) - - -@app.route("/model", methods=["POST"]) -def switch_model(): - if global_config.disable_model_switch: - return "Switch model is disabled", 400 - - new_name = request.form.get("name") - if new_name == global_config.model_manager.name: - return "Same model", 200 - - try: - global_config.model_manager.switch(new_name) - except Exception as e: - traceback.print_exc() - error_message = f"{type(e).__name__} - {str(e)}" - logger.error(error_message) - return f"Switch model failed: {error_message}", 500 - return f"ok, switch to {new_name}", 200 - - -@app.route("/") -def index(): - return send_file(os.path.join(BUILD_DIR, "index.html")) - - -@app.route("/inputimage") -def get_cli_input_image(): - if global_config.input_image_path: - with open(global_config.input_image_path, "rb") as f: - image_in_bytes = f.read() - return send_file( - global_config.input_image_path, - as_attachment=True, - download_name=Path(global_config.input_image_path).name, - mimetype=f"image/{get_image_ext(image_in_bytes)}", - ) - else: - return "No Input Image" - - def start( host: str, port: int, diff --git a/lama_cleaner/tests/test_model_switch.py b/lama_cleaner/tests/test_model_switch.py index 6f732a1..588bf82 100644 --- a/lama_cleaner/tests/test_model_switch.py +++ b/lama_cleaner/tests/test_model_switch.py @@ -1,6 +1,6 @@ import os -from lama_cleaner.schema import Config +from lama_cleaner.schema import InpaintRequest os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -38,7 +38,7 @@ def test_controlnet_switch_onoff(caplog): ) model.switch_controlnet_method( - Config( + InpaintRequest( name=name, enable_controlnet=False, ) @@ -63,7 +63,7 @@ def test_switch_controlnet_method(caplog): ) model.switch_controlnet_method( - Config( + InpaintRequest( name=name, enable_controlnet=True, controlnet_method=new_method, diff --git a/lama_cleaner/tests/test_save_exif.py b/lama_cleaner/tests/test_save_exif.py index 36f48a5..31e76ca 100644 --- a/lama_cleaner/tests/test_save_exif.py +++ b/lama_cleaner/tests/test_save_exif.py @@ -1,4 +1,5 @@ import io +import tempfile from pathlib import Path from typing import List @@ -18,9 +19,9 @@ def extra_info(img_p: Path): ext = img_p.suffix.strip(".") img_bytes = img_p.read_bytes() np_img, _, infos = load_img(img_bytes, False, True) - pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos) - res_img = Image.open(io.BytesIO(pil_bytes)) - return infos, res_img.info + res_pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos) + res_img = Image.open(io.BytesIO(res_pil_bytes)) + return infos, res_img.info, res_pil_bytes def assert_keys(keys: List[str], infos, res_infos): @@ -30,23 +31,29 @@ def assert_keys(keys: List[str], infos, res_infos): assert infos[k] == res_infos[k] +def run_test(file_path, keys): + infos, res_infos, res_pil_bytes = extra_info(file_path) + assert_keys(keys, infos, res_infos) + with tempfile.NamedTemporaryFile("wb", suffix=file_path.suffix) as temp_file: + temp_file.write(res_pil_bytes) + temp_file.flush() + infos, res_infos, res_pil_bytes = extra_info(Path(temp_file.name)) + assert_keys(keys, infos, res_infos) + + def test_png_icc_profile_png(): - infos, res_infos = extra_info(current_dir / "icc_profile_test.png") - assert_keys(["icc_profile", "exif"], infos, res_infos) + run_test(current_dir / "icc_profile_test.png", ["icc_profile", "exif"]) def test_png_icc_profile_jpeg(): - infos, res_infos = extra_info(current_dir / "icc_profile_test.jpg") - assert_keys(["icc_profile", "exif"], infos, res_infos) + run_test(current_dir / "icc_profile_test.jpg", ["icc_profile", "exif"]) def test_jpeg(): jpg_img_p = current_dir / "bunny.jpeg" - infos, res_infos = extra_info(jpg_img_p) - assert_keys(["dpi", "exif"], infos, res_infos) + run_test(jpg_img_p, ["dpi", "exif"]) def test_png_parameter(): jpg_img_p = current_dir / "png_parameter_test.png" - infos, res_infos = extra_info(jpg_img_p) - assert_keys(["parameters"], infos, res_infos) + run_test(jpg_img_p, ["parameters"]) diff --git a/lama_cleaner/tests/utils.py b/lama_cleaner/tests/utils.py index b84cd54..00ebdf1 100644 --- a/lama_cleaner/tests/utils.py +++ b/lama_cleaner/tests/utils.py @@ -3,7 +3,7 @@ import cv2 import pytest import torch -from lama_cleaner.schema import LDMSampler, HDStrategy, Config, SDSampler +from lama_cleaner.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler current_dir = Path(__file__).parent.absolute().resolve() save_dir = current_dir / "result" @@ -72,4 +72,4 @@ def get_config(**kwargs): hd_strategy_resize_limit=200, ) data.update(**kwargs) - return Config(**data) + return InpaintRequest(**data) diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index 5107d27..49f724d 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -43,7 +43,7 @@ def save_config( restoreformer_device, enable_gif, ): - config = Config(**locals()) + config = InpaintRequest(**locals()) print(config) if config.input and not os.path.exists(config.input): return "[Error] Input file or directory does not exist" diff --git a/publish.sh b/publish.sh index e42d5f5..0e76da0 100644 --- a/publish.sh +++ b/publish.sh @@ -1,8 +1,8 @@ #!/usr/bin/env bash set -e -pushd ./lama_cleaner/app -yarn run build +pushd ./web_app +npm run build popd rm -r -f dist diff --git a/requirements.txt b/requirements.txt index 7f6e668..461b2e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch>=2.0.0 typer opencv-python -flask==2.2.3 -flask-socketio +fastapi==0.108.0 +python-multipart simple-websocket flask_cors flaskwebgui==0.3.5 diff --git a/web_app/src/hooks/useInputImage.tsx b/web_app/src/hooks/useInputImage.tsx index ba9a66e..bd4be4d 100644 --- a/web_app/src/hooks/useInputImage.tsx +++ b/web_app/src/hooks/useInputImage.tsx @@ -9,21 +9,28 @@ export default function useInputImage() { headers.append("pragma", "no-cache") headers.append("cache-control", "no-cache") - fetch(`${API_ENDPOINT}/inputimage`, { headers }).then(async (res) => { - const filename = res.headers - .get("content-disposition") - ?.split("filename=")[1] - .split(";")[0] + fetch(`${API_ENDPOINT}/inputimage`, { headers }) + .then(async (res) => { + if (!res.ok) { + throw new Error("No input image found") + } + const filename = res.headers + .get("content-disposition") + ?.split("filename=")[1] + .split(";")[0] - const data = await res.blob() - if (data && data.type.startsWith("image")) { - const userInput = new File( - [data], - filename !== undefined ? filename : "inputImage" - ) - setInputImage(userInput) - } - }) + const data = await res.blob() + if (data && data.type.startsWith("image")) { + const userInput = new File( + [data], + filename !== undefined ? filename : "inputImage" + ) + setInputImage(userInput) + } + }) + .catch((err) => { + console.log(err) + }) }, [setInputImage]) useEffect(() => { diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 8e13910..09e41e0 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -1,11 +1,11 @@ import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types" import { Settings } from "@/lib/states" -import { srcToFile } from "@/lib/utils" -import axios from "axios" +import { convertToBase64, srcToFile } from "@/lib/utils" +import axios, { AxiosError } from "axios" export const API_ENDPOINT = import.meta.env.VITE_BACKEND ? import.meta.env.VITE_BACKEND - : "" + : "/api/v1" const api = axios.create({ baseURL: API_ENDPOINT, @@ -19,96 +19,75 @@ export default async function inpaint( mask: File | Blob, paintByExampleImage: File | null = null ) { - const fd = new FormData() - fd.append("image", imageFile) - fd.append("mask", mask) - fd.append("ldmSteps", settings.ldmSteps.toString()) - fd.append("ldmSampler", settings.ldmSampler.toString()) - fd.append("zitsWireframe", settings.zitsWireframe.toString()) - fd.append("hdStrategy", "Crop") - fd.append("hdStrategyCropMargin", "128") - fd.append("hdStrategyCropTrigerSize", "640") - fd.append("hdStrategyResizeLimit", "2048") - - fd.append("prompt", settings.prompt) - fd.append("negativePrompt", settings.negativePrompt) - - fd.append("useCroper", settings.showCropper ? "true" : "false") - fd.append("croperX", croperRect.x.toString()) - fd.append("croperY", croperRect.y.toString()) - fd.append("croperHeight", croperRect.height.toString()) - fd.append("croperWidth", croperRect.width.toString()) - - fd.append("useExtender", settings.showExtender ? "true" : "false") - fd.append("extenderX", extenderState.x.toString()) - fd.append("extenderY", extenderState.y.toString()) - fd.append("extenderHeight", extenderState.height.toString()) - fd.append("extenderWidth", extenderState.width.toString()) - - fd.append("sdMaskBlur", settings.sdMaskBlur.toString()) - fd.append("sdStrength", settings.sdStrength.toString()) - fd.append("sdSteps", settings.sdSteps.toString()) - fd.append("sdGuidanceScale", settings.sdGuidanceScale.toString()) - fd.append("sdSampler", settings.sdSampler.toString()) - - if (settings.seedFixed) { - fd.append("sdSeed", settings.seed.toString()) - } else { - fd.append("sdSeed", "-1") - } - - fd.append("sdMatchHistograms", settings.sdMatchHistograms ? "true" : "false") - fd.append("sdScale", (settings.sdScale / 100).toString()) - fd.append("enableFreeu", settings.enableFreeu.toString()) - fd.append("freeuConfig", JSON.stringify(settings.freeuConfig)) - fd.append("enableLCMLora", settings.enableLCMLora.toString()) - - fd.append("cv2Radius", settings.cv2Radius.toString()) - fd.append("cv2Flag", settings.cv2Flag.toString()) - - // TODO: resize image's shortest_edge to 224 before pass to backend, save network time? - // https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor - if (paintByExampleImage) { - fd.append("paintByExampleImage", paintByExampleImage) - } - - // InstructPix2Pix - fd.append("p2pImageGuidanceScale", settings.p2pImageGuidanceScale.toString()) - - // ControlNet - fd.append("enable_controlnet", settings.enableControlnet.toString()) - fd.append( - "controlnet_conditioning_scale", - settings.controlnetConditioningScale.toString() - ) - fd.append("controlnet_method", settings.controlnetMethod?.toString()) - - // PowerPaint - if (settings.showExtender) { - fd.append("powerpaintTask", PowerPaintTask.outpainting) - } else { - fd.append("powerpaintTask", settings.powerpaintTask) - } - + const imageBase64 = await convertToBase64(imageFile) + const maskBase64 = await convertToBase64(mask) + const exampleImageBase64 = paintByExampleImage + ? await convertToBase64(paintByExampleImage) + : null try { const res = await fetch(`${API_ENDPOINT}/inpaint`, { method: "POST", - body: fd, + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + image: imageBase64, + mask: maskBase64, + ldm_steps: settings.ldmSteps, + ldm_sampler: settings.ldmSampler, + zits_wireframe: settings.zitsWireframe, + cv2_flag: settings.cv2Flag, + cv2_radius: settings.cv2Radius, + hd_strategy: "Crop", + hd_strategy_crop_triger_size: 640, + hd_strategy_crop_margin: 128, + hd_trategy_resize_imit: 2048, + prompt: settings.prompt, + negative_prompt: settings.negativePrompt, + use_croper: settings.showCropper, + croper_x: croperRect.x, + croper_y: croperRect.y, + croper_height: croperRect.height, + croper_width: croperRect.width, + use_extender: settings.showExtender, + extender_x: extenderState.x, + extender_y: extenderState.y, + extender_height: extenderState.height, + extender_width: extenderState.width, + sd_mask_blur: settings.sdMaskBlur, + sd_strength: settings.sdStrength, + sd_steps: settings.sdSteps, + sd_guidance_scale: settings.sdGuidanceScale, + sd_sampler: settings.sdSampler, + sd_seed: settings.seedFixed ? settings.seed : -1, + sd_match_histograms: settings.sdMatchHistograms, + sd_freeu: settings.enableFreeu, + sd_freeu_config: settings.freeuConfig, + sd_lcm_lora: settings.enableLCMLora, + paint_by_example_example_image: exampleImageBase64, + p2p_image_guidance_scale: settings.p2pImageGuidanceScale, + enable_controlnet: settings.enableControlnet, + controlnet_conditioning_scale: settings.controlnetConditioningScale, + controlnet_method: settings.controlnetMethod + ? settings.controlnetMethod + : "", + powerpaint_task: settings.showExtender + ? PowerPaintTask.outpainting + : settings.powerpaintTask, + }), }) - if (res.ok) { - const blob = await res.blob() - const newSeed = res.headers.get("X-seed") - return { blob: URL.createObjectURL(blob), seed: newSeed } + const blob = await res.blob() + return { + blob: URL.createObjectURL(blob), + seed: res.headers.get("X-Seed"), } - const errMsg = await res.text() - throw new Error(errMsg) - } catch (error) { - throw new Error(`Something went wrong: ${error}`) + } catch (error: any) { + throw new Error(`Something went wrong: ${JSON.stringify(error.message)}`) } } export function getServerConfig() { - return fetch(`${API_ENDPOINT}/server_config`, { + return fetch(`${API_ENDPOINT}/server-config`, { method: "GET", }) } diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index 55dc2a1..5cfbf8e 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -491,10 +491,6 @@ export const useStore = createWithEqualityFn()( paintByExampleFile ) - if (!res) { - throw new Error("Something went wrong on server side.") - } - const { blob, seed } = res if (seed) { get().setSeed(parseInt(seed, 10)) diff --git a/web_app/src/lib/utils.ts b/web_app/src/lib/utils.ts index 4c3e1fb..e2422a3 100644 --- a/web_app/src/lib/utils.ts +++ b/web_app/src/lib/utils.ts @@ -223,3 +223,17 @@ export const generateMask = ( return maskCanvas } + +export const convertToBase64 = (fileOrBlob: File | Blob): Promise => { + return new Promise((resolve, reject) => { + const reader = new FileReader() + reader.onload = (event) => { + const base64String = event.target?.result as string + resolve(base64String) + } + reader.onerror = (error) => { + reject(error) + } + reader.readAsDataURL(fileOrBlob) + }) +}