1
0
Files
ContentGeneration/src/generate_videos.py

216 lines
6.2 KiB
Python

#!/usr/bin/env python3
"""Generate shot videos with HunyuanVideo based on reel script and audio durations."""
from __future__ import annotations
import argparse
import contextlib
import json
import logging
import os
import runpy
import subprocess
import sys
from pathlib import Path
import torch
from torch.distributed.run import main as torch_run
from src.logging_config import configure_logging, debug_log_lifecycle
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_BASE_DIR = SCRIPT_DIR.parents[1]
DEFAULT_HUNYUAN_DIR = DEFAULT_BASE_DIR / "HunyuanVideo-1.5"
DEFAULT_REEL_SCRIPT = DEFAULT_BASE_DIR / "reel_script.json"
DEFAULT_IMAGES_DIR = DEFAULT_BASE_DIR / "images"
DEFAULT_VIDEOS_DIR = DEFAULT_BASE_DIR / "videos"
DEFAULT_AUDIOS_DIR = DEFAULT_BASE_DIR / "audios"
LOGGER = logging.getLogger(__name__)
@contextlib.contextmanager
def _temporary_environ(update: dict[str, str]):
previous: dict[str, str | None] = {key: os.environ.get(key) for key in update}
os.environ.update(update)
try:
yield
finally:
for key, value in previous.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def _run_hunyuan_generate_in_process(
hunyuan_dir: Path,
argv: list[str],
env_update: dict[str, str],
) -> None:
old_argv = sys.argv[:]
old_cwd = Path.cwd()
try:
with _temporary_environ(env_update):
os.chdir(hunyuan_dir)
sys.argv = argv
torch_run()
finally:
sys.argv = old_argv
os.chdir(old_cwd)
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR)
parser.add_argument("--reel-script", type=Path, default=DEFAULT_REEL_SCRIPT)
parser.add_argument("--images-dir", type=Path, default=DEFAULT_IMAGES_DIR)
parser.add_argument("--videos-dir", type=Path, default=DEFAULT_VIDEOS_DIR)
parser.add_argument("--audios-dir", type=Path, default=DEFAULT_AUDIOS_DIR)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument(
"--log-level",
default=None,
help="Logging level (overrides LOG_LEVEL env var)",
)
return parser.parse_args(argv)
@debug_log_lifecycle
def get_audio_duration(audio_path: Path) -> float:
result = subprocess.run(
[
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
str(audio_path),
],
check=True,
text=True,
capture_output=True,
)
return float(result.stdout.strip())
@debug_log_lifecycle
def duration_to_video_length(duration: float) -> int:
frames = int(duration * 24) + 1
if frames % 2 == 0:
frames += 1
return max(49, min(frames, 169))
@debug_log_lifecycle
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level)
model_path = args.hunyuan_dir / "ckpts"
args.videos_dir.mkdir(parents=True, exist_ok=True)
env = os.environ.copy()
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
data = json.loads(args.reel_script.read_text())
shots = data.get("shots", [])
LOGGER.info("Found %s shots to generate", len(shots))
for shot in shots:
shot_number = shot["shot_number"]
prompt = str(shot["image_description"]).replace("\t", " ").replace("\n", " ")
image_path = args.images_dir / f"shot_{shot_number}.png"
output_path = args.videos_dir / f"output_{shot_number}.mp4"
audio_path = args.audios_dir / f"output_{shot_number}.mp3"
if not audio_path.exists():
LOGGER.warning("No audio found at %s, falling back to 5s default", audio_path)
duration = 5.0
else:
duration = get_audio_duration(audio_path)
LOGGER.info("Audio duration for shot %s: %ss", shot_number, duration)
video_length = duration_to_video_length(duration)
LOGGER.info("Shot %s | %ss -> %s frames", shot_number, duration, video_length)
LOGGER.info("Prompt: %s", prompt)
LOGGER.info("Image: %s", image_path)
LOGGER.info("Audio: %s", audio_path)
LOGGER.info("Output: %s", output_path)
if output_path.exists():
LOGGER.info("Output path already exists, skipping")
continue
if not image_path.exists():
LOGGER.warning("Image not found at %s, skipped", image_path)
continue
with _temporary_environ(env):
torch.cuda.empty_cache()
LOGGER.info("GPU cache cleared")
run_argv = [
"--nproc_per_node=1",
"generate.py",
"--prompt",
prompt,
"--image_path",
str(image_path),
"--resolution",
"480p",
"--aspect_ratio",
"16:9",
"--seed",
str(args.seed),
"--video_length",
str(video_length),
"--rewrite",
"false",
"--cfg_distilled",
"true",
"--enable_step_distill",
"true",
"--sparse_attn",
"false",
"--use_sageattn",
"true",
"--enable_cache",
"false",
"--overlap_group_offloading",
"true",
"--sr",
"false",
"--output_path",
str(output_path),
"--model_path",
str(model_path),
]
run_env = {
**env,
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": os.environ.get("MASTER_PORT", "29500"),
}
_run_hunyuan_generate_in_process(
hunyuan_dir=args.hunyuan_dir,
argv=run_argv,
env_update=run_env,
)
LOGGER.info("Shot %s done", shot_number)
LOGGER.info("Done")
return 0
if __name__ == "__main__":
raise SystemExit(main())