Skip to content

Commit b4caa4d

Browse files
committed
Enhance neuron connectivity query: support direction filtering and accurate count retrieval
1 parent 60c4054 commit b4caa4d

1 file changed

Lines changed: 53 additions & 8 deletions

File tree

src/vfbquery/vfb_queries.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,34 +2655,52 @@ def get_neuron_neuron_connectivity(short_form: str, return_dataframe=True, limit
26552655
"""
26562656
# Build Cypher query to get connected neurons using synapsed_to relationships
26572657
# XMI spec uses min_weight > 1, but we default to 0 to return all valid connections
2658-
cypher = f"""
2658+
base_cypher = f"""
26592659
MATCH (primary:Individual {{short_form: '{short_form}'}})
26602660
MATCH (oi:Individual)-[r:synapsed_to]-(primary)
26612661
WHERE exists(r.weight) AND r.weight[0] > {min_weight}
26622662
WITH primary, oi
26632663
OPTIONAL MATCH (oi)<-[down:synapsed_to]-(primary)
26642664
WITH down, oi, primary
26652665
OPTIONAL MATCH (primary)<-[up:synapsed_to]-(oi)
2666+
"""
2667+
2668+
if direction == 'upstream':
2669+
base_cypher += " WHERE up IS NOT NULL AND up.weight[0] > 0"
2670+
elif direction == 'downstream':
2671+
base_cypher += " WHERE down IS NOT NULL AND down.weight[0] > 0"
2672+
# for 'both', no additional WHERE
2673+
2674+
cypher = base_cypher + """
26662675
RETURN
26672676
oi.short_form AS id,
26682677
oi.label AS label,
26692678
coalesce(down.weight[0], 0) AS outputs,
26702679
coalesce(up.weight[0], 0) AS inputs,
26712680
oi.uniqueFacets AS tags
26722681
"""
2682+
26732683
if limit != -1:
26742684
cypher += f" LIMIT {limit}"
26752685

26762686
# Run query using Neo4j client
26772687
results = vc.nc.commit_list([cypher])
26782688
rows = get_dict_cursor()(results)
26792689

2680-
# Filter by direction if specified
2681-
if direction != 'both':
2690+
# Get total count if limit was applied
2691+
if limit != -1:
26822692
if direction == 'upstream':
2683-
rows = [row for row in rows if row.get('inputs', 0) > 0]
2693+
count_query = base_cypher + " WHERE up IS NOT NULL AND up.weight[0] > 0 RETURN count(DISTINCT oi)"
26842694
elif direction == 'downstream':
2685-
rows = [row for row in rows if row.get('outputs', 0) > 0]
2695+
count_query = base_cypher + " WHERE down IS NOT NULL AND down.weight[0] > 0 RETURN count(DISTINCT oi)"
2696+
else: # both
2697+
count_query = base_cypher + " RETURN count(DISTINCT oi)"
2698+
count_results = vc.nc.commit_list([count_query])
2699+
total_count = count_results[0][0] if count_results and count_results[0] else 0
2700+
else:
2701+
total_count = len(rows)
2702+
2703+
# No need to filter by direction, it's done in the query
26862704

26872705
# Format output
26882706
if return_dataframe:
@@ -2699,7 +2717,7 @@ def get_neuron_neuron_connectivity(short_form: str, return_dataframe=True, limit
26992717
return {
27002718
'headers': headers,
27012719
'rows': rows,
2702-
'count': len(rows)
2720+
'count': total_count
27032721
}
27042722

27052723

@@ -3871,7 +3889,7 @@ def fill_query_results(term_info):
38713889
def process_query(query):
38723890
# print(f"Query Keys:{query.keys()}")
38733891

3874-
if "preview" in query.keys() and (query['preview'] > 0 or query['count'] < 0) and query['count'] != 0:
3892+
if "preview" in query.keys() and query['preview'] > 0:
38753893
function = globals().get(query['function'])
38763894
summary_mode = query.get('output_format', 'table') == 'ribbon'
38773895

@@ -3916,6 +3934,11 @@ def process_query(query):
39163934
filtered_result = []
39173935
filtered_headers = {}
39183936

3937+
if result is None:
3938+
query['preview_results'] = {'headers': query.get('preview_columns', ['id', 'label', 'tags', 'thumbnail']), 'rows': []}
3939+
query['count'] = 0
3940+
return
3941+
39193942
if isinstance(result, dict) and 'rows' in result:
39203943
for item in result['rows']:
39213944
if 'preview_columns' in query.keys() and len(query['preview_columns']) > 0:
@@ -3958,8 +3981,30 @@ def process_query(query):
39583981
# Handle count extraction based on result type
39593982
if isinstance(result, dict) and 'count' in result:
39603983
result_count = result['count']
3984+
# If limit was applied, the count in dict may be wrong, get correct count
3985+
if query['preview'] > 0 and result_count == len(result['rows']):
3986+
try:
3987+
if function_args and takes_short_form:
3988+
short_form_value = list(function_args.values())[0]
3989+
full_dict = function(short_form_value, return_dataframe=False, limit=-1)
3990+
else:
3991+
full_dict = function(return_dataframe=False, limit=-1)
3992+
result_count = full_dict['count']
3993+
except Exception as e:
3994+
print(f"Error getting full count for {query['function']}: {e}")
3995+
result_count = result['count'] # Keep as is
39613996
elif isinstance(result, pd.DataFrame):
3962-
result_count = len(result)
3997+
# For DataFrame results, we need the full count even when preview is limited
3998+
try:
3999+
if function_args and takes_short_form:
4000+
short_form_value = list(function_args.values())[0]
4001+
full_result = function(short_form_value, return_dataframe=True, limit=-1)
4002+
else:
4003+
full_result = function(return_dataframe=True, limit=-1)
4004+
result_count = len(full_result)
4005+
except Exception as e:
4006+
print(f"Error getting full count for {query['function']}: {e}")
4007+
result_count = len(result) # Fallback to limited count
39634008
else:
39644009
result_count = 0
39654010

0 commit comments

Comments
 (0)