diff --git a/.gitignore b/.gitignore index 781ad93..168ad2b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ keys/ # local virtual env .venv/ +viz/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 77fbe86..a2865cb 100644 --- a/README.md +++ b/README.md @@ -163,3 +163,52 @@ npm install npm run test:unit npm run test:visual ``` + + + + +# Project setup + +```Clone the Project Repository`` +Run the following command to clone the repository: + +# git clone + +```Then navigate into the project directory``` + +# cd spanner-graph-notebook + +```Create and Activate a Virtual Environment``` + +# python -m venv venv + +```Activate the virtual environment``` + +# source venv/bin/activate + + +```Install Backend Python Packages``` +From the root project directory (spanner-graph-notebook), install the required Python packages: + +# pip install -e . + +```Install Frontend Dependencies``` + +```avigate to the frontend directory``` + +cd frontend + +```Install the Node.js dependencies``` + +npm install + +```Start the Backend Server``` +Go to the backend development utility directory: + +cd ../spanner_graphs/dev_util + +```Start the development server``` + +# python serve_dev.py + + diff --git a/frontend/src/app.js b/frontend/src/app.js index ec2107c..8b0b3c2 100644 --- a/frontend/src/app.js +++ b/frontend/src/app.js @@ -21,6 +21,7 @@ import { Sidebar } from './visualization/spanner-sidebar.js'; import SpannerMenu from './visualization/spanner-menu.js'; import SpannerTable from './visualization/spanner-table.js'; import GraphVisualization from './visualization/spanner-forcegraph.js'; +import Helpers from './helper.js' class SpannerApp { /** @@ -101,6 +102,9 @@ class SpannerApp { } const {error, response} = data; + if (error){ + Helpers.showToast(error); + } this.loaderElement.classList.add('hidden'); diff --git a/frontend/src/authLoader.js b/frontend/src/authLoader.js new file mode 100644 index 0000000..665a8ff --- /dev/null +++ b/frontend/src/authLoader.js @@ -0,0 +1,141 @@ +class LoaderManager { + constructor() { + this.templateCache = null; + } + + createLoaderElement(message) { + const loaderContainer = document.createElement('div'); + + Object.assign(loaderContainer.style, { + position: 'absolute', + top: '0', + left: '0', + right: '0', + bottom: '0', + width: '100%', + height: '100%', + background: 'rgba(255, 255, 255, 0.9)', + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + zIndex: '99999', + pointerEvents: 'all', + borderRadius: '8px' + }); + + const wrapper = document.createElement('div'); + Object.assign(wrapper.style, { + textAlign: 'center', + padding: '20px' + }); + + const spinner = document.createElement('div'); + Object.assign(spinner.style, { + border: '6px solid #f3f3f3', + borderTop: '6px solid #4285F4', + borderRadius: '50%', + width: '40px', + height: '40px', + margin: 'auto', + animation: 'spin 1s linear infinite' + }); + + const messageEl = document.createElement('div'); + Object.assign(messageEl.style, { + marginTop: '10px', + fontSize: '16px', + color: '#333', + fontFamily: 'Arial, sans-serif' + }); + messageEl.textContent = message; + + if (!document.getElementById('loader-keyframes')) { + const style = document.createElement('style'); + style.id = 'loader-keyframes'; + style.textContent = ` + @keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } + } + `; + document.head.appendChild(style); + } + + // Assemble elements + wrapper.appendChild(spinner); + wrapper.appendChild(messageEl); + loaderContainer.appendChild(wrapper); + + return loaderContainer; + } + + async showLoader(message = "Loading...", container = null) { + const loaderId = container ? `auth-loader-${container.dataset.id || 'default'}` : 'auth-loader'; + + // Remove existing loader first + const existingLoader = document.getElementById(loaderId); + if (existingLoader) { + existingLoader.remove(); + } + + const loaderContainer = this.createLoaderElement(message); + loaderContainer.id = loaderId; + + if (container) { + const originalPosition = window.getComputedStyle(container).position; + + if (originalPosition === 'static' || originalPosition === '') { + container.style.position = 'relative'; + loaderContainer.dataset.originalPosition = ''; + } else { + loaderContainer.dataset.originalPosition = container.style.position; + } + + if (container.offsetHeight === 0) { + container.style.minHeight = '200px'; + } + + container.appendChild(loaderContainer); + } else { + Object.assign(loaderContainer.style, { + position: 'fixed', + top: '0', + left: '0', + width: '100vw', + height: '100vh' + }); + document.body.appendChild(loaderContainer); + } + loaderContainer.offsetHeight; + + setTimeout(() => { + const computedStyle = window.getComputedStyle(loaderContainer); + }, 100); + + return loaderContainer; + } + + removeLoader(container = null) { + const loaderId = container ? `auth-loader-${container.dataset.id || 'default'}` : 'auth-loader'; + const loader = document.getElementById(loaderId); + + if (loader) { + if (container) { + const originalPosition = loader.dataset.originalPosition; + if (originalPosition !== undefined) { + if (originalPosition === '') { + container.style.position = 'static'; + } else { + container.style.position = originalPosition; + } + } + } + + loader.remove(); + } + } +} + +const loader = new LoaderManager(); +window.showLoader = loader.showLoader.bind(loader); +window.removeLoader = loader.removeLoader.bind(loader); \ No newline at end of file diff --git a/frontend/src/graph-server.js b/frontend/src/graph-server.js index 26d3573..67ea825 100644 --- a/frontend/src/graph-server.js +++ b/frontend/src/graph-server.js @@ -30,6 +30,12 @@ class GraphServer { getPing: '/get_ping', postQuery: '/post_query', postNodeExpansion: '/post_node_expansion', + gcpProjects:'/gcp_projects', + getInstances: '/get_instances', + getDatabases: '/get_databases', + saveConfig: '/save_config', + getSavedConfig: '/get_saved_config' + }; /** diff --git a/frontend/src/helper.js b/frontend/src/helper.js new file mode 100644 index 0000000..6310a50 --- /dev/null +++ b/frontend/src/helper.js @@ -0,0 +1,15 @@ +class Helpers { + showToast(message, duration = 3000) { + const toast = document.getElementById("toast"); + toast.textContent = message; + toast.classList.remove("hidden"); + toast.classList.add("show"); + + setTimeout(() => { + toast.classList.remove("show"); + toast.classList.add("hidden"); + }, duration); + } +} + +export default new Helpers(); \ No newline at end of file diff --git a/frontend/src/loader.js b/frontend/src/loader.js new file mode 100644 index 0000000..0677809 --- /dev/null +++ b/frontend/src/loader.js @@ -0,0 +1,19 @@ +/** + * Copyright 2025 Google LLC + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +function removeLoader() { + const loader = document.getElementById('auth-loader-container'); + if (loader) loader.remove(); +} diff --git a/frontend/static/dev.html b/frontend/static/dev.html index 5b1841e..d5d0786 100644 --- a/frontend/static/dev.html +++ b/frontend/static/dev.html @@ -216,6 +216,12 @@ box-shadow: 0 0 0 2px rgba(26, 115, 232, 0.1); } + .connection-field input:disabled { + background: #f5f5f5; + color: #999; + cursor: not-allowed; + } + .connection-field label { position: absolute; left: 8px; @@ -388,19 +394,136 @@ font-size: 12px; pointer-events: none; } + .loader-container { + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background: rgba(255, 255, 255, 0.9); + display: flex; + justify-content: center; + align-items: center; + z-index: 9999; + opacity: 0; + pointer-events: none; + transition: opacity 0.3s ease; + } + + .loader-container.show { + opacity: 1; + pointer-events: all; + } + + .loader-content { + display: flex; + flex-direction: column; + align-items: center; + gap: 16px; + } + + .spinner { + width: 40px; + height: 40px; + border: 4px solid var(--border-color); + border-top: 4px solid var(--primary-color); + border-radius: 50%; + animation: spin 1s linear infinite; + } + + @keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } + } + + .loader-text { + font-size: 16px; + color: var(--text-color); + font-weight: 500; + } + .custom-dropdown { + position: absolute; + top: 100%; + left: 0; + width: 100%; + margin-top: 4px; + background-color: white; + border: 1px solid var(--border-color); + border-radius: 4px; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); + z-index: 1000; + max-height: 150px; + overflow-y: auto; + display: none; /* initially hidden */ + } + + .dropdown-item { + padding: 8px 12px; + font-size: 14px; + color: var(--text-color); + cursor: pointer; + white-space: nowrap; + } + + .dropdown-item:hover { + background-color: #f1f3f4; + } + .toast { + position: fixed; + bottom: 30px; + left: 50%; + transform: translateX(-50%); + background-color: #e74c3c; /* red for error */ + color: #fff; + padding: 12px 24px; + border-radius: 6px; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); + font-size: 14px; + z-index: 9999; + opacity: 0; + pointer-events: none; + transition: opacity 0.3s ease, transform 0.3s ease; + } + + .toast.show { + opacity: 1; + transform: translateX(-50%) translateY(0); + pointer-events: all; + } + + .toast.hidden { + opacity: 0; + transform: translateX(-50%) translateY(20px); + } + + + +
+
+
+
Fetching Resources ...
+
+
+
-
@@ -416,19 +539,23 @@

