forked from LiveCarta/ContentGeneration
Run video generation model from python instead of calling subprocess
This commit is contained in:
@@ -3,6 +3,7 @@ FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
|||||||
ENV DEBIAN_FRONTEND=noninteractive \
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
PYTHONUNBUFFERED=1 \
|
PYTHONUNBUFFERED=1 \
|
||||||
UV_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu121 \
|
UV_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu121 \
|
||||||
|
UV_NO_SYNC=true \
|
||||||
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
|
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
|
||||||
|
|
||||||
# Base OS tools + media stack + Python toolchain.
|
# Base OS tools + media stack + Python toolchain.
|
||||||
|
|||||||
@@ -4,12 +4,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import runpy
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from src.logging_config import configure_logging, debug_log_lifecycle
|
from src.logging_config import configure_logging, debug_log_lifecycle
|
||||||
|
|
||||||
|
|
||||||
@@ -24,6 +29,41 @@ DEFAULT_AUDIOS_DIR = DEFAULT_BASE_DIR / "audios"
|
|||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _temporary_environ(update: dict[str, str]):
|
||||||
|
previous: dict[str, str | None] = {key: os.environ.get(key) for key in update}
|
||||||
|
os.environ.update(update)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for key, value in previous.items():
|
||||||
|
if value is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _run_hunyuan_generate_in_process(
|
||||||
|
hunyuan_dir: Path,
|
||||||
|
argv: list[str],
|
||||||
|
env_update: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
generate_script = hunyuan_dir / "generate.py"
|
||||||
|
if not generate_script.exists():
|
||||||
|
raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}")
|
||||||
|
|
||||||
|
old_argv = sys.argv[:]
|
||||||
|
old_cwd = Path.cwd()
|
||||||
|
try:
|
||||||
|
with _temporary_environ(env_update):
|
||||||
|
os.chdir(hunyuan_dir)
|
||||||
|
sys.argv = ["generate.py", *argv]
|
||||||
|
runpy.run_path(str(generate_script), run_name="__main__")
|
||||||
|
finally:
|
||||||
|
sys.argv = old_argv
|
||||||
|
os.chdir(old_cwd)
|
||||||
|
|
||||||
|
|
||||||
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR)
|
parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR)
|
||||||
@@ -114,22 +154,11 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
LOGGER.warning("Image not found at %s, skipped", image_path)
|
LOGGER.warning("Image not found at %s, skipped", image_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
subprocess.run(
|
with _temporary_environ(env):
|
||||||
[
|
torch.cuda.empty_cache()
|
||||||
"python3",
|
|
||||||
"-c",
|
|
||||||
"import torch; torch.cuda.empty_cache()",
|
|
||||||
],
|
|
||||||
check=True,
|
|
||||||
env=env,
|
|
||||||
)
|
|
||||||
LOGGER.info("GPU cache cleared")
|
LOGGER.info("GPU cache cleared")
|
||||||
|
|
||||||
subprocess.run(
|
run_argv = [
|
||||||
[
|
|
||||||
"torchrun",
|
|
||||||
"--nproc_per_node=1",
|
|
||||||
"generate.py",
|
|
||||||
"--prompt",
|
"--prompt",
|
||||||
prompt,
|
prompt,
|
||||||
"--image_path",
|
"--image_path",
|
||||||
@@ -162,10 +191,19 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
str(output_path),
|
str(output_path),
|
||||||
"--model_path",
|
"--model_path",
|
||||||
str(model_path),
|
str(model_path),
|
||||||
],
|
]
|
||||||
check=True,
|
run_env = {
|
||||||
cwd=args.hunyuan_dir,
|
**env,
|
||||||
env=env,
|
"RANK": "0",
|
||||||
|
"LOCAL_RANK": "0",
|
||||||
|
"WORLD_SIZE": "1",
|
||||||
|
"MASTER_ADDR": "127.0.0.1",
|
||||||
|
"MASTER_PORT": os.environ.get("MASTER_PORT", "29500"),
|
||||||
|
}
|
||||||
|
_run_hunyuan_generate_in_process(
|
||||||
|
hunyuan_dir=args.hunyuan_dir,
|
||||||
|
argv=run_argv,
|
||||||
|
env_update=run_env,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOGGER.info("Shot %s done", shot_number)
|
LOGGER.info("Shot %s done", shot_number)
|
||||||
|
|||||||
Reference in New Issue
Block a user