diff --git a/main.py b/main.py index f6051b1..fe501c4 100644 --- a/main.py +++ b/main.py @@ -92,7 +92,12 @@ def process(): print(f"process time: {(time.time() - start) * 1000}ms") torch.cuda.empty_cache() + if alpha_channel is not None: + if alpha_channel.shape[:2] != res_np_img.shape[:2]: + alpha_channel = cv2.resize( + alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) + ) res_np_img = np.concatenate( (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 )