diff --git a/fuse/traces.py b/fuse/traces.py index b5f9722..0719312 100644 --- a/fuse/traces.py +++ b/fuse/traces.py @@ -98,11 +98,16 @@ def tabulate(self, Qwts, trace_entity): return result def manipulate_basis(self, basis): - if basis.shape == (1, 1): + if basis.shape[-1] == 1: return basis elif basis.shape == (1, 2): result = np.matmul(basis, np.array([[0, -1], [1, 0]])) - elif basis.shape[0] == 2: + elif basis.shape == (2, 2): + # Two dim cross product - pad with zeros and take z component of result + zeros_row = np.zeros((basis.shape[0], 1), dtype=basis.dtype) + basis = np.hstack([basis, zeros_row]) + result = np.cross(basis[0], basis[1])[2] + elif basis.shape == (2, 3): result = np.cross(basis[0], basis[1]) else: raise ValueError("Immersion of HDiv edges not defined in 3D")