Skip to content

Commit 829fd94

Browse files
authored
feat(firebaseai): live session resumption (#18038)
* add structure for live session resumption * session resumption config should be session based * init setup * a bit session management fix * some update for session management * refactor live_session connect and resume api * refactor for websocket session open * fix some error while handling the received message * some clean up, and more dev logs * session resume with toggle * some minor updates * fix analyzer * somehow worked resume session * add google search system tool for bidi page * some clean up * format * remove unnecessary logs * minor tweak * fix analyzer * review feedback * fix format * add the sliding window documentation * session resumption api update * address review comments * add more error logs
1 parent ab0b2f9 commit 829fd94

10 files changed

Lines changed: 1068 additions & 515 deletions

File tree

packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart

Lines changed: 609 additions & 450 deletions
Large diffs are not rendered by default.

packages/firebase_ai/firebase_ai/example/lib/pages/chat_page.dart

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class _ChatPageState extends State<ChatPage> {
152152
Icons.send,
153153
color: Theme.of(context).colorScheme.primary,
154154
),
155+
tooltip: 'Send',
155156
),
156157
IconButton(
157158
onPressed: () {
@@ -161,6 +162,7 @@ class _ChatPageState extends State<ChatPage> {
161162
Icons.stream,
162163
color: Theme.of(context).colorScheme.primary,
163164
),
165+
tooltip: 'Send Stream',
164166
),
165167
],
166168
)

