forked from LiveCarta/ContentGeneration
Run Hunyuan's generate script with torchrun utility
This commit is contained in:
@@ -14,6 +14,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributed.run import main as torch_run
|
||||||
|
|
||||||
from src.logging_config import configure_logging, debug_log_lifecycle
|
from src.logging_config import configure_logging, debug_log_lifecycle
|
||||||
|
|
||||||
@@ -48,17 +49,13 @@ def _run_hunyuan_generate_in_process(
|
|||||||
argv: list[str],
|
argv: list[str],
|
||||||
env_update: dict[str, str],
|
env_update: dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
generate_script = hunyuan_dir / "generate.py"
|
|
||||||
if not generate_script.exists():
|
|
||||||
raise FileNotFoundError(f"Hunyuan generate script not found at {generate_script}")
|
|
||||||
|
|
||||||
old_argv = sys.argv[:]
|
old_argv = sys.argv[:]
|
||||||
old_cwd = Path.cwd()
|
old_cwd = Path.cwd()
|
||||||
try:
|
try:
|
||||||
with _temporary_environ(env_update):
|
with _temporary_environ(env_update):
|
||||||
os.chdir(hunyuan_dir)
|
os.chdir(hunyuan_dir)
|
||||||
sys.argv = ["generate.py", *argv]
|
sys.argv = argv
|
||||||
runpy.run_path(str(generate_script), run_name="__main__")
|
torch_run()
|
||||||
finally:
|
finally:
|
||||||
sys.argv = old_argv
|
sys.argv = old_argv
|
||||||
os.chdir(old_cwd)
|
os.chdir(old_cwd)
|
||||||
@@ -159,6 +156,8 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
LOGGER.info("GPU cache cleared")
|
LOGGER.info("GPU cache cleared")
|
||||||
|
|
||||||
run_argv = [
|
run_argv = [
|
||||||
|
"--nproc_per_node=1",
|
||||||
|
"generate.py",
|
||||||
"--prompt",
|
"--prompt",
|
||||||
prompt,
|
prompt,
|
||||||
"--image_path",
|
"--image_path",
|
||||||
|
|||||||
Reference in New Issue
Block a user