diff --git a/FirebaseAI/Sources/Extensions/Internal/LanguageModelSession+ModelSession.swift b/FirebaseAI/Sources/Extensions/Internal/LanguageModelSession+ModelSession.swift new file mode 100644 index 00000000000..5e783091737 --- /dev/null +++ b/FirebaseAI/Sources/Extensions/Internal/LanguageModelSession+ModelSession.swift @@ -0,0 +1,174 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if compiler(>=6.2.3) && canImport(FoundationModels) + import Foundation + import FoundationModels + + @available(iOS 26.0, macOS 26.0, *) + @available(tvOS, unavailable) + @available(watchOS, unavailable) + extension FoundationModels.LanguageModelSession: _ModelSession { + public var _hasHistory: Bool { + if transcript.isEmpty { + return false + } + + for entry in transcript { + switch entry { + case .instructions: + continue + case .prompt: + return true + case .toolCalls: + return true + case .toolOutput: + return true + case .response: + return true + @unknown default: + // Unknown entry type, assuming that it is session history. + return true + } + } + + return false + } + + public func _respond(to prompt: [any Part], schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, options: GenerationConfig?) async throws + -> _ModelSessionResponse { + let prompt = try prompt.toFoundationModelsPrompt() + + let response: FoundationModels.LanguageModelSession + .Response + if let schema { + response = try await respond( + to: prompt, + schema: schema.generationSchema, + includeSchemaInPrompt: includeSchemaInPrompt + // TODO: Add options: GenerationOptions + ) + } else { + response = try await respond( + to: prompt, + schema: String.generationSchema + // TODO: Add options: GenerationOptions + ) + } + + // TODO: Extract common response handling code into a helper method. + let responseText: String + if schema == nil, case let .string(text) = response.rawContent.kind { + responseText = text + } else { + responseText = response.rawContent.jsonString + } + + let generatedContent = response.rawContent.firebaseGeneratedContent + let modelContent = ModelContent( + role: "model", + parts: [InternalPart(.text(responseText), isThought: false, thoughtSignature: nil)] + ) + let candidate = Candidate( + content: modelContent, + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let rawResponse = GenerateContentResponse( + candidates: [candidate], + modelVersion: FirebaseAI.SystemLanguageModel.modelName + ) + + return _ModelSessionResponse(rawContent: generatedContent, rawResponse: rawResponse) + } + + public func _streamResponse(to prompt: [any Part], + schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, + options: GenerationConfig?) + -> sending AsyncThrowingStream<_ModelSessionResponse, any Error> { + return AsyncThrowingStream { continuation in + let foundationModelsPrompt: Prompt + do { + foundationModelsPrompt = try prompt.toFoundationModelsPrompt() + } catch { + continuation.finish(throwing: error) + return + } + + let stream: FoundationModels.LanguageModelSession + .ResponseStream + if let schema { + stream = streamResponse( + to: foundationModelsPrompt, + schema: schema.generationSchema, + includeSchemaInPrompt: includeSchemaInPrompt + // TODO: Add options: GenerationOptions + ) + } else { + stream = streamResponse( + to: foundationModelsPrompt, + schema: String.generationSchema + // TODO: Check `includeSchemaInPrompt: includeSchemaInPrompt` behaviour with `String` + // TODO: Add options: GenerationOptions + ) + } + + let task = Task { + do { + for try await snapshot in stream { + // TODO: Extract common response handling code into a helper method. + let responseText: String + if schema == nil, case let .string(text) = snapshot.rawContent.kind { + responseText = text + } else { + responseText = snapshot.rawContent.jsonString + } + + let generatedContent = snapshot.rawContent.firebaseGeneratedContent + let modelContent = ModelContent( + role: "model", + parts: [InternalPart(.text(responseText), isThought: false, thoughtSignature: nil)] + ) + let candidate = Candidate( + content: modelContent, + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let rawResponse = GenerateContentResponse( + candidates: [candidate], + modelVersion: FirebaseAI.SystemLanguageModel.modelName + ) + + let response = _ModelSessionResponse( + rawContent: generatedContent, + rawResponse: rawResponse + ) + + continuation.yield(response) + } + continuation.finish() + } catch { + continuation.finish(throwing: error) + return + } + } + continuation.onTermination = { _ in task.cancel() } + } + } + } +#endif // compiler(>=6.2.3) && canImport(FoundationModels) diff --git a/FirebaseAI/Sources/Extensions/Internal/SystemLanguageModel+LanguageModel.swift b/FirebaseAI/Sources/Extensions/Internal/SystemLanguageModel+LanguageModel.swift new file mode 100644 index 00000000000..cf4afb5528a --- /dev/null +++ b/FirebaseAI/Sources/Extensions/Internal/SystemLanguageModel+LanguageModel.swift @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if compiler(>=6.2.3) && canImport(FoundationModels) + import FoundationModels + + extension FirebaseAI.SystemLanguageModel: LanguageModel { + static let modelName = "apple-foundation-models-system-language-model" + + public var _modelName: String { + return FirebaseAI.SystemLanguageModel.modelName + } + + public func _startSession(tools: [any ToolRepresentable]?, + instructions: String?) throws -> any _ModelSession { + switch availability { + case .available: + break + case let .unavailable(reason): + throw GenerativeModelSession.GenerationError.assetsUnavailable( + GenerativeModelSession.GenerationError.Context(debugDescription: """ + The Foundation Models `SystemLanguageModel` is unavailable: \(reason) + """) + ) + } + + #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { + var afmTools = [any FoundationModels.Tool]() + // Only function calling tools are supported by Foundation Models. + for tool in tools ?? [] { + // Skips any unsupported tools such as `GoogleMaps` or `CodeExecution` since they are + // only + // supported by Gemini models. + // TODO: Decide whether to throw for unsupported `FirebaseAILogic.Tool` types or ignore. + let functionDeclarations = tool.toolRepresentation.functionDeclarations ?? [] + for functionDeclaration in functionDeclarations { + switch functionDeclaration.kind { + case .manual: + // TODO: Decide whether ignore manual function calling declarations, throw or assert. + continue + case let .foundationModels(afmTool): + guard let afmTool = afmTool as? (any FoundationModels.Tool) else { + assertionFailure(""" + The function declaration "\(afmTool)" in the tool "\(tool)" is not a + `FoundationModels.Tool` type. + """) + continue + } + afmTools.append(afmTool) + } + } + } + return LanguageModelSession(tools: afmTools, instructions: instructions) + } + #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + + throw GenerativeModelSession.GenerationError.assetsUnavailable( + GenerativeModelSession.GenerationError.Context(debugDescription: """ + Failed to start a `LanguageModelSession`. The Foundation Models `SystemLanguageModel` is not + available on the current platform. + """) + ) + } + } +#endif // compiler(>=6.2.3) && canImport(FoundationModels) diff --git a/FirebaseAI/Sources/FirebaseAI.swift b/FirebaseAI/Sources/FirebaseAI.swift index 31ad264c45b..65670be4175 100644 --- a/FirebaseAI/Sources/FirebaseAI.swift +++ b/FirebaseAI/Sources/FirebaseAI.swift @@ -108,6 +108,9 @@ public final class FirebaseAI: Sendable { // TODO: Remove the `#if compiler(>=6.2.3)` when Xcode 26.2 is the minimum supported version. #if compiler(>=6.2.3) + + // TODO: Add public API for instantiating models to use with hybrid GenerativeModelSession. + /// Creates a new `GenerativeModelSession` with the given model. /// /// - Important: **Public Preview** - This API is a public preview and may be subject to change. @@ -122,14 +125,28 @@ public final class FirebaseAI: Sendable { /// - instructions: System instructions that direct the model's behavior. public func generativeModelSession(model: String, tools: [any ToolRepresentable]? = nil, instructions: String? = nil) -> GenerativeModelSession { - let tools = tools?.map { $0.toolRepresentation } - let model = generativeModel( - modelName: model, - tools: tools, - systemInstruction: instructions.map { ModelContent(role: "system", parts: $0) } + let geminiModel = geminiModel(modelName: model) + + return generativeModelSession(model: geminiModel, tools: tools, instructions: instructions) + } + + // TODO: Update this testing API for hybrid GenerativeModelSession. + func geminiModel(modelName: String, safetySettings: [SafetySetting]? = nil, + toolConfig: ToolConfig? = nil) -> any LanguageModel { + return GeminiModel( + modelName: modelName, + modelResourceName: modelResourceName(modelName: modelName), + firebaseInfo: firebaseInfo, + apiConfig: apiConfig, + safetySettings: safetySettings, + toolConfig: toolConfig ) + } - return GenerativeModelSession(model: model) + // TODO: Update this testing API for hybrid GenerativeModelSession. + func generativeModelSession(model: any LanguageModel, tools: [any ToolRepresentable]? = nil, + instructions: String? = nil) -> GenerativeModelSession { + return GenerativeModelSession(model: model, tools: tools, instructions: instructions) } #if canImport(FoundationModels) diff --git a/FirebaseAI/Sources/GenerateContentResponse.swift b/FirebaseAI/Sources/GenerateContentResponse.swift index 9915acc4287..6435c021b16 100644 --- a/FirebaseAI/Sources/GenerateContentResponse.swift +++ b/FirebaseAI/Sources/GenerateContentResponse.swift @@ -71,6 +71,8 @@ public struct GenerateContentResponse: Sendable { let responseID: String? + let modelVersion: String? + /// The response's content as text, if it exists. /// /// - Note: This does not include thought summaries; see ``thoughtSummary`` for more details. @@ -124,6 +126,17 @@ public struct GenerateContentResponse: Sendable { self.promptFeedback = promptFeedback self.usageMetadata = usageMetadata responseID = nil + modelVersion = nil + } + + init(candidates: [Candidate], promptFeedback: PromptFeedback? = nil, + usageMetadata: UsageMetadata? = nil, responseID: String? = nil, + modelVersion: String? = nil) { + self.candidates = candidates + self.promptFeedback = promptFeedback + self.usageMetadata = usageMetadata + self.responseID = responseID + self.modelVersion = modelVersion } func text(isThought: Bool) -> String? { @@ -448,6 +461,7 @@ extension GenerateContentResponse: Decodable { case promptFeedback case usageMetadata case responseID = "responseId" + case modelVersion } public init(from decoder: Decoder) throws { @@ -474,6 +488,7 @@ extension GenerateContentResponse: Decodable { promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback) usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata) responseID = try container.decodeIfPresent(String.self, forKey: .responseID) + modelVersion = try container.decodeIfPresent(String.self, forKey: .modelVersion) } } diff --git a/FirebaseAI/Sources/GenerativeModelSession.swift b/FirebaseAI/Sources/GenerativeModelSession.swift index cf00df7a6c4..ed2a4c402c9 100644 --- a/FirebaseAI/Sources/GenerativeModelSession.swift +++ b/FirebaseAI/Sources/GenerativeModelSession.swift @@ -14,6 +14,7 @@ // TODO: Remove the `#if compiler(>=6.2.3)` when Xcode 26.2 is the minimum supported version. #if compiler(>=6.2.3) + private import FirebaseCoreInternal import Foundation #if canImport(FoundationModels) import FoundationModels @@ -51,8 +52,8 @@ /// print("Favorite Topics: \(response.content.favoriteTopics.joined(separator: ", "))") /// ``` public final class GenerativeModelSession: Sendable { - let session: Chat - let functionDeclarations: [String: FunctionDeclaration] + let sessionManager: SessionManager + let instructions: String? // The maximum number of automatic back-and-forth turns the session will perform to resolve // function calls. @@ -66,9 +67,9 @@ /// /// **Public Preview**: This API is a public preview and may be subject to change. /// - Parameter model: The `GenerativeModel` to use for generating content. - init(model: GenerativeModel) { - session = model.startChat() - functionDeclarations = model.functionDeclarationsByName() + init(model: any LanguageModel, tools: [any ToolRepresentable]?, instructions: String?) { + sessionManager = SessionManager(model: model, tools: tools) + self.instructions = instructions } /// Sends a new prompt to the model and returns a `Response` containing the generated content as @@ -245,284 +246,78 @@ generating type: Content.Type, includeSchemaInPrompt: Bool, options: GenerationConfig?) async throws -> GenerativeModelSession.Response { - let parts = [ModelContent(parts: prompt)] - let config = try buildConfig( - options: options, - schema: schema, - includeSchemaInPrompt: includeSchemaInPrompt - ) - - var response = try await session.sendMessage(parts, generationConfig: config) - - var autoFunctionCallTurns = 0 - while !response.functionCalls.isEmpty { - guard autoFunctionCallTurns < GenerativeModelSession.maxAutoFunctionCallTurns else { - throw GenerationError.internalError( - GenerationError.Context( - debugDescription: """ - The model exceeded the maximum allowed automatic function call iterations \ - (\(GenerativeModelSession.maxAutoFunctionCallTurns)). - """ - ), - underlyingError: FunctionCallingError.maxFunctionCallTurnsExceeded - ) - } - - let functionResponses = try await execute(functionCalls: response.functionCalls) - - guard !functionResponses.isEmpty else { break } - response = try await session.sendMessage( - [ModelContent(role: "user", parts: functionResponses)], - generationConfig: config - ) - - autoFunctionCallTurns += 1 + try sessionManager.startResponding() + defer { + self.sessionManager.finishResponding() } - let text: String - if let responseText = response.text { - text = responseText - } else if let parts = response.candidates.first?.content.parts, !parts.isEmpty { - text = "" - } else { - throw GenerationError.decodingFailure( - GenerationError.Context(debugDescription: "No parts in response: \(response)") - ) - } - let generationID = response.responseID.map { - #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM - if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { - return FirebaseAI.GenerationID(responseID: $0, generationID: GenerationID()) - } - #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM - - return FirebaseAI.GenerationID(responseID: $0, generationID: nil) - } + let session = try sessionManager.getOrStartSession( + instructions: instructions + ) - let rawContent = try Self.makeRawContent( - from: text, - generationID: generationID, - hasSchema: schema != nil, - isComplete: true + let response = try await session._respond( + to: prompt.partsValue, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: options ) - let content: Content = try Self.resolveContent(from: rawContent) - return GenerativeModelSession.Response( - content: content, rawContent: rawContent, rawResponse: response + return try GenerativeModelSession.Response( + content: Self.resolveContent(from: response.rawContent), + rawContent: response.rawContent, + rawResponse: response.rawResponse ) } @available(macOS 12.0, watchOS 8.0, *) - private func streamResponse(to prompt: [PartsRepresentable], + private func streamResponse(to prompt: any PartsRepresentable, schema: FirebaseAI.GenerationSchema?, generating type: Content.Type, includeSchemaInPrompt: Bool, options: GenerationConfig?) -> sending GenerativeModelSession.ResponseStream { - let initialParts = [ModelContent(parts: prompt)] - return GenerativeModelSession.ResponseStream { context in + let parts = prompt.partsValue + return GenerativeModelSession.ResponseStream { context in do { - let config = try self.buildConfig( - options: options, - schema: schema, - includeSchemaInPrompt: includeSchemaInPrompt - ) - - var currentParts = initialParts - var generationID: FirebaseAI.GenerationID? - var autoFunctionCallTurns = 0 - - functionCallingLoop: while true { - let stream = try self.session.sendMessageStream(currentParts, generationConfig: config) - - var streamedText = "" - var functionCalls = [FunctionCallPart]() - - // 1. Create a buffer to hold the previous iteration's data in order to differentiate - // the last chunk to accurately set `isComplete`. - var pendingChunkData: ( - text: String, - id: FirebaseAI.GenerationID?, - response: GenerateContentResponse - )? - - for try await chunk in stream { - functionCalls.append(contentsOf: chunk.functionCalls) - - let text: String - if let responseText = chunk.text { - text = responseText - } else if let parts = chunk.candidates.first?.content.parts, !parts.isEmpty { - text = "" - } else { - throw GenerationError.decodingFailure( - GenerationError.Context(debugDescription: "No parts in response: \(chunk)") - ) - } - - // 2. If we have pending data, we now know it wasn't the last chunk. - if let pending = pendingChunkData, - !pending.text.isEmpty || pending.response.thoughtSummary != nil { - let rawContent = try Self.makeRawContent( - from: pending.text, - generationID: pending.id, - hasSchema: schema != nil, - isComplete: false - ) - let rawResult = GenerativeModelSession.ResponseStream - .RawResult( - rawContent: rawContent, - rawResponse: pending.response - ) - await context.yield(rawResult) - } - - // 3. Update our cumulative state for the current chunk - streamedText.append(text) - if generationID == nil { - generationID = chunk.responseID.map { - #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM - if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { - return FirebaseAI.GenerationID( - responseID: $0, generationID: FoundationModels.GenerationID() - ) - } - #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM - - return FirebaseAI.GenerationID(responseID: $0, generationID: nil) - } - } - - // 4. Save the current state as the new pending chunk. - pendingChunkData = (text: streamedText, id: generationID, response: chunk) - } + try self.sessionManager.startResponding() + } catch { + await context.finish(throwing: error) + return + } + defer { + self.sessionManager.finishResponding() + } - // Stream for the current turn finished. Check if there are function calls to handle. - if !functionCalls.isEmpty { - guard autoFunctionCallTurns < GenerativeModelSession.maxAutoFunctionCallTurns else { - throw GenerationError.internalError( - GenerationError.Context( - debugDescription: """ - The model exceeded the maximum allowed automatic function call iterations \ - (\(GenerativeModelSession.maxAutoFunctionCallTurns)). - """ - ), - underlyingError: FunctionCallingError.maxFunctionCallTurnsExceeded - ) - } - let functionResponses = try await self.execute(functionCalls: functionCalls) - - if !functionResponses.isEmpty { - // Yield any pending text if it's not empty, but mark it as NOT complete yet. - if let pending = pendingChunkData, - !pending.text.isEmpty || pending.response.thoughtSummary != nil { - let rawContent = try Self.makeRawContent( - from: pending.text, - generationID: pending.id, - hasSchema: schema != nil, - isComplete: false - ) - let rawResult = GenerativeModelSession.ResponseStream - .RawResult( - rawContent: rawContent, - rawResponse: pending.response - ) - await context.yield(rawResult) - } - - currentParts = [ModelContent(role: "user", parts: functionResponses)] - autoFunctionCallTurns += 1 - continue functionCallingLoop - } - } + do { + let session = try self.sessionManager.getOrStartSession( + instructions: self.instructions + ) - // 5. The remaining pending chunk is the final one. - if let finalChunk = pendingChunkData { - let rawContent = try Self.makeRawContent( - from: finalChunk.text, - generationID: finalChunk.id, - hasSchema: schema != nil, - isComplete: true + let stream = session._streamResponse( + to: parts, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: options + ) + for try await response in stream { + let rawResult = GenerativeModelSession.ResponseStream + .RawResult( + rawContent: response.rawContent, + rawResponse: response.rawResponse ) - let rawResult = GenerativeModelSession.ResponseStream - .RawResult( - rawContent: rawContent, - rawResponse: finalChunk.response - ) - await context.yield(rawResult) - } - - break functionCallingLoop + await context.yield(rawResult) } await context.finish() + return } catch { await context.finish(throwing: error) } } } - private func execute(functionCalls: [FunctionCallPart]) async throws -> [FunctionResponsePart] { - var functionResponses = [FunctionResponsePart]() - for functionCall in functionCalls { - guard let functionDeclaration = functionDeclarations[functionCall.name] else { - throw GenerationError.internalError( - GenerationError.Context(debugDescription: """ - No function named "\(functionCall.name)" was declared. - """), - underlyingError: FunctionCallingError.invalidFunctionCall - ) - } - - switch functionDeclaration.kind { - case .manual: - continue - case let .foundationModels(tool): - #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM - if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { - guard let tool = tool as? (any FoundationModels.Tool) else { - assertionFailure("The value '\(tool)' is not a Foundation Models `Tool`.") - throw TypeConversionError( - from: (any Sendable).self, to: (any FoundationModels.Tool).self - ) - } - try functionResponses.append(await FunctionDeclaration.call( - tool: tool, - functionCall: functionCall - )) - continue - } - #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM - assertionFailure(""" - A Foundation Models `Tool` '\(tool)' was provided but not running on a supported platform. - """) - } - } - - return functionResponses - } - - private func buildConfig(options: GenerationConfig?, - schema: FirebaseAI.GenerationSchema?, - includeSchemaInPrompt: Bool) throws -> GenerationConfig { - var config = GenerationConfig.merge( - session.generationConfig, with: options - ) ?? GenerationConfig() - - if let schema { - config.responseMIMEType = "application/json" - config.responseJSONSchema = includeSchemaInPrompt ? try schema.toGeminiJSONSchema() : nil - config.responseSchema = nil // `responseSchema` must not be set with `responseJSONSchema` - } - - config.responseModalities = nil // Override to the default (text only) - config.candidateCount = nil // Override to the default (one candidate) - - return config - } - - private static func makeRawContent(from text: String, generationID: FirebaseAI.GenerationID?, - hasSchema: Bool, isComplete: Bool) throws + static func makeRawContent(from text: String, generationID: FirebaseAI.GenerationID?, + hasSchema: Bool, isComplete: Bool) throws -> FirebaseAI.GeneratedContent { if hasSchema { if text.isEmpty && !isComplete { @@ -555,7 +350,8 @@ ) } - static func resolveContent(from rawContent: FirebaseAI.GeneratedContent) throws -> T { + private static func resolveContent(from rawContent: FirebaseAI.GeneratedContent) throws + -> T { if let content = rawContent as? T { return content } @@ -582,6 +378,62 @@ } } + extension GenerativeModelSession { + final class SessionManager: @unchecked Sendable { + // TODO: Track when sessions have permanent failures. + // TODO: Track and propagate history status (`Transcript`) for `modelSessions`. + + private let model: any LanguageModel + private let tools: [any ToolRepresentable]? + + private let _isResponding = UnfairLock(false) + private(set) var _activeSession: (any _ModelSession)? + + init(model: any LanguageModel, tools: [any ToolRepresentable]?) { + self.model = model + self.tools = tools + } + + var isResponding: Bool { + _isResponding.value() + } + + func startResponding() throws { + try _isResponding.withLock { isResponding in + guard !isResponding else { + throw GenerativeModelSession.GenerationError.concurrentRequests( + GenerativeModelSession.GenerationError.Context(debugDescription: """ + Attempted to start a new generation request while one was already in progress. \ + Create an additional session to perform concurrent requests. + """) + ) + } + + isResponding = true + } + } + + func finishResponding() { + _isResponding.withLock { isResponding in + assert(isResponding, "`finishResponding` called but `isResponding` is false.") + isResponding = false + } + } + + func getOrStartSession(instructions: String?) throws -> any _ModelSession { + try _isResponding.withLock { isResponding in + if let currentSession = _activeSession { + return currentSession + } else { + let newSession = try model._startSession(tools: tools, instructions: instructions) + _activeSession = newSession + return newSession + } + } + } + } + } + // MARK: - Response Types public extension GenerativeModelSession { @@ -624,9 +476,12 @@ let context = StreamContext(continuation: extractedContinuation) self.context = context - Task { + let task = Task { await builder(context) } + extractedContinuation.onTermination = { _ in + task.cancel() + } } /// An iterator that provides snapshots of the model's response. @@ -731,6 +586,11 @@ private var waitingContinuations: [CheckedContinuation] = [] private var latestRaw: RawResult? + // Returns `true` if the stream has yielded one or more values. + var hasYielded: Bool { + return latestRaw != nil + } + init(continuation: AsyncThrowingStream.Continuation) { self.continuation = continuation } @@ -842,9 +702,13 @@ } } + case assetsUnavailable(GenerativeModelSession.GenerationError.Context) + /// The model's response could not be decoded. case decodingFailure(GenerativeModelSession.GenerationError.Context) + case concurrentRequests(GenerativeModelSession.GenerationError.Context) + case internalError(GenerativeModelSession.GenerationError.Context, underlyingError: any Error) } diff --git a/FirebaseAI/Sources/PartsRepresentable.swift b/FirebaseAI/Sources/PartsRepresentable.swift index bdd4a61f000..5cf7d17659b 100644 --- a/FirebaseAI/Sources/PartsRepresentable.swift +++ b/FirebaseAI/Sources/PartsRepresentable.swift @@ -13,6 +13,9 @@ // limitations under the License. import Foundation +#if canImport(FoundationModels) + import FoundationModels +#endif // canImport(FoundationModels) /// A protocol describing any data that could be serialized to model-interpretable input data, /// where the serialization process cannot fail with an error. @@ -44,3 +47,14 @@ extension String: PartsRepresentable { return [TextPart(self)] } } + +#if compiler(>=6.2.3) && canImport(FoundationModels) + @available(iOS 26.0, macOS 26.0, *) + @available(tvOS, unavailable) + @available(watchOS, unavailable) + extension PartsRepresentable { + func toFoundationModelsPrompt() throws -> FoundationModels.Prompt { + return try partsValue.toFoundationModelsPrompt() + } + } +#endif // compiler(>=6.2.3) && canImport(FoundationModels) diff --git a/FirebaseAI/Sources/Protocols/Internal/ConvertibleToGeneratedContent.swift b/FirebaseAI/Sources/Protocols/Internal/ConvertibleToGeneratedContent.swift index c14695b52a7..3dde2f3e580 100644 --- a/FirebaseAI/Sources/Protocols/Internal/ConvertibleToGeneratedContent.swift +++ b/FirebaseAI/Sources/Protocols/Internal/ConvertibleToGeneratedContent.swift @@ -33,8 +33,7 @@ @available(iOS 26.0, macOS 26.0, *) @available(tvOS, unavailable) @available(watchOS, unavailable) - extension FirebaseAI.ConvertibleToGeneratedContent - where Self: FoundationModels.ConvertibleToGeneratedContent { + extension FoundationModels.ConvertibleToGeneratedContent { var firebaseGeneratedContent: FirebaseAI.GeneratedContent { return FirebaseAI.GeneratedContent( kind: generatedContent.kind, diff --git a/FirebaseAI/Sources/Protocols/Public/LanguageModel.swift b/FirebaseAI/Sources/Protocols/Public/LanguageModel.swift new file mode 100644 index 00000000000..2ea5604743a --- /dev/null +++ b/FirebaseAI/Sources/Protocols/Public/LanguageModel.swift @@ -0,0 +1,23 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if compiler(>=6.2.3) + public protocol LanguageModel: Sendable { + var _modelName: String { get } + + // TODO: Replace `instructions` with `Transcript` for session history. + func _startSession(tools: [any ToolRepresentable]?, + instructions: String?) throws -> any _ModelSession + } +#endif // compiler(>=6.2.3) diff --git a/FirebaseAI/Sources/Protocols/Public/ModelSession.swift b/FirebaseAI/Sources/Protocols/Public/ModelSession.swift new file mode 100644 index 00000000000..c290813b4ff --- /dev/null +++ b/FirebaseAI/Sources/Protocols/Public/ModelSession.swift @@ -0,0 +1,32 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if compiler(>=6.2.3) + public protocol _ModelSession: Sendable { + var _hasHistory: Bool { get } + + nonisolated(nonsending) func _respond(to prompt: [any Part], + schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, + options: GenerationConfig?) async throws + -> _ModelSessionResponse + + @available(macOS 12.0, watchOS 8.0, *) + func _streamResponse(to prompt: [any Part], + schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, + options: GenerationConfig?) + -> sending AsyncThrowingStream<_ModelSessionResponse, any Error> + } +#endif // compiler(>=6.2.3) diff --git a/FirebaseAI/Sources/Protocols/Public/ModelSessionResponse.swift b/FirebaseAI/Sources/Protocols/Public/ModelSessionResponse.swift new file mode 100644 index 00000000000..e26db73d438 --- /dev/null +++ b/FirebaseAI/Sources/Protocols/Public/ModelSessionResponse.swift @@ -0,0 +1,20 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if compiler(>=6.2.3) + public struct _ModelSessionResponse: Sendable { + let rawContent: FirebaseAI.GeneratedContent + let rawResponse: GenerateContentResponse + } +#endif // compiler(>=6.2.3) diff --git a/FirebaseAI/Sources/Types/Internal/GeminiModel.swift b/FirebaseAI/Sources/Types/Internal/GeminiModel.swift new file mode 100644 index 00000000000..bde76bca505 --- /dev/null +++ b/FirebaseAI/Sources/Types/Internal/GeminiModel.swift @@ -0,0 +1,67 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +#if compiler(>=6.2.3) + struct GeminiModel: LanguageModel { + let modelName: String + let modelResourceName: String + let firebaseInfo: FirebaseInfo + let apiConfig: APIConfig + let safetySettings: [SafetySetting]? + let toolConfig: ToolConfig? + let requestOptions: RequestOptions + let urlSession: URLSession + + init(modelName: String, + modelResourceName: String, + firebaseInfo: FirebaseInfo, + apiConfig: APIConfig, + safetySettings: [SafetySetting]? = nil, + toolConfig: ToolConfig? = nil, + requestOptions: RequestOptions = RequestOptions(), + urlSession: URLSession = GenAIURLSession.default) { + self.modelName = modelName + self.modelResourceName = modelResourceName + self.firebaseInfo = firebaseInfo + self.apiConfig = apiConfig + self.safetySettings = safetySettings + self.toolConfig = toolConfig + self.requestOptions = requestOptions + self.urlSession = urlSession + } + + var _modelName: String { modelName } + + func _startSession(tools: [any ToolRepresentable]?, + instructions: String?) -> any _ModelSession { + let model = GenerativeModel( + modelName: modelName, + modelResourceName: modelResourceName, + firebaseInfo: firebaseInfo, + apiConfig: apiConfig, + generationConfig: nil, + safetySettings: safetySettings, + tools: tools?.map { $0.toolRepresentation }, + toolConfig: toolConfig, + systemInstruction: instructions.map { ModelContent(role: "system", parts: $0) }, + requestOptions: requestOptions, + urlSession: urlSession + ) + + return GeminiModelSession(model: model, history: []) + } + } +#endif // compiler(>=6.2.3) diff --git a/FirebaseAI/Sources/Types/Internal/GeminiModelSession.swift b/FirebaseAI/Sources/Types/Internal/GeminiModelSession.swift new file mode 100644 index 00000000000..4e299918eae --- /dev/null +++ b/FirebaseAI/Sources/Types/Internal/GeminiModelSession.swift @@ -0,0 +1,318 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +#if canImport(FoundationModels) + import FoundationModels +#endif // canImport(FoundationModels) + +#if compiler(>=6.2.3) + /// An object that represents a back-and-forth chat with a model, capturing the history and saving + /// the context in memory between each message sent. + final class GeminiModelSession: _ModelSession { + let chat: Chat + private let functionDeclarations: [String: FunctionDeclaration] + + init(model: GenerativeModel, history: [ModelContent]) { + chat = model.startChat(history: history) + functionDeclarations = model.functionDeclarationsByName() + } + + // MARK: ModelSession Conformance + + var _hasHistory: Bool { + return !chat.history.isEmpty + } + + nonisolated(nonsending) + func _respond(to prompt: [any Part], schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, options: GenerationConfig?) async throws + -> _ModelSessionResponse { + let parts = [ModelContent(parts: prompt)] + let config = try buildConfig( + options: options, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt + ) + + var response = try await chat.sendMessage(parts, generationConfig: config) + + var autoFunctionCallTurns = 0 + while !response.functionCalls.isEmpty { + guard autoFunctionCallTurns < GenerativeModelSession.maxAutoFunctionCallTurns else { + throw GenerativeModelSession.GenerationError.internalError( + GenerativeModelSession.GenerationError.Context( + debugDescription: """ + The model exceeded the maximum allowed automatic function call iterations \ + (\(GenerativeModelSession.maxAutoFunctionCallTurns)). + """ + ), + underlyingError: GenerativeModelSession.FunctionCallingError + .maxFunctionCallTurnsExceeded + ) + } + + let functionResponses = try await execute(functionCalls: response.functionCalls) + + guard !functionResponses.isEmpty else { break } + response = try await chat.sendMessage( + [ModelContent(role: "user", parts: functionResponses)], + generationConfig: config + ) + + autoFunctionCallTurns += 1 + } + + let text: String + if let responseText = response.text { + text = responseText + } else if let parts = response.candidates.first?.content.parts, !parts.isEmpty { + text = "" + } else { + throw GenerativeModelSession.GenerationError.decodingFailure( + GenerativeModelSession.GenerationError + .Context(debugDescription: "No parts in response: \(response)") + ) + } + let generationID = response.responseID.map { + #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { + return FirebaseAI.GenerationID(responseID: $0, generationID: GenerationID()) + } + #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + + return FirebaseAI.GenerationID(responseID: $0, generationID: nil) + } + + let rawContent = try GenerativeModelSession.makeRawContent( + from: text, + generationID: generationID, + hasSchema: schema != nil, + isComplete: true + ) + + return _ModelSessionResponse(rawContent: rawContent, rawResponse: response) + } + + @available(macOS 12.0, watchOS 8.0, *) + func _streamResponse(to prompt: [any Part], + schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, + options: GenerationConfig?) + -> sending AsyncThrowingStream<_ModelSessionResponse, any Error> { + let initialParts = [ModelContent(parts: prompt)] + return AsyncThrowingStream { continuation in + let task = Task { + do { + let config = try self.buildConfig( + options: options, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt + ) + + var currentParts = initialParts + var generationID: FirebaseAI.GenerationID? + var autoFunctionCallTurns = 0 + + functionCallingLoop: while true { + let stream = try self.chat.sendMessageStream(currentParts, generationConfig: config) + + var streamedText = "" + var functionCalls = [FunctionCallPart]() + + // 1. Create a buffer to hold the previous iteration's data in order to differentiate + // the last chunk to accurately set `isComplete`. + var pendingChunkData: ( + text: String, + id: FirebaseAI.GenerationID?, + response: GenerateContentResponse + )? + + for try await chunk in stream { + functionCalls.append(contentsOf: chunk.functionCalls) + + let text: String + if let responseText = chunk.text { + text = responseText + } else if let parts = chunk.candidates.first?.content.parts, !parts.isEmpty { + text = "" + } else { + throw GenerativeModelSession.GenerationError.decodingFailure( + GenerativeModelSession.GenerationError + .Context(debugDescription: "No parts in response: \(chunk)") + ) + } + + // 2. If we have pending data, we now know it wasn't the last chunk. + if let pending = pendingChunkData, + !pending.text.isEmpty || pending.response.thoughtSummary != nil { + let rawContent = try GenerativeModelSession.makeRawContent( + from: pending.text, + generationID: pending.id, + hasSchema: schema != nil, + isComplete: false + ) + let response = _ModelSessionResponse( + rawContent: rawContent, + rawResponse: pending.response + ) + continuation.yield(response) + } + + // 3. Update our cumulative state for the current chunk + streamedText.append(text) + if generationID == nil { + generationID = chunk.responseID.map { + #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { + return FirebaseAI.GenerationID( + responseID: $0, generationID: FoundationModels.GenerationID() + ) + } + #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + + return FirebaseAI.GenerationID(responseID: $0, generationID: nil) + } + } + + // 4. Save the current state as the new pending chunk. + pendingChunkData = (text: streamedText, id: generationID, response: chunk) + } + + // Stream for the current turn finished. Check if there are function calls to handle. + if !functionCalls.isEmpty { + guard autoFunctionCallTurns < GenerativeModelSession.maxAutoFunctionCallTurns else { + throw GenerativeModelSession.GenerationError.internalError( + GenerativeModelSession.GenerationError.Context( + debugDescription: """ + The model exceeded the maximum allowed automatic function call iterations \ + (\(GenerativeModelSession.maxAutoFunctionCallTurns)). + """ + ), + underlyingError: GenerativeModelSession.FunctionCallingError + .maxFunctionCallTurnsExceeded + ) + } + let functionResponses = try await self.execute(functionCalls: functionCalls) + + if !functionResponses.isEmpty { + // Yield any pending text if it's not empty, but mark it as NOT complete yet. + if let pending = pendingChunkData, + !pending.text.isEmpty || pending.response.thoughtSummary != nil { + let rawContent = try GenerativeModelSession.makeRawContent( + from: pending.text, + generationID: pending.id, + hasSchema: schema != nil, + isComplete: false + ) + let response = _ModelSessionResponse( + rawContent: rawContent, + rawResponse: pending.response + ) + continuation.yield(response) + } + + currentParts = [ModelContent(role: "user", parts: functionResponses)] + autoFunctionCallTurns += 1 + continue functionCallingLoop + } + } + + // 5. The remaining pending chunk is the final one. + if let finalChunk = pendingChunkData { + let rawContent = try GenerativeModelSession.makeRawContent( + from: finalChunk.text, + generationID: finalChunk.id, + hasSchema: schema != nil, + isComplete: true + ) + let response = _ModelSessionResponse( + rawContent: rawContent, + rawResponse: finalChunk.response + ) + continuation.yield(response) + } + + break functionCallingLoop + } + + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in task.cancel() } + } + } + + private func execute(functionCalls: [FunctionCallPart]) async throws -> [FunctionResponsePart] { + var functionResponses = [FunctionResponsePart]() + for functionCall in functionCalls { + guard let functionDeclaration = functionDeclarations[functionCall.name] else { + throw GenerativeModelSession.GenerationError.internalError( + GenerativeModelSession.GenerationError.Context(debugDescription: """ + No function named "\(functionCall.name)" was declared. + """), + underlyingError: GenerativeModelSession.FunctionCallingError.invalidFunctionCall + ) + } + + switch functionDeclaration.kind { + case .manual: + continue + case let .foundationModels(tool): + #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { + guard let tool = tool as? (any FoundationModels.Tool) else { + assertionFailure("The value '\(tool)' is not a Foundation Models `Tool`.") + throw GenerativeModelSession.TypeConversionError( + from: (any Sendable).self, to: (any FoundationModels.Tool).self + ) + } + try functionResponses.append(await FunctionDeclaration.call( + tool: tool, + functionCall: functionCall + )) + continue + } + #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + assertionFailure(""" + A Foundation Models `Tool` '\(tool)' was provided but not running on a supported platform. + """) + } + } + + return functionResponses + } + + private func buildConfig(options: GenerationConfig?, + schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool) throws -> GenerationConfig { + var config = GenerationConfig.merge( + chat.generationConfig, with: options + ) ?? GenerationConfig() + + if let schema { + config.responseMIMEType = "application/json" + config.responseJSONSchema = includeSchemaInPrompt ? try schema.toGeminiJSONSchema() : nil + config.responseSchema = nil // `responseSchema` must not be set with `responseJSONSchema` + } + + config.responseModalities = nil // Override to the default (text only) + config.candidateCount = nil // Override to the default (one candidate) + + return config + } + } +#endif // compiler(>=6.2.3) diff --git a/FirebaseAI/Sources/Types/Internal/HybridModelSession.swift b/FirebaseAI/Sources/Types/Internal/HybridModelSession.swift new file mode 100644 index 00000000000..0054cb0accd --- /dev/null +++ b/FirebaseAI/Sources/Types/Internal/HybridModelSession.swift @@ -0,0 +1,98 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if compiler(>=6.2.3) + final class HybridModelSession: _ModelSession { + private let primary: any _ModelSession + private let secondary: any _ModelSession + + init(primary: any _ModelSession, secondary: any _ModelSession) { + self.primary = primary + self.secondary = secondary + } + + var _hasHistory: Bool { + return primary._hasHistory || secondary._hasHistory + } + + func _respond(to prompt: [any Part], schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, + options: GenerationConfig?) async throws -> _ModelSessionResponse { + do { + // Try the primary model + return try await primary._respond( + to: prompt, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: options + ) + } catch { + // Do not fallback to other other sessions if the current session contains history. + if primary._hasHistory { + throw error + } + + return try await secondary._respond( + to: prompt, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: options + ) + } + } + + @available(macOS 12.0, watchOS 8.0, *) + func _streamResponse(to prompt: [any Part], schema: FirebaseAI.GenerationSchema?, + includeSchemaInPrompt: Bool, + options: GenerationConfig?) + -> sending AsyncThrowingStream<_ModelSessionResponse, any Error> { + return AsyncThrowingStream { continuation in + let task = Task { + let stream = primary._streamResponse( + to: prompt, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: options + ) + + do { + for try await snapshot in stream { + continuation.yield(snapshot) + } + continuation.finish() + } catch { + // Do not fallback to other other sessions if the current session contains history. + if primary._hasHistory { + continuation.finish(throwing: error) + return + } + + let stream = secondary._streamResponse( + to: prompt, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: options + ) + + for try await snapshot in stream { + continuation.yield(snapshot) + } + continuation.finish() + } + } + continuation.onTermination = { _ in task.cancel() } + } + } + } +#endif // compiler(>=6.2.3) diff --git a/FirebaseAI/Sources/Types/Public/HybridModel.swift b/FirebaseAI/Sources/Types/Public/HybridModel.swift new file mode 100644 index 00000000000..51f8964da0a --- /dev/null +++ b/FirebaseAI/Sources/Types/Public/HybridModel.swift @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +public enum InferenceMode: Sendable { case preferOnDevice, preferInCloud } + +public struct HybridModel: LanguageModel { + let cloud: any LanguageModel + let onDevice: any LanguageModel + let mode: InferenceMode + + public var _modelName: String { + return "hybrid:\(cloud._modelName),\(onDevice._modelName)" + } + + public func _startSession(tools: [any ToolRepresentable]?, + instructions: String?) throws -> any _ModelSession { + let cloudSession = try cloud._startSession(tools: tools, instructions: instructions) + let deviceSession = try onDevice._startSession(tools: tools, instructions: instructions) + + switch mode { + case .preferOnDevice: + return HybridModelSession(primary: deviceSession, secondary: cloudSession) + case .preferInCloud: + return HybridModelSession(primary: cloudSession, secondary: deviceSession) + } + } +} diff --git a/FirebaseAI/Sources/Types/Public/Part.swift b/FirebaseAI/Sources/Types/Public/Part.swift index df492c7f75d..c92c6298a8c 100644 --- a/FirebaseAI/Sources/Types/Public/Part.swift +++ b/FirebaseAI/Sources/Types/Public/Part.swift @@ -13,6 +13,9 @@ // limitations under the License. import Foundation +#if canImport(FoundationModels) + import FoundationModels +#endif // canImport(FoundationModels) /// A discrete piece of data in a media format interpretable by an AI model. /// @@ -346,3 +349,44 @@ public struct CodeExecutionResultPart: Part { self.thoughtSignature = thoughtSignature } } + +#if compiler(>=6.2.3) && canImport(FoundationModels) + @available(iOS 26.0, macOS 26.0, *) + @available(tvOS, unavailable) + @available(watchOS, unavailable) + extension [any Part] { + func toFoundationModelsPrompt() throws -> FoundationModels.Prompt { + let parts = ModelContent(parts: self) + let promptParts: [any FoundationModels.PromptRepresentable] = try parts.internalParts + .compactMap { part in + // Skip any `thought` parts since they are unused by Foundation Models. + guard !(part.isThought ?? false) else { return nil } + + // Skip any parts without `data`, for example a `Part` containing only a thought + // signature, since they are unused by Foundation Models. + guard let data = part.data else { return nil } + + // Currently only string types are supported. + guard case let .text(string) = data else { + // TODO: Create a custom error type for unsupported prompt part types. + throw GenerativeModelSession.GenerationError.internalError( + GenerativeModelSession.GenerationError.Context( + debugDescription: """ + Prompt data type "\(data)" is not supported by Foundation Models. + """ + ), + underlyingError: NSError(domain: Constants.baseErrorDomain, code: 0) + ) + } + + return string + } + + return Prompt { + for part in promptParts { + part.promptRepresentation + } + } + } + } +#endif // compiler(>=6.2.3) && canImport(FoundationModels) diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionHybridTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionHybridTests.swift new file mode 100644 index 00000000000..f2b3f011014 --- /dev/null +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionHybridTests.swift @@ -0,0 +1,254 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO: Remove the `#if compiler(>=6.2.3)` when Xcode 26.2 is the minimum supported version. +#if compiler(>=6.2.3) + @testable import FirebaseAILogic + import FirebaseAITestApp + import Foundation + #if canImport(FoundationModels) + import FoundationModels + #endif // canImport(FoundationModels) + import Testing + + @Suite(.serialized) + struct GenerativeModelSessionHybridTests { + @Test(arguments: [InstanceConfig.vertexAI_v1beta_global]) + func respondText_fallbackOnGeminiModelError(_ config: InstanceConfig) async throws { + let firebaseAI = FirebaseAI.componentInstance(config) + let invalidModel1 = firebaseAI.geminiModel(modelName: "invalid-model-name-1") + let invalidModel2 = firebaseAI.geminiModel(modelName: "invalid-model-name-2") + let validModel = firebaseAI.geminiModel(modelName: ModelNames.gemini2_5_FlashLite) + let session = firebaseAI.generativeModelSession( + model: HybridModel( + cloud: invalidModel1, + onDevice: HybridModel(cloud: invalidModel2, onDevice: validModel, mode: .preferInCloud), + mode: .preferInCloud + ) + ) + let prompt = "Why is the sky blue?" + + let response = try await session.respond(to: prompt) + + let content = response.content + #expect(!content.isEmpty) + #expect(response.rawContent.isComplete) + #if canImport(FoundationModels) + if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { + #expect(response.rawContent.kind == .string(content)) + } + #endif // canImport(FoundationModels) + #expect(response.rawContent.generationID != nil) + #expect(response.rawResponse.text == content) + #expect(response.rawResponse.modelVersion == validModel._modelName) + } + + @Test(arguments: [InstanceConfig.vertexAI_v1beta_global]) + @available(iOS 26.0, macOS 26.0, *) + @available(tvOS, unavailable) + @available(watchOS, unavailable) + func respondText_fallbackOnFoundationModelsError(_ config: InstanceConfig) async throws { + let firebaseAI = FirebaseAI.componentInstance(config) + let onDeviceModel = SystemLanguageModel.default + let cloudModel = firebaseAI.geminiModel(modelName: ModelNames.gemini2_5_FlashLite) + let session = firebaseAI.generativeModelSession( + model: HybridModel(cloud: cloudModel, onDevice: onDeviceModel, mode: .preferOnDevice) + ) + let prompt = "In one sentence, why is the sky blue?" + + let response = try await session.respond(to: prompt) + + let content = response.content + #expect(!content.isEmpty) + #expect(response.rawContent.isComplete) + #expect(response.rawContent.kind == .string(content)) + #expect(response.rawContent.generationID != nil) + #expect(response.rawResponse.text == content) + // Check for the on-device model name when running on Apple Intelligence supported devices; in + // this case, no fallback occurs. When running on devices that do not support Apple + // Intelligence, including GitHub Runner Images, check for the cloud (Gemini) model name. + if await foundationModelsIsAvailable() { + #expect(response.rawResponse.modelVersion == onDeviceModel._modelName) + } else { + #expect(response.rawResponse.modelVersion == cloudModel._modelName) + } + } + + @Test(arguments: [InstanceConfig.vertexAI_v1beta_global]) + @available(iOS 26.0, macOS 26.0, *) + @available(tvOS, unavailable) + @available(watchOS, unavailable) + func respondGenerable_fallbackOnFoundationModelsError(_ config: InstanceConfig) async throws { + let firebaseAI = FirebaseAI.componentInstance(config) + let onDeviceModel = SystemLanguageModel.default + let cloudModel = firebaseAI.geminiModel(modelName: ModelNames.gemini2_5_FlashLite) + let session = firebaseAI.generativeModelSession( + model: HybridModel(cloud: cloudModel, onDevice: onDeviceModel, mode: .preferOnDevice) + ) + let prompt = "Generate a cute rescue cat" + + let response = try await session.respond( + to: prompt, + generating: GenerativeModelSessionTests.CatProfile.self + ) + + let catProfile = response.content + #expect(!catProfile.name.isEmpty) + #expect(catProfile.age >= 1) + #expect(catProfile.age <= 20) + #expect(!catProfile.profile.isEmpty) + #expect(response.rawContent.isComplete) + #expect(response.rawContent.generationID != nil) + // Check for the on-device model name when running on Apple Intelligence supported devices; in + // this case, no fallback occurs. When running on devices that do not support Apple + // Intelligence, including GitHub Runner Images, check for the cloud (Gemini) model name. + if await foundationModelsIsAvailable() { + #expect(response.rawResponse.modelVersion == onDeviceModel._modelName) + } else { + #expect(response.rawResponse.modelVersion == cloudModel._modelName) + } + } + + @Test(arguments: [InstanceConfig.vertexAI_v1beta_global]) + func streamResponseText_fallbackOnGeminiModelError(_ config: InstanceConfig) async throws { + let firebaseAI = FirebaseAI.componentInstance(config) + let invalidModel = firebaseAI.geminiModel(modelName: "invalid-model-name") + let validModel = firebaseAI.geminiModel(modelName: ModelNames.gemini2_5_FlashLite) + let session = firebaseAI.generativeModelSession( + model: HybridModel(cloud: invalidModel, onDevice: validModel, mode: .preferInCloud) + ) + let prompt = "In one sentence, why is the sky blue?" + + let stream = session.streamResponse(to: prompt) + + var generationID: FirebaseAI.GenerationID? + var isComplete = false + for try await snapshot in stream { + #expect(!isComplete, "Stream yielded more elements after a snapshot was marked complete.") + let partial = snapshot.content + #expect(!partial.isEmpty) + if let generationID { + #expect( + generationID == snapshot.rawContent.generationID, + "The generation ID was not stable for the duration of the response." + ) + } else { + #expect(snapshot.rawContent.generationID != nil) + generationID = snapshot.rawContent.generationID + } + isComplete = snapshot.rawContent.isComplete + } + #expect(isComplete, "The stream finished, but the final snapshot was not marked as complete.") + + let response = try await stream.collect() + let content = response.content + #expect(!content.isEmpty) + #expect(response.rawContent.isComplete, "The final response was not marked as complete.") + #expect(response.rawContent.generationID == generationID) + #if canImport(FoundationModels) + if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { + #expect(response.rawContent.kind == .string(content)) + } + #endif // canImport(FoundationModels) + if let text = response.rawResponse.text { + #expect(content.hasSuffix(text)) + } + #expect(response.rawResponse.modelVersion == validModel._modelName) + } + + @Test(arguments: [InstanceConfig.vertexAI_v1beta_global]) + @available(iOS 26.0, macOS 26.0, *) + @available(tvOS, unavailable) + @available(watchOS, unavailable) + func streamResponseText_fallbackOnFoundationModelsError(_ config: InstanceConfig) async throws { + let firebaseAI = FirebaseAI.componentInstance(config) + let onDeviceModel = SystemLanguageModel.default + let cloudModel = firebaseAI.geminiModel(modelName: ModelNames.gemini2_5_FlashLite) + let session = firebaseAI.generativeModelSession( + model: HybridModel(cloud: cloudModel, onDevice: onDeviceModel, mode: .preferOnDevice) + ) + let prompt = "In one sentence, why is the sky blue?" + + let stream = session.streamResponse(to: prompt) + + var generationID: FirebaseAI.GenerationID? + var isComplete = false + for try await snapshot in stream { + #expect(!isComplete, "Stream yielded more elements after a snapshot was marked complete.") + let partial = snapshot.content + #expect(!partial.isEmpty) + if let generationID { + #expect( + generationID == snapshot.rawContent.generationID, + "The generation ID was not stable for the duration of the response." + ) + } else { + #expect(snapshot.rawContent.generationID != nil) + generationID = snapshot.rawContent.generationID + } + isComplete = snapshot.rawContent.isComplete + } + #expect(isComplete, "The stream finished, but the final snapshot was not marked as complete.") + + let response = try await stream.collect() + let content = response.content + #expect(!content.isEmpty) + #expect(response.rawContent.isComplete, "The final response was not marked as complete.") + #expect(response.rawContent.generationID == generationID) + #expect(response.rawContent.kind == .string(content)) + if let text = response.rawResponse.text { + #expect(content.hasSuffix(text)) + } + // Check for the on-device model name when running on Apple Intelligence supported devices; in + // this case, no fallback occurs. When running on devices that do not support Apple + // Intelligence, including GitHub Runner Images, check for the cloud (Gemini) model name. + if await foundationModelsIsAvailable() { + #expect(response.rawResponse.modelVersion == onDeviceModel._modelName) + } else { + #expect(response.rawResponse.modelVersion == cloudModel._modelName) + } + } + + /// Returns `true` if `FoundationModels.SystemLanguageModel` is available. + /// + /// This is a workaround for `SystemLanguageModel.isAvailable`, which returns `true` if *any* + /// version of the model is available. However, calls to `LanguageModelSession().respond(to:)` + /// throw a `ModelManagerError` if the simulator's model version does not match the host macOS + /// version. A new version of the model was introduced in Xcode/macOS/iOS 26.4. + func foundationModelsIsAvailable() async -> Bool { + #if canImport(FoundationModels) + if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { + let model = SystemLanguageModel.default + guard model.isAvailable else { + return false + } + + let session = LanguageModelSession(model: model) + do { + _ = try await session.respond( + to: "Hello", + options: GenerationOptions(sampling: .greedy, temperature: 0) + ) + + return true + } catch { + return false + } + } + #endif // canImport(FoundationModels) + + return false + } + } +#endif // compiler(>=6.2.3) diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionTests.swift index 82ce3ee905f..435f2e1359b 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionTests.swift @@ -35,11 +35,11 @@ let content = response.content #expect(!content.isEmpty) #expect(response.rawContent.isComplete) - #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + #if canImport(FoundationModels) if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { #expect(response.rawContent.kind == .string(content)) } - #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + #endif // canImport(FoundationModels) #expect(response.rawContent.generationID != nil) #expect(response.rawResponse.text == content) } @@ -376,11 +376,11 @@ #expect(!content.isEmpty) #expect(response.rawContent.isComplete, "The final response was not marked as complete.") #expect(response.rawContent.generationID == generationID) - #if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + #if canImport(FoundationModels) if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) { #expect(response.rawContent.kind == .string(content)) } - #endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM + #endif // canImport(FoundationModels) if let text = response.rawResponse.text { #expect(content.hasSuffix(text)) } diff --git a/FirebaseAI/Tests/Unit/GenerativeModelSessionTests.swift b/FirebaseAI/Tests/Unit/GenerativeModelSessionTests.swift index bc944105778..e312a11c504 100644 --- a/FirebaseAI/Tests/Unit/GenerativeModelSessionTests.swift +++ b/FirebaseAI/Tests/Unit/GenerativeModelSessionTests.swift @@ -74,15 +74,23 @@ ), ] let currentTimeTool = CurrentTimeTool() - let model = try mockGenerativeModel(tools: .autoFunctionDeclaration(currentTimeTool)) - let session = GenerativeModelSession(model: model) + let model = try mockGeminiModel() + let session = GenerativeModelSession( + model: model, + tools: [.autoFunctionDeclaration(currentTimeTool)], + instructions: nil + ) let response = try await session.respond(to: testPrompt) XCTAssertEqual(response.content, "Mountain View") var functionCalls = [FunctionCall]() var functionResponses = [FunctionResponse]() - for content in session.session.history { + let modelSession = try XCTUnwrap( + session.sessionManager.getOrStartSession(instructions: nil) + ) + let geminiSession = try XCTUnwrap(modelSession as? GeminiModelSession) + for content in geminiSession.chat.history { for part in content.internalParts { switch part.data { case let .functionCall(functionCall): @@ -120,15 +128,23 @@ subdirectory: googleAISubdirectory )) let currentTimeTool = CurrentTimeTool() - let model = try mockGenerativeModel(tools: .autoFunctionDeclaration(currentTimeTool)) - let session = GenerativeModelSession(model: model) + let model = try mockGeminiModel() + let session = GenerativeModelSession( + model: model, + tools: [.autoFunctionDeclaration(currentTimeTool)], + instructions: nil + ) let response = try await session.respond(to: testPrompt) XCTAssertEqual(response.content, "Mountain View") var functionCalls = [FunctionCall]() var functionResponses = [FunctionResponse]() - for content in session.session.history { + let modelSession = try XCTUnwrap( + session.sessionManager.getOrStartSession(instructions: nil) + ) + let geminiSession = try XCTUnwrap(modelSession as? GeminiModelSession) + for content in geminiSession.chat.history { for part in content.internalParts { switch part.data { case let .functionCall(functionCall): @@ -152,6 +168,94 @@ XCTAssertEqual(functionResponse.response, ["result": .string(CurrentTimeTool.currentTime)]) } + func testRespondTo_activeSessionIndexPreventsFallback() async throws { + let model1 = try mockGeminiModel(modelName: "gemini-2.5-flash") + let model2 = try mockGeminiModel(modelName: "gemini-2.0-flash") + + let session = GenerativeModelSession( + model: HybridModel(cloud: model1, onDevice: model2, mode: .preferInCloud), + tools: nil, + instructions: nil + ) + let expectedStatusCode = 400 + try MockURLProtocol.requestHandlersQueue.append(contentsOf: [ + GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-thinking-reply-thought-summary", + withExtension: "json", + subdirectory: googleAISubdirectory + ), + GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-failure-api-key", + withExtension: "json", + subdirectory: googleAISubdirectory, + statusCode: expectedStatusCode + ), + GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + subdirectory: googleAISubdirectory + ), + ]) + + // Verify first request succeeds. + let response1 = try await session.respond(to: testPrompt) + XCTAssertEqual(response1.content, "Mountain View") + XCTAssertEqual(response1.rawResponse.modelVersion, model1.modelName) + + // Verify no fallback to model2 after successful request. + await XCTAssertThrowsError({ + try await session.respond(to: testPrompt) + }, "Expected an error but request succeeded.") { error in + guard case let GenerateContentError.internalError(underlying: underlyingError) = error + else { + return XCTFail("Unexpected error type: \(error)") + } + guard let backendError = underlyingError as? BackendError else { + return XCTFail("Unexpected underlying error type: \(underlyingError)") + } + XCTAssertEqual(backendError.status, .invalidArgument) + XCTAssertEqual(backendError.httpResponseCode, expectedStatusCode) + XCTAssertTrue(backendError.message.hasPrefix("API key not valid.")) + } + XCTAssertEqual(MockURLProtocol.requestHandlersQueue.count, 1, """ + Expected 'unary-success-basic-reply-short' to remain the queue since falling back to 'model2' + is not supported after a successful request using model1. + """) + } + + func testRespondTo_fallbackAfterFailure() async throws { + let model1 = try mockGeminiModel(modelName: "gemini-5.0-flash") + let model2 = try mockGeminiModel(modelName: "gemini-2.5-flash") + let session = GenerativeModelSession( + model: HybridModel(cloud: model1, onDevice: model2, mode: .preferInCloud), + tools: nil, + instructions: nil + ) + let expectedStatusCode = 404 + try MockURLProtocol.requestHandlersQueue.append(contentsOf: [ + GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-failure-unknown-model", + withExtension: "json", + subdirectory: googleAISubdirectory, + statusCode: expectedStatusCode + ), + GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-thinking-reply-thought-summary", + withExtension: "json", + subdirectory: googleAISubdirectory + ), + ]) + + // Verify request falls back to model2 after initial failure with model1. + let response1 = try await session.respond(to: testPrompt) + + XCTAssertEqual(response1.content, "Mountain View") + XCTAssertEqual(response1.rawResponse.modelVersion, model2.modelName) + XCTAssertTrue(MockURLProtocol.requestHandlersQueue.isEmpty, """ + Expected the queue to be empty after automatically falling back to model2. + """) + } + func testRespondTo_functionCall_maxFunctionCallTurnsExceeded() async throws { let functionCallCount = GenerativeModelSession.maxAutoFunctionCallTurns + 1 MockURLProtocol.requestHandlersQueue = try Array( @@ -167,8 +271,12 @@ subdirectory: googleAISubdirectory )) let currentTimeTool = CurrentTimeTool() - let model = try mockGenerativeModel(tools: .autoFunctionDeclaration(currentTimeTool)) - let session = GenerativeModelSession(model: model) + let model = try mockGeminiModel() + let session = GenerativeModelSession( + model: model, + tools: [.autoFunctionDeclaration(currentTimeTool)], + instructions: nil + ) await XCTAssertThrowsError { try await session.respond(to: testPrompt) @@ -193,7 +301,11 @@ var functionCalls = [FunctionCall]() var functionResponses = [FunctionResponse]() - for content in session.session.history { + let modelSession = try XCTUnwrap( + session.sessionManager.getOrStartSession(instructions: nil) + ) + let geminiSession = try XCTUnwrap(modelSession as? GeminiModelSession) + for content in geminiSession.chat.history { for part in content.internalParts { switch part.data { case let .functionCall(functionCall): @@ -231,8 +343,12 @@ ), ] let currentTimeTool = CurrentTimeTool() - let model = try mockGenerativeModel(tools: .autoFunctionDeclaration(currentTimeTool)) - let session = GenerativeModelSession(model: model) + let model = try mockGeminiModel() + let session = GenerativeModelSession( + model: model, + tools: [.autoFunctionDeclaration(currentTimeTool)], + instructions: nil + ) let stream = session.streamResponse(to: testPrompt) let response = try await stream.collect() @@ -243,7 +359,11 @@ """) var functionCalls = [FunctionCall]() var functionResponses = [FunctionResponse]() - for content in session.session.history { + let modelSession = try XCTUnwrap( + session.sessionManager.getOrStartSession(instructions: nil) + ) + let geminiSession = try XCTUnwrap(modelSession as? GeminiModelSession) + for content in geminiSession.chat.history { for part in content.internalParts { switch part.data { case let .functionCall(functionCall): @@ -282,8 +402,12 @@ subdirectory: googleAISubdirectory )) let currentTimeTool = CurrentTimeTool() - let model = try mockGenerativeModel(tools: .autoFunctionDeclaration(currentTimeTool)) - let session = GenerativeModelSession(model: model) + let model = try mockGeminiModel() + let session = GenerativeModelSession( + model: model, + tools: [.autoFunctionDeclaration(currentTimeTool)], + instructions: nil + ) let stream = session.streamResponse(to: testPrompt) let response = try await stream.collect() @@ -294,7 +418,11 @@ """) var functionCalls = [FunctionCall]() var functionResponses = [FunctionResponse]() - for content in session.session.history { + let modelSession = try XCTUnwrap( + session.sessionManager.getOrStartSession(instructions: nil) + ) + let geminiSession = try XCTUnwrap(modelSession as? GeminiModelSession) + for content in geminiSession.chat.history { for part in content.internalParts { switch part.data { case let .functionCall(functionCall): @@ -334,8 +462,12 @@ subdirectory: googleAISubdirectory )) let currentTimeTool = CurrentTimeTool() - let model = try mockGenerativeModel(tools: .autoFunctionDeclaration(currentTimeTool)) - let session = GenerativeModelSession(model: model) + let model = try mockGeminiModel() + let session = GenerativeModelSession( + model: model, + tools: [.autoFunctionDeclaration(currentTimeTool)], + instructions: nil + ) await XCTAssertThrowsError { let stream = session.streamResponse(to: testPrompt) @@ -361,7 +493,11 @@ var functionCalls = [FunctionCall]() var functionResponses = [FunctionResponse]() - for content in session.session.history { + let modelSession = try XCTUnwrap( + session.sessionManager.getOrStartSession(instructions: nil) + ) + let geminiSession = try XCTUnwrap(modelSession as? GeminiModelSession) + for content in geminiSession.chat.history { for part in content.internalParts { switch part.data { case let .functionCall(functionCall): @@ -387,16 +523,18 @@ // MARK: - Helper Utilities - func mockGenerativeModel(modelName: String? = nil, modelResourceName: String? = nil, - firebaseInfo: FirebaseInfo? = nil, apiConfig: APIConfig? = nil, - tools: ToolRepresentable..., requestOptions: RequestOptions? = nil, - urlSession: URLSession? = nil) throws -> GenerativeModel { - return GenerativeModel( + func mockGeminiModel(modelName: String? = nil, modelResourceName: String? = nil, + firebaseInfo: FirebaseInfo? = nil, apiConfig: APIConfig? = nil, + safetySettings: [SafetySetting]? = nil, toolConfig: ToolConfig? = nil, + requestOptions: RequestOptions? = nil, + urlSession: URLSession? = nil) throws -> GeminiModel { + return GeminiModel( modelName: modelName ?? testModelName, modelResourceName: modelResourceName ?? testModelResourceName, firebaseInfo: firebaseInfo ?? GenerativeModelTestUtil.testFirebaseInfo(), apiConfig: apiConfig ?? self.apiConfig, - tools: tools.isEmpty ? nil : tools.asFirebaseTools(), + safetySettings: safetySettings, + toolConfig: toolConfig, requestOptions: requestOptions ?? RequestOptions(), urlSession: urlSession ?? self.urlSession )