1
0

Run Hunyuan's generate script with torchrun utility

This commit is contained in:
2026-04-03 20:12:31 +02:00
parent 9668088b27
commit dede1988e6

View File

@@ -14,6 +14,7 @@ import sys
from pathlib import Path from pathlib import Path
import torch import torch
from torch.distributed.run import main as torch_run
from src.logging_config import configure_logging, debug_log_lifecycle from src.logging_config import configure_logging, debug_log_lifecycle
@@ -48,17 +49,13 @@ def _run_hunyuan_generate_in_process(
argv: list[str], argv: list[str],
env_update: dict[str, str], env_update: dict[str, str],
) -> None: ) -> None:
generate_script = hunyuan_dir / "generate.py"
if not generate_script.exists():
raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}")
old_argv = sys.argv[:] old_argv = sys.argv[:]
old_cwd = Path.cwd() old_cwd = Path.cwd()
try: try:
with _temporary_environ(env_update): with _temporary_environ(env_update):
os.chdir(hunyuan_dir) os.chdir(hunyuan_dir)
sys.argv = ["generate.py", *argv] sys.argv = argv
runpy.run_path(str(generate_script), run_name="__main__") torch_run()
finally: finally:
sys.argv = old_argv sys.argv = old_argv
os.chdir(old_cwd) os.chdir(old_cwd)
@@ -159,6 +156,8 @@ def main(argv: list[str] | None = None) -> int:
LOGGER.info("GPU cache cleared") LOGGER.info("GPU cache cleared")
run_argv = [ run_argv = [
"--nproc_per_node=1",
"generate.py",
"--prompt", "--prompt",
prompt, prompt,
"--image_path", "--image_path",