Skip to content

Commit 826caf6

Browse files
committed
get_connected_neurons_by_type should return all class combinations for group_by_class
1 parent 392a42a commit 826caf6

1 file changed

Lines changed: 132 additions & 22 deletions

File tree

src/vfb_connect/cross_server_tools.py

Lines changed: 132 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -523,14 +523,24 @@ def get_connected_neurons_by_type(self, weight, upstream_type=None, downstream_t
523523
group_by_class=False, exclude_dbs=['hb', 'fafb'], return_dataframe=True, verbose=False):
524524

525525
"""Get all synaptic connections between individual neurons of `upstream_type` and `downstream_type` where
526-
synapse count >= `weight`. At least one of 'upstream_type' or downstream_type must be specified.
526+
synapse count >= `weight`. At least one of `upstream_type` or `downstream_type` must be specified.
527527
528+
When `group_by_class=True`, results are aggregated by class with subclass closure: a connection
529+
contributes to the row for every (upstream_ancestor, downstream_ancestor) pair within the
530+
`upstream_type`/`downstream_type` scope, not only the directly asserted INSTANCEOF class. This
531+
means a single connection appears in multiple rows (once per ancestor pair) and the per-row
532+
`pairwise_connections` and `total_weight` will not sum to the underlying connection counts.
533+
`total_upstream_count` is also computed with subclass closure, so `percent_connected` is
534+
internally consistent within each row.
535+
536+
:param weight: Minimum synapse count for a connection to be included.
528537
:param upstream_type: The upstream neuron type (e.g., 'GABAergic neuron').
529538
:param downstream_type: The downstream neuron type (e.g., 'Descending neuron').
530-
:param group_by_class: If `True`, return connectivity results aggregated by class rather than per neuron. Default `False`.
531539
:param query_by_label: Optional. Specify neuron type by label if `True` (default) or by short_form ID if `False`.
540+
:param group_by_class: Optional. If `True`, return connectivity results aggregated by class (with subclass closure, see note above) rather than per neuron. Default `False`.
532541
:param exclude_dbs: Optional. List of databases (short_forms or symbols) to exclude from results. Hemibrain and catmaid FAFB excluded by default.
533542
:param return_dataframe: Optional. Returns pandas DataFrame if `True`, otherwise returns list of dicts. Default `True`.
543+
:param verbose: Optional. If `True`, print the Cypher queries used. Default `False`.
534544
:return: A DataFrame or list of synaptic connections between specified neuron types.
535545
:rtype: pandas.DataFrame or list of dicts
536546
"""
@@ -584,26 +594,126 @@ def get_connected_neurons_by_type(self, weight, upstream_type=None, downstream_t
584594
"s2.short_form AS down_data_source, r2.accession[0] AS down_accession ")
585595

