Skip to content

Commit 9d29d31

Browse files
committed
fix: adding add_group and add points from dataframe for geoh5 + some tests
1 parent df938d1 commit 9d29d31

2 files changed

Lines changed: 164 additions & 41 deletions

File tree

LoopStructural/export/geoh5.py

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,29 @@
55

66
from LoopStructural.datatypes import ValuePoints, VectorPoints
77

8-
9-
def add_surface_to_geoh5(filename, surface, overwrite=True, groupname="Loop"):
8+
def add_group_to_geoh5(filename, groupname="Loop", parent=None, overwrite=True):
109
with geoh5py.workspace.Workspace(filename) as workspace:
10+
1111
group = workspace.get_entity(groupname)[0]
12+
if group and overwrite:
13+
group.allow_delete = True
14+
workspace.remove_entity(group)
1215
if not group:
1316
group = geoh5py.groups.ContainerGroup.create(
1417
workspace, name=groupname, allow_delete=True
1518
)
19+
if parent is not None:
20+
parent = workspace.get_entity(parent)[0]
21+
if parent:
22+
parent.add_children(group)
23+
return group.uid
24+
def add_surface_to_geoh5(filename, surface, overwrite=True, group="Loop"):
25+
with geoh5py.workspace.Workspace(filename) as workspace:
26+
group = workspace.get_entity(group)[0]
27+
if not group:
28+
group = geoh5py.groups.ContainerGroup.create(
29+
workspace, name=group, allow_delete=True
30+
)
1631
if surface.name in workspace.list_entities_name.values():
1732
existing_surf = workspace.get_entity(surface.name)
1833
existing_surf[0].allow_delete = True
@@ -34,6 +49,7 @@ def add_surface_to_geoh5(filename, surface, overwrite=True, groupname="Loop"):
3449

