1
0

Refactor pipeline to call script entrypoints directly

This commit is contained in:
2026-04-03 16:15:19 +02:00
parent 74f8159eff
commit 008ee18ba8
8 changed files with 188 additions and 140 deletions

View File

@@ -6,16 +6,15 @@ from __future__ import annotations
import argparse import argparse
import logging import logging
import os import os
import subprocess
import sys
from pathlib import Path from pathlib import Path
from typing import Callable
from src import concat_merged, generate_audios, generate_images, generate_script, generate_videos, merge_audio_video
from src.logging_config import configure_logging, debug_log_lifecycle from src.logging_config import configure_logging, debug_log_lifecycle
from src.s3_video_storage import S3VideoStorage from src.s3_video_storage import S3VideoStorage
PROJECT_ROOT = Path(__file__).resolve().parent PROJECT_ROOT = Path(__file__).resolve().parent
SCRIPT_DIR = PROJECT_ROOT / "src"
DEFAULT_BASE_DIR = PROJECT_ROOT DEFAULT_BASE_DIR = PROJECT_ROOT
DEFAULT_HUNYUAN_DIR = DEFAULT_BASE_DIR / "HunyuanVideo-1.5" DEFAULT_HUNYUAN_DIR = DEFAULT_BASE_DIR / "HunyuanVideo-1.5"
DEFAULT_REEL_SCRIPT = DEFAULT_BASE_DIR / "reel_script.json" DEFAULT_REEL_SCRIPT = DEFAULT_BASE_DIR / "reel_script.json"
@@ -53,18 +52,17 @@ def parse_args() -> argparse.Namespace:
@debug_log_lifecycle @debug_log_lifecycle
def run_step(name: str, cmd: list[str], cwd: Path | None = None) -> None: def run_step(name: str, step: Callable[[], int]) -> None:
LOGGER.info("=== %s ===", name) LOGGER.info("=== %s ===", name)
LOGGER.info("$ %s", " ".join(str(part) for part in cmd)) rc = step()
if cwd is not None: if rc != 0:
LOGGER.info("(cwd: %s)", cwd) raise RuntimeError(f"Step '{name}' failed with exit code {rc}")
subprocess.run(cmd, check=True, cwd=str(cwd) if cwd else None)
def _with_log_level(cmd: list[str], log_level: str | None) -> list[str]: def _with_log_level(argv: list[str], log_level: str | None) -> list[str]:
if not log_level: if not log_level:
return cmd return argv
return [*cmd, "--log-level", log_level] return [*argv, "--log-level", log_level]
@debug_log_lifecycle @debug_log_lifecycle
@@ -114,11 +112,17 @@ def main() -> int:
if not args.skip_generate and not args.reel_script.exists(): if not args.skip_generate and not args.reel_script.exists():
run_step( run_step(
"Generate Reel Script", "Generate Reel Script",
_with_log_level([ lambda: generate_script.main(
sys.executable, _with_log_level(
str(SCRIPT_DIR / "generate_script.py"), [
], args.log_level), "--topic-description",
cwd=args.base_dir, str(args.base_dir / "topic_description.txt"),
"--output-script",
str(args.reel_script),
],
args.log_level,
)
),
) )
if not args.reel_script.exists(): if not args.reel_script.exists():
LOGGER.error("Reel script was not generated at %s", args.reel_script) LOGGER.error("Reel script was not generated at %s", args.reel_script)
@@ -127,48 +131,61 @@ def main() -> int:
if not args.skip_generate and not args.skip_audio_generate: if not args.skip_generate and not args.skip_audio_generate:
run_step( run_step(
"Generate Audios", "Generate Audios",
_with_log_level([ lambda: generate_audios.main(
sys.executable, _with_log_level(
str(SCRIPT_DIR / "generate_audios.py"), [
], args.log_level), "--reel-script",
cwd=args.base_dir, str(args.reel_script),
"--audios-dir",
str(args.audios_dir),
],
args.log_level,
)
),
) )
if not args.skip_generate: if not args.skip_generate:
run_step( run_step(
"Generate Images", "Generate Images",
_with_log_level([ lambda: generate_images.main(
sys.executable, _with_log_level(
str(SCRIPT_DIR / "generate_images.py"), [
], args.log_level), "--reel-script",
cwd=args.base_dir, str(args.reel_script),
"--images-dir",
str(args.images_dir),
],
args.log_level,
)
),
) )
if not args.skip_generate: if not args.skip_generate:
run_step( run_step(
"Generate Videos", "Generate Videos",
_with_log_level([ lambda: generate_videos.main(
sys.executable, _with_log_level(
str(SCRIPT_DIR / "generate_videos.py"), [
"--hunyuan-dir", "--hunyuan-dir",
str(args.hunyuan_dir), str(args.hunyuan_dir),
"--reel-script", "--reel-script",
str(args.reel_script), str(args.reel_script),
"--images-dir", "--images-dir",
str(args.images_dir), str(args.images_dir),
"--videos-dir", "--videos-dir",
str(args.videos_dir), str(args.videos_dir),
"--audios-dir", "--audios-dir",
str(args.audios_dir), str(args.audios_dir),
"--seed", "--seed",
str(args.seed), str(args.seed),
], args.log_level), ],
args.log_level,
)
),
) )
if not args.skip_merge: if not args.skip_merge:
merge_cmd = [ merge_argv = [
sys.executable,
str(SCRIPT_DIR / "merge_audio_video.py"),
"--videos-dir", "--videos-dir",
str(args.videos_dir), str(args.videos_dir),
"--audios-dir", "--audios-dir",
@@ -177,28 +194,31 @@ def main() -> int:
str(args.merged_dir), str(args.merged_dir),
] ]
if args.skip_audio_generate: if args.skip_audio_generate:
merge_cmd.append("--allow-missing-audio") merge_argv.append("--allow-missing-audio")
run_step( run_step(
"Merge Audio + Video", "Merge Audio + Video",
_with_log_level(merge_cmd, args.log_level), lambda: merge_audio_video.main(_with_log_level(merge_argv, args.log_level)),
) )
if not args.skip_concat: if not args.skip_concat:
run_step( run_step(
"Concatenate Merged Videos", "Concatenate Merged Videos",
_with_log_level([ lambda: concat_merged.main(
sys.executable, _with_log_level(
str(SCRIPT_DIR / "concat_merged.py"), [
"--merged-dir", "--merged-dir",
str(args.merged_dir), str(args.merged_dir),
"--output", "--output",
str(args.output), str(args.output),
], args.log_level), ],
args.log_level,
)
),
) )
except subprocess.CalledProcessError as exc: except Exception:
LOGGER.exception("Pipeline failed at command: %s", exc.cmd) LOGGER.exception("Pipeline failed")
return exc.returncode return 1
if not args.skip_s3_upload: if not args.skip_s3_upload:
try: try:

