optimize mem

This commit is contained in:
Qing 2024-01-09 22:54:20 +08:00
parent e94a94e3c2
commit db4a6f4547
3 changed files with 18 additions and 3 deletions

View File

@ -1,6 +1,14 @@
import os import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
os.environ["LRU_CACHE_CAPACITY"] = "1"
# prevent CPU memory leak when run model on GPU
# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
import warnings import warnings

View File

@ -1,9 +1,10 @@
import json import json
import cv2
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
from PIL import Image
import cv2
import psutil
from PIL import Image
from loguru import logger from loguru import logger
from rich.console import Console from rich.console import Console
from rich.progress import ( from rich.progress import (
@ -14,10 +15,10 @@ from rich.progress import (
TextColumn, TextColumn,
BarColumn, BarColumn,
TaskProgressColumn, TaskProgressColumn,
TimeRemainingColumn,
) )
from iopaint.helper import pil_to_bytes from iopaint.helper import pil_to_bytes
from iopaint.model.utils import torch_gc
from iopaint.model_manager import ModelManager from iopaint.model_manager import ModelManager
from iopaint.schema import InpaintRequest from iopaint.schema import InpaintRequest
@ -119,3 +120,8 @@ def batch_inpaint(
fw.write(img_bytes) fw.write(img_bytes)
progress.update(task, advance=1) progress.update(task, advance=1)
torch_gc()
# pid = psutil.Process().pid
# memory_info = psutil.Process(pid).memory_info()
# memory_in_mb = memory_info.rss / (1024 * 1024)
# print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")

View File

@ -69,6 +69,7 @@ class ModelManager:
raise NotImplementedError(f"Unsupported model: {name}") raise NotImplementedError(f"Unsupported model: {name}")
@torch.inference_mode()
def __call__(self, image, mask, config: InpaintRequest): def __call__(self, image, mask, config: InpaintRequest):
""" """