Skip to content

Commit dbe3b1b

Browse files
authored
[AI] Add wrapper for FoundationModels.GenerationOptions (#16103)
1 parent a5172d2 commit dbe3b1b

2 files changed

Lines changed: 315 additions & 0 deletions

File tree

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if compiler(>=6.2.3)
16+
import Foundation
17+
#if canImport(FoundationModels)
18+
import FoundationModels
19+
#endif // canImport(FoundationModels)
20+
21+
public extension FirebaseAI {
22+
/// Options that control how the model generates its response to a prompt.
23+
///
24+
/// This is a thin wrapper for the `FoundationModels.GenerationOptions` struct that is
25+
/// available on a wider range of operating system versions.
26+
struct GenerationOptions: Sendable, Equatable {
27+
protocol GenerationOptionsProtocol: Sendable, Equatable {}
28+
29+
/// A type that defines how values are sampled from a probability distribution.
30+
public struct SamplingMode: Sendable, Equatable {
31+
protocol SamplingModeProtocol: Sendable, Equatable {}
32+
33+
enum Kind {
34+
case greedy
35+
case randomTopK(k: Int, seed: UInt64?)
36+
case randomProbabilityThreshold(probabilityThreshold: Double, seed: UInt64?)
37+
case foundationModelsSamplingMode(any SamplingModeProtocol)
38+
}
39+
40+
let kind: Kind
41+
42+
init(kind: Kind) {
43+
self.kind = kind
44+
}
45+
46+
/// A sampling mode that always chooses the most likely token.
47+
public static var greedy: GenerationOptions.SamplingMode {
48+
return SamplingMode(kind: .greedy)
49+
}
50+
51+
/// A sampling mode that considers a fixed number of high-probability tokens.
52+
public static func random(top k: Int, seed: UInt64? = nil) -> GenerationOptions
53+
.SamplingMode {
54+
return SamplingMode(kind: .randomTopK(k: k, seed: seed))
55+
}
56+
57+
/// A mode that considers a variable number of high-probability tokens based on the
58+
/// specified threshold.
59+
public static func random(probabilityThreshold: Double,
60+
seed: UInt64? = nil) -> GenerationOptions.SamplingMode {
61+
return SamplingMode(kind: .randomProbabilityThreshold(
62+
probabilityThreshold: probabilityThreshold,
63+
seed: seed
64+
))
65+
}
66+
67+
#if canImport(FoundationModels)
68+
@available(iOS 26.0, macOS 26.0, *)
69+
@available(tvOS, unavailable)
70+
@available(watchOS, unavailable)
71+
init(_ samplingMode: FoundationModels.GenerationOptions.SamplingMode) {
72+
kind = .foundationModelsSamplingMode(samplingMode)
73+
}
74+
75+
@available(iOS 26.0, macOS 26.0, *)
76+
@available(tvOS, unavailable)
77+
@available(watchOS, unavailable)
78+
var samplingMode: FoundationModels.GenerationOptions.SamplingMode {
79+
switch kind {
80+
case .greedy:
81+
return FoundationModels.GenerationOptions.SamplingMode.greedy
82+
case let .randomTopK(k, seed):
83+
return FoundationModels.GenerationOptions.SamplingMode.random(top: k, seed: seed)
84+
case let .randomProbabilityThreshold(prob, seed):
85+
return FoundationModels.GenerationOptions.SamplingMode.random(
86+
probabilityThreshold: prob,
87+
seed: seed
88+
)
89+
case let .foundationModelsSamplingMode(samplingMode):
90+
guard let samplingMode = samplingMode as? FoundationModels.GenerationOptions
91+
.SamplingMode else {
92+
preconditionFailure("""
93+
\(Self.self).#\(#function): `samplingMode` must be a
94+
`FoundationModels.GenerationOptions.SamplingMode`.
95+
""")
96+
}
97+
98+
return samplingMode
99+
}
100+
}
101+
#endif // canImport(FoundationModels)
102+
103+
public static func == (lhs: SamplingMode, rhs: SamplingMode) -> Bool {
104+
switch (lhs.kind, rhs.kind) {
105+
case (.greedy, .greedy):
106+
return true
107+
case let (.randomTopK(lhsK, lhsSeed), .randomTopK(rhsK, rhsSeed)):
108+
return lhsK == rhsK && lhsSeed == rhsSeed
109+
case let (
110+
.randomProbabilityThreshold(lhsP, lhsSeed),
111+
.randomProbabilityThreshold(rhsP, rhsSeed)
112+
):
113+
return lhsP == rhsP && lhsSeed == rhsSeed
114+
case let (.foundationModelsSamplingMode(lhsMode), .foundationModelsSamplingMode(rhsMode)):
115+
#if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM
116+
if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) {
117+
if let lhsMode = lhsMode as? FoundationModels.GenerationOptions.SamplingMode,
118+
let rhsMode = rhsMode as? FoundationModels.GenerationOptions.SamplingMode {
119+
return lhsMode == rhsMode
120+
}
121+
}
122+
#endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM
123+
return false
124+
default:
125+
return false
126+
}
127+
}
128+
}
129+
130+
/// A sampling strategy for how the model picks tokens when generating a response.
131+
public var sampling: GenerationOptions.SamplingMode?
132+
133+
/// Temperature influences the confidence of the model's response.
134+
public var temperature: Double?
135+
136+
/// The maximum number of tokens the model is allowed to produce in its response.
137+
public var maximumResponseTokens: Int?
138+
139+
// Opaque storage for Apple's type to support full round-tripping when created from it.
140+
private var _generationOptions: (any GenerationOptionsProtocol)?
141+
142+
/// Creates generation options that control token sampling behavior.
143+
public init(sampling: GenerationOptions.SamplingMode? = nil, temperature: Double? = nil,
144+
maximumResponseTokens: Int? = nil) {
145+
self.sampling = sampling
146+
self.temperature = temperature
147+
self.maximumResponseTokens = maximumResponseTokens
148+
_generationOptions = nil
149+
}
150+
151+
#if canImport(FoundationModels)
152+
/// Initializes a ``FirebaseAI/GenerationOptions`` from a
153+
/// `FoundationModels.GenerationOptions`.
154+
///
155+
/// - Parameter options: The `FoundationModels.GenerationOptions` to wrap.
156+
@available(iOS 26.0, macOS 26.0, *)
157+
@available(tvOS, unavailable)
158+
@available(watchOS, unavailable)
159+
public init(_ options: FoundationModels.GenerationOptions) {
160+
_generationOptions = options
161+
sampling = options.sampling.map { SamplingMode(kind: .foundationModelsSamplingMode($0)) }
162+
temperature = options.temperature
163+
maximumResponseTokens = options.maximumResponseTokens
164+
}
165+
166+
@available(iOS 26.0, macOS 26.0, *)
167+
@available(tvOS, unavailable)
168+
@available(watchOS, unavailable)
169+
func toFoundationModels() -> FoundationModels.GenerationOptions {
170+
if let generationOptions = _generationOptions as? FoundationModels.GenerationOptions {
171+
return generationOptions
172+
}
173+
174+
return FoundationModels.GenerationOptions(
175+
sampling: sampling?.samplingMode,
176+
temperature: temperature,
177+
maximumResponseTokens: maximumResponseTokens
178+
)
179+
}
180+
#endif // canImport(FoundationModels)
181+
182+
public static func == (lhs: GenerationOptions, rhs: GenerationOptions) -> Bool {
183+
#if canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM
184+
if #available(iOS 26.0, macOS 26.0, visionOS 26.0, *) {
185+
if let lhsOptions = lhs._generationOptions as? FoundationModels.GenerationOptions,
186+
let rhsOptions = rhs._generationOptions as? FoundationModels.GenerationOptions {
187+
return lhsOptions == rhsOptions
188+
}
189+
}
190+
#endif // canImport(FoundationModels) && IS_FOUNDATION_MODELS_SUPPORTED_PLATFORM
191+
192+
return lhs.sampling == rhs.sampling &&
193+
lhs.temperature == rhs.temperature &&
194+
lhs.maximumResponseTokens == rhs.maximumResponseTokens
195+
}
196+
}
197+
}
198+
199+
#if canImport(FoundationModels)
200+
@available(iOS 26.0, macOS 26.0, *)
201+
@available(tvOS, unavailable)
202+
@available(watchOS, unavailable)
203+
extension FoundationModels.GenerationOptions: FirebaseAI.GenerationOptions
204+
.GenerationOptionsProtocol {}
205+
206+
@available(iOS 26.0, macOS 26.0, *)
207+
@available(tvOS, unavailable)
208+
@available(watchOS, unavailable)
209+
extension FoundationModels.GenerationOptions.SamplingMode: FirebaseAI.GenerationOptions
210+
.SamplingMode.SamplingModeProtocol {}
211+
#endif // canImport(FoundationModels)
212+
#endif // compiler(>=6.2.3)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright 2026 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if compiler(>=6.2.3)
16+
@testable import FirebaseAILogic
17+
import XCTest
18+
19+
#if canImport(FoundationModels)
20+
import FoundationModels
21+
#endif
22+
23+
final class GenerationOptionsTests: XCTestCase {
24+
#if canImport(FoundationModels)
25+
@available(iOS 26.0, macOS 26.0, *)
26+
@available(tvOS, unavailable)
27+
@available(watchOS, unavailable)
28+
func testConversionToFoundationModels() throws {
29+
// Skip this test on platforms that do not support Foundation Models. This is a
30+
// workaround for XCTest ignoring the `@available` attributes. See
31+
// https://stackoverflow.com/q/59645536 for more details.
32+
try XCTSkipFoundationModelsUnsupported()
33+
34+
let options = FirebaseAI.GenerationOptions(
35+
sampling: .greedy,
36+
temperature: 0.5,
37+
maximumResponseTokens: 100
38+
)
39+
40+
let afmOptions = options.toFoundationModels()
41+
42+
XCTAssertEqual(afmOptions.temperature, 0.5)
43+
XCTAssertEqual(afmOptions.maximumResponseTokens, 100)
44+
XCTAssertNotNil(afmOptions.sampling)
45+
XCTAssertEqual(afmOptions.sampling, .greedy)
46+
}
47+
#endif // canImport(FoundationModels)
48+
49+
func testEquatable_emptyOptions() throws {
50+
let options = FirebaseAI.GenerationOptions()
51+
52+
XCTAssertNil(options.sampling)
53+
XCTAssertNil(options.temperature)
54+
XCTAssertNil(options.maximumResponseTokens)
55+
}
56+
57+
func testGenerationSchema_greedy() throws {
58+
let temperature = 0.9
59+
let maximumResponseTokens = 200
60+
let options = FirebaseAI.GenerationOptions(
61+
sampling: .greedy,
62+
temperature: temperature,
63+
maximumResponseTokens: maximumResponseTokens
64+
)
65+
66+
XCTAssertEqual(options.sampling, .greedy)
67+
XCTAssertEqual(options.temperature, temperature)
68+
XCTAssertEqual(options.maximumResponseTokens, maximumResponseTokens)
69+
}
70+
71+
func testGenerationSchema_probabilityThreshold() throws {
72+
let topP = 0.8
73+
let seed: UInt64 = 5_000_000_000
74+
let temperature = 0.6
75+
let maximumResponseTokens = 80
76+
let options = FirebaseAI.GenerationOptions(
77+
sampling: .random(probabilityThreshold: topP, seed: seed),
78+
temperature: temperature,
79+
maximumResponseTokens: maximumResponseTokens
80+
)
81+
82+
XCTAssertEqual(options.sampling, .random(probabilityThreshold: topP, seed: seed))
83+
XCTAssertEqual(options.temperature, temperature)
84+
XCTAssertEqual(options.maximumResponseTokens, maximumResponseTokens)
85+
}
86+
87+
func testGenerationSchema_topK() throws {
88+
let topK = 5
89+
let seed: UInt64 = 6_000_000_000
90+
let temperature = 0.4
91+
let maximumResponseTokens = 1000
92+
let options = FirebaseAI.GenerationOptions(
93+
sampling: .random(top: topK, seed: seed),
94+
temperature: temperature,
95+
maximumResponseTokens: maximumResponseTokens
96+
)
97+
98+
XCTAssertEqual(options.sampling, .random(top: topK, seed: seed))
99+
XCTAssertEqual(options.temperature, temperature)
100+
XCTAssertEqual(options.maximumResponseTokens, maximumResponseTokens)
101+
}
102+
}
103+
#endif // compiler(>=6.2.3)

0 commit comments

Comments
 (0)