diff --git a/api.py b/api.py index 693c70ac..65624e3f 100644 --- a/api.py +++ b/api.py @@ -3,16 +3,18 @@ import uvicorn, json, datetime import torch -DEVICE = "cuda" -DEVICE_ID = "0" -CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE +DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +DEVICE_ID = "0" if DEVICE == "cuda" else None +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE == "cuda" and DEVICE_ID else DEVICE def torch_gc(): - if torch.cuda.is_available(): + if DEVICE == "cuda" and torch.cuda.is_available(): with torch.cuda.device(CUDA_DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif DEVICE == "mps": + torch.mps.empty_cache() app = FastAPI() @@ -51,6 +53,6 @@ async def create_item(request: Request): if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to(DEVICE) model.eval() uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)