55import tempfile
66import zipfile
77import io as python_io
8+ import base64
89
910from fastapi import FastAPI , Request , UploadFile , File
10- from fastapi .responses import HTMLResponse , StreamingResponse
11+ from fastapi .responses import HTMLResponse , StreamingResponse , JSONResponse
1112from fastapi .staticfiles import StaticFiles
1213from fastapi .templating import Jinja2Templates
1314import torch
@@ -76,6 +77,63 @@ async def read_root(request: Request):
7677
7778@app .post ("/predict" )
7879async def predict (files : list [UploadFile ] = File (...)):
80+ """Process images and return PLY data for viewing or download."""
81+ if not predictor :
82+ return JSONResponse ({"error" : "Model not loaded" }, status_code = 500 )
83+
84+ # Create a temporary directory to process files
85+ with tempfile .TemporaryDirectory () as temp_dir :
86+ temp_path = Path (temp_dir )
87+ results = []
88+
89+ for file in files :
90+ try :
91+ # Save uploaded file
92+ file_path = temp_path / file .filename
93+ with open (file_path , "wb" ) as buffer :
94+ shutil .copyfileobj (file .file , buffer )
95+
96+ LOGGER .info (f"Processing { file .filename } " )
97+
98+ # Load image using sharp's IO to get focal length and handle rotation
99+ image , _ , f_px = sharp_io .load_rgb (file_path )
100+
101+ # Run prediction
102+ gaussians = predict_image (predictor , image , f_px , device )
103+
104+ # Save PLY
105+ ply_filename = f"{ file_path .stem } .ply"
106+ ply_path = temp_path / ply_filename
107+
108+ height , width = image .shape [:2 ]
109+ save_ply (gaussians , f_px , (height , width ), ply_path )
110+
111+ # Read PLY file and encode as base64
112+ with open (ply_path , "rb" ) as f :
113+ ply_data = base64 .b64encode (f .read ()).decode ("utf-8" )
114+
115+ results .append ({
116+ "filename" : file .filename ,
117+ "ply_filename" : ply_filename ,
118+ "ply_data" : ply_data ,
119+ "width" : width ,
120+ "height" : height ,
121+ "focal_length" : f_px ,
122+ })
123+
124+ except Exception as e :
125+ LOGGER .error (f"Error processing { file .filename } : { e } " )
126+ results .append ({
127+ "filename" : file .filename ,
128+ "error" : str (e ),
129+ })
130+
131+ return JSONResponse ({"results" : results })
132+
133+
134+ @app .post ("/predict/download" )
135+ async def predict_download (files : list [UploadFile ] = File (...)):
136+ """Process images and return a ZIP file for download."""
79137 if not predictor :
80138 return HTMLResponse ("Model not loaded" , status_code = 500 )
81139
@@ -98,9 +156,6 @@ async def predict(files: list[UploadFile] = File(...)):
98156 image , _ , f_px = sharp_io .load_rgb (file_path )
99157
100158 # Run prediction
101- # We need to convert numpy image to what predict_image expects if needed
102- # predict_image expects numpy array (H, W, 3)
103-
104159 gaussians = predict_image (predictor , image , f_px , device )
105160
106161 # Save PLY
@@ -115,7 +170,6 @@ async def predict(files: list[UploadFile] = File(...)):
115170
116171 except Exception as e :
117172 LOGGER .error (f"Error processing { file .filename } : { e } " )
118- # We could add an error log to the zip or just skip
119173 continue
120174
121175 output_zip .seek (0 )
0 commit comments