|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 |
|
3 | 3 | import datetime as dt, logging, numpy as np, os, pandas as pd, pyarrow as pa, pytest |
| 4 | +from pandas.api.types import is_datetime64_any_dtype |
4 | 5 |
|
5 | 6 | from graphistry.pygraphistry import PyGraphistry |
6 | 7 | from graphistry.Engine import Engine, DataframeLike |
@@ -61,19 +62,18 @@ def honeypot_pdf() -> pd.DataFrame: |
61 | 62 | "time(max)": "float64", |
62 | 63 | "time(min)": "float64", |
63 | 64 | } |
64 | | - base_dtypes = { |
65 | | - **base_csv_dtypes, |
66 | | - "time(max)": "datetime64[ns]", |
67 | | - "time(min)": "datetime64[ns]", |
68 | | - } |
69 | 65 | df = pd.read_csv( |
70 | 66 | "graphistry/tests/data/honeypot.5.csv", |
71 | 67 | #'graphistry/tests/data/honeypot.csv', |
72 | 68 | dtype=base_csv_dtypes, |
73 | 69 | parse_dates=["time(max)", "time(min)"], |
74 | 70 | date_parser=lambda v: pd.to_datetime(int(float(v))), |
75 | 71 | ) |
76 | | - assert df.dtypes.to_dict() == base_dtypes |
| 72 | + dtypes = df.dtypes.to_dict() |
| 73 | + for col, dtype in base_csv_dtypes.items(): |
| 74 | + assert dtypes[col] == dtype |
| 75 | + for col in ("time(max)", "time(min)"): |
| 76 | + assert is_datetime64_any_dtype(dtypes[col]) |
77 | 77 | assert len(df) == HONEYPOT_ROWS |
78 | 78 | return df |
79 | 79 |
|
@@ -537,6 +537,23 @@ def test_hyper_to_pa_mixed2(self): |
537 | 537 | nodes_arr = pa.Table.from_pandas(hg.graph._nodes) |
538 | 538 | assert len(nodes_arr) == HONEYPOT_NODES |
539 | 539 |
|
| 540 | + @pytest.mark.parametrize("unit", ["ms", "us"]) |
| 541 | + def test_hyper_to_pa_mixed2_unit_variants(self, unit): |
| 542 | + df = honeypot_pdf().copy() |
| 543 | + for col in ("time(max)", "time(min)"): |
| 544 | + df[col] = df[col].astype(f"datetime64[{unit}]") |
| 545 | + |
| 546 | + hg = hypergraph(**honeypot_hyperparams(df)) |
| 547 | + |
| 548 | + for col in ("time(max)", "time(min)"): |
| 549 | + assert is_datetime64_any_dtype(hg.graph._edges.dtypes[col]) |
| 550 | + assert is_datetime64_any_dtype(hg.graph._nodes.dtypes[col]) |
| 551 | + |
| 552 | + edges_arr = pa.Table.from_pandas(hg.graph._edges) |
| 553 | + assert len(edges_arr) == HONEYPOT_EDGES |
| 554 | + nodes_arr = pa.Table.from_pandas(hg.graph._nodes) |
| 555 | + assert len(nodes_arr) == HONEYPOT_NODES |
| 556 | + |
540 | 557 | def test_hyper_to_pa_na(self): |
541 | 558 |
|
542 | 559 | df = pd.DataFrame({"x": ["a", None, "c"], "y": [1, 2, None]}) |
|
0 commit comments