forked from LiveCarta/ContentGeneration
Run Hunyuan's generate script with torchrun utility
This commit is contained in:
@@ -14,6 +14,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.distributed.run import main as torch_run
|
||||
|
||||
from src.logging_config import configure_logging, debug_log_lifecycle
|
||||
|
||||
@@ -48,17 +49,13 @@ def _run_hunyuan_generate_in_process(
|
||||
argv: list[str],
|
||||
env_update: dict[str, str],
|
||||
) -> None:
|
||||
generate_script = hunyuan_dir / "generate.py"
|
||||
if not generate_script.exists():
|
||||
raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}")
|
||||
|
||||
old_argv = sys.argv[:]
|
||||
old_cwd = Path.cwd()
|
||||
try:
|
||||
with _temporary_environ(env_update):
|
||||
os.chdir(hunyuan_dir)
|
||||
sys.argv = ["generate.py", *argv]
|
||||
runpy.run_path(str(generate_script), run_name="__main__")
|
||||
sys.argv = argv
|
||||
torch_run()
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
os.chdir(old_cwd)
|
||||
@@ -159,6 +156,8 @@ def main(argv: list[str] | None = None) -> int:
|
||||
LOGGER.info("GPU cache cleared")
|
||||
|
||||
run_argv = [
|
||||
"--nproc_per_node=1",
|
||||
"generate.py",
|
||||
"--prompt",
|
||||
prompt,
|
||||
"--image_path",
|
||||
|
||||
Reference in New Issue
Block a user