From 0b1204e149d0b063ce0edd5e6a558753077410e1 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Wed, 11 Feb 2026 23:59:45 +0530 Subject: [PATCH] Fix hardcoded CUDA device in api.py to support MPS and CPU fallback The API server crashes on non-CUDA systems (e.g., macOS with Apple Silicon or CPU-only machines) because DEVICE is hardcoded to "cuda" and the model loading uses .cuda() directly. Changes: - Auto-detect the best available device (CUDA > MPS > CPU) - Replace .cuda() with .to(DEVICE) for portable device placement - Update torch_gc() to handle MPS cache clearing and skip CUDA-specific cleanup when not on a CUDA device --- api.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/api.py b/api.py index 693c70ac..65624e3f 100644 --- a/api.py +++ b/api.py @@ -3,16 +3,18 @@ import uvicorn, json, datetime import torch -DEVICE = "cuda" -DEVICE_ID = "0" -CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE +DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +DEVICE_ID = "0" if DEVICE == "cuda" else None +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE == "cuda" and DEVICE_ID else DEVICE def torch_gc(): - if torch.cuda.is_available(): + if DEVICE == "cuda" and torch.cuda.is_available(): with torch.cuda.device(CUDA_DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif DEVICE == "mps": + torch.mps.empty_cache() app = FastAPI() @@ -51,6 +53,6 @@ async def create_item(request: Request): if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to(DEVICE) model.eval() uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)