diff --git a/src/generate_script.py b/src/generate_script.py index a074f64..b95ae17 100644 --- a/src/generate_script.py +++ b/src/generate_script.py @@ -11,7 +11,6 @@ from logging_config import configure_logging, debug_log_lifecycle LOGGER = logging.getLogger(__name__) -device = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL_ID = "Qwen/Qwen3-14B" WORDS_PER_SECOND = 2.5 MAX_DEAD_AIR_SECONDS = 1 @@ -30,6 +29,12 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() +def get_device(): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + LOGGER.info("Using device: %s", device) + return device + + @debug_log_lifecycle def load_model(model_id: str = MODEL_ID): tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) @@ -42,7 +47,7 @@ def load_model(model_id: str = MODEL_ID): model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, - device_map="auto", + device_map=get_device(), trust_remote_code=True, ).eval()