@@ -19,7 +19,7 @@ class Loss:
1919
2020 def __init__ (self , loss_function : torch .nn .Module ) -> None :
2121 """
22- Initialize the the base loss.
22+ Initialize the base loss.
2323
2424 Parameters
2525 ----------
@@ -41,10 +41,10 @@ def __call__(
4141 ----------
4242 prediction : torch.Tensor
4343 The predicted values.
44- Tensor of variable shape .
44+ Shape is variable.
4545 ground_truth : torch.Tensor
4646 The ground truth.
47- Tensor of variable shape .
47+ Shape is variable.
4848 \*\*kwargs : Any
4949 Keyword arguments.
5050
@@ -87,13 +87,13 @@ def __call__(
8787 ----------
8888 prediction : torch.Tensor
8989 The predicted values.
90- Tensor of shape [number_of_samples, ...].
90+ Shape is `` [number_of_samples, ...]`` .
9191 ground_truth : torch.Tensor
9292 The ground truth.
93- Tensor of shape [number_of_samples, ...].
93+ Shape is `` [number_of_samples, ...]`` .
9494 \*\*kwargs : Any
9595 Keyword arguments.
96- The ``reduction_dimensions`` is an expected keyword argument for the vector loss.
96+ ``reduction_dimensions`` is an expected keyword argument for the vector loss.
9797
9898 Raises
9999 ------
@@ -104,7 +104,7 @@ def __call__(
104104 -------
105105 torch.Tensor
106106 The summed MSE vector loss reduced along the specified dimensions.
107- Tensor of shape [number_of_samples].
107+ Shape is `` [number_of_samples]`` .
108108 """
109109 expected_kwargs = ["reduction_dimensions" ]
110110 for key in expected_kwargs :
@@ -162,13 +162,13 @@ def __call__(
162162 ----------
163163 prediction : torch.Tensor
164164 The predicted values.
165- Tensor of shape [number_of_samples, bitmap_resolution_e, bitmap_resolution_u].
165+ Shape is `` [number_of_samples, bitmap_resolution_e, bitmap_resolution_u]`` .
166166 ground_truth : torch.Tensor
167167 The ground truth.
168- Tensor of shape [number_of_samples, 4].
168+ Shape is `` [number_of_samples, 4]`` .
169169 \*\*kwargs : Any
170170 Keyword arguments.
171- The ``reduction_dimensions``, ``target_area_indices`` and ``device`` are expected keyword arguments for the focal spot loss.
171+ ``target_area_indices`` and ``device`` are expected keyword arguments for the focal spot loss.
172172
173173 Raises
174174 ------
@@ -179,9 +179,9 @@ def __call__(
179179 -------
180180 torch.Tensor
181181 The focal spot loss.
182- Tensor of shape [number_of_samples].
182+ Shape is `` [number_of_samples]`` .
183183 """
184- expected_kwargs = ["reduction_dimensions" , " device" , "target_area_indices" ]
184+ expected_kwargs = ["device" , "target_area_indices" ]
185185 errors = []
186186 for key in expected_kwargs :
187187 if key not in kwargs :
@@ -196,21 +196,20 @@ def __call__(
196196
197197 target_area_indices = kwargs ["target_area_indices" ]
198198
199- focal_spot = utils .get_center_of_mass (
199+ focal_spots_bitmap = utils .get_center_of_mass (
200200 bitmaps = prediction ,
201- target_centers = self .scenario .target_areas .centers [target_area_indices ],
202- target_widths = self .scenario .target_areas .dimensions [target_area_indices ][
203- :, index_mapping .target_area_width
204- ],
205- target_heights = self .scenario .target_areas .dimensions [target_area_indices ][
206- :, index_mapping .target_area_height
207- ],
208201 device = device ,
209202 )
210203
211- loss = torch .norm (focal_spot [:, :3 ] - ground_truth [:, :3 ], dim = 1 )
204+ focal_spot_coordinates = utils .bitmap_coordinates_to_target_coordinates (
205+ bitmap_coordinates = focal_spots_bitmap ,
206+ bitmap_resolution = torch .tensor (prediction .shape [1 :]),
207+ solar_tower = self .scenario .solar_tower ,
208+ target_area_indices = target_area_indices ,
209+ device = device ,
210+ )
212211
213- return loss
212+ return torch . norm ( focal_spot_coordinates [:, : 3 ] - ground_truth [:, : 3 ], dim = 1 )
214213
215214
216215class PixelLoss (Loss ):
@@ -257,13 +256,13 @@ def __call__(
257256 ----------
258257 prediction : torch.Tensor
259258 The predicted values.
260- Tensor of shape [number_of_samples, bitmap_resolution_e, bitmap_resolution_u].
259+ Shape is `` [number_of_samples, bitmap_resolution_e, bitmap_resolution_u]`` .
261260 ground_truth : torch.Tensor
262261 The ground truth.
263- Tensor of shape [number_of_samples, bitmap_resolution_e, bitmap_resolution_u].
262+ Shape is `` [number_of_samples, bitmap_resolution_e, bitmap_resolution_u]`` .
264263 \*\*kwargs : Any
265264 Keyword arguments.
266- The ``reduction_dimensions``, ``target_area_indices`` and optionally ``device`` are expected keyword arguments for the pixel loss.
265+ ``reduction_dimensions``, ``target_area_indices``, and ``device`` are expected keyword arguments for the pixel loss.
267266
268267 Raises
269268 ------
@@ -334,13 +333,13 @@ def __call__(
334333 ----------
335334 prediction : torch.Tensor
336335 The predicted values.
337- Tensor of shape [number_of_samples, bitmap_resolution_e, bitmap_resolution_u].
336+ Shape is `` [number_of_samples, bitmap_resolution_e, bitmap_resolution_u]`` .
338337 ground_truth : torch.Tensor
339338 The ground truth.
340- Tensor of shape [number_of_samples, bitmap_resolution_e, bitmap_resolution_u].
339+ Shape is `` [number_of_samples, bitmap_resolution_e, bitmap_resolution_u]`` .
341340 \*\*kwargs : Any
342341 Keyword arguments.
343- The ``reduction_dimensions`` is an expected keyword argument for the KL-divergence loss.
342+ ``reduction_dimensions`` is an expected keyword argument for the KL-divergence loss.
344343
345344 Raises
346345 ------
@@ -351,7 +350,7 @@ def __call__(
351350 -------
352351 torch.Tensor
353352 The summed KL-divergence loss reduced along the specified dimensions.
354- Tensor of shape [number_of_samples].
353+ Shape is `` [number_of_samples]`` .
355354 """
356355 expected_kwargs = ["reduction_dimensions" ]
357356 for key in expected_kwargs :
@@ -413,21 +412,17 @@ def __call__(
413412 ----------
414413 prediction : torch.Tensor
415414 The predicted values.
416- Tensor of shape [number_of_samples, 4].
415+ Shape is `` [number_of_samples, 4]`` .
417416 ground_truth : torch.Tensor
418417 The ground truth.
419- Tensor of shape [number_of_samples, 4].
418+ Shape is `` [number_of_samples, 4]`` .
420419 \*\*kwargs : Any
421420 Keyword arguments.
422421
423422 Returns
424423 -------
425424 torch.Tensor
426425 The summed loss reduced along the specified dimensions.
427- Tensor of shape [number_of_samples].
426+ Shape is `` [number_of_samples]`` .
428427 """
429- cosine_similarity = self .loss_function (prediction , ground_truth )
430-
431- loss = 1.0 - cosine_similarity
432-
433- return loss
428+ return 1.0 - self .loss_function (prediction , ground_truth )
0 commit comments