fix pnginfo data

This commit is contained in:
Qing 2023-05-07 16:58:55 +08:00
parent 7841d63c90
commit 18a2498688
4 changed files with 58 additions and 42 deletions

View File

@ -5,7 +5,7 @@ from typing import List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import cv2 import cv2
from PIL import Image, ImageOps from PIL import Image, ImageOps, PngImagePlugin
import numpy as np import numpy as np
import torch import torch
from lama_cleaner.const import MPS_SUPPORT_MODELS from lama_cleaner.const import MPS_SUPPORT_MODELS
@ -135,9 +135,20 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
return image_bytes return image_bytes
def pil_to_bytes(pil_img, ext: str, quality: int = 95, exif=None) -> bytes: def pil_to_bytes(pil_img, ext: str, quality: int = 95, exif_infos={}) -> bytes:
with io.BytesIO() as output: with io.BytesIO() as output:
pil_img.save(output, format=ext, exif=exif, quality=quality) kwargs = {k: v for k, v in exif_infos.items() if v is not None}
if ext == "png" and "parameters" in kwargs:
pnginfo_data = PngImagePlugin.PngInfo()
pnginfo_data.add_text("parameters", kwargs["parameters"])
kwargs["pnginfo"] = pnginfo_data
pil_img.save(
output,
format=ext,
quality=quality,
**kwargs,
)
image_bytes = output.getvalue() image_bytes = output.getvalue()
return image_bytes return image_bytes
@ -146,12 +157,9 @@ def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
alpha_channel = None alpha_channel = None
image = Image.open(io.BytesIO(img_bytes)) image = Image.open(io.BytesIO(img_bytes))
try:
if return_exif: if return_exif:
exif = image.getexif() info = image.info or {}
except: exif_infos = {"exif": image.getexif(), "parameters": info.get("parameters")}
exif = None
logger.error("Failed to extract exif from image")
try: try:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
@ -171,7 +179,7 @@ def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
np_img = np.array(image) np_img = np.array(image)
if return_exif: if return_exif:
return np_img, alpha_channel, exif return np_img, alpha_channel, exif_infos
return np_img, alpha_channel return np_img, alpha_channel

View File

@ -90,7 +90,7 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"]) CORS(app, expose_headers=["Content-Disposition"])
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading') socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
model: ModelManager = None model: ModelManager = None
thumb: FileManager = None thumb: FileManager = None
@ -114,7 +114,7 @@ def get_image_ext(img_bytes):
def diffuser_callback(i, t, latents): def diffuser_callback(i, t, latents):
socketio.emit('diffusion_progress', {'step': i}) socketio.emit("diffusion_progress", {"step": i})
@app.route("/save_image", methods=["POST"]) @app.route("/save_image", methods=["POST"])
@ -188,7 +188,7 @@ def process():
input = request.files input = request.files
# RGB # RGB
origin_image_bytes = input["image"].read() origin_image_bytes = input["image"].read()
image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
mask, _ = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
@ -288,12 +288,14 @@ def process():
ext = get_image_ext(origin_image_bytes) ext = get_image_ext(origin_image_bytes)
# fmt: off bytes_io = io.BytesIO(
if exif is not None: pil_to_bytes(
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, quality=image_quality, exif=exif)) Image.fromarray(res_np_img),
else: ext,
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, quality=image_quality)) quality=image_quality,
# fmt: on exif_infos=exif_infos,
)
)
response = make_response( response = make_response(
send_file( send_file(
@ -304,7 +306,7 @@ def process():
) )
response.headers["X-Seed"] = str(config.sd_seed) response.headers["X-Seed"] = str(config.sd_seed)
socketio.emit('diffusion_finish') socketio.emit("diffusion_finish")
return response return response
@ -317,7 +319,9 @@ def run_plugin():
return "Plugin not found", 500 return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB origin_image_bytes = files["image"].read() # RGB
rgb_np_img, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) rgb_np_img, alpha_channel, exif_infos = load_img(
origin_image_bytes, return_exif=True
)
start = time.time() start = time.time()
try: try:
@ -372,7 +376,10 @@ def run_plugin():
send_file( send_file(
io.BytesIO( io.BytesIO(
pil_to_bytes( pil_to_bytes(
Image.fromarray(rgb_res), ext, quality=image_quality, exif=exif Image.fromarray(rgb_res),
ext,
quality=image_quality,
exif_infos=exif_infos,
) )
), ),
mimetype=f"image/{ext}", mimetype=f"image/{ext}",

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

View File

@ -3,12 +3,9 @@ from pathlib import Path
from PIL import Image from PIL import Image
from lama_cleaner.helper import pil_to_bytes from lama_cleaner.helper import pil_to_bytes, load_img
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
png_img_p = current_dir / "image.png"
jpg_img_p = current_dir / "bunny.jpeg"
def print_exif(exif): def print_exif(exif):
@ -16,27 +13,31 @@ def print_exif(exif):
print(f"{k}: {v}") print(f"{k}: {v}")
def test_png(): def run_test(img_p: Path):
img = Image.open(png_img_p) print(img_p)
exif = img.getexif() ext = img_p.suffix.strip(".")
print_exif(exif) img_bytes = img_p.read_bytes()
np_img, _, exif_infos = load_img(img_bytes, False, True)
print(exif_infos)
print("Original exif_infos")
print_exif(exif_infos["exif"])
pil_bytes = pil_to_bytes(img, ext="png", exif=exif) pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos={})
pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos=exif_infos)
res_img = Image.open(io.BytesIO(pil_bytes)) res_img = Image.open(io.BytesIO(pil_bytes))
print(f"Result img info: {res_img.info}")
res_exif = res_img.getexif() res_exif = res_img.getexif()
print_exif(res_exif)
assert res_exif == exif_infos["exif"]
assert exif_infos["parameters"] == res_img.info.get("parameters")
assert dict(exif) == dict(res_exif)
def test_png():
run_test(current_dir / "image.png")
run_test(current_dir / "pnginfo_test.png")
def test_jpeg(): def test_jpeg():
img = Image.open(jpg_img_p) jpg_img_p = current_dir / "bunny.jpeg"
exif = img.getexif() run_test(jpg_img_p)
print_exif(exif)
pil_bytes = pil_to_bytes(img, ext="jpeg", exif=exif)
res_img = Image.open(io.BytesIO(pil_bytes))
res_exif = res_img.getexif()
assert dict(exif) == dict(res_exif)