Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitleaks.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ description = "Empty environment variables with KEY pattern"
regex = '''os\.environ\[".*?KEY"\]\s*=\s*".+"'''

[allowlist]
paths = ["requirements.txt", "tests"]
paths = ["requirements.txt", "tests", "veadk/realtime/client.py", "veadk/realtime/live.py"]
62 changes: 62 additions & 0 deletions tests/config/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest import TestCase, mock
from veadk.configs.model_configs import RealtimeModelConfig


class TestRealtimeModelConfig(TestCase):
def test_default_values(self):
"""Test that default values are set correctly"""
config = RealtimeModelConfig()
self.assertEqual(config.name, "doubao_realtime_voice_model")
self.assertEqual(
config.api_base, "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
)

@mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": "test_api_key"})
def test_api_key_from_env(self):
"""Test api_key is retrieved from environment variable"""
config = RealtimeModelConfig()
self.assertEqual(config.api_key, "test_api_key")

@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch(
"veadk.configs.model_configs.get_speech_token", return_value="mocked_token"
)
def test_api_key_from_get_speech_token(self, mock_get_token):
"""Test api_key falls back to get_speech_token when env var is not set"""
config = RealtimeModelConfig()
self.assertEqual(config.api_key, "mocked_token")
mock_get_token.assert_called_once()

@mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": ""})
@mock.patch(
"veadk.configs.model_configs.get_speech_token", return_value="mocked_token"
)
def test_api_key_empty_env_var(self, mock_get_token):
"""Test api_key falls back when env var is empty string"""
config = RealtimeModelConfig()
self.assertEqual(config.api_key, "mocked_token")
mock_get_token.assert_called_once()

def test_api_key_caching(self):
"""Test that api_key is properly cached"""
with mock.patch.dict(os.environ, {"MODEL_REALTIME_API_KEY": "test_key"}):
config = RealtimeModelConfig()
first_call = config.api_key
second_call = config.api_key
self.assertEqual(first_call, second_call)
self.assertEqual(first_call, "test_key")
82 changes: 82 additions & 0 deletions tests/realtime/test_doubao_realtime_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from unittest.mock import patch, MagicMock
from google.genai._api_client import BaseApiClient
from veadk.realtime.client import DoubaoClient, DoubaoAsyncClient
from veadk.utils.logger import get_logger

logger = get_logger(__name__)


class TestDoubaoAsyncClient(unittest.TestCase):
def setUp(self):
self.mock_api_client = MagicMock(spec=BaseApiClient)
self.async_client = DoubaoAsyncClient(self.mock_api_client)

def test_initialization(self):
self.assertIsInstance(self.async_client, DoubaoAsyncClient)
self.assertEqual(self.async_client._api_client, self.mock_api_client)

def test_live_property(self):
from veadk.realtime.live import DoubaoAsyncLive

live_instance = self.async_client.live
self.assertIsInstance(live_instance, DoubaoAsyncLive)
self.assertEqual(live_instance._api_client, self.mock_api_client)


class TestDoubaoClient(unittest.TestCase):
def setUp(self):
self.patcher = patch.dict("os.environ", {}, clear=True)
self.patcher.start()

def tearDown(self):
self.patcher.stop()

def test_initialization_without_google_key(self):
# Test when GOOGLE_API_KEY is not set
os.environ["REALTIME_API_KEY"] = "hack_google_api_key"
client = DoubaoClient()
self.assertEqual(os.environ["GOOGLE_API_KEY"], "hack_google_api_key")
self.assertIsNotNone(client._aio)

def test_initialization_with_google_key(self):
# Test when GOOGLE_API_KEY is already set
os.environ["GOOGLE_API_KEY"] = "existing_key"
os.environ["REALTIME_API_KEY"] = "existing_key"
client = DoubaoClient()
self.assertEqual(os.environ["GOOGLE_API_KEY"], "existing_key")
self.assertIsNotNone(client._aio)

@patch(
"veadk.realtime.client.DoubaoAsyncClient", side_effect=Exception("Test error")
)
def test_initialization_failure(self, mock_async_client):
# Test when DoubaoAsyncClient initialization fails
os.environ["REALTIME_API_KEY"] = "hack_google_api_key"
client = DoubaoClient()
self.assertIsNone(client._aio)

def test_aio_property(self):
os.environ["REALTIME_API_KEY"] = "hack_google_api_key"
client = DoubaoClient()
aio_client = client.aio
self.assertIsInstance(aio_client, DoubaoAsyncClient)


if __name__ == "__main__":
unittest.main()
119 changes: 119 additions & 0 deletions tests/realtime/test_doubao_realtime_voice_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from google.genai import types
from veadk.realtime.doubao_realtime_voice_llm import DoubaoRealtimeVoice
from google.adk.models.llm_request import LlmRequest
from google.adk.models.base_llm_connection import BaseLlmConnection
from google.genai.types import GenerateContentConfig
import os
from veadk.realtime.client import DoubaoClient
from veadk.realtime.doubao_realtime_voice_llm import (
_AGENT_ENGINE_TELEMETRY_TAG,
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME,
)


