2023-01-17 14:05:17 +01:00
import io
2021-11-15 08:22:34 +01:00
import os
import sys
2022-07-14 10:49:03 +02:00
from typing import List , Optional
2021-11-15 08:22:34 +01:00
from urllib . parse import urlparse
import cv2
2023-01-17 14:05:17 +01:00
from PIL import Image , ImageOps
2021-11-15 08:22:34 +01:00
import numpy as np
import torch
2023-02-11 06:30:09 +01:00
from lama_cleaner . const import MPS_SUPPORT_MODELS
2022-07-14 10:49:03 +02:00
from loguru import logger
2021-11-15 08:22:34 +01:00
from torch . hub import download_url_to_file , get_dir
2023-02-11 06:30:09 +01:00
def switch_mps_device ( model_name , device ) :
2023-02-14 02:08:56 +01:00
if model_name not in MPS_SUPPORT_MODELS and (
device == " mps " or device == torch . device ( " mps " )
) :
2023-02-11 06:30:09 +01:00
logger . info ( f " { model_name } not support mps, switch to cpu " )
2023-02-14 02:08:56 +01:00
return torch . device ( " cpu " )
2023-02-11 06:30:09 +01:00
return device
2022-04-17 17:31:12 +02:00
def get_cache_path_by_url ( url ) :
2021-11-15 08:22:34 +01:00
parts = urlparse ( url )
hub_dir = get_dir ( )
model_dir = os . path . join ( hub_dir , " checkpoints " )
2021-11-15 20:11:46 +01:00
if not os . path . isdir ( model_dir ) :
2023-01-14 15:08:45 +01:00
os . makedirs ( model_dir )
2021-11-15 08:22:34 +01:00
filename = os . path . basename ( parts . path )
cached_file = os . path . join ( model_dir , filename )
2022-04-17 17:31:12 +02:00
return cached_file
def download_model ( url ) :
cached_file = get_cache_path_by_url ( url )
2021-11-15 08:22:34 +01:00
if not os . path . exists ( cached_file ) :
sys . stderr . write ( ' Downloading: " {} " to {} \n ' . format ( url , cached_file ) )
hash_prefix = None
download_url_to_file ( url , cached_file , hash_prefix , progress = True )
return cached_file
def ceil_modulo ( x , mod ) :
if x % mod == 0 :
return x
return ( x / / mod + 1 ) * mod
2022-07-14 10:49:03 +02:00
def load_jit_model ( url_or_path , device ) :
if os . path . exists ( url_or_path ) :
model_path = url_or_path
else :
model_path = download_model ( url_or_path )
2023-02-14 02:08:56 +01:00
logger . info ( f " Loading model from: { model_path } " )
2022-07-19 15:47:21 +02:00
try :
2023-02-14 02:08:56 +01:00
model = torch . jit . load ( model_path , map_location = " cpu " ) . to ( device )
except Exception as e :
2022-08-22 17:24:02 +02:00
logger . error (
2023-02-14 02:08:56 +01:00
f " Failed to load { model_path } , please delete model and restart lama-cleaner. \n "
f " If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually. \n "
f " If all above operations doesn ' t work, please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error: \n { e } "
2022-08-22 17:24:02 +02:00
)
exit ( - 1 )
model . eval ( )
return model
def load_model ( model : torch . nn . Module , url_or_path , device ) :
if os . path . exists ( url_or_path ) :
model_path = url_or_path
else :
model_path = download_model ( url_or_path )
try :
2023-02-06 15:00:47 +01:00
state_dict = torch . load ( model_path , map_location = " cpu " )
2022-08-22 17:24:02 +02:00
model . load_state_dict ( state_dict , strict = True )
model . to ( device )
logger . info ( f " Load model from: { model_path } " )
2022-07-19 15:47:21 +02:00
except :
logger . error (
f " Failed to load { model_path } , delete model and restart lama-cleaner "
)
exit ( - 1 )
2022-07-14 10:49:03 +02:00
model . eval ( )
return model
2022-04-09 01:23:33 +02:00
def numpy_to_bytes ( image_numpy : np . ndarray , ext : str ) - > bytes :
2022-07-19 15:47:21 +02:00
data = cv2 . imencode (
f " . { ext } " ,
image_numpy ,
[ int ( cv2 . IMWRITE_JPEG_QUALITY ) , 100 , int ( cv2 . IMWRITE_PNG_COMPRESSION ) , 0 ] ,
) [ 1 ]
2021-11-15 08:22:34 +01:00
image_bytes = data . tobytes ( )
return image_bytes
2023-02-06 15:00:47 +01:00
def pil_to_bytes ( pil_img , ext : str , exif = None ) - > bytes :
with io . BytesIO ( ) as output :
2023-02-07 14:06:31 +01:00
pil_img . save ( output , format = ext , exif = exif , quality = 95 )
2023-02-06 15:00:47 +01:00
image_bytes = output . getvalue ( )
return image_bytes
def load_img ( img_bytes , gray : bool = False , return_exif : bool = False ) :
2022-04-09 02:12:37 +02:00
alpha_channel = None
2023-01-17 14:05:17 +01:00
image = Image . open ( io . BytesIO ( img_bytes ) )
2023-02-06 15:00:47 +01:00
try :
if return_exif :
exif = image . getexif ( )
except :
exif = None
logger . error ( " Failed to extract exif from image " )
2023-01-17 14:05:17 +01:00
try :
image = ImageOps . exif_transpose ( image )
except :
pass
2021-11-15 08:22:34 +01:00
if gray :
2023-02-06 15:00:47 +01:00
image = image . convert ( " L " )
2023-01-17 14:05:17 +01:00
np_img = np . array ( image )
2021-11-15 08:22:34 +01:00
else :
2023-02-06 15:00:47 +01:00
if image . mode == " RGBA " :
2023-01-17 14:05:17 +01:00
np_img = np . array ( image )
2022-04-09 02:12:37 +02:00
alpha_channel = np_img [ : , : , - 1 ]
2023-01-17 14:05:17 +01:00
np_img = cv2 . cvtColor ( np_img , cv2 . COLOR_RGBA2RGB )
2021-11-27 13:37:37 +01:00
else :
2023-02-06 15:00:47 +01:00
image = image . convert ( " RGB " )
2023-01-17 14:05:17 +01:00
np_img = np . array ( image )
2021-11-27 13:37:37 +01:00
2023-02-06 15:00:47 +01:00
if return_exif :
return np_img , alpha_channel , exif
2022-04-09 02:12:37 +02:00
return np_img , alpha_channel
2021-11-15 08:22:34 +01:00
2021-11-27 13:37:37 +01:00
def norm_img ( np_img ) :
if len ( np_img . shape ) == 2 :
np_img = np_img [ : , : , np . newaxis ]
np_img = np . transpose ( np_img , ( 2 , 0 , 1 ) )
np_img = np_img . astype ( " float32 " ) / 255
2021-11-15 08:22:34 +01:00
return np_img
2021-11-27 13:37:37 +01:00
def resize_max_size (
np_img , size_limit : int , interpolation = cv2 . INTER_CUBIC
) - > np . ndarray :
# Resize image's longer size to size_limit if longer size larger than size_limit
h , w = np_img . shape [ : 2 ]
if max ( h , w ) > size_limit :
ratio = size_limit / max ( h , w )
new_w = int ( w * ratio + 0.5 )
new_h = int ( h * ratio + 0.5 )
return cv2 . resize ( np_img , dsize = ( new_w , new_h ) , interpolation = interpolation )
else :
return np_img
2022-07-19 15:47:21 +02:00
def pad_img_to_modulo (
img : np . ndarray , mod : int , square : bool = False , min_size : Optional [ int ] = None
) :
2022-04-15 18:11:51 +02:00
"""
Args :
img : [ H , W , C ]
mod :
2022-07-14 10:49:03 +02:00
square : 是否为正方形
min_size :
2022-04-15 18:11:51 +02:00
Returns :
"""
if len ( img . shape ) == 2 :
img = img [ : , : , np . newaxis ]
height , width = img . shape [ : 2 ]
2021-11-15 08:22:34 +01:00
out_height = ceil_modulo ( height , mod )
out_width = ceil_modulo ( width , mod )
2022-07-14 10:49:03 +02:00
if min_size is not None :
assert min_size % mod == 0
out_width = max ( min_size , out_width )
out_height = max ( min_size , out_height )
if square :
max_size = max ( out_height , out_width )
out_height = max_size
out_width = max_size
2021-11-15 08:22:34 +01:00
return np . pad (
img ,
2022-04-15 18:11:51 +02:00
( ( 0 , out_height - height ) , ( 0 , out_width - width ) , ( 0 , 0 ) ) ,
2021-11-15 08:22:34 +01:00
mode = " symmetric " ,
)
2022-03-23 03:02:01 +01:00
def boxes_from_mask ( mask : np . ndarray ) - > List [ np . ndarray ] :
"""
Args :
2022-04-15 18:11:51 +02:00
mask : ( h , w , 1 ) 0 ~ 255
2022-03-23 03:02:01 +01:00
Returns :
"""
2022-04-15 18:11:51 +02:00
height , width = mask . shape [ : 2 ]
_ , thresh = cv2 . threshold ( mask , 127 , 255 , 0 )
2022-03-23 03:02:01 +01:00
contours , _ = cv2 . findContours ( thresh , cv2 . RETR_EXTERNAL , cv2 . CHAIN_APPROX_SIMPLE )
boxes = [ ]
for cnt in contours :
x , y , w , h = cv2 . boundingRect ( cnt )
2022-07-14 10:49:03 +02:00
box = np . array ( [ x , y , x + w , y + h ] ) . astype ( int )
2022-03-23 03:02:01 +01:00
box [ : : 2 ] = np . clip ( box [ : : 2 ] , 0 , width )
box [ 1 : : 2 ] = np . clip ( box [ 1 : : 2 ] , 0 , height )
boxes . append ( box )
return boxes
2022-11-27 14:25:27 +01:00
def only_keep_largest_contour ( mask : np . ndarray ) - > List [ np . ndarray ] :
"""
Args :
mask : ( h , w ) 0 ~ 255
Returns :
"""
_ , thresh = cv2 . threshold ( mask , 127 , 255 , 0 )
contours , _ = cv2 . findContours ( thresh , cv2 . RETR_EXTERNAL , cv2 . CHAIN_APPROX_SIMPLE )
max_area = 0
max_index = - 1
for i , cnt in enumerate ( contours ) :
area = cv2 . contourArea ( cnt )
if area > max_area :
max_area = area
max_index = i
if max_index != - 1 :
new_mask = np . zeros_like ( mask )
return cv2 . drawContours ( new_mask , contours , max_index , 255 , - 1 )
else :
return mask