@@ -131,19 +131,18 @@ def test_respects_boundaries(self, sample_data_dict):
131131 num_sections = 1 ,
132132 )
133133
134+ original_first = sample_data_dict ["image" ][0 , 0 , :, :].clone ()
135+ original_last = sample_data_dict ["image" ][0 , - 1 , :, :].clone ()
134136 original_depth = sample_data_dict ["image" ].shape [1 ]
135137
136138 # Run multiple times
137139 for _ in range (10 ):
138140 result = transform (sample_data_dict .copy ())
139141 # Shape should be preserved
140142 assert result ["image" ].shape [1 ] == original_depth
141- # First and last sections should not be zero (preserved)
142- first_sum = result ["image" ][0 , 0 , :, :].sum ()
143- last_sum = result ["image" ][0 , - 1 , :, :].sum ()
144- # First and last should have non-zero values (preserved)
145- assert first_sum > 0
146- assert last_sum > 0
143+ # First and last sections should be unchanged (preserved)
144+ assert torch .equal (result ["image" ][0 , 0 , :, :], original_first )
145+ assert torch .equal (result ["image" ][0 , - 1 , :, :], original_last )
147146
148147 def test_probability_control (self , sample_data_dict ):
149148 """Test probability control."""
0 commit comments