View File

@@ -26,7 +26,7 @@ def shot_number(path: Path) -> int:
return int(match.group(1)) if match else -1 return int(match.group(1)) if match else -1
def parse_args() -> argparse.Namespace: def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--merged-dir", type=Path, default=DEFAULT_MERGED_DIR) parser.add_argument("--merged-dir", type=Path, default=DEFAULT_MERGED_DIR)
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
@@ -35,12 +35,12 @@ def parse_args() -> argparse.Namespace:
default=None, default=None,
help="Logging level (overrides LOG_LEVEL env var)", help="Logging level (overrides LOG_LEVEL env var)",
) )
return parser.parse_args() return parser.parse_args(argv)
@debug_log_lifecycle @debug_log_lifecycle
def main() -> int: def main(argv: list[str] | None = None) -> int:
args = parse_args() args = parse_args(argv)
configure_logging(args.log_level) configure_logging(args.log_level)
videos = sorted(args.merged_dir.glob("merged_*.mp4"), key=shot_number) videos = sorted(args.merged_dir.glob("merged_*.mp4"), key=shot_number)

View File

@@ -19,29 +19,29 @@ load_dotenv(PROJECT_ROOT / ".env")
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace: def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--reel-script", type=Path, default=PROJECT_ROOT / "reel_script.json")
parser.add_argument("--audios-dir", type=Path, default=PROJECT_ROOT / "audios")
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
default=None, default=None,
help="Logging level (overrides LOG_LEVEL env var)", help="Logging level (overrides LOG_LEVEL env var)",
) )
return parser.parse_args() return parser.parse_args(argv)
@debug_log_lifecycle @debug_log_lifecycle
def main() -> int: def main(argv: list[str] | None = None) -> int:
args = parse_args() args = parse_args(argv)
configure_logging(args.log_level) configure_logging(args.log_level)
api_key = os.getenv("ELEVENLABS_API_KEY") api_key = os.getenv("ELEVENLABS_API_KEY")
if not api_key: if not api_key:
raise RuntimeError("ELEVENLABS_API_KEY is not set") raise RuntimeError("ELEVENLABS_API_KEY is not set")
reel_script = PROJECT_ROOT / "reel_script.json" args.audios_dir.mkdir(parents=True, exist_ok=True)
audios_dir = PROJECT_ROOT / "audios"
audios_dir.mkdir(parents=True, exist_ok=True)
reel_data = json.loads(reel_script.read_text()) reel_data = json.loads(args.reel_script.read_text())
client = ElevenLabs(api_key=api_key) client = ElevenLabs(api_key=api_key)
for shot in reel_data["shots"]: for shot in reel_data["shots"]:
@@ -57,7 +57,7 @@ def main() -> int:
) )
audio_bytes = b"".join(audio) audio_bytes = b"".join(audio)
out_path = audios_dir / f"output_{shot_num}.mp3" out_path = args.audios_dir / f"output_{shot_num}.mp3"
out_path.write_bytes(audio_bytes) out_path.write_bytes(audio_bytes)
return 0 return 0