class TestDoubaoRealtimeVoice:
@pytest.fixture
def mock_llm_request(self):
request = MagicMock(spec=LlmRequest)
request.model = "doubao_realtime_voice"
request.config = GenerateContentConfig()
request.config.system_instruction = "Test instruction"
request.config.tools = []
request.live_connect_config = types.LiveConnectConfig(
http_options=types.HttpOptions()
)
return request

def test_supported_models(self):
"""Test supported_models returns correct model patterns"""
models = DoubaoRealtimeVoice.supported_models()
assert isinstance(models, list)
assert len(models) == 2
assert r"doubao_realtime_voice.*" in models
assert r"Doubao_scene_SLM_Doubao_realtime_voice_model.*" in models

def test_api_client_property(self):
"""Test api_client property returns DoubaoClient with correct options"""
model = DoubaoRealtimeVoice()
client = model.api_client
assert isinstance(client, DoubaoClient)
assert client._api_client._http_options.retry_options == model.retry_options

def test_live_api_client_property(self):
"""Test _live_api_client property returns DoubaoClient with correct version"""
model = DoubaoRealtimeVoice()
client = model._live_api_client
assert isinstance(client, DoubaoClient)
assert client._api_client._http_options.api_version == model._live_api_version

def test_tracking_headers_without_env(self):
"""Test _tracking_headers without environment variable"""
model = DoubaoRealtimeVoice()
headers = model._tracking_headers
assert "x-volcengine-api-client" in headers
assert "user-agent" in headers
assert _AGENT_ENGINE_TELEMETRY_TAG not in headers["x-volcengine-api-client"]

@patch.dict(os.environ, {_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME: "test_id"})
def test_tracking_headers_with_env(self):
"""Test _tracking_headers with environment variable set"""
model = DoubaoRealtimeVoice()
headers = model._tracking_headers
assert _AGENT_ENGINE_TELEMETRY_TAG in headers["x-volcengine-api-client"]

@pytest.mark.asyncio
async def test_connect_with_speech_config(self, mock_llm_request):
"""Test connect method with speech config"""
speech_config = types.SpeechConfig()
model = DoubaoRealtimeVoice(speech_config=speech_config)

# 修正异步上下文管理器的 mock 设置
with patch.object(model._live_api_client.aio.live, "connect") as mock_connect:
# 创建模拟的异步上下文管理器
mock_session = AsyncMock()
mock_connect.return_value.__aenter__.return_value = mock_session

async with model.connect(mock_llm_request) as connection:
assert isinstance(connection, BaseLlmConnection)
assert (
mock_llm_request.live_connect_config.speech_config == speech_config
)
mock_connect.assert_called_once_with(
model=mock_llm_request.model,
config=mock_llm_request.live_connect_config,
)

@pytest.mark.asyncio
async def test_connect_without_speech_config(self, mock_llm_request):
"""Test connect method without speech config"""
model = DoubaoRealtimeVoice()

with patch.object(model._live_api_client.aio.live, "connect") as mock_connect:
# 使用AsyncMock模拟会话对象,更贴近真实场景
mock_session = AsyncMock()
mock_connect.return_value.__aenter__.return_value = mock_session

async with model.connect(mock_llm_request) as connection:
assert isinstance(connection, BaseLlmConnection)
# 验证speech_config为None而非检查属性是否存在
assert mock_llm_request.live_connect_config.speech_config is None
mock_connect.assert_called_once_with(
model=mock_llm_request.model,
config=mock_llm_request.live_connect_config,
)
90 changes: 90 additions & 0 deletions tests/realtime/test_doubao_realtime_voice_llm_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest.mock import AsyncMock
from veadk.realtime.doubao_realtime_voice_llm_connection import (
DoubaoRealtimeVoiceLlmConnection,
)
from google.genai import types


@pytest.mark.asyncio
async def test_send_realtime_with_blob():
"""Test sending Blob input."""
# Setup
mock_session = AsyncMock()
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
connection._gemini_session = mock_session

blob_input = types.Blob()

# Execute
await connection.send_realtime(blob_input)

# Verify
mock_session.send_realtime_input.assert_called_once_with(media=blob_input)


@pytest.mark.asyncio
async def test_send_realtime_with_activity_start():
"""Test sending ActivityStart input."""
# Setup
mock_session = AsyncMock()
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
connection._gemini_session = mock_session

activity_start = types.ActivityStart()

# Execute
await connection.send_realtime(activity_start)

# Verify
mock_session.send_realtime_input.assert_called_once_with(
activity_start=activity_start
)


@pytest.mark.asyncio
async def test_send_realtime_with_activity_end():
"""Test sending ActivityEnd input."""
# Setup
mock_session = AsyncMock()
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
connection._gemini_session = mock_session

activity_end = types.ActivityEnd()

# Execute
await connection.send_realtime(activity_end)

# Verify
mock_session.send_realtime_input.assert_called_once_with(activity_end=activity_end)


@pytest.mark.asyncio
async def test_send_realtime_with_unsupported_type():
"""Test sending unsupported input type."""
# Setup
mock_session = AsyncMock()
connection = DoubaoRealtimeVoiceLlmConnection(gemini_session=mock_session)
connection._gemini_session = mock_session

unsupported_input = "unsupported_type"

# Execute & Verify
with pytest.raises(ValueError) as excinfo:
await connection.send_realtime(unsupported_input)

assert "Unsupported input type" in str(excinfo.value)
Loading