Skip to content

Commit 1ca29c6

Browse files
authored
Add integration tests for Transcribe Streaming (#38)
1 parent fe0c6af commit 1ca29c6

5 files changed

Lines changed: 354 additions & 0 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from pathlib import Path
5+
6+
from smithy_aws_core.identity import EnvironmentCredentialsResolver
7+
8+
from aws_sdk_transcribe_streaming.client import TranscribeStreamingClient
9+
from aws_sdk_transcribe_streaming.config import Config
10+
11+
AUDIO_FILE = Path(__file__).parent / "assets" / "test.wav"
12+
13+
14+
def create_transcribe_client(region: str) -> TranscribeStreamingClient:
15+
"""Helper to create a TranscribeStreamingClient for a given region."""
16+
return TranscribeStreamingClient(
17+
config=Config(
18+
endpoint_uri=f"https://transcribestreaming.{region}.amazonaws.com",
19+
region=region,
20+
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
21+
)
22+
)
Binary file not shown.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Test bidirectional event stream handling."""
5+
6+
import asyncio
7+
import time
8+
9+
from smithy_core.aio.eventstream import DuplexEventStream
10+
11+
from aws_sdk_transcribe_streaming.models import (
12+
AudioEvent,
13+
AudioStream,
14+
AudioStreamAudioEvent,
15+
LanguageCode,
16+
MediaEncoding,
17+
StartStreamTranscriptionInput,
18+
StartStreamTranscriptionOutput,
19+
TranscriptResultStream,
20+
TranscriptResultStreamTranscriptEvent,
21+
)
22+
23+
from . import AUDIO_FILE, create_transcribe_client
24+
25+
26+
SAMPLE_RATE = 16000
27+
BYTES_PER_SAMPLE = 2
28+
CHANNEL_NUMS = 1
29+
CHUNK_SIZE = 1024 * 8
30+
31+
32+
async def _send_audio_chunks(
33+
stream: DuplexEventStream[
34+
AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput
35+
],
36+
) -> None:
37+
"""Send audio chunks from file simulating real-time delay."""
38+
start_time = time.time()
39+
elapsed_audio_time = 0.0
40+
41+
with AUDIO_FILE.open("rb") as f:
42+
while chunk := f.read(CHUNK_SIZE):
43+
await stream.input_stream.send(
44+
AudioStreamAudioEvent(value=AudioEvent(audio_chunk=chunk))
45+
)
46+
elapsed_audio_time += len(chunk) / (
47+
BYTES_PER_SAMPLE * SAMPLE_RATE * CHANNEL_NUMS
48+
)
49+
wait_time = start_time + elapsed_audio_time - time.time()
50+
await asyncio.sleep(wait_time)
51+
52+
# Send an empty audio event to signal end of input
53+
await stream.input_stream.send(
54+
AudioStreamAudioEvent(value=AudioEvent(audio_chunk=b""))
55+
)
56+
await asyncio.sleep(0.4)
57+
await stream.input_stream.close()
58+
59+
60+
async def _receive_transcription_output(
61+
stream: DuplexEventStream[
62+
AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput
63+
],
64+
) -> tuple[bool, list[str]]:
65+
"""Receive and collect transcription output from the stream.
66+
67+
Returns:
68+
Tuple of (got_transcript_events, transcripts)
69+
"""
70+
got_transcript_events = False
71+
transcripts: list[str] = []
72+
73+
_, output_stream = await stream.await_output()
74+
if output_stream is None:
75+
return got_transcript_events, transcripts
76+
77+
async for event in output_stream:
78+
if not isinstance(event, TranscriptResultStreamTranscriptEvent):
79+
raise RuntimeError(
80+
f"Received unexpected event type in stream: {type(event).__name__}"
81+
)
82+
83+
got_transcript_events = True
84+
if event.value.transcript and event.value.transcript.results:
85+
for result in event.value.transcript.results:
86+
if result.alternatives:
87+
for alt in result.alternatives:
88+
if alt.transcript:
89+
transcripts.append(alt.transcript)
90+
91+
return got_transcript_events, transcripts
92+
93+
94+
async def test_start_stream_transcription() -> None:
95+
"""Test bidirectional streaming with audio input and transcription output."""
96+
transcribe_client = create_transcribe_client("us-west-2")
97+
98+
stream = await transcribe_client.start_stream_transcription(
99+
input=StartStreamTranscriptionInput(
100+
language_code=LanguageCode.EN_US,
101+
media_sample_rate_hertz=SAMPLE_RATE,
102+
media_encoding=MediaEncoding.PCM,
103+
)
104+
)
105+
106+
results = await asyncio.gather(
107+
_send_audio_chunks(stream), _receive_transcription_output(stream)
108+
)
109+
got_transcript_events, transcripts = results[1]
110+
111+
assert got_transcript_events, "Expected to receive transcript events"
112+
assert len(transcripts) > 0, "Expected to receive at least one transcript"
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Test non-streaming output type handling.
5+
6+
This test requires AWS resources (an IAM role and an S3 bucket).
7+
To set them up locally, run:
8+
9+
uv run scripts/setup_resources.py
10+
11+
Then export the environment variables shown in the output.
12+
"""
13+
14+
import asyncio
15+
import os
16+
import time
17+
import uuid
18+
19+
import pytest
20+
21+
from aws_sdk_transcribe_streaming.models import (
22+
ClinicalNoteGenerationSettings,
23+
GetMedicalScribeStreamInput,
24+
GetMedicalScribeStreamOutput,
25+
LanguageCode,
26+
MedicalScribeAudioEvent,
27+
MedicalScribeConfigurationEvent,
28+
MedicalScribeInputStreamAudioEvent,
29+
MedicalScribeInputStreamConfigurationEvent,
30+
MedicalScribeInputStreamSessionControlEvent,
31+
MedicalScribePostStreamAnalyticsSettings,
32+
MedicalScribeSessionControlEvent,
33+
MedicalScribeSessionControlEventType,
34+
MediaEncoding,
35+
StartMedicalScribeStreamInput,
36+
)
37+
38+
from . import AUDIO_FILE, create_transcribe_client
39+
40+
SAMPLE_RATE = 16000
41+
BYTES_PER_SAMPLE = 2
42+
CHANNEL_NUMS = 1
43+
CHUNK_SIZE = 1024 * 8
44+
45+
46+
async def test_get_medical_scribe_stream() -> None:
47+
role_arn = os.environ.get("HEALTHSCRIBE_ROLE_ARN")
48+
s3_bucket = os.environ.get("HEALTHSCRIBE_S3_BUCKET")
49+
50+
if not role_arn or not s3_bucket:
51+
pytest.fail("HEALTHSCRIBE_ROLE_ARN or HEALTHSCRIBE_S3_BUCKET not set")
52+
53+
transcribe_client = create_transcribe_client("us-east-1")
54+
session_id = str(uuid.uuid4())
55+
56+
stream = await transcribe_client.start_medical_scribe_stream(
57+
input=StartMedicalScribeStreamInput(
58+
language_code=LanguageCode.EN_US,
59+
media_sample_rate_hertz=SAMPLE_RATE,
60+
media_encoding=MediaEncoding.PCM,
61+
session_id=session_id,
62+
)
63+
)
64+
65+
await stream.input_stream.send(
66+
MedicalScribeInputStreamConfigurationEvent(
67+
value=MedicalScribeConfigurationEvent(
68+
resource_access_role_arn=role_arn,
69+
post_stream_analytics_settings=MedicalScribePostStreamAnalyticsSettings(
70+
clinical_note_generation_settings=ClinicalNoteGenerationSettings(
71+
output_bucket_name=s3_bucket
72+
)
73+
),
74+
)
75+
)
76+
)
77+
78+
start_time = time.time()
79+
elapsed_audio_time = 0.0
80+
81+
with AUDIO_FILE.open("rb") as f:
82+
while chunk := f.read(CHUNK_SIZE):
83+
await stream.input_stream.send(
84+
MedicalScribeInputStreamAudioEvent(
85+
value=MedicalScribeAudioEvent(audio_chunk=chunk)
86+
)
87+
)
88+
elapsed_audio_time += len(chunk) / (
89+
BYTES_PER_SAMPLE * SAMPLE_RATE * CHANNEL_NUMS
90+
)
91+
wait_time = start_time + elapsed_audio_time - time.time()
92+
if wait_time > 0:
93+
await asyncio.sleep(wait_time)
94+
95+
await stream.input_stream.send(
96+
MedicalScribeInputStreamSessionControlEvent(
97+
value=MedicalScribeSessionControlEvent(
98+
type=MedicalScribeSessionControlEventType.END_OF_SESSION
99+
)
100+
)
101+
)
102+
await stream.input_stream.close()
103+
104+
await stream.await_output()
105+
106+
# Consume output stream events to properly close the connection
107+
if stream.output_stream:
108+
async for _ in stream.output_stream:
109+
pass
110+
111+
response = await transcribe_client.get_medical_scribe_stream(
112+
input=GetMedicalScribeStreamInput(session_id=session_id)
113+
)
114+
115+
assert isinstance(response, GetMedicalScribeStreamOutput)
116+
assert response.medical_scribe_stream_details is not None
117+
118+
details = response.medical_scribe_stream_details
119+
assert details.session_id == session_id
120+
assert details.stream_status == "COMPLETED"
121+
assert details.language_code == "en-US"
122+
assert details.media_encoding == "pcm"
123+
assert details.media_sample_rate_hertz == SAMPLE_RATE
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# /// script
2+
# requires-python = ">=3.12"
3+
# dependencies = [
4+
# "boto3",
5+
# ]
6+
# ///
7+
#
8+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
9+
# SPDX-License-Identifier: Apache-2.0
10+
11+
"""Setup script to create AWS resources for integration tests.
12+
13+
Creates an IAM role and S3 bucket needed for medical scribe integration tests.
14+
15+
Note:
16+
This script is intended for local testing only and should not be used for
17+
production setups.
18+
19+
Usage:
20+
uv run scripts/setup_resources.py
21+
"""
22+
23+
import json
24+
from typing import Any
25+
26+
import boto3
27+
28+
29+
def create_iam_role(iam_client: Any, role_name: str, bucket_name: str) -> None:
30+
trust_policy = {
31+
"Version": "2012-10-17",
32+
"Statement": [
33+
{
34+
"Effect": "Allow",
35+
"Principal": {
36+
"Service": [
37+
"transcribe.streaming.amazonaws.com"
38+
]
39+
},
40+
"Action": "sts:AssumeRole",
41+
}
42+
]
43+
}
44+
45+
try:
46+
iam_client.create_role(
47+
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy)
48+
)
49+
except iam_client.exceptions.EntityAlreadyExistsException:
50+
pass
51+
52+
permissions_policy = {
53+
"Version": "2012-10-17",
54+
"Statement": [
55+
{
56+
"Action": [
57+
"s3:PutObject"
58+
],
59+
"Resource": [
60+
f"arn:aws:s3:::{bucket_name}",
61+
f"arn:aws:s3:::{bucket_name}/*",
62+
],
63+
"Effect": "Allow"
64+
}
65+
]
66+
}
67+
68+
iam_client.put_role_policy(
69+
RoleName=role_name,
70+
PolicyName="HealthScribeS3Access",
71+
PolicyDocument=json.dumps(permissions_policy),
72+
)
73+
74+
75+
def setup_healthscribe_resources() -> tuple[str, str]:
76+
region = "us-east-1"
77+
iam = boto3.client("iam")
78+
s3 = boto3.client("s3", region_name=region)
79+
sts = boto3.client("sts")
80+
81+
account_id = sts.get_caller_identity()["Account"]
82+
bucket_name = f"healthscribe-test-{account_id}-{region}"
83+
role_name = "HealthScribeIntegrationTestRole"
84+
85+
s3.create_bucket(Bucket=bucket_name)
86+
create_iam_role(iam, role_name, bucket_name)
87+
88+
role_arn = f"arn:aws:iam::{account_id}:role/{role_name}"
89+
return role_arn, bucket_name
90+
91+
92+
if __name__ == "__main__":
93+
role_arn, bucket_name = setup_healthscribe_resources()
94+
95+
print("Setup complete. Export these environment variables before running tests:")
96+
print(f"export HEALTHSCRIBE_ROLE_ARN={role_arn}")
97+
print(f"export HEALTHSCRIBE_S3_BUCKET={bucket_name}")

0 commit comments

Comments
 (0)