Skip to content

Commit 8afb545

Browse files
committed
Added web ui;
Added start.command for one-click setup; Updated readme.md for one-click setup; Added images of web-ui; Added new web viewer after ply is generated; Cleaned up project by moving styles and js into seperate files;
1 parent cdb4ddc commit 8afb545

11 files changed

Lines changed: 1711 additions & 0 deletions

File tree

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,21 @@ If you find our work useful, please cite the following paper:
8989

9090
Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details.
9191

92+
## Web Interface (One-Click Setup)
93+
94+
For a simple web interface where you can upload images and download 3D Gaussians:
95+
96+
**macOS:** Double-click `start.command` in Finder.
97+
98+
This will automatically:
99+
- Create the conda environment if it doesn't exist
100+
- Install all dependencies if needed
101+
- Start the web server at http://localhost:8000
102+
103+
![](data/web-dark.png)
104+
![](data/web-light.png)
105+
106+
92107
## License
93108

94109
Please check out the repository [LICENSE](LICENSE) before using the provided code and

data/web-dark.png

4.4 MB
Loading

data/web-light.png

4.53 MB
Loading

src/sharp/web/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Sharp Web Interface
2+
3+
This is a web interface for the Sharp 3D prediction model.
4+
5+
## Prerequisites
6+
7+
Make sure you have the `sharp` package installed (see root README).
8+
Install the web dependencies:
9+
10+
```bash
11+
pip install -r requirements.txt
12+
```
13+
14+
## Running the Server
15+
16+
Run the following command from the `web` directory:
17+
18+
```bash
19+
python app.py
20+
```
21+
22+
Or using uvicorn directly:
23+
24+
```bash
25+
uvicorn app:app --reload --host 0.0.0.0 --port 8000
26+
```
27+
28+
## Usage
29+
30+
1. Open your browser and navigate to `http://localhost:8000`.
31+
2. Drag and drop images or click to select them.
32+
3. Click "Predict 3D Gaussians".
33+
4. A zip file containing the resulting `.ply` files will be downloaded automatically.

