@@ -438,19 +438,18 @@ def trace_rays(
438438 )
439439
440440 flux_distributions = []
441+ global_active_indices = torch .nonzero (active_heliostats_mask , as_tuple = True )[0 ]
441442 for batch_index , (batch_u , batch_e ) in enumerate (self .distortions_loader ):
442443 sampler_indices = list (self .distortions_sampler )
443-
444+ batch_mask_indices = sampler_indices [
445+ batch_index * self .batch_size : (batch_index + 1 ) * self .batch_size
446+ ]
444447 active_heliostats_mask_batch = torch .zeros (
445448 self .heliostat_group .number_of_active_heliostats ,
446449 dtype = torch .bool ,
447450 device = device ,
448451 )
449- active_heliostats_mask_batch [
450- sampler_indices [
451- batch_index * self .batch_size : (batch_index + 1 ) * self .batch_size
452- ]
453- ] = True
452+ active_heliostats_mask_batch [batch_mask_indices ] = True
454453
455454 rays = self .scatter_rays (
456455 distortion_u = batch_u ,
@@ -486,9 +485,10 @@ def trace_rays(
486485 points_at_ray_origins = self .heliostat_group .active_surface_points [
487486 active_heliostats_mask_batch , None , :, :3
488487 ].expand (- 1 , self .light_source .number_of_rays , - 1 , - 1 )
489- ray_to_heliostat_mapping = torch .arange (
490- number_of_heliostats , device = device
491- ).repeat_interleave (number_of_rays * number_of_points )
488+ batch_global_indices = global_active_indices [batch_mask_indices ]
489+ ray_to_heliostat_mapping = batch_global_indices .repeat_interleave (
490+ number_of_rays * number_of_points
491+ )
492492
493493 # Filter out the blocking primitives that are relevant for blocking.
494494 filtered_blocking_primitive_indices = (
@@ -499,7 +499,7 @@ def trace_rays(
499499 ..., :3
500500 ],
501501 ray_to_heliostat_mapping = ray_to_heliostat_mapping ,
502- max_stack_size = 128 ,
502+ max_stack_size = 64 ,
503503 device = device ,
504504 )
505505 )
0 commit comments