1
0

Refactor pipeline to call script entrypoints directly

This commit is contained in:
2026-04-03 16:15:19 +02:00
parent 74f8159eff
commit 008ee18ba8
8 changed files with 188 additions and 140 deletions

View File

@@ -26,7 +26,7 @@ def shot_number(path: Path) -> int:
return int(match.group(1)) if match else -1
def parse_args() -> argparse.Namespace:
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--merged-dir", type=Path, default=DEFAULT_MERGED_DIR)
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
@@ -35,12 +35,12 @@ def parse_args() -> argparse.Namespace:
default=None,
help="Logging level (overrides LOG_LEVEL env var)",
)
return parser.parse_args()
return parser.parse_args(argv)
@debug_log_lifecycle
def main() -> int:
args = parse_args()
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level)
videos = sorted(args.merged_dir.glob("merged_*.mp4"), key=shot_number)

View File

@@ -19,29 +19,29 @@ load_dotenv(PROJECT_ROOT / ".env")
LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--reel-script", type=Path, default=PROJECT_ROOT / "reel_script.json")
parser.add_argument("--audios-dir", type=Path, default=PROJECT_ROOT / "audios")
parser.add_argument(
"--log-level",
default=None,
help="Logging level (overrides LOG_LEVEL env var)",
)
return parser.parse_args()
return parser.parse_args(argv)
@debug_log_lifecycle
def main() -> int:
args = parse_args()
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level)
api_key = os.getenv("ELEVENLABS_API_KEY")
if not api_key:
raise RuntimeError("ELEVENLABS_API_KEY is not set")
reel_script = PROJECT_ROOT / "reel_script.json"
audios_dir = PROJECT_ROOT / "audios"
audios_dir.mkdir(parents=True, exist_ok=True)
args.audios_dir.mkdir(parents=True, exist_ok=True)
reel_data = json.loads(reel_script.read_text())
reel_data = json.loads(args.reel_script.read_text())
client = ElevenLabs(api_key=api_key)
for shot in reel_data["shots"]:
@@ -57,7 +57,7 @@ def main() -> int:
)
audio_bytes = b"".join(audio)
out_path = audios_dir / f"output_{shot_num}.mp3"
out_path = args.audios_dir / f"output_{shot_num}.mp3"
out_path.write_bytes(audio_bytes)
return 0

View File

@@ -16,25 +16,25 @@ PROJECT_ROOT = SCRIPT_DIR.parent
LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--reel-script", type=Path, default=PROJECT_ROOT / "reel_script.json")
parser.add_argument("--images-dir", type=Path, default=PROJECT_ROOT / "images")
parser.add_argument(
"--log-level",
default=None,
help="Logging level (overrides LOG_LEVEL env var)",
)
return parser.parse_args()
return parser.parse_args(argv)
@debug_log_lifecycle
def main() -> int:
args = parse_args()
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level)
reel_script = PROJECT_ROOT / "reel_script.json"
images_dir = PROJECT_ROOT / "images"
images_dir.mkdir(parents=True, exist_ok=True)
args.images_dir.mkdir(parents=True, exist_ok=True)
reel_data = json.loads(reel_script.read_text())
reel_data = json.loads(args.reel_script.read_text())
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
@@ -54,7 +54,7 @@ def main() -> int:
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
image.save(images_dir / f"shot_{shot_num}.png")
image.save(args.images_dir / f"shot_{shot_num}.png")
return 0

View File

@@ -5,12 +5,16 @@ import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import re
from typing import Optional
from pathlib import Path
from logging_config import configure_logging, debug_log_lifecycle
LOGGER = logging.getLogger(__name__)
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
MODEL_ID = "Qwen/Qwen3-14B"
WORDS_PER_SECOND = 2.5
MAX_DEAD_AIR_SECONDS = 1
@@ -19,14 +23,24 @@ MAX_VOICEOVER_WORDS = int(MAX_VOICEOVER_SECONDS * WORDS_PER_SECOND)
MIN_VOICEOVER_WORDS = 5
def parse_args() -> argparse.Namespace:
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--topic-description",
type=Path,
default=PROJECT_ROOT / "topic_description.txt",
)
parser.add_argument(
"--output-script",
type=Path,
default=PROJECT_ROOT / "reel_script.json",
)
parser.add_argument(
"--log-level",
default=None,
help="Logging level (overrides LOG_LEVEL env var)",
)
return parser.parse_args()
return parser.parse_args(argv)
def get_device():
@@ -354,18 +368,22 @@ def parse_reel_scenario(raw_scenario: str) -> dict:
return result
if __name__ == '__main__':
args = parse_args()
@debug_log_lifecycle
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level)
with open("topic_description.txt", "r") as f:
topic = f.read()
topic = args.topic_description.read_text()
model, tokenizer = load_model()
scenario_raw = generate_reel_scenario(model, tokenizer, topic)
parsed = parse_reel_scenario(scenario_raw)
with open("reel_script.json", "w") as f:
json.dump(parsed, f)
args.output_script.write_text(json.dumps(parsed))
return 0
if __name__ == '__main__':
raise SystemExit(main())

View File

@@ -24,7 +24,7 @@ DEFAULT_AUDIOS_DIR = DEFAULT_BASE_DIR / "audios"
LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--hunyuan-dir", type=Path, default=DEFAULT_HUNYUAN_DIR)
parser.add_argument("--reel-script", type=Path, default=DEFAULT_REEL_SCRIPT)
@@ -37,7 +37,7 @@ def parse_args() -> argparse.Namespace:
default=None,
help="Logging level (overrides LOG_LEVEL env var)",
)
return parser.parse_args()
return parser.parse_args(argv)
@debug_log_lifecycle
@@ -69,8 +69,8 @@ def duration_to_video_length(duration: float) -> int:
@debug_log_lifecycle
def main() -> int:
args = parse_args()
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level)
model_path = args.hunyuan_dir / "ckpts"

View File

@@ -27,7 +27,7 @@ def shot_number(path: Path) -> int:
return int(match.group(1)) if match else -1
def parse_args() -> argparse.Namespace:
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--videos-dir", type=Path, default=DEFAULT_VIDEOS_DIR)
parser.add_argument("--audios-dir", type=Path, default=DEFAULT_AUDIOS_DIR)
@@ -38,11 +38,11 @@ def parse_args() -> argparse.Namespace:
help="If set, create merged output from video only when audio is missing.",
)
parser.add_argument("--log-level", default="INFO")
return parser.parse_args()
return parser.parse_args(argv)
def main() -> int:
args = parse_args()
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
configure_logging(args.log_level)
args.output_dir.mkdir(parents=True, exist_ok=True)