forked from LiveCarta/ContentGeneration
Run video generation model from python instead of calling subprocess
This commit is contained in:
@@ -3,6 +3,7 @@ FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
UV_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu121 \
|
||||
UV_NO_SYNC=true \
|
||||
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
|
||||
|
||||
# Base OS tools + media stack + Python toolchain.
|
||||
|
||||
@@ -4,12 +4,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import runpy
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from src.logging_config import configure_logging, debug_log_lifecycle
|
||||
|
||||
|
||||
@@ -24,6 +29,41 @@ DEFAULT_AUDIOS_DIR = DEFAULT_BASE_DIR / "audios"
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temporary_environ(update: dict[str, str]):
|
||||
previous: dict[str, str | None] = {key: os.environ.get(key) for key in update}
|
||||
os.environ.update(update)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key, value in previous.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def _run_hunyuan_generate_in_process(
|
||||
hunyuan_dir: Path,
|
||||
argv: list[str],
|
||||
env_update: dict[str, str],
|
||||
) -> None:
|
||||
generate_script = hunyuan_dir / "generate.py"
|
||||
if not generate_script.exists():
|
||||
raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}")
|
||||
|
||||
old_argv = sys.argv[:]
|
||||
old_cwd = Path.cwd()
|
||||
try:
|
||||
with _temporary_environ(env_update):
|
||||
os.chdir(hunyuan_dir)
|
||||
sys.argv = ["generate.py", *argv]
|
||||
runpy.run_path(str(generate_script), run_name="__main__")
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
os.chdir(old_cwd)
|
||||
|
||||
|
||||
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR)
|
||||
@@ -114,58 +154,56 @@ def main(argv: list[str] | None = None) -> int:
|
||||
LOGGER.warning("Image not found at %s, skipped", image_path)
|
||||
continue
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"python3",
|
||||
"-c",
|
||||
"import torch; torch.cuda.empty_cache()",
|
||||
],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
with _temporary_environ(env):
|
||||
torch.cuda.empty_cache()
|
||||
LOGGER.info("GPU cache cleared")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"torchrun",
|
||||
"--nproc_per_node=1",
|
||||
"generate.py",
|
||||
"--prompt",
|
||||
prompt,
|
||||
"--image_path",
|
||||
str(image_path),
|
||||
"--resolution",
|
||||
"480p",
|
||||
"--aspect_ratio",
|
||||
"16:9",
|
||||
"--seed",
|
||||
str(args.seed),
|
||||
"--video_length",
|
||||
str(video_length),
|
||||
"--rewrite",
|
||||
"false",
|
||||
"--cfg_distilled",
|
||||
"true",
|
||||
"--enable_step_distill",
|
||||
"true",
|
||||
"--sparse_attn",
|
||||
"false",
|
||||
"--use_sageattn",
|
||||
"true",
|
||||
"--enable_cache",
|
||||
"false",
|
||||
"--overlap_group_offloading",
|
||||
"true",
|
||||
"--sr",
|
||||
"false",
|
||||
"--output_path",
|
||||
str(output_path),
|
||||
"--model_path",
|
||||
str(model_path),
|
||||
],
|
||||
check=True,
|
||||
cwd=args.hunyuan_dir,
|
||||
env=env,
|
||||
run_argv = [
|
||||
"--prompt",
|
||||
prompt,
|
||||
"--image_path",
|
||||
str(image_path),
|
||||
"--resolution",
|
||||
"480p",
|
||||
"--aspect_ratio",
|
||||
"16:9",
|
||||
"--seed",
|
||||
str(args.seed),
|
||||
"--video_length",
|
||||
str(video_length),
|
||||
"--rewrite",
|
||||
"false",
|
||||
"--cfg_distilled",
|
||||
"true",
|
||||
"--enable_step_distill",
|
||||
"true",
|
||||
"--sparse_attn",
|
||||
"false",
|
||||
"--use_sageattn",
|
||||
"true",
|
||||
"--enable_cache",
|
||||
"false",
|
||||
"--overlap_group_offloading",
|
||||
"true",
|
||||
"--sr",
|
||||
"false",
|
||||
"--output_path",
|
||||
str(output_path),
|
||||
"--model_path",
|
||||
str(model_path),
|
||||
]
|
||||
run_env = {
|
||||
**env,
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": os.environ.get("MASTER_PORT", "29500"),
|
||||
}
|
||||
_run_hunyuan_generate_in_process(
|
||||
hunyuan_dir=args.hunyuan_dir,
|
||||
argv=run_argv,
|
||||
env_update=run_env,
|
||||
)
|
||||
|
||||
LOGGER.info("Shot %s done", shot_number)
|
||||
|
||||
Reference in New Issue
Block a user