IOPaint/lama_cleaner/file_manager/file_manager.py

223 lines
7.3 KiB
Python
Raw Normal View History

2022-12-31 14:07:08 +01:00
import os
from io import BytesIO
2023-01-05 15:07:39 +01:00
from pathlib import Path
2023-12-30 16:36:44 +01:00
from typing import List
2022-12-31 14:07:08 +01:00
from PIL import Image, ImageOps, PngImagePlugin
2023-12-30 16:36:44 +01:00
from fastapi import FastAPI, UploadFile, HTTPException
from starlette.responses import FileResponse
2024-01-01 09:05:34 +01:00
from ..schema import MediasResponse, MediaTab
2023-01-05 15:07:39 +01:00
2022-12-31 14:07:08 +01:00
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
2022-12-31 14:07:08 +01:00
from .storage_backends import FilesystemStorageBackend
from .utils import aspect_to_string, generate_filename, glob_img
2023-03-22 05:57:18 +01:00
class FileManager:
2023-12-30 16:36:44 +01:00
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
2022-12-31 14:07:08 +01:00
self.app = app
2023-12-30 16:36:44 +01:00
self.input_dir: Path = input_dir
self.output_dir: Path = output_dir
2022-12-31 14:07:08 +01:00
2023-01-08 13:53:55 +01:00
self.image_dir_filenames = []
self.output_dir_filenames = []
2023-12-30 16:36:44 +01:00
if not self.thumbnail_directory.exists():
self.thumbnail_directory.mkdir(parents=True)
# fmt: off
self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
2024-01-01 09:05:34 +01:00
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
2023-12-30 16:36:44 +01:00
# fmt: on
def api_save_image(self, file: UploadFile):
filename = file.filename
origin_image_bytes = file.file.read()
with open(self.output_dir / filename, "wb") as fw:
fw.write(origin_image_bytes)
2024-01-01 09:05:34 +01:00
def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
img_dir = self._get_dir(tab)
2023-12-30 16:36:44 +01:00
return self._media_names(img_dir)
2024-01-01 09:05:34 +01:00
def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
file_path = self._get_file(tab, filename)
return FileResponse(file_path, media_type="image/png")
2023-12-30 16:36:44 +01:00
2024-01-01 09:05:34 +01:00
# tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
def api_media_thumbnail_file(
self, tab: MediaTab, filename: str, width: int, height: int
) -> FileResponse:
img_dir = self._get_dir(tab)
2023-12-30 16:36:44 +01:00
thumb_filename, (width, height) = self.get_thumbnail(
2024-01-01 09:05:34 +01:00
img_dir, filename, width=width, height=height
)
2023-12-30 16:36:44 +01:00
thumbnail_filepath = self.thumbnail_directory / thumb_filename
return FileResponse(
thumbnail_filepath,
headers={
"X-Width": str(width),
"X-Height": str(height),
},
2024-01-01 09:05:34 +01:00
media_type="image/jpeg",
)
2022-12-31 14:07:08 +01:00
2023-12-30 16:36:44 +01:00
def _get_dir(self, tab: MediaTab) -> Path:
if tab == "input":
return self.input_dir
elif tab == "output":
return self.output_dir
2022-12-31 14:07:08 +01:00
else:
2023-12-30 16:36:44 +01:00
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
2022-12-31 14:07:08 +01:00
2023-12-30 16:36:44 +01:00
def _get_file(self, tab: MediaTab, filename: str) -> Path:
file_path = self._get_dir(tab) / filename
if not file_path.exists():
raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
return file_path
2023-01-07 13:51:05 +01:00
@property
2023-12-30 16:36:44 +01:00
def thumbnail_directory(self) -> Path:
return self.output_dir / "thumbnails"
2023-01-07 13:51:05 +01:00
@staticmethod
2023-12-30 16:36:44 +01:00
def _media_names(directory: Path) -> List[MediasResponse]:
2023-01-07 13:51:05 +01:00
names = sorted([it.name for it in glob_img(directory)])
2022-12-31 14:07:08 +01:00
res = []
for name in names:
2023-01-07 13:51:05 +01:00
path = os.path.join(directory, name)
2023-01-07 01:52:11 +01:00
img = Image.open(path)
res.append(
2023-12-30 16:36:44 +01:00
MediasResponse(
name=name,
height=img.height,
width=img.width,
ctime=os.path.getctime(path),
mtime=os.path.getmtime(path),
)
)
2022-12-31 14:07:08 +01:00
return res
def get_thumbnail(
self, directory: Path, original_filename: str, width, height, **options
):
2023-12-01 03:15:35 +01:00
directory = Path(directory)
2022-12-31 14:07:08 +01:00
storage = FilesystemStorageBackend(self.app)
crop = options.get("crop", "fit")
background = options.get("background")
quality = options.get("quality", 90)
original_path, original_filename = os.path.split(original_filename)
2023-01-07 13:51:05 +01:00
original_filepath = os.path.join(directory, original_path, original_filename)
2022-12-31 14:07:08 +01:00
image = Image.open(BytesIO(storage.read(original_filepath)))
# keep ratio resize
2023-12-30 16:36:44 +01:00
if not width and not height:
width = 256
if width != 0:
2022-12-31 14:07:08 +01:00
height = int(image.height * width / image.width)
else:
width = int(image.width * height / image.height)
thumbnail_size = (width, height)
thumbnail_filename = generate_filename(
2023-12-01 03:15:35 +01:00
directory,
original_filename,
aspect_to_string(thumbnail_size),
crop,
background,
quality,
2022-12-31 14:07:08 +01:00
)
thumbnail_filepath = os.path.join(
self.thumbnail_directory, original_path, thumbnail_filename
)
if storage.exists(thumbnail_filepath):
2023-12-30 16:36:44 +01:00
return thumbnail_filepath, (width, height)
2022-12-31 14:07:08 +01:00
try:
image.load()
except (IOError, OSError):
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
2023-12-30 16:36:44 +01:00
return thumbnail_filepath, (width, height)
2022-12-31 14:07:08 +01:00
# get original image format
options["format"] = options.get("format", image.format)
image = self._create_thumbnail(
image, thumbnail_size, crop, background=background
)
2022-12-31 14:07:08 +01:00
raw_data = self.get_raw_data(image, **options)
storage.save(thumbnail_filepath, raw_data)
2023-12-30 16:36:44 +01:00
return thumbnail_filepath, (width, height)
2022-12-31 14:07:08 +01:00
def get_raw_data(self, image, **options):
data = {
"format": self._get_format(image, **options),
"quality": options.get("quality", 90),
}
_file = BytesIO()
image.save(_file, **data)
return _file.getvalue()
@staticmethod
def colormode(image, colormode="RGB"):
if colormode == "RGB" or colormode == "RGBA":
if image.mode == "RGBA":
return image
if image.mode == "LA":
return image.convert("RGBA")
return image.convert(colormode)
if colormode == "GRAY":
return image.convert("L")
return image.convert(colormode)
@staticmethod
def background(original_image, color=0xFF):
size = (max(original_image.size),) * 2
image = Image.new("L", size, color)
image.paste(
original_image,
tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
)
return image
def _get_format(self, image, **options):
if options.get("format"):
return options.get("format")
if image.format:
return image.format
2023-12-30 16:36:44 +01:00
return "JPEG"
2022-12-31 14:07:08 +01:00
def _create_thumbnail(self, image, size, crop="fit", background=None):
try:
resample = Image.Resampling.LANCZOS
except AttributeError: # pylint: disable=raise-missing-from
resample = Image.ANTIALIAS
if crop == "fit":
image = ImageOps.fit(image, size, resample)
else:
image = image.copy()
image.thumbnail(size, resample=resample)
if background is not None:
image = self.background(image)
image = self.colormode(image)
return image