@@ -205,12 +205,19 @@ def __setitem__(self, index, value):
205205 the specified index positions.
206206 """
207207 if self .partition is Partition .BROADCAST :
208+ ncp = get_module (self .engine )
208209 view = self .local_array [index ]
210+ buf = ncp .empty (view .shape , dtype = self .dtype )
209211 if self .rank == 0 :
210- view [...] = value
211- view = self ._bcast (self .base_comm , self .base_comm_nccl ,
212- view , root = 0 , engine = self .engine )
213- self .local_array [index ] = view
212+ buf [...] = value
213+ buf = self ._bcast (
214+ self .base_comm ,
215+ self .base_comm_nccl ,
216+ buf ,
217+ root = 0 ,
218+ engine = self .engine
219+ )
220+ self .local_array [index ] = buf
214221 else :
215222 self .local_array [index ] = value
216223
@@ -644,8 +651,25 @@ def multiply(self, dist_array):
644651 ProductArray [:] = self .local_array * dist_array
645652 return ProductArray
646653
647- def dot (self , dist_array ):
648- """Distributed Dot Product
654+ def dot (self , dist_array , vdot : bool = False ):
655+ """
656+ Compute the distributed dot product between this array and another
657+ distributed array.
658+
659+ Parameters
660+ ----------
661+ dist_array : :obj:`pylops_mpi.DistributedArray`
662+ The distributed array with which to compute the dot product.
663+ It must have a compatible shape and partitioning.
664+ vdot : bool, optional, Defaults to `False`
665+ If True, compute the complex conjugate dot product (like numpy.vdot),
666+ where the first argument is conjugated. If False, compute the standard
667+ dot product.
668+
669+ Returns
670+ -------
671+ result : float
672+ The result of the dot product across all ranks. This is reduced across all processes.
649673 """
650674 self ._check_partition_shape (dist_array )
651675 self ._check_mask (dist_array )
@@ -656,8 +680,9 @@ def dot(self, dist_array):
656680 y = DistributedArray .to_dist (x = dist_array .local_array , base_comm = self .base_comm , base_comm_nccl = self .base_comm_nccl ) \
657681 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else dist_array
658682 # Flatten the local arrays and calculate dot product
683+ dot_func = ncp .vdot if vdot else ncp .dot
659684 return self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
660- ncp . dot (x .local_array .flatten (), y .local_array .flatten ()),
685+ dot_func (x .local_array .flatten (), y .local_array .flatten ()),
661686 engine = self .engine )
662687
663688 def _compute_vector_norm (self , local_array : NDArray ,
@@ -839,6 +864,15 @@ def ravel(self, order: Optional[str] = "C"):
839864 arr [:] = x
840865 return arr
841866
867+ def empty_like (self ):
868+ """Creates an empty like DistributedArray with uninitialized values
869+ """
870+ dist = DistributedArray (global_shape = self .global_shape , base_comm = self .base_comm ,
871+ base_comm_nccl = self .base_comm_nccl , partition = self .partition ,
872+ axis = self .axis , local_shapes = self .local_shapes , mask = self .mask ,
873+ engine = self .engine , dtype = self .dtype )
874+ return dist
875+
842876 def add_ghost_cells (self , cells_front : Optional [int ] = None ,
843877 cells_back : Optional [int ] = None ):
844878 """Add ghost cells to the DistributedArray along the axis
@@ -1038,13 +1072,30 @@ def multiply(self, stacked_array):
10381072 ProductArray [iarr ][:] = (self [iarr ] * stacked_array )[:]
10391073 return ProductArray
10401074
1041- def dot (self , stacked_array ):
1042- """Dot Product of Stacked Distributed Arrays
1075+ def dot (self , stacked_array , vdot : bool = False ):
1076+ """
1077+ Compute the distributed dot product between this array and another
1078+ distributed array.
1079+
1080+ Parameters
1081+ ----------
1082+ stacked_array : :obj:`pylops_mpi.StackedDistributedArray`
1083+ The distributed array with which to compute the dot product.
1084+ It must have a compatible shape and partitioning.
1085+ vdot : bool, optional, Defaults to `False`
1086+ If True, compute the complex conjugate dot product (like numpy.vdot),
1087+ where the first argument is conjugated. If False, compute the standard
1088+ dot product.
1089+
1090+ Returns
1091+ -------
1092+ result : float
1093+ The result of the dot product across all ranks. This is reduced across all processes.
10431094 """
10441095 self ._check_stacked_size (stacked_array )
10451096 dotprod = 0.
10461097 for iarr in range (self .narrays ):
1047- dotprod += self [iarr ].dot (stacked_array [iarr ])
1098+ dotprod += self [iarr ].dot (stacked_array [iarr ], vdot = vdot )
10481099 return dotprod
10491100
10501101 def norm (self , ord : Optional [int ] = None ):
@@ -1085,6 +1136,19 @@ def copy(self):
10851136 arr = StackedDistributedArray ([distarray .copy () for distarray in self .distarrays ])
10861137 return arr
10871138
1139+ def empty_like (self ):
1140+ """Creates an empty like StackedDistributedArray with uninitialized values
1141+ """
1142+ dists = []
1143+ for iarr in range (self .narrays ):
1144+ distarray = self .distarrays [iarr ]
1145+ dist = DistributedArray (global_shape = distarray .global_shape , base_comm = distarray .base_comm ,
1146+ base_comm_nccl = distarray .base_comm_nccl , partition = distarray .partition ,
1147+ axis = distarray .axis , local_shapes = distarray .local_shapes , mask = distarray .mask ,
1148+ engine = distarray .engine , dtype = distarray .dtype )
1149+ dists .append (dist )
1150+ return StackedDistributedArray (distarrays = dists )
1151+
10881152 def __repr__ (self ):
10891153 repr_dist = "\n " .join ([distarray .__repr__ () for distarray in self .distarrays ])
10901154 return f"<StackedDistributedArray with { self .narrays } distributed arrays: \n " + repr_dist
0 commit comments