586596
else:
587-
cypher_ql.append("WITH c1, c2, count(*) as pairwise_connections, sum(r.weight[0]) as total_weight, "
588-
"count(distinct n1) as connected_upstream_count \n\n"
589-
"MATCH (c1)<-[:INSTANCEOF]-(all_n1:Individual:has_neuron_connectivity)%s \n\n"
590-
"WITH c1, c2, pairwise_connections, total_weight, connected_upstream_count, "
591-
"count(distinct all_n1) as total_upstream_count \n\n"
592-
"RETURN c1.label AS upstream_class, "
593-
"c1.short_form AS upstream_class_id, "
594-
"c2.label AS downstream_class, "
595-
"c2.short_form AS downstream_class_id, "
596-
"total_upstream_count, "
597-
"connected_upstream_count, "
598-
"round((toFloat(connected_upstream_count)/toFloat(total_upstream_count))*100) as percent_connected, "
599-
"pairwise_connections, "
600-
"total_weight, "
601-
"total_weight/pairwise_connections as average_weight "
602-
"ORDER BY percent_connected DESC, average_weight DESC"
603-
% ("-[:database_cross_reference]->(s:Individual:Site {is_data_source:[True]}) \n"
604-
"WHERE NOT (s.short_form IN %s) \n"
605-
"AND NOT (s.symbol[0] IN %s) "
606-
% (exclude_dbs, exclude_dbs) if exclude_dbs else ""))
597+
# Aggregate in Python so that connections also count for ancestor classes
598+
# (within the upstream_type/downstream_type scope), not just the directly
599+
# asserted INSTANCEOF class. Doing the closure traversal in Cypher would
600+
# blow up the query plan; this keeps the connection match fast and pushes
601+
# hierarchy expansion to a small follow-up query restricted to the
602+
# neurons actually involved.
603+
cypher_ql.append(
604+
"RETURN DISTINCT n1.short_form AS upstream_neuron_id, "
605+
"n2.short_form AS downstream_neuron_id, "
606+
"r.weight[0] AS weight"
607+
)
608+
cypher_q = ' \n\n'.join(cypher_ql)
609+
print("Connectivity query:\n%s" % cypher_q) if verbose else None
610+
r = self.nc.commit_list([cypher_q])
611+
if not r:
612+
warnings.warn("No results returned")
613+
return False
614+
conns = dict_cursor(r)
615+
if not conns:
616+
warnings.warn("No results returned")
617+
return False
618+
619+
def _scope_prefix(type_id):
620+
if type_id:
621+
return '(:Class:Neuron {short_form:"%s"})<-[:SUBCLASSOF*0..]-' % type_id
622+
return ''
623+
624+
n1_ids = list({c['upstream_neuron_id'] for c in conns})
625+
n2_ids = list({c['downstream_neuron_id'] for c in conns})
626+
627+
ancestor_q = (
628+
"MATCH %s(c:Class:Neuron)<-[:SUBCLASSOF*0..]-(:Class)<-[:INSTANCEOF]-(n:Individual) "
629+
"WHERE n.short_form IN %s "
630+
"RETURN n.short_form AS nid, "
631+
"collect(DISTINCT {id: c.short_form, label: c.label}) AS classes"
632+
)
633+
up_q = ancestor_q % (_scope_prefix(upstream_type), n1_ids)
634+
down_q = ancestor_q % (_scope_prefix(downstream_type), n2_ids)
635+
print("Upstream ancestor query:\n%s" % up_q) if verbose else None
636+
print("Downstream ancestor query:\n%s" % down_q) if verbose else None
637+
up_rows = dict_cursor(self.nc.commit_list([up_q]))
638+
down_rows = dict_cursor(self.nc.commit_list([down_q]))
639+
if not up_rows or not down_rows:
640+
raise RuntimeError(
641+
"Ancestor class lookup returned no rows for neurons that "
642+
"appeared in the connectivity query. This indicates a "
643+
"missing INSTANCEOF edge or a failed ancestor query."
644+
)
645+
n1_classes = {row['nid']: row['classes'] for row in up_rows}
646+
n2_classes = {row['nid']: row['classes'] for row in down_rows}
647+
648+
from collections import defaultdict
649+
pairwise = defaultdict(int)
650+
weight_sum = defaultdict(int)
651+
connected_n1s = defaultdict(set)
652+
class_labels = {}
653+
for c in conns:
654+
ups = n1_classes.get(c['upstream_neuron_id'], [])
655+
downs = n2_classes.get(c['downstream_neuron_id'], [])
656+
for a1 in ups:
657+
class_labels[a1['id']] = a1['label']
658+
for a2 in downs:
659+
class_labels[a2['id']] = a2['label']
660+
key = (a1['id'], a2['id'])
661+
pairwise[key] += 1
662+
weight_sum[key] += c['weight']
663+
connected_n1s[key].add(c['upstream_neuron_id'])
664+
665+
upstream_class_ids = list({k[0] for k in pairwise})
666+
db_filter = ""
667+
if exclude_dbs:
668+
db_filter = (
669+
"MATCH (all_n1)-[:database_cross_reference]->"
670+
"(s:Individual:Site {is_data_source:[True]}) "
671+
"WHERE NOT (s.short_form IN %s) "
672+
"AND NOT (s.symbol[0] IN %s) "
673+
% (exclude_dbs, exclude_dbs)
674+
)
675+
totals = {}
676+
if upstream_class_ids:
677+
total_q = (
678+
"MATCH %s(c:Class:Neuron) WHERE c.short_form IN %s "
679+
"MATCH (c)<-[:SUBCLASSOF*0..]-(:Class)<-[:INSTANCEOF]-"
680+
"(all_n1:Individual:has_neuron_connectivity) "
681+
"%s"
682+
"RETURN c.short_form AS cid, count(DISTINCT all_n1) AS total"
683+
% (_scope_prefix(upstream_type), upstream_class_ids, db_filter)
684+
)
685+
print("Total upstream count query:\n%s" % total_q) if verbose else None
686+
total_rows = dict_cursor(self.nc.commit_list([total_q]))
687+
if not total_rows:
688+
raise RuntimeError(
689+
"total_upstream_count query returned no rows for "
690+
"upstream classes that appeared in the connectivity "
691+
"results. This indicates a failed query."
692+
)
693+
totals = {row['cid']: row['total'] for row in total_rows}
694+
695+
rows = []
696+
for (c1_id, c2_id), pw in pairwise.items():
697+
tw = weight_sum[(c1_id, c2_id)]
698+
cu = len(connected_n1s[(c1_id, c2_id)])
699+
tot = totals.get(c1_id, 0)
700+
pct = round((cu / tot) * 100) if tot else 0
701+
rows.append({
702+
'upstream_class': class_labels.get(c1_id),
703+
'upstream_class_id': c1_id,
704+
'downstream_class': class_labels.get(c2_id),
705+
'downstream_class_id': c2_id,
706+
'total_upstream_count': tot,
707+
'connected_upstream_count': cu,
708+
'percent_connected': pct,
709+
'pairwise_connections': pw,
710+
'total_weight': tw,
711+
'average_weight': tw // pw if pw else 0,
712+
})
713+
rows.sort(key=lambda r: (-r['percent_connected'], -r['average_weight']))
714+
if return_dataframe:
715+
return pd.DataFrame.from_records(rows)
716+
return rows
607717

608718
cypher_q = ' \n\n'.join(cypher_ql)
609719
print(cypher_q) if verbose else None

0 commit comments

Comments
 (0)