Skip to content

Commit c0fe37e

Browse files
gustavor101p-zubieta
authored andcommitted
Add optional weights and add quaternion component of rotation
1 parent 3f3c2ee commit c0fe37e

2 files changed

Lines changed: 136 additions & 44 deletions

File tree

pysages/colvars/orientation.py

Lines changed: 134 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
from jax import numpy as np
1414
from jax.numpy import linalg
1515

16-
from pysages.collective_variables.core import CollectiveVariable, AxisCV
17-
from pysages.collective_variables.coordinates import barycenter
16+
from pysages.colvars.core import CollectiveVariable, AxisCV
17+
from pysages.colvars.coordinates import barycenter, weighted_barycenter
1818

1919

20-
QUATERNION_BASES = {
20+
_QUATERNION_BASES = {
2121
0: (0, 1, 0, 0),
2222
1: (0, 0, 1, 0),
2323
2: (0, 0, 0, 1),
2424
}
2525

2626

27-
def quaternion_matrix(positions, references):
27+
def quaternion_matrix(positions, references, weights):
2828
"""
2929
Function to construct the quaternion matrix based on the reference positions.
3030
@@ -35,12 +35,24 @@ def quaternion_matrix(positions, references):
3535
references: np.array
3636
Cartesian coordinates of the reference positions of the atoms in indices.
3737
The number of coordinates must match the ones of the atoms used in indices.
38+
weights: np.array
39+
Weights for the barycenter calculation
3840
"""
39-
pos_b = barycenter(positions)
40-
ref_b = barycenter(references)
41+
if len(positions) != len(references):
42+
raise RuntimeError("References must be of the same length as the positions")
43+
pos_b = np.where(weights is None,
44+
barycenter(positions),
45+
weighted_barycenter(positions, weights))
46+
ref_b = np.where(weights is None,
47+
barycenter(references),
48+
weighted_barycenter(references, weights))
4149
R = np.zeros((3, 3))
42-
for pos, ref in zip(positions, references):
43-
R += np.outer(pos - pos_b, ref - ref_b)
50+
if weights is None:
51+
for pos, ref in zip(positions, references):
52+
R += np.outer(pos - pos_b, ref - ref_b)
53+
else:
54+
for pos, ref, w in zip(positions, references, weights):
55+
R += np.outer(w*pos - pos_b, w*ref - ref_b)
4456
S_00 = R[0, 0] + R[1, 1] + R[2, 2]
4557
S_01 = R[1, 2] - R[2, 1]
4658
S_02 = R[2, 0] - R[0, 2]
@@ -77,17 +89,20 @@ class Tilt(AxisCV):
7789
of coordinates must match the ones of the atoms used in indices.
7890
"""
7991

80-
def __init__(self, indices, axis, references):
92+
def __init__(self, indices, axis, references, weights=None):
93+
if weights is not None and len(indices) != len(weights):
94+
raise RuntimeError("Indices and weights must be of the same length")
8195
super().__init__(indices, axis)
8296
self.references = np.asarray(references)
83-
self.e = np.asarray(QUATERNION_BASES[axis])
97+
self.weights = np.asarray(weights)
98+
self.e = np.asarray(_QUATERNION_BASES[axis])
8499

85100
@property
86101
def function(self):
87-
return lambda rs: tilt(rs, self.references, self.axis)
102+
return lambda rs: tilt(rs, self.references, self.axis, self.weights)
88103

89104

90-
def tilt(r1, r2, e):
105+
def tilt(r1, r2, e, w):
91106
"""
92107
Function to calculate the tilt rotation respect to an axis.
93108
@@ -100,8 +115,10 @@ def tilt(r1, r2, e):
100115
The number of coordinates must match the ones of the atoms used in indices.
101116
e: JaxArray
102117
Quaternion rotation axis.
118+
w: JaxArray
119+
Atomic weights.
103120
"""
104-
S = quaternion_matrix(r1, r2)
121+
S = quaternion_matrix(r1, r2, w)
105122
_, v = linalg.eigh(S)
106123
v_dot_e = np.dot(v[:, 3], e)
107124
cs = np.cos(np.arctan2(v_dot_e, v[0, 3]))
@@ -124,17 +141,20 @@ class RotationEigenvalues(CollectiveVariable):
124141
The number of coordinates must match the ones of the atoms used in indices.
125142
"""
126143

127-
def __init__(self, indices, axis, references):
144+
def __init__(self, indices, axis, references, weights=None):
145+
if weights is not None and len(indices) != len(weights):
146+
raise RuntimeError("Indices and weights must be of the same length")
128147
super().__init__(indices)
129148
self.axis = axis
130149
self.references = np.asarray(references)
150+
self.weights = np.asarray(weights)
131151

132152
@property
133153
def function(self):
134-
return lambda r: rotation_eigvals(r, self.references)[self.axis]
154+
return lambda r: rotation_eigvals(r, self.references, self.weights)[self.axis]
135155

136156

137-
def rotation_eigvals(r1, r2):
157+
def rotation_eigvals(r1, r2, w):
138158
"""
139159
Returns the eigenvalue of the quaternion matrix rspect to an axis.
140160
@@ -147,7 +167,7 @@ def rotation_eigvals(r1, r2):
147167
r2: np.array
148168
Cartesian coordinates of the reference position of the atoms in r1.
149169
"""
150-
S = quaternion_matrix(r1, r2)
170+
S = quaternion_matrix(r1, r2, w)
151171
return linalg.eigvalsh(S)
152172

153173

@@ -166,17 +186,20 @@ class SpinAngle(AxisCV):
166186
The coordinates must match the ones of the atoms used in indices.
167187
"""
168188

169-
def __init__(self, indices, axis, references):
189+
def __init__(self, indices, axis, references, weights=None):
190+
if weights is not None and len(indices) != len(weights):
191+
raise RuntimeError("Indices and weights must be of the same length")
170192
super().__init__(indices, axis)
171193
self.references = np.asarray(references)
172-
self.e = np.array(QUATERNION_BASES[axis])
194+
self.e = np.array(_QUATERNION_BASES[axis])
195+
self.weights = np.asarray(weights)
173196

174197
@property
175198
def function(self):
176-
return lambda r: spin(r, self.references, self.e)
199+
return lambda r: spin(r, self.references, self.e, self.weights)
177200

178201

179-
def spin(r1, r2, e):
202+
def spin(r1, r2, e, w):
180203
"""
181204
Calculate the spin angle rotation respect to an axis.
182205
@@ -189,7 +212,7 @@ def spin(r1, r2, e):
189212
e: np.array
190213
Rotation axis.
191214
"""
192-
S = quaternion_matrix(r1, r2)
215+
S = quaternion_matrix(r1, r2, w)
193216
_, v = linalg.eigh(S)
194217
v_dot_e = np.dot(v[:, 3], e)
195218
return 2 * np.arctan2(v_dot_e, v[0, 3])
@@ -209,16 +232,19 @@ class RotationAngle(CollectiveVariable):
209232
The coordinates must match the ones of the atoms used in indices.
210233
"""
211234

212-
def __init__(self, indices, references):
235+
def __init__(self, indices, references, weights=None):
236+
if weights is not None and len(indices) != len(weights):
237+
raise RuntimeError("Indices and weights must be of the same length")
213238
super().__init__(indices)
214239
self.references = np.asarray(references)
240+
self.weights = np.asarray(weights)
215241

216242
@property
217243
def function(self):
218-
return lambda r: rotation_angle(r, self.references)
244+
return lambda r: rotation_angle(r, self.references, self.weights)
219245

220246

221-
def rotation_angle(r1, r2):
247+
def rotation_angle(r1, r2, w):
222248
"""
223249
Calculate the rotation angle respect to a reference.
224250
@@ -229,7 +255,7 @@ def rotation_angle(r1, r2):
229255
r2: np.array
230256
Cartesian coordinates of the reference position of the atoms.
231257
"""
232-
S = quaternion_matrix(r1, r2)
258+
S = quaternion_matrix(r1, r2, w)
233259
_, v = linalg.eigh(S)
234260
return 2 * np.arccos(v[0, 3])
235261

@@ -248,16 +274,19 @@ class RotationProjection(CollectiveVariable):
248274
The coordinates must match the ones of the atoms used in indices.
249275
"""
250276

251-
def __init__(self, indices, references):
277+
def __init__(self, indices, references, weights=None):
278+
if weights is not None and len(indices) != len(weights):
279+
raise RuntimeError("Indices and weights must be of the same length")
252280
super().__init__(indices)
253281
self.references = np.asarray(references)
282+
self.weights = np.asarray(weights)
254283

255284
@property
256285
def function(self):
257-
return lambda r: rotation_projection(r, self.references)
286+
return lambda r: rotation_projection(r, self.references, self.weights)
258287

259288

260-
def rotation_projection(r1, r2):
289+
def rotation_projection(r1, r2, w):
261290
"""
262291
Calculate the rotation angle projection.
263292
@@ -268,7 +297,7 @@ def rotation_projection(r1, r2):
268297
r2: np.array
269298
Cartesian coordinates of the reference position of the atoms.
270299
"""
271-
S = quaternion_matrix(r1, r2)
300+
S = quaternion_matrix(r1, r2, w)
272301
_, v = linalg.eigh(S)
273302
return 2 * v[0, 3] * v[0, 3] - 1
274303

@@ -287,16 +316,19 @@ class RMSD(CollectiveVariable):
287316
The coordinates must match the ones of the atoms used in indices.
288317
"""
289318

290-
def __init__(self, indices, references):
319+
def __init__(self, indices, references, weights=None):
320+
if weights is not None and len(indices) != len(weights):
321+
raise RuntimeError("Indices and weights must be of the same length")
291322
super().__init__(indices)
292323
self.references = np.asarray(references)
324+
self.weights = np.asarray(weights)
293325

294326
@property
295327
def function(self):
296-
return lambda r: rmsd(r, self.references)
328+
return lambda r: rmsd(r, self.references, self.weights)
297329

298330

299-
def sq_norm_rotation(positions, references):
331+
def sq_norm_rotation(positions, references, weights):
300332
"""
301333
Calculate the squared norm of the atomic positions and references respect to barycenters.
302334
@@ -307,17 +339,31 @@ def sq_norm_rotation(positions, references):
307339
references: np.array
308340
Cartesian coordinates of the reference position of the atoms.
309341
"""
310-
pos_b = barycenter(positions)
311-
ref_b = barycenter(references)
342+
if len(positions) != len(references):
343+
raise RuntimeError("References must be of the same length as the positions")
344+
pos_b = np.where(weights is None,
345+
barycenter(positions),
346+
weighted_barycenter(positions, weights))
347+
ref_b = np.where(weights is None,
348+
barycenter(references),
349+
weighted_barycenter(references, weights))
312350
R = 0.0
313-
for pos, ref in zip(positions, references):
314-
pos -= pos_b
315-
ref -= ref_b
316-
R += np.dot(pos, pos) + np.dot(ref, ref)
351+
if weights is None:
352+
for pos, ref in zip(positions, references):
353+
pos -= pos_b
354+
ref -= ref_b
355+
R += np.dot(pos, pos) + np.dot(ref, ref)
356+
else:
357+
for pos, ref, w in zip(positions, references, weights):
358+
pos *= w
359+
ref *= w
360+
pos -= pos_b
361+
ref -= ref_b
362+
R += np.dot(pos, pos) + np.dot(ref, ref)
317363
return R
318364

319365

320-
def rmsd(r1, r2):
366+
def rmsd(r1, r2, w_0):
321367
"""
322368
Calculate the rmsd respect to a reference using quaternions.
323369
@@ -329,7 +375,53 @@ def rmsd(r1, r2):
329375
Cartesian coordinates of the reference position of the atoms.
330376
"""
331377
N = r1.shape[0]
332-
S = quaternion_matrix(r1, r2)
378+
S = quaternion_matrix(r1, r2, w_0)
333379
w = linalg.eigvalsh(S)
334-
norm_sq = sq_norm_rotation(r1, r2)
380+
norm_sq = sq_norm_rotation(r1, r2, w_0)
335381
return np.sqrt((norm_sq - 2 * np.max(w)) / N)
382+
383+
384+
class QuaternionComponent(CollectiveVariable):
385+
"""
386+
Calculate the quaternion component of the rotation respect to a reference.
387+
388+
Parameters
389+
----------
390+
indices: list[int], list[tuple(int)]
391+
Select atom groups via indices.
392+
axis: int
393+
Index for the component of the quaternion (a value form `0` to `3`).
394+
references: list[tuple(float)]
395+
Cartesian coordinates of the reference position of the atoms in indices.
396+
The number of coordinates must match the ones of the atoms used in indices.
397+
"""
398+
399+
def __init__(self, indices, axis, references, weights=None):
400+
if weights is not None and len(indices) != len(weights):
401+
raise RuntimeError("Indices and weights must be of the same length")
402+
super().__init__(indices)
403+
self.axis = axis
404+
self.references = np.asarray(references)
405+
self.weights = np.asarray(weights)
406+
407+
@property
408+
def function(self):
409+
return lambda r: quaternion_component(r, self.references, self.weights)[self.axis]
410+
411+
412+
def quaternion_component(r1, r2, w):
413+
"""
414+
Returns the eigenvalue of the quaternion matrix rspect to an axis.
415+
416+
Parameters
417+
----------
418+
r1: np.array
419+
Atomic positions.
420+
axis: int
421+
select the component of the quaternion for rotation.
422+
r2: np.array
423+
Cartesian coordinates of the reference position of the atoms in r1.
424+
"""
425+
S = quaternion_matrix(r1, r2, w)
426+
_, v = linalg.eigh(S)
427+
return v

pysages/colvars/shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class RadiusOfGyration(CollectiveVariable):
2525
squared: Optional[bool]
2626
Indicates whether to return the squared value.
2727
weights: Optional[JaxArray]
28-
If providede the weighted radius_of_gyration will be computed.
28+
If providede the weighted radius of gyration will be computed.
2929
"""
3030

3131
def __init__(self, indices, squared=False, weights=None):
@@ -48,7 +48,7 @@ def function(self):
4848
else:
4949
rog = radius_of_gyration
5050

51-
return rog if self.squared else jit(lambda rs: np.sqrt(rog(rs)))
51+
return jit(rog) if self.squared else jit(lambda rs: np.sqrt(rog(rs)))
5252

5353

5454
def radius_of_gyration(positions):

0 commit comments

Comments
 (0)