1
0

Run video generation model from python instead of calling subprocess

This commit is contained in:
2026-04-03 18:34:21 +02:00
parent a4059aa4f8
commit 17ac033729
2 changed files with 89 additions and 50 deletions

View File

@@ -3,6 +3,7 @@ FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive \ ENV DEBIAN_FRONTEND=noninteractive \
PYTHONUNBUFFERED=1 \ PYTHONUNBUFFERED=1 \
UV_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu121 \ UV_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu121 \
UV_NO_SYNC=true \
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
# Base OS tools + media stack + Python toolchain. # Base OS tools + media stack + Python toolchain.

View File

@@ -4,12 +4,17 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import contextlib
import json import json
import logging import logging
import os import os
import runpy
import subprocess import subprocess
import sys
from pathlib import Path from pathlib import Path
import torch
from src.logging_config import configure_logging, debug_log_lifecycle from src.logging_config import configure_logging, debug_log_lifecycle
@@ -24,6 +29,41 @@ DEFAULT_AUDIOS_DIR = DEFAULT_BASE_DIR / "audios"
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@contextlib.contextmanager
def _temporary_environ(update: dict[str, str]):
previous: dict[str, str | None] = {key: os.environ.get(key) for key in update}
os.environ.update(update)
try:
yield
finally:
for key, value in previous.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def _run_hunyuan_generate_in_process(
hunyuan_dir: Path,
argv: list[str],
env_update: dict[str, str],
) -> None:
generate_script = hunyuan_dir / "generate.py"
if not generate_script.exists():
raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}")
old_argv = sys.argv[:]
old_cwd = Path.cwd()
try:
with _temporary_environ(env_update):
os.chdir(hunyuan_dir)
sys.argv = ["generate.py", *argv]
runpy.run_path(str(generate_script), run_name="__main__")
finally:
sys.argv = old_argv
os.chdir(old_cwd)
def parse_args(argv: list[str] | None = None) -> argparse.Namespace: def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR) parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR)
@@ -114,22 +154,11 @@ def main(argv: list[str] | None = None) -> int:
LOGGER.warning("Image not found at %s, skipped", image_path) LOGGER.warning("Image not found at %s, skipped", image_path)
continue continue
subprocess.run( with _temporary_environ(env):
[ torch.cuda.empty_cache()
"python3",
"-c",
"import torch; torch.cuda.empty_cache()",
],
check=True,
env=env,
)
LOGGER.info("GPU cache cleared") LOGGER.info("GPU cache cleared")
subprocess.run( run_argv = [
[
"torchrun",
"--nproc_per_node=1",
"generate.py",
"--prompt", "--prompt",
prompt, prompt,
"--image_path", "--image_path",
@@ -162,10 +191,19 @@ def main(argv: list[str] | None = None) -> int:
str(output_path), str(output_path),
"--model_path", "--model_path",
str(model_path), str(model_path),
], ]
check=True, run_env = {
cwd=args.hunyuan_dir, **env,
env=env, "RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": os.environ.get("MASTER_PORT", "29500"),
}
_run_hunyuan_generate_in_process(
hunyuan_dir=args.hunyuan_dir,
argv=run_argv,
env_update=run_env,
) )
LOGGER.info("Shot %s done", shot_number) LOGGER.info("Shot %s done", shot_number)