use onnx on face enhancer
This commit is contained in:
@@ -25,3 +25,4 @@ models/DMDNet.pth
|
||||
faceswap/
|
||||
.vscode/
|
||||
switch_states.json
|
||||
/models
|
||||
|
||||
+9
-4
@@ -11,7 +11,11 @@ import platform
|
||||
import signal
|
||||
import shutil
|
||||
import argparse
|
||||
import torch
|
||||
try:
|
||||
import torch
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
import onnxruntime
|
||||
import tensorflow
|
||||
|
||||
@@ -21,11 +25,12 @@ import modules.ui as ui
|
||||
from modules.processors.frame.core import get_frame_processors_modules
|
||||
from modules.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
|
||||
|
||||
if 'ROCMExecutionProvider' in modules.globals.execution_providers:
|
||||
if HAS_TORCH and 'ROCMExecutionProvider' in modules.globals.execution_providers:
|
||||
del torch
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
|
||||
if HAS_TORCH:
|
||||
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
|
||||
|
||||
|
||||
def parse_args() -> None:
|
||||
@@ -167,7 +172,7 @@ def limit_resources() -> None:
|
||||
|
||||
|
||||
def release_resources() -> None:
|
||||
if 'CUDAExecutionProvider' in modules.globals.execution_providers:
|
||||
if 'CUDAExecutionProvider' in modules.globals.execution_providers and HAS_TORCH:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# --- START OF FILE face_enhancer.py ---
|
||||
# Uses ONNX Runtime for GFPGAN face enhancement (no torch/gfpgan dependency)
|
||||
|
||||
from typing import Any, List
|
||||
import cv2
|
||||
import threading
|
||||
import gfpgan
|
||||
import numpy as np
|
||||
import os
|
||||
import platform
|
||||
import torch # Make sure torch is imported
|
||||
|
||||
import onnxruntime
|
||||
|
||||
import modules.globals
|
||||
import modules.processors.frame.core
|
||||
from modules.core import update_status
|
||||
from modules.face_analyser import get_one_face
|
||||
from modules.face_analyser import get_one_face, get_many_faces
|
||||
from modules.typing import Frame, Face
|
||||
from modules.utilities import (
|
||||
conditional_download,
|
||||
is_image,
|
||||
is_video,
|
||||
)
|
||||
@@ -29,15 +29,29 @@ models_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(abs_dir))), "models"
|
||||
)
|
||||
|
||||
# Standard FFHQ 5-point face template for 512x512 resolution
|
||||
# Points: left_eye, right_eye, nose, left_mouth, right_mouth
|
||||
FFHQ_TEMPLATE_512 = np.array(
|
||||
[
|
||||
[192.98138, 239.94708],
|
||||
[318.90277, 240.19366],
|
||||
[256.63416, 314.01935],
|
||||
[201.26117, 371.41043],
|
||||
[313.08905, 371.15118],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def pre_check() -> bool:
|
||||
download_directory_path = models_dir
|
||||
conditional_download(
|
||||
download_directory_path,
|
||||
[
|
||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth"
|
||||
],
|
||||
)
|
||||
model_path = os.path.join(models_dir, "gfpgan-1024.onnx")
|
||||
if not os.path.exists(model_path):
|
||||
update_status(
|
||||
f"GFPGAN ONNX model not found at {model_path}. "
|
||||
"Please place gfpgan-1024.onnx in the models folder.",
|
||||
NAME,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -50,108 +64,257 @@ def pre_start() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_face_enhancer() -> Any:
|
||||
def get_face_enhancer() -> onnxruntime.InferenceSession:
|
||||
"""
|
||||
Initializes and returns the GFPGAN face enhancer instance,
|
||||
prioritizing CUDA, then MPS (Mac), then CPU.
|
||||
Initializes and returns the GFPGAN ONNX Runtime inference session,
|
||||
using the execution providers configured in modules.globals.
|
||||
"""
|
||||
global FACE_ENHANCER
|
||||
|
||||
with THREAD_LOCK:
|
||||
if FACE_ENHANCER is None:
|
||||
model_path = os.path.join(models_dir, "GFPGANv1.4.pth")
|
||||
device = None
|
||||
try:
|
||||
# Priority 1: CUDA
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print(f"{NAME}: Using CUDA device.")
|
||||
# Priority 2: MPS (Mac Silicon)
|
||||
elif platform.system() == "Darwin" and torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
print(f"{NAME}: Using MPS device.")
|
||||
# Priority 3: CPU
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"{NAME}: Using CPU device.")
|
||||
model_path = os.path.join(models_dir, "gfpgan-1024.onnx")
|
||||
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1, # upscale=1 means enhancement only, no resizing
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
device=device
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"{NAME}: Model not found at {model_path}"
|
||||
)
|
||||
print(f"{NAME}: GFPGANer initialized successfully on {device}.")
|
||||
|
||||
try:
|
||||
providers = modules.globals.execution_providers
|
||||
|
||||
session_options = onnxruntime.SessionOptions()
|
||||
session_options.graph_optimization_level = (
|
||||
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
)
|
||||
|
||||
FACE_ENHANCER = onnxruntime.InferenceSession(
|
||||
model_path,
|
||||
sess_options=session_options,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
input_info = FACE_ENHANCER.get_inputs()[0]
|
||||
output_info = FACE_ENHANCER.get_outputs()[0]
|
||||
active_providers = FACE_ENHANCER.get_providers()
|
||||
print(
|
||||
f"{NAME}: GFPGAN ONNX model loaded successfully."
|
||||
)
|
||||
print(
|
||||
f"{NAME}: Input: {input_info.name}, "
|
||||
f"shape: {input_info.shape}, type: {input_info.type}"
|
||||
)
|
||||
print(
|
||||
f"{NAME}: Output: {output_info.name}, "
|
||||
f"shape: {output_info.shape}, type: {output_info.type}"
|
||||
)
|
||||
print(f"{NAME}: Active providers: {active_providers}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"{NAME}: Error initializing GFPGANer: {e}")
|
||||
# Fallback to CPU if initialization with GPU fails for some reason
|
||||
if device is not None and device.type != 'cpu':
|
||||
print(f"{NAME}: Falling back to CPU due to error.")
|
||||
try:
|
||||
device = torch.device("cpu")
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
device=device
|
||||
)
|
||||
print(f"{NAME}: GFPGANer initialized successfully on CPU after fallback.")
|
||||
except Exception as fallback_e:
|
||||
print(f"{NAME}: FATAL: Could not initialize GFPGANer even on CPU: {fallback_e}")
|
||||
FACE_ENHANCER = None # Ensure it's None if totally failed
|
||||
else:
|
||||
# If it failed even on the first CPU attempt or device was already CPU
|
||||
print(f"{NAME}: FATAL: Could not initialize GFPGANer on CPU: {e}")
|
||||
FACE_ENHANCER = None # Ensure it's None if totally failed
|
||||
print(f"{NAME}: Error loading GFPGAN ONNX model: {e}")
|
||||
FACE_ENHANCER = None
|
||||
raise RuntimeError(
|
||||
f"{NAME}: Failed to load GFPGAN ONNX model: {e}"
|
||||
)
|
||||
|
||||
|
||||
# Check if enhancer is still None after attempting initialization
|
||||
if FACE_ENHANCER is None:
|
||||
raise RuntimeError(f"{NAME}: Failed to initialize GFPGANer. Check logs for errors.")
|
||||
raise RuntimeError(
|
||||
f"{NAME}: Failed to initialize GFPGAN ONNX session. Check logs."
|
||||
)
|
||||
|
||||
return FACE_ENHANCER
|
||||
|
||||
|
||||
def _align_face(
|
||||
frame: Frame, landmarks_5: np.ndarray, output_size: int
|
||||
) -> tuple:
|
||||
"""
|
||||
Align and crop a face from the frame using 5-point landmarks and the
|
||||
standard FFHQ template.
|
||||
|
||||
Returns:
|
||||
(aligned_face, affine_matrix) or (None, None) on failure.
|
||||
"""
|
||||
# Scale the 512-base template to the desired output size
|
||||
scale = output_size / 512.0
|
||||
template = FFHQ_TEMPLATE_512 * scale
|
||||
|
||||
# Estimate a similarity transform (4 DOF: rotation, scale, tx, ty)
|
||||
affine_matrix, _ = cv2.estimateAffinePartial2D(
|
||||
landmarks_5, template, method=cv2.LMEDS
|
||||
)
|
||||
if affine_matrix is None:
|
||||
return None, None
|
||||
|
||||
# Warp the face to the aligned position
|
||||
aligned_face = cv2.warpAffine(
|
||||
frame,
|
||||
affine_matrix,
|
||||
(output_size, output_size),
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(135, 133, 132),
|
||||
)
|
||||
|
||||
return aligned_face, affine_matrix
|
||||
|
||||
|
||||
def _paste_back(
|
||||
frame: Frame,
|
||||
enhanced_face: np.ndarray,
|
||||
affine_matrix: np.ndarray,
|
||||
output_size: int,
|
||||
) -> Frame:
|
||||
"""
|
||||
Paste an enhanced (aligned) face back onto the original frame using the
|
||||
inverse affine transform with feathered-edge blending.
|
||||
"""
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Inverse the affine warp
|
||||
inv_matrix = cv2.invertAffineTransform(affine_matrix)
|
||||
inv_restored = cv2.warpAffine(
|
||||
enhanced_face,
|
||||
inv_matrix,
|
||||
(w, h),
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0, 0, 0),
|
||||
)
|
||||
|
||||
# Build a soft feathered mask in aligned space for edge blending
|
||||
face_mask = np.ones((output_size, output_size), dtype=np.float32)
|
||||
|
||||
# Feather the border (5 % of the size on each edge)
|
||||
border = max(1, int(output_size * 0.05))
|
||||
ramp_up = np.linspace(0.0, 1.0, border, dtype=np.float32)
|
||||
ramp_down = np.linspace(1.0, 0.0, border, dtype=np.float32)
|
||||
|
||||
# Top / bottom rows
|
||||
face_mask[:border, :] *= ramp_up[:, None]
|
||||
face_mask[-border:, :] *= ramp_down[:, None]
|
||||
# Left / right columns
|
||||
face_mask[:, :border] *= ramp_up[None, :]
|
||||
face_mask[:, -border:] *= ramp_down[None, :]
|
||||
|
||||
# Expand to 3-channel
|
||||
face_mask_3c = np.stack([face_mask] * 3, axis=-1)
|
||||
|
||||
# Warp mask back to original frame space
|
||||
inv_mask = cv2.warpAffine(
|
||||
face_mask_3c,
|
||||
inv_matrix,
|
||||
(w, h),
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0, 0, 0),
|
||||
)
|
||||
inv_mask = np.clip(inv_mask, 0.0, 1.0)
|
||||
|
||||
# Alpha-blend
|
||||
result = (
|
||||
frame.astype(np.float32) * (1.0 - inv_mask)
|
||||
+ inv_restored.astype(np.float32) * inv_mask
|
||||
)
|
||||
return np.clip(result, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _preprocess_face(aligned_face: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert an aligned BGR uint8 face image to the ONNX model input tensor.
|
||||
Format: NCHW float32, normalised to [-1, 1].
|
||||
"""
|
||||
# BGR -> RGB
|
||||
rgb = cv2.cvtColor(aligned_face, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||
# [0, 255] -> [0, 1] -> [-1, 1]
|
||||
rgb = rgb / 255.0
|
||||
rgb = (rgb - 0.5) / 0.5
|
||||
# HWC -> CHW, add batch dim
|
||||
chw = np.transpose(rgb, (2, 0, 1))
|
||||
return np.expand_dims(chw, axis=0) # shape: (1, 3, H, W)
|
||||
|
||||
|
||||
def _postprocess_face(output: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert the ONNX model output tensor back to a BGR uint8 image.
|
||||
Expects input in NCHW format with values in [-1, 1].
|
||||
"""
|
||||
face = np.squeeze(output) # remove batch dim -> (3, H, W)
|
||||
face = np.transpose(face, (1, 2, 0)) # CHW -> HWC
|
||||
# [-1, 1] -> [0, 1] -> [0, 255]
|
||||
face = (face + 1.0) / 2.0
|
||||
face = np.clip(face * 255.0, 0, 255).astype(np.uint8)
|
||||
# RGB -> BGR
|
||||
return cv2.cvtColor(face, cv2.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
def enhance_face(temp_frame: Frame) -> Frame:
|
||||
"""Enhances faces in a single frame using the global GFPGANer instance."""
|
||||
# Ensure enhancer is ready
|
||||
enhancer = get_face_enhancer()
|
||||
"""Enhances all faces in a frame using the GFPGAN ONNX model."""
|
||||
session = get_face_enhancer()
|
||||
|
||||
# Determine model input resolution from the session metadata
|
||||
input_info = session.get_inputs()[0]
|
||||
input_name = input_info.name
|
||||
input_shape = input_info.shape # e.g. [1, 3, 512, 512]
|
||||
# Safely extract input size (handle dynamic / symbolic dimensions)
|
||||
try:
|
||||
with THREAD_SEMAPHORE:
|
||||
# The enhance method returns: _, restored_faces, restored_img
|
||||
_, _, restored_img = enhancer.enhance(
|
||||
temp_frame,
|
||||
has_aligned=False, # Assume faces are not pre-aligned
|
||||
only_center_face=False, # Enhance all detected faces
|
||||
paste_back=True # Paste enhanced faces back onto the original image
|
||||
)
|
||||
# GFPGAN might return None if no face is detected or an error occurs
|
||||
if restored_img is None:
|
||||
# print(f"{NAME}: Warning: GFPGAN enhancement returned None. Returning original frame.")
|
||||
return temp_frame
|
||||
return restored_img
|
||||
except Exception as e:
|
||||
print(f"{NAME}: Error during face enhancement: {e}")
|
||||
# Return the original frame in case of error during enhancement
|
||||
align_size = int(input_shape[2])
|
||||
if align_size <= 0:
|
||||
align_size = 512
|
||||
except (ValueError, TypeError, IndexError):
|
||||
align_size = 512
|
||||
|
||||
# Detect faces using InsightFace (already a project dependency)
|
||||
faces = get_many_faces(temp_frame)
|
||||
if not faces:
|
||||
return temp_frame
|
||||
|
||||
result_frame = temp_frame.copy()
|
||||
|
||||
for face in faces:
|
||||
# Need the 5-point key-points for alignment
|
||||
if not hasattr(face, "kps") or face.kps is None:
|
||||
continue
|
||||
|
||||
landmarks_5 = face.kps.astype(np.float32)
|
||||
if landmarks_5.shape[0] < 5:
|
||||
continue
|
||||
|
||||
# Align / crop the face at the model's INPUT resolution
|
||||
aligned_face, affine_matrix = _align_face(
|
||||
temp_frame, landmarks_5, output_size=align_size
|
||||
)
|
||||
if aligned_face is None or affine_matrix is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
with THREAD_SEMAPHORE:
|
||||
input_tensor = _preprocess_face(aligned_face)
|
||||
output_tensor = session.run(None, {input_name: input_tensor})[0]
|
||||
enhanced_bgr = _postprocess_face(output_tensor)
|
||||
|
||||
# The model may output at a different resolution than its input
|
||||
# (e.g. input 512x512 → output 1024x1024). Resize the enhanced
|
||||
# face back to the alignment size so the inverse affine maps
|
||||
# correctly.
|
||||
eh, ew = enhanced_bgr.shape[:2]
|
||||
if eh != align_size or ew != align_size:
|
||||
enhanced_bgr = cv2.resize(
|
||||
enhanced_bgr,
|
||||
(align_size, align_size),
|
||||
interpolation=cv2.INTER_LANCZOS4,
|
||||
)
|
||||
|
||||
# Paste enhanced face back onto the frame
|
||||
result_frame = _paste_back(
|
||||
result_frame, enhanced_bgr, affine_matrix, output_size=align_size
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{NAME}: Error enhancing a face: {e}")
|
||||
continue
|
||||
|
||||
return result_frame
|
||||
|
||||
|
||||
def process_frame(source_face: Face | None, temp_frame: Frame) -> Frame:
|
||||
"""Processes a frame: enhances face if detected."""
|
||||
# We don't strictly need source_face for enhancement only
|
||||
# Check if any face exists to potentially save processing time, though GFPGAN also does detection.
|
||||
# For simplicity and ensuring enhancement is attempted if possible, we can rely on enhance_face.
|
||||
# target_face = get_one_face(temp_frame) # This gets only ONE face
|
||||
# If you want to enhance ONLY if a face is detected by your *own* analyser first:
|
||||
# has_face = get_one_face(temp_frame) is not None # Or use get_many_faces
|
||||
# if has_face:
|
||||
# temp_frame = enhance_face(temp_frame)
|
||||
# else: # Enhance regardless, let GFPGAN handle detection
|
||||
temp_frame = enhance_face(temp_frame)
|
||||
return temp_frame
|
||||
|
||||
@@ -162,14 +325,18 @@ def process_frames(
|
||||
"""Processes multiple frames from file paths."""
|
||||
for temp_frame_path in temp_frame_paths:
|
||||
if not os.path.exists(temp_frame_path):
|
||||
print(f"{NAME}: Warning: Frame path not found {temp_frame_path}, skipping.")
|
||||
print(
|
||||
f"{NAME}: Warning: Frame path not found {temp_frame_path}, skipping."
|
||||
)
|
||||
if progress:
|
||||
progress.update(1)
|
||||
continue
|
||||
|
||||
temp_frame = cv2.imread(temp_frame_path)
|
||||
if temp_frame is None:
|
||||
print(f"{NAME}: Warning: Failed to read frame {temp_frame_path}, skipping.")
|
||||
print(
|
||||
f"{NAME}: Warning: Failed to read frame {temp_frame_path}, skipping."
|
||||
)
|
||||
if progress:
|
||||
progress.update(1)
|
||||
continue
|
||||
@@ -180,7 +347,9 @@ def process_frames(
|
||||
progress.update(1)
|
||||
|
||||
|
||||
def process_image(source_path: str | None, target_path: str, output_path: str) -> None:
|
||||
def process_image(
|
||||
source_path: str | None, target_path: str, output_path: str
|
||||
) -> None:
|
||||
"""Processes a single image file."""
|
||||
target_frame = cv2.imread(target_path)
|
||||
if target_frame is None:
|
||||
@@ -191,16 +360,13 @@ def process_image(source_path: str | None, target_path: str, output_path: str) -
|
||||
print(f"{NAME}: Enhanced image saved to {output_path}")
|
||||
|
||||
|
||||
def process_video(source_path: str | None, temp_frame_paths: List[str]) -> None:
|
||||
def process_video(
|
||||
source_path: str | None, temp_frame_paths: List[str]
|
||||
) -> None:
|
||||
"""Processes video frames using the frame processor core."""
|
||||
# source_path might be optional depending on how process_video is called
|
||||
modules.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames)
|
||||
modules.processors.frame.core.process_video(
|
||||
source_path, temp_frame_paths, process_frames
|
||||
)
|
||||
|
||||
# Optional: Keep process_frame_v2 if it's used elsewhere, otherwise it's redundant
|
||||
# def process_frame_v2(temp_frame: Frame) -> Frame:
|
||||
# target_face = get_one_face(temp_frame)
|
||||
# if target_face:
|
||||
# temp_frame = enhance_face(temp_frame)
|
||||
# return temp_frame
|
||||
|
||||
# --- END OF FILE face_enhancer.py ---
|
||||
# --- END OF FILE face_enhancer.py ---
|
||||
|
||||
+1
-9
@@ -1,5 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu128
|
||||
|
||||
numpy>=1.23.5,<2
|
||||
typing-extensions>=4.8.0
|
||||
opencv-python==4.10.0.84
|
||||
@@ -10,15 +8,9 @@ psutil==5.9.8
|
||||
tk==0.1.0
|
||||
customtkinter==5.2.2
|
||||
pillow==12.1.1
|
||||
torch; sys_platform != 'darwin'
|
||||
torch==2.8.0+cu128; sys_platform == 'darwin'
|
||||
torchvision; sys_platform != 'darwin'
|
||||
torchvision==0.20.1; sys_platform == 'darwin'
|
||||
onnxruntime-silicon==1.16.3; sys_platform == 'darwin' and platform_machine == 'arm64'
|
||||
onnxruntime-gpu==1.22.0; sys_platform != 'darwin'
|
||||
onnxruntime-gpu==1.24.2; sys_platform != 'darwin'
|
||||
tensorflow; sys_platform != 'darwin'
|
||||
opennsfw2==0.10.2
|
||||
protobuf==4.25.1
|
||||
git+https://github.com/xinntao/BasicSR.git@master
|
||||
git+https://github.com/TencentARC/GFPGAN.git@master
|
||||
pygrabber
|
||||
|
||||
Reference in New Issue
Block a user