상세 컨텐츠

본문 제목

업스케일링 코드 (환경에 맞게 코드 최적화하기)

업스케일링

by zmo 2025. 11. 24. 17:12

본문

이 코드는 [https://github.com/xinntao/Real-ESRGAN]를 기반으로 수정되었습니다.

또한 RTX 3090 (NVIDIA GPU)에 최적화되어 있습니다.

NVIDIA 그래픽카드의 NVENC 기능을 사용하여 인코딩 속도를 높였으므로, 만약 NVIDIA GPU가 없다면, 코드의 h264_nvenc 부분을 libx264로 변경해서 사용하세요!

 

설치 환경

 

Python 버전:  Python 3.11 환경에서 테스트되었습니다.

필수 라이브러리: torch, torchvision, basicsr, opencv-python

외부 프로그램: FFmpeg 설치, 환경 변수(PATH)에 등록

 

 

오늘은 저번주에 업스케일링에 70시간이 걸려 좌절했던 쓴 기억을 발판으로 삼아 환경에 맞게 코드 최적화를 진행할 것이다. 해당 글에는 코드만 올리고 테스트와 수정을 거쳐(한번에 잘되면 좋겠다) 한번 더 알아볼 예정이다.

 

 

자동 업스케일링 스크립트

import cv2
import os
import sys
import subprocess
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

# --- 설정 ---
input_video = "video.mp4"                # 원본 파일명
output_video = "output_4k_fast.mp4"      # 결과 파일명
model_name = 'RealESRNet_x4plus'         # 빠른 모델
outscale = 2                             # 2배 확대 (4K)
tile_size = 0                            # 3090용 통짜 처리
# -----------

def main():
    # 1. 모델 준비
    model_path = os.path.join('weights', model_name + '.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    upsampler = RealESRGANer(
        scale=4,
        model_path=model_path,
        model=model,
        tile=tile_size,
        tile_pad=10,
        pre_pad=0,
        half=True, # FP16 사용 (속도 향상)
        device=device,
    )

    # 2. 비디오 정보 읽기
    cap = cv2.VideoCapture(input_video)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    target_width = width * outscale
    target_height = height * outscale

    print(f"Input: {width}x{height} @ {fps}fps")
    print(f"Output Target: {target_width}x{target_height}")


    ffmpeg_cmd = [
        'ffmpeg',
        '-y', # 덮어쓰기
        '-f', 'rawvideo',
        '-vcodec', 'rawvideo',
        '-s', f'{target_width}x{target_height}', # 해상도
        '-pix_fmt', 'bgr24',
        '-r', str(fps),
        '-i', '-', # 입력은 파이프(stdin)에서 받음
        '-i', input_video, # 오디오를 위해 원본도 입력으로 받음
        '-map', '0:v:0', # 0번 입력(파이프)의 비디오 사용
        '-map', '1:a:0', # 1번 입력(원본)의 오디오 사용
        '-c:v', 'h264_nvenc', # 3090 전용 GPU 인코더 
        '-preset', 'p4',      # 인코딩 속도/품질 프리셋
        '-b:v', '20M',        # 비트레이트 20Mbps (화질 타협)
        '-c:a', 'copy',       # 오디오는 복사 (인코딩 X)
        '-shortest',          # 짧은 쪽(보통 오디오) 길이에 맞춤
        output_video
    ]

    # FFmpeg 프로세스 실행
    process = subprocess.Popen(ffmpeg_cmd, stdin=subprocess.PIPE)

    # 4. 프레임 루프
    frame_count = 0
    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # AI 업스케일링
            # outscale=2 로 설정했으므로 내부에서 줄여서 나옴
            output, _ = upsampler.enhance(frame, outscale=outscale)

            # 결과를 FFmpeg 파이프로 쏘기 (디스크 저장 X)
            process.stdin.write(output.tobytes())

            frame_count += 1
            if frame_count % 10 == 0:
                print(f"Processing frame: {frame_count}/{total_frames}", end='\r')

    except KeyboardInterrupt:
        print("\nStopped by user.")
    
    # 5. 마무리
    cap.release()
    process.stdin.close()
    process.wait()
    print(f"\nDone! Saved to {output_video}")

if __name__ == '__main__':
    main()

 

 

업스케일링 처리

# -------------------------------------------------------------------------
# [한국어 설명]
# 이 코드는 Real-ESRGAN 프로젝트를 기반으로 수정되었습니다.
# 원본 저작권자: XPixelGroup (https://github.com/xinntao/Real-ESRGAN)
# 라이선스: BSD 3-Clause License
# -------------------------------------------------------------------------
# [Original License Notice]
# Copyright (c) 2021, XPixelGroup
# All rights reserved.
# -------------------------------------------------------------------------
# Based on Real-ESRGAN (https://github.com/xinntao/Real-ESRGAN)
# Licensed under BSD 3-Clause License
#
# Optimized by: [zmo]
# Key Changes:
#   1. Added FFmpeg pipline for direct video encoding (No intermediate images)
#   2. Used NVENC (h264_nvenc) for faster encoding on NVIDIA GPUs
#   3. Automatic audio copy from source video
# -------------------------------------------------------------------------

import cv2
import ...
# -*- coding: utf-8 -*-
"""
Optimized Real-ESRGAN video inference script (multi-part).
Features:
 - GPU auto-detect and use (CUDA)
 - fp16 support for speed
 - tile auto-heuristic
 - multiprocessing-safe fallback (ensures at least 1 process)
 - progress logging and ETA estimation
 - robust ffmpeg integration
 - compatible with drag-and-drop .bat wrappers
"""

import os
import sys
import time
import math
import shutil
import argparse
import subprocess
from pathlib import Path
from datetime import datetime
from typing import Optional, Tuple, List

import numpy as np
import cv2
import torch

# basicsr / realesrgan imports (ensure installed in venv)
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

# Try import ffmpeg-python; if missing, instruct user rather than pip install in-script
try:
    import ffmpeg
except Exception:
    ffmpeg = None

# Multiprocessing
import multiprocessing as mp

# ---------------------------
# Utility / helper functions
# ---------------------------

def print_header():
    print("=" * 60)
    print("Real-ESRGAN Optimized Video Inference")
    print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("=" * 60)


def get_device(prefer_gpu: bool = True) -> torch.device:
    """
    Detect and return the best available device.
    """
    if prefer_gpu and torch.cuda.is_available():
        dev = torch.device('cuda')
        try:
            name = torch.cuda.get_device_name(0)
            print(f"✅ Using GPU: {name}")
        except Exception:
            print("✅ Using GPU (device name unavailable)")
        return dev
    else:
        print("⚠️ GPU not available or disabled — using CPU")
        return torch.device('cpu')


def ensure_ffmpeg():
    """
    Ensure ffmpeg module is available and ffmpeg binary is on PATH.
    This function does not pip-install automatically; it explains next steps.
    """
    if ffmpeg is None:
        print("⚠ ffmpeg-python not found. Please install it inside your venv:")
        print("    pip install ffmpeg-python")
    # check ffmpeg binary
    try:
        subprocess.run(['ffmpeg', '-version'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    except Exception:
        print("⚠ ffmpeg binary not found in PATH. Please install ffmpeg and add to PATH.")
        print("Example: https://www.gyan.dev/ffmpeg/builds/ (Windows builds)")
        # don't exit here; some fallback paths may still work via OpenCV
    return


def download_model_if_missing(model_name: str, target_dir: str = 'weights') -> str:
    """
    Ensure model .pth exists in weights/. If not, try to download from known URLs.
    Returns the model path (string).
    """
    os.makedirs(target_dir, exist_ok=True)
    model_base = model_name.split('.pth')[0]
    model_path = os.path.join(target_dir, model_base + '.pth')
    if os.path.isfile(model_path):
        return model_path

    # Known release URLs (minimal set)
    known_urls = {
        'RealESRGAN_x4plus': ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'],
        'RealESRNet_x4plus': ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'],
        'RealESRGAN_x4plus_anime_6B': ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'],
        'RealESRGAN_x2plus': ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'],
        'realesr-animevideov3': ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'],
        'realesr-general-x4v3': [
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
        ]
    }

    urls = known_urls.get(model_base, None)
    if not urls:
        print(f"Model {model_name} not found locally and no known URL. Please place the .pth in {target_dir}.")
        return model_path

    print(f"Model {model_name} not found. Attempting to download to {target_dir} ...")
    for url in urls:
        try:
            downloaded = load_file_from_url(url=url, model_dir=target_dir, progress=True, file_name=None)
            if downloaded and os.path.isfile(downloaded):
                print(f"Downloaded model to {downloaded}")
                return downloaded
        except Exception as e:
            print(f"Download failed for {url}: {e}")
    print("Automatic download failed. Please manually download the model and place it under 'weights/'.")
    return model_path


def human_readable_time(seconds: float) -> str:
    if seconds is None or seconds == float('inf'):
        return 'Unknown'
    s = int(seconds)
    h, s = divmod(s, 3600)
    m, s = divmod(s, 60)
    if h:
        return f"{h}h {m}m {s}s"
    if m:
        return f"{m}m {s}s"
    return f"{s}s"


# ---------------------------
# Argparse and defaults
# ---------------------------

def get_argparser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(prog='inference_realesrgan_video_opt',
                                     description='Optimized Real-ESRGAN video inference (GPU/CPU compatible)')
    parser.add_argument('-i', '--input', type=str, required=True, help='Input video or folder')
    parser.add_argument('-n', '--model_name', type=str, default='RealESRGAN_x4plus',
                        help='Model name (RealESRGAN_x4plus, RealESRNet_x4plus, RealESRGAN_x2plus, realesr-animevideov3, etc.)')
    parser.add_argument('-o', '--output', type=str, default=None, help='Output folder (default: <input>_upscaled)')
    parser.add_argument('-s', '--outscale', type=float, default=4.0, help='Upscale factor (default 4)')
    parser.add_argument('--suffix', type=str, default='out', help='Suffix for resulting file')
    parser.add_argument('--tile', type=int, default=0, help='Tile size (0 auto). Larger tile = less tiling overhead, more memory.')
    parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
    parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding')
    parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN for face enhancement')
    parser.add_argument('--fp16', action='store_true', help='Use fp16 (half) inference for speed (requires GPU)')
    parser.add_argument('--fps', type=float, default=None, help='Force output FPS')
    parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='FFmpeg binary path')
    parser.add_argument('--num_process_per_gpu', type=int, default=1, help='Workers per GPU for multiprocessing split')
    parser.add_argument('--max_workers', type=int, default=4, help='Max total workers fallback if GPU count uncertain')
    parser.add_argument('--ext', type=str, default='mp4', help='Output extension')
    return parser



def build_model_and_upsampler(args, device: torch.device):
    """
    Build the network model and create a RealESRGANer upsampler.
    Downloads weights if missing.
    """
    model_name = args.model_name.split('.pth')[0]
    # choose architecture
    if model_name == 'RealESRGAN_x4plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
    elif model_name == 'RealESRNet_x4plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
    elif model_name == 'RealESRGAN_x4plus_anime_6B':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
    elif model_name == 'RealESRGAN_x2plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
    elif model_name.startswith('realesr-'):
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=int(args.outscale), act_type='prelu')
        netscale = int(args.outscale)
    else:
        # fallback to x4 rrdb
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=int(args.outscale))
        netscale = int(args.outscale)

    # ensure model weights exist (may download)
    model_path = download_model_if_missing(model_name)

    # dni handling omitted for brevity; using single model_path
    half = args.fp16 and device.type == 'cuda'
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=None,
        model=model,
        tile=args.tile,
        tile_pad=args.tile_pad,
        pre_pad=args.pre_pad,
        half=half,
        device=device
    )
    return upsampler, netscale


