forked from LiveCarta/ContentGeneration
Define and set device
This commit is contained in:
@@ -11,7 +11,6 @@ from logging_config import configure_logging, debug_log_lifecycle
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
MODEL_ID = "Qwen/Qwen3-14B"
|
||||
WORDS_PER_SECOND = 2.5
|
||||
MAX_DEAD_AIR_SECONDS = 1
|
||||
@@ -30,6 +29,12 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_device():
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
LOGGER.info("Using device: %s", device)
|
||||
return device
|
||||
|
||||
|
||||
@debug_log_lifecycle
|
||||
def load_model(model_id: str = MODEL_ID):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
@@ -42,7 +47,7 @@ def load_model(model_id: str = MODEL_ID):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
device_map=get_device(),
|
||||
trust_remote_code=True,
|
||||
).eval()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user