forked from LiveCarta/ContentGeneration
28 lines
869 B
Python
28 lines
869 B
Python
import torch
|
|
from diffusers import FluxPipeline
|
|
import json
|
|
import os
|
|
|
|
if __name__ == '__main__':
|
|
|
|
script_path = "reel_script.json"
|
|
with open(script_path, "r") as f:
|
|
reel_data = json.load(f)
|
|
|
|
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
|
pipe.enable_model_cpu_offload()
|
|
|
|
for shot in reel_data["shots"]:
|
|
print(shot["shot_number"], shot["image_description"])
|
|
prompt = shot["image_description"]
|
|
image = pipe(
|
|
prompt,
|
|
guidance_scale=0.0,
|
|
num_inference_steps=4,
|
|
max_sequence_length=256,
|
|
generator=torch.Generator("cpu").manual_seed(0)
|
|
).images[0]
|
|
|
|
if not os.path.exists("images"):
|
|
os.makedirs("images")
|
|
image.save(f"images/shot_{shot["shot_number"]}.png") |