View File

@@ -16,25 +16,25 @@ PROJECT_ROOT = SCRIPT_DIR.parent
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace: def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--reel-script", type=Path, default=PROJECT_ROOT / "reel_script.json")
parser.add_argument("--images-dir", type=Path, default=PROJECT_ROOT / "images")
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
default=None, default=None,
help="Logging level (overrides LOG_LEVEL env var)", help="Logging level (overrides LOG_LEVEL env var)",
) )
return parser.parse_args() return parser.parse_args(argv)
@debug_log_lifecycle @debug_log_lifecycle
def main() -> int: def main(argv: list[str] | None = None) -> int:
args = parse_args() args = parse_args(argv)
configure_logging(args.log_level) configure_logging(args.log_level)
reel_script = PROJECT_ROOT / "reel_script.json" args.images_dir.mkdir(parents=True, exist_ok=True)
images_dir = PROJECT_ROOT / "images"
images_dir.mkdir(parents=True, exist_ok=True)
reel_data = json.loads(reel_script.read_text()) reel_data = json.loads(args.reel_script.read_text())
pipe = FluxPipeline.from_pretrained( pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
@@ -54,7 +54,7 @@ def main() -> int:
max_sequence_length=256, max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0), generator=torch.Generator("cpu").manual_seed(0),
).images[0] ).images[0]
image.save(images_dir / f"shot_{shot_num}.png") image.save(args.images_dir / f"shot_{shot_num}.png")
return 0 return 0

View File

