We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f23f15b commit 6312d50Copy full SHA for 6312d50
1 file changed
src/arm_pytorch_utilities/tensor_utils.py
@@ -116,3 +116,7 @@ def ensure_diagonal(Q, dim):
116
raise RuntimeError("Expect {} sized diagonal vector but given {}".format(dim, Q.shape[0]))
117
Q = torch.diag(Q)
118
return Q
119
+
120
+def first_positive(x, dim=0):
121
+ nonz = (x > 0)
122
+ return ((nonz.cumsum(dim) == 1) & nonz).max(dim)
0 commit comments