forked from LiveCarta/ContentGeneration
Define and set device
This commit is contained in:
@@ -11,7 +11,6 @@ from logging_config import configure_logging, debug_log_lifecycle
|
|||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
MODEL_ID = "Qwen/Qwen3-14B"
|
MODEL_ID = "Qwen/Qwen3-14B"
|
||||||
WORDS_PER_SECOND = 2.5
|
WORDS_PER_SECOND = 2.5
|
||||||
MAX_DEAD_AIR_SECONDS = 1
|
MAX_DEAD_AIR_SECONDS = 1
|
||||||
@@ -30,6 +29,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
LOGGER.info("Using device: %s", device)
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
@debug_log_lifecycle
|
@debug_log_lifecycle
|
||||||
def load_model(model_id: str = MODEL_ID):
|
def load_model(model_id: str = MODEL_ID):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||||
@@ -42,7 +47,7 @@ def load_model(model_id: str = MODEL_ID):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
quantization_config=bnb_config,
|
quantization_config=bnb_config,
|
||||||
device_map="auto",
|
device_map=get_device(),
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
).eval()
|
).eval()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user