From 5b0bf735b5952c84f7275ffd4b8a0a9e3e355036 Mon Sep 17 00:00:00 2001 From: Kenneth Estanislao Date: Mon, 23 Feb 2026 00:01:22 +0800 Subject: [PATCH] use onnx on face enhancer --- .gitignore | 1 + modules/core.py | 13 +- modules/processors/frame/face_enhancer.py | 374 ++++++++++++++++------ requirements.txt | 10 +- 4 files changed, 281 insertions(+), 117 deletions(-) diff --git a/.gitignore b/.gitignore index 6974d63..65636d7 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ models/DMDNet.pth faceswap/ .vscode/ switch_states.json +/models diff --git a/modules/core.py b/modules/core.py index 4f8d1b6..a85007a 100644 --- a/modules/core.py +++ b/modules/core.py @@ -11,7 +11,11 @@ import platform import signal import shutil import argparse -import torch +try: + import torch + HAS_TORCH = True +except ImportError: + HAS_TORCH = False import onnxruntime import tensorflow @@ -21,11 +25,12 @@ import modules.ui as ui from modules.processors.frame.core import get_frame_processors_modules from modules.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path -if 'ROCMExecutionProvider' in modules.globals.execution_providers: +if HAS_TORCH and 'ROCMExecutionProvider' in modules.globals.execution_providers: del torch warnings.filterwarnings('ignore', category=FutureWarning, module='insightface') -warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') +if HAS_TORCH: + warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') def parse_args() -> None: @@ -167,7 +172,7 @@ def limit_resources() -> None: def release_resources() -> None: - if 'CUDAExecutionProvider' in modules.globals.execution_providers: + if 'CUDAExecutionProvider' in modules.globals.execution_providers and HAS_TORCH: torch.cuda.empty_cache() diff --git a/modules/processors/frame/face_enhancer.py b/modules/processors/frame/face_enhancer.py index 22a4d52..bbc8276 100644 --- a/modules/processors/frame/face_enhancer.py +++ b/modules/processors/frame/face_enhancer.py @@ -1,20 +1,20 @@ # --- START OF FILE face_enhancer.py --- +# Uses ONNX Runtime for GFPGAN face enhancement (no torch/gfpgan dependency) from typing import Any, List import cv2 import threading -import gfpgan +import numpy as np import os -import platform -import torch # Make sure torch is imported + +import onnxruntime import modules.globals import modules.processors.frame.core from modules.core import update_status -from modules.face_analyser import get_one_face +from modules.face_analyser import get_one_face, get_many_faces from modules.typing import Frame, Face from modules.utilities import ( - conditional_download, is_image, is_video, ) @@ -29,15 +29,29 @@ models_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(abs_dir))), "models" ) +# Standard FFHQ 5-point face template for 512x512 resolution +# Points: left_eye, right_eye, nose, left_mouth, right_mouth +FFHQ_TEMPLATE_512 = np.array( + [ + [192.98138, 239.94708], + [318.90277, 240.19366], + [256.63416, 314.01935], + [201.26117, 371.41043], + [313.08905, 371.15118], + ], + dtype=np.float32, +) + def pre_check() -> bool: - download_directory_path = models_dir - conditional_download( - download_directory_path, - [ - "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth" - ], - ) + model_path = os.path.join(models_dir, "gfpgan-1024.onnx") + if not os.path.exists(model_path): + update_status( + f"GFPGAN ONNX model not found at {model_path}. " + "Please place gfpgan-1024.onnx in the models folder.", + NAME, + ) + return False return True @@ -50,108 +64,257 @@ def pre_start() -> bool: return True -def get_face_enhancer() -> Any: +def get_face_enhancer() -> onnxruntime.InferenceSession: """ - Initializes and returns the GFPGAN face enhancer instance, - prioritizing CUDA, then MPS (Mac), then CPU. + Initializes and returns the GFPGAN ONNX Runtime inference session, + using the execution providers configured in modules.globals. """ global FACE_ENHANCER with THREAD_LOCK: if FACE_ENHANCER is None: - model_path = os.path.join(models_dir, "GFPGANv1.4.pth") - device = None - try: - # Priority 1: CUDA - if torch.cuda.is_available(): - device = torch.device("cuda") - print(f"{NAME}: Using CUDA device.") - # Priority 2: MPS (Mac Silicon) - elif platform.system() == "Darwin" and torch.backends.mps.is_available(): - device = torch.device("mps") - print(f"{NAME}: Using MPS device.") - # Priority 3: CPU - else: - device = torch.device("cpu") - print(f"{NAME}: Using CPU device.") + model_path = os.path.join(models_dir, "gfpgan-1024.onnx") - FACE_ENHANCER = gfpgan.GFPGANer( - model_path=model_path, - upscale=1, # upscale=1 means enhancement only, no resizing - arch='clean', - channel_multiplier=2, - bg_upsampler=None, - device=device + if not os.path.exists(model_path): + raise FileNotFoundError( + f"{NAME}: Model not found at {model_path}" ) - print(f"{NAME}: GFPGANer initialized successfully on {device}.") + + try: + providers = modules.globals.execution_providers + + session_options = onnxruntime.SessionOptions() + session_options.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + + FACE_ENHANCER = onnxruntime.InferenceSession( + model_path, + sess_options=session_options, + providers=providers, + ) + + input_info = FACE_ENHANCER.get_inputs()[0] + output_info = FACE_ENHANCER.get_outputs()[0] + active_providers = FACE_ENHANCER.get_providers() + print( + f"{NAME}: GFPGAN ONNX model loaded successfully." + ) + print( + f"{NAME}: Input: {input_info.name}, " + f"shape: {input_info.shape}, type: {input_info.type}" + ) + print( + f"{NAME}: Output: {output_info.name}, " + f"shape: {output_info.shape}, type: {output_info.type}" + ) + print(f"{NAME}: Active providers: {active_providers}") except Exception as e: - print(f"{NAME}: Error initializing GFPGANer: {e}") - # Fallback to CPU if initialization with GPU fails for some reason - if device is not None and device.type != 'cpu': - print(f"{NAME}: Falling back to CPU due to error.") - try: - device = torch.device("cpu") - FACE_ENHANCER = gfpgan.GFPGANer( - model_path=model_path, - upscale=1, - arch='clean', - channel_multiplier=2, - bg_upsampler=None, - device=device - ) - print(f"{NAME}: GFPGANer initialized successfully on CPU after fallback.") - except Exception as fallback_e: - print(f"{NAME}: FATAL: Could not initialize GFPGANer even on CPU: {fallback_e}") - FACE_ENHANCER = None # Ensure it's None if totally failed - else: - # If it failed even on the first CPU attempt or device was already CPU - print(f"{NAME}: FATAL: Could not initialize GFPGANer on CPU: {e}") - FACE_ENHANCER = None # Ensure it's None if totally failed + print(f"{NAME}: Error loading GFPGAN ONNX model: {e}") + FACE_ENHANCER = None + raise RuntimeError( + f"{NAME}: Failed to load GFPGAN ONNX model: {e}" + ) - - # Check if enhancer is still None after attempting initialization if FACE_ENHANCER is None: - raise RuntimeError(f"{NAME}: Failed to initialize GFPGANer. Check logs for errors.") + raise RuntimeError( + f"{NAME}: Failed to initialize GFPGAN ONNX session. Check logs." + ) return FACE_ENHANCER +def _align_face( + frame: Frame, landmarks_5: np.ndarray, output_size: int +) -> tuple: + """ + Align and crop a face from the frame using 5-point landmarks and the + standard FFHQ template. + + Returns: + (aligned_face, affine_matrix) or (None, None) on failure. + """ + # Scale the 512-base template to the desired output size + scale = output_size / 512.0 + template = FFHQ_TEMPLATE_512 * scale + + # Estimate a similarity transform (4 DOF: rotation, scale, tx, ty) + affine_matrix, _ = cv2.estimateAffinePartial2D( + landmarks_5, template, method=cv2.LMEDS + ) + if affine_matrix is None: + return None, None + + # Warp the face to the aligned position + aligned_face = cv2.warpAffine( + frame, + affine_matrix, + (output_size, output_size), + borderMode=cv2.BORDER_CONSTANT, + borderValue=(135, 133, 132), + ) + + return aligned_face, affine_matrix + + +def _paste_back( + frame: Frame, + enhanced_face: np.ndarray, + affine_matrix: np.ndarray, + output_size: int, +) -> Frame: + """ + Paste an enhanced (aligned) face back onto the original frame using the + inverse affine transform with feathered-edge blending. + """ + h, w = frame.shape[:2] + + # Inverse the affine warp + inv_matrix = cv2.invertAffineTransform(affine_matrix) + inv_restored = cv2.warpAffine( + enhanced_face, + inv_matrix, + (w, h), + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0, 0, 0), + ) + + # Build a soft feathered mask in aligned space for edge blending + face_mask = np.ones((output_size, output_size), dtype=np.float32) + + # Feather the border (5 % of the size on each edge) + border = max(1, int(output_size * 0.05)) + ramp_up = np.linspace(0.0, 1.0, border, dtype=np.float32) + ramp_down = np.linspace(1.0, 0.0, border, dtype=np.float32) + + # Top / bottom rows + face_mask[:border, :] *= ramp_up[:, None] + face_mask[-border:, :] *= ramp_down[:, None] + # Left / right columns + face_mask[:, :border] *= ramp_up[None, :] + face_mask[:, -border:] *= ramp_down[None, :] + + # Expand to 3-channel + face_mask_3c = np.stack([face_mask] * 3, axis=-1) + + # Warp mask back to original frame space + inv_mask = cv2.warpAffine( + face_mask_3c, + inv_matrix, + (w, h), + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0, 0, 0), + ) + inv_mask = np.clip(inv_mask, 0.0, 1.0) + + # Alpha-blend + result = ( + frame.astype(np.float32) * (1.0 - inv_mask) + + inv_restored.astype(np.float32) * inv_mask + ) + return np.clip(result, 0, 255).astype(np.uint8) + + +def _preprocess_face(aligned_face: np.ndarray) -> np.ndarray: + """ + Convert an aligned BGR uint8 face image to the ONNX model input tensor. + Format: NCHW float32, normalised to [-1, 1]. + """ + # BGR -> RGB + rgb = cv2.cvtColor(aligned_face, cv2.COLOR_BGR2RGB).astype(np.float32) + # [0, 255] -> [0, 1] -> [-1, 1] + rgb = rgb / 255.0 + rgb = (rgb - 0.5) / 0.5 + # HWC -> CHW, add batch dim + chw = np.transpose(rgb, (2, 0, 1)) + return np.expand_dims(chw, axis=0) # shape: (1, 3, H, W) + + +def _postprocess_face(output: np.ndarray) -> np.ndarray: + """ + Convert the ONNX model output tensor back to a BGR uint8 image. + Expects input in NCHW format with values in [-1, 1]. + """ + face = np.squeeze(output) # remove batch dim -> (3, H, W) + face = np.transpose(face, (1, 2, 0)) # CHW -> HWC + # [-1, 1] -> [0, 1] -> [0, 255] + face = (face + 1.0) / 2.0 + face = np.clip(face * 255.0, 0, 255).astype(np.uint8) + # RGB -> BGR + return cv2.cvtColor(face, cv2.COLOR_RGB2BGR) + + def enhance_face(temp_frame: Frame) -> Frame: - """Enhances faces in a single frame using the global GFPGANer instance.""" - # Ensure enhancer is ready - enhancer = get_face_enhancer() + """Enhances all faces in a frame using the GFPGAN ONNX model.""" + session = get_face_enhancer() + + # Determine model input resolution from the session metadata + input_info = session.get_inputs()[0] + input_name = input_info.name + input_shape = input_info.shape # e.g. [1, 3, 512, 512] + # Safely extract input size (handle dynamic / symbolic dimensions) try: - with THREAD_SEMAPHORE: - # The enhance method returns: _, restored_faces, restored_img - _, _, restored_img = enhancer.enhance( - temp_frame, - has_aligned=False, # Assume faces are not pre-aligned - only_center_face=False, # Enhance all detected faces - paste_back=True # Paste enhanced faces back onto the original image - ) - # GFPGAN might return None if no face is detected or an error occurs - if restored_img is None: - # print(f"{NAME}: Warning: GFPGAN enhancement returned None. Returning original frame.") - return temp_frame - return restored_img - except Exception as e: - print(f"{NAME}: Error during face enhancement: {e}") - # Return the original frame in case of error during enhancement + align_size = int(input_shape[2]) + if align_size <= 0: + align_size = 512 + except (ValueError, TypeError, IndexError): + align_size = 512 + + # Detect faces using InsightFace (already a project dependency) + faces = get_many_faces(temp_frame) + if not faces: return temp_frame + result_frame = temp_frame.copy() + + for face in faces: + # Need the 5-point key-points for alignment + if not hasattr(face, "kps") or face.kps is None: + continue + + landmarks_5 = face.kps.astype(np.float32) + if landmarks_5.shape[0] < 5: + continue + + # Align / crop the face at the model's INPUT resolution + aligned_face, affine_matrix = _align_face( + temp_frame, landmarks_5, output_size=align_size + ) + if aligned_face is None or affine_matrix is None: + continue + + try: + with THREAD_SEMAPHORE: + input_tensor = _preprocess_face(aligned_face) + output_tensor = session.run(None, {input_name: input_tensor})[0] + enhanced_bgr = _postprocess_face(output_tensor) + + # The model may output at a different resolution than its input + # (e.g. input 512x512 → output 1024x1024). Resize the enhanced + # face back to the alignment size so the inverse affine maps + # correctly. + eh, ew = enhanced_bgr.shape[:2] + if eh != align_size or ew != align_size: + enhanced_bgr = cv2.resize( + enhanced_bgr, + (align_size, align_size), + interpolation=cv2.INTER_LANCZOS4, + ) + + # Paste enhanced face back onto the frame + result_frame = _paste_back( + result_frame, enhanced_bgr, affine_matrix, output_size=align_size + ) + except Exception as e: + print(f"{NAME}: Error enhancing a face: {e}") + continue + + return result_frame + def process_frame(source_face: Face | None, temp_frame: Frame) -> Frame: """Processes a frame: enhances face if detected.""" - # We don't strictly need source_face for enhancement only - # Check if any face exists to potentially save processing time, though GFPGAN also does detection. - # For simplicity and ensuring enhancement is attempted if possible, we can rely on enhance_face. - # target_face = get_one_face(temp_frame) # This gets only ONE face - # If you want to enhance ONLY if a face is detected by your *own* analyser first: - # has_face = get_one_face(temp_frame) is not None # Or use get_many_faces - # if has_face: - # temp_frame = enhance_face(temp_frame) - # else: # Enhance regardless, let GFPGAN handle detection temp_frame = enhance_face(temp_frame) return temp_frame @@ -162,14 +325,18 @@ def process_frames( """Processes multiple frames from file paths.""" for temp_frame_path in temp_frame_paths: if not os.path.exists(temp_frame_path): - print(f"{NAME}: Warning: Frame path not found {temp_frame_path}, skipping.") + print( + f"{NAME}: Warning: Frame path not found {temp_frame_path}, skipping." + ) if progress: progress.update(1) continue temp_frame = cv2.imread(temp_frame_path) if temp_frame is None: - print(f"{NAME}: Warning: Failed to read frame {temp_frame_path}, skipping.") + print( + f"{NAME}: Warning: Failed to read frame {temp_frame_path}, skipping." + ) if progress: progress.update(1) continue @@ -180,7 +347,9 @@ def process_frames( progress.update(1) -def process_image(source_path: str | None, target_path: str, output_path: str) -> None: +def process_image( + source_path: str | None, target_path: str, output_path: str +) -> None: """Processes a single image file.""" target_frame = cv2.imread(target_path) if target_frame is None: @@ -191,16 +360,13 @@ def process_image(source_path: str | None, target_path: str, output_path: str) - print(f"{NAME}: Enhanced image saved to {output_path}") -def process_video(source_path: str | None, temp_frame_paths: List[str]) -> None: +def process_video( + source_path: str | None, temp_frame_paths: List[str] +) -> None: """Processes video frames using the frame processor core.""" - # source_path might be optional depending on how process_video is called - modules.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames) + modules.processors.frame.core.process_video( + source_path, temp_frame_paths, process_frames + ) -# Optional: Keep process_frame_v2 if it's used elsewhere, otherwise it's redundant -# def process_frame_v2(temp_frame: Frame) -> Frame: -# target_face = get_one_face(temp_frame) -# if target_face: -# temp_frame = enhance_face(temp_frame) -# return temp_frame -# --- END OF FILE face_enhancer.py --- \ No newline at end of file +# --- END OF FILE face_enhancer.py --- diff --git a/requirements.txt b/requirements.txt index 1935a6c..19d1b31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ ---extra-index-url https://download.pytorch.org/whl/cu128 - numpy>=1.23.5,<2 typing-extensions>=4.8.0 opencv-python==4.10.0.84 @@ -10,15 +8,9 @@ psutil==5.9.8 tk==0.1.0 customtkinter==5.2.2 pillow==12.1.1 -torch; sys_platform != 'darwin' -torch==2.8.0+cu128; sys_platform == 'darwin' -torchvision; sys_platform != 'darwin' -torchvision==0.20.1; sys_platform == 'darwin' onnxruntime-silicon==1.16.3; sys_platform == 'darwin' and platform_machine == 'arm64' -onnxruntime-gpu==1.22.0; sys_platform != 'darwin' +onnxruntime-gpu==1.24.2; sys_platform != 'darwin' tensorflow; sys_platform != 'darwin' opennsfw2==0.10.2 protobuf==4.25.1 -git+https://github.com/xinntao/BasicSR.git@master -git+https://github.com/TencentARC/GFPGAN.git@master pygrabber