11"""Filter endpoint for plots."""
22
33import logging
4+ from collections .abc import Callable
45
56from fastapi import APIRouter , Depends , Request
67from sqlalchemy .exc import SQLAlchemyError
1920router = APIRouter (tags = ["plots" ])
2021
2122
23+ # =============================================================================
24+ # Filter Category Extractors - Unified dispatch pattern for filter logic
25+ # =============================================================================
26+ # These extractors define how to get filter-matching values for each category.
27+ # Each extractor takes (library, spec_tags, impl_tags) and returns a list of values.
28+
29+ FilterExtractor = Callable [[str , dict , dict ], list [str ]]
30+
31+ # Spec-level category extractors (match against spec tags)
32+ _SPEC_EXTRACTORS : dict [str , Callable [[dict ], list [str ]]] = {
33+ "plot" : lambda tags : tags .get ("plot_type" , []),
34+ "data" : lambda tags : tags .get ("data_type" , []),
35+ "dom" : lambda tags : tags .get ("domain" , []),
36+ "feat" : lambda tags : tags .get ("features" , []),
37+ }
38+
39+ # Impl-level category extractors (match against impl tags)
40+ _IMPL_EXTRACTORS : dict [str , Callable [[dict ], list [str ]]] = {
41+ "dep" : lambda tags : tags .get ("dependencies" , []),
42+ "tech" : lambda tags : tags .get ("techniques" , []),
43+ "pat" : lambda tags : tags .get ("patterns" , []),
44+ "prep" : lambda tags : tags .get ("dataprep" , []),
45+ "style" : lambda tags : tags .get ("styling" , []),
46+ }
47+
48+
49+ def _get_category_values (category : str , spec_id : str , library : str , spec_tags : dict , impl_tags : dict ) -> list [str ]:
50+ """
51+ Get the values for a category from the appropriate tag source.
52+
53+ This unified function replaces the repeated if/elif chains throughout the module.
54+
55+ Args:
56+ category: Filter category (lib, spec, plot, data, dom, feat, dep, tech, pat, prep, style)
57+ spec_id: Specification ID
58+ library: Library ID
59+ spec_tags: Spec-level tags dict
60+ impl_tags: Implementation-level tags dict
61+
62+ Returns:
63+ List of values that match this category for the given image/spec/impl
64+ """
65+ if category == "lib" :
66+ return [library ]
67+ if category == "spec" :
68+ return [spec_id ]
69+ if category in _SPEC_EXTRACTORS :
70+ return _SPEC_EXTRACTORS [category ](spec_tags )
71+ if category in _IMPL_EXTRACTORS :
72+ return _IMPL_EXTRACTORS [category ](impl_tags )
73+ return []
74+
75+
76+ def _category_matches_filter (
77+ category : str , values : list [str ], spec_id : str , library : str , spec_tags : dict , impl_tags : dict
78+ ) -> bool :
79+ """
80+ Check if any of the filter values match the category's values.
81+
82+ Args:
83+ category: Filter category
84+ values: Filter values to match against
85+ spec_id: Specification ID
86+ library: Library ID
87+ spec_tags: Spec-level tags dict
88+ impl_tags: Implementation-level tags dict
89+
90+ Returns:
91+ True if any filter value matches, False otherwise
92+ """
93+ category_values = _get_category_values (category , spec_id , library , spec_tags , impl_tags )
94+ return any (v in category_values for v in values )
95+
96+
2297def _image_matches_groups (spec_id : str , library : str , groups : list [dict ], spec_lookup : dict , impl_lookup : dict ) -> bool :
2398 """Check if an image matches a set of filter groups."""
2499 if spec_id not in spec_lookup :
@@ -30,55 +105,22 @@ def _image_matches_groups(spec_id: str, library: str, groups: list[dict], spec_l
30105 category = group ["category" ]
31106 values = group ["values" ]
32107
33- if category == "lib" :
34- if library not in values :
35- return False
36- elif category == "spec" :
37- if spec_id not in values :
38- return False
39- elif category == "plot" :
40- spec_plot_types = spec_tags .get ("plot_type" , [])
41- if not any (v in spec_plot_types for v in values ):
42- return False
43- elif category == "data" :
44- spec_data_types = spec_tags .get ("data_type" , [])
45- if not any (v in spec_data_types for v in values ):
46- return False
47- elif category == "dom" :
48- spec_domains = spec_tags .get ("domain" , [])
49- if not any (v in spec_domains for v in values ):
50- return False
51- elif category == "feat" :
52- spec_features = spec_tags .get ("features" , [])
53- if not any (v in spec_features for v in values ):
54- return False
55- # Impl-level tag filters (issue #2434)
56- elif category == "dep" :
57- impl_deps = impl_tags .get ("dependencies" , [])
58- if not any (v in impl_deps for v in values ):
59- return False
60- elif category == "tech" :
61- impl_techs = impl_tags .get ("techniques" , [])
62- if not any (v in impl_techs for v in values ):
63- return False
64- elif category == "pat" :
65- impl_pats = impl_tags .get ("patterns" , [])
66- if not any (v in impl_pats for v in values ):
67- return False
68- elif category == "prep" :
69- impl_preps = impl_tags .get ("dataprep" , [])
70- if not any (v in impl_preps for v in values ):
71- return False
72- elif category == "style" :
73- impl_styles = impl_tags .get ("styling" , [])
74- if not any (v in impl_styles for v in values ):
75- return False
108+ if not _category_matches_filter (category , values , spec_id , library , spec_tags , impl_tags ):
109+ return False
76110 return True
77111
78112
79- def _calculate_global_counts (all_specs : list ) -> dict :
80- """Calculate global counts for all filter categories."""
81- global_counts : dict = {
113+ def _increment_category_counts (counts : dict , spec_id : str , library : str , spec_tags : dict , impl_tags : dict ) -> None :
114+ """Increment counts for all categories based on an image's spec/impl tags."""
115+ all_categories = ["lib" , "spec" , "plot" , "data" , "dom" , "feat" , "dep" , "tech" , "pat" , "prep" , "style" ]
116+ for category in all_categories :
117+ for value in _get_category_values (category , spec_id , library , spec_tags , impl_tags ):
118+ counts [category ][value ] = counts [category ].get (value , 0 ) + 1
119+
120+
121+ def _create_empty_counts () -> dict :
122+ """Create an empty counts dictionary with all categories initialized."""
123+ return {
82124 "lib" : {},
83125 "spec" : {},
84126 "plot" : {},
@@ -93,6 +135,18 @@ def _calculate_global_counts(all_specs: list) -> dict:
93135 "style" : {},
94136 }
95137
138+
139+ def _sort_counts (counts : dict ) -> dict :
140+ """Sort counts by value descending, then key ascending."""
141+ for category in counts :
142+ counts [category ] = dict (sorted (counts [category ].items (), key = lambda x : (- x [1 ], x [0 ])))
143+ return counts
144+
145+
146+ def _calculate_global_counts (all_specs : list ) -> dict :
147+ """Calculate global counts for all filter categories."""
148+ global_counts = _create_empty_counts ()
149+
96150 for spec_obj in all_specs :
97151 if not spec_obj .impls :
98152 continue
@@ -102,104 +156,25 @@ def _calculate_global_counts(all_specs: list) -> dict:
102156 if not impl .preview_url :
103157 continue
104158
105- # Count library
106- global_counts ["lib" ][impl .library_id ] = global_counts ["lib" ].get (impl .library_id , 0 ) + 1
107-
108- # Count spec ID
109- global_counts ["spec" ][spec_obj .id ] = global_counts ["spec" ].get (spec_obj .id , 0 ) + 1
110-
111- # Count spec-level tags
112- for plot_type in spec_tags .get ("plot_type" , []):
113- global_counts ["plot" ][plot_type ] = global_counts ["plot" ].get (plot_type , 0 ) + 1
114-
115- for data_type in spec_tags .get ("data_type" , []):
116- global_counts ["data" ][data_type ] = global_counts ["data" ].get (data_type , 0 ) + 1
117-
118- for domain in spec_tags .get ("domain" , []):
119- global_counts ["dom" ][domain ] = global_counts ["dom" ].get (domain , 0 ) + 1
120-
121- for feature in spec_tags .get ("features" , []):
122- global_counts ["feat" ][feature ] = global_counts ["feat" ].get (feature , 0 ) + 1
123-
124- # Count impl-level tags (issue #2434)
125159 impl_tags = impl .impl_tags or {}
126- for dep in impl_tags .get ("dependencies" , []):
127- global_counts ["dep" ][dep ] = global_counts ["dep" ].get (dep , 0 ) + 1
128- for tech in impl_tags .get ("techniques" , []):
129- global_counts ["tech" ][tech ] = global_counts ["tech" ].get (tech , 0 ) + 1
130- for pat in impl_tags .get ("patterns" , []):
131- global_counts ["pat" ][pat ] = global_counts ["pat" ].get (pat , 0 ) + 1
132- for prep in impl_tags .get ("dataprep" , []):
133- global_counts ["prep" ][prep ] = global_counts ["prep" ].get (prep , 0 ) + 1
134- for style in impl_tags .get ("styling" , []):
135- global_counts ["style" ][style ] = global_counts ["style" ].get (style , 0 ) + 1
136-
137- # Sort counts
138- for category in global_counts :
139- global_counts [category ] = dict (sorted (global_counts [category ].items (), key = lambda x : (- x [1 ], x [0 ])))
160+ _increment_category_counts (global_counts , spec_obj .id , impl .library_id , spec_tags , impl_tags )
140161
141- return global_counts
162+ return _sort_counts ( global_counts )
142163
143164
144165def _calculate_contextual_counts (filtered_images : list [dict ], spec_id_to_tags : dict , impl_lookup : dict ) -> dict :
145166 """Calculate contextual counts from filtered images."""
146- counts : dict = {
147- "lib" : {},
148- "spec" : {},
149- "plot" : {},
150- "data" : {},
151- "dom" : {},
152- "feat" : {},
153- # Impl-level tag counts (issue #2434)
154- "dep" : {},
155- "tech" : {},
156- "pat" : {},
157- "prep" : {},
158- "style" : {},
159- }
167+ counts = _create_empty_counts ()
160168
161169 for img in filtered_images :
162170 spec_id = img ["spec_id" ]
163171 library = img ["library" ]
164172 spec_tags = spec_id_to_tags .get (spec_id , {})
165173 impl_tags = impl_lookup .get ((spec_id , library ), {})
166174
167- # Count library
168- counts ["lib" ][library ] = counts ["lib" ].get (library , 0 ) + 1
169-
170- # Count spec ID
171- counts ["spec" ][spec_id ] = counts ["spec" ].get (spec_id , 0 ) + 1
172-
173- # Count spec-level tags
174- for plot_type in spec_tags .get ("plot_type" , []):
175- counts ["plot" ][plot_type ] = counts ["plot" ].get (plot_type , 0 ) + 1
176-
177- for data_type in spec_tags .get ("data_type" , []):
178- counts ["data" ][data_type ] = counts ["data" ].get (data_type , 0 ) + 1
175+ _increment_category_counts (counts , spec_id , library , spec_tags , impl_tags )
179176
180- for domain in spec_tags .get ("domain" , []):
181- counts ["dom" ][domain ] = counts ["dom" ].get (domain , 0 ) + 1
182-
183- for feature in spec_tags .get ("features" , []):
184- counts ["feat" ][feature ] = counts ["feat" ].get (feature , 0 ) + 1
185-
186- # Count impl-level tags (issue #2434)
187- for dep in impl_tags .get ("dependencies" , []):
188- counts ["dep" ][dep ] = counts ["dep" ].get (dep , 0 ) + 1
189- for tech in impl_tags .get ("techniques" , []):
190- counts ["tech" ][tech ] = counts ["tech" ].get (tech , 0 ) + 1
191- for pat in impl_tags .get ("patterns" , []):
192- counts ["pat" ][pat ] = counts ["pat" ].get (pat , 0 ) + 1
193- for prep in impl_tags .get ("dataprep" , []):
194- counts ["prep" ][prep ] = counts ["prep" ].get (prep , 0 ) + 1
195- for style in impl_tags .get ("styling" , []):
196- counts ["style" ][style ] = counts ["style" ].get (style , 0 ) + 1
197-
198- # Sort counts
199- for category in counts :
200- counts [category ] = dict (sorted (counts [category ].items (), key = lambda x : (- x [1 ], x [0 ])))
201-
202- return counts
177+ return _sort_counts (counts )
203178
204179
205180def _calculate_or_counts (
@@ -240,38 +215,9 @@ def _calculate_or_counts(
240215 spec_tags = spec_id_to_tags .get (spec_id , {})
241216 impl_tags = impl_lookup .get ((spec_id , library ), {})
242217
243- if category == "lib" :
244- group_counts [library ] = group_counts .get (library , 0 ) + 1
245- elif category == "spec" :
246- group_counts [spec_id ] = group_counts .get (spec_id , 0 ) + 1
247- elif category == "plot" :
248- for v in spec_tags .get ("plot_type" , []):
249- group_counts [v ] = group_counts .get (v , 0 ) + 1
250- elif category == "data" :
251- for v in spec_tags .get ("data_type" , []):
252- group_counts [v ] = group_counts .get (v , 0 ) + 1
253- elif category == "dom" :
254- for v in spec_tags .get ("domain" , []):
255- group_counts [v ] = group_counts .get (v , 0 ) + 1
256- elif category == "feat" :
257- for v in spec_tags .get ("features" , []):
258- group_counts [v ] = group_counts .get (v , 0 ) + 1
259- # Impl-level tag counts (issue #2434)
260- elif category == "dep" :
261- for v in impl_tags .get ("dependencies" , []):
262- group_counts [v ] = group_counts .get (v , 0 ) + 1
263- elif category == "tech" :
264- for v in impl_tags .get ("techniques" , []):
265- group_counts [v ] = group_counts .get (v , 0 ) + 1
266- elif category == "pat" :
267- for v in impl_tags .get ("patterns" , []):
268- group_counts [v ] = group_counts .get (v , 0 ) + 1
269- elif category == "prep" :
270- for v in impl_tags .get ("dataprep" , []):
271- group_counts [v ] = group_counts .get (v , 0 ) + 1
272- elif category == "style" :
273- for v in impl_tags .get ("styling" , []):
274- group_counts [v ] = group_counts .get (v , 0 ) + 1
218+ # Use unified value extractor
219+ for value in _get_category_values (category , spec_id , library , spec_tags , impl_tags ):
220+ group_counts [value ] = group_counts .get (value , 0 ) + 1
275221
276222 # Sort by count descending
277223 group_counts = dict (sorted (group_counts .items (), key = lambda x : (- x [1 ], x [0 ])))
0 commit comments