packages/firebase_ai/firebase_ai/example/lib/utils/audio_input.dart

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ class AudioInput extends ChangeNotifier {
110110
sampleRate: 24000,
111111
device: selectedDevice,
112112
numChannels: 1,
113-
echoCancel: true,
114-
noiseSuppress: true,
115113
androidConfig: const AndroidRecordConfig(
116114
audioSource: AndroidAudioSource.voiceCommunication,
117115
),

packages/firebase_ai/firebase_ai/lib/firebase_ai.dart

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,19 @@ export 'src/imagen/imagen_reference.dart'
102102
ImagenControlReference;
103103
export 'src/live_api.dart'
104104
show
105-
LiveGenerationConfig,
106-
SpeechConfig,
107105
AudioTranscriptionConfig,
106+
ContextWindowCompressionConfig,
107+
GoingAwayNotice,
108+
LiveGenerationConfig,
108109
LiveServerMessage,
109110
LiveServerContent,
110111
LiveServerToolCall,
111112
LiveServerToolCallCancellation,
112113
LiveServerResponse,
113-
GoingAwayNotice,
114+
SessionResumptionConfig,
115+
SessionResumptionUpdate,
116+
SlidingWindow,
117+
SpeechConfig,
114118
Transcription;
115119
export 'src/live_session.dart' show LiveSession;
116120
export 'src/schema.dart' show JSONSchema, Schema, SchemaType;

packages/firebase_ai/firebase_ai/lib/src/base_model.dart

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,12 @@
1313
// limitations under the License.
1414

1515
import 'dart:async';
16-
import 'dart:convert';
1716

1817
import 'package:firebase_app_check/firebase_app_check.dart';
1918
import 'package:firebase_auth/firebase_auth.dart';
2019
import 'package:firebase_core/firebase_core.dart';
21-
import 'package:flutter/foundation.dart';
2220
import 'package:http/http.dart' as http;
2321
import 'package:meta/meta.dart';
24-
import 'package:web_socket_channel/io.dart';
25-
import 'package:web_socket_channel/web_socket_channel.dart';
2622

2723
import 'api.dart';
2824
import 'client.dart';

packages/firebase_ai/firebase_ai/lib/src/live_api.dart

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,98 @@ class AudioTranscriptionConfig {
7777
Map<String, Object?> toJson() => {};
7878
}
7979

80+
/// Configures the sliding window context compression mechanism.
81+
///
82+
/// The SlidingWindow method operates by discarding content at the beginning of
83+
/// the context window. The resulting context will always begin at the start of
84+
/// a USER role turn. System instructions will always remain at the start of the
85+
/// result.
86+
class SlidingWindow {
87+
/// Creates a [SlidingWindow] instance.
88+
///
89+
/// [targetTokens] (optional): The target number of tokens to keep in the
90+
/// context window.
91+
SlidingWindow({this.targetTokens});
92+
93+
/// The session reduction target, i.e., how many tokens we should keep.
94+
final int? targetTokens;
95+
// ignore: public_member_api_docs
96+
Map<String, Object?> toJson() =>
97+
{if (targetTokens case final targetTokens?) 'targetTokens': targetTokens};
98+
}
99+
100+
/// Enables context window compression to manage the model's context window.
101+
///
102+
/// This mechanism prevents the context from exceeding a given length.
103+
class ContextWindowCompressionConfig {
104+
/// Creates a [ContextWindowCompressionConfig] instance.
105+
///
106+
/// [triggerTokens] (optional): The number of tokens that triggers the
107+
/// compression mechanism.
108+
/// [slidingWindow] (optional): The sliding window compression mechanism to
109+
/// use.
110+
ContextWindowCompressionConfig({this.triggerTokens, this.slidingWindow});
111+
112+
/// The number of tokens (before running a turn) that triggers the context
113+
/// window compression.
114+
final int? triggerTokens;
115+
116+
/// The sliding window compression mechanism.
117+
final SlidingWindow? slidingWindow;
118+
// ignore: public_member_api_docs
119+
Map<String, Object?> toJson() => {
120+
if (triggerTokens case final triggerTokens?)
121+
'triggerTokens': triggerTokens,
122+
if (slidingWindow case final slidingWindow?)
123+
'slidingWindow': slidingWindow.toJson()
124+
};
125+
}
126+
127+
/// Configuration for the session resumption mechanism.
128+
///
129+
/// When included in the session setup, the server will send
130+
/// [SessionResumptionUpdate] messages.
131+
class SessionResumptionConfig {
132+
/// Creates a [SessionResumptionConfig] to start a new resumable session.
133+
///
134+
/// When this is included in the session setup, the server will send
135+
/// [SessionResumptionUpdate] messages with handles that can be used to
136+
/// resume the session later.
137+
SessionResumptionConfig() : handle = null;
138+
139+
/// Creates a [SessionResumptionConfig] to resume a previous session.
140+
///
141+
/// [handle] is the session resumption handle received in a previous session's
142+
/// [SessionResumptionUpdate].
143+
SessionResumptionConfig.resume(String this.handle);
144+
145+
/// The session resumption handle of the previous session to restore.
146+
///
147+
/// If null, a new session will be started (and will be resumable if this
148+
/// config was included).
149+
final String? handle;
150+
151+
// ignore: public_member_api_docs
152+
Map<String, Object?> toJson() => {
153+
if (handle case final handle?) 'handle': handle,
154+
};
155+
}
156+
80157
/// Configures live generation settings.
81158
final class LiveGenerationConfig extends BaseGenerationConfig {
82159
// ignore: public_member_api_docs
83-
LiveGenerationConfig({
84-
this.speechConfig,
85-
this.inputAudioTranscription,
86-
this.outputAudioTranscription,
87-
super.responseModalities,
88-
super.maxOutputTokens,
89-
super.temperature,
90-
super.topP,
91-
super.topK,
92-
super.presencePenalty,
93-
super.frequencyPenalty,
94-
});
160+
LiveGenerationConfig(
161+
{this.speechConfig,
162+
this.inputAudioTranscription,
163+
this.outputAudioTranscription,
164+
this.contextWindowCompression,
165+
super.responseModalities,
166+
super.maxOutputTokens,
167+
super.temperature,
168+
super.topP,
169+
super.topK,
170+
super.presencePenalty,
171+
super.frequencyPenalty});
95172

96173
/// The speech configuration.
97174
final SpeechConfig? speechConfig;
@@ -103,6 +180,9 @@ final class LiveGenerationConfig extends BaseGenerationConfig {
103180
/// the output audio.
104181
final AudioTranscriptionConfig? outputAudioTranscription;
105182

183+
/// The context window compression configuration.
184+
final ContextWindowCompressionConfig? contextWindowCompression;
185+
106186
@override
107187
Map<String, Object?> toJson() => {
108188
...super.toJson(),
@@ -222,6 +302,34 @@ class GoingAwayNotice implements LiveServerMessage {
222302
final String? timeLeft;
223303
}
224304

305+
/// An update of the session resumption state.
306+
///
307+
/// This message is only sent if [SessionResumptionConfig] was set in the
308+
/// session setup.
309+
class SessionResumptionUpdate implements LiveServerMessage {
310+
/// Creates a [SessionResumptionUpdate] instance.
311+
///
312+
/// [newHandle] (optional): The new handle that represents the state that can
313+
/// be resumed.
314+
/// [resumable] (optional): Indicates if the session can be resumed at this
315+
/// point.
316+
/// [lastConsumedClientMessageIndex] (optional): The index of the last client
317+
/// message that is included in the state represented by this update.
318+
SessionResumptionUpdate(
319+
{this.newHandle, this.resumable, this.lastConsumedClientMessageIndex});
320+
321+
/// The new handle that represents the state that can be resumed. Empty if
322+
/// `resumable` is false.
323+
final String? newHandle;
324+
325+
/// Indicates if the session can be resumed at this point.
326+
final bool? resumable;
327+
328+
/// The index of the last client message that is included in the state
329+
/// represented by this update.
330+
final int? lastConsumedClientMessageIndex;
331+
}
332+
225333
/// A single response chunk received during a live content generation.
226334
///
227335
/// It can contain generated content, function calls to be executed, or
@@ -449,8 +557,17 @@ LiveServerMessage _parseServerMessage(Object jsonObject) {
449557
} else if (json.containsKey('setupComplete')) {
450558
return LiveServerSetupComplete();
451559
} else if (json.containsKey('goAway')) {
452-
final goAwayJson = json['goAway'] as Map;
560+
final goAwayJson = json['goAway'] as Map<String, dynamic>;
453561
return GoingAwayNotice(timeLeft: goAwayJson['timeLeft'] as String?);
562+
} else if (json.containsKey('sessionResumptionUpdate')) {
563+
final sessionResumptionUpdateJson =
564+
json['sessionResumptionUpdate'] as Map<String, dynamic>;
565+
return SessionResumptionUpdate(
566+
newHandle: sessionResumptionUpdateJson['newHandle'] as String?,
567+
resumable: sessionResumptionUpdateJson['resumable'] as bool?,
568+
lastConsumedClientMessageIndex:
569+
sessionResumptionUpdateJson['lastConsumedClientMessageIndex'] as int?,
570+
);
454571
} else {
455572
throw unhandledFormat('LiveServerMessage', json);
456573
}

packages/firebase_ai/firebase_ai/lib/src/live_model.dart

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -90,48 +90,33 @@ final class LiveGenerativeModel extends BaseModel {
9090
///
9191
/// This function handles the WebSocket connection setup and returns an [LiveSession]
9292
/// object that can be used to communicate with the service.
93+
/// [sessionResumption] (optional): The configuration for session resumption,
94+
/// such as the handle to the previous session state to restore.
9395
///
9496
/// Returns a [Future] that resolves to an [LiveSession] object upon successful
9597
/// connection.
96-
Future<LiveSession> connect() async {
98+
Future<LiveSession> connect(
99+
{SessionResumptionConfig? sessionResumption}) async {
97100
final uri = _useVertexBackend ? _vertexAIUri() : _googleAIUri();
98101
final modelString =
99102
_useVertexBackend ? _vertexAIModelString() : _googleAIModelString();
100103

101-
final setupJson = {
102-
'setup': {
103-
'model': modelString,
104-
if (_systemInstruction != null)
105-
'system_instruction': _systemInstruction.toJson(),
106-
if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(),
107-
if (_liveGenerationConfig != null) ...{
108-
'generation_config': _liveGenerationConfig.toJson(),
109-
if (_liveGenerationConfig.inputAudioTranscription != null)
110-
'input_audio_transcription':
111-
_liveGenerationConfig.inputAudioTranscription!.toJson(),
112-
if (_liveGenerationConfig.outputAudioTranscription != null)
113-
'output_audio_transcription':
114-
_liveGenerationConfig.outputAudioTranscription!.toJson(),
115-
},
116-
}
117-
};
118-
119-
final request = jsonEncode(setupJson);
120104
final headers = await BaseModel.firebaseTokens(
121105
_appCheck,
122106
_auth,
123107
_app,
124108
_useLimitedUseAppCheckTokens,
125109
)();
126110

127-
var ws = kIsWeb
128-
? WebSocketChannel.connect(Uri.parse(uri))
129-
: IOWebSocketChannel.connect(Uri.parse(uri), headers: headers);
130-
await ws.ready;
131-
132-
ws.sink.add(request);
133-
134-
return LiveSession(ws);
111+
return LiveSession.create(
112+
uri: uri,
113+
headers: headers,
114+
modelString: modelString,
115+
systemInstruction: _systemInstruction,
116+
tools: _tools,
117+
sessionResumption: sessionResumption,
118+
liveGenerationConfig: _liveGenerationConfig,
119+
);
135120
}
136121
}
137122

0 commit comments

Comments
 (0)