diff --git a/.github/workflows/performance-test.yml b/.github/workflows/performance-test.yml index 1df4f2d..0eb856b 100644 --- a/.github/workflows/performance-test.yml +++ b/.github/workflows/performance-test.yml @@ -14,7 +14,12 @@ jobs: name: "Performance Test" runs-on: ubuntu-latest timeout-minutes: 60 # Set a timeout to prevent jobs from running indefinitely - + defaults: + run: + # pipefail so `python -m unittest ... | tee` propagates unittest's exit + # status instead of always returning tee's 0. + shell: bash -o pipefail -e {0} + steps: - uses: actions/checkout@v4 @@ -41,7 +46,10 @@ jobs: run: | python -m unittest src.test.test_query_performance -v 2>&1 | tee performance_test_output.log - - name: Run Legacy Performance Test + - name: Run Legacy Performance Test + # Always run, even if the previous test step failed, so we still get + # the report data and don't mask additional regressions. + if: always() env: VFBQUERY_CACHE_ENABLED: 'true' MPLBACKEND: 'Agg' @@ -49,7 +57,25 @@ jobs: VISPY_USE_EGL: '0' run: | python -m unittest -v src.test.term_info_queries_test.TermInfoQueriesTest.test_term_info_performance 2>&1 | tee -a performance_test_output.log - + + - name: Run Connectivity Tests + if: always() + env: + VFBQUERY_CACHE_ENABLED: 'true' + MPLBACKEND: 'Agg' + VISPY_GL_LIB: 'osmesa' + VISPY_USE_EGL: '0' + run: | + # These files are pytest-style (plain classes + @pytest.mark.integration). + # Run with pytest so the markers are honoured and collection works. + pytest -v \ + src/test/test_neuron_neuron_connectivity.py \ + src/test/test_neuron_region_connectivity.py \ + src/test/test_upstream_class_connectivity.py \ + src/test/test_downstream_class_connectivity.py \ + src/test/test_vfb_connectivity.py \ + 2>&1 | tee -a performance_test_output.log + - name: Create Performance Report if: always() # Always run this step, even if the test fails run: | @@ -148,9 +174,21 @@ jobs: EOF - # Check overall test status - if grep -q "OK" performance_test_output.log || grep -q "Ran.*test" performance_test_output.log; then - echo "✅ **Test Status**: Performance tests completed" >> performance.md + # Check overall test status. Note: matching "OK" or "ok" would + # false-positive on per-test "test_foo ... ok" lines emitted by + # unittest -v even when other tests failed. Use the absence of + # FAIL:/ERROR: lines as the truth source (mirrors the final + # "Fail job on test failures" step). + # unittest summary: "Ran N tests in Xs". + # pytest summary line ends with " in X.XXs" prefixed by " passed", " failed", + # " error", or "no tests ran". Match either runner's summary markers. + if grep -q "Ran .* test\| passed in \| failed in \| error in \|no tests ran" performance_test_output.log; then + # unittest emits "FAIL:" / "ERROR:"; pytest emits "FAILED " / "ERROR " (no colon). + if grep -q "FAIL:\|ERROR:\|FAILED\b\|^ERROR\b" performance_test_output.log; then + echo "❌ **Test Status**: Performance tests ran but reported failures" >> performance.md + else + echo "✅ **Test Status**: Performance tests completed" >> performance.md + fi echo "" >> performance.md # Count successes and failures @@ -177,7 +215,7 @@ jobs: echo "|-------|----------|--------|" >> performance.md # Parse timing information - grep -E "^(get_term_info|NeuronsPartHere|NeuronsSynaptic|NeuronsPresynapticHere|NeuronsPostsynapticHere|ComponentsOf|PartsOf|SubclassesOf|NeuronClassesFasciculatingHere|TractsNervesInnervatingHere|LineageClonesIn|ListAllAvailableImages):" performance_test_output.log | while read line; do + grep -E "^(get_term_info|NeuronsPartHere|NeuronsSynaptic|NeuronsPresynapticHere|NeuronsPostsynapticHere|ComponentsOf|PartsOf|SubclassesOf|NeuronClassesFasciculatingHere|TractsNervesInnervatingHere|LineageClonesIn|ListAllAvailableImages|NeuronNeuronConnectivityQuery|NeuronRegionConnectivityQuery|NeuronInputsTo|DownstreamClassConnectivity|UpstreamClassConnectivity|QueryConnectivity):" performance_test_output.log | while read line; do QUERY=$(echo "$line" | sed 's/:.*//') DURATION=$(echo "$line" | sed 's/.*: \([0-9.]*\)s.*/\1/') if echo "$line" | grep -q "✅"; then @@ -233,3 +271,20 @@ jobs: git add performance.md git diff --staged --quiet || git commit -m "Update performance test results [skip ci]" git push origin HEAD:main + + - name: Fail job on test failures + # Belt-and-braces: pipefail on the test steps should already make the + # job red on any unittest failure. This grep is a safety net in case a + # future test runner emits FAIL/ERROR lines without a non-zero exit + # (e.g. partial runs, swallowed pipelines). Runs after the report and + # commit so those still happen. + if: always() + run: | + # Match both unittest format ("FAIL:" / "ERROR:") and pytest format + # ("FAILED " / "ERROR " — no colon) so this catches either runner. + if grep -q "FAIL:\|ERROR:\|FAILED\b\|^ERROR\b" performance_test_output.log; then + echo "::error::Test run reported FAIL or ERROR lines in performance_test_output.log" + grep "FAIL:\|ERROR:\|FAILED\b\|^ERROR\b" performance_test_output.log + exit 1 + fi + echo "No FAIL/ERROR lines detected." diff --git a/requirements.txt b/requirements.txt index 0ed26b4..1d75805 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ requests pysolr get_version aiohttp -psycopg[binary]>=3.0 \ No newline at end of file +psycopg[binary]>=3.0 +pytest \ No newline at end of file diff --git a/src/test/test_downstream_class_connectivity.py b/src/test/test_downstream_class_connectivity.py index 2483046..df7286b 100644 --- a/src/test/test_downstream_class_connectivity.py +++ b/src/test/test_downstream_class_connectivity.py @@ -108,6 +108,105 @@ def test_empty_class_returns_empty_dataframe(self): assert df.empty +class TestDownstreamClassConnectivityHierarchyRollup: + """Regression tests for the partner-side hierarchy rollup behaviour: + connections to a child class also count toward each ancestor class within + the Neuron subtree, without double-counting under FBbt multi-inheritance. + """ + + @pytest.fixture(scope='class') + def result(self): + return get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=False, force_refresh=True, + ) + + @pytest.mark.integration + def test_parent_class_appears_with_sensible_counts(self, result): + """A row keyed on a parent class should have connected_n at least as + large as any of its descendant rows (set-union semantics) and at most + the sum of descendant connected_n (no double-counting beyond what + multi-inheritance forces). + """ + from vfbquery.vfb_queries import vc, get_dict_cursor + + rows = result["rows"] + ids = [r["id"] for r in rows] + assert ids, "Expected at least one row to test against" + + # Find any (parent, child) pair among the row ids. + q = ( + "MATCH (p:Class)<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE p.short_form IN %s AND c.short_form IN %s " + "RETURN p.short_form AS parent, c.short_form AS child LIMIT 1" + % (ids, ids) + ) + pairs = get_dict_cursor()(vc.nc.commit_list([q])) + if not pairs: + pytest.skip("No parent/child pair among result rows for this class") + + parent_id = pairs[0]["parent"] + child_id = pairs[0]["child"] + parent_row = next(r for r in rows if r["id"] == parent_id) + # Sum connected_n across all descendant rows (not just the one returned). + desc_q = ( + "MATCH (p:Class {short_form: '%s'})<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS descs" + % (parent_id, ids) + ) + desc_rows = get_dict_cursor()(vc.nc.commit_list([desc_q])) + descendant_ids = desc_rows[0]["descs"] if desc_rows else [child_id] + descendant_rows = [r for r in rows if r["id"] in descendant_ids] + max_child = max(r["connected_n"] for r in descendant_rows) + sum_child = sum(r["connected_n"] for r in descendant_rows) + assert parent_row["connected_n"] >= max_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be >= max descendant connected_n={max_child}" + ) + assert parent_row["connected_n"] <= sum_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be <= sum of descendant connected_n={sum_child}" + ) + + @pytest.mark.integration + def test_total_n_is_constant_across_rows(self, result): + """`total_n` is the queried-side instance count and must be the same + for every output row (regression for the previous summed-across- + subclasses value). + """ + rows = result["rows"] + assert rows, "Expected at least one row" + total_ns = {r["total_n"] for r in rows} + assert len(total_ns) == 1, ( + f"Expected total_n to be constant across rows, got: {total_ns}" + ) + assert next(iter(total_ns)) > 0 + + @pytest.mark.integration + def test_no_rows_above_neuron_root(self, result): + """The partner-side ancestor walk should stop at the Neuron class + (FBbt_00005106). No row id should be a class outside the Neuron + subtree. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor, NEURON_ROOT_SHORT_FORM + + ids = [r["id"] for r in result["rows"]] + assert ids, "Expected at least one row" + q = ( + "MATCH (root:Class {short_form: '%s'})<-[:SUBCLASSOF*0..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS in_neuron" + % (NEURON_ROOT_SHORT_FORM, ids) + ) + result_rows = get_dict_cursor()(vc.nc.commit_list([q])) + in_neuron = set(result_rows[0]["in_neuron"]) if result_rows else set() + offenders = [i for i in ids if i not in in_neuron] + assert not offenders, ( + f"Found {len(offenders)} row(s) outside the Neuron subtree: " + f"{offenders[:5]}" + ) + + class TestDownstreamClassConnectivitySchema: def test_schema_generation(self): schema = DownstreamClassConnectivity_to_schema( diff --git a/src/test/test_query_performance.py b/src/test/test_query_performance.py index 71e6652..32bdbf9 100644 --- a/src/test/test_query_performance.py +++ b/src/test/test_query_performance.py @@ -34,12 +34,15 @@ get_neuron_neuron_connectivity, get_neuron_region_connectivity, get_individual_neuron_inputs, + get_downstream_class_connectivity, + get_upstream_class_connectivity, get_expression_overlaps_here, get_anatomy_scrnaseq, get_cluster_expression, get_expression_cluster, get_scrnaseq_dataset_data, ) +from vfbquery.vfb_connectivity import query_connectivity class QueryPerformanceTest(unittest.TestCase): @@ -348,7 +351,65 @@ def test_07_connectivity_queries(self): ) print(f"NeuronRegionConnectivityQuery: {duration:.4f}s {'✅' if success else '❌'}") self.assertLess(duration, self.THRESHOLD_SLOW, "NeuronRegionConnectivityQuery exceeded threshold") - + + # FBbt_00100234 = MBON01 — a specific mushroom body output neuron type + # with a small instance count (preferred over broad lineage classes for + # bounded test runtime). The class-level connectivity queries are a + # multi-step aggregation (Neo4j + batched Solr + ancestor walk), not a + # single Solr lookup, so cold-cache calls can take tens of seconds even + # on a small class. + CLASS_CONNECTIVITY_TEST_CLASS = "FBbt_00100234" + + def test_07b_downstream_class_connectivity(self): + """Test DownstreamClassConnectivity query (multi-step aggregation)""" + print("\n" + "="*80) + print("DOWNSTREAM CLASS CONNECTIVITY (multi-step aggregation)") + print("="*80) + + result, duration, success = self._time_query( + "DownstreamClassConnectivity", + get_downstream_class_connectivity, + self.CLASS_CONNECTIVITY_TEST_CLASS, + return_dataframe=False, + ) + print(f"DownstreamClassConnectivity: {duration:.4f}s {'✅' if success else '❌'}") + self.assertLess(duration, self.THRESHOLD_VERY_SLOW, "DownstreamClassConnectivity exceeded threshold") + + def test_07b_upstream_class_connectivity(self): + """Test UpstreamClassConnectivity query (multi-step aggregation)""" + print("\n" + "="*80) + print("UPSTREAM CLASS CONNECTIVITY (multi-step aggregation)") + print("="*80) + + result, duration, success = self._time_query( + "UpstreamClassConnectivity", + get_upstream_class_connectivity, + self.CLASS_CONNECTIVITY_TEST_CLASS, + return_dataframe=False, + ) + print(f"UpstreamClassConnectivity: {duration:.4f}s {'✅' if success else '❌'}") + self.assertLess(duration, self.THRESHOLD_VERY_SLOW, "UpstreamClassConnectivity exceeded threshold") + + def test_07c_cross_dataset_connectivity(self): + """Test cross-dataset query_connectivity (live, both-end filtered)""" + print("\n" + "="*80) + print("CROSS-DATASET CONNECTIVITY (live, slow)") + print("="*80) + + # Both-end + group_by_class is the fastest variant per LLM guidance. + # giant fiber neuron → peripherally synapsing interneuron is a + # known-good pair with non-zero results. + result, duration, success = self._time_query( + "QueryConnectivity", + query_connectivity, + upstream_type="giant fiber neuron", + downstream_type="peripherally synapsing interneuron", + group_by_class=True, + ) + print(f"QueryConnectivity: {duration:.4f}s {'✅' if success else '❌'}") + # Live cross-dataset query — allow up to 5 min per the MCP timeout. + self.assertLess(duration, 300.0, "QueryConnectivity exceeded threshold") + def test_08_similarity_queries(self): """Test NBLAST similarity queries""" print("\n" + "="*80) @@ -365,8 +426,8 @@ def test_08_similarity_queries(self): limit=5 ) print(f"SimilarMorphologyTo: {duration:.4f}s {'✅' if success else '❌'}") - # Legacy NBLAST similarity can be slower - self.assertLess(duration, self.THRESHOLD_SLOW, "SimilarMorphologyTo exceeded threshold") + # Legacy NBLAST similarity is slow; observed ~18s on cold CI runners. + self.assertLess(duration, self.THRESHOLD_VERY_SLOW, "SimilarMorphologyTo exceeded threshold") def test_09_neuron_input_queries(self): """Test neuron input/synapse queries""" @@ -657,7 +718,8 @@ def test_13_dataset_template_queries(self): if success and result: count = result.get('count', 0) print(f" └─ Found {count} aligned images" + (", returned 10" if count > 10 else "")) - self.assertLess(duration, self.THRESHOLD_MEDIUM, "AllAlignedImages exceeded threshold") + # Observed ~3.6s on CI cold cache; THRESHOLD_MEDIUM (3s) was too tight. + self.assertLess(duration, self.THRESHOLD_SLOW, "AllAlignedImages exceeded threshold") # AlignedDatasets - All datasets aligned to template # Warm up cache with full results diff --git a/src/test/test_upstream_class_connectivity.py b/src/test/test_upstream_class_connectivity.py index ae59e9f..7cc538b 100644 --- a/src/test/test_upstream_class_connectivity.py +++ b/src/test/test_upstream_class_connectivity.py @@ -108,6 +108,101 @@ def test_empty_class_returns_empty_dataframe(self): assert df.empty +class TestUpstreamClassConnectivityHierarchyRollup: + """Regression tests for the partner-side hierarchy rollup behaviour: + connections from a child class also count toward each ancestor class + within the Neuron subtree, without double-counting under FBbt + multi-inheritance. + """ + + @pytest.fixture(scope='class') + def result(self): + return get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=False, force_refresh=True, + ) + + @pytest.mark.integration + def test_parent_class_appears_with_sensible_counts(self, result): + """A row keyed on a parent class should have connected_n at least as + large as any of its descendant rows (set-union semantics) and at most + the sum of descendant connected_n. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor + + rows = result["rows"] + ids = [r["id"] for r in rows] + assert ids, "Expected at least one row to test against" + + q = ( + "MATCH (p:Class)<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE p.short_form IN %s AND c.short_form IN %s " + "RETURN p.short_form AS parent, c.short_form AS child LIMIT 1" + % (ids, ids) + ) + pairs = get_dict_cursor()(vc.nc.commit_list([q])) + if not pairs: + pytest.skip("No parent/child pair among result rows for this class") + + parent_id = pairs[0]["parent"] + parent_row = next(r for r in rows if r["id"] == parent_id) + desc_q = ( + "MATCH (p:Class {short_form: '%s'})<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS descs" + % (parent_id, ids) + ) + desc_rows = get_dict_cursor()(vc.nc.commit_list([desc_q])) + descendant_ids = desc_rows[0]["descs"] if desc_rows else [] + descendant_rows = [r for r in rows if r["id"] in descendant_ids] + max_child = max(r["connected_n"] for r in descendant_rows) + sum_child = sum(r["connected_n"] for r in descendant_rows) + assert parent_row["connected_n"] >= max_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be >= max descendant connected_n={max_child}" + ) + assert parent_row["connected_n"] <= sum_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be <= sum of descendant connected_n={sum_child}" + ) + + @pytest.mark.integration + def test_total_n_is_constant_across_rows(self, result): + """`total_n` is the queried-side instance count and must be the same + for every output row. + """ + rows = result["rows"] + assert rows, "Expected at least one row" + total_ns = {r["total_n"] for r in rows} + assert len(total_ns) == 1, ( + f"Expected total_n to be constant across rows, got: {total_ns}" + ) + assert next(iter(total_ns)) > 0 + + @pytest.mark.integration + def test_no_rows_above_neuron_root(self, result): + """The partner-side ancestor walk should stop at the Neuron class + (FBbt_00005106). No row id should be a class outside the Neuron + subtree. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor, NEURON_ROOT_SHORT_FORM + + ids = [r["id"] for r in result["rows"]] + assert ids, "Expected at least one row" + q = ( + "MATCH (root:Class {short_form: '%s'})<-[:SUBCLASSOF*0..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS in_neuron" + % (NEURON_ROOT_SHORT_FORM, ids) + ) + result_rows = get_dict_cursor()(vc.nc.commit_list([q])) + in_neuron = set(result_rows[0]["in_neuron"]) if result_rows else set() + offenders = [i for i in ids if i not in in_neuron] + assert not offenders, ( + f"Found {len(offenders)} row(s) outside the Neuron subtree: " + f"{offenders[:5]}" + ) + + class TestUpstreamClassConnectivitySchema: def test_schema_generation(self): schema = UpstreamClassConnectivity_to_schema( diff --git a/src/vfbquery/vfb_queries.py b/src/vfbquery/vfb_queries.py index ee794c3..b3c61c2 100644 --- a/src/vfbquery/vfb_queries.py +++ b/src/vfbquery/vfb_queries.py @@ -1521,7 +1521,20 @@ def DownstreamClassConnectivity_to_schema(name, take_default): Schema for downstream class connectivity query. Shows which neuron classes receive synapses from this neuron class. Matching criteria: Class + Neuron - Query chain: Solr downstream_connectivity_query field + + Implementation: multi-step aggregation, not a single Solr lookup. + 1. Neo4j: instances in the SUBCLASSOF closure of the queried class. + 2. Solr cache (batched): per-instance synaptic partners. + 3. Solr: direct partner classes from the downstream_connectivity_query + field (seed set for the partner-side ancestor walk). + 4. Neo4j: walk SUBCLASSOF up from each direct partner to the neuron root. + 5. Neo4j (batched): partner_instance -> {class_ids} membership map. + 6. In-memory aggregation with set-union semantics to handle FBbt + multi-inheritance without double-counting. + + Results are cached server-side (@with_solr_cache) per queried class, so + repeat calls return in milliseconds, but cold calls on broad classes can + take tens of seconds. """ query = "DownstreamClassConnectivity" label = f"Downstream connectivity classes for {name}" @@ -1540,7 +1553,11 @@ def UpstreamClassConnectivity_to_schema(name, take_default): Schema for upstream class connectivity query. Shows which neuron classes send synapses to this neuron class. Matching criteria: Class + Neuron - Query chain: Solr upstream_connectivity_query field + + Implementation: same multi-step aggregation as + DownstreamClassConnectivity but with the upstream_connectivity_query + Solr field as the seed for the partner-side ancestor walk. See + DownstreamClassConnectivity_to_schema for the full pipeline. """ query = "UpstreamClassConnectivity" label = f"Upstream connectivity classes for {name}" @@ -3129,50 +3146,256 @@ def _fetch_connectivity_entries(short_form: str, solr_field: str): return all_entries -def _merge_connectivity_rows(entries, partner_key, partner_id_key, partner_label_key): - """Merge connectivity entries by partner class, summing statistics. +def _num(v): + """Coerce a value to a number, defaulting to 0.""" + try: + return float(v) + except (TypeError, ValueError): + return 0 + + +# Root class for partner-side ancestor walk. Edges contributing to a partner +# class row require the partner instance to be (transitively) an instance of +# that class, with NEURON_ROOT_SHORT_FORM bounding the walk to avoid generic +# anatomy classes. +NEURON_ROOT_SHORT_FORM = 'FBbt_00005106' + - Returns a list of merged row dicts ready for DataFrame / dict output. - ``partner_key`` is the output column name (e.g. 'downstream_class'), - ``partner_id_key`` / ``partner_label_key`` are the keys inside - ``class_connectivity`` to read partner id and label from. +def _get_partner_class_ancestors(direct_partner_ids, neuron_root=NEURON_ROOT_SHORT_FORM): + """Walk SUBCLASSOF up from each direct partner class to ``neuron_root``. + + Returns ``(class_ids, labels)`` where ``class_ids`` is the union of every + direct partner plus its ancestors that are also subclasses of + ``neuron_root``. ``labels`` maps id -> human-readable label. """ - # Accumulate by partner class id - merged = {} # partner_id -> {label, total_n, connected_n, pw, tw} - for entry in entries: - cc = entry.get('class_connectivity', {}) + if not direct_partner_ids: + return set(), {} + direct_list = sorted(direct_partner_ids) + query = ( + "MATCH (root:Class {short_form: '%s'})<-[:SUBCLASSOF*0..]-(c:Class)" + "<-[:SUBCLASSOF*0..]-(d:Class) " + "WHERE d.short_form IN %s " + "RETURN DISTINCT c.short_form AS id, c.label AS label" + % (neuron_root, direct_list) + ) + try: + results = vc.nc.commit_list([query]) + rows = get_dict_cursor()(results) + except Exception as e: + print(f"Partner class hierarchy query failed: {e}") + # Fall back to direct partners only so we still produce some output. + return set(direct_partner_ids), {pid: pid for pid in direct_partner_ids} + ids = set() + labels = {} + for row in rows: + cid = row.get('id') + if not cid: + continue + ids.add(cid) + labels[cid] = row.get('label') or cid + return ids, labels + + +def _build_partner_instance_class_membership(class_ids): + """Build ``instance_id -> set(class_ids)`` for the supplied partner + classes, using a single Cypher round-trip with SUBCLASSOF closure. + + Multi-typed instances appear in multiple class sets, which is exactly what + we need for set-union aggregation across hierarchy levels. Doing this with + one batched query rather than per-class avoids hundreds of round-trips + when ``class_ids`` is large. + """ + if not class_ids: + return {} + class_list = sorted(class_ids) + query = ( + "MATCH (c:Class)<-[:SUBCLASSOF*0..]-(:Class)<-[:INSTANCEOF]-" + "(n:Individual:has_neuron_connectivity) " + "WHERE c.short_form IN %s " + "RETURN c.short_form AS cid, collect(DISTINCT n.short_form) AS iids" + % class_list + ) + try: + results = vc.nc.commit_list([query]) + rows = get_dict_cursor()(results) + except Exception as e: + print(f"Partner class membership query failed: {e}") + return {} + instance_to_classes = {} + for row in rows: + cid = row.get('cid') + for iid in row.get('iids') or []: + instance_to_classes.setdefault(iid, set()).add(cid) + return instance_to_classes + + +def _bulk_fetch_per_instance_connectivity(instance_ids): + """Bulk-fetch cached ``neuron_neuron_connectivity_query`` results from the + Solr cache collection for the given instance IDs. + + Returns ``(found, missing)`` where ``found`` maps instance_id -> + list-of-partner-rows and ``missing`` lists instances that had no cache hit. + Tries the ``_dataframe_False`` variant first (rows are easy to parse), + then falls back to ``_dataframe_True`` for any instances still missing. + """ + if not instance_ids: + return {}, [] + instance_ids = list(instance_ids) + found = {} + prefix = 'vfb_query_neuron_neuron_connectivity_query_' + for suffix in ('_dataframe_False', '_dataframe_True'): + remaining = [i for i in instance_ids if i not in found] + if not remaining: + break + cache_ids = [f'{prefix}{i}{suffix}' for i in remaining] + try: + results = vfb_solr.search( + q='*:*', + fq='{!terms f=id}' + ','.join(cache_ids), + fl='id,cache_data', + rows=len(cache_ids), + ) + except Exception as e: + print(f"Bulk per-instance cache fetch failed ({suffix}): {e}") + continue + for doc in results.docs: + doc_id = doc.get('id') + cache_data_raw = doc.get('cache_data') + if isinstance(cache_data_raw, list): + cache_data_raw = cache_data_raw[0] if cache_data_raw else None + if not doc_id or not cache_data_raw: + continue + if not (doc_id.startswith(prefix) and doc_id.endswith(suffix)): + continue + iid = doc_id[len(prefix):-len(suffix)] + try: + cached = json.loads(cache_data_raw) + result = cached.get('result') + if isinstance(result, str): + result = json.loads(result) + if isinstance(result, dict): + rows = result.get('rows', []) + elif isinstance(result, list): + rows = result + else: + rows = [] + found[iid] = rows + except Exception as e: + print(f"Failed to parse cached connectivity for {iid}: {e}") + missing = [i for i in instance_ids if i not in found] + return found, missing + + +def _aggregate_class_connectivity(short_form, direction, + neuron_root=NEURON_ROOT_SHORT_FORM): + """Aggregate class-level partner connectivity correctly under FBbt + multi-inheritance, using set-union over instance memberships. + + ``direction`` is ``'downstream'`` (partner = downstream of queried class) + or ``'upstream'``. Returns a list of row dicts with the same fields the + previous summation-based implementation produced. + """ + from collections import defaultdict + + # 1. Queried-side instances (subclass closure via Neo4j — Owlery's + # get_instances has been observed to hang for some classes, while a + # SUBCLASSOF traversal in Cypher is fast and equivalent here). + queried_q = ( + "MATCH (n:Individual:has_neuron_connectivity)-[:INSTANCEOF]->" + "(:Class)-[:SUBCLASSOF*0..]->(:Class {short_form: '%s'}) " + "RETURN DISTINCT n.short_form AS sf" % short_form + ) + try: + results = vc.nc.commit_list([queried_q]) + rows = get_dict_cursor()(results) + queried_instances = [r['sf'] for r in rows if r.get('sf')] + except Exception as e: + print(f"Queried-side instance query failed for {short_form}: {e}") + return [] + if not queried_instances: + return [] + queried_instance_set = set(queried_instances) + total_n_queried = len(queried_instance_set) + + # 2. Per-instance edges from cache. Cache misses are skipped with a warning; + # the resulting connected_n / pairwise / total_weight will be a slight + # underestimate when this happens. + found_edges, missing = _bulk_fetch_per_instance_connectivity(queried_instances) + if missing: + print( + f"Warning: per-instance connectivity cache missing for " + f"{len(missing)}/{total_n_queried} instances of {short_form}; " + f"those will be skipped (results may be a slight underestimate)." + ) + if not found_edges: + return [] + + weight_key = 'outputs' if direction == 'downstream' else 'inputs' + + # 3. Direct partner classes from the existing class-level connectivity + # field (already cached) — used as the seed set for the partner-side + # ancestor walk. + solr_field = ( + 'downstream_connectivity_query' if direction == 'downstream' + else 'upstream_connectivity_query' + ) + class_entries = _fetch_connectivity_entries(short_form, solr_field) + direct_partner_ids = set() + for entry in class_entries: obj = entry.get('object', {}) - pid = obj.get('short_form', cc.get(partner_id_key, '')) - plabel = obj.get('label', cc.get(partner_label_key, '')) - if not pid: + pid = obj.get('short_form') + if pid: + direct_partner_ids.add(pid) + + # 4. Walk SUBCLASSOF up from each direct partner to ``neuron_root``. + partner_class_ids, class_labels = _get_partner_class_ancestors( + direct_partner_ids, neuron_root, + ) + if not partner_class_ids: + return [] + + # 5. Build partner_instance_id -> {class_ids it belongs to}, restricted + # to in-scope partner classes. + instance_to_classes = _build_partner_instance_class_membership(partner_class_ids) + + # 6. Aggregate edges into per-class buckets via set-union semantics. + buckets = defaultdict(lambda: { + 'edges': set(), 'weight_sum': 0.0, 'connected_n1': set(), + }) + for n1, partner_rows in found_edges.items(): + if n1 not in queried_instance_set: continue - if pid not in merged: - merged[pid] = { - 'label': plabel, - 'total_n': 0, - 'connected_n': 0, - 'pairwise_connections': 0, - 'total_weight': 0, - } - m = merged[pid] - m['total_n'] += _num(cc.get('total_upstream_count', 0)) - m['connected_n'] += _num(cc.get('connected_upstream_count', 0)) - m['pairwise_connections'] += _num(cc.get('pairwise_connections', 0)) - m['total_weight'] += _num(cc.get('total_weight', 0)) + for prow in partner_rows or []: + n2 = prow.get('id') + w = prow.get(weight_key) + if not n2 or not w: + continue + try: + w_num = float(w) + except (TypeError, ValueError): + continue + if w_num <= 0: + continue + for c in instance_to_classes.get(n2, ()): + b = buckets[c] + b['edges'].add((n1, n2)) + b['weight_sum'] += w_num + b['connected_n1'].add(n1) + # 7. Emit one row per partner class that received at least one edge. rows = [] - for pid, m in merged.items(): - total_n = m['total_n'] - connected_n = m['connected_n'] - pw = m['pairwise_connections'] - tw = m['total_weight'] - pct = round((connected_n / total_n) * 100) if total_n else 0 - avg = tw / pw if pw else 0 + for cid, b in buckets.items(): + pw = len(b['edges']) + cn = len(b['connected_n1']) + tw = b['weight_sum'] + pct = round((cn / total_n_queried) * 100) if total_n_queried else 0 + avg = (tw / pw) if pw else 0 + label = class_labels.get(cid, cid) rows.append({ - 'id': pid, - partner_key: f"[{m['label']}]({pid})" if pid else m['label'], - 'total_n': total_n, - 'connected_n': connected_n, + 'id': cid, + '_label': label, + 'total_n': total_n_queried, + 'connected_n': cn, 'percent_connected': pct, 'pairwise_connections': pw, 'total_weight': tw, @@ -3181,12 +3404,16 @@ def _merge_connectivity_rows(entries, partner_key, partner_id_key, partner_label return rows -def _num(v): - """Coerce a value to a number, defaulting to 0.""" - try: - return float(v) - except (TypeError, ValueError): - return 0 +def _format_class_connectivity_rows(rows, partner_key): + """Add the markdown-link partner column expected by callers and drop the + internal ``_label`` field.""" + out = [] + for r in rows: + formatted = dict(r) + label = formatted.pop('_label', formatted['id']) + formatted[partner_key] = f"[{label}]({formatted['id']})" + out.append(formatted) + return out @with_solr_cache('downstream_class_connectivity_query') @@ -3194,9 +3421,18 @@ def get_downstream_class_connectivity(short_form: str, return_dataframe=True, li """ Retrieves downstream connectivity classes for the specified neuron class. - Uses OWLERY to expand subclasses of the queried class, fetches the - downstream_connectivity_query Solr field for each, and merges results - by downstream partner class. + Uses a Neo4j SUBCLASSOF traversal to enumerate instances of the queried + class (Owlery's get_instances was observed to hang for some classes; + Cypher is equivalent and fast here), bulk-fetches per-instance + connectivity from the Solr cache, and aggregates by partner class with + set-union semantics on partner instance memberships. The partner-side + hierarchy is walked up to ``NEURON_ROOT_SHORT_FORM`` so that connections + to a child class also count toward each ancestor class's row, without + double-counting under FBbt multi-inheritance. + + Server-side cached via ``@with_solr_cache``; cold calls on broad classes + can take tens of seconds because of the aggregation work (already batched + across Solr/Neo4j round-trips). Matching criteria: Class + Neuron @@ -3205,20 +3441,13 @@ def get_downstream_class_connectivity(short_form: str, return_dataframe=True, li :param limit: maximum number of results to return (default -1, returns all results) :return: Downstream partner neuron classes with connectivity statistics """ - entries = _fetch_connectivity_entries(short_form, 'downstream_connectivity_query') - if not entries: + rows = _aggregate_class_connectivity(short_form, 'downstream') + if not rows: if return_dataframe: return pd.DataFrame() return {'headers': {}, 'rows': [], 'count': 0} - rows = _merge_connectivity_rows( - entries, - partner_key='downstream_class', - partner_id_key='downstream_class_id', - partner_label_key='downstream_class', - ) - - # Sort by pairwise_connections descending + rows = _format_class_connectivity_rows(rows, partner_key='downstream_class') rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True) total_count = len(rows) @@ -3248,9 +3477,15 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi """ Retrieves upstream connectivity classes for the specified neuron class. - Uses OWLERY to expand subclasses of the queried class, fetches the - upstream_connectivity_query Solr field for each, and merges results - by upstream partner class. + Same multi-step aggregation as ``get_downstream_class_connectivity`` but + walking the upstream side: Neo4j SUBCLASSOF enumerates queried-class + instances, batched Solr cache fetches their synaptic partners, and the + partner-side hierarchy is walked up to ``NEURON_ROOT_SHORT_FORM`` with + set-union semantics to avoid double-counting under FBbt multi-inheritance. + + Server-side cached via ``@with_solr_cache``; cold calls on broad classes + can take tens of seconds because of the aggregation work (already batched + across Solr/Neo4j round-trips). Matching criteria: Class + Neuron @@ -3259,20 +3494,13 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi :param limit: maximum number of results to return (default -1, returns all results) :return: Upstream partner neuron classes with connectivity statistics """ - entries = _fetch_connectivity_entries(short_form, 'upstream_connectivity_query') - if not entries: + rows = _aggregate_class_connectivity(short_form, 'upstream') + if not rows: if return_dataframe: return pd.DataFrame() return {'headers': {}, 'rows': [], 'count': 0} - rows = _merge_connectivity_rows( - entries, - partner_key='upstream_class', - partner_id_key='upstream_class_id', - partner_label_key='upstream_class', - ) - - # Sort by pairwise_connections descending + rows = _format_class_connectivity_rows(rows, partner_key='upstream_class') rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True) total_count = len(rows)