From 17ac03372980f7437a407723219adb4637a3d5be Mon Sep 17 00:00:00 2001 From: Artsiom Siamashka Date: Fri, 3 Apr 2026 18:34:21 +0200 Subject: [PATCH] Run video generation model from python instead of calling subprocess --- Dockerfile | 1 + src/generate_videos.py | 138 ++++++++++++++++++++++++++--------------- 2 files changed, 89 insertions(+), 50 deletions(-) diff --git a/Dockerfile b/Dockerfile index c941dc8..998387a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive \ PYTHONUNBUFFERED=1 \ UV_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu121 \ + UV_NO_SYNC=true \ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128 # Base OS tools + media stack + Python toolchain. diff --git a/src/generate_videos.py b/src/generate_videos.py index f568948..6da195f 100644 --- a/src/generate_videos.py +++ b/src/generate_videos.py @@ -4,12 +4,17 @@ from __future__ import annotations import argparse +import contextlib import json import logging import os +import runpy import subprocess +import sys from pathlib import Path +import torch + from src.logging_config import configure_logging, debug_log_lifecycle @@ -24,6 +29,41 @@ DEFAULT_AUDIOS_DIR = DEFAULT_BASE_DIR / "audios" LOGGER = logging.getLogger(__name__) +@contextlib.contextmanager +def _temporary_environ(update: dict[str, str]): + previous: dict[str, str | None] = {key: os.environ.get(key) for key in update} + os.environ.update(update) + try: + yield + finally: + for key, value in previous.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + +def _run_hunyuan_generate_in_process( + hunyuan_dir: Path, + argv: list[str], + env_update: dict[str, str], +) -> None: + generate_script = hunyuan_dir / "generate.py" + if not generate_script.exists(): + raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}") + + old_argv = sys.argv[:] + old_cwd = Path.cwd() + try: + with _temporary_environ(env_update): + os.chdir(hunyuan_dir) + sys.argv = ["generate.py", *argv] + runpy.run_path(str(generate_script), run_name="__main__") + finally: + sys.argv = old_argv + os.chdir(old_cwd) + + def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR) @@ -114,58 +154,56 @@ def main(argv: list[str] | None = None) -> int: LOGGER.warning("Image not found at %s, skipped", image_path) continue - subprocess.run( - [ - "python3", - "-c", - "import torch; torch.cuda.empty_cache()", - ], - check=True, - env=env, - ) + with _temporary_environ(env): + torch.cuda.empty_cache() LOGGER.info("GPU cache cleared") - subprocess.run( - [ - "torchrun", - "--nproc_per_node=1", - "generate.py", - "--prompt", - prompt, - "--image_path", - str(image_path), - "--resolution", - "480p", - "--aspect_ratio", - "16:9", - "--seed", - str(args.seed), - "--video_length", - str(video_length), - "--rewrite", - "false", - "--cfg_distilled", - "true", - "--enable_step_distill", - "true", - "--sparse_attn", - "false", - "--use_sageattn", - "true", - "--enable_cache", - "false", - "--overlap_group_offloading", - "true", - "--sr", - "false", - "--output_path", - str(output_path), - "--model_path", - str(model_path), - ], - check=True, - cwd=args.hunyuan_dir, - env=env, + run_argv = [ + "--prompt", + prompt, + "--image_path", + str(image_path), + "--resolution", + "480p", + "--aspect_ratio", + "16:9", + "--seed", + str(args.seed), + "--video_length", + str(video_length), + "--rewrite", + "false", + "--cfg_distilled", + "true", + "--enable_step_distill", + "true", + "--sparse_attn", + "false", + "--use_sageattn", + "true", + "--enable_cache", + "false", + "--overlap_group_offloading", + "true", + "--sr", + "false", + "--output_path", + str(output_path), + "--model_path", + str(model_path), + ] + run_env = { + **env, + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": os.environ.get("MASTER_PORT", "29500"), + } + _run_hunyuan_generate_in_process( + hunyuan_dir=args.hunyuan_dir, + argv=run_argv, + env_update=run_env, ) LOGGER.info("Shot %s done", shot_number)