class Reader:
    """
    Read frames from video (via ffmpeg pipe) or list of images.
    Supports splitting when multiprocessing with sub-videos.
    """
    def __init__(self, args, video_path: str, total_workers: int = 1, worker_idx: int = 0):
        self.args = args
        self.video_path = video_path
        self.total_workers = total_workers
        self.worker_idx = worker_idx

        # detect input type by extension
        mime = None
        try:
            import mimetypes
            mime = mimetypes.guess_type(video_path)[0]
        except Exception:
            mime = None

        self.is_video = (mime is not None and mime.startswith('video')) or (str(video_path).lower().endswith(('.mp4', '.mov', '.mkv', '.avi', '.flv')))
        self.stream_proc = None
        self.width = None
        self.height = None
        self.nb_frames = None
        self.fps = None
        self.audio = None

        if self.is_video:
            # If total_workers > 1, we may have already created a sub-video file before calling Reader
            # Here video_path is the (possibly) sub-split path provided by caller.
            meta = get_video_meta_info(video_path)
            self.width = meta['width']
            self.height = meta['height']
            self.fps = meta['fps']
            self.audio = meta.get('audio', None)
            self.nb_frames = meta.get('nb_frames', None)
            # open ffmpeg pipe for rawvideo
            cmd = [
                args.ffmpeg_bin,
                '-i', str(video_path),
                '-f', 'rawvideo',
                '-pix_fmt', 'bgr24',
                '-vsync', '0',
                '-loglevel', 'error',
                'pipe:1'
            ]
            self.stream_proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        else:
            # image folder or single image
            if os.path.isdir(video_path):
                paths = sorted([os.path.join(video_path, p) for p in os.listdir(video_path)])
                self.paths = [p for p in paths if p.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
            else:
                self.paths = [video_path]
            if len(self.paths) == 0:
                raise RuntimeError("No images found in input")
            from PIL import Image
            tmp = Image.open(self.paths[0])
            self.width, self.height = tmp.size
            self.nb_frames = len(self.paths)
            self.fps = args.fps or 24

        self._idx = 0

    def read_frame(self):
        if self.is_video:
            # read raw bytes
            raw_size = self.width * self.height * 3
            data = self.stream_proc.stdout.read(raw_size)
            if not data or len(data) < raw_size:
                return None
            frame = np.frombuffer(data, dtype=np.uint8).reshape((self.height, self.width, 3))
            return frame
        else:
            if self._idx >= len(self.paths):
                return None
            f = self.paths[self._idx]
            img = cv2.imread(f)
            self._idx += 1
            return img

    def close(self):
        if self.stream_proc:
            try:
                self.stream_proc.stdout.close()
                self.stream_proc.terminate()
            except Exception:
                pass


class Writer:
    """
    Write frames to a video file via ffmpeg pipe for best speed.
    """
    def __init__(self, args, out_path: str, width: int, height: int, fps: float, outscale: float):
        out_w = int(width * outscale)
        out_h = int(height * outscale)
        self.args = args
        self.out_path = out_path
        self.width = out_w
        self.height = out_h
        self.fps = fps

        # ffmpeg command to accept rawvideo pipe and write mp4 (h264)
        cmd = [
            args.ffmpeg_bin,
            '-y',
            '-f', 'rawvideo',
            '-pix_fmt', 'bgr24',
            '-s', f'{out_w}x{out_h}',
            '-r', str(fps),
            '-i', 'pipe:0',
            '-c:v', 'libx264',
            '-pix_fmt', 'yuv420p',
            '-preset', 'fast',
            '-crf', '18',
            self.out_path
        ]
        self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    def write_frame(self, frame: np.ndarray):
        # frame expected as HWC uint8
        if frame.dtype != np.uint8:
            frame = frame.astype(np.uint8)
        self.proc.stdin.write(frame.tobytes())

    def close(self):
        try:
            self.proc.stdin.close()
            self.proc.wait()
        except Exception:
            pass


def inference_video(args, video_input_path: str, video_save_path: str, device: Optional[torch.device] = None,
                    total_workers: int = 1, worker_idx: int = 0):
    """
    Process one video (or sub-video) from start to finish.
    """
    # prepare reader
    reader = Reader(args, video_input_path, total_workers=total_workers, worker_idx=worker_idx)
    height, width = reader.height, reader.width
    fps = reader.fps or args.fps or 24
    # device fallback
    if device is None:
        device = get_device(prefer_gpu=True)
    # build model
    upsampler, netscale = build_model_and_upsampler(args, device)

    # writer path
    os.makedirs(os.path.dirname(video_save_path) or '.', exist_ok=True)
    writer = Writer(args, video_save_path, width, height, fps, args.outscale)

    # progress vars
    total_frames = reader.nb_frames or 0
    processed = 0
    t0 = time.time()

    while True:
        frame = reader.read_frame()
        if frame is None:
            break
        try:
            # upsampler.enhance returns tuple (output, _, _)
            output, _ = upsampler.enhance(frame, outscale=args.outscale)
        except RuntimeError as e:
            print(f"[worker {worker_idx}] RuntimeError during enhance: {e}. Writing original frame.")
            output = frame
        writer.write_frame(output)
        processed += 1
        # simple ETA print every 10 frames
        if processed % 10 == 0 or (total_frames and processed == total_frames):
            elapsed = time.time() - t0
            per_frame = elapsed / processed
            remaining = (total_frames - processed) * per_frame if total_frames else None
            print(f"[worker {worker_idx}] frames {processed}/{total_frames or '?'} elapsed {human_readable_time(elapsed)} ETA {human_readable_time(remaining)}")

    reader.close()
    writer.close()
    # sync cuda if used
    if device is not None and device.type == 'cuda' and torch.cuda.is_available():
        try:
            torch.cuda.synchronize(device)
        except Exception:
            pass

    print(f"[worker {worker_idx}] finished. processed {processed} frames in {human_readable_time(time.time()-t0)}")

# ---------------------------
# multiprocessing split/merge, safe run(), temp handling
# ---------------------------

def split_video_for_workers(args, input_path: str, output_dir: str, num_workers: int) -> List[str]:
    """
    Split input video into num_workers subclips using ffmpeg (by time).
    Returns list of subclip paths (length == num_workers).
    If num_workers == 1, returns [input_path].
    """
    if num_workers <= 1:
        return [input_path]

    ensure_ffmpeg()
    meta = get_video_meta_info(input_path)
    fps = meta['fps']
    total_frames = meta['nb_frames']
    duration = int(total_frames / fps)
    part_seconds = duration // num_workers
    os.makedirs(output_dir, exist_ok=True)
    parts = []
    for i in range(num_workers):
        start = part_seconds * i
        end = part_seconds * (i + 1) if i != num_workers - 1 else None
        out_path = os.path.join(output_dir, f'part_{i:03d}.mp4')
        if end is None:
            cmd = [args.ffmpeg_bin, '-i', input_path, '-ss', str(start), '-c', 'copy', out_path, '-y']
        else:
            cmd = [args.ffmpeg_bin, '-i', input_path, '-ss', str(start), '-to', str(end), '-c', 'copy', out_path, '-y']
        print("Splitting:", " ".join(cmd))
        subprocess.call(cmd, shell=(os.name == 'nt'))
        parts.append(out_path)
    return parts


def combine_subvideos(args, out_dir: str, video_name: str, num_workers: int, final_path: str):
    """
    Combine subvideos in out_dir into final_path using ffmpeg concat demuxer.
    """
    list_file = os.path.join(out_dir, f'{video_name}_vidlist.txt')
    with open(list_file, 'w', encoding='utf-8') as f:
        for i in range(num_workers):
            v = os.path.join(out_dir, f'{i:03d}.mp4')
            f.write(f"file '{v}'\n")
    cmd = [args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', list_file, '-c', 'copy', final_path, '-y']
    print("Combining:", " ".join(cmd))
    subprocess.call(cmd, shell=(os.name == 'nt'))
    try:
        os.remove(list_file)
    except Exception:
        pass


def safe_remove(path: str):
    try:
        if os.path.isdir(path):
            shutil.rmtree(path)
        elif os.path.isfile(path):
            os.remove(path)
    except Exception:
        pass


def run(args):
    """
    Top-level run function that handles splitting the video, launching workers (or single-process),
    and combining outputs. Ensures num_process >= 1 and provides safe fallbacks.
    """
    args.video_name = os.path.splitext(os.path.basename(args.input))[0]
    if args.output is None:
        args.output = os.path.join(os.path.dirname(args.input), f"{args.video_name}_upscaled")
    os.makedirs(args.output, exist_ok=True)
    video_save_path = os.path.join(args.output, f'{args.video_name}_{args.suffix}.{args.ext}')

    # If input is a folder or images, just run directly single-worker path
    input_mime = None
    try:
        import mimetypes
        input_mime = mimetypes.guess_type(args.input)[0]
    except Exception:
        input_mime = None

    is_video = (input_mime is not None and input_mime.startswith('video')) or str(args.input).lower().endswith(('.mp4', '.mov', '.mkv', '.avi', '.flv'))

    if not is_video:
        # images / folder: single worker inference
        print("Input is not a video; running single-worker inference.")
        inference_video(args, args.input, video_save_path, device=get_device(prefer_gpu=True), total_workers=1, worker_idx=0)
        return

    # For videos: compute GPUs and desired processes
    num_gpus = torch.cuda.device_count()
    try:
        num_process = int(num_gpus) * int(args.num_process_per_gpu)
    except Exception:
        num_process = None

    if num_process is None or num_process < 1:
        # fallback to single-process mode (use GPU 0 if available)
        print("Falling back to single-process inference (no multiprocessing).")
        device = torch.device('cuda') if (num_gpus > 0 and torch.cuda.is_available()) else None
        inference_video(args, args.input, video_save_path, device=device, total_workers=1, worker_idx=0)
        return

    # Multiprocessing path: split video to subclips for each worker
    tmp_in_dir = os.path.join(args.output, f'{args.video_name}_inp_tmp_videos')
    tmp_out_dir = os.path.join(args.output, f'{args.video_name}_out_tmp_videos')
    safe_remove(tmp_in_dir)
    safe_remove(tmp_out_dir)
    os.makedirs(tmp_in_dir, exist_ok=True)
    os.makedirs(tmp_out_dir, exist_ok=True)

    subclips = split_video_for_workers(args, args.input, tmp_in_dir, num_process)

    # Launch multiprocessing pool - use torch multiprocessing context spawn
    ctx = mp.get_context('spawn')
    pool = ctx.Pool(processes=num_process)
    results = []
    try:
        for i, sc in enumerate(subclips):
            sub_save = os.path.join(tmp_out_dir, f'{i:03d}.mp4')
            device = torch.device(i % num_gpus) if (num_gpus > 0 and torch.cuda.is_available()) else None
            print(f"Launching worker {i} on device {device} for {sc} -> {sub_save}")
            res = pool.apply_async(inference_video, args=(args, sc, sub_save, device, num_process, i))
            results.append(res)
        pool.close()
        pool.join()
    except KeyboardInterrupt:
        print("KeyboardInterrupt received: terminating pool.")
        try:
            pool.terminate()
        except Exception:
            pass
        pool.join()
        raise
    except Exception as e:
        print("Exception in multiprocessing:", e)
        try:
            pool.terminate()
        except Exception:
            pass
        pool.join()
        raise

    # Combine subvideos
    combine_subvideos(args, tmp_out_dir, args.video_name, num_process, video_save_path)

    # cleanup temp dirs
    safe_remove(tmp_in_dir)
    safe_remove(tmp_out_dir)

    print("All done. Output:", video_save_path)


# ---------------------------
# CLI entrypoint, main()
# ---------------------------

def main():
    import argparse
    parser = argparse.ArgumentParser()

    # Required
    parser.add_argument('-i', '--input', type=str, required=True, help='Input video or image folder path')

    # Model
    parser.add_argument('-n', '--model_name', type=str, default='RealESRGAN_x4plus',
                        help='Model name: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B etc.')

    # Output
    parser.add_argument('-o', '--output', type=str, default=None, help='Output directory (optional)')
    parser.add_argument('--suffix', type=str, default='upscaled', help='Output filename suffix')
    parser.add_argument('--ext', type=str, default='mp4', help='Output extension (mp4, mov…)')

    # GPU / multiprocessing
    parser.add_argument('--num_process_per_gpu', type=int, default=1,
                        help='Number of processes per GPU (default=1)')
    parser.add_argument('--fp16', action='store_true', help='Use half precision')
    parser.add_argument('--tile', type=int, default=0, help='Tile size (0 means no tiling)')
    parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
    parser.add_argument('--max_tile_size', type=int, default=400, help='Max tile size for safe FP16 inference')

    # FFMPEG
    parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='Path to ffmpeg binary')

    args = parser.parse_args()

    print("=" * 55)
    print("Real-ESRGAN 비디오 업스케일링 시작")
    print("=" * 55)
    print("-----------------------------------------------")
    print("입력 파일:", args.input)
    print("모델:", args.model_name)
    print("출력 폴더:", args.output if args.output else "(자동 생성됨)")
    from datetime import datetime
    print("시작 시각:", datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    print("-----------------------------------------------")

    try:
        run(args)
    except KeyboardInterrupt:
        print("\n❌ 사용자에 의해 작업 중단됨 (Ctrl+C)")
    except Exception as e:
        print("❌ 오류 발생!")
        print(e)
        sys.exit(1)


if __name__ == '__main__':
    main()

'업스케일링' 카테고리의 다른 글

업스케일링 최적화 코드 오류 해결하기(1)  (0) 2025.11.30
업스케일링 해보기(2)  (0) 2025.11.15

관련글 더보기