forked from LiveCarta/ContentGeneration
63 lines
1.6 KiB
Python
63 lines
1.6 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from diffusers import FluxPipeline
|
|
from logging_config import configure_logging, debug_log_lifecycle
|
|
|
|
|
|
SCRIPT_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = SCRIPT_DIR.parents[1]
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--log-level",
|
|
default=None,
|
|
help="Logging level (overrides LOG_LEVEL env var)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
@debug_log_lifecycle
|
|
def main() -> int:
|
|
args = parse_args()
|
|
configure_logging(args.log_level)
|
|
reel_script = PROJECT_ROOT / "reel_script.json"
|
|
images_dir = PROJECT_ROOT / "images"
|
|
images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
reel_data = json.loads(reel_script.read_text())
|
|
|
|
pipe = FluxPipeline.from_pretrained(
|
|
"black-forest-labs/FLUX.1-schnell",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
pipe.enable_model_cpu_offload()
|
|
|
|
for shot in reel_data["shots"]:
|
|
shot_num = shot["shot_number"]
|
|
prompt = shot["image_description"]
|
|
LOGGER.info("Generating image for shot %s: %s", shot_num, prompt)
|
|
|
|
image = pipe(
|
|
prompt,
|
|
guidance_scale=0.0,
|
|
num_inference_steps=4,
|
|
max_sequence_length=256,
|
|
generator=torch.Generator("cpu").manual_seed(0),
|
|
).images[0]
|
|
image.save(images_dir / f"shot_{shot_num}.png")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main()) |