-
Notifications
You must be signed in to change notification settings - Fork 595
Expand file tree
/
Copy pathtest_predict.py
More file actions
75 lines (55 loc) · 2.29 KB
/
test_predict.py
File metadata and controls
75 lines (55 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import pytest
import torch
from pathlib import Path
from sharp.cli.predict import predict_image
from sharp.models import PredictorParams, create_predictor
from sharp.utils.gaussians import Gaussians3D
@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.mps.is_available(),
reason="Requires CUDA or MPS for model inference",
)
def test_predict_image_with_model(tmp_path):
"""Test predict_image function with a real model checkpoint."""
example_path = Path(__file__).parent / "data" / "example.jpg"
if not example_path.exists():
pytest.skip("Test image not found")
from sharp.utils import io
image, _, f_px = io.load_rgb(example_path)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
# Use the pre-cached model checkpoint
checkpoint_path = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" / "sharp_2572gikvuh.pt"
if not checkpoint_path.exists():
pytest.skip("Model checkpoint not found in cache")
try:
predictor = create_predictor(PredictorParams(checkpoint_path=str(checkpoint_path)))
predictor.eval()
predictor.to(device)
gaussians = predict_image(predictor, image, f_px, device)
assert isinstance(gaussians, Gaussians3D)
assert gaussians.mean_vectors.shape[0] == 1
assert gaussians.mean_vectors.shape[2] == 3
assert gaussians.colors.shape[2] == 3
assert gaussians.opacities.shape[1] > 0
except Exception as e:
pytest.skip(f"Model inference failed (likely missing checkpoint): {e}")
def test_predict_image_signature():
"""Test that predict_image function has correct signature."""
import inspect
sig = inspect.signature(predict_image)
params = sig.parameters
assert "predictor" in params
assert "image" in params
assert "f_px" in params
assert "device" in params
def test_create_predictor():
"""Test creating a predictor model."""
params = PredictorParams()
predictor = create_predictor(params)
assert predictor is not None
assert hasattr(predictor, "eval")
assert hasattr(predictor, "to")
def test_predictor_params():
"""Test PredictorParams creation."""
params = PredictorParams()
assert params is not None