Skip to content

Commit bcaed6a

Browse files
committed
Merge branch 'chahak13-2d'
2 parents 66f34e1 + 0225856 commit bcaed6a

18 files changed

Lines changed: 1149 additions & 443 deletions

diffmpm/constraint.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from jax.tree_util import register_pytree_node_class
2+
3+
4+
@register_pytree_node_class
5+
class Constraint:
6+
def __init__(self, dir, velocity):
7+
self.dir = dir
8+
self.velocity = velocity
9+
10+
def tree_flatten(self):
11+
return ((), (self.dir, self.velocity))
12+
13+
@classmethod
14+
def tree_unflatten(cls, aux_data, children):
15+
del children
16+
return cls(*aux_data)
17+
18+
def apply(self, obj, ids):
19+
"""
20+
Apply constraint values to the passed object.
21+
22+
Arguments
23+
---------
24+
obj : diffmpm.node.Nodes, diffmpm.particle.Particles
25+
Object on which the constraint is applied
26+
ids : array_like
27+
The indices of the container `obj` on which the constraint
28+
will be applied.
29+
"""
30+
obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity)
31+
obj.momentum = obj.momentum.at[ids, :, self.dir].set(
32+
obj.mass[ids, :, 0] * self.velocity
33+
)
34+
obj.acceleration = obj.acceleration.at[ids, :, self.dir].set(0)

0 commit comments

Comments
 (0)