src/sharp/web/app.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import sys
2+
from pathlib import Path
3+
import logging
4+
import shutil
5+
import tempfile
6+
import zipfile
7+
import io as python_io
8+
import base64
9+
10+
from fastapi import FastAPI, Request, UploadFile, File
11+
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
12+
from fastapi.staticfiles import StaticFiles
13+
from fastapi.templating import Jinja2Templates
14+
import torch
15+
import numpy as np
16+
17+
# Add src to path so we can import sharp
18+
sys.path.append(str(Path(__file__).parent.parent / "src"))
19+
20+
from sharp.models import (
21+
PredictorParams,
22+
RGBGaussianPredictor,
23+
create_predictor,
24+
)
25+
from sharp.utils import io as sharp_io
26+
from sharp.utils.gaussians import save_ply
27+
from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL
28+
29+
# Configure logging
30+
logging.basicConfig(level=logging.INFO)
31+
LOGGER = logging.getLogger(__name__)
32+
33+
app = FastAPI()
34+
35+
# Mount static files if needed (we created the dir)
36+
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
37+
38+
templates = Jinja2Templates(directory=Path(__file__).parent / "templates")
39+
40+
# Global variables for the model
41+
predictor: RGBGaussianPredictor = None
42+
device: torch.device = None
43+
44+
@app.on_event("startup")
45+
async def startup_event():
46+
global predictor, device
47+
48+
# Determine device
49+
if torch.cuda.is_available():
50+
device_str = "cuda"
51+
elif torch.mps.is_available():
52+
device_str = "mps"
53+
else:
54+
device_str = "cpu"
55+
56+
device = torch.device(device_str)
57+
LOGGER.info(f"Using device: {device}")
58+
59+
# Load model
60+
LOGGER.info("Loading model...")
61+
try:
62+
# Try to load from cache or download
63+
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True, map_location=device)
64+
65+
predictor = create_predictor(PredictorParams())
66+
predictor.load_state_dict(state_dict)
67+
predictor.eval()
68+
predictor.to(device)
69+
LOGGER.info("Model loaded successfully.")
70+
except Exception as e:
71+
LOGGER.error(f"Failed to load model: {e}")
72+
raise e
73+
74+
@app.get("/", response_class=HTMLResponse)
75+
async def read_root(request: Request):
76+
return templates.TemplateResponse("index.html", {"request": request})
77+
78+
@app.post("/predict")
79+
async def predict(files: list[UploadFile] = File(...)):
80+
"""Process images and return PLY data for viewing or download."""
81+
if not predictor:
82+
return JSONResponse({"error": "Model not loaded"}, status_code=500)
83+
84+
# Create a temporary directory to process files
85+
with tempfile.TemporaryDirectory() as temp_dir:
86+
temp_path = Path(temp_dir)
87+
results = []
88+
89+
for file in files:
90+
try:
91+
# Save uploaded file
92+
file_path = temp_path / file.filename
93+
with open(file_path, "wb") as buffer:
94+
shutil.copyfileobj(file.file, buffer)
95+
96+
LOGGER.info(f"Processing {file.filename}")
97+
98+
# Load image using sharp's IO to get focal length and handle rotation
99+
image, _, f_px = sharp_io.load_rgb(file_path)
100+
101+
# Run prediction
102+
gaussians = predict_image(predictor, image, f_px, device)
103+
104+
# Save PLY
105+
ply_filename = f"{file_path.stem}.ply"
106+
ply_path = temp_path / ply_filename
107+
108+
height, width = image.shape[:2]
109+
save_ply(gaussians, f_px, (height, width), ply_path)
110+
111+
# Read PLY file and encode as base64
112+
with open(ply_path, "rb") as f:
113+
ply_data = base64.b64encode(f.read()).decode("utf-8")
114+
115+
results.append({
116+
"filename": file.filename,
117+
"ply_filename": ply_filename,
118+
"ply_data": ply_data,
119+
"width": width,
120+
"height": height,
121+
"focal_length": f_px,
122+
})
123+
124+
except Exception as e:
125+
LOGGER.error(f"Error processing {file.filename}: {e}")
126+
results.append({
127+
"filename": file.filename,
128+
"error": str(e),
129+
})
130+
131+
return JSONResponse({"results": results})
132+
133+
134+
@app.post("/predict/download")
135+
async def predict_download(files: list[UploadFile] = File(...)):
136+
"""Process images and return a ZIP file for download."""
137+
if not predictor:
138+
return HTMLResponse("Model not loaded", status_code=500)
139+
140+
# Create a temporary directory to process files
141+
with tempfile.TemporaryDirectory() as temp_dir:
142+
temp_path = Path(temp_dir)
143+
output_zip = python_io.BytesIO()
144+
145+
with zipfile.ZipFile(output_zip, "w") as zf:
146+
for file in files:
147+
try:
148+
# Save uploaded file
149+
file_path = temp_path / file.filename
150+
with open(file_path, "wb") as buffer:
151+
shutil.copyfileobj(file.file, buffer)
152+
153+
LOGGER.info(f"Processing {file.filename}")
154+
155+
# Load image using sharp's IO to get focal length and handle rotation
156+
image, _, f_px = sharp_io.load_rgb(file_path)
157+
158+
# Run prediction
159+
gaussians = predict_image(predictor, image, f_px, device)
160+
161+
# Save PLY
162+
ply_filename = f"{file_path.stem}.ply"
163+
ply_path = temp_path / ply_filename
164+
165+
height, width = image.shape[:2]
166+
save_ply(gaussians, f_px, (height, width), ply_path)
167+
168+
# Add to zip
169+
zf.write(ply_path, ply_filename)
170+
171+
except Exception as e:
172+
LOGGER.error(f"Error processing {file.filename}: {e}")
173+
continue
174+
175+
output_zip.seek(0)
176+
return StreamingResponse(
177+
output_zip,
178+
media_type="application/zip",
179+
headers={"Content-Disposition": "attachment; filename=gaussians.zip"}
180+
)
181+
182+
if __name__ == "__main__":
183+
import uvicorn
184+
uvicorn.run(app, host="0.0.0.0", port=8000)

src/sharp/web/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
fastapi
2+
uvicorn
3+
python-multipart
4+
jinja2

0 commit comments

Comments
 (0)