1
0

Define and set device

This commit is contained in:
2026-04-03 15:50:11 +02:00
parent e767aca68c
commit 74f8159eff

View File

@@ -11,7 +11,6 @@ from logging_config import configure_logging, debug_log_lifecycle
LOGGER = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_ID = "Qwen/Qwen3-14B"
WORDS_PER_SECOND = 2.5
MAX_DEAD_AIR_SECONDS = 1
@@ -30,6 +29,12 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def get_device():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
LOGGER.info("Using device: %s", device)
return device
@debug_log_lifecycle
def load_model(model_id: str = MODEL_ID):
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
@@ -42,7 +47,7 @@ def load_model(model_id: str = MODEL_ID):
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
device_map=get_device(),
trust_remote_code=True,
).eval()