fix GFPGAN face detect

This commit is contained in:
Qing 2023-04-01 21:26:40 +08:00
parent f9727e1af6
commit 674c60f5a8

View File

@ -17,11 +17,10 @@ class GFPGANPlugin(BasePlugin):
model_path = download_model(url, model_md5) model_path = download_model(url, model_md5)
logger.info(f"GFPGAN model path: {model_path}") logger.info(f"GFPGAN model path: {model_path}")
face_det_device = "cpu" if "cuda" in str(device) else device
import facexlib import facexlib
if hasattr(facexlib.detection.retinaface, "device"): if hasattr(facexlib.detection.retinaface, "device"):
facexlib.detection.retinaface.device = face_det_device facexlib.detection.retinaface.device = device
# Use GFPGAN for face enhancement # Use GFPGAN for face enhancement
self.face_enhancer = MyGFPGANer( self.face_enhancer = MyGFPGANer(
@ -32,9 +31,9 @@ class GFPGANPlugin(BasePlugin):
device=device, device=device,
bg_upsampler=upscaler.model if upscaler is not None else None, bg_upsampler=upscaler.model if upscaler is not None else None,
) )
self.face_enhancer.face_helper.face_det.mean_tensor.to(face_det_device) self.face_enhancer.face_helper.face_det.mean_tensor.to(device)
self.face_enhancer.face_helper.face_det = ( self.face_enhancer.face_helper.face_det = (
self.face_enhancer.face_helper.face_det.to(face_det_device) self.face_enhancer.face_helper.face_det.to(device)
) )
def __call__(self, rgb_np_img, files, form): def __call__(self, rgb_np_img, files, form):