11import json
2- import warnings
3- from typing import Any , Dict , List , Tuple , Union
2+ from typing import Any , Callable , Dict , List , Set , Tuple , Union
43
54import click
65
76from .content_loader import ContentLoader
7+ from .logger import Logger
88from .models .openapi_models import (
99 HttpHeader ,
1010 HttpParameter ,
@@ -32,9 +32,19 @@ def __init__(self, openapi_input: str, tag: str = "", operation_id: str = ""):
3232 self .tag = tag
3333 self .operation_id = operation_id
3434 self .openapi_input = openapi_input
35+
3536 self .visited_refs = (
3637 set ()
3738 ) # Track visited references to prevent infinite recursion
39+ self .json_schema_key_words = ["$ref" ] # , "allOf", "oneOf", "anyOf", "not"]
40+ self .openapi_schema_groupings = [
41+ "allOf" ,
42+ "oneOf" ,
43+ "anyOf" ,
44+ "not" ,
45+ "properties" ,
46+ "items" ,
47+ ]
3848
3949 def parse (self ) -> OpenAPIMetadata :
4050 """
@@ -71,13 +81,12 @@ def _get_operations_tags_http_params_from_paths(
7181 self ,
7282 ) -> Tuple [List [OpenAPIOperation ], List [str ], List [HttpParameter ]]:
7383 operations = []
84+ unique_operation_ids = set ()
7485 tags = set ()
7586 http_params = []
7687 for path , methods in self .paths .items ():
7788 for method , details in methods .items ():
7889 # if statements are for filtering parser for easier debugging
79- if "operationId" not in details :
80- continue
8190 if self .operation_id != "" and self .operation_id != details .get (
8291 "operationId" , ""
8392 ):
@@ -86,9 +95,10 @@ def _get_operations_tags_http_params_from_paths(
8695 continue
8796
8897 operation : OpenAPIOperation = self ._parse_operation (
89- path , method , details
98+ path , method , details , unique_operation_ids
9099 )
91- tags .add (operation .tag )
100+ if operation .tag != "" :
101+ tags .add (operation .tag )
92102 for http_param in operation .parameters :
93103 self ._add_unique_http_param (
94104 http_params , http_param .model_dump (by_alias = True )
@@ -102,7 +112,7 @@ def _merge_path_and_spec_tags(self, tags_from_paths: List[str]) -> List[OpenAPIT
102112 spec_tag_names : List [OpenAPITag ] = []
103113 for tag in openapi_spec_tags :
104114 if "name" not in tag :
105- warnings .warn (f"Tag missing name: { tag } " )
115+ Logger .warn (f"Tag missing name: { tag } " )
106116 else :
107117 spec_tag_names .append (tag .get ("name" , "" ))
108118 for tag_name in tags_from_paths :
@@ -121,22 +131,58 @@ def _add_unique_http_param(
121131 headers_list .append (new_header )
122132
123133 def _parse_operation (
124- self , path : str , method : str , details : Dict [str , Any ]
134+ self ,
135+ path : str ,
136+ method : str ,
137+ details : Dict [str , Any ],
138+ unique_operation_ids : Set [str ],
125139 ) -> OpenAPIOperation :
126140 """
127141 Extract relevant details for an API operation.
128142 """
143+ # get first tag, ignore the rest
144+ tag = details .get ("tags" , ["" ])[0 ]
145+ operation_id = self ._parse_operation_id (
146+ path , method , details , unique_operation_ids
147+ )
148+ method = method .upper ()
149+
129150 return OpenAPIOperation (
130- tag = details . get ( "tags" , [ "" ])[ 0 ] ,
131- operation_id = details [ "operationId" ] ,
132- method = method . upper () ,
151+ tag = tag ,
152+ operation_id = operation_id ,
153+ method = method ,
133154 path = path ,
134155 summary = details .get ("summary" , "" ),
135156 description = details .get ("description" , "" ),
136157 parameters = self ._parse_parameters (details .get ("parameters" , [])),
137158 request_body = self ._parse_request_body (details .get ("requestBody" , {})),
138159 )
139160
161+ def _parse_operation_id (
162+ self ,
163+ path : str ,
164+ method : str ,
165+ details : Dict [str , Any ],
166+ unique_operation_ids : Set [str ],
167+ ) -> str :
168+ """
169+ Extract the operationId from the details.
170+ """
171+ # Handle case where no operationId exists
172+ operation_id = details .get ("operationId" , "" )
173+ if not operation_id :
174+ Logger .warn (
175+ f"No operationId found for { path } { method } . Using: { method } _{ path } "
176+ )
177+ operation_id = f"{ method } _{ path } "
178+ if operation_id in unique_operation_ids :
179+ Logger .warn (
180+ f"Duplicate operationId found: { operation_id } . Making unique by adding _duplicate"
181+ )
182+ operation_id = f"{ operation_id } _duplicate"
183+ unique_operation_ids .add (operation_id )
184+ return operation_id
185+
140186 def _resolve_param_ref (self , param : str ) -> Dict [str , Any ]:
141187 """
142188 Resolve a reference to a component in the OpenAPI spec.
@@ -190,15 +236,18 @@ def _parse_request_body(
190236 json_schema = content .get ("application/json" , {}).get ("schema" , {})
191237 return self ._schema_metadata (json_schema )
192238
193- def _schema_metadata (self , schema : Dict [str , Any ]) -> SchemaMetadata :
239+ def _schema_metadata (
240+ self , schema : Dict [str , Any ], count : int = 0
241+ ) -> SchemaMetadata :
194242 """
195243 Extract relevant metadata from a given schema.
196244 """
197245 required = schema .get ("required" )
198246 nullable = schema .get ("nullable" )
199247 json_schema_type = self ._resolve_type (schema )
200248 nested_json_schema_refs = self ._extract_refs (schema )
201- nested_json_schemas = self ._resolve_nested_types (schema )
249+
250+ nested_json_schemas = self ._resolve_nested_types (schema , count = count + 1 )
202251 type_is_schema = len (nested_json_schema_refs ) > 0
203252
204253 return SchemaMetadata (
@@ -240,15 +289,16 @@ def _extract_refs(self, schema: Dict[str, Any]) -> List[str]:
240289 refs .append (ref_name )
241290 if ref_name in self .schemas :
242291 refs .extend (self ._extract_refs (self .schemas [ref_name ]))
243- for key in ["allOf" , "oneOf" , "anyOf" , "not" , "properties" , "items" ]:
244- if key in schema :
245- for sub_schema in (
246- schema [key ] if isinstance (schema [key ], list ) else [schema [key ]]
247- ):
248- refs .extend (self ._extract_refs (sub_schema ))
292+ else :
293+ Logger .warn (f"Schema { ref_name } not found in schemas" )
294+
295+ self .loop_over_schema_groupings (schema , refs , self ._extract_refs )
296+
249297 return refs
250298
251- def _resolve_nested_types (self , schema : Dict [str , Any ]) -> List [Dict [str , Any ]]:
299+ def _resolve_nested_types (
300+ self , schema : Dict [str , Any ], count : int = 0
301+ ) -> List [Dict [str , Any ]]:
252302 """
253303 Recursively resolve nested types within a schema, including all properties and nested properties.
254304 Handles $ref, allOf, oneOf, anyOf, not, and properties within objects and arrays.
@@ -259,54 +309,121 @@ def _resolve_nested_types(self, schema: Dict[str, Any]) -> List[Dict[str, Any]]:
259309 Returns:
260310 List of resolved nested type schemas
261311 """
312+
262313 nested_types = []
263314 if not schema : # Handle None or empty schema
264315 return nested_types
265316
266317 if "type" in schema :
267- self ._traverse_dict (schema )
268- nested_types .append (schema )
318+ if schema ["type" ] in ["object" , "array" ]:
319+ self ._traverse_dict (schema , count = count + 1 )
320+ nested_types .append (schema )
321+ else :
322+ # type in schema is not object or array
323+ Logger .warn (f"Schema has type { schema ['type' ]} " )
269324
270325 if "$ref" in schema :
271326 ref_name = schema ["$ref" ].split ("/" )[- 1 ]
272- if ref_name in self .schemas and ref_name not in self . visited_refs :
273- self . visited_refs . add ( ref_name ) # Mark as visited
274- nested_types . extend ( self ._resolve_nested_types ( self . schemas [ ref_name ]))
327+ if ref_name in self .schemas :
328+ if ref_name not in self . visited_refs :
329+ self .visited_refs . add ( ref_name ) # Mark as visited
275330
276- for key in ["allOf" , "oneOf" , "anyOf" , "not" ]:
277- if key in schema :
278- for sub_schema in (
279- schema [key ] if isinstance (schema [key ], list ) else [schema [key ]]
280- ):
281- nested_types .extend (self ._resolve_nested_types (sub_schema ))
331+ nested_types .extend (
332+ self ._resolve_nested_types (
333+ self .schemas [ref_name ], count = count + 1
334+ )
335+ )
336+ else :
337+ Logger .warn (f"Schema { ref_name } has already been visited" )
338+ else :
339+ Logger .warn (f"Schema { ref_name } not found in schemas" )
340+
341+ self .loop_over_schema_groupings (
342+ schema ,
343+ nested_types ,
344+ self ._resolve_nested_types ,
345+ )
282346
283347 return nested_types
284348
349+ def loop_over_schema_groupings (
350+ self ,
351+ schema : Dict [str , Any ],
352+ refs : List [str ],
353+ schema_type_handler : Callable [[Dict [str , Any ]], None ],
354+ ):
355+ for key in self .openapi_schema_groupings :
356+ if key in schema :
357+ value_at_key = schema [key ]
358+ if isinstance (value_at_key , list ):
359+ for sub_schema in value_at_key :
360+ refs .extend (schema_type_handler (sub_schema ))
361+ else :
362+ refs .extend (schema_type_handler (value_at_key ))
363+
285364 def _traverse_dict (
286365 self ,
287366 d : Dict [str , Any ],
288- key : Union [str , int ] = None ,
367+ parent_key : Union [str , int ] = None ,
289368 parent : Union [Dict [str , Any ], List [Any ]] = None ,
369+ count : int = 0 ,
290370 ):
291371 """
292372 Traverses a dictionary, resolving any '$ref' values.
293373 :param d: The dictionary to traverse.
294374 :param resolve_ref: A function that resolves '$ref' values.
295375 """
296- for k , value in d .items ():
297- if k in ["$ref" , "allOf" , "oneOf" , "anyOf" , "not" ]:
298- schema_metadata = self ._schema_metadata (d )
299- parent [key ] = schema_metadata
376+
377+ for d_key , value in d .items ():
378+ if d_key in self .json_schema_key_words :
379+ self ._update_parent_schema_with_schema_metadata (
380+ d = d ,
381+ d_key = d_key ,
382+ parent_key = parent_key ,
383+ parent = parent ,
384+ count = count + 1 ,
385+ )
386+
300387 elif isinstance (value , dict ):
301- self ._traverse_dict (d = value , key = k , parent = d )
388+ self ._traverse_dict (
389+ d = value , parent_key = d_key , parent = d , count = count + 1
390+ )
302391 elif isinstance (value , list ):
303- self ._traverse_array (arr = value , key = k , parent = d )
392+ self ._traverse_array (
393+ arr = value , parent_key = d_key , parent = d , count = count + 1
394+ )
395+
396+ def _update_parent_schema_with_schema_metadata (
397+ self ,
398+ d : Dict [str , Any ],
399+ d_key : Union [str , int ],
400+ parent_key : Union [str , int ],
401+ parent : Dict [str , Any ],
402+ count : int = 0 ,
403+ ):
404+ num_of_keys = len (d .keys ())
405+ schema : Dict [str , Any ] = {}
406+ schema_parent : Dict [str , Any ] = parent
407+
408+ # If the dictionary has more than one key, we need to create a new dictionary with ONLY the key
409+ if num_of_keys > 1 :
410+ # schema_parent = d
411+ schema [d_key ] = d [d_key ]
412+
413+ else :
414+ # schema_parent = parent
415+ schema = d
416+
417+ schema_metadata = self ._schema_metadata (schema , count = count + 1 )
418+
419+ schema_parent [parent_key ] = schema_metadata
304420
305421 def _traverse_array (
306422 self ,
307423 arr : List [Any ],
308- key : Union [str , int ] = None ,
424+ parent_key : Union [str , int ] = None ,
309425 parent : Union [Dict [str , Any ], List [Any ]] = None ,
426+ count : int = 0 ,
310427 ):
311428 """
312429 Traverses an array (list), resolving any '$ref' values inside the array.
@@ -315,9 +432,11 @@ def _traverse_array(
315432 """
316433 for i , item in enumerate (arr ):
317434 if isinstance (item , dict ):
318- self ._traverse_dict (d = item , key = i , parent = arr )
435+ self ._traverse_dict (d = item , parent_key = i , parent = arr , count = count + 1 )
319436 elif isinstance (item , list ):
320- self ._traverse_array (arr = item , key = i , parent = arr )
437+ self ._traverse_array (
438+ arr = item , parent_key = i , parent = arr , count = count + 1
439+ )
321440
322441
323442@click .command ()
0 commit comments