Skip to content

Commit fbbfdd8

Browse files
committed
Add test_planar_cc
1 parent d8c466c commit fbbfdd8

1 file changed

Lines changed: 75 additions & 0 deletions

File tree

tests/test_planar_pcs.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dill
12
import jax
23

34
jax.config.update("jax_enable_x64", True) # double precision
@@ -7,6 +8,7 @@
78
from functools import partial
89
from numpy.testing import assert_allclose
910
from pathlib import Path
11+
import sympy as sp
1012

1113
from jsrm.systems import planar_pcs, euler_lagrangian
1214
from jsrm.utils.tolerance import Tolerance
@@ -30,6 +32,78 @@ def constant_strain_inverse_kinematics_fn(params, xi_eq, chi, s) -> Array:
3032
q = xi - xi_eq
3133
return q
3234

35+
def test_planar_cc():
36+
sym_exp_filepath = (
37+
Path(jsrm.__file__).parent / "symbolic_expressions" / "planar_pcs_ns-1.dill"
38+
)
39+
sym_exps = dill.load(open(str(sym_exp_filepath), "rb"))
40+
41+
xi_syms = sym_exps["state_syms"]["xi"]
42+
xi_d_syms = sym_exps["state_syms"]["xi_d"]
43+
44+
num_segments = len(xi_syms) // 3
45+
shear_indices = [3 * i + 1 for i in range(num_segments)]
46+
axial_indices = [3 * i + 2 for i in range(num_segments)]
47+
48+
substitutions = {}
49+
for idx in shear_indices + axial_indices:
50+
substitutions[xi_syms[idx]] = 0
51+
substitutions[xi_d_syms[idx]] = 0
52+
53+
forbidden_syms = set(substitutions.keys())
54+
55+
def remove_rows_cols(mat: sp.Matrix, remove_idxs):
56+
mat_mutable = sp.Matrix(mat)
57+
keep_rows = [i for i in range(mat_mutable.rows) if i not in remove_idxs]
58+
keep_cols = [i for i in range(mat_mutable.cols) if i not in remove_idxs]
59+
return mat_mutable.extract(keep_rows, keep_cols)
60+
61+
def remove_cols(mat: sp.Matrix, remove_idxs):
62+
mat_mutable = sp.Matrix(mat)
63+
keep_cols = [i for i in range(mat_mutable.cols) if i not in remove_idxs]
64+
keep_rows = list(range(mat_mutable.rows))
65+
return mat_mutable.extract(keep_rows, keep_cols)
66+
67+
def remove_rows(mat: sp.Matrix, remove_idxs):
68+
mat_mutable = sp.Matrix(mat)
69+
keep_rows = [i for i in range(mat_mutable.rows) if i not in remove_idxs]
70+
return mat_mutable.extract(keep_rows, [0])
71+
72+
simplified_exps = {}
73+
expected_dim = len(xi_syms) - len(shear_indices) - len(axial_indices)
74+
expected_j_cols = len(xi_syms) // 3 # one bending DOF per segment
75+
76+
for exp_key, exp_val in sym_exps["exps"].items():
77+
def simplify_and_reduce(expr: sp.Expr) -> sp.Expr:
78+
simplified_expr = sp.simplify(expr.subs(substitutions))
79+
if exp_key in {"B", "C"}:
80+
simplified_expr = remove_rows_cols(simplified_expr, shear_indices + axial_indices)
81+
assert simplified_expr.shape == (expected_dim, expected_dim)
82+
elif exp_key == "G":
83+
simplified_expr = remove_rows(simplified_expr, shear_indices + axial_indices)
84+
assert simplified_expr.shape == (expected_dim, 1)
85+
elif exp_key in {"J_sms", "J_d_sms", "Jee", "Jee_d"}:
86+
simplified_expr = remove_cols(simplified_expr, shear_indices + axial_indices)
87+
assert simplified_expr.shape == (simplified_expr.rows, expected_j_cols)
88+
elif exp_key == "J_tend_sms":
89+
simplified_expr = remove_cols(simplified_expr, shear_indices + axial_indices)
90+
assert simplified_expr.shape == (simplified_expr.rows, expected_j_cols)
91+
return simplified_expr
92+
93+
if isinstance(exp_val, list):
94+
simplified_list = []
95+
for idx, exp_item in enumerate(exp_val):
96+
simplified_item = simplify_and_reduce(exp_item)
97+
simplified_list.append(simplified_item)
98+
print(f"{exp_key}[{idx}] =\n{simplified_item}")
99+
assert forbidden_syms.isdisjoint(simplified_item.free_symbols)
100+
simplified_exps[exp_key] = simplified_list
101+
else:
102+
simplified_item = simplify_and_reduce(exp_val)
103+
simplified_exps[exp_key] = simplified_item
104+
print(f"{exp_key} =\n{simplified_item}")
105+
assert forbidden_syms.isdisjoint(simplified_item.free_symbols)
106+
33107
def test_planar_cs():
34108
sym_exp_filepath = (
35109
Path(jsrm.__file__).parent / "symbolic_expressions" / "planar_pcs_ns-1.dill"
@@ -124,4 +198,5 @@ def test_planar_cs():
124198

125199

126200
if __name__ == "__main__":
201+
test_planar_cc()
127202
test_planar_cs()

0 commit comments

Comments
 (0)