fix image exif rotation
This commit is contained in:
parent
a3275fc0dc
commit
ff50421003
@ -1,9 +1,11 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
from urllib.parse import urlparse
|
||||
import cv2
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
@ -85,16 +87,23 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||
|
||||
def load_img(img_bytes, gray: bool = False):
|
||||
alpha_channel = None
|
||||
nparr = np.frombuffer(img_bytes, np.uint8)
|
||||
image = Image.open(io.BytesIO(img_bytes))
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if gray:
|
||||
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
||||
image = image.convert('L')
|
||||
np_img = np.array(image)
|
||||
else:
|
||||
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
||||
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
|
||||
if image.mode == 'RGBA':
|
||||
np_img = np.array(image)
|
||||
alpha_channel = np_img[:, :, -1]
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
||||
else:
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
||||
image = image.convert('RGB')
|
||||
np_img = np.array(image)
|
||||
|
||||
return np_img, alpha_channel
|
||||
|
||||
|
21
lama_cleaner/tests/test_load_img.py
Normal file
21
lama_cleaner/tests/test_load_img.py
Normal file
@ -0,0 +1,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
from lama_cleaner.helper import load_img
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
png_img_p = current_dir / "image.png"
|
||||
jpg_img_p = current_dir / "bunny.jpeg"
|
||||
|
||||
|
||||
def test_load_png_image():
|
||||
with open(png_img_p, "rb") as f:
|
||||
np_img, alpha_channel = load_img(f.read())
|
||||
assert np_img.shape == (256, 256, 3)
|
||||
assert alpha_channel.shape == (256, 256)
|
||||
|
||||
|
||||
def test_load_jpg_image():
|
||||
with open(jpg_img_p, "rb") as f:
|
||||
np_img, alpha_channel = load_img(f.read())
|
||||
assert np_img.shape == (394, 448, 3)
|
||||
assert alpha_channel is None
|
Loading…
Reference in New Issue
Block a user