1
0

Run Hunyuan's generate script with torchrun utility

This commit is contained in:
2026-04-03 20:12:31 +02:00
parent 9668088b27
commit dede1988e6

View File

@@ -14,6 +14,7 @@ import sys
from pathlib import Path
import torch
from torch.distributed.run import main as torch_run
from src.logging_config import configure_logging, debug_log_lifecycle
@@ -48,17 +49,13 @@ def _run_hunyuan_generate_in_process(
argv: list[str],
env_update: dict[str, str],
) -> None:
generate_script = hunyuan_dir / "generate.py"
if not generate_script.exists():
raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}")
old_argv = sys.argv[:]
old_cwd = Path.cwd()
try:
with _temporary_environ(env_update):
os.chdir(hunyuan_dir)
sys.argv = ["generate.py", *argv]
runpy.run_path(str(generate_script), run_name="__main__")
sys.argv = argv
torch_run()
finally:
sys.argv = old_argv
os.chdir(old_cwd)
@@ -159,6 +156,8 @@ def main(argv: list[str] | None = None) -> int:
LOGGER.info("GPU cache cleared")
run_argv = [
"--nproc_per_node=1",
"generate.py",
"--prompt",
prompt,
"--image_path",