22from pytest import mark , raises
33from torch .nn import Linear , MSELoss , ReLU , Sequential
44
5- from torchjd .autojac ._utils import _get_leaf_tensors
5+ from torchjd .autojac ._utils import get_leaf_tensors
66
77
88def test_simple_get_leaf_tensors ():
@@ -14,7 +14,7 @@ def test_simple_get_leaf_tensors():
1414 y1 = torch .tensor ([- 1.0 , 1.0 ]) @ a1 + a2 .sum ()
1515 y2 = (a1 ** 2 ).sum () + a2 .norm ()
1616
17- leaves = _get_leaf_tensors (tensors = [y1 , y2 ], excluded = set ())
17+ leaves = get_leaf_tensors (tensors = [y1 , y2 ], excluded = set ())
1818 assert set (leaves ) == {a1 , a2 }
1919
2020
@@ -35,7 +35,7 @@ def test_get_leaf_tensors_excluded_1():
3535 y1 = torch .tensor ([- 1.0 , 1.0 ]) @ a1 + b2
3636 y2 = b1
3737
38- leaves = _get_leaf_tensors (tensors = [y1 , y2 ], excluded = {b1 , b2 })
38+ leaves = get_leaf_tensors (tensors = [y1 , y2 ], excluded = {b1 , b2 })
3939 assert set (leaves ) == {a1 }
4040
4141
@@ -56,7 +56,7 @@ def test_get_leaf_tensors_excluded_2():
5656 y1 = torch .tensor ([- 1.0 , 1.0 ]) @ a1 + a2 .sum ()
5757 y2 = b1
5858
59- leaves = _get_leaf_tensors (tensors = [y1 , y2 ], excluded = {b1 , b2 })
59+ leaves = get_leaf_tensors (tensors = [y1 , y2 ], excluded = {b1 , b2 })
6060 assert set (leaves ) == {a1 , a2 }
6161
6262
@@ -71,7 +71,7 @@ def test_get_leaf_tensors_leaf_not_requiring_grad():
7171 y1 = torch .tensor ([- 1.0 , 1.0 ]) @ a1 + a2 .sum ()
7272 y2 = (a1 ** 2 ).sum () + a2 .norm ()
7373
74- leaves = _get_leaf_tensors (tensors = [y1 , y2 ], excluded = set ())
74+ leaves = get_leaf_tensors (tensors = [y1 , y2 ], excluded = set ())
7575 assert set (leaves ) == {a1 }
7676
7777
@@ -90,7 +90,7 @@ def test_get_leaf_tensors_model():
9090 y_hat = model (x )
9191 losses = loss_fn (y_hat , y )
9292
93- leaves = _get_leaf_tensors (tensors = [losses ], excluded = set ())
93+ leaves = get_leaf_tensors (tensors = [losses ], excluded = set ())
9494 assert set (leaves ) == set (model .parameters ())
9595
9696
@@ -111,7 +111,7 @@ def test_get_leaf_tensors_model_excluded_2():
111111 z_hat = model2 (y )
112112 losses = loss_fn (z_hat , z )
113113
114- leaves = _get_leaf_tensors (tensors = [losses ], excluded = {y })
114+ leaves = get_leaf_tensors (tensors = [losses ], excluded = {y })
115115 assert set (leaves ) == set (model2 .parameters ())
116116
117117
@@ -121,14 +121,14 @@ def test_get_leaf_tensors_single_root():
121121 p = torch .tensor ([1.0 , 2.0 ], requires_grad = True )
122122 y = p * 2
123123
124- leaves = _get_leaf_tensors (tensors = [y ], excluded = set ())
124+ leaves = get_leaf_tensors (tensors = [y ], excluded = set ())
125125 assert set (leaves ) == {p }
126126
127127
128128def test_get_leaf_tensors_empty_roots ():
129129 """Tests that _get_leaf_tensors returns no leaves when roots is the empty set."""
130130
131- leaves = _get_leaf_tensors (tensors = [], excluded = set ())
131+ leaves = get_leaf_tensors (tensors = [], excluded = set ())
132132 assert set (leaves ) == set ()
133133
134134
@@ -141,7 +141,7 @@ def test_get_leaf_tensors_excluded_root():
141141 y1 = torch .tensor ([- 1.0 , 1.0 ]) @ a1 + a2 .sum ()
142142 y2 = (a1 ** 2 ).sum ()
143143
144- leaves = _get_leaf_tensors (tensors = [y1 , y2 ], excluded = {y1 })
144+ leaves = get_leaf_tensors (tensors = [y1 , y2 ], excluded = {y1 })
145145 assert set (leaves ) == {a1 }
146146
147147
@@ -154,7 +154,7 @@ def test_get_leaf_tensors_deep(depth: int):
154154 for i in range (depth ):
155155 sum_ = sum_ + one
156156
157- leaves = _get_leaf_tensors (tensors = [sum_ ], excluded = set ())
157+ leaves = get_leaf_tensors (tensors = [sum_ ], excluded = set ())
158158 assert set (leaves ) == {one }
159159
160160
@@ -163,7 +163,7 @@ def test_get_leaf_tensors_leaf():
163163
164164 a = torch .tensor (1.0 , requires_grad = True )
165165 with raises (ValueError ):
166- _ = _get_leaf_tensors (tensors = [a ], excluded = set ())
166+ _ = get_leaf_tensors (tensors = [a ], excluded = set ())
167167
168168
169169def test_get_leaf_tensors_tensor_not_requiring_grad ():
@@ -173,7 +173,7 @@ def test_get_leaf_tensors_tensor_not_requiring_grad():
173173
174174 a = torch .tensor (1.0 , requires_grad = False ) * 2
175175 with raises (ValueError ):
176- _ = _get_leaf_tensors (tensors = [a ], excluded = set ())
176+ _ = get_leaf_tensors (tensors = [a ], excluded = set ())
177177
178178
179179def test_get_leaf_tensors_excluded_leaf ():
@@ -182,7 +182,7 @@ def test_get_leaf_tensors_excluded_leaf():
182182 a = torch .tensor (1.0 , requires_grad = True ) * 2
183183 b = torch .tensor (2.0 , requires_grad = True )
184184 with raises (ValueError ):
185- _ = _get_leaf_tensors (tensors = [a ], excluded = {b })
185+ _ = get_leaf_tensors (tensors = [a ], excluded = {b })
186186
187187
188188def test_get_leaf_tensors_excluded_not_requiring_grad ():
@@ -193,4 +193,4 @@ def test_get_leaf_tensors_excluded_not_requiring_grad():
193193 a = torch .tensor (1.0 , requires_grad = True ) * 2
194194 b = torch .tensor (2.0 , requires_grad = False ) * 2
195195 with raises (ValueError ):
196- _ = _get_leaf_tensors (tensors = [a ], excluded = {b })
196+ _ = get_leaf_tensors (tensors = [a ], excluded = {b })
0 commit comments