|
23 | 23 |
|
24 | 24 | from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput |
25 | 25 | from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme |
26 | | -from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk |
| 26 | +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, InvokeModelWithBidirectionalStreamOperationOutput |
27 | 27 | from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver |
28 | 28 |
|
29 | 29 | from ....types.content import Messages |
|
35 | 35 | BidirectionalConnectionStartEvent, |
36 | 36 | InterruptionDetectedEvent, |
37 | 37 | TextOutputEvent, |
38 | | - UsageMetricsEvent |
| 38 | + UsageMetricsEvent, |
39 | 39 | ) |
40 | | - |
41 | 40 | from .bidirectional_model import BidirectionalModel, BidirectionalModelSession |
42 | 41 |
|
43 | 42 | logger = logging.getLogger(__name__) |
@@ -81,11 +80,11 @@ class NovaSonicSession(BidirectionalModelSession): |
81 | 80 | interface. |
82 | 81 | """ |
83 | 82 |
|
84 | | - def __init__(self, stream: any, config: dict[str, any]) -> None: |
| 83 | + def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, config: dict[str, any]) -> None: |
85 | 84 | """Initialize Nova Sonic connection. |
86 | 85 |
|
87 | 86 | Args: |
88 | | - stream: Nova Sonic bidirectional stream. |
| 87 | + stream: Nova Sonic bidirectional stream operation output from AWS SDK. |
89 | 88 | config: Model configuration. |
90 | 89 | """ |
91 | 90 | self.stream = stream |
@@ -487,14 +486,14 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No |
487 | 486 |
|
488 | 487 | return {"interruptionDetected": interruption} |
489 | 488 |
|
490 | | - # Handle usage events (ignore) |
| 489 | + # Handle usage events - convert to standardized format |
491 | 490 | elif "usageEvent" in nova_event: |
492 | 491 | usage_data = nova_event["usageEvent"] |
493 | 492 | usage_metrics: UsageMetricsEvent = { |
494 | | - "totalTokens": usage_data.get("totalTokens"), |
495 | | - "inputTokens": usage_data.get("totalInputTokens"), |
496 | | - "outputTokens": usage_data.get("totalOutputTokens"), |
497 | | - "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens"), |
| 493 | + "totalTokens": usage_data.get("totalTokens", 0), |
| 494 | + "inputTokens": usage_data.get("totalInputTokens", 0), |
| 495 | + "outputTokens": usage_data.get("totalOutputTokens", 0), |
| 496 | + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens", 0) |
498 | 497 | } |
499 | 498 | return {"usageMetrics": usage_metrics} |
500 | 499 |
|
|
0 commit comments