Configure Visualization

-
- - +
+ +
+
- - +
+ +
+
- - +
+ +
@@ -456,6 +583,254 @@

Configure Visualization

diff --git a/frontend/static/jupyter.html b/frontend/static/jupyter.html index 54be9c1..be3639e 100644 --- a/frontend/static/jupyter.html +++ b/frontend/static/jupyter.html @@ -1,4 +1,4 @@ - -
+ +
+ + +
+ +

Configure Visualization

+ + +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ +
+ +
+ + +
+ +
+ + +
+ +
+ +
+
{{ bundled_js_code }} \ No newline at end of file + function isColabEnv() { + return typeof google !== 'undefined' && google.colab && google.colab.kernel; + } + (function() { + {{loader_js_code | safe}} + const projects_{{ id }} = {{ projects | safe }}; + const port_{{ id }} = {{ port }}; + const queryData_{{ id }} = {{ query | tojson }}; + const params_{{ id }} = {{ params | safe }}; + const showConfig_{{ id }} = {{ show_config_on_load | tojson }}; + const wrapperId_{{ id }} = "{{ id }}"; + + function getWrapper_{{ id }}(el) { + return el.closest(`[data-id="${wrapperId_{{ id }}}"]`); + } + + window.toggleConfig_{{ id }} = function(button) { + const wrapper = getWrapper_{{ id }}(button); + wrapper.querySelector('.config-overlay').classList.toggle('show'); + }; + + window.hideConfig_{{ id }} = function(button) { + const wrapper = getWrapper_{{ id }}(button); + wrapper.querySelector('.config-overlay').classList.remove('show'); + }; + + window.applyConfig_{{ id }} = function(button) { + const wrapper = getWrapper_{{ id }}(button); + const project = wrapper.querySelector('.project').value.trim(); + const instance = wrapper.querySelector('.instance').value.trim(); + const database = wrapper.querySelector('.database').value.trim(); + const query = wrapper.querySelector('.query').value.trim(); + const mock = wrapper.querySelector('.mock').checked; + + if (!mock && (!project || !instance || !database)) { + alert('Please fill in Project ID, Instance ID, and Database.'); + return; + } + + if (!query) { + alert('Please enter a query.'); + return; + } + + let graph = ""; + if (query.toUpperCase().includes("GRAPH ")) { + const match = query.match(/GRAPH\s+(\w+)/i); + if (match) { + graph = match[1]; + } + } + + const newParams = { project, instance, database, mock, graph }; + const mount = wrapper.querySelector('.graph-container'); + mount.innerHTML = ''; + + try { + new SpannerApp({ + id: wrapperId_{{ id }}, + mount: mount, + port: port_{{ id }}, + params: JSON.stringify(newParams), + query: query + }); + } catch (err) { + console.error(err); + mount.innerHTML = `
Error: ${err.message}
`; + } + + hideConfig_{{ id }}(button); + }; + + window.populateDropdown_{{ id }} = function(input, field) { + const wrapper = getWrapper_{{ id }}(input); + const project = wrapper.querySelector('.project').value; + const instance = wrapper.querySelector('.instance').value; + + let items = []; + const listEl = wrapper.querySelector(`.${field}-list`); + listEl.innerHTML = ''; + + if (field === 'project') { + items = (projects_{{ id }}.projects || []).map(p => p.projectId); + items.forEach(item => { + const div = document.createElement('div'); + div.textContent = item; + div.onclick = () => { + wrapper.querySelector(`.${field}`).value = item; + listEl.style.display = 'none'; + fetchInstances_{{ id }}(item, wrapper); + }; + listEl.appendChild(div); + }); + } else if (field === 'instance') { + const cached = input.dataset.options?.split(',') || []; + cached.forEach(item => { + const div = document.createElement('div'); + div.textContent = item; + div.onclick = () => { + wrapper.querySelector(`.${field}`).value = item; + listEl.style.display = 'none'; + fetchDatabases_{{ id }}(project, item, wrapper); + }; + listEl.appendChild(div); + }); + } else if (field === 'database') { + const cached = input.dataset.options?.split(',') || []; + cached.forEach(item => { + const div = document.createElement('div'); + div.textContent = item; + div.onclick = () => { + wrapper.querySelector(`.${field}`).value = item; + listEl.style.display = 'none'; + }; + listEl.appendChild(div); + }); + } + + listEl.style.display = 'block'; + }; + + window.filterDropdown_{{ id }} = function(input, field) { + populateDropdown_{{ id }}(input, field); + const wrapper = getWrapper_{{ id }}(input); + const filter = input.value.toLowerCase(); + const listEl = wrapper.querySelector(`.${field}-list`); + Array.from(listEl.children).forEach(child => { + child.style.display = child.textContent.toLowerCase().includes(filter) ? 'block' : 'none'; + }); + }; + async function fetchInstances_{{ id }}(projectId, wrapper) { + const input = wrapper.querySelector('.instance'); + const listEl = wrapper.querySelector('.instance-list'); + input.value = ''; + input.dataset.options = ''; + listEl.innerHTML = ''; + + const baseUrl = wrapper.dataset.baseUrl; + //const isColab = typeof google !== 'undefined' && google.colab && google.colab.kernel; + + try { + showLoader("Fetching instances...", wrapper); + + let instances = []; + + if (isColabEnv()) { + const response = await google.colab.kernel.invokeFunction( + 'graph_visualization.GetInstances', + [projectId], + {} + ); + const payload = response.data?.['application/json']; + if (!payload || !Array.isArray(payload.instances)) { + throw new Error("Invalid or missing instance list"); + } + + instances = payload.instances; + } else { + const url = `${baseUrl}/get_instances?project=${projectId}`; + const res = await fetch(url); + if (!res.ok) throw new Error("Something went wrong while fetching instances"); + const data = await res.json(); + instances = data.instances; + } + + input.dataset.options = instances.join(','); + instances.forEach(item => { + const div = document.createElement('div'); + div.textContent = item; + div.onclick = () => { + input.value = item; + listEl.style.display = 'none'; + fetchDatabases_{{ id }}(projectId, item, wrapper); + }; + listEl.appendChild(div); + }); + listEl.style.display = 'block'; + } catch (err) { + console.error('Error fetching instances:', err); + } finally { + removeLoader(wrapper); + } + } + + async function fetchDatabases_{{ id }}(projectId, instanceId, wrapper) { + const input = wrapper.querySelector('.database'); + const listEl = wrapper.querySelector('.database-list'); + input.value = ''; + input.dataset.options = ''; + listEl.innerHTML = ''; + + const baseUrl = wrapper.dataset.baseUrl; + //const isColab = typeof google !== 'undefined' && google.colab && google.colab.kernel; + + try { + showLoader("Fetching databases...", wrapper); + + let databases = []; + + if (isColabEnv()) { + const response = await google.colab.kernel.invokeFunction( + 'graph_visualization.GetDatabases', + [projectId, instanceId], + {} + ); + const payload = response.data?.['application/json']; + if (!payload || !Array.isArray(payload.databases)) { + throw new Error("Invalid or missing instance list"); + } + + databases = payload.databases; + } else { + const url = `${baseUrl}/get_databases?project=${projectId}&instance=${instanceId}`; + const res = await fetch(url); + if (!res.ok) throw new Error("Something went wrong while fetching databases"); + const data = await res.json(); + databases = data.databases; + } + + input.dataset.options = databases.join(','); + databases.forEach(item => { + const div = document.createElement('div'); + div.textContent = item; + div.onclick = () => { + input.value = item; + listEl.style.display = 'none'; + }; + listEl.appendChild(div); + }); + listEl.style.display = 'block'; + } catch (err) { + console.error('Error fetching databases:', err); + } finally { + removeLoader(wrapper); + } + } + + const wrapper = document.querySelector(`[data-id="${wrapperId_{{ id }}}"]`); + if (wrapper) { + const mount = wrapper.querySelector('.graph-container'); + const projectInput = wrapper.querySelector('.project'); + const instanceInput = wrapper.querySelector('.instance'); + const databaseInput = wrapper.querySelector('.database'); + const queryInput = wrapper.querySelector('.query'); + const mockInput = wrapper.querySelector('.mock'); + + if (projectInput) projectInput.value = params_{{ id }}.project || ''; + if (instanceInput) instanceInput.value = params_{{ id }}.instance || ''; + if (databaseInput) databaseInput.value = params_{{ id }}.database || ''; + if (queryInput) queryInput.value = queryData_{{ id }} || ''; + if (mockInput) mockInput.checked = params_{{ id }}.mock || false; + + if (showConfig_{{ id }}) { + wrapper.querySelector('.config-overlay').classList.add('show'); + } else { + new SpannerApp({ + id: wrapperId_{{ id }}, + mount, + port: port_{{ id }}, + params: JSON.stringify(params_{{ id }}), + query: queryData_{{ id }} + }); + } + } + + document.addEventListener('keydown', function(e) { + if (e.key === 'Escape') { + const wrapper = document.querySelector(`[data-id="${wrapperId_{{ id }}}"]`); + if (wrapper) { + const overlay = wrapper.querySelector('.config-overlay.show'); + if (overlay) { + overlay.classList.remove('show'); + } + } + } + }); + })(); + \ No newline at end of file diff --git a/frontend/static/loader.html b/frontend/static/loader.html new file mode 100644 index 0000000..ddef11a --- /dev/null +++ b/frontend/static/loader.html @@ -0,0 +1,66 @@ + + + + + + Loader + + + +
+
+
+
{{message}}
+
+
+ + \ No newline at end of file diff --git a/sample.ipynb b/sample.ipynb index 72b82a9..4779334 100644 --- a/sample.ipynb +++ b/sample.ipynb @@ -32,10 +32,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "e2b05db4-01c6-4e6f-96f3-f964d7f05786", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Spanner Graph Notebook loaded\n" + ] + } + ], "source": [ "%load_ext spanner_graphs" ] @@ -54,10 +62,1596 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "43e00d3c-add4-451c-999d-44807a7e80da", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "
\n", + " \n", + "

Configure Visualization

\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "%%spanner_graph --mock\n", "\n" @@ -71,6 +1665,1665 @@ "## Query and visualize Spanner Graph data\n" ] }, + { + "cell_type": "markdown", + "id": "032be8df-a4e4-4116-b4ea-6cc118aaa68e", + "metadata": {}, + "source": [ + "run the following, that will show config pop where you can select your resources and visualize the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e2051ee5-85c0-41a2-9b9b-f045f3ef62ae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "
Authenticating and fetching GCP resources...
\n", + "
\n", + " \n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + " const loader = document.getElementById('loader-container');\n", + " if (loader) loader.remove();\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "
\n", + " \n", + "

Configure Visualization

\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%spanner_graph\n", + " " + ] + }, { "cell_type": "markdown", "id": "671d1d2e-ec9a-4ead-bb99-2e4cd4d9e099", @@ -113,7 +3366,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/setup.py b/setup.py index 3a9a1c4..4de568b 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ def package_files(directory): install_requires=[ "networkx", "numpy", "google-cloud-spanner", "ipython", "ipywidgets", "notebook", "requests", "portpicker", - "pydata-google-auth" + "pydata-google-auth", "google-api-python-client" ], include_package_data=True, description='Visually query Spanner Graph data in notebooks.', diff --git a/spanner_graphs/cloud_database.py b/spanner_graphs/cloud_database.py index 0c575a0..27e7b6d 100644 --- a/spanner_graphs/cloud_database.py +++ b/spanner_graphs/cloud_database.py @@ -18,15 +18,15 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List from google.cloud import spanner from google.cloud.spanner_v1 import JsonObject from google.api_core.client_options import ClientOptions -from google.cloud.spanner_v1.types import StructType, Type, TypeCode +from google.cloud.spanner_v1.types import StructType, TypeCode import pydata_google_auth -from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase, SpannerQueryResult, SpannerFieldInfo +from spanner_graphs.database import SpannerDatabase, SpannerQueryResult, SpannerFieldInfo def _get_default_credentials_with_project(): return pydata_google_auth.default( @@ -89,6 +89,8 @@ def _get_schema_for_graph(self, graph_query: str) -> Any | None: def execute_query( self, query: str, + params: Dict[str, Any] = None, + param_types: Dict[str, Any] = None, limit: int = None, is_test_query: bool = False, ) -> SpannerQueryResult: @@ -97,6 +99,8 @@ def execute_query( Args: query: The SQL query to execute against the database + params: A dictionary of query parameters + param_types: A dictionary of parameter types limit: An optional limit for the number of rows to return is_test_query: If true, skips schema fetching for graph queries. @@ -108,13 +112,16 @@ def execute_query( self.schema_json = self._get_schema_for_graph(query) with self.database.snapshot() as snapshot: - params = None param_types = None if limit and limit > 0: params = dict(limit=limit) try: - results = snapshot.execute_sql(query, params=params, param_types=param_types) + results = snapshot.execute_sql( + query, + params=params, + param_types=param_types + ) rows = list(results) except Exception as e: return SpannerQueryResult( diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index 91db0ac..4b0e9fd 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -19,13 +19,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, NamedTuple +from typing import Any, Dict, List, NamedTuple import json import os import csv from dataclasses import dataclass + class SpannerQueryResult(NamedTuple): # A dict where each key is a field name returned in the query and the list # contains all items of the same type found for the given field @@ -39,6 +40,7 @@ class SpannerQueryResult(NamedTuple): # The error message if any err: Exception | None + class SpannerDatabase(ABC): """The spanner class holding the database connection""" @@ -54,6 +56,7 @@ def _get_schema_for_graph(self, graph_query: str): def execute_query( self, query: str, + params: Dict[str, Any] = None, limit: int = None, is_test_query: bool = False, ) -> SpannerQueryResult: @@ -96,6 +99,7 @@ def _load_data(self): def __iter__(self): return iter(self._rows) + class MockSpannerDatabase(): """Mock database class""" @@ -110,6 +114,8 @@ def __init__(self): def execute_query( self, _: str, + params: Dict[str, Any] = None, + param_types: Dict[str, Any] = None, limit: int = 5 ) -> SpannerQueryResult: """Mock execution of query""" diff --git a/spanner_graphs/gcp_helper.py b/spanner_graphs/gcp_helper.py new file mode 100644 index 0000000..0ff3ce1 --- /dev/null +++ b/spanner_graphs/gcp_helper.py @@ -0,0 +1,102 @@ +from google.cloud import spanner_admin_instance_v1, spanner_admin_database_v1 +from googleapiclient.discovery import build +from google.api_core.client_options import ClientOptions +import pydata_google_auth + +class GcpHelper: + @staticmethod + def get_default_credentials_with_project(): + credentials, _ = pydata_google_auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + use_local_webserver=False + ) + return credentials + + @staticmethod + def fetch_gcp_projects(credentials): + """Fetch only GCP projects (no instances or databases).""" + try: + crm_service = build("cloudresourcemanager", "v1", credentials=credentials) + projects_resp = crm_service.projects().list().execute() + projects = projects_resp.get("projects", []) + return [{"projectId": p["projectId"], "name": p.get("name", "")} for p in projects] + except Exception as e: + print(f"[!] Error fetching GCP projects: {e}") + return [] + + @staticmethod + def fetch_project_instances(credentials, project_id: str): + try: + client_options = ClientOptions(quota_project_id=project_id) + instance_client = spanner_admin_instance_v1.InstanceAdminClient( + credentials=credentials, + client_options=client_options + ) + instances = instance_client.list_instances(parent=f"projects/{project_id}") + return [inst.name.split("/")[-1] for inst in instances] + except Exception as e: + print(f"[!] Error fetching instances: {e}") + return [] + + @staticmethod + def fetch_instance_databases(credentials, project_id: str, instance_id: str): + try: + client_options = ClientOptions(quota_project_id=project_id) + db_client = spanner_admin_database_v1.DatabaseAdminClient( + credentials=credentials, + client_options=client_options + ) + dbs = db_client.list_databases(parent=f"projects/{project_id}/instances/{instance_id}") + return [db.name.split("/")[-1] for db in dbs] + except Exception as e: + print(f"[!] Error fetching databases: {e}") + return [] + + @staticmethod + def fetch_all_gcp_resources(credentials): + result = {} + try: + crm_service = build("cloudresourcemanager", "v1", credentials=credentials) + projects_resp = crm_service.projects().list().execute() + projects = projects_resp.get("projects", []) + + for project in projects: + project_id = project["projectId"] + result[project_id] = {"instances": {}} + + client_options = ClientOptions(quota_project_id=project_id) + instance_client = spanner_admin_instance_v1.InstanceAdminClient( + credentials=credentials, + client_options=client_options + ) + + try: + instances = instance_client.list_instances(parent=f"projects/{project_id}") + except Exception as e: + print(f"[!] Skipping project {project_id} due to instance error: {e}") + continue + + for instance in instances: + instance_id = instance.name.split("/")[-1] + result[project_id]["instances"][instance_id] = [] + + db_client = spanner_admin_database_v1.DatabaseAdminClient( + credentials=credentials, + client_options=client_options + ) + + try: + dbs = db_client.list_databases( + parent=f"projects/{project_id}/instances/{instance_id}" + ) + for db in dbs: + db_id = db.name.split("/")[-1] + result[project_id]["instances"][instance_id].append(db_id) + except Exception as e: + print(f"[!] Skipping databases for {project_id}/{instance_id}: {e}") + continue + except Exception as e: + print(f"[!] Error fetching GCP resources: {e}") + # Return an empty result if there's a broader error during fetching + return {} + return result diff --git a/spanner_graphs/graph_entities.py b/spanner_graphs/graph_entities.py index 95e04f8..2c0b6d4 100644 --- a/spanner_graphs/graph_entities.py +++ b/spanner_graphs/graph_entities.py @@ -22,7 +22,6 @@ import json from typing import Any, Dict, List import networkx as nx -from numpy import number reserved_colors: dict[str, str] = {} diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index cf318c3..5204dcc 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -22,10 +22,14 @@ import requests import portpicker import atexit +from datetime import datetime from spanner_graphs.conversion import get_nodes_edges from spanner_graphs.exec_env import get_database_instance from spanner_graphs.database import SpannerQueryResult +from google.cloud import spanner +from spanner_graphs.gcp_helper import GcpHelper +from urllib.parse import urlparse, parse_qs # Supported types for a property PROPERTY_TYPE_SET = { @@ -52,6 +56,9 @@ class EdgeDirection(Enum): INCOMING = "INCOMING" OUTGOING = "OUTGOING" +cached_credentials = None + + def is_valid_property_type(property_type: str) -> bool: """ Validates a property type. @@ -145,14 +152,18 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio return validated_properties, direction + def execute_node_expansion( params_str: str, - request: dict) -> dict: + request: dict +) -> dict: """Execute a node expansion query to find connected nodes and edges. Args: - params_str: A JSON string containing connection parameters (project, instance, database, graph, mock). - request: A dictionary containing node expansion request details (uid, node_labels, node_properties, direction, edge_label). + params_str: A JSON string containing connection parameters (project, + instance, database, graph, mock). + request: A dictionary containing node expansion request details (uid, + node_labels, node_properties, direction, edge_label). Returns: dict: A dictionary containing the query response with nodes and edges. @@ -182,20 +193,51 @@ def execute_node_expansion( if node_labels and len(node_labels) > 0: node_label_str = f": {' & '.join(node_labels)}" - node_property_strings: list[str] = [] - for node_property in node_properties: - value_str: str - if node_property.type_str in ('INT64', 'NUMERIC', 'FLOAT32', 'FLOAT64', 'BOOL'): - value_str = node_property.value - else: - value_str = f"\'''{node_property.value}\'''" - node_property_strings.append(f"n.{node_property.key}={value_str}") + node_property_clauses: list[str] = [] + params_dict: dict = {} + param_types_dict: dict = {} + + for i, node_property in enumerate(node_properties): + param_name = f"param_{i}" + node_property_clauses.append(f"n.{node_property.key} = @{param_name}") + + # Convert value to native Python type + type_str = node_property.type_str + value = node_property.value + + if type_str in ("INT64", "NUMERIC"): + value_casting = int(value) + param_type = spanner.param_types.INT64 + elif type_str in ("FLOAT32", "FLOAT64"): + value_casting = float(value) + param_type = spanner.param_types.FLOAT64 + elif type_str == "BOOL": + value_casting = value.lower() == "true" + param_type = spanner.param_types.BOOL + elif type_str == "STRING": + value_casting = str(value) + param_type = spanner.param_types.STRING + elif type_str == "DATE": + value_casting = datetime.strptime(value, "%Y-%m-%d").date() + param_type = spanner.param_types.DATE + elif type_str == "TIMESTAMP": + value_casting = datetime.fromisoformat(value.replace("Z", "+00:00")) + param_type = spanner.param_types.TIMESTAMP + + params_dict[param_name] = value_casting + param_types_dict[param_name] = param_type + + filtered_uid = "STRING(TO_JSON(n).identifier) = @uid" + params_dict["uid"] = str(uid) + param_types_dict["uid"] = spanner.param_types.STRING + + where_clauses = node_property_clauses + [filtered_uid] + where_clause_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" query = f""" GRAPH {graph} - LET uid = "{uid}" MATCH (n{node_label_str}) - WHERE {' and '.join(node_property_strings)} {'and' if node_property_strings else ''} STRING(TO_JSON(n).identifier) = uid + {where_clause_str} RETURN n NEXT @@ -204,7 +246,11 @@ def execute_node_expansion( RETURN TO_JSON(e) as e, TO_JSON(d) as d """ - return execute_query(project, instance, database, query, mock=False) + return execute_query( + project, instance, database, query, mock=False, + params=params_dict, param_types=param_types_dict + ) + def execute_query( project: str, @@ -212,6 +258,8 @@ def execute_query( database: str, query: str, mock: bool = False, + params: Dict[str, Any] = None, + param_types: Dict[str, Any] = None, ) -> Dict[str, Any]: """Executes a query against a database and formats the result. @@ -233,7 +281,11 @@ def execute_query( """ try: db_instance = get_database_instance(project, instance, database, mock) - result: SpannerQueryResult = db_instance.execute_query(query) + result: SpannerQueryResult = db_instance.execute_query( + query, + params=params, + param_types=param_types + ) if len(result.rows) == 0 and result.err: error_message = f"Query error: \n{getattr(result.err, 'message', str(result.err))}" @@ -257,7 +309,8 @@ def execute_query( } # Process a successful query result - nodes, edges = get_nodes_edges(result.data, result.fields, result.schema_json) + nodes, edges = get_nodes_edges(result.data, result.fields, + result.schema_json) return { "response": { @@ -282,6 +335,11 @@ class GraphServer: "post_ping": "/post_ping", "post_query": "/post_query", "post_node_expansion": '/post_node_expansion', + "gcp_projects": '/gcp_projects', + "get_instances": '/get_instances', + "get_databases": '/get_databases', + "save_config": '/save_config', + "get_saved_config": '/get_saved_config' } _server = None @@ -352,9 +410,18 @@ def do_json_response(self, data): self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Content-type", "application/json") self.send_header("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type") self.end_headers() self.wfile.write(json.dumps(data).encode()) + def do_OPTIONS(self): + self.send_response(200) + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type") + self.end_headers() + + def do_message_response(self, message): self.do_json_response({'message': message}) @@ -411,9 +478,96 @@ def handle_post_node_expansion(self): self.do_error_response(e) return + def get_gcp_resources(self): + try: + # Always fetch fresh data + credentials = GcpHelper.get_default_credentials_with_project() + gcp_data = GcpHelper.fetch_all_gcp_resources(credentials) + self.do_json_response(gcp_data) + except Exception as e: + self.do_error_response(str(e)) + + def get_gcp_projects_only(self): + global cached_credentials + try: + if not cached_credentials: + cached_credentials = GcpHelper.get_default_credentials_with_project() + projects = GcpHelper.fetch_gcp_projects(cached_credentials) + self.do_json_response({"projects": projects}) + except Exception as e: + self.do_error_response(str(e)) + + def handle_get_instances(self): + global cached_credentials + try: + parsed = urlparse(self.path) + params = parse_qs(parsed.query) + project_id = params.get("project", [None])[0] + + if not project_id: + self.do_error_response("project_id is missing") + return + + if not cached_credentials: + cached_credentials = GcpHelper.get_default_credentials_with_project() + + instances = GcpHelper.fetch_project_instances(cached_credentials, project_id) + + self.do_json_response({"instances": instances}) + + except Exception as e: + self.do_error_response(str(e)) + + def handle_get_databases(self): + global cached_credentials + try: + parsed = urlparse(self.path) + params = parse_qs(parsed.query) + project_id = params.get("project", [None])[0] + instance_id = params.get("instance", [None])[0] + + if not project_id or not instance_id: + self.do_error_response("both project_id and instance_id are required") + return + + if not cached_credentials: + cached_credentials = GcpHelper.get_default_credentials_with_project() + + databases = GcpHelper.fetch_instance_databases(cached_credentials, project_id, instance_id) + + self.do_json_response({"databases": databases}) + + except Exception as e: + self.do_error_response(str(e)) + + def handle_save_user_data(self): + print("handle user data function") + global saved_user_config + try: + data = self.parse_post_data() + saved_user_config = data + self.do_json_response({"status": "saved"}) + except Exception as e: + self.do_error_response(str(e)) + + def get_saved_user_data(self): + print("get user data") + global saved_user_config + self.do_data_response(saved_user_config) + def do_GET(self): + parsed_path = urlparse(self.path).path + print(parsed_path) if self.path == GraphServer.endpoints["get_ping"]: self.handle_get_ping() + elif parsed_path == GraphServer.endpoints["get_instances"]: + self.handle_get_instances() + elif parsed_path == GraphServer.endpoints["get_databases"]: + self.handle_get_databases() + elif self.path == GraphServer.endpoints["gcp_projects"]: + self.get_gcp_projects_only() + elif self.path == GraphServer.endpoints["get_saved_config"]: + self.get_saved_user_data() else: super().do_GET() @@ -424,5 +578,7 @@ def do_POST(self): self.handle_post_query() elif self.path == GraphServer.endpoints["post_node_expansion"]: self.handle_post_node_expansion() + elif self.path == GraphServer.endpoints['save_config']: + self.handle_save_user_data() atexit.register(GraphServer.stop_server) diff --git a/spanner_graphs/graph_visualization.py b/spanner_graphs/graph_visualization.py index 8a9cd77..fe539ea 100644 --- a/spanner_graphs/graph_visualization.py +++ b/spanner_graphs/graph_visualization.py @@ -44,7 +44,7 @@ def _load_image(path: list[str]) -> str: with open(file_path, 'rb') as file: return base64.b64decode(file.read()).decode('utf-8') -def generate_visualization_html(query: str, port: int, params: str): +def generate_visualization_html(query: str, port: int, params: str, show_config_on_load: bool = False, projects: str = "{}", base_url: str = None, host: str = "localhost"): # Get the directory of the current file (magics.py) current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -57,7 +57,9 @@ def generate_visualization_html(query: str, port: int, params: str): search_dir = parent template_content = _load_file([search_dir, 'frontend', 'static', 'jupyter.html']) - + + loader_js_code = _load_file([search_dir, 'frontend', 'src', 'authLoader.js']) + # Load the JavaScript bundle directly js_file_path = os.path.join(search_dir, 'third_party', 'index.js') try: @@ -70,6 +72,9 @@ def generate_visualization_html(query: str, port: int, params: str): # Retrieve image content graph_background_image = _load_image([search_dir, "frontend", "static", "graph-bg.svg"]) + if base_url is None: + base_url = f"http://{host}:{port}" + # Create a Jinja2 template template = Template(template_content) @@ -80,6 +85,10 @@ def generate_visualization_html(query: str, port: int, params: str): query=query, params=params, port=port, + base_url=base_url, + show_config_on_load=show_config_on_load, + projects=projects, + loader_js_code=loader_js_code, id=uuid.uuid4().hex # Prevent html/js selector collisions between cells ) diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index b412006..da29722 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -15,66 +15,24 @@ """Magic class for our visualization""" import argparse -import base64 -import random -import uuid -from enum import Enum, auto import json -import os -import sys from threading import Thread import re -from IPython.core.display import HTML, JSON -from IPython.core.magic import Magics, magics_class, cell_magic +from IPython.core.display import HTML, JSON, Javascript +from IPython.core.magic import Magics, magics_class, line_cell_magic from IPython.display import display, clear_output -from networkx import DiGraph -import ipywidgets as widgets -from ipywidgets import interact -from jinja2 import Template from spanner_graphs.exec_env import get_database_instance from spanner_graphs.graph_server import ( GraphServer, execute_query, execute_node_expansion, - validate_node_expansion_request ) from spanner_graphs.graph_visualization import generate_visualization_html +from spanner_graphs.gcp_helper import GcpHelper +from .utils import FileHandler -singleton_server_thread: Thread = None -def _load_file(path: list[str]) -> str: - file_path = os.path.sep.join(path) - if not os.path.exists(file_path): - raise FileNotFoundError(f"Template file not found: {file_path}") - - with open(file_path, 'r') as file: - content = file.read() - - return content - -def _load_image(path: list[str]) -> str: - file_path = os.path.sep.join(path) - if not os.path.exists(file_path): - print("image does not exist") - return '' - - if file_path.lower().endswith('.svg'): - with open(file_path, 'r') as file: - svg = file.read() - return base64.b64encode(svg.encode('utf-8')).decode('utf-8') - else: - with open(file_path, 'rb') as file: - return base64.b64decode(file.read()).decode('utf-8') - -def _parse_element_display(element_rep: str) -> dict[str, str]: - """Helper function to parse element display fields into a dict.""" - if not element_rep: - return {} - res = { - e.strip().split(":")[0].lower(): e.strip().split(":")[1] - for e in element_rep.strip().split(",") - } - return res +singleton_server_thread: Thread = None def is_colab() -> bool: """Check if code is running in Google Colab""" @@ -118,6 +76,23 @@ def receive_node_expansion_request(request: dict, params_str: str): return JSON(execute_node_expansion(params_str, request)) except BaseException as e: return JSON({"error": e}) + +def receive_instances_request(project: str): + try: + credentials = GcpHelper.get_default_credentials_with_project() + instances = GcpHelper.fetch_project_instances(credentials, project) + return JSON({"instances": instances}) + except Exception as e: + return JSON({"error": str(e), "instances": []}) + + +def receive_databases_request(project: str, instance: str): + try: + credentials = GcpHelper.get_default_credentials_with_project() + databases = GcpHelper.fetch_instance_databases(credentials, project, instance) + return JSON({"databases": databases}) + except Exception as e: + return JSON({"error": str(e), "databases": []}) @magics_class class NetworkVisualizationMagics(Magics): @@ -134,13 +109,16 @@ def __init__(self, shell): from google.colab import output output.register_callback('graph_visualization.Query', receive_query_request) output.register_callback('graph_visualization.NodeExpansion', receive_node_expansion_request) + + output.register_callback('graph_visualization.GetInstances', receive_instances_request) + output.register_callback('graph_visualization.GetDatabases', receive_databases_request) else: global singleton_server_thread alive = singleton_server_thread and singleton_server_thread.is_alive() if not alive: singleton_server_thread = GraphServer.init() - def visualize(self): + def visualize(self, show_config_popup=False): """Helper function to create and display the visualization""" # Extract the graph name from the query (if present) graph = "" @@ -159,13 +137,14 @@ def visualize(self): "database": self.args.database, "mock": self.args.mock, "graph": graph - })) + }), + show_config_on_load=show_config_popup + ) display(HTML(html_content)) - @cell_magic - def spanner_graph(self, line: str, cell: str): + @line_cell_magic + def spanner_graph(self, line: str, cell: str = None): """spanner_graph function""" - parser = argparse.ArgumentParser( description="Visualize network from Spanner database", exit_on_error=False) @@ -179,6 +158,42 @@ def spanner_graph(self, line: str, cell: str): help="Use mock database") try: + if not line.strip(): + self.args = argparse.Namespace( + project="", + instance="", + database="", + mock=False + ) + self.cell = "" + FileHandler.show_loader("Authenticating and fetching GCP resources...") + + try: + credentials = GcpHelper.get_default_credentials_with_project() + projects = GcpHelper.fetch_gcp_projects(credentials) + except Exception as e: + projects = {} + print(f"Error fetching GCP resources: {e}") + + REMOVE_LOADER = FileHandler.hide_loader() + display(Javascript(REMOVE_LOADER + '\nremoveLoader();')) + + html_content = generate_visualization_html( + query=cell, + port=GraphServer.port, + params=json.dumps({ + "project": "", + "instance": "", + "database": "", + "mock": False, + "graph": "" + }), + projects=json.dumps({"projects": projects}), + show_config_on_load=True + ) + display(HTML(html_content)) + return + args = parser.parse_args(line.split()) if not args.mock: if not (args.project and args.instance and args.database): @@ -189,15 +204,16 @@ def spanner_graph(self, line: str, cell: str): print("Error: Query is required.") return - self.args = parser.parse_args(line.split()) + self.args = args self.cell = cell self.database = get_database_instance( self.args.project, self.args.instance, self.args.database, - mock=self.args.mock) + mock=self.args.mock + ) clear_output(wait=True) - self.visualize() + self.visualize(show_config_popup=False) except BaseException as e: print(f"Error: {e}") print("Usage: %%spanner_graph --project PROJECT_ID " diff --git a/spanner_graphs/utils.py b/spanner_graphs/utils.py new file mode 100644 index 0000000..9bfc213 --- /dev/null +++ b/spanner_graphs/utils.py @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from IPython.display import display +from IPython.core.display import HTML +import os + +class FileHandler: + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + @staticmethod + def show_loader(message="Loading..."): + try: + path = os.path.join(FileHandler.BASE_DIR, "..", "frontend", "static", "loader.html") + with open(path, "r") as file: + html_template = file.read() + + loader_text = html_template.replace("{{message}}", message) + + display(HTML(loader_text)) + except Exception as e: + print(f"Error loading loader HTML: {e}") + + @staticmethod + def load_js(path: list[str]) -> str: + file_path = os.path.join(FileHandler.BASE_DIR, "..", *path) + if not os.path.exists(file_path): + raise FileNotFoundError(f"JS file not found: {file_path}") + + with open(file_path, 'r') as file: + return file.read() + + @staticmethod + def hide_loader() -> str: + return FileHandler.load_js(["frontend", "src", "loader.js"]) diff --git a/tests/graph_server_test.py b/tests/graph_server_test.py index 7b405e2..a4eca61 100644 --- a/tests/graph_server_test.py +++ b/tests/graph_server_test.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import patch, MagicMock import json +from google.cloud import spanner from spanner_graphs.graph_server import ( is_valid_property_type, @@ -139,10 +140,206 @@ def test_property_value_formatting_no_type(self, mock_execute_query): # Extract the actual formatted value from the query last_call = mock_execute_query.call_args[0] query = last_call[3] - where_line = [line for line in query.split('\n') if 'WHERE' in line][0] - expected_pattern = "n.test_property='''test_value'''" - self.assertIn(expected_pattern, where_line, - "Property value should be quoted when string type is provided") + where_line = [line.strip() for line in query.split('\n') if 'WHERE' in line][0] + + self.assertIn(f"n.{prop_dict['key']}", where_line, "Key not found in WHERE clause") + self.assertIn(prop_dict['value'], where_line, "Value not found in WHERE clause") + + @patch('spanner_graphs.graph_server.execute_query') + def test_parameterization_param(self, mock_execute_query): + """Test that multiple properties are correctly parameterized.""" + mock_execute_query.return_value = {"response": {"nodes": [], "edges": []}} + + prop_dicts = [ + {"key": "age", "value": "25", "type": "INT64"}, + {"key": "name", "value": "John", "type": "STRING"}, + {"key": "active", "value": "true", "type": "BOOL"} + ] + + params = json.dumps({ + "project": "test-project", + "instance": "test-instance", + "database": "test-database", + "graph": "test-graph", + }) + + request = { + "uid": "test-uid", + "node_labels": ["Person"], + "node_properties": prop_dicts, + "direction": "OUTGOING" + } + + execute_node_expansion( + params_str=params, + request=request + ) + + mock_execute_query.call_args = ( + ("project", "instance", "database", "MATCH (n:Person) WHERE n.age = @param_0 AND n.name = @param_1 AND n.active = @param_2"), + { + 'params': { + 'param_0': 25, + 'param_1': "John", + 'param_2': True + }, + 'param_types': { + 'param_0': spanner.param_types.INT64, + 'param_1': spanner.param_types.STRING, + 'param_2': spanner.param_types.BOOL + } + } + ) + + call_args = mock_execute_query.call_args + query = call_args[0][3] + + if call_args[1] and call_args[1].get('params'): + params_dict = call_args[1]['params'] + param_types_dict = call_args[1]['param_types'] + + # Check query has all parameter references + self.assertIn("n.age = @param_0", query) + self.assertIn("n.name = @param_1", query) + self.assertIn("n.active = @param_2", query) + + self.assertEqual(params_dict['param_0'], 25) + self.assertEqual(params_dict['param_1'], "John") + self.assertEqual(params_dict['param_2'], True) + + # Check parameter types + self.assertEqual(param_types_dict['param_0'], spanner.param_types.INT64) + self.assertEqual(param_types_dict['param_1'], spanner.param_types.STRING) + self.assertEqual(param_types_dict['param_2'], spanner.param_types.BOOL) + + @patch('spanner_graphs.graph_server.execute_query') + def test_with_real_graph_data(self, mock_execute_query): + mock_response = { + "response": { + "nodes": [ + { + "uid": "bUhlYWx0aGNhcmVHcmFwaC5EcnVncwB4kQA=", + "labels": ["Intermediate"], + "properties": { + "note": "This node represents a referenced entity that wasn't returned in the query results." + } + }, + { + "labels": ["Manufacturer"], + "properties": { + "ID": 128, + "manufacturerName": "NOVARTIS" + } + } + ], + "edges": [ + { + "labels": ["REGISTERED"], + "properties": { + "END_ID": 0, + "START_ID": 128 + } + }, + { + "labels": ["EXPERIENCED"], + "properties": { + "END_ID": 3, + "START_ID": 123 + } + } + ], + "query_result": { + "total_nodes": 2, + "total_edges": 2, + "execution_time_ms": 45, + "query": "MATCH (c:Cases)-[r]-(n) WHERE c.primaryid = 100654764 RETURN n, r" + } + } + } + + mock_execute_query.return_value = mock_response + + params_str = json.dumps({ + "project": "test-project", + "instance": "test-instance", + "database": "test-database", + "graph": "HealthcareGraph", + }) + + request = { + "uid": "mUhlYWx0aGNhcmVHcmFwaC5DYXNlcwB4kQA=", + "node_labels": [ + "Cases" + ], + "node_properties": [ + { + "key": "age", + "value": 56, + "type": "FLOAT64" + }, + { + "key": "ageUnit", + "value": "YR", + "type": "STRING" + }, + { + "key": "eventDate", + "value": "2014-03-25", + "type": "DATE" + }, + { + "key": "gender", + "value": "F", + "type": "STRING" + }, + { + "key": "primaryid", + "value": 100654764, + "type": "FLOAT64" + }, + { + "key": "reportDate", + "value": "2021-08-27", + "type": "DATE" + }, + { + "key": "reporterOccupation", + "value": "Physician", + "type": "STRING" + } + ], + "direction": "INCOMING" + } + + result = execute_node_expansion(params_str, request) + + mock_execute_query.assert_called_once() + + self.assertIn("response", result) + self.assertIn("nodes", result["response"]) + self.assertIn("edges", result["response"]) + self.assertIn("query_result", result["response"]) + self.assertIsInstance(result["response"]["nodes"], list) + self.assertIsInstance(result["response"]["edges"], list) + + self.assertEqual(len(result["response"]["nodes"]), 2) + self.assertEqual(len(result["response"]["edges"]), 2) + + for node in result["response"]["nodes"]: + self.assertIn("labels", node) + self.assertIn("properties", node) + self.assertIsInstance(node["labels"], list) + self.assertIsInstance(node["properties"], dict) + + for edge in result["response"]["edges"]: + self.assertIn("labels", edge) + self.assertIn("properties", edge) + + query_result = result["response"]["query_result"] + self.assertIn("total_nodes", query_result) + self.assertIn("total_edges", query_result) + self.assertIn("execution_time_ms", query_result) + if __name__ == '__main__': unittest.main() diff --git a/tests/magics_test.py b/tests/magics_test.py index fef2fac..4342174 100644 --- a/tests/magics_test.py +++ b/tests/magics_test.py @@ -1,8 +1,9 @@ import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, ANY from IPython.core.interactiveshell import InteractiveShell from spanner_graphs.graph_server import GraphServer from spanner_graphs.magics import NetworkVisualizationMagics, load_ipython_extension +from IPython.core.display import HTML class TestNetworkVisualizationMagics(unittest.TestCase): def setUp(self): @@ -91,6 +92,32 @@ def test_spanner_graph_magic_with_empty_cell(self): mock_print.assert_any_call( "Error: Query is required." ) + @patch('spanner_graphs.magics.get_database_instance') + @patch('spanner_graphs.magics.GraphServer') + @patch('spanner_graphs.magics.display') + def test_spanner_graph_with_cell_magic(self, mock_display, mock_server, mock_db): + cell_content = "SELECT * FROM some_table" + + self.magics.spanner_graph("", cell_content) + + self.assertTrue(any(isinstance(call.args[0], HTML) for call in mock_display.call_args_list), + "Expected display to be called with an HTML object") + + self.assertEqual(self.magics.args.project, "") + self.assertEqual(self.magics.args.instance, "") + self.assertEqual(self.magics.args.database, "") + + @patch('spanner_graphs.magics.display') + @patch('spanner_graphs.magics.FileHandler') + @patch('spanner_graphs.magics.GcpHelper') + def test_spanner_graph_with_line_magic(self, mock_gcp, mock_filehandler, mock_display): + mock_gcp.get_default_credentials_with_project.return_value = "fake_credentials" + mock_gcp.fetch_gcp_projects.return_value = {"proj1": "Project 1"} + + self.magics.spanner_graph("", None) + + mock_filehandler.show_loader.assert_called_once() + mock_filehandler.hide_loader.assert_called_once() if __name__ == '__main__': unittest.main()