diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 9eb77ee..babbeac 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -5,7 +5,7 @@ from typing import List, Optional from urllib.parse import urlparse import cv2 -from PIL import Image, ImageOps +from PIL import Image, ImageOps, PngImagePlugin import numpy as np import torch from lama_cleaner.const import MPS_SUPPORT_MODELS @@ -135,9 +135,20 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: return image_bytes -def pil_to_bytes(pil_img, ext: str, quality: int = 95, exif=None) -> bytes: +def pil_to_bytes(pil_img, ext: str, quality: int = 95, exif_infos={}) -> bytes: with io.BytesIO() as output: - pil_img.save(output, format=ext, exif=exif, quality=quality) + kwargs = {k: v for k, v in exif_infos.items() if v is not None} + if ext == "png" and "parameters" in kwargs: + pnginfo_data = PngImagePlugin.PngInfo() + pnginfo_data.add_text("parameters", kwargs["parameters"]) + kwargs["pnginfo"] = pnginfo_data + + pil_img.save( + output, + format=ext, + quality=quality, + **kwargs, + ) image_bytes = output.getvalue() return image_bytes @@ -146,12 +157,9 @@ def load_img(img_bytes, gray: bool = False, return_exif: bool = False): alpha_channel = None image = Image.open(io.BytesIO(img_bytes)) - try: - if return_exif: - exif = image.getexif() - except: - exif = None - logger.error("Failed to extract exif from image") + if return_exif: + info = image.info or {} + exif_infos = {"exif": image.getexif(), "parameters": info.get("parameters")} try: image = ImageOps.exif_transpose(image) @@ -171,7 +179,7 @@ def load_img(img_bytes, gray: bool = False, return_exif: bool = False): np_img = np.array(image) if return_exif: - return np_img, alpha_channel, exif + return np_img, alpha_channel, exif_infos return np_img, alpha_channel diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index f6509fb..02f9711 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -90,7 +90,7 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app.config["JSON_AS_ASCII"] = False CORS(app, expose_headers=["Content-Disposition"]) -socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading') +socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading") model: ModelManager = None thumb: FileManager = None @@ -114,7 +114,7 @@ def get_image_ext(img_bytes): def diffuser_callback(i, t, latents): - socketio.emit('diffusion_progress', {'step': i}) + socketio.emit("diffusion_progress", {"step": i}) @app.route("/save_image", methods=["POST"]) @@ -188,7 +188,7 @@ def process(): input = request.files # RGB origin_image_bytes = input["image"].read() - image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) + image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True) mask, _ = load_img(input["mask"].read(), gray=True) mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] @@ -288,12 +288,14 @@ def process(): ext = get_image_ext(origin_image_bytes) - # fmt: off - if exif is not None: - bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, quality=image_quality, exif=exif)) - else: - bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, quality=image_quality)) - # fmt: on + bytes_io = io.BytesIO( + pil_to_bytes( + Image.fromarray(res_np_img), + ext, + quality=image_quality, + exif_infos=exif_infos, + ) + ) response = make_response( send_file( @@ -304,7 +306,7 @@ def process(): ) response.headers["X-Seed"] = str(config.sd_seed) - socketio.emit('diffusion_finish') + socketio.emit("diffusion_finish") return response @@ -317,7 +319,9 @@ def run_plugin(): return "Plugin not found", 500 origin_image_bytes = files["image"].read() # RGB - rgb_np_img, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) + rgb_np_img, alpha_channel, exif_infos = load_img( + origin_image_bytes, return_exif=True + ) start = time.time() try: @@ -372,7 +376,10 @@ def run_plugin(): send_file( io.BytesIO( pil_to_bytes( - Image.fromarray(rgb_res), ext, quality=image_quality, exif=exif + Image.fromarray(rgb_res), + ext, + quality=image_quality, + exif_infos=exif_infos, ) ), mimetype=f"image/{ext}", diff --git a/lama_cleaner/tests/pnginfo_test.png b/lama_cleaner/tests/pnginfo_test.png new file mode 100644 index 0000000..dc18bce Binary files /dev/null and b/lama_cleaner/tests/pnginfo_test.png differ diff --git a/lama_cleaner/tests/test_save_exif.py b/lama_cleaner/tests/test_save_exif.py index c22aedc..5921256 100644 --- a/lama_cleaner/tests/test_save_exif.py +++ b/lama_cleaner/tests/test_save_exif.py @@ -3,12 +3,9 @@ from pathlib import Path from PIL import Image -from lama_cleaner.helper import pil_to_bytes - +from lama_cleaner.helper import pil_to_bytes, load_img current_dir = Path(__file__).parent.absolute().resolve() -png_img_p = current_dir / "image.png" -jpg_img_p = current_dir / "bunny.jpeg" def print_exif(exif): @@ -16,27 +13,31 @@ def print_exif(exif): print(f"{k}: {v}") -def test_png(): - img = Image.open(png_img_p) - exif = img.getexif() - print_exif(exif) +def run_test(img_p: Path): + print(img_p) + ext = img_p.suffix.strip(".") + img_bytes = img_p.read_bytes() + np_img, _, exif_infos = load_img(img_bytes, False, True) + print(exif_infos) + print("Original exif_infos") + print_exif(exif_infos["exif"]) - pil_bytes = pil_to_bytes(img, ext="png", exif=exif) + pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos={}) + pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos=exif_infos) res_img = Image.open(io.BytesIO(pil_bytes)) + print(f"Result img info: {res_img.info}") res_exif = res_img.getexif() + print_exif(res_exif) + assert res_exif == exif_infos["exif"] + assert exif_infos["parameters"] == res_img.info.get("parameters") - assert dict(exif) == dict(res_exif) + +def test_png(): + run_test(current_dir / "image.png") + run_test(current_dir / "pnginfo_test.png") def test_jpeg(): - img = Image.open(jpg_img_p) - exif = img.getexif() - print_exif(exif) - - pil_bytes = pil_to_bytes(img, ext="jpeg", exif=exif) - - res_img = Image.open(io.BytesIO(pil_bytes)) - res_exif = res_img.getexif() - - assert dict(exif) == dict(res_exif) + jpg_img_p = current_dir / "bunny.jpeg" + run_test(jpg_img_p)