1313from jax import numpy as np
1414from jax .numpy import linalg
1515
16- from pysages .collective_variables .core import CollectiveVariable , AxisCV
17- from pysages .collective_variables .coordinates import barycenter
16+ from pysages .colvars .core import CollectiveVariable , AxisCV
17+ from pysages .colvars .coordinates import barycenter , weighted_barycenter
1818
1919
20- QUATERNION_BASES = {
20+ _QUATERNION_BASES = {
2121 0 : (0 , 1 , 0 , 0 ),
2222 1 : (0 , 0 , 1 , 0 ),
2323 2 : (0 , 0 , 0 , 1 ),
2424}
2525
2626
27- def quaternion_matrix (positions , references ):
27+ def quaternion_matrix (positions , references , weights ):
2828 """
2929 Function to construct the quaternion matrix based on the reference positions.
3030
@@ -35,12 +35,24 @@ def quaternion_matrix(positions, references):
3535 references: np.array
3636 Cartesian coordinates of the reference positions of the atoms in indices.
3737 The number of coordinates must match the ones of the atoms used in indices.
38+ weights: np.array
39+ Weights for the barycenter calculation
3840 """
39- pos_b = barycenter (positions )
40- ref_b = barycenter (references )
41+ if len (positions ) != len (references ):
42+ raise RuntimeError ("References must be of the same length as the positions" )
43+ pos_b = np .where (weights is None ,
44+ barycenter (positions ),
45+ weighted_barycenter (positions , weights ))
46+ ref_b = np .where (weights is None ,
47+ barycenter (references ),
48+ weighted_barycenter (references , weights ))
4149 R = np .zeros ((3 , 3 ))
42- for pos , ref in zip (positions , references ):
43- R += np .outer (pos - pos_b , ref - ref_b )
50+ if weights is None :
51+ for pos , ref in zip (positions , references ):
52+ R += np .outer (pos - pos_b , ref - ref_b )
53+ else :
54+ for pos , ref , w in zip (positions , references , weights ):
55+ R += np .outer (w * pos - pos_b , w * ref - ref_b )
4456 S_00 = R [0 , 0 ] + R [1 , 1 ] + R [2 , 2 ]
4557 S_01 = R [1 , 2 ] - R [2 , 1 ]
4658 S_02 = R [2 , 0 ] - R [0 , 2 ]
@@ -77,17 +89,20 @@ class Tilt(AxisCV):
7789 of coordinates must match the ones of the atoms used in indices.
7890 """
7991
80- def __init__ (self , indices , axis , references ):
92+ def __init__ (self , indices , axis , references , weights = None ):
93+ if weights is not None and len (indices ) != len (weights ):
94+ raise RuntimeError ("Indices and weights must be of the same length" )
8195 super ().__init__ (indices , axis )
8296 self .references = np .asarray (references )
83- self .e = np .asarray (QUATERNION_BASES [axis ])
97+ self .weights = np .asarray (weights )
98+ self .e = np .asarray (_QUATERNION_BASES [axis ])
8499
85100 @property
86101 def function (self ):
87- return lambda rs : tilt (rs , self .references , self .axis )
102+ return lambda rs : tilt (rs , self .references , self .axis , self . weights )
88103
89104
90- def tilt (r1 , r2 , e ):
105+ def tilt (r1 , r2 , e , w ):
91106 """
92107 Function to calculate the tilt rotation respect to an axis.
93108
@@ -100,8 +115,10 @@ def tilt(r1, r2, e):
100115 The number of coordinates must match the ones of the atoms used in indices.
101116 e: JaxArray
102117 Quaternion rotation axis.
118+ w: JaxArray
119+ Atomic weights.
103120 """
104- S = quaternion_matrix (r1 , r2 )
121+ S = quaternion_matrix (r1 , r2 , w )
105122 _ , v = linalg .eigh (S )
106123 v_dot_e = np .dot (v [:, 3 ], e )
107124 cs = np .cos (np .arctan2 (v_dot_e , v [0 , 3 ]))
@@ -124,17 +141,20 @@ class RotationEigenvalues(CollectiveVariable):
124141 The number of coordinates must match the ones of the atoms used in indices.
125142 """
126143
127- def __init__ (self , indices , axis , references ):
144+ def __init__ (self , indices , axis , references , weights = None ):
145+ if weights is not None and len (indices ) != len (weights ):
146+ raise RuntimeError ("Indices and weights must be of the same length" )
128147 super ().__init__ (indices )
129148 self .axis = axis
130149 self .references = np .asarray (references )
150+ self .weights = np .asarray (weights )
131151
132152 @property
133153 def function (self ):
134- return lambda r : rotation_eigvals (r , self .references )[self .axis ]
154+ return lambda r : rotation_eigvals (r , self .references , self . weights )[self .axis ]
135155
136156
137- def rotation_eigvals (r1 , r2 ):
157+ def rotation_eigvals (r1 , r2 , w ):
138158 """
139159 Returns the eigenvalue of the quaternion matrix rspect to an axis.
140160
@@ -147,7 +167,7 @@ def rotation_eigvals(r1, r2):
147167 r2: np.array
148168 Cartesian coordinates of the reference position of the atoms in r1.
149169 """
150- S = quaternion_matrix (r1 , r2 )
170+ S = quaternion_matrix (r1 , r2 , w )
151171 return linalg .eigvalsh (S )
152172
153173
@@ -166,17 +186,20 @@ class SpinAngle(AxisCV):
166186 The coordinates must match the ones of the atoms used in indices.
167187 """
168188
169- def __init__ (self , indices , axis , references ):
189+ def __init__ (self , indices , axis , references , weights = None ):
190+ if weights is not None and len (indices ) != len (weights ):
191+ raise RuntimeError ("Indices and weights must be of the same length" )
170192 super ().__init__ (indices , axis )
171193 self .references = np .asarray (references )
172- self .e = np .array (QUATERNION_BASES [axis ])
194+ self .e = np .array (_QUATERNION_BASES [axis ])
195+ self .weights = np .asarray (weights )
173196
174197 @property
175198 def function (self ):
176- return lambda r : spin (r , self .references , self .e )
199+ return lambda r : spin (r , self .references , self .e , self . weights )
177200
178201
179- def spin (r1 , r2 , e ):
202+ def spin (r1 , r2 , e , w ):
180203 """
181204 Calculate the spin angle rotation respect to an axis.
182205
@@ -189,7 +212,7 @@ def spin(r1, r2, e):
189212 e: np.array
190213 Rotation axis.
191214 """
192- S = quaternion_matrix (r1 , r2 )
215+ S = quaternion_matrix (r1 , r2 , w )
193216 _ , v = linalg .eigh (S )
194217 v_dot_e = np .dot (v [:, 3 ], e )
195218 return 2 * np .arctan2 (v_dot_e , v [0 , 3 ])
@@ -209,16 +232,19 @@ class RotationAngle(CollectiveVariable):
209232 The coordinates must match the ones of the atoms used in indices.
210233 """
211234
212- def __init__ (self , indices , references ):
235+ def __init__ (self , indices , references , weights = None ):
236+ if weights is not None and len (indices ) != len (weights ):
237+ raise RuntimeError ("Indices and weights must be of the same length" )
213238 super ().__init__ (indices )
214239 self .references = np .asarray (references )
240+ self .weights = np .asarray (weights )
215241
216242 @property
217243 def function (self ):
218- return lambda r : rotation_angle (r , self .references )
244+ return lambda r : rotation_angle (r , self .references , self . weights )
219245
220246
221- def rotation_angle (r1 , r2 ):
247+ def rotation_angle (r1 , r2 , w ):
222248 """
223249 Calculate the rotation angle respect to a reference.
224250
@@ -229,7 +255,7 @@ def rotation_angle(r1, r2):
229255 r2: np.array
230256 Cartesian coordinates of the reference position of the atoms.
231257 """
232- S = quaternion_matrix (r1 , r2 )
258+ S = quaternion_matrix (r1 , r2 , w )
233259 _ , v = linalg .eigh (S )
234260 return 2 * np .arccos (v [0 , 3 ])
235261
@@ -248,16 +274,19 @@ class RotationProjection(CollectiveVariable):
248274 The coordinates must match the ones of the atoms used in indices.
249275 """
250276
251- def __init__ (self , indices , references ):
277+ def __init__ (self , indices , references , weights = None ):
278+ if weights is not None and len (indices ) != len (weights ):
279+ raise RuntimeError ("Indices and weights must be of the same length" )
252280 super ().__init__ (indices )
253281 self .references = np .asarray (references )
282+ self .weights = np .asarray (weights )
254283
255284 @property
256285 def function (self ):
257- return lambda r : rotation_projection (r , self .references )
286+ return lambda r : rotation_projection (r , self .references , self . weights )
258287
259288
260- def rotation_projection (r1 , r2 ):
289+ def rotation_projection (r1 , r2 , w ):
261290 """
262291 Calculate the rotation angle projection.
263292
@@ -268,7 +297,7 @@ def rotation_projection(r1, r2):
268297 r2: np.array
269298 Cartesian coordinates of the reference position of the atoms.
270299 """
271- S = quaternion_matrix (r1 , r2 )
300+ S = quaternion_matrix (r1 , r2 , w )
272301 _ , v = linalg .eigh (S )
273302 return 2 * v [0 , 3 ] * v [0 , 3 ] - 1
274303
@@ -287,16 +316,19 @@ class RMSD(CollectiveVariable):
287316 The coordinates must match the ones of the atoms used in indices.
288317 """
289318
290- def __init__ (self , indices , references ):
319+ def __init__ (self , indices , references , weights = None ):
320+ if weights is not None and len (indices ) != len (weights ):
321+ raise RuntimeError ("Indices and weights must be of the same length" )
291322 super ().__init__ (indices )
292323 self .references = np .asarray (references )
324+ self .weights = np .asarray (weights )
293325
294326 @property
295327 def function (self ):
296- return lambda r : rmsd (r , self .references )
328+ return lambda r : rmsd (r , self .references , self . weights )
297329
298330
299- def sq_norm_rotation (positions , references ):
331+ def sq_norm_rotation (positions , references , weights ):
300332 """
301333 Calculate the squared norm of the atomic positions and references respect to barycenters.
302334
@@ -307,17 +339,31 @@ def sq_norm_rotation(positions, references):
307339 references: np.array
308340 Cartesian coordinates of the reference position of the atoms.
309341 """
310- pos_b = barycenter (positions )
311- ref_b = barycenter (references )
342+ if len (positions ) != len (references ):
343+ raise RuntimeError ("References must be of the same length as the positions" )
344+ pos_b = np .where (weights is None ,
345+ barycenter (positions ),
346+ weighted_barycenter (positions , weights ))
347+ ref_b = np .where (weights is None ,
348+ barycenter (references ),
349+ weighted_barycenter (references , weights ))
312350 R = 0.0
313- for pos , ref in zip (positions , references ):
314- pos -= pos_b
315- ref -= ref_b
316- R += np .dot (pos , pos ) + np .dot (ref , ref )
351+ if weights is None :
352+ for pos , ref in zip (positions , references ):
353+ pos -= pos_b
354+ ref -= ref_b
355+ R += np .dot (pos , pos ) + np .dot (ref , ref )
356+ else :
357+ for pos , ref , w in zip (positions , references , weights ):
358+ pos *= w
359+ ref *= w
360+ pos -= pos_b
361+ ref -= ref_b
362+ R += np .dot (pos , pos ) + np .dot (ref , ref )
317363 return R
318364
319365
320- def rmsd (r1 , r2 ):
366+ def rmsd (r1 , r2 , w_0 ):
321367 """
322368 Calculate the rmsd respect to a reference using quaternions.
323369
@@ -329,7 +375,53 @@ def rmsd(r1, r2):
329375 Cartesian coordinates of the reference position of the atoms.
330376 """
331377 N = r1 .shape [0 ]
332- S = quaternion_matrix (r1 , r2 )
378+ S = quaternion_matrix (r1 , r2 , w_0 )
333379 w = linalg .eigvalsh (S )
334- norm_sq = sq_norm_rotation (r1 , r2 )
380+ norm_sq = sq_norm_rotation (r1 , r2 , w_0 )
335381 return np .sqrt ((norm_sq - 2 * np .max (w )) / N )
382+
383+
384+ class QuaternionComponent (CollectiveVariable ):
385+ """
386+ Calculate the quaternion component of the rotation respect to a reference.
387+
388+ Parameters
389+ ----------
390+ indices: list[int], list[tuple(int)]
391+ Select atom groups via indices.
392+ axis: int
393+ Index for the component of the quaternion (a value form `0` to `3`).
394+ references: list[tuple(float)]
395+ Cartesian coordinates of the reference position of the atoms in indices.
396+ The number of coordinates must match the ones of the atoms used in indices.
397+ """
398+
399+ def __init__ (self , indices , axis , references , weights = None ):
400+ if weights is not None and len (indices ) != len (weights ):
401+ raise RuntimeError ("Indices and weights must be of the same length" )
402+ super ().__init__ (indices )
403+ self .axis = axis
404+ self .references = np .asarray (references )
405+ self .weights = np .asarray (weights )
406+
407+ @property
408+ def function (self ):
409+ return lambda r : quaternion_component (r , self .references , self .weights )[self .axis ]
410+
411+
412+ def quaternion_component (r1 , r2 , w ):
413+ """
414+ Returns the eigenvalue of the quaternion matrix rspect to an axis.
415+
416+ Parameters
417+ ----------
418+ r1: np.array
419+ Atomic positions.
420+ axis: int
421+ select the component of the quaternion for rotation.
422+ r2: np.array
423+ Cartesian coordinates of the reference position of the atoms in r1.
424+ """
425+ S = quaternion_matrix (r1 , r2 , w )
426+ _ , v = linalg .eigh (S )
427+ return v
0 commit comments