@@ -5,12 +5,16 @@ import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import re import re
from typing import Optional from typing import Optional
from pathlib import Path
from logging_config import configure_logging, debug_log_lifecycle from logging_config import configure_logging, debug_log_lifecycle
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
MODEL_ID = "Qwen/Qwen3-14B" MODEL_ID = "Qwen/Qwen3-14B"
WORDS_PER_SECOND = 2.5 WORDS_PER_SECOND = 2.5
MAX_DEAD_AIR_SECONDS = 1 MAX_DEAD_AIR_SECONDS = 1
@@ -19,14 +23,24 @@ MAX_VOICEOVER_WORDS = int(MAX_VOICEOVER_SECONDS * WORDS_PER_SECOND)
MIN_VOICEOVER_WORDS = 5 MIN_VOICEOVER_WORDS = 5
def parse_args() -> argparse.Namespace: def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--topic-description",
type=Path,
default=PROJECT_ROOT / "topic_description.txt",
)
parser.add_argument(
"--output-script",
type=Path,
default=PROJECT_ROOT / "reel_script.json",
)
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
default=None, default=None,
help="Logging level (overrides LOG_LEVEL env var)", help="Logging level (overrides LOG_LEVEL env var)",
) )
return parser.parse_args() return parser.parse_args(argv)
def get_device(): def get_device():
@@ -354,18 +368,22 @@ def parse_reel_scenario(raw_scenario: str) -> dict:
return result return result
if __name__ == '__main__': @debug_log_lifecycle
args = parse_args() def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level) configure_logging(args.log_level)
with open("topic_description.txt", "r") as f: topic = args.topic_description.read_text()
topic = f.read()
model, tokenizer = load_model() model, tokenizer = load_model()
scenario_raw = generate_reel_scenario(model, tokenizer, topic) scenario_raw = generate_reel_scenario(model, tokenizer, topic)
parsed = parse_reel_scenario(scenario_raw) parsed = parse_reel_scenario(scenario_raw)
with open("reel_script.json", "w") as f: args.output_script.write_text(json.dumps(parsed))
json.dump(parsed, f) return 0
if __name__ == '__main__':
raise SystemExit(main())

View File

@@ -24,7 +24,7 @@ DEFAULT_AUDIOS_DIR = DEFAULT_BASE_DIR / "audios"
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace: def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR) parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR)
parser.add_argument("--reel-script", type=Path, default=DEFAULT_REEL_SCRIPT) parser.add_argument("--reel-script", type=Path, default=DEFAULT_REEL_SCRIPT)
@@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace:
default=None, default=None,
help="Logging level (overrides LOG_LEVEL env var)", help="Logging level (overrides LOG_LEVEL env var)",
) )
return parser.parse_args() return parser.parse_args(argv)
@debug_log_lifecycle @debug_log_lifecycle
@@ -69,8 +69,8 @@ def duration_to_video_length(duration: float) -> int:
@debug_log_lifecycle @debug_log_lifecycle
def main() -> int: def main(argv: list[str] | None = None) -> int:
args = parse_args() args = parse_args(argv)
configure_logging(args.log_level) configure_logging(args.log_level)
model_path = args.hunyuan_dir / "ckpts" model_path = args.hunyuan_dir / "ckpts"

View File

