Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added missing `gen_ai.response.id` attribute to span and event.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from typing import (
Any,
Dict,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Union,
)

Expand All @@ -24,41 +20,13 @@
FlattenedDict = Dict[str, FlattenedValue]


class FlattenFunc(Protocol):
def __call__(
self,
key: str,
value: Any,
exclude_keys: Set[str],
rename_keys: Dict[str, str],
flatten_functions: Dict[str, "FlattenFunc"],
**kwargs: Any,
) -> Any:
return None


_logger = logging.getLogger(__name__)


def _concat_key(prefix: Optional[str], suffix: str):
if not prefix:
return suffix
return f"{prefix}.{suffix}"


def _is_primitive(v):
for t in [str, bool, int, float]:
if isinstance(v, t):
return True
return False


def _is_homogenous_primitive_list(v):
if not isinstance(v, list):
return False
if len(v) == 0:
return True
if not _is_primitive(v[0]):
if not isinstance(v[0], (str, bool, int, float)):
return False
first_entry_value_type = type(v[0])
for entry in v[1:]:
Expand All @@ -67,49 +35,10 @@ def _is_homogenous_primitive_list(v):
return True


def _get_flatten_func(
flatten_functions: Dict[str, FlattenFunc], key_names: set[str]
) -> Optional[FlattenFunc]:
for key in key_names:
flatten_func = flatten_functions.get(key)
if flatten_func is not None:
return flatten_func
return None


def _flatten_with_flatten_func(
key: str,
value: Any,
exclude_keys: Set[str],
rename_keys: Dict[str, str],
flatten_functions: Dict[str, FlattenFunc],
key_names: Set[str],
) -> Tuple[bool, Any]:
flatten_func = _get_flatten_func(flatten_functions, key_names)
if flatten_func is None:
return False, value
func_output = flatten_func(
key,
value,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
)
if func_output is None:
return True, {}
if _is_primitive(func_output) or _is_homogenous_primitive_list(
func_output
):
return True, {key: func_output}
return False, func_output


def _flatten_compound_value_using_json(
key: str,
value: Any,
exclude_keys: Set[str],
rename_keys: Dict[str, str],
flatten_functions: Dict[str, FlattenFunc],
_from_json=False,
) -> FlattenedDict:
if _from_json:
Expand All @@ -126,168 +55,64 @@ def _flatten_compound_value_using_json(
value,
)
return {}
json_value = json.loads(json_string)
return _flatten_value(
key,
json_value,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
json.loads(json_string),
exclude_keys,
# Ensure that we don't recurse indefinitely if "json.loads()" somehow returns
# a complex, compound object that does not get handled by the "primitive", "list",
# or "dict" cases. Prevents falling back on the JSON serialization fallback path.
_from_json=True,
True,
)


def _flatten_compound_value( # pylint: disable=too-many-return-statements
def _flatten_compound_value(
key: str,
value: Any,
exclude_keys: Set[str],
rename_keys: Dict[str, str],
flatten_functions: Dict[str, FlattenFunc],
key_names: Set[str],
_from_json=False,
) -> FlattenedDict:
fully_flattened_with_flatten_func, value = _flatten_with_flatten_func(
key=key,
value=value,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
key_names=key_names,
)
if fully_flattened_with_flatten_func:
return value
if isinstance(value, dict):
return _flatten_dict(
value,
key_prefix=key,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
)
return flatten_dict(value, key, exclude_keys)
if isinstance(value, list):
if _is_homogenous_primitive_list(value):
return {key: value}
return _flatten_list(
value,
key_prefix=key,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
)
result = {f"{key}.length": len(value)}
for idx, val in enumerate(value):
result.update(_flatten_value(f"{key}[{idx}]", val, exclude_keys))
return result
if hasattr(value, "model_dump"):
try:
return _flatten_dict(
value.model_dump(),
key_prefix=key,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
)
return flatten_dict(value.model_dump(), key, exclude_keys)
except TypeError:
return {key: str(value)}
return _flatten_compound_value_using_json(
key,
value,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
_from_json=_from_json,
key, value, exclude_keys, _from_json
)


def _flatten_value(
key: str,
value: Any,
exclude_keys: Set[str],
rename_keys: Dict[str, str],
flatten_functions: Dict[str, FlattenFunc],
_from_json=False,
) -> FlattenedDict:
if value is None:
return {}
key_names = set([key])
renamed_key = rename_keys.get(key)
if renamed_key is not None:
key_names.add(renamed_key)
key = renamed_key
if key_names & exclude_keys:
if value is None or key in exclude_keys:
return {}
if _is_primitive(value):
if isinstance(value, (str, bool, int, float)):
return {key: value}
return _flatten_compound_value(
key=key,
value=value,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
key_names=key_names,
_from_json=_from_json,
)
return _flatten_compound_value(key, value, exclude_keys, _from_json)


def _flatten_dict(
def flatten_dict(
d: Dict[str, Any],
key_prefix: str,
exclude_keys: Set[str],
rename_keys: Dict[str, str],
flatten_functions: Dict[str, FlattenFunc],
) -> FlattenedDict:
result = {}
for key, value in d.items():
if key in exclude_keys:
continue
full_key = _concat_key(key_prefix, key)
flattened = _flatten_value(
full_key,
value,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
)
result.update(flattened)
return result


def _flatten_list(
lst: list[Any],
key_prefix: str,
exclude_keys: Set[str],
rename_keys: Dict[str, str],
flatten_functions: Dict[str, FlattenFunc],
) -> FlattenedDict:
result = {}
result[_concat_key(key_prefix, "length")] = len(lst)
for index, value in enumerate(lst):
full_key = f"{key_prefix}[{index}]"
flattened = _flatten_value(
full_key,
value,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
)
result.update(flattened)
if key not in exclude_keys:
result.update(
_flatten_value(f"{key_prefix}.{key}", value, exclude_keys)
)
return result


def flatten_dict(
d: Dict[str, Any],
key_prefix: Optional[str] = None,
exclude_keys: Optional[Sequence[str]] = None,
rename_keys: Optional[Dict[str, str]] = None,
flatten_functions: Optional[Dict[str, FlattenFunc]] = None,
):
key_prefix = key_prefix or ""
exclude_keys = set(exclude_keys or [])
rename_keys = rename_keys or {}
flatten_functions = flatten_functions or {}
return _flatten_dict(
d,
key_prefix=key_prefix,
exclude_keys=exclude_keys,
rename_keys=rename_keys,
flatten_functions=flatten_functions,
)
Loading