Skip to content

Commit 9821745

Browse files
smukilSailesh Mukil
andauthored
Fix broken tests and get build back to green (#47)
1. MagicsTest.test_spanner_graph_magic_with_empty_cell: - Added check for empty cell in magics.py to fix - Hung test fixed by adding tearDown method 2. ConversionTest.test_get_nodes_edges(): - Incorrect use of DB.execute_query() API. Fixed 3. GraphEntitiesTest.test_add_edge_to_graph(): - Wrong type in assertion. Changed from 'int' to 'str' 4. SampleNotebookTest.test_notebook_cells(): - Wrong assertion (formatting) - Hung test fixed by adding tearDown method 5. GraphServerTest (multiple failures): - Formatting issue. Changed expected formatting in tests to include "'''" around properties Co-authored-by: Sailesh Mukil <mukil.sailesh@gmail.com>
1 parent 0f99a10 commit 9821745

7 files changed

Lines changed: 123 additions & 97 deletions

File tree

spanner_graphs/graph_server.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,27 @@ class EdgeDirection(Enum):
5656
def validate_property_type(property_type: str) -> TypeCode:
5757
"""
5858
Validates and converts a property type string to a Spanner TypeCode.
59-
59+
6060
Args:
6161
property_type: The property type string from the request
62-
62+
6363
Returns:
6464
The corresponding TypeCode enum value
65-
65+
6666
Raises:
6767
ValueError: If the property type is invalid
6868
"""
6969
if not property_type:
7070
raise ValueError("Property type must be provided")
71-
71+
7272
# Convert to uppercase for case-insensitive comparison
7373
property_type = property_type.upper()
74-
74+
7575
# Check if the type is valid
7676
if property_type not in PROPERTY_TYPE_MAP:
7777
valid_types = ', '.join(sorted(PROPERTY_TYPE_MAP.keys()))
7878
raise ValueError(f"Invalid property type: {property_type}. Allowed types are: {valid_types}")
79-
79+
8080
return PROPERTY_TYPE_MAP[property_type]
8181

8282
def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection):
@@ -149,11 +149,11 @@ def execute_node_expansion(
149149
params_str: str,
150150
request: dict) -> dict:
151151
"""Execute a node expansion query to find connected nodes and edges.
152-
152+
153153
Args:
154154
params_str: A JSON string containing connection parameters (project, instance, database, graph, mock).
155155
request: A dictionary containing node expansion request details (uid, node_labels, node_properties, direction, edge_label).
156-
156+
157157
Returns:
158158
dict: A dictionary containing the query response with nodes and edges.
159159
"""
@@ -235,7 +235,7 @@ def execute_query(project: str, instance: str, database: str, query: str, mock =
235235
"error": f"Query error: \n{getattr(err, 'message', str(err))}"
236236
}
237237
nodes, edges = get_nodes_edges(query_result, fields, schema_json)
238-
238+
239239
return {
240240
"response": {
241241
"nodes": [node.to_json() for node in nodes],
@@ -360,7 +360,7 @@ def handle_post_ping(self):
360360
def handle_post_query(self):
361361
data = self.parse_post_data()
362362
params = json.loads(data["params"])
363-
response = execute_query(
363+
response = execute_query(
364364
project=params["project"],
365365
instance=params["instance"],
366366
database=params["database"],
@@ -371,7 +371,7 @@ def handle_post_query(self):
371371

372372
def handle_post_node_expansion(self):
373373
"""Handle POST requests for node expansion.
374-
374+
375375
Expects a JSON payload with:
376376
- params: A JSON string containing connection parameters (project, instance, database, graph)
377377
- request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label)

spanner_graphs/magics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def receive_query_request(query: str, params: str):
9494

9595
def receive_node_expansion_request(request: dict, params_str: str):
9696
"""Handle node expansion requests in Google Colab environment
97-
97+
9898
Args:
9999
request: A dictionary containing node expansion details including:
100100
- uid: str - Unique identifier of the node to expand
@@ -108,7 +108,7 @@ def receive_node_expansion_request(request: dict, params_str: str):
108108
- database: str - Spanner database ID
109109
- graph: str - Graph name
110110
- mock: bool - Whether to use mock data
111-
111+
112112
Returns:
113113
JSON: A JSON-serialized response containing either:
114114
- The query results with nodes and edges
@@ -165,6 +165,7 @@ def visualize(self):
165165
@cell_magic
166166
def spanner_graph(self, line: str, cell: str):
167167
"""spanner_graph function"""
168+
168169
parser = argparse.ArgumentParser(
169170
description="Visualize network from Spanner database",
170171
exit_on_error=False)
@@ -184,6 +185,9 @@ def spanner_graph(self, line: str, cell: str):
184185
raise ValueError(
185186
"Please provide `--project`, `--instance`, "
186187
"and `--database` values for your query.")
188+
if not cell or not cell.strip():
189+
print("Error: Query is required.")
190+
return
187191

188192
self.args = parser.parse_args(line.split())
189193
self.cell = cell

tests/conversion_test.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_get_nodes_edges(self) -> None:
3838
"""
3939
# Get data from mock database
4040
mock_db = MockSpannerDatabase()
41-
data, fields, _, schema_json = mock_db.execute_query("")
41+
data, fields, _, schema_json, _ = mock_db.execute_query("")
4242

4343
# Convert data to nodes and edges
4444
nodes, edges = get_nodes_edges(data, fields)
@@ -72,7 +72,7 @@ def test_get_nodes_edges(self) -> None:
7272
self.assertTrue(hasattr(edge, 'destination'), "Edge should have a destination")
7373
self.assertIsInstance(edge.labels, list, "Edge labels should be a list")
7474
self.assertIsInstance(edge.properties, dict, "Edge properties should be a dict")
75-
75+
7676
# Verify edge endpoints exist in nodes
7777
source_exists = any(node.identifier == edge.source for node in nodes)
7878
dest_exists = any(node.identifier == edge.destination for node in nodes)
@@ -94,32 +94,32 @@ def test_get_nodes_edges_with_missing_nodes(self) -> None:
9494
}),
9595
json.dumps({
9696
"kind": "node",
97-
"identifier": "node1",
97+
"identifier": "node1",
9898
"labels": ["Device"],
9999
"properties": {"name": "Router"}
100100
})
101101
# Note: node2 is intentionally missing
102102
]
103103
}
104-
104+
105105
# Create a mock field for the column
106106
field = StructType.Field(
107107
name="column1",
108108
type_=Type(code=TypeCode.JSON)
109109
)
110-
110+
111111
# Convert data to nodes and edges
112112
nodes, edges = get_nodes_edges(data, [field])
113-
113+
114114
# Verify we got the expected number of nodes and edges
115115
self.assertEqual(len(edges), 1, "Should have one edge")
116116
self.assertEqual(len(nodes), 2, "Should have two nodes (one real, one intermediate)")
117-
117+
118118
# Verify node identifiers
119119
node_ids = {node.identifier for node in nodes}
120120
self.assertIn("node1", node_ids, "Original node should exist")
121121
self.assertIn("node2", node_ids, "Missing node should be created as intermediate")
122-
122+
123123
# Find the intermediate node
124124
intermediate_node = next((node for node in nodes if node.identifier == "node2"), None)
125125
self.assertIsNotNone(intermediate_node, "Intermediate node should exist")
@@ -150,33 +150,33 @@ def test_get_nodes_edges_with_multiple_references(self) -> None:
150150
}),
151151
json.dumps({
152152
"kind": "node",
153-
"identifier": "node1",
153+
"identifier": "node1",
154154
"labels": ["Device"],
155155
"properties": {"name": "Router"}
156156
}),
157157
json.dumps({
158158
"kind": "node",
159-
"identifier": "node2",
159+
"identifier": "node2",
160160
"labels": ["Device"],
161161
"properties": {"name": "Switch"}
162162
})
163163
# Note: missing_node is intentionally missing
164164
]
165165
}
166-
166+
167167
# Create a mock field for the column
168168
field = StructType.Field(
169169
name="column1",
170170
type_=Type(code=TypeCode.JSON)
171171
)
172-
172+
173173
# Convert data to nodes and edges
174174
nodes, edges = get_nodes_edges(data, [field])
175-
175+
176176
# Verify we got the expected number of nodes and edges
177177
self.assertEqual(len(edges), 2, "Should have two edges")
178178
self.assertEqual(len(nodes), 3, "Should have three nodes (two real, one intermediate)")
179-
179+
180180
# Count intermediate nodes
181181
intermediate_nodes = [node for node in nodes if node.intermediate]
182182
self.assertEqual(len(intermediate_nodes), 1, "Should create only one intermediate node")
@@ -197,32 +197,32 @@ def test_get_nodes_edges_with_complete_data(self) -> None:
197197
}),
198198
json.dumps({
199199
"kind": "node",
200-
"identifier": "node1",
200+
"identifier": "node1",
201201
"labels": ["Device"],
202202
"properties": {"name": "Router"}
203203
}),
204204
json.dumps({
205205
"kind": "node",
206-
"identifier": "node2",
206+
"identifier": "node2",
207207
"labels": ["Device"],
208208
"properties": {"name": "Switch"}
209209
})
210210
]
211211
}
212-
212+
213213
# Create a mock field for the column
214214
field = StructType.Field(
215215
name="column1",
216216
type_=Type(code=TypeCode.JSON)
217217
)
218-
218+
219219
# Convert data to nodes and edges
220220
nodes, edges = get_nodes_edges(data, [field])
221-
221+
222222
# Verify we got the expected number of nodes and edges
223223
self.assertEqual(len(edges), 1, "Should have one edge")
224224
self.assertEqual(len(nodes), 2, "Should have exactly two nodes (no intermediates)")
225-
225+
226226
# Verify no intermediate nodes exist
227227
intermediate_nodes = [node for node in nodes if node.intermediate]
228228
self.assertEqual(len(intermediate_nodes), 0, "Should not create any intermediate nodes")

tests/graph_entities_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ def test_intermediate_flag_in_constructor(self):
4343
# Test with intermediate=True
4444
node1 = Node("1", ["Person"], {"name": "Emmanuel"}, intermediate=True)
4545
self.assertTrue(node1.intermediate, "Node should be marked as intermediate")
46-
46+
4747
# Test with intermediate=False
4848
node2 = Node("2", ["Person"], {"name": "John"}, intermediate=False)
4949
self.assertFalse(node2.intermediate, "Node should not be marked as intermediate")
50-
50+
5151
# Test with default (should be False)
5252
node3 = Node("3", ["Person"], {"name": "Alice"})
5353
self.assertFalse(node3.intermediate, "Node should default to not intermediate")
@@ -56,7 +56,7 @@ def test_make_intermediate(self):
5656
"""Test the make_intermediate static method"""
5757
test_identifier = "test123"
5858
node = Node.make_intermediate(test_identifier)
59-
59+
6060
self.assertEqual(node.identifier, test_identifier, "Identifier should match input")
6161
self.assertEqual(node.labels, ["Intermediate"], "Labels should include 'Intermediate'")
6262
self.assertTrue(node.intermediate, "Node should be marked as intermediate")
@@ -72,7 +72,7 @@ def test_to_json_with_intermediate(self):
7272
node1 = Node("1", ["Person"], {"name": "Jill"}, intermediate=True)
7373
json1 = node1.to_json()
7474
self.assertTrue(json1["intermediate"], "JSON should include intermediate=True")
75-
75+
7676
# Test with intermediate=False
7777
node2 = Node("2", ["Person"], {"name": "John"}, intermediate=False)
7878
json2 = node2.to_json()
@@ -117,7 +117,7 @@ def test_add_edge_to_graph(self):
117117
edge = Edge.from_json(data)
118118
edge.add_to_graph(graph)
119119

120-
self.assertIn((1, 2), graph.edges)
120+
self.assertIn(('1', '2'), graph.edges)
121121
# self.assertEqual(graph.edges[1, 2]["label"], "KNOWS")
122122
# self.assertEqual(graph.edges[1, 2]["title"],
123123
# "--- Edge Properties ---\nsince: 2020")

0 commit comments

Comments
 (0)