1
0
Files
ContentGeneration/src/generate_images.py

63 lines
1.6 KiB
Python

from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
import torch
from diffusers import FluxPipeline
from logging_config import configure_logging, debug_log_lifecycle
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parents[1]
LOGGER = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--log-level",
default=None,
help="Logging level (overrides LOG_LEVEL env var)",
)
return parser.parse_args()
@debug_log_lifecycle
def main() -> int:
args = parse_args()
configure_logging(args.log_level)
reel_script = PROJECT_ROOT / "reel_script.json"
images_dir = PROJECT_ROOT / "images"
images_dir.mkdir(parents=True, exist_ok=True)
reel_data = json.loads(reel_script.read_text())
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
for shot in reel_data["shots"]:
shot_num = shot["shot_number"]
prompt = shot["image_description"]
LOGGER.info("Generating image for shot %s: %s", shot_num, prompt)
image = pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
image.save(images_dir / f"shot_{shot_num}.png")
return 0
if __name__ == "__main__":
raise SystemExit(main())