@@ -27,7 +27,7 @@ def shot_number(path: Path) -> int:
return int(match.group(1)) if match else -1 return int(match.group(1)) if match else -1
def parse_args() -> argparse.Namespace: def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--videos-dir", type=Path, default=DEFAULT_VIDEOS_DIR) parser.add_argument("--videos-dir", type=Path, default=DEFAULT_VIDEOS_DIR)
parser.add_argument("--audios-dir", type=Path, default=DEFAULT_AUDIOS_DIR) parser.add_argument("--audios-dir", type=Path, default=DEFAULT_AUDIOS_DIR)
@@ -38,11 +38,11 @@ def parse_args() -> argparse.Namespace:
help="If set, create merged output from video only when audio is missing.", help="If set, create merged output from video only when audio is missing.",
) )
parser.add_argument("--log-level", default="INFO") parser.add_argument("--log-level", default="INFO")
return parser.parse_args() return parser.parse_args(argv)
def main() -> int: def main(argv: list[str] | None = None) -> int:
args = parse_args() args = parse_args(argv)
configure_logging(args.log_level) configure_logging(args.log_level)
args.output_dir.mkdir(parents=True, exist_ok=True) args.output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -48,62 +48,72 @@ def test_full_generation_process_calls_all_scripts(monkeypatch) -> None:
log_level="DEBUG", log_level="DEBUG",
) )
executed_scripts: list[str] = [] executed_steps: list[str] = []
expected_scripts = [ expected_steps = [
"generate_script.py", "generate_script",
"generate_audios.py", "generate_audios",
"generate_images.py", "generate_images",
"generate_videos.py", "generate_videos",
"merge_audio_video.py", "merge_audio_video",
"concat_merged.py", "concat_merged",
] ]
def fake_subprocess_run(cmd: list[str], check: bool, cwd: str | None = None): def fake_generate_script_main(argv=None) -> int:
script_name = Path(cmd[1]).name if len(cmd) > 1 else "" executed_steps.append("generate_script")
if script_name not in expected_scripts: payload = {
pytest.fail(f"Unexpected external process call: {cmd}") "shots": [
{
"shot_number": 1,
"image_description": "A test image",
"voiceover": "A test voiceover",
}
]
}
reel_script.write_text(json.dumps(payload))
return 0
executed_scripts.append(script_name) def fake_generate_audios_main(argv=None) -> int:
executed_steps.append("generate_audios")
audios_dir.mkdir(parents=True, exist_ok=True)
(audios_dir / "output_1.mp3").write_bytes(b"audio")
return 0
if script_name == "generate_script.py": def fake_generate_images_main(argv=None) -> int:
payload = { executed_steps.append("generate_images")
"shots": [ images_dir.mkdir(parents=True, exist_ok=True)
{ (images_dir / "shot_1.png").write_bytes(b"image")
"shot_number": 1, return 0
"image_description": "A test image",
"voiceover": "A test voiceover",
}
]
}
reel_script.write_text(json.dumps(payload))
elif script_name == "generate_audios.py":
audios_dir.mkdir(parents=True, exist_ok=True)
(audios_dir / "output_1.mp3").write_bytes(b"audio")
elif script_name == "generate_images.py":
images_dir.mkdir(parents=True, exist_ok=True)
(images_dir / "shot_1.png").write_bytes(b"image")
elif script_name == "generate_videos.py":
videos_dir.mkdir(parents=True, exist_ok=True)
(videos_dir / "output_1.mp4").write_bytes(b"video")
elif script_name == "merge_audio_video.py":
merged_dir.mkdir(parents=True, exist_ok=True)
(merged_dir / "merged_1.mp4").write_bytes(b"merged")
elif script_name == "concat_merged.py":
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_bytes(b"final")
class Result: def fake_generate_videos_main(argv=None) -> int:
returncode = 0 executed_steps.append("generate_videos")
videos_dir.mkdir(parents=True, exist_ok=True)
(videos_dir / "output_1.mp4").write_bytes(b"video")
return 0
return Result() def fake_merge_audio_video_main(argv=None) -> int:
executed_steps.append("merge_audio_video")
merged_dir.mkdir(parents=True, exist_ok=True)
(merged_dir / "merged_1.mp4").write_bytes(b"merged")
return 0
def fake_concat_merged_main(argv=None) -> int:
executed_steps.append("concat_merged")
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_bytes(b"final")
return 0
monkeypatch.setattr(pipeline, "parse_args", lambda: args) monkeypatch.setattr(pipeline, "parse_args", lambda: args)
monkeypatch.setattr(pipeline.subprocess, "run", fake_subprocess_run) monkeypatch.setattr(pipeline.generate_script, "main", fake_generate_script_main)
monkeypatch.setattr(pipeline.generate_audios, "main", fake_generate_audios_main)
monkeypatch.setattr(pipeline.generate_images, "main", fake_generate_images_main)
monkeypatch.setattr(pipeline.generate_videos, "main", fake_generate_videos_main)
monkeypatch.setattr(pipeline.merge_audio_video, "main", fake_merge_audio_video_main)
monkeypatch.setattr(pipeline.concat_merged, "main", fake_concat_merged_main)
rc = pipeline.main() rc = pipeline.main()
assert rc == 0 assert rc == 0
assert output_path.exists() assert output_path.exists()
# Coverage check for orchestration: ensure every required script stage was called. # Coverage check for orchestration: ensure every required script stage was called.
assert executed_scripts == expected_scripts assert executed_steps == expected_steps