3550
def add_points_to_geoh5(filename, point, overwrite=True, groupname="Loop"):
3651
with geoh5py.workspace.Workspace(filename) as workspace:
52+
3753
group = workspace.get_entity(groupname)[0]
3854
if not group:
3955
group = geoh5py.groups.ContainerGroup.create(
@@ -54,56 +70,73 @@ def add_points_to_geoh5(filename, point, overwrite=True, groupname="Loop"):
5470
data['vz'] = {'association': "VERTEX", "values": point.vectors[:, 2]}
5571

5672
if isinstance(point, ValuePoints):
57-
data['val'] = {'association': "VERTEX", "values": point.values}
73+
data['values'] = {'association': "VERTEX", "values": point.values}
5874
point = geoh5py.objects.Points.create(
5975
workspace,
6076
name=point.name,
6177
vertices=point.locations,
6278
parent=group,
6379
)
6480
point.add_data(data)
81+
82+
def overwrite_object(workspace, name, overwrite):
83+
if name in workspace.list_entities_name.values():
84+
existing_entity = workspace.get_entity(name)
85+
existing_entity[0].allow_delete = True
86+
if overwrite:
87+
workspace.remove_entity(existing_entity[0])
6588

66-
def add_points_from_df(filename, df, overwrite=True, child = None, parent = None,
67-
normal_cols=['nx','ny','nz']):
89+
def add_points_from_df(filename, df, name='pointset', overwrite=True, columns=None, groupname="Loop", x_col='X', y_col='Y', z_col='Z'):
90+
"""
91+
Add points to a geoh5 file from a pandas DataFrame. The DataFrame must have columns 'name', 'X', 'Y', 'Z' for the point locations.
92+
Additional columns can be added as data associated with the points.
93+
Parameters
94+
----------
95+
filename: str
96+
Path to the geoh5 file.
97+
df: pandas.DataFrame
98+
DataFrame containing point data. Must have columns 'name', 'X', 'Y', 'Z'. Additional columns will be added as data.
99+
overwrite: bool, optional
100+
Whether to overwrite existing points with the same name. Default is True.
101+
columns: list of str, optional
102+
List of columns in the DataFrame to add as data. If None, all columns except 'name', 'X', 'Y', 'Z' will be added. Default is None.
103+
104+
"""
105+
if columns is None:
106+
columns = df.columns.tolist()
107+
if x_col not in columns or y_col not in columns or z_col not in columns:
108+
raise ValueError("DataFrame must contain 'name', 'X', 'Y', 'Z' columns. " \
109+
"Specify the column names using x_col, y_col, z_col parameters if they are different.")
68110
with geoh5py.workspace.Workspace(filename) as workspace:
69-
entities = workspace.get_entity(child)
70-
child_name = child
71-
child = entities[0] if entities else None
72-
if not child:
73-
child = geoh5py.groups.ContainerGroup.create(
74-
workspace, name=child_name, allow_delete=True,
75-
)
76-
if parent:
77-
parent.add_children(child)
78-
79-
for _, row in df.iterrows():
80-
name = row['name']
81-
loc = np.array([[row['X'], row['Y'], row['Z']]]) # shape (1,3)
82-
83-
# remove existing entity if present and overwrite requested
84-
if name in workspace.list_entities_name.values():
85-
existing = workspace.get_entity(name)
86-
if existing:
87-
existing[0].allow_delete = True
88-
if overwrite:
89-
workspace.remove_entity(existing[0])
111+
if groupname:
112+
group = workspace.get_entity(groupname)
113+
group = group[0] if group else None
114+
if not group:
115+
group = geoh5py.groups.ContainerGroup.create(
116+
workspace, name=groupname, allow_delete=True,
117+
)
118+
119+
location = np.array(df[[x_col, y_col, z_col]].values) # shape (n,3)
120+
121+
overwrite_object(workspace, name, overwrite)
122+
90123

91-
pts = geoh5py.objects.Points.create(
92-
workspace,
93-
name=name,
94-
vertices=loc,
95-
parent=child,
96-
)
97-
98-
# build data dict from normal_cols (and any other columns you want)
99-
data = {}
100-
for col in normal_cols:
101-
if col in row and not pd.isna(row[col]):
102-
# association must be "VERTEX" and values length must match vertices (1)
103-
data[col] = {"association": "VERTEX", "values": np.array([row[col]])}
124+
pts = geoh5py.objects.Points.create(
125+
workspace,
126+
name=name,
127+
vertices=location,
128+
parent=group,
129+
)
130+
data = {}
131+
for col in columns:
132+
if col in ['name', x_col, y_col, z_col]:
133+
continue
134+
data[col] = {"association": "VERTEX", "values": np.array(df[col]).flatten()}
135+
104136

105-
if data:
106-
pts.add_data(data)
137+
if data:
138+
pts.add_data(data)
139+
107140

108141
def add_structured_grid_to_geoh5(filename, structured_grid, overwrite=True, groupname="Loop"):
109142
with geoh5py.workspace.Workspace(filename) as workspace:

tests/unit/io/test_geoh5.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from LoopStructural.export.geoh5 import add_group_to_geoh5, add_points_to_geoh5, add_points_from_df
2+
import geoh5py
3+
import pytest
4+
from pathlib import Path
5+
from LoopStructural.datatypes import ValuePoints, VectorPoints
6+
import numpy as np
7+
@pytest.fixture
8+
def tmp_path():
9+
import tempfile
10+
with tempfile.TemporaryDirectory() as tmpdir:
11+
yield Path(tmpdir)
12+
@pytest.fixture
13+
def test_setup(tmp_path):
14+
filename = tmp_path / "test.geoh5"
15+
with geoh5py.workspace.Workspace.create(filename) as workspace:
16+
yield filename
17+
workspace.close()
18+
19+
def test_add_group_to_geoh5(test_setup):
20+
filename = test_setup
21+
group_uid = add_group_to_geoh5(filename, groupname="TestGroup")
22+
23+
with geoh5py.workspace.Workspace(filename) as workspace:
24+
assert workspace.get_entity(group_uid)[0].name == "TestGroup"
25+
26+
def test_add_points_to_geoh5(test_setup):
27+
filename = test_setup
28+
group_uid = add_group_to_geoh5(filename, groupname="TestGroup")
29+
points = ValuePoints(
30+
name="TestPoints",
31+
locations=[[0, 0, 0], [1, 1, 1], [2, 2, 2]],
32+
values=[10., 20, 30],
33+
)
34+
add_points_to_geoh5(filename, points, groupname=group_uid)
35+
with geoh5py.workspace.Workspace(filename) as workspace:
36+
point_entity = workspace.get_entity("TestPoints")[0]
37+
assert point_entity.name == "TestPoints"
38+
assert point_entity.vertices.tolist() == [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
39+
assert np.sum(point_entity.get_data("values")[0].values-np.array([10., 20., 30.])) == 0
40+
41+
def test_add_vector_points_to_geoh5(test_setup):
42+
filename = test_setup
43+
group_uid = add_group_to_geoh5(filename, groupname="TestGroup")
44+
points = VectorPoints(
45+
name="TestVectorPoints",
46+
locations=[[0, 0, 0], [1, 1, 1], [2, 2, 2]],
47+
vectors=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
48+
)
49+
add_points_to_geoh5(filename, points, groupname=group_uid)
50+
with geoh5py.workspace.Workspace(filename) as workspace:
51+
point_entity = workspace.get_entity("TestVectorPoints")[0]
52+
assert point_entity.name == "TestVectorPoints"
53+
assert point_entity.vertices.tolist() == [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
54+
assert np.sum(point_entity.get_data("vx")[0].values-np.array([1., 0., 0.])) == 0
55+
assert np.sum(point_entity.get_data("vy")[0].values-np.array([0., 1., 0.])) == 0
56+
assert np.sum(point_entity.get_data("vz")[0].values-np.array([0., 0., 1.])) == 0
57+
58+
def test_add_df_to_geoh5(test_setup):
59+
import pandas as pd
60+
filename = test_setup
61+
group_uid = add_group_to_geoh5(filename, groupname="TestGroup")
62+
df = pd.DataFrame({
63+
'X': [0, 1, 2],
64+
'Y': [0, 1, 2],
65+
'Z': [0, 1, 2],
66+
'value': [10., 20., 30.],
67+
})
68+
add_points_from_df(filename, df, name='df_points', groupname=group_uid)
69+
with geoh5py.workspace.Workspace(filename) as workspace:
70+
point_entity = workspace.get_entity("TestGroup")[0].children[0]
71+
assert point_entity.name == "df_points"
72+
assert point_entity.vertices.tolist() == [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
73+
assert np.sum(point_entity.get_data("value")[0].values-np.array([10., 20., 30.])) == 0
74+
75+
def test_add_df_with_alternate_xyz_to_geoh5(test_setup):
76+
import pandas as pd
77+
filename = test_setup
78+
group_uid = add_group_to_geoh5(filename, groupname="TestGroup")
79+
df = pd.DataFrame({
80+
'EAST': [0, 1, 2],
81+
'NORTH': [0, 1, 2],
82+
'RL': [0, 1, 2],
83+
'value': [10., 20., 30.],
84+
})
85+
add_points_from_df(filename, df, name='df_points',groupname=group_uid, x_col='EAST', y_col='NORTH', z_col='RL')
86+
with geoh5py.workspace.Workspace(filename) as workspace:
87+
point_entity = workspace.get_entity("TestGroup")[0].children[0]
88+
assert point_entity.name == "df_points"
89+
assert point_entity.vertices.tolist() == [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
90+
assert np.sum(point_entity.get_data("value")[0].values-np.array([10., 20., 30.])) == 0

0 commit comments

Comments
 (0)