Skip to content

Commit f92462c

Browse files
committed
replacing get_axes with principal_axes for all beads/levels
1 parent 6f929a5 commit f92462c

2 files changed

Lines changed: 39 additions & 498 deletions

File tree

CodeEntropy/levels.py

Lines changed: 18 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def get_matrices(
8080
self,
8181
data_container,
8282
level,
83-
number_frames,
8483
highest_level,
8584
force_matrix,
8685
torque_matrix,
@@ -92,7 +91,6 @@ def get_matrices(
9291
Parameters:
9392
data_container (MDAnalysis.Universe): Data for a molecule or residue.
9493
level (str): 'polymer', 'residue', or 'united_atom'.
95-
number_frames (int): Number of frames being processed.
9694
highest_level (bool): Whether this is the top (largest bead size) level.
9795
force_matrix, torque_matrix (np.ndarray or None): Accumulated matrices to add
9896
to.
@@ -116,21 +114,23 @@ def get_matrices(
116114

117115
# Calculate forces/torques for each bead
118116
for bead_index in range(number_beads):
117+
bead = list_of_beads[bead_index]
119118
# Set up axes
120119
# translation and rotation use different axes
121120
# how the axes are defined depends on the level
122-
trans_axes, rot_axes = self.get_axes(data_container, level, bead_index)
121+
trans_axes = data_container.atoms.principal_axes()
122+
rot_axes = np.real(bead.principal_axes())
123123

124124
# Sort out coordinates, forces, and torques for each atom in the bead
125125
weighted_forces[bead_index] = self.get_weighted_forces(
126126
data_container,
127-
list_of_beads[bead_index],
127+
bead,
128128
trans_axes,
129129
highest_level,
130130
force_partitioning,
131131
)
132132
weighted_torques[bead_index] = self.get_weighted_torques(
133-
data_container, list_of_beads[bead_index], rot_axes, force_partitioning
133+
data_container, bead, rot_axes, force_partitioning
134134
)
135135

136136
# Create covariance submatrices
@@ -233,217 +233,6 @@ def get_beads(self, data_container, level):
233233

234234
return list_of_beads
235235

236-
def get_axes(self, data_container, level, index=0):
237-
"""
238-
Function to set the translational and rotational axes.
239-
The translational axes are based on the principal axes of the unit
240-
one level larger than the level we are interested in (except for
241-
the polymer level where there is no larger unit). The rotational
242-
axes use the covalent links between residues or atoms where possible
243-
to define the axes, or if the unit is not bonded to others of the
244-
same level the prinicpal axes of the unit are used.
245-
246-
Args:
247-
data_container (MDAnalysis.Universe): the molecule and trajectory data
248-
level (str): the level (united atom, residue, or polymer) of interest
249-
index (int): residue index
250-
251-
Returns:
252-
trans_axes : translational axes
253-
rot_axes : rotational axes
254-
"""
255-
index = int(index)
256-
257-
if level == "polymer":
258-
# for polymer use principle axis for both translation and rotation
259-
trans_axes = data_container.atoms.principal_axes()
260-
rot_axes = data_container.atoms.principal_axes()
261-
262-
elif level == "residue":
263-
# Translation
264-
# for residues use principal axes of whole molecule for translation
265-
trans_axes = data_container.atoms.principal_axes()
266-
267-
# Rotation
268-
# find bonds between atoms in residue of interest and other residues
269-
# we are assuming bonds only exist between adjacent residues
270-
# (linear chains of residues)
271-
# TODO refine selection so that it will work for branched polymers
272-
index_prev = index - 1
273-
index_next = index + 1
274-
atom_set = data_container.select_atoms(
275-
f"(resindex {index_prev} or resindex {index_next}) "
276-
f"and bonded resid {index}"
277-
)
278-
residue = data_container.select_atoms(f"resindex {index}")
279-
280-
if len(atom_set) == 0:
281-
# if no bonds to other residues use pricipal axes of residue
282-
rot_axes = residue.atoms.principal_axes()
283-
284-
else:
285-
# set center of rotation to center of mass of the residue
286-
center = residue.atoms.center_of_mass()
287-
288-
# get vector for average position of bonded atoms
289-
vector = self.get_avg_pos(atom_set, center)
290-
291-
# use spherical coordinates function to get rotational axes
292-
rot_axes = self.get_sphCoord_axes(vector)
293-
294-
elif level == "united_atom":
295-
# Translation
296-
# for united atoms use principal axes of residue for translation
297-
trans_axes = data_container.residues.principal_axes()
298-
299-
# Rotation
300-
# for united atoms use heavy atoms bonded to the heavy atom
301-
atom_set = data_container.select_atoms(
302-
f"(prop mass > 1.1) and bonded index {index}"
303-
)
304-
305-
if len(atom_set) == 0:
306-
# if no bonds to other residues use pricipal axes of residue
307-
rot_axes = data_container.residues.principal_axes()
308-
else:
309-
# center at position of heavy atom
310-
atom_group = data_container.select_atoms(f"index {index}")
311-
center = atom_group.positions[0]
312-
313-
# get vector for average position of bonded atoms
314-
vector = self.get_avg_pos(atom_set, center)
315-
316-
# use spherical coordinates function to get rotational axes
317-
rot_axes = self.get_sphCoord_axes(vector)
318-
319-
logger.debug(f"Translational Axes: {trans_axes}")
320-
logger.debug(f"Rotational Axes: {rot_axes}")
321-
322-
return trans_axes, rot_axes
323-
324-
def get_avg_pos(self, atom_set, center):
325-
"""
326-
Function to get the average position of a set of atoms.
327-
328-
Args:
329-
atom_set : MDAnalysis atom group
330-
center : position for center of rotation
331-
332-
Returns:
333-
avg_position : three dimensional vector
334-
"""
335-
# start with an empty vector
336-
avg_position = np.zeros((3))
337-
338-
# get number of atoms
339-
number_atoms = len(atom_set.names)
340-
341-
if number_atoms != 0:
342-
# sum positions for all atoms in the given set
343-
for atom_index in range(number_atoms):
344-
atom_position = atom_set.atoms[atom_index].position
345-
346-
avg_position += atom_position
347-
348-
avg_position /= number_atoms # divide by number of atoms to get average
349-
350-
else:
351-
# if no atoms in set the unit has no bonds to restrict its rotational
352-
# motion, so we can use a random vector to get spherical
353-
# coordinate axes
354-
avg_position = np.random.random(3)
355-
356-
# transform the average position to a coordinate system with the origin
357-
# at center
358-
avg_position = avg_position - center
359-
360-
logger.debug(f"Average Position: {avg_position}")
361-
362-
return avg_position
363-
364-
def get_sphCoord_axes(self, arg_r):
365-
"""
366-
For a given vector in space, treat it is a radial vector rooted at
367-
0,0,0 and derive a curvilinear coordinate system according to the
368-
rules of polar spherical coordinates
369-
370-
Args:
371-
arg_r: 3 dimensional vector
372-
373-
Returns:
374-
spherical_basis: axes set (3 vectors)
375-
"""
376-
377-
x2y2 = arg_r[0] ** 2 + arg_r[1] ** 2
378-
r2 = x2y2 + arg_r[2] ** 2
379-
380-
# Check for division by zero
381-
if r2 == 0.0:
382-
raise ValueError("r2 is zero, cannot compute spherical coordinates.")
383-
384-
if x2y2 == 0.0:
385-
raise ValueError("x2y2 is zero, cannot compute sin_phi and cos_phi.")
386-
387-
# These conditions are mathematically unreachable for real-valued vectors.
388-
# Marked as no cover to avoid false negatives in coverage reports.
389-
390-
# Check for non-negative values inside the square root
391-
if x2y2 / r2 < 0: # pragma: no cover
392-
raise ValueError(
393-
f"Negative value encountered for sin_theta calculation: {x2y2 / r2}. "
394-
f"Cannot take square root."
395-
)
396-
397-
if x2y2 < 0: # pragma: no cover
398-
raise ValueError(
399-
f"Negative value encountered for sin_phi and cos_phi "
400-
f"calculation: {x2y2}. "
401-
f"Cannot take square root."
402-
)
403-
404-
if x2y2 != 0.0:
405-
sin_theta = np.sqrt(x2y2 / r2)
406-
cos_theta = arg_r[2] / np.sqrt(r2)
407-
408-
sin_phi = arg_r[1] / np.sqrt(x2y2)
409-
cos_phi = arg_r[0] / np.sqrt(x2y2)
410-
411-
else: # pragma: no cover
412-
sin_theta = 0.0
413-
cos_theta = 1
414-
415-
sin_phi = 0.0
416-
cos_phi = 1
417-
418-
# if abs(sin_theta) > 1 or abs(sin_phi) > 1:
419-
# print('Bad sine : T {} , P {}'.format(sin_theta, sin_phi))
420-
421-
# cos_theta = np.sqrt(1 - sin_theta*sin_theta)
422-
# cos_phi = np.sqrt(1 - sin_phi*sin_phi)
423-
424-
# print('{} {} {}'.format(*arg_r))
425-
# print('Sin T : {}, cos T : {}'.format(sin_theta, cos_theta))
426-
# print('Sin P : {}, cos P : {}'.format(sin_phi, cos_phi))
427-
428-
spherical_basis = np.zeros((3, 3))
429-
430-
# r^
431-
spherical_basis[0, :] = np.asarray(
432-
[sin_theta * cos_phi, sin_theta * sin_phi, cos_theta]
433-
)
434-
435-
# Theta^
436-
spherical_basis[1, :] = np.asarray(
437-
[cos_theta * cos_phi, cos_theta * sin_phi, -sin_theta]
438-
)
439-
440-
# Phi^
441-
spherical_basis[2, :] = np.asarray([-sin_phi, cos_phi, 0.0])
442-
443-
logger.debug(f"Spherical Basis: {spherical_basis}")
444-
445-
return spherical_basis
446-
447236
def get_weighted_forces(
448237
self, data_container, bead, trans_axes, highest_level, force_partitioning
449238
):
@@ -526,9 +315,9 @@ def get_weighted_torques(self, data_container, bead, rot_axes, force_partitionin
526315

527316
torques = np.zeros((3,))
528317
weighted_torque = np.zeros((3,))
318+
moment_of_inertia = np.zeros(3)
529319

530320
for atom in bead.atoms:
531-
532321
# update local coordinates in rotational axes
533322
coords_rot = (
534323
data_container.atoms[atom.index].position - bead.center_of_mass()
@@ -547,10 +336,15 @@ def get_weighted_torques(self, data_container, bead, rot_axes, force_partitionin
547336
torques_local = np.cross(coords_rot, forces_rot)
548337
torques += torques_local
549338

339+
# multiply by the force_partitioning parameter to avoid double counting
340+
# of the forces on weakly correlated atoms
341+
# the default value of force_partitioning is 0.5 (dividing by two)
342+
# torques = force_partitioning * torques
343+
550344
# divide by moment of inertia to get weighted torques
551345
# moment of inertia is a 3x3 tensor
552-
# the weighting is done in each dimension (x,y,z) using the diagonal
553-
# elements of the moment of inertia tensor
346+
# the weighting is done in each dimension (x,y,z) using
347+
# the diagonal elements of the moment of inertia tensor
554348
moment_of_inertia = bead.moment_of_inertia()
555349

556350
for dimension in range(3):
@@ -561,16 +355,16 @@ def get_weighted_torques(self, data_container, bead, rot_axes, force_partitionin
561355

562356
# Check for zero moment of inertia
563357
if np.isclose(moment_of_inertia[dimension, dimension], 0):
564-
raise ZeroDivisionError(
565-
f"Attempted to divide by zero moment of inertia in dimension "
566-
f"{dimension}."
567-
)
358+
# If moment of inertia is 0 there should be 0 torque
359+
weighted_torque[dimension] = 0
360+
logger.warning("Zero moment of inertia. Setting torque to 0")
361+
continue
568362

569363
# Check for negative moment of inertia
570364
if moment_of_inertia[dimension, dimension] < 0:
571365
raise ValueError(
572366
f"Negative value encountered for moment of inertia: "
573-
f"{moment_of_inertia[dimension, dimension]} "
367+
f"{moment_of_inertia[dimension]} "
574368
f"Cannot compute weighted torque."
575369
)
576370

@@ -817,7 +611,6 @@ def update_force_torque_matrices(
817611
f_mat, t_mat = self.get_matrices(
818612
res,
819613
level,
820-
num_frames,
821614
highest,
822615
None if key not in force_avg["ua"] else force_avg["ua"][key],
823616
None if key not in torque_avg["ua"] else torque_avg["ua"][key],
@@ -847,7 +640,6 @@ def update_force_torque_matrices(
847640
f_mat, t_mat = self.get_matrices(
848641
mol,
849642
level,
850-
num_frames,
851643
highest,
852644
None if force_avg[key][group_id] is None else force_avg[key][group_id],
853645
(

0 commit comments

Comments
 (0)