Skip to content

Commit b866269

Browse files
committed
feat(aggregation): add estimate_dual_cone_spherical_volume
1 parent 4ad95e7 commit b866269

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/torchjd/aggregation/_utils/dual_cone.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,20 @@ def _to_array(tensor: Tensor) -> np.ndarray:
6060
"""Transforms a tensor into a numpy array with float64 dtype."""
6161

6262
return tensor.cpu().detach().numpy().astype(np.float64)
63+
64+
65+
def estimate_dual_cone_spherical_volume(G: Tensor, n_samples=1_000_000) -> Tensor:
66+
"""
67+
Estimates the spherical volume of the dual cone defined by the Gramian G.
68+
"""
69+
n = G.size(0)
70+
device = G.device
71+
72+
L = torch.linalg.cholesky(G + torch.eye(n, device=device) * 1e-10)
73+
ws = torch.randn(n_samples, n, device=device)
74+
zs = ws @ L.T
75+
76+
is_inside = torch.all(zs >= 0, dim=1)
77+
proportion = is_inside.float().mean()
78+
79+
return proportion

0 commit comments

Comments
 (0)