diff --git a/.github/workflows/cargo-deny.yml b/.github/workflows/cargo-deny.yml index ee7cecce5aed..7d0a2718bd59 100644 --- a/.github/workflows/cargo-deny.yml +++ b/.github/workflows/cargo-deny.yml @@ -24,6 +24,6 @@ jobs: steps: - uses: actions/checkout@v6 # https://github.com/EmbarkStudios/cargo-deny-action v2.0.15 - - uses: EmbarkStudios/cargo-deny-action@a531616d8ce3b9177443e48a1159bc945a099823 + - uses: EmbarkStudios/cargo-deny-action@bb137d7af7e4fb67e5f82a49c4fce4fad40782fe with: command: check advisories diff --git a/.github/workflows/publish-npm.yml b/.github/workflows/publish-npm.yml index 3fdac7136a56..360849c22c26 100644 --- a/.github/workflows/publish-npm.yml +++ b/.github/workflows/publish-npm.yml @@ -32,7 +32,7 @@ jobs: always-auth: true - name: Setup pnpm - uses: pnpm/action-setup@8912a9102ac27614460f54aedde9e1e7f9aec20d # v6.0.5 + uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8 with: version: 10.30.3 @@ -167,7 +167,7 @@ jobs: always-auth: true - name: Setup pnpm - uses: pnpm/action-setup@8912a9102ac27614460f54aedde9e1e7f9aec20d # v6.0.5 + uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8 with: version: 10.30.3 diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 56eb8cec8ab9..b0b4b1de3f4e 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -39,7 +39,7 @@ jobs: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@f49aabe0b5af0936a0987cfb85d86b75731b0186 # v2.4.1 + uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3 with: results_file: results.sarif results_format: sarif diff --git a/Cargo.lock b/Cargo.lock index a962076cdd8a..199955151b5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -439,9 +439,9 @@ checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "aws-config" -version = "1.8.17" +version = "1.8.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "517aa062d8bd9015ee23d6daa5e1c1372328412fdae4e6c4c1be9b69c6ad37a2" +checksum = "e33f815b73a3899c03b380d543532e5865f230dce9678d108dc10732a8682275" dependencies = [ "aws-credential-types", "aws-runtime", @@ -531,10 +531,11 @@ dependencies = [ [[package]] name = "aws-sdk-bedrockruntime" -version = "1.131.0" +version = "1.132.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e494025b4c578bfefd025aada69c51ab1db6b7589f61cb78ae681f3115269209" +checksum = "41a2940faeb61f4f579a434bc3a546e9ab49a89596e94527d329281ef55fd44d" dependencies = [ + "arc-swap", "aws-credential-types", "aws-runtime", "aws-sigv4", @@ -558,10 +559,11 @@ dependencies = [ [[package]] name = "aws-sdk-sagemakerruntime" -version = "1.102.0" +version = "1.104.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb969c5c411b21a4bca5b52e9f5492af20e927fd4c410f748a3996462eb295c5" +checksum = "031fec5f68bdc840361ed09b98b6ef916cc9e49de0323bb4d584adad2cb87cc7" dependencies = [ + "arc-swap", "aws-credential-types", "aws-runtime", "aws-smithy-async", @@ -583,10 +585,11 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.99.0" +version = "1.101.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f4055e6099b2ec264abdc0d9bbfffce306c1601809275c861594779a0b04b45" +checksum = "b647baea49ff551960b904f905681e9b4765a6c4ea08631e89dc52d8bd3f5896" dependencies = [ + "arc-swap", "aws-credential-types", "aws-runtime", "aws-smithy-async", @@ -607,10 +610,11 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.101.0" +version = "1.103.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02f009ba0284c5d696425fd7b4dcc5b189f5726f4041b7a5794daecb3a68d598" +checksum = "7ae401c65ff288aa7873117fe535cd32b7b1bb0bc43751d28901a1d5f20636b9" dependencies = [ + "arc-swap", "aws-credential-types", "aws-runtime", "aws-smithy-async", @@ -631,10 +635,11 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.104.0" +version = "1.106.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aa6622798e19e6a76b690562085dd4771c736cd48343464a53ab4ae2f2c9f84" +checksum = "4c80de7bb7d03e9ca8c9fd7b489f20f3948d3f3be91a7953591347d238115408" dependencies = [ + "arc-swap", "aws-credential-types", "aws-runtime", "aws-smithy-async", @@ -723,9 +728,9 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.62.6" +version = "0.62.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "517089205f18ab4adc5a3e02888cb139bbbbb2e168eac9f396216925d1fbeaf5" +checksum = "701a947f4797e52a911e114a898667c746c39feea467bbd1abd7b3721f702ffa" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-schema", @@ -778,9 +783,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.12.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc117c179ecf39a62a0a3f49f600e9ac26a7ad7dd172177999f83933af776c32" +checksum = "9db177daa6ba8afb9ee1aefcf548c907abcf52065e394ee11a92780057fe0e8c" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api-macros", @@ -818,9 +823,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.4.8" +version = "1.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "056b66dbce2f81cc0c1e2b05bb402eb58f8a3530479d650efadd5bbae9a4050b" +checksum = "53f93074121a1be41317b9aa607143ae17900631f7f59a99f2b905d519d6783b" dependencies = [ "base64-simd", "bytes", @@ -1727,7 +1732,7 @@ dependencies = [ "rayon", "safetensors 0.7.0", "thiserror 2.0.18", - "tokenizers 0.22.2", + "tokenizers", "yoke 0.8.2", "zip 7.2.0", ] @@ -2030,9 +2035,9 @@ checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "clap_mangen" -version = "0.2.33" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e30ffc187e2e3aeafcd1c6e2aa416e29739454c0ccaa419226d5ecd181f2d78" +checksum = "d82842b45bf9f6a3be090dd860095ac30728042c08e0d6261ca7259b5d850f07" dependencies = [ "clap", "roff", @@ -2093,8 +2098,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "961b955a666e25ee5a1091d219128d6e6401e3dab84efb1a2bf6b4035d797b39" dependencies = [ "crmf", - "der", - "spki", + "der 0.7.10", + "spki 0.7.3", "x509-cert", ] @@ -2105,8 +2110,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b77c319abfd5219629c45c34c89ba945ed3c5e49fcde9d16b6c3885f118a730" dependencies = [ "const-oid 0.9.6", - "der", - "spki", + "der 0.7.10", + "spki 0.7.3", "x509-cert", ] @@ -2395,8 +2400,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36fe21b96d5b87f5de4b5b7202ec41c00110ac817ce6728fe75fb2fe5962ed92" dependencies = [ "cms", - "der", - "spki", + "der 0.7.10", + "spki 0.7.3", "x509-cert", ] @@ -2668,20 +2673,6 @@ dependencies = [ "parking_lot_core", ] -[[package]] -name = "dashmap" -version = "6.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" -dependencies = [ - "cfg-if", - "crossbeam-utils", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "data-encoding" version = "2.11.0" @@ -3107,6 +3098,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" +dependencies = [ + "const-oid 0.10.2", + "zeroize", +] + [[package]] name = "der-parser" version = "10.0.0" @@ -3404,9 +3405,9 @@ checksum = "9bda8e21c04aca2ae33ffc2fd8c23134f3cac46db123ba97bd9d3f3b8a4a85e1" [[package]] name = "dtor" -version = "1.0.3" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2137ce22f50d4c43ce098daf41c904cc700de1ce8bc2daf53ed4e702180a464" +checksum = "6d738e43aa64edab57c983d56de890d65fea7dc05605490c74451ce721dfd84b" dependencies = [ "linktime-proc-macro", ] @@ -3454,12 +3455,12 @@ version = "0.16.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ - "der", + "der 0.7.10", "digest 0.10.7", "elliptic-curve", "rfc6979", "signature", - "spki", + "spki 0.7.3", ] [[package]] @@ -3477,7 +3478,7 @@ version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" dependencies = [ - "pkcs8", + "pkcs8 0.10.2", "signature", ] @@ -3518,7 +3519,7 @@ dependencies = [ "group", "hkdf", "pem-rfc7468", - "pkcs8", + "pkcs8 0.10.2", "rand_core 0.6.4", "sec1", "subtle", @@ -4449,7 +4450,6 @@ version = "1.37.0" dependencies = [ "agent-client-protocol", "agent-client-protocol-schema", - "ahash", "anyhow", "arboard", "async-stream", @@ -4469,7 +4469,6 @@ dependencies = [ "chrono", "clap", "ctor", - "dashmap 6.2.1", "dirs 5.0.1", "dotenvy", "dtor", @@ -4510,12 +4509,12 @@ dependencies = [ "opentelemetry-appender-tracing", "opentelemetry-otlp 0.32.0", "opentelemetry-stdout", - "opentelemetry_sdk 0.32.0", + "opentelemetry_sdk 0.32.1", "pastey", "pctx_code_mode", "pem", "pkcs1", - "pkcs8", + "pkcs8 0.11.0", "process-wrap", "pulldown-cmark", "rand 0.8.6", @@ -4537,14 +4536,14 @@ dependencies = [ "shellexpand", "smithy-transport-reqwest", "sqlx", - "strum 0.27.2", + "strum 0.28.0", "symphonia", "sys-info", "tempfile", "test-case", "thiserror 1.0.69", "tiktoken-rs", - "tokenizers 0.21.4", + "tokenizers", "tokio", "tokio-cron-scheduler", "tokio-stream", @@ -4624,7 +4623,7 @@ dependencies = [ "sha2 0.11.0", "shlex", "sigstore-verify", - "strum 0.27.2", + "strum 0.28.0", "tar", "tempfile", "test-case", @@ -5904,9 +5903,9 @@ dependencies = [ [[package]] name = "linktime-proc-macro" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44cd706ff0d503ee32b2071166510ca27e281228de10cd3aa8d35ff94560f81" +checksum = "8c7b0a3383c2a1002d11349c92c85a666a5fb679e96c79d782cf0dbe557fd6ee" [[package]] name = "linux-keyutils" @@ -6176,9 +6175,9 @@ dependencies = [ [[package]] name = "mockall" -version = "0.13.1" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +checksum = "f58d964098a5f9c6b63d0798e5372fd04708193510a7af313c22e9f29b7b620b" dependencies = [ "cfg-if", "downcast", @@ -6190,9 +6189,9 @@ dependencies = [ [[package]] name = "mockall_derive" -version = "0.13.1" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +checksum = "ca41ce716dda6a9be188b385aa78ee5260fc25cd3802cb2a8afdc6afbe6b6dbf" dependencies = [ "cfg-if", "proc-macro2", @@ -6313,7 +6312,7 @@ dependencies = [ "async-trait", "boxed_error", "capacity_builder", - "dashmap 5.5.3", + "dashmap", "deno_error", "deno_maybe_sync", "deno_media_type", @@ -6926,7 +6925,7 @@ dependencies = [ "opentelemetry 0.32.0", "opentelemetry-http 0.32.0", "opentelemetry-proto 0.32.0", - "opentelemetry_sdk 0.32.0", + "opentelemetry_sdk 0.32.1", "prost", "reqwest 0.13.4", "thiserror 2.0.18", @@ -6952,7 +6951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56d658ba1faf63f7b9c492cfbe6e0ec365440a16132d3270c1065f7b33f1b638" dependencies = [ "opentelemetry 0.32.0", - "opentelemetry_sdk 0.32.0", + "opentelemetry_sdk 0.32.1", "prost", ] @@ -6964,7 +6963,7 @@ checksum = "a1b1c6a247d79091f0062a5f4bd058589525cf987a8d4c169440d9c1be72f0ad" dependencies = [ "chrono", "opentelemetry 0.32.0", - "opentelemetry_sdk 0.32.0", + "opentelemetry_sdk 0.32.1", ] [[package]] @@ -6984,9 +6983,9 @@ dependencies = [ [[package]] name = "opentelemetry_sdk" -version = "0.32.0" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368afaed344110f40b179bb8fbe54bc52d98f9bd2b281799ef32487c2650c956" +checksum = "9b59f80e1ac4d5ff7a2db8fb6c80badb7f0f3f858211fba08dd9aaec750894f9" dependencies = [ "futures-channel", "futures-executor", @@ -7446,9 +7445,9 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" dependencies = [ - "der", - "pkcs8", - "spki", + "der 0.7.10", + "pkcs8 0.10.2", + "spki 0.7.3", ] [[package]] @@ -7457,8 +7456,18 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ - "der", - "spki", + "der 0.7.10", + "spki 0.7.3", +] + +[[package]] +name = "pkcs8" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451913da69c775a56034ea8d9003d27ee8948e12443eae7c038ba100a4f21cb7" +dependencies = [ + "der 0.8.0", + "spki 0.8.0", ] [[package]] @@ -7683,7 +7692,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools 0.14.0", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.117", @@ -8363,10 +8372,10 @@ dependencies = [ "num-integer", "num-traits", "pkcs1", - "pkcs8", + "pkcs8 0.10.2", "rand_core 0.6.4", "signature", - "spki", + "spki 0.7.3", "subtle", "zeroize", ] @@ -8597,15 +8606,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "scc" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" -dependencies = [ - "sdd", -] - [[package]] name = "schannel" version = "0.1.29" @@ -8702,12 +8702,6 @@ dependencies = [ "sha2 0.10.9", ] -[[package]] -name = "sdd" -version = "3.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" - [[package]] name = "sec1" version = "0.7.3" @@ -8715,9 +8709,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" dependencies = [ "base16ct", - "der", + "der 0.7.10", "generic-array", - "pkcs8", + "pkcs8 0.10.2", "subtle", "zeroize", ] @@ -8968,21 +8962,20 @@ dependencies = [ [[package]] name = "serial_test" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f" +checksum = "699f4197115b8a7e7ff19c9a315a4bd6fffec26cc4626ef45ecaea389e081c6d" dependencies = [ "once_cell", "parking_lot", - "scc", "serial_test_derive", ] [[package]] name = "serial_test_derive" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" +checksum = "94e153fc76e1c6a068703d6d29c508a0b15c061c4b7e43da59cc097bc342673c" dependencies = [ "proc-macro2", "quote", @@ -9099,14 +9092,14 @@ dependencies = [ "aws-lc-rs", "base64 0.22.1", "const-oid 0.9.6", - "der", + "der 0.7.10", "digest 0.10.7", "pem", "rand_core 0.9.5", "sha2 0.10.9", "signature", "sigstore-types", - "spki", + "spki 0.7.3", "thiserror 2.0.18", "tracing", "x509-cert", @@ -9172,7 +9165,7 @@ dependencies = [ "cmpv2", "cms", "const-oid 0.9.6", - "der", + "der 0.7.10", "hex", "jiff", "rand 0.9.4", @@ -9334,9 +9327,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -9376,7 +9369,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", - "der", + "der 0.7.10", +] + +[[package]] +name = "spki" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9efca8738c78ee9484207732f728b1ef517bbb1833d6fc0879ca898a522f6f" +dependencies = [ + "base64ct", + "der 0.8.0", ] [[package]] @@ -10760,39 +10763,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "tokenizers" -version = "0.21.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a620b996116a59e184c2fa2dfd8251ea34a36d0a514758c6f966386bd2e03476" -dependencies = [ - "ahash", - "aho-corasick", - "compact_str 0.9.0", - "dary_heap", - "derive_builder", - "esaxx-rs", - "getrandom 0.3.4", - "itertools 0.14.0", - "log", - "macro_rules_attribute", - "monostate", - "onig", - "paste", - "rand 0.9.4", - "rayon", - "rayon-cond", - "regex", - "regex-syntax", - "serde", - "serde_json", - "spm_precompiled", - "thiserror 2.0.18", - "unicode-normalization-alignments", - "unicode-segmentation", - "unicode_categories", -] - [[package]] name = "tokenizers" version = "0.22.2" @@ -11323,9 +11293,9 @@ dependencies = [ [[package]] name = "tree-sitter-swift" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3b98fb6bc8e6a6a10023f401aa6a1858115e849dfaf7de57dd8b8ea0f257bd9" +checksum = "fe36052155b9dd69ca82b3b8f1b4ccfb2d867125ac1a4db1dd7331829242668c" dependencies = [ "cc", "tree-sitter-language", @@ -11713,9 +11683,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "d258b83ceec21034727ecee8c382cfa6c3e133699b0742c64571814fb420c9f7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -12692,10 +12662,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1301e935010a701ae5f8655edc0ad17c44bad3ac5ce8c39185f75453b720ae94" dependencies = [ "const-oid 0.9.6", - "der", + "der 0.7.10", "sha1", "signature", - "spki", + "spki 0.7.3", "tls_codec", ] @@ -12725,7 +12695,7 @@ checksum = "f5ceece934a21607055b7ac5c25adb56a2ff559804b10705dc674d1d838c15e1" dependencies = [ "cmpv2", "cms", - "der", + "der 0.7.10", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a2434589a8da..d21c8a3a38f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,7 @@ serde = { version = "1.0.228", default-features = false, features = ["derive", " serde_json = { version = "1.0.145", default-features = false, features = ["std"] } serde_yaml = { version = "0.9.32", default-features = false } shellexpand = { version = "3", default-features = false, features = ["base-0", "tilde"] } -strum = { version = "0.27.1", default-features = false, features = ["derive", "std"] } +strum = { version = "0.28.0", default-features = false, features = ["derive", "std"] } tempfile = { version = "3.10.1", default-features = false } thiserror = { version = "1.0.49", default-features = false } tokio = { version = "1.48", default-features = false } @@ -69,7 +69,7 @@ tracing-futures = { version = "0.2.4", default-features = false, features = ["fu tracing-subscriber = { version = "0.3.22", default-features = false, features = ["std"] } urlencoding = { version = "2.1", default-features = false } utoipa = { version = "4.2", default-features = false } -uuid = { version = "1.18", default-features = false, features = ["v4", "std"] } +uuid = { version = "1.23", default-features = false, features = ["v4", "std"] } webbrowser = { version = "1", default-features = false } which = { version = "8", default-features = false, features = ["real-sys"] } winapi = { version = "0.3.9", default-features = false, features = ["wincred", "std"] } diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 7dadbd32c8a3..15e2c5304809 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -20,7 +20,7 @@ name = "generate_manpages" path = "src/bin/generate_manpages.rs" [dependencies] -clap_mangen = { version = "0.2", default-features = false } +clap_mangen = { version = "0.3", default-features = false } goose = { path = "../goose", default-features = false } goose-mcp = { path = "../goose-mcp", default-features = false } rmcp = { workspace = true } diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index ecad829f88a2..6bf45d6da2f0 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -1944,11 +1944,16 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()> // Download let manager = goose::download_manager::get_download_manager(); + let hf_token = goose::providers::huggingface_auth::resolve_token_async() + .await + .ok() + .flatten(); manager - .download_model_sharded( + .download_model_sharded_with_bearer_token( format!("{}-model", model_id), download_files, file.size_bytes + mmproj_size_bytes, + hf_token, None, ) .await?; diff --git a/crates/goose-sdk/src/custom_notifications.rs b/crates/goose-sdk/src/custom_notifications.rs new file mode 100644 index 000000000000..44424fb3b3b1 --- /dev/null +++ b/crates/goose-sdk/src/custom_notifications.rs @@ -0,0 +1,173 @@ +use crate::custom_requests::CustomMethodSchema; +use agent_client_protocol::{JsonRpcMessage, JsonRpcNotification}; +use schemars::{JsonSchema, SchemaGenerator}; +use serde::{Deserialize, Serialize}; + +/// Goose-custom session update notification — a parallel to ACP's +/// `session/update` carrying goose-specific update variants. +#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcNotification)] +#[notification(method = "_goose/unstable/session/update")] +#[serde(rename_all = "camelCase")] +pub struct GooseSessionNotification { + pub session_id: String, + pub update: GooseSessionUpdate, +} + +/// Discriminated union of goose-specific session update payloads. +/// Variant tag matches ACP's convention (`sessionUpdate: ""`). +/// +/// `discriminator.mapping` is what makes TS codegen (`@hey-api/openapi-ts`) +/// emit the correct snake_case tag value even when this enum has a single +/// variant. Add a mapping entry per variant. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "sessionUpdate", rename_all = "snake_case")] +#[schemars(extend("discriminator" = { + "propertyName": "sessionUpdate", + "mapping": { + "usage_update": "#/$defs/SessionUsageUpdate", + "status_message": "#/$defs/StatusMessageUpdate", + "interaction_update": "#/$defs/InteractionUpdate" + } +}))] +pub enum GooseSessionUpdate { + UsageUpdate(SessionUsageUpdate), + StatusMessage(StatusMessageUpdate), + InteractionUpdate(InteractionUpdate), +} + +impl Default for GooseSessionUpdate { + fn default() -> Self { + GooseSessionUpdate::UsageUpdate(SessionUsageUpdate::default()) + } +} + +/// Streaming context-window usage update for a session. +#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct SessionUsageUpdate { + pub used: u64, + pub context_limit: u64, + pub accumulated_input_tokens: u64, + pub accumulated_output_tokens: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub accumulated_cost: Option, +} + +/// Live UI/session status. This is not conversation transcript content, and +/// should not be persisted or replayed as history. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct StatusMessageUpdate { + pub status: StatusMessage, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StatusMessage { + #[serde(rename_all = "camelCase")] + Notice { message: String }, + #[serde(rename_all = "camelCase")] + Progress { message: String }, +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct InteractionUpdate { + pub interaction: Interaction, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "_meta")] + pub meta: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Interaction { + #[serde(rename_all = "camelCase")] + Elicitation { + id: String, + state: InteractionState, + #[serde(skip_serializing_if = "Option::is_none")] + message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + requested_schema: Option, + }, +} + +impl Default for Interaction { + fn default() -> Self { + Self::Elicitation { + id: String::new(), + state: InteractionState::Pending, + message: None, + requested_schema: None, + } + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum InteractionState { + #[default] + Pending, + Submitted, +} + +fn notification_schema(generator: &mut SchemaGenerator) -> CustomMethodSchema +where + T: Default + JsonRpcMessage + JsonSchema, +{ + let dummy = T::default(); + let type_name = std::any::type_name::() + .rsplit("::") + .next() + .unwrap_or(std::any::type_name::()) + .to_string(); + CustomMethodSchema { + method: dummy.method().to_string(), + params_schema: Some(generator.subschema_for::()), + params_type_name: Some(type_name), + response_schema: None, + response_type_name: None, + } +} + +/// Schemas for every goose-custom outbound notification. To register a new +/// notification, define the struct above (with `JsonRpcNotification` + +/// `Default`) and add one line below. +pub fn custom_notification_schemas(generator: &mut SchemaGenerator) -> Vec { + vec![notification_schema::(generator)] +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn status_message_serializes_to_expected_wire_shape() { + let notification = GooseSessionNotification { + session_id: "s1".to_string(), + update: GooseSessionUpdate::StatusMessage(StatusMessageUpdate { + status: StatusMessage::Notice { + message: "Compaction complete".to_string(), + }, + }), + }; + + let value = serde_json::to_value(notification).unwrap(); + + assert_eq!( + value, + json!({ + "sessionId": "s1", + "update": { + "sessionUpdate": "status_message", + "status": { + "type": "notice", + "message": "Compaction complete" + } + } + }) + ); + } +} diff --git a/crates/goose-sdk/src/custom_requests.rs b/crates/goose-sdk/src/custom_requests.rs index 5b007574bbd5..672ff8b09003 100644 --- a/crates/goose-sdk/src/custom_requests.rs +++ b/crates/goose-sdk/src/custom_requests.rs @@ -436,6 +436,17 @@ pub struct ImportSessionResponse { pub message_count: u64, } +/// Submit a response for a pending MCP elicitation in an active session. +#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcRequest)] +#[request(method = "_goose/unstable/elicitation/respond", response = EmptyResponse)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRespondRequest { + pub session_id: String, + pub elicitation_id: String, + #[serde(default)] + pub user_data: serde_json::Value, +} + #[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub struct ProviderConfigKey { diff --git a/crates/goose-sdk/src/lib.rs b/crates/goose-sdk/src/lib.rs index 6c1c1bf50c05..eb50a0179e87 100644 --- a/crates/goose-sdk/src/lib.rs +++ b/crates/goose-sdk/src/lib.rs @@ -1 +1,2 @@ +pub mod custom_notifications; pub mod custom_requests; diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index afff2d7c93af..3a0df42a07ff 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -393,6 +393,8 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::config_management::remove_config, super::routes::config_management::read_config, super::routes::config_management::read_all_config, + super::routes::config_management::list_provider_secrets, + super::routes::config_management::delete_provider_secret, super::routes::config_management::providers, super::routes::config_management::get_provider_models, super::routes::config_management::get_provider_model_info, @@ -482,6 +484,10 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::config_management::ConfigResponse, super::routes::config_management::ProvidersResponse, super::routes::config_management::ProviderDetails, + super::routes::config_management::ProviderSecretsResponse, + super::routes::config_management::ProviderSecret, + super::routes::config_management::ProviderSecretStorage, + super::routes::config_management::ProviderSecretStatus, super::routes::config_management::SlashCommandsResponse, super::routes::config_management::SlashCommand, super::routes::config_management::CommandType, diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index a31f079b5f5f..f556c39a6510 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -7,8 +7,10 @@ use axum::{ routing::{delete, get, post}, Json, Router, }; +use chrono::{DateTime, TimeZone, Utc}; use goose::config::declarative_providers::LoadedProvider; use goose::config::paths::Paths; +use goose::config::ExtensionEntry; use goose::config::{Config, ConfigError}; use goose::custom_requests::SourceType; use goose::model::ModelConfig; @@ -19,17 +21,35 @@ use goose::providers::catalog::{ ProviderTemplate, }; use goose::providers::create_with_default_model; +use goose::providers::huggingface_auth; use goose::providers::providers as get_providers; use goose::{ - agents::execute_commands, config::permission::PermissionLevel, + agents::execute_commands, agents::ExtensionConfig, config::permission::PermissionLevel, slash_commands::recipe_slash_command, }; use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_yaml; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use utoipa::ToSchema; +#[derive(Serialize, ToSchema)] +pub struct ExtensionResponse { + pub extensions: Vec, + #[serde(default)] + pub warnings: Vec, +} + +#[derive(Deserialize, ToSchema)] +pub struct ExtensionQuery { + pub name: String, + pub config: ExtensionConfig, + pub enabled: bool, +} + #[derive(Deserialize, ToSchema)] pub struct UpsertConfigQuery { pub key: String, @@ -126,6 +146,43 @@ pub enum ConfigValueResponse { MaskedValue(MaskedSecret), } +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ProviderSecretStorage { + SecretStore, + ProviderCache, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ProviderSecretStatus { + Valid, + Expired, + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct ProviderSecret { + pub id: String, + pub provider: String, + pub provider_display_name: String, + pub name: String, + pub storage: ProviderSecretStorage, + pub expires_at: Option>, + pub status: ProviderSecretStatus, + pub configured: bool, + pub has_secret: bool, + pub can_delete: bool, + pub can_configure: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub configure_provider: Option, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct ProviderSecretsResponse { + pub secrets: Vec, +} + #[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub enum CommandType { Builtin, @@ -233,6 +290,343 @@ fn mask_secret(secret: Value) -> String { format!("{}{}", visible, mask) } +const SECRET_STORE_ID_PREFIX: &str = "secret_store:"; +const PROVIDER_CACHE_ID_PREFIX: &str = "provider_cache:"; + +fn provider_secret_status(expires_at: Option>) -> ProviderSecretStatus { + match expires_at { + Some(expires_at) if expires_at <= Utc::now() => ProviderSecretStatus::Expired, + Some(_) => ProviderSecretStatus::Valid, + None => ProviderSecretStatus::Unknown, + } +} + +fn parse_expiry_value(value: &Value) -> Option> { + match value { + Value::String(value) => DateTime::parse_from_rfc3339(value) + .ok() + .map(|dt| dt.with_timezone(&Utc)), + Value::Number(value) => value + .as_i64() + .and_then(|timestamp| Utc.timestamp_opt(timestamp, 0).single()), + _ => None, + } +} + +fn find_expires_at(value: &Value) -> Option> { + match value { + Value::Object(map) => { + if map + .get("refresh_token") + .and_then(Value::as_str) + .is_some_and(|token| !token.is_empty()) + { + return None; + } + if let Some(expires_at) = map.get("expires_at").and_then(parse_expiry_value) { + return Some(expires_at); + } + if let Some(expires_at) = map.get("expires_on").and_then(parse_expiry_value) { + return Some(expires_at); + } + map.values().find_map(find_expires_at) + } + Value::Array(values) => values.iter().find_map(find_expires_at), + _ => None, + } +} + +#[derive(Clone, Copy)] +struct ProviderCacheSecretDefinition { + provider: &'static str, + name: &'static str, + path: &'static str, + is_directory: bool, +} + +const PROVIDER_CACHE_SECRET_DEFINITIONS: &[ProviderCacheSecretDefinition] = &[ + ProviderCacheSecretDefinition { + provider: "gemini_oauth", + name: "OAuth token", + path: "gemini_oauth/tokens.json", + is_directory: false, + }, + ProviderCacheSecretDefinition { + provider: "chatgpt_codex", + name: "OAuth token", + path: "chatgpt_codex/tokens.json", + is_directory: false, + }, + ProviderCacheSecretDefinition { + provider: "kimi_code", + name: "OAuth token", + path: "kimicode/token.json", + is_directory: false, + }, + ProviderCacheSecretDefinition { + provider: "github_copilot", + name: "OAuth token", + path: "githubcopilot", + is_directory: true, + }, + ProviderCacheSecretDefinition { + provider: "xai_oauth", + name: "OAuth token", + path: "xai_oauth/tokens.json", + is_directory: false, + }, + ProviderCacheSecretDefinition { + provider: "databricks", + name: "OAuth token", + path: "databricks/oauth", + is_directory: true, + }, + ProviderCacheSecretDefinition { + provider: "databricks_v2", + name: "OAuth token", + path: "databricks/oauth", + is_directory: true, + }, +]; + +fn provider_cache_definitions_for_display() -> Vec { + let mut seen_paths = HashSet::new(); + PROVIDER_CACHE_SECRET_DEFINITIONS + .iter() + .copied() + .filter(|definition| seen_paths.insert(definition.path)) + .collect() +} + +fn provider_cache_definition(provider: &str) -> Option { + PROVIDER_CACHE_SECRET_DEFINITIONS + .iter() + .copied() + .find(|definition| definition.provider == provider) +} + +fn provider_cache_providers_sharing_cache(provider: &str) -> Vec<&'static str> { + let Some(definition) = provider_cache_definition(provider) else { + return Vec::new(); + }; + + PROVIDER_CACHE_SECRET_DEFINITIONS + .iter() + .filter(|other| other.path == definition.path) + .map(|definition| definition.provider) + .collect() +} + +fn read_json_file(path: &std::path::Path) -> Option { + std::fs::read_to_string(path) + .ok() + .and_then(|contents| serde_json::from_str(&contents).ok()) +} + +fn collect_json_expiries(path: &std::path::Path, is_directory: bool) -> Vec> { + if !is_directory { + return read_json_file(path) + .and_then(|value| find_expires_at(&value)) + .into_iter() + .collect(); + } + + let mut expiries = Vec::new(); + let mut stack = vec![path.to_path_buf()]; + + while let Some(current) = stack.pop() { + let Ok(entries) = std::fs::read_dir(current) else { + continue; + }; + + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + stack.push(path); + continue; + } + if path.extension().and_then(|ext| ext.to_str()) != Some("json") { + continue; + } + if let Some(expires_at) = + read_json_file(&path).and_then(|value| find_expires_at(&value)) + { + expiries.push(expires_at); + } + } + } + + expiries +} + +fn provider_cache_exists(path: &std::path::Path, is_directory: bool) -> bool { + if !is_directory { + return path.is_file(); + } + + let Ok(entries) = std::fs::read_dir(path) else { + return false; + }; + + entries.flatten().any(|entry| { + let path = entry.path(); + path.is_file() || provider_cache_exists(&path, true) + }) +} + +fn provider_cache_expiry(definition: ProviderCacheSecretDefinition) -> Option> { + let path = Paths::in_config_dir(definition.path); + let expiries = collect_json_expiries(&path, definition.is_directory); + expiries.into_iter().min() +} + +fn build_provider_cache_secret( + definition: ProviderCacheSecretDefinition, + display_names: &HashMap, +) -> Option { + let path = Paths::in_config_dir(definition.path); + if !provider_cache_exists(&path, definition.is_directory) { + return None; + } + + let expires_at = provider_cache_expiry(definition); + Some(ProviderSecret { + id: format!("{}{}", PROVIDER_CACHE_ID_PREFIX, definition.provider), + provider: definition.provider.to_string(), + provider_display_name: display_names + .get(definition.provider) + .cloned() + .unwrap_or_else(|| definition.provider.to_string()), + name: definition.name.to_string(), + storage: ProviderSecretStorage::ProviderCache, + expires_at, + status: provider_secret_status(expires_at), + configured: true, + has_secret: true, + can_delete: true, + can_configure: false, + configure_provider: None, + }) +} + +fn build_huggingface_oauth_secret( + token: Option, +) -> ProviderSecret { + let expires_at = token.as_ref().and_then(|token| token.expires_at); + let has_secret = token.is_some(); + + ProviderSecret { + id: format!( + "{}{}", + PROVIDER_CACHE_ID_PREFIX, + huggingface_auth::HUGGINGFACE_PROVIDER_NAME + ), + provider: huggingface_auth::HUGGINGFACE_PROVIDER_NAME.to_string(), + provider_display_name: huggingface_auth::HUGGINGFACE_DISPLAY_NAME.to_string(), + name: huggingface_auth::HUGGINGFACE_OAUTH_TOKEN_NAME.to_string(), + storage: ProviderSecretStorage::ProviderCache, + expires_at, + status: provider_secret_status(expires_at), + configured: has_secret, + has_secret, + can_delete: has_secret, + can_configure: true, + configure_provider: Some(huggingface_auth::HUGGINGFACE_PROVIDER_NAME.to_string()), + } +} + +fn build_secret_store_secrets( + stored_secrets: &HashMap, + providers: &[(ProviderMetadata, ProviderType)], +) -> Vec { + let mut secrets = Vec::new(); + + for (metadata, _) in providers { + for config_key in metadata.config_keys.iter().filter(|key| key.secret) { + if !stored_secrets.contains_key(&config_key.name) { + continue; + } + secrets.push(ProviderSecret { + id: format!( + "{}{}:{}", + SECRET_STORE_ID_PREFIX, metadata.name, config_key.name + ), + provider: metadata.name.clone(), + provider_display_name: metadata.display_name.clone(), + name: config_key.name.clone(), + storage: ProviderSecretStorage::SecretStore, + expires_at: None, + status: ProviderSecretStatus::Unknown, + configured: true, + has_secret: true, + can_delete: true, + can_configure: false, + configure_provider: None, + }); + } + } + + secrets +} + +fn is_known_provider_secret( + providers: &[(ProviderMetadata, ProviderType)], + provider: &str, + key: &str, +) -> bool { + providers + .iter() + .filter(|(metadata, _)| metadata.name == provider) + .flat_map(|(metadata, _)| metadata.config_keys.iter()) + .any(|config_key| config_key.secret && config_key.name == key) +} + +fn unconfigure_provider(config: &Config, provider_name: &str) -> Result<(), ConfigError> { + if let Some(mut entry) = goose::config::get_provider_entry(config, provider_name) { + entry.configured = false; + goose::config::set_provider_entry(config, provider_name, &entry)?; + } + + let configured_marker = format!("{}_configured", provider_name); + config.delete(&configured_marker)?; + Ok(()) +} + +fn mark_provider_configured(config: &Config, provider_name: &str) -> Result<(), ConfigError> { + if let Some(mut entry) = goose::config::get_provider_entry(config, provider_name) { + entry.configured = true; + goose::config::set_provider_entry(config, provider_name, &entry)?; + } else { + let model = if goose::config::get_active_provider(config).as_deref() == Some(provider_name) + { + config.get_goose_model().unwrap_or_default() + } else { + String::new() + }; + goose::config::set_provider_entry( + config, + provider_name, + &goose::config::ProviderEntry { + enabled: true, + model, + configured: true, + }, + )?; + } + + Ok(()) +} + +fn parse_secret_store_id(id: &str) -> Option<(&str, &str)> { + let rest = id.strip_prefix(SECRET_STORE_ID_PREFIX)?; + let (provider, key) = rest.split_once(':')?; + Some((provider, key)) +} + +fn parse_provider_cache_id(id: &str) -> Option<&str> { + id.strip_prefix(PROVIDER_CACHE_ID_PREFIX) +} + fn is_valid_provider_name(provider_name: &str) -> bool { !provider_name.is_empty() && provider_name @@ -240,6 +634,123 @@ fn is_valid_provider_name(provider_name: &str) -> bool { .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') } +fn should_unconfigure_after_secret_delete( + provider: &str, + key: &str, + has_usable_huggingface_oauth_token: impl FnOnce() -> bool, +) -> bool { + provider == huggingface_auth::HUGGINGFACE_PROVIDER_NAME + && key == huggingface_auth::HUGGINGFACE_TOKEN_SECRET_KEY + && !has_usable_huggingface_oauth_token() +} + +#[utoipa::path( + get, + path = "/config/provider-secrets", + responses( + (status = 200, description = "Provider secrets retrieved successfully", body = ProviderSecretsResponse), + (status = 500, description = "Internal server error") + ) +)] +pub async fn list_provider_secrets() -> Result, ErrorResponse> { + let config = Config::global(); + let stored_secrets = config.all_secrets()?; + let providers = get_providers().await; + let display_names: HashMap = providers + .iter() + .map(|(metadata, _)| (metadata.name.clone(), metadata.display_name.clone())) + .collect(); + + let mut secrets = build_secret_store_secrets(&stored_secrets, &providers); + + for definition in provider_cache_definitions_for_display() { + if let Some(secret) = build_provider_cache_secret(definition, &display_names) { + if !secrets.iter().any(|existing| existing.id == secret.id) { + secrets.push(secret); + } + } + } + + let huggingface_secret = build_huggingface_oauth_secret(huggingface_auth::load_oauth_token()); + if let Some(existing) = secrets + .iter_mut() + .find(|existing| existing.id == huggingface_secret.id) + { + *existing = huggingface_secret; + } else { + secrets.push(huggingface_secret); + } + + secrets.sort_by(|a, b| { + a.provider_display_name + .cmp(&b.provider_display_name) + .then_with(|| a.name.cmp(&b.name)) + }); + + Ok(Json(ProviderSecretsResponse { secrets })) +} + +#[utoipa::path( + delete, + path = "/config/provider-secrets/{id}", + params( + ("id" = String, Path, description = "Provider secret identifier") + ), + responses( + (status = 200, description = "Provider secret deleted successfully", body = String), + (status = 400, description = "Invalid provider secret identifier"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn delete_provider_secret(Path(id): Path) -> Result, ErrorResponse> { + let config = Config::global(); + + if let Some((provider, key)) = parse_secret_store_id(&id) { + let providers = get_providers().await; + if !is_known_provider_secret(&providers, provider, key) { + return Err(ErrorResponse::bad_request(format!( + "Invalid provider secret id: '{}'", + id + ))); + } + + config.delete_secret(key)?; + if should_unconfigure_after_secret_delete(provider, key, || { + huggingface_auth::has_configured_token().unwrap_or(false) + }) { + unconfigure_provider(config, provider)?; + } + return Ok(Json(format!("Deleted provider secret {}", id))); + } + + if let Some(provider) = parse_provider_cache_id(&id) { + if provider == huggingface_auth::HUGGINGFACE_PROVIDER_NAME { + huggingface_auth::clear_oauth_token()?; + unconfigure_provider(config, provider)?; + return Ok(Json(format!("Deleted provider secret {}", id))); + } + + let cache_definition = provider_cache_definition(provider); + + if !is_valid_provider_name(provider) || cache_definition.is_none() { + return Err(ErrorResponse::bad_request(format!( + "Invalid provider name: '{}'", + provider + ))); + } + goose::providers::cleanup_provider(provider).await?; + for shared_provider in provider_cache_providers_sharing_cache(provider) { + unconfigure_provider(config, shared_provider)?; + } + return Ok(Json(format!("Deleted provider secret {}", id))); + } + + Err(ErrorResponse::bad_request(format!( + "Invalid provider secret id: '{}'", + id + ))) +} + #[utoipa::path( post, path = "/config/read", @@ -284,6 +795,72 @@ pub async fn read_config( Ok(Json(response_value)) } +#[utoipa::path( + get, + path = "/config/extensions", + responses( + (status = 200, description = "All extensions retrieved successfully", body = ExtensionResponse), + (status = 500, description = "Internal server error") + ) +)] +pub async fn get_extensions() -> Result, ErrorResponse> { + let extensions = goose::config::get_all_extensions() + .into_iter() + .filter(|ext| !goose::agents::extension_manager::is_hidden_extension(&ext.config.name())) + .collect(); + let warnings = goose::config::get_warnings(); + Ok(Json(ExtensionResponse { + extensions, + warnings, + })) +} + +#[utoipa::path( + post, + path = "/config/extensions", + request_body = ExtensionQuery, + responses( + (status = 200, description = "Extension added or updated successfully", body = String), + (status = 400, description = "Invalid request"), + (status = 422, description = "Could not serialize config.yaml"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn add_extension( + Json(extension_query): Json, +) -> Result, ErrorResponse> { + let extensions = goose::config::get_all_extensions(); + let key = goose::config::extensions::name_to_key(&extension_query.name); + + let is_update = extensions.iter().any(|e| e.config.key() == key); + + goose::config::set_extension(ExtensionEntry { + enabled: extension_query.enabled, + config: extension_query.config, + }); + + if is_update { + Ok(Json(format!("Updated extension {}", extension_query.name))) + } else { + Ok(Json(format!("Added extension {}", extension_query.name))) + } +} + +#[utoipa::path( + delete, + path = "/config/extensions/{name}", + responses( + (status = 200, description = "Extension removed successfully", body = String), + (status = 404, description = "Extension not found"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn remove_extension(Path(name): Path) -> Result, ErrorResponse> { + let key = goose::config::extensions::name_to_key(&name); + goose::config::remove_extension(&key); + Ok(Json(format!("Removed extension {}", name))) +} + #[utoipa::path( get, path = "/config", @@ -852,6 +1429,17 @@ pub async fn configure_provider_oauth( ))); } + if provider_name == huggingface_auth::HUGGINGFACE_PROVIDER_NAME { + huggingface_auth::configure_oauth().await.map_err(|e| { + ErrorResponse::bad_request(format!( + "OAuth configuration failed for provider '{}': {}", + provider_name, e + )) + })?; + mark_provider_configured(goose::config::Config::global(), &provider_name)?; + return Ok(Json("OAuth configuration completed".to_string())); + } + let temp_model = ModelConfig::new("temp") .map_err(|e| { ErrorResponse::bad_request(format!("Failed to create temporary model config: {}", e)) @@ -875,29 +1463,7 @@ pub async fn configure_provider_oauth( )) })?; - // Mark the provider as configured after successful OAuth - let config = goose::config::Config::global(); - if let Some(mut entry) = goose::config::get_provider_entry(config, &provider_name) { - entry.configured = true; - goose::config::set_provider_entry(config, &provider_name, &entry)?; - } else { - let model = if goose::config::get_active_provider(config).as_deref() - == Some(provider_name.as_str()) - { - config.get_goose_model().unwrap_or_default() - } else { - String::new() - }; - goose::config::set_provider_entry( - config, - &provider_name, - &goose::config::ProviderEntry { - enabled: true, - model, - configured: true, - }, - )?; - } + mark_provider_configured(goose::config::Config::global(), &provider_name)?; Ok(Json("OAuth configuration completed".to_string())) } @@ -908,6 +1474,14 @@ pub fn routes(state: Arc) -> Router { .route("/config/upsert", post(upsert_config)) .route("/config/remove", post(remove_config)) .route("/config/read", post(read_config)) + .route("/config/provider-secrets", get(list_provider_secrets)) + .route( + "/config/provider-secrets/{id}", + delete(delete_provider_secret), + ) + .route("/config/extensions", get(get_extensions)) + .route("/config/extensions", post(add_extension)) + .route("/config/extensions/{name}", delete(remove_extension)) .route("/config/providers", get(providers)) .route("/config/providers/{name}/models", get(get_provider_models)) .route( @@ -947,4 +1521,281 @@ pub fn routes(state: Arc) -> Router { } #[cfg(test)] -mod tests {} +mod tests { + use super::*; + use goose::config::ProviderEntry; + use goose::providers::base::ConfigKey; + use serde_json::json; + + fn new_test_config() -> Config { + let unique = format!( + "goose-server-config-test-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + ); + let config_path = std::env::temp_dir().join(format!("{unique}-config.yaml")); + let secrets_path = std::env::temp_dir().join(format!("{unique}-secrets.yaml")); + Config::new_with_file_secrets(config_path, secrets_path).unwrap() + } + + #[test] + fn secret_store_listing_only_includes_provider_secret_keys() { + let metadata = ProviderMetadata::new( + "openai", + "OpenAI", + "OpenAI provider", + "gpt-4o", + vec![], + "https://example.com", + vec![ + ConfigKey::new("OPENAI_API_KEY", true, true, None, true), + ConfigKey::new("OPENAI_HOST", false, false, None, false), + ], + ); + let providers = vec![(metadata, ProviderType::Builtin)]; + let stored_secrets = HashMap::from([ + ( + "OPENAI_API_KEY".to_string(), + Value::String("secret-value".to_string()), + ), + ( + "UNRELATED_SECRET".to_string(), + Value::String("other-secret".to_string()), + ), + ( + "OPENAI_HOST".to_string(), + Value::String("https://api.openai.com".to_string()), + ), + ]); + + let secrets = build_secret_store_secrets(&stored_secrets, &providers); + + assert_eq!(secrets.len(), 1); + assert_eq!(secrets[0].id, "secret_store:openai:OPENAI_API_KEY"); + assert_eq!(secrets[0].provider_display_name, "OpenAI"); + assert_eq!(secrets[0].name, "OPENAI_API_KEY"); + assert_eq!(secrets[0].storage, ProviderSecretStorage::SecretStore); + assert_eq!(secrets[0].status, ProviderSecretStatus::Unknown); + } + + #[test] + fn provider_secret_delete_validation_requires_provider_secret_key() { + let metadata = ProviderMetadata::new( + "openai", + "OpenAI", + "OpenAI provider", + "gpt-4o", + vec![], + "https://example.com", + vec![ + ConfigKey::new("OPENAI_API_KEY", true, true, None, true), + ConfigKey::new("OPENAI_HOST", false, false, None, false), + ], + ); + let providers = vec![(metadata, ProviderType::Builtin)]; + + assert!(is_known_provider_secret( + &providers, + "openai", + "OPENAI_API_KEY" + )); + assert!(!is_known_provider_secret( + &providers, + "openai", + "OPENAI_HOST" + )); + assert!(!is_known_provider_secret( + &providers, + "openai", + "UNRELATED_SECRET" + )); + assert!(!is_known_provider_secret( + &providers, + "anthropic", + "OPENAI_API_KEY" + )); + } + + #[test] + fn expiry_extraction_handles_nested_rfc3339_values() { + let expires_at = Utc::now() + chrono::Duration::hours(1); + let value = json!({ + "project_id": "project", + "token": { + "access_token": "secret", + "expires_at": expires_at.to_rfc3339(), + } + }); + + let parsed = find_expires_at(&value).expect("expected expiry"); + + assert_eq!(parsed.timestamp(), expires_at.timestamp()); + assert_eq!( + provider_secret_status(Some(parsed)), + ProviderSecretStatus::Valid + ); + } + + #[test] + fn expiry_extraction_ignores_refreshable_access_tokens() { + let expires_at = Utc::now() - chrono::Duration::hours(1); + let value = json!({ + "access_token": "access", + "refresh_token": "refresh", + "expires_at": expires_at.to_rfc3339(), + }); + + assert_eq!(find_expires_at(&value), None); + } + + #[test] + fn expiry_extraction_handles_expired_unix_timestamps() { + let value = json!({ + "info": { + "expires_at": 1 + } + }); + + let parsed = find_expires_at(&value).expect("expected expiry"); + + assert_eq!(parsed.timestamp(), 1); + assert_eq!( + provider_secret_status(Some(parsed)), + ProviderSecretStatus::Expired + ); + } + + #[test] + fn provider_secret_ids_parse_expected_prefixes() { + assert_eq!( + parse_secret_store_id("secret_store:openai:OPENAI_API_KEY"), + Some(("openai", "OPENAI_API_KEY")) + ); + assert_eq!( + parse_provider_cache_id("provider_cache:gemini_oauth"), + Some("gemini_oauth") + ); + assert_eq!(parse_secret_store_id("provider_cache:openai"), None); + assert_eq!(parse_provider_cache_id("secret_store:openai:key"), None); + } + + #[test] + fn shared_databricks_cache_is_displayed_once() { + let databricks_definitions: Vec<_> = provider_cache_definitions_for_display() + .into_iter() + .filter(|definition| definition.path == "databricks/oauth") + .collect(); + + assert_eq!(databricks_definitions.len(), 1); + assert_eq!(databricks_definitions[0].provider, "databricks"); + } + + #[test] + fn shared_databricks_cache_unconfigures_both_providers() { + assert_eq!( + provider_cache_providers_sharing_cache("databricks"), + vec!["databricks", "databricks_v2"] + ); + assert_eq!( + provider_cache_providers_sharing_cache("databricks_v2"), + vec!["databricks", "databricks_v2"] + ); + } + + #[test] + fn unconfigure_provider_clears_structured_entry() { + let config = new_test_config(); + goose::config::set_provider_entry( + &config, + "huggingface", + &ProviderEntry { + enabled: true, + model: "Qwen/Qwen3-Coder-480B-A35B-Instruct".to_string(), + configured: true, + }, + ) + .unwrap(); + + unconfigure_provider(&config, "huggingface").unwrap(); + + let entry = goose::config::get_provider_entry(&config, "huggingface").unwrap(); + assert!(entry.enabled); + assert_eq!(entry.model, "Qwen/Qwen3-Coder-480B-A35B-Instruct"); + assert!(!entry.configured); + } + + #[test] + fn unconfigure_provider_deletes_legacy_configured_marker() { + let config = new_test_config(); + config.set_param("huggingface_configured", true).unwrap(); + + unconfigure_provider(&config, "huggingface").unwrap(); + + assert!(config.get_param::("huggingface_configured").is_err()); + } + + #[test] + fn deleting_huggingface_token_unconfigures_without_oauth() { + assert!(should_unconfigure_after_secret_delete( + "huggingface", + "HF_TOKEN", + || false + )); + } + + #[test] + fn deleting_huggingface_token_keeps_configured_with_oauth() { + assert!(!should_unconfigure_after_secret_delete( + "huggingface", + "HF_TOKEN", + || true + )); + } + + #[test] + fn deleting_other_provider_secret_does_not_unconfigure_huggingface() { + assert!(!should_unconfigure_after_secret_delete( + "openai", + "OPENAI_API_KEY", + || false + )); + } + + #[test] + fn huggingface_oauth_secret_is_permanent_without_token() { + let secret = build_huggingface_oauth_secret(None); + + assert_eq!(secret.id, "provider_cache:huggingface"); + assert_eq!(secret.provider_display_name, "Hugging Face"); + assert_eq!(secret.name, "OAuth token"); + assert_eq!(secret.storage, ProviderSecretStorage::ProviderCache); + assert_eq!(secret.status, ProviderSecretStatus::Unknown); + assert!(!secret.configured); + assert!(!secret.has_secret); + assert!(!secret.can_delete); + assert!(secret.can_configure); + assert_eq!(secret.configure_provider.as_deref(), Some("huggingface")); + } + + #[test] + fn huggingface_oauth_secret_reports_cached_token_metadata() { + let expires_at = Utc::now() + chrono::Duration::hours(1); + let secret = build_huggingface_oauth_secret(Some(huggingface_auth::HuggingFaceTokenData { + access_token: "hidden".to_string(), + refresh_token: None, + expires_at: Some(expires_at), + })); + + assert_eq!( + secret.expires_at.map(|value| value.timestamp()), + Some(expires_at.timestamp()) + ); + assert_eq!(secret.status, ProviderSecretStatus::Valid); + assert!(secret.configured); + assert!(secret.has_secret); + assert!(secret.can_delete); + } +} diff --git a/crates/goose-server/src/routes/local_inference.rs b/crates/goose-server/src/routes/local_inference.rs index 8ee9bc4eebcf..c53733f2c30d 100644 --- a/crates/goose-server/src/routes/local_inference.rs +++ b/crates/goose-server/src/routes/local_inference.rs @@ -11,6 +11,7 @@ use axum::{ use futures::future::join_all; use goose::config::paths::Paths; use goose::download_manager::{get_download_manager, DownloadProgress}; +use goose::providers::huggingface_auth; use goose::providers::local_inference::hf_models::{self, HfModelInfo, HfQuantVariant}; use goose::providers::local_inference::{ available_inference_memory_bytes, builtin_chat_template_names, @@ -251,6 +252,7 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> { // Auto-download mmproj files for models that are already downloaded. // Deduplicate by path since multiple quants share one mmproj file. let dm = get_download_manager(); + let hf_token = huggingface_auth::resolve_token_async().await.ok().flatten(); let mut started_paths = std::collections::HashSet::new(); for (model_id, url, path) in mmproj_downloads_needed { if !path.exists() && started_paths.insert(path.clone()) { @@ -260,7 +262,16 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> { .is_some_and(|p| p.status == goose::download_manager::DownloadStatus::Downloading); if !dominated_by_active { tracing::info!(model_id = %model_id, "Auto-downloading vision encoder for existing model"); - if let Err(e) = dm.download_model(download_id, url, path, None).await { + if let Err(e) = dm + .download_model_with_bearer_token( + download_id, + url, + path, + hf_token.clone(), + None, + ) + .await + { tracing::warn!(model_id = %model_id, error = %e, "Failed to start mmproj download"); } } @@ -471,6 +482,7 @@ pub async fn download_hf_model( let (_repo, resolved) = resolve_model_spec_full(&req.spec) .await .map_err(|e| ErrorResponse::bad_request(format!("Invalid spec: {}", e)))?; + let hf_token = huggingface_auth::resolve_token_async().await.ok().flatten(); let model_id = model_id_from_repo(&repo_id, &quantization); let models_dir = Paths::in_data_dir("models"); @@ -545,10 +557,11 @@ pub async fn download_hf_model( .map(|f| (f.download_url.clone(), models_dir.join(&f.filename))) .collect(); - dm.download_model_sharded( + dm.download_model_sharded_with_bearer_token( format!("{}-model", model_id), all_files, resolved.total_size, + hf_token.clone(), None, ) .await @@ -556,10 +569,11 @@ pub async fn download_hf_model( if let Some((mmproj_path, mmproj_url)) = mmproj_path { if !mmproj_path.exists() { - dm.download_model( + dm.download_model_with_bearer_token( format!("{}-mmproj", model_id), mmproj_url, mmproj_path, + hf_token, None, ) .await diff --git a/crates/goose-server/src/routes/recipe_utils.rs b/crates/goose-server/src/routes/recipe_utils.rs index 800283ddb46a..085fe3542088 100644 --- a/crates/goose-server/src/routes/recipe_utils.rs +++ b/crates/goose-server/src/routes/recipe_utils.rs @@ -1,7 +1,4 @@ use std::collections::HashMap; -use std::fs; -use std::hash::DefaultHasher; -use std::hash::{Hash, Hasher}; use std::path::PathBuf; use std::sync::Arc; @@ -10,10 +7,10 @@ use crate::state::AppState; use anyhow::Result; use axum::http::StatusCode; use goose::agents::Agent; -use goose::recipe::build_recipe::{ - build_recipe_from_template, resolve_sub_recipe_path, RecipeError, -}; -use goose::recipe::local_recipes::{get_recipe_library_dir, list_local_recipes}; +use goose::recipe::build_recipe::{build_recipe_from_template, RecipeError}; +use goose::recipe::local_recipes::get_recipe_library_dir; +pub use goose::recipe::manifest::short_id_from_path; +use goose::recipe::manifest::{list_recipe_file_manifests, load_recipe_from_path}; use goose::recipe::validate_recipe::validate_recipe_template_from_content; use goose::recipe::Recipe; use serde::Serialize; @@ -36,44 +33,18 @@ pub struct RecipeManifest { pub slash_command: Option, } -pub fn short_id_from_path(path: &str) -> String { - let mut hasher = DefaultHasher::new(); - path.hash(&mut hasher); - let h = hasher.finish(); - format!("{:016x}", h) -} - pub fn get_all_recipes_manifests() -> Result> { - let recipes_with_path = list_local_recipes()?; - let mut recipe_manifests_with_path = Vec::new(); - for (file_path, mut recipe) in recipes_with_path { - let Ok(last_modified) = fs::metadata(file_path.clone()) - .map(|m| chrono::DateTime::::from(m.modified().unwrap()).to_rfc3339()) - else { - continue; - }; - - if let Some(recipe_dir) = file_path.parent() { - if let Some(ref mut sub_recipes) = recipe.sub_recipes { - for sr in sub_recipes.iter_mut() { - if let Ok(resolved) = resolve_sub_recipe_path(&sr.path, recipe_dir) { - sr.path = resolved; - } - } - } - } - - let manifest_with_path = RecipeManifest { - id: short_id_from_path(file_path.to_string_lossy().as_ref()), - recipe, - file_path, - last_modified, + let recipe_manifests_with_path = list_recipe_file_manifests()? + .into_iter() + .map(|manifest| RecipeManifest { + id: manifest.id, + recipe: manifest.recipe, + file_path: manifest.file_path, + last_modified: manifest.last_modified, schedule_cron: None, slash_command: None, - }; - recipe_manifests_with_path.push(manifest_with_path); - } - recipe_manifests_with_path.sort_by(|a, b| b.last_modified.cmp(&a.last_modified)); + }) + .collect(); Ok(recipe_manifests_with_path) } @@ -138,22 +109,10 @@ pub async fn get_recipe_file_path_by_id( pub async fn load_recipe_by_id(state: &AppState, id: &str) -> Result { let path = get_recipe_file_path_by_id(state, id).await?; - let mut recipe = Recipe::from_file_path(&path).map_err(|err| ErrorResponse { + load_recipe_from_path(&path).map_err(|err| ErrorResponse { message: format!("Failed to load recipe: {}", err), status: StatusCode::INTERNAL_SERVER_ERROR, - })?; - - if let Some(recipe_dir) = path.parent() { - if let Some(ref mut sub_recipes) = recipe.sub_recipes { - for sr in sub_recipes.iter_mut() { - if let Ok(resolved) = resolve_sub_recipe_path(&sr.path, recipe_dir) { - sr.path = resolved; - } - } - } - } - - Ok(recipe) + }) } pub async fn build_recipe_with_parameter_values( diff --git a/crates/goose-server/src/routes/utils.rs b/crates/goose-server/src/routes/utils.rs index 2c4a7d2a70d2..51cd8710b0ad 100644 --- a/crates/goose-server/src/routes/utils.rs +++ b/crates/goose-server/src/routes/utils.rs @@ -1,6 +1,7 @@ -use goose::config::declarative_providers::load_provider; +use goose::config::declarative_providers::{load_provider, LoadedProvider}; use goose::config::Config; use goose::providers::base::{ConfigKey, ProviderMetadata, ProviderType}; +use goose::providers::huggingface_auth; use serde::{Deserialize, Serialize}; use std::env; use std::error::Error; @@ -92,15 +93,37 @@ pub fn inspect_keys( } pub fn check_provider_configured(metadata: &ProviderMetadata, provider_type: ProviderType) -> bool { - let config = Config::global(); + check_provider_configured_with_huggingface_oauth(metadata, provider_type, || { + huggingface_auth::has_usable_or_refreshable_oauth_token() + }) +} +fn check_provider_configured_with_huggingface_oauth( + metadata: &ProviderMetadata, + provider_type: ProviderType, + has_usable_huggingface_oauth_token: impl Fn() -> bool, +) -> bool { // Special override if metadata.name == "local" { return true; } + if accepts_huggingface_oauth(metadata, None, &has_usable_huggingface_oauth_token) { + return true; + } + + let config = Config::global(); + if provider_type == ProviderType::Custom || provider_type == ProviderType::Declarative { if let Ok(loaded_provider) = load_provider(metadata.name.as_str()) { + if accepts_huggingface_oauth( + metadata, + Some(&loaded_provider), + &has_usable_huggingface_oauth_token, + ) { + return true; + } + if !loaded_provider.config.requires_auth { return true; } @@ -211,3 +234,89 @@ pub fn check_provider_configured(metadata: &ProviderMetadata, provider_type: Pro is_set_in_env || is_set_in_config }) } + +fn accepts_huggingface_oauth( + metadata: &ProviderMetadata, + loaded_provider: Option<&LoadedProvider>, + has_usable_huggingface_oauth_token: &impl Fn() -> bool, +) -> bool { + let is_huggingface_provider = metadata.name == huggingface_auth::HUGGINGFACE_PROVIDER_NAME + || loaded_provider.is_some_and(|provider| { + provider.config.catalog_provider_id.as_deref() + == Some(huggingface_auth::HUGGINGFACE_PROVIDER_NAME) + }); + + is_huggingface_provider && has_usable_huggingface_oauth_token() +} + +#[cfg(test)] +mod tests { + use super::*; + use goose::config::declarative_providers::{DeclarativeProviderConfig, ProviderEngine}; + use goose::providers::base::ModelInfo; + + fn huggingface_metadata() -> ProviderMetadata { + ProviderMetadata::new( + huggingface_auth::HUGGINGFACE_PROVIDER_NAME, + huggingface_auth::HUGGINGFACE_DISPLAY_NAME, + "Hugging Face provider", + "Qwen/Qwen3-Coder-480B-A35B-Instruct", + vec![], + "https://huggingface.co/docs/inference-providers", + vec![ConfigKey::new( + huggingface_auth::HUGGINGFACE_TOKEN_SECRET_KEY, + true, + true, + None, + true, + )], + ) + } + + #[test] + fn huggingface_oauth_token_counts_as_configured_without_hf_token() { + assert!(check_provider_configured_with_huggingface_oauth( + &huggingface_metadata(), + ProviderType::Builtin, + || true, + )); + } + + #[test] + fn huggingface_catalog_provider_oauth_counts_as_configured() { + let mut metadata = huggingface_metadata(); + metadata.name = "custom-huggingface".to_string(); + + let loaded_provider = LoadedProvider { + config: DeclarativeProviderConfig { + name: metadata.name.clone(), + engine: ProviderEngine::OpenAI, + display_name: "Custom Hugging Face".to_string(), + description: None, + api_key_env: String::new(), + base_url: "https://router.huggingface.co/v1".to_string(), + models: vec![ModelInfo::new("test-model", 128_000)], + headers: None, + timeout_seconds: None, + supports_streaming: None, + requires_auth: true, + catalog_provider_id: Some(huggingface_auth::HUGGINGFACE_PROVIDER_NAME.to_string()), + base_path: None, + env_vars: None, + dynamic_models: None, + skip_canonical_filtering: false, + model_doc_link: None, + setup_steps: vec![], + fast_model: None, + preserves_thinking: false, + }, + is_editable: false, + }; + + assert!(accepts_huggingface_oauth( + &metadata, + Some(&loaded_provider), + &|| true, + )); + } +} diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index fccd7dbbc6d8..909dcb60836b 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -137,13 +137,13 @@ sqlx = { version = "0.8.5", default-features = false, features = [ ] } # For Bedrock provider (optional, behind "aws-providers" feature) -aws-config = { version = "1.6", default-features = false, features = ["credentials-process", "rt-tokio", "sso", "behavior-version-latest"], optional = true } +aws-config = { version = "1.8", default-features = false, features = ["credentials-process", "rt-tokio", "sso", "behavior-version-latest"], optional = true } aws-smithy-types = { version = "1.3.4", default-features = false, features = ["rt-tokio"], optional = true } -aws-sdk-bedrockruntime = { version = "1.119", default-features = false, features = ["rt-tokio"], optional = true } +aws-sdk-bedrockruntime = { version = "1.132", default-features = false, features = ["rt-tokio"], optional = true } smithy-transport-reqwest = { version = "0.1", default-features = false, features = ["http2", "system-proxy"], optional = true } # For SageMaker TGI provider (optional, behind "aws-providers" feature) -aws-sdk-sagemakerruntime = { version = "1.64", default-features = false, features = ["rt-tokio"], optional = true } +aws-sdk-sagemakerruntime = { version = "1.104", default-features = false, features = ["rt-tokio"], optional = true } # For GCP Vertex AI provider auth jsonwebtoken = { version = "10.2", default-features = false, features = ["use_pem"] } @@ -152,8 +152,6 @@ blake3 = { version = "1", default-features = false, features = ["std"] } fs2 = { workspace = true } tokio-stream = { workspace = true, features = ["io-util"] } tempfile = { workspace = true } -dashmap = { version = "6", default-features = false } -ahash = { version = "0.8.11", default-features = false, features = ["std"] } tokio-util = { workspace = true, features = ["compat"] } agent-client-protocol-schema = { workspace = true } agent-client-protocol = { workspace = true, features = ["unstable"] } @@ -164,7 +162,7 @@ candle-core = { workspace = true, optional = true } candle-nn = { workspace = true, optional = true } candle-transformers = { version = "0.10", default-features = false, optional = true } byteorder = { version = "1.5", default-features = false, features = ["std"], optional = true } -tokenizers = { version = "0.21", default-features = false, features = ["onig"], optional = true } +tokenizers = { version = "0.22", default-features = false, features = ["onig"], optional = true } symphonia = { version = "0.5", default-features = false, features = ["aac", "adpcm", "alac", "isomp4", "mkv", "mp3", "pcm", "vorbis", "wav"], optional = true } rubato = { version = "0.16", default-features = false, optional = true } zip = { workspace = true } @@ -196,7 +194,7 @@ pastey = { version = "0.2", default-features = false } shell-words = { workspace = true } pem = { version = "3.0.2", default-features = false, features = ["std"], optional = true } pkcs1 = { version = "0.7.5", default-features = false, features = ["pkcs8", "std"], optional = true } -pkcs8 = { version = "0.10.2", default-features = false, features = ["alloc", "std"], optional = true } +pkcs8 = { version = "0.11.0", default-features = false, features = ["alloc", "std"], optional = true } sec1 = { version = "0.7", default-features = false, features = ["der", "pkcs8", "std"], optional = true } goose-acp-macros = { path = "../goose-acp-macros", default-features = false } tower-http = { workspace = true, features = ["cors"] } @@ -233,7 +231,7 @@ libc = { version = "0.2.182", default-features = false, features = ["std"] } [dev-dependencies] serial_test = { workspace = true } -mockall = { version = "0.13", default-features = false } +mockall = { version = "0.14", default-features = false } wiremock = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true, features = ["compat"] } @@ -248,7 +246,7 @@ bytes = { workspace = true } http = { workspace = true } goose-mcp = { path = "../goose-mcp", default-features = false } insta = { version = "1", default-features = false } -dtor = { version = "1.0.3", default-features = false, features = ["proc_macro"] } +dtor = { version = "1.0.5", default-features = false, features = ["proc_macro"] } [[example]] name = "agent" diff --git a/crates/goose/acp-meta.json b/crates/goose/acp-meta.json index e1d6679ef976..e72f437015fd 100644 --- a/crates/goose/acp-meta.json +++ b/crates/goose/acp-meta.json @@ -185,6 +185,11 @@ "requestType": "ImportSessionRequest_unstable", "responseType": "ImportSessionResponse_unstable" }, + { + "method": "_goose/unstable/elicitation/respond", + "requestType": "ElicitationRespondRequest_unstable", + "responseType": "EmptyResponse" + }, { "method": "_goose/unstable/session/project/update", "requestType": "UpdateSessionProjectRequest_unstable", @@ -285,5 +290,11 @@ "requestType": "DictationModelSelectRequest_unstable", "responseType": "EmptyResponse" } + ], + "notifications": [ + { + "method": "_goose/unstable/session/update", + "paramsType": "GooseSessionNotification_unstable" + } ] } diff --git a/crates/goose/acp-schema.json b/crates/goose/acp-schema.json index 9afa77e6c4a0..5980ff34e68d 100644 --- a/crates/goose/acp-schema.json +++ b/crates/goose/acp-schema.json @@ -1887,6 +1887,27 @@ "x-side": "agent", "x-method": "_goose/unstable/session/import" }, + "ElicitationRespondRequest_unstable": { + "type": "object", + "properties": { + "sessionId": { + "type": "string" + }, + "elicitationId": { + "type": "string" + }, + "userData": { + "default": null + } + }, + "required": [ + "sessionId", + "elicitationId" + ], + "description": "Submit a response for a pending MCP elicitation in an active session.", + "x-side": "agent", + "x-method": "_goose/unstable/elicitation/respond" + }, "UpdateSessionProjectRequest_unstable": { "type": "object", "properties": { @@ -2651,6 +2672,209 @@ "x-side": "agent", "x-method": "_goose/unstable/dictation/models/select" }, + "GooseSessionNotification_unstable": { + "type": "object", + "properties": { + "sessionId": { + "type": "string" + }, + "update": { + "$ref": "#/$defs/GooseSessionUpdate" + } + }, + "required": [ + "sessionId", + "update" + ], + "description": "Goose-custom session update notification — a parallel to ACP's\n`session/update` carrying goose-specific update variants.", + "x-side": "agent", + "x-method": "_goose/unstable/session/update" + }, + "GooseSessionUpdate": { + "oneOf": [ + { + "$ref": "#/$defs/SessionUsageUpdate", + "type": "object", + "properties": { + "sessionUpdate": { + "type": "string", + "const": "usage_update" + } + }, + "required": [ + "sessionUpdate" + ] + }, + { + "$ref": "#/$defs/StatusMessageUpdate", + "type": "object", + "properties": { + "sessionUpdate": { + "type": "string", + "const": "status_message" + } + }, + "required": [ + "sessionUpdate" + ] + }, + { + "$ref": "#/$defs/InteractionUpdate", + "type": "object", + "properties": { + "sessionUpdate": { + "type": "string", + "const": "interaction_update" + } + }, + "required": [ + "sessionUpdate" + ] + } + ], + "description": "Discriminated union of goose-specific session update payloads.\nVariant tag matches ACP's convention (`sessionUpdate: \"\"`).\n\n`discriminator.mapping` is what makes TS codegen (`@hey-api/openapi-ts`)\nemit the correct snake_case tag value even when this enum has a single\nvariant. Add a mapping entry per variant.", + "discriminator": { + "propertyName": "sessionUpdate", + "mapping": { + "usage_update": "#/$defs/SessionUsageUpdate", + "status_message": "#/$defs/StatusMessageUpdate", + "interaction_update": "#/$defs/InteractionUpdate" + } + } + }, + "SessionUsageUpdate": { + "type": "object", + "properties": { + "used": { + "type": "integer", + "minimum": 0 + }, + "contextLimit": { + "type": "integer", + "minimum": 0 + }, + "accumulatedInputTokens": { + "type": "integer", + "minimum": 0 + }, + "accumulatedOutputTokens": { + "type": "integer", + "minimum": 0 + }, + "accumulatedCost": { + "type": [ + "number", + "null" + ], + "format": "double" + } + }, + "required": [ + "used", + "contextLimit", + "accumulatedInputTokens", + "accumulatedOutputTokens" + ], + "description": "Streaming context-window usage update for a session." + }, + "StatusMessage": { + "oneOf": [ + { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "type": { + "type": "string", + "const": "notice" + } + }, + "required": [ + "type", + "message" + ] + }, + { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "type": { + "type": "string", + "const": "progress" + } + }, + "required": [ + "type", + "message" + ] + } + ] + }, + "StatusMessageUpdate": { + "type": "object", + "properties": { + "status": { + "$ref": "#/$defs/StatusMessage" + } + }, + "required": [ + "status" + ], + "description": "Live UI/session status. This is not conversation transcript content, and\nshould not be persisted or replayed as history." + }, + "Interaction": { + "oneOf": [ + { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/InteractionState" + }, + "message": { + "type": [ + "string", + "null" + ] + }, + "requestedSchema": {}, + "type": { + "type": "string", + "const": "elicitation" + } + }, + "required": [ + "type", + "id", + "state" + ] + } + ] + }, + "InteractionState": { + "type": "string", + "enum": [ + "pending", + "submitted" + ] + }, + "InteractionUpdate": { + "type": "object", + "properties": { + "interaction": { + "$ref": "#/$defs/Interaction" + }, + "_meta": {} + }, + "required": [ + "interaction" + ] + }, "ExtRequest": { "properties": { "id": { @@ -2996,6 +3220,15 @@ "description": "Params for _goose/unstable/session/import", "title": "ImportSessionRequest_unstable" }, + { + "allOf": [ + { + "$ref": "#/$defs/ElicitationRespondRequest_unstable" + } + ], + "description": "Params for _goose/unstable/elicitation/respond", + "title": "ElicitationRespondRequest_unstable" + }, { "allOf": [ { @@ -3523,6 +3756,42 @@ } ], "x-docs-ignore": true + }, + "ExtNotification": { + "properties": { + "method": { + "type": "string" + }, + "params": { + "anyOf": [ + { + "anyOf": [ + { + "allOf": [ + { + "$ref": "#/$defs/GooseSessionNotification_unstable" + } + ], + "description": "Params for _goose/unstable/session/update", + "title": "GooseSessionNotification_unstable" + } + ] + }, + { + "description": "Untyped params", + "type": [ + "object", + "null" + ] + } + ] + } + }, + "required": [ + "method" + ], + "type": "object", + "x-docs-ignore": true } }, "anyOf": [ @@ -3543,6 +3812,15 @@ ], "description": "Extension response (agent → client)", "title": "Response" + }, + { + "allOf": [ + { + "$ref": "#/$defs/ExtNotification" + } + ], + "description": "Extension notification (agent → client, fire-and-forget)", + "title": "Notification" } ] } diff --git a/crates/goose/src/acp/mod.rs b/crates/goose/src/acp/mod.rs index 594c14cac10e..1a0396a47714 100644 --- a/crates/goose/src/acp/mod.rs +++ b/crates/goose/src/acp/mod.rs @@ -3,12 +3,14 @@ mod common; pub(crate) mod fs; mod mcp_app_proxy; mod provider; +mod response_builder; pub mod server; pub mod server_factory; pub(crate) mod tools; pub mod transport; pub use common::{map_permission_response, PermissionDecision}; +pub use goose_sdk::custom_notifications; pub use goose_sdk::custom_requests; pub use provider::{ extension_configs_to_mcp_servers, AcpProvider, AcpProviderConfig, ACP_CURRENT_MODEL, diff --git a/crates/goose/src/acp/response_builder.rs b/crates/goose/src/acp/response_builder.rs new file mode 100644 index 000000000000..7a93bce5580e --- /dev/null +++ b/crates/goose/src/acp/response_builder.rs @@ -0,0 +1,400 @@ +use crate::config::GooseMode; +use crate::providers::inventory::{ProviderInventoryEntry, ProviderInventoryService}; +use crate::session::Session; +use agent_client_protocol::schema::{ + AvailableCommand, AvailableCommandInput, AvailableCommandsUpdate, ModelId, ModelInfo, + SessionConfigOption, SessionConfigOptionCategory, SessionConfigSelectOption, SessionId, + SessionMode, SessionModeId, SessionModeState, SessionModelState, SessionNotification, + SessionUpdate, UnstructuredCommandInput, +}; +use agent_client_protocol::{Client, ConnectionTo}; +use strum::{EnumMessage, VariantNames}; + +use super::server::{build_usage_updates, DEFAULT_PROVIDER_ID, DEFAULT_PROVIDER_LABEL}; + +pub(super) fn session_provider_selection(session: &Session) -> &str { + session + .provider_name + .as_deref() + .unwrap_or(DEFAULT_PROVIDER_ID) +} + +pub(super) fn build_model_state( + current_model: &str, + inventory: &ProviderInventoryEntry, +) -> SessionModelState { + let mut available_models = inventory + .models + .iter() + .map(|model| ModelInfo::new(ModelId::new(model.id.as_str()), model.name.as_str())) + .collect::>(); + if !available_models + .iter() + .any(|model| model.model_id.0.as_ref() == current_model) + { + available_models.insert( + 0, + ModelInfo::new(ModelId::new(current_model), current_model), + ); + } + SessionModelState::new(ModelId::new(current_model), available_models) +} + +struct ProviderOptionEntry { + id: String, + label: String, +} + +async fn list_provider_entries(current_provider: Option<&str>) -> Vec { + let mut providers = crate::providers::providers() + .await + .into_iter() + .map(|(metadata, _)| ProviderOptionEntry { + id: metadata.name, + label: metadata.display_name, + }) + .collect::>(); + providers.sort_by(|left, right| left.id.cmp(&right.id)); + providers.dedup_by(|left, right| left.id == right.id); + + if let Some(current_provider) = current_provider { + if current_provider != DEFAULT_PROVIDER_ID + && !providers + .iter() + .any(|provider| provider.id == current_provider) + { + providers.push(ProviderOptionEntry { + id: current_provider.to_string(), + label: current_provider.to_string(), + }); + providers.sort_by(|left, right| left.id.cmp(&right.id)); + } + } + + let mut entries = Vec::with_capacity(providers.len() + 1); + entries.push(ProviderOptionEntry { + id: DEFAULT_PROVIDER_ID.to_string(), + label: DEFAULT_PROVIDER_LABEL.to_string(), + }); + entries.extend(providers); + entries +} + +pub(super) async fn build_provider_options( + current_provider: Option<&str>, +) -> Vec { + list_provider_entries(current_provider) + .await + .into_iter() + .map(|provider| SessionConfigSelectOption::new(provider.id, provider.label)) + .collect() +} + +pub(super) fn should_refresh_inventory_for_session_init(entry: &ProviderInventoryEntry) -> bool { + entry.configured + && entry.supports_refresh + && (entry.last_updated_at.is_none() || ProviderInventoryService::is_stale(entry)) +} + +pub(super) fn build_mode_state( + current_mode: GooseMode, +) -> Result { + let mut available = Vec::with_capacity(GooseMode::VARIANTS.len()); + for &name in GooseMode::VARIANTS { + let goose_mode: GooseMode = name.parse().map_err(|_| { + agent_client_protocol::Error::internal_error() // impossible but satisfy linters + .data(format!("Failed to parse GooseMode variant: {}", name)) + })?; + let mut mode = SessionMode::new(SessionModeId::new(name), name); + mode.description = goose_mode.get_message().map(Into::into); + available.push(mode); + } + Ok(SessionModeState::new( + SessionModeId::new(current_mode.to_string()), + available, + )) +} + +pub(super) async fn build_session_setup_config( + provider_inventory: &ProviderInventoryService, + session: &Session, +) -> Result< + ( + SessionModeState, + Option, + Option>, + ), + agent_client_protocol::Error, +> { + let mode_state = build_mode_state(session.goose_mode)?; + + let (Some(provider_name), Some(model_config)) = ( + session.provider_name.as_deref(), + session.model_config.as_ref(), + ) else { + return Ok((mode_state, None, None)); + }; + let Some(inventory) = provider_inventory + .find_entry_for_provider(provider_name) + .await + else { + return Ok((mode_state, None, None)); + }; + let model_state = build_model_state(model_config.model_name.as_str(), &inventory); + let provider_selection = session_provider_selection(session); + let provider_options = build_provider_options(Some(provider_name)).await; + let config_options = build_config_options( + &mode_state, + &model_state, + provider_selection, + provider_options, + ); + Ok((mode_state, Some(model_state), Some(config_options))) +} + +pub(super) fn build_config_options( + mode_state: &SessionModeState, + model_state: &SessionModelState, + provider_selection: &str, + provider_options: Vec, +) -> Vec { + let mode_options: Vec = mode_state + .available_modes + .iter() + .map(|m| { + SessionConfigSelectOption::new(m.id.0.clone(), m.name.clone()) + .description(m.description.clone()) + }) + .collect(); + let model_options: Vec = model_state + .available_models + .iter() + .map(|m| SessionConfigSelectOption::new(m.model_id.0.clone(), m.name.clone())) + .collect(); + vec![ + SessionConfigOption::select( + "provider", + "Provider", + provider_selection.to_string(), + provider_options, + ), + SessionConfigOption::select( + "mode", + "Mode", + mode_state.current_mode_id.0.clone(), + mode_options, + ) + .category(SessionConfigOptionCategory::Mode), + SessionConfigOption::select( + "model", + "Model", + model_state.current_model_id.0.clone(), + model_options, + ) + .category(SessionConfigOptionCategory::Model), + ] +} + +fn available_commands_update(working_dir: &std::path::Path) -> AvailableCommandsUpdate { + let commands = crate::slash_commands::slash_command::list_acp_commands(Some(working_dir)) + .into_iter() + .map(|entry| { + let mut command = AvailableCommand::new(entry.name, entry.description); + if let Some(input_hint) = entry.input_hint { + command = command.input(AvailableCommandInput::Unstructured( + UnstructuredCommandInput::new(input_hint), + )); + } + command + }) + .collect(); + + AvailableCommandsUpdate::new(commands) +} + +pub(super) fn send_session_setup_notifications( + cx: &ConnectionTo, + session: &Session, +) -> Result<(), agent_client_protocol::Error> { + let session_id = SessionId::new(session.id.clone()); + if let Some(updates) = build_usage_updates(session) { + cx.send_notification(updates.custom)?; + cx.send_notification(SessionNotification::new( + session_id.clone(), + SessionUpdate::UsageUpdate(updates.standard), + ))?; + } + cx.send_notification(SessionNotification::new( + session_id, + SessionUpdate::AvailableCommandsUpdate(available_commands_update(&session.working_dir)), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use test_case::test_case; + + #[test_case( + vec!["model-a".into(), "model-b".into()] + => SessionModelState::new( + ModelId::new("unused"), + vec![ModelInfo::new(ModelId::new("unused"), "unused"), + ModelInfo::new(ModelId::new("model-a"), "model-a"), + ModelInfo::new(ModelId::new("model-b"), "model-b")], + ) + ; "returns current and available models" + )] + #[test_case( + vec![] + => SessionModelState::new( + ModelId::new("unused"), + vec![ModelInfo::new(ModelId::new("unused"), "unused")], + ) + ; "empty model list" + )] + fn test_build_model_state(models: Vec) -> SessionModelState { + let inventory = ProviderInventoryEntry { + provider_id: "mock".to_string(), + provider_name: "Mock".to_string(), + description: "Mock".to_string(), + default_model: "unused".to_string(), + configured: true, + provider_type: crate::providers::base::ProviderType::Builtin, + category: crate::providers::catalog::ProviderSetupCategory::Model, + config_keys: vec![], + setup_steps: vec![], + supports_refresh: true, + refreshing: false, + models: models + .into_iter() + .map(|id| crate::providers::inventory::InventoryModel { + name: id.clone(), + id, + family: None, + context_limit: None, + reasoning: None, + recommended: false, + }) + .collect(), + last_updated_at: None, + last_refresh_attempt_at: None, + last_refresh_error: None, + model_selection_hint: None, + }; + build_model_state("unused", &inventory) + } + + #[test_case( + GooseMode::Auto + => Ok(SessionModeState::new( + SessionModeId::new("auto"), + vec![ + SessionMode::new(SessionModeId::new("auto"), "auto") + .description("Automatically approve tool calls"), + SessionMode::new(SessionModeId::new("approve"), "approve") + .description("Ask before every tool call"), + SessionMode::new(SessionModeId::new("smart_approve"), "smart_approve") + .description("Ask only for sensitive tool calls"), + SessionMode::new(SessionModeId::new("chat"), "chat") + .description("Chat only, no tool calls"), + ], + )) + ; "auto mode" + )] + #[test_case( + GooseMode::Approve + => Ok(SessionModeState::new( + SessionModeId::new("approve"), + vec![ + SessionMode::new(SessionModeId::new("auto"), "auto") + .description("Automatically approve tool calls"), + SessionMode::new(SessionModeId::new("approve"), "approve") + .description("Ask before every tool call"), + SessionMode::new(SessionModeId::new("smart_approve"), "smart_approve") + .description("Ask only for sensitive tool calls"), + SessionMode::new(SessionModeId::new("chat"), "chat") + .description("Chat only, no tool calls"), + ], + )) + ; "approve mode" + )] + fn test_build_mode_state( + current_mode: GooseMode, + ) -> Result { + build_mode_state(current_mode) + } + + #[test_case( + build_mode_state(GooseMode::Auto).unwrap(), + "openai", + vec![ + SessionConfigSelectOption::new("anthropic", "anthropic"), + SessionConfigSelectOption::new("openai", "openai"), + ], + SessionModelState::new( + ModelId::new("gpt-4"), + vec![ModelInfo::new(ModelId::new("gpt-4"), "gpt-4"), ModelInfo::new(ModelId::new("gpt-3.5"), "gpt-3.5")], + ) + => vec![ + SessionConfigOption::select( + "provider", "Provider", "openai", + vec![ + SessionConfigSelectOption::new("anthropic", "anthropic"), + SessionConfigSelectOption::new("openai", "openai"), + ], + ), + SessionConfigOption::select( + "mode", "Mode", "auto", + vec![ + SessionConfigSelectOption::new("auto", "auto").description("Automatically approve tool calls"), + SessionConfigSelectOption::new("approve", "approve").description("Ask before every tool call"), + SessionConfigSelectOption::new("smart_approve", "smart_approve").description("Ask only for sensitive tool calls"), + SessionConfigSelectOption::new("chat", "chat").description("Chat only, no tool calls"), + ], + ).category(SessionConfigOptionCategory::Mode), + SessionConfigOption::select( + "model", "Model", "gpt-4", + vec![ + SessionConfigSelectOption::new("gpt-4", "gpt-4"), + SessionConfigSelectOption::new("gpt-3.5", "gpt-3.5"), + ], + ).category(SessionConfigOptionCategory::Model), + ] + ; "auto mode with multiple models" + )] + #[test_case( + build_mode_state(GooseMode::Approve).unwrap(), + "openai", + vec![SessionConfigSelectOption::new("openai", "openai")], + SessionModelState::new(ModelId::new("only-model"), vec![ModelInfo::new(ModelId::new("only-model"), "only-model")]) + => vec![ + SessionConfigOption::select( + "provider", "Provider", "openai", + vec![SessionConfigSelectOption::new("openai", "openai")], + ), + SessionConfigOption::select( + "mode", "Mode", "approve", + vec![ + SessionConfigSelectOption::new("auto", "auto").description("Automatically approve tool calls"), + SessionConfigSelectOption::new("approve", "approve").description("Ask before every tool call"), + SessionConfigSelectOption::new("smart_approve", "smart_approve").description("Ask only for sensitive tool calls"), + SessionConfigSelectOption::new("chat", "chat").description("Chat only, no tool calls"), + ], + ).category(SessionConfigOptionCategory::Mode), + SessionConfigOption::select( + "model", "Model", "only-model", + vec![SessionConfigSelectOption::new("only-model", "only-model")], + ).category(SessionConfigOptionCategory::Model), + ] + ; "approve mode with single model" + )] + fn test_build_config_options( + mode_state: SessionModeState, + provider_name: &'static str, + provider_options: Vec, + model_state: SessionModelState, + ) -> Vec { + build_config_options(&mode_state, &model_state, provider_name, provider_options) + } +} diff --git a/crates/goose/src/acp/server.rs b/crates/goose/src/acp/server.rs index b95010df96df..d6bd5f7f1d97 100644 --- a/crates/goose/src/acp/server.rs +++ b/crates/goose/src/acp/server.rs @@ -1,49 +1,61 @@ +use crate::acp::custom_notifications::*; use crate::acp::custom_requests::*; use crate::acp::fs::AcpTools; +pub(super) use crate::acp::response_builder::{ + build_config_options, build_mode_state, build_model_state, build_provider_options, + build_session_setup_config, send_session_setup_notifications, session_provider_selection, + should_refresh_inventory_for_session_init, +}; use crate::acp::tools::AcpAwareToolMeta; use crate::acp::{PermissionDecision, ACP_CURRENT_MODEL}; +use crate::action_required_manager::ActionRequiredManager; use crate::agents::extension::{Envs, PLATFORM_EXTENSIONS}; use crate::agents::extension_manager::TRUSTED_TOOL_UPDATE_META_KEY; use crate::agents::mcp_client::{GooseMcpHostInfo, McpClientTrait}; use crate::agents::platform_extensions::developer::DeveloperClient; -use crate::agents::{Agent, AgentConfig, ExtensionConfig, GoosePlatform, SessionConfig}; +use crate::agents::{ + Agent, AgentConfig, ExtensionConfig, ExtensionLoadResult, GoosePlatform, SessionConfig, +}; use crate::config::base::CONFIG_YAML_NAME; use crate::config::extensions::get_enabled_extensions_with_config; use crate::config::paths::Paths; use crate::config::permission::PermissionManager; use crate::config::{Config, GooseMode}; -use crate::conversation::message::{ActionRequiredData, Message, MessageContent, ToolRequest}; +use crate::conversation::message::{ + ActionRequiredData, Message, MessageContent, SystemNotificationContent, SystemNotificationType, + ToolRequest, +}; +use crate::execution::manager::{AgentManager, AgentManagerGetResult, RuntimeContext}; use crate::mcp_utils::ToolResult; use crate::permission::permission_confirmation::PrincipalType; use crate::permission::{Permission, PermissionConfirmation}; use crate::providers::base::Provider; use crate::providers::inventory::{ - InventoryIdentity, ProviderInventoryEntry, ProviderInventoryService, RefreshJobPlan, - RefreshPlan, RefreshSkipReason, + ProviderInventoryEntry, ProviderInventoryService, RefreshJobPlan, RefreshPlan, + RefreshSkipReason, }; use crate::session::session_manager::{SessionListCursor, SessionType}; -use crate::session::{EnabledExtensionsState, Session, SessionManager}; +use crate::session::{ + EnabledExtensionsState, ExtensionData, ExtensionState, Session, SessionManager, +}; use crate::source_roots::SourceRoot; use crate::utils::sanitize_unicode_tags; use agent_client_protocol::schema::{ AgentCapabilities, Annotations, AuthMethod, AuthMethodAgent, AuthenticateRequest, - AuthenticateResponse, AvailableCommand, AvailableCommandInput, AvailableCommandsUpdate, - BlobResourceContents, CancelNotification, CloseSessionRequest, CloseSessionResponse, - ConfigOptionUpdate, Content, ContentBlock, ContentChunk, CurrentModeUpdate, EmbeddedResource, - EmbeddedResourceResource, FileSystemCapabilities, ForkSessionRequest, ForkSessionResponse, - ImageContent, InitializeRequest, InitializeResponse, ListSessionsRequest, ListSessionsResponse, - LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, Meta, ModelId, ModelInfo, - NewSessionRequest, NewSessionResponse, PermissionOption, PermissionOptionKind, - PromptCapabilities, PromptRequest, PromptResponse, RequestPermissionOutcome, - RequestPermissionRequest, ResourceLink, SessionCapabilities, SessionCloseCapabilities, - SessionConfigOption, SessionConfigOptionCategory, SessionConfigSelectOption, SessionId, - SessionInfo, SessionInfoUpdate, SessionListCapabilities, SessionMode, SessionModeId, - SessionModeState, SessionModelState, SessionNotification, SessionUpdate, - SetSessionConfigOptionRequest, SetSessionConfigOptionResponse, SetSessionModeRequest, - SetSessionModeResponse, SetSessionModelRequest, SetSessionModelResponse, StopReason, - TextContent, TextResourceContents, ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, - ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, ToolKind, UnstructuredCommandInput, - Usage, UsageUpdate, + AuthenticateResponse, BlobResourceContents, CancelNotification, CloseSessionRequest, + CloseSessionResponse, ConfigOptionUpdate, Content, ContentBlock, ContentChunk, + CurrentModeUpdate, EmbeddedResource, EmbeddedResourceResource, FileSystemCapabilities, + ForkSessionRequest, ForkSessionResponse, ImageContent, InitializeRequest, InitializeResponse, + ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, + McpCapabilities, McpServer, Meta, NewSessionRequest, NewSessionResponse, PermissionOption, + PermissionOptionKind, PromptCapabilities, PromptRequest, PromptResponse, + RequestPermissionOutcome, RequestPermissionRequest, ResourceLink, SessionCapabilities, + SessionCloseCapabilities, SessionConfigOption, SessionId, SessionInfo, SessionInfoUpdate, + SessionListCapabilities, SessionNotification, SessionUpdate, SetSessionConfigOptionRequest, + SetSessionConfigOptionResponse, SetSessionModeRequest, SetSessionModeResponse, + SetSessionModelRequest, SetSessionModelResponse, StopReason, TextContent, TextResourceContents, + ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate, + ToolCallUpdateFields, ToolKind, Usage, UsageUpdate, }; use agent_client_protocol::util::MatchDispatchFrom; use agent_client_protocol::{ @@ -53,7 +65,7 @@ use agent_client_protocol::{ use anyhow::Result; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; use fs_err as fs; -use futures::future::{BoxFuture, Either}; +use futures::future::BoxFuture; use futures::stream::{self, StreamExt}; use futures::FutureExt; use rmcp::model::{ @@ -65,7 +77,6 @@ use std::collections::{HashMap, HashSet}; use std::panic::AssertUnwindSafe; use std::path::{Path, PathBuf}; use std::sync::Arc; -use strum::{EnumMessage, VariantNames}; use tokio::sync::{Mutex, OnceCell}; use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _}; use tokio_util::sync::CancellationToken; @@ -77,10 +88,13 @@ mod custom_dispatch; mod dictation; mod dispatch; mod extensions; +mod fork_session; +mod load_session; +mod manage_sessions; +mod new_session; mod onboarding; mod providers; mod resources; -mod sessions; mod sources; mod tools; @@ -131,24 +145,10 @@ impl ResultExt for Result { } } -const DEFAULT_PROVIDER_ID: &str = "goose"; -const DEFAULT_PROVIDER_LABEL: &str = "Goose (Default)"; +pub(super) const DEFAULT_PROVIDER_ID: &str = "goose"; +pub(super) const DEFAULT_PROVIDER_LABEL: &str = "Goose (Default)"; const PROVIDER_CONFIG_STATUS_CHECK_CONCURRENCY: usize = 16; -async fn ensure_refresh_identity_current( - provider_id: &str, - planned_identity: &InventoryIdentity, -) -> Result<()> { - let current_identity = crate::providers::inventory_identity(provider_id) - .await? - .into_identity()?; - if current_identity != *planned_identity { - anyhow::bail!("provider inventory identity changed before refresh completed"); - } - - Ok(()) -} - /// In-memory state for an active ACP session. /// /// ## Terminology (temporary, until all clients migrate to ACP) @@ -161,7 +161,7 @@ async fn ensure_refresh_identity_current( /// The ACP session ID maps directly to a `sessions` row. The `sessions` HashMap /// below is keyed by session ID. struct GooseAcpSession { - agent: AgentHandle, + agent: Arc, tool_requests: HashMap, /// For each tool_call_id that belongs to a multi-tool chain (run of /// consecutive ToolRequest blocks within one assistant message), the chain @@ -176,9 +176,6 @@ struct GooseAcpSession { /// Idempotence guard so we summarize each chain at most once. summarized_chains: HashSet, cancel_token: Option, - /// Working directory set while the agent was still loading. - /// Applied once the agent becomes ready. - pending_working_dir: Option, } /// A run of consecutive ToolRequest blocks within one assistant message, @@ -193,45 +190,11 @@ struct ToolChain { message_id: String, } -/// Progress stages signalled by the background agent setup task via the watch -/// channel. `ProviderReady` fires as soon as the provider (and goose-mode) -/// are initialized — before extensions finish loading. `FullyReady` fires -/// once every extension has been loaded (or failed). -#[derive(Clone)] -enum AgentSetupProgress { - /// Provider is initialized; extensions are still loading in the background. - ProviderReady(Arc), - /// Provider *and* all extensions are initialized. - FullyReady(Arc), -} - -type AgentSetupSignal = Option>; - -/// The agent may still be initializing in the background (extension loading, -/// provider setup). Callers that need the live agent (e.g. `on_prompt`) await -/// the handle; callers that only need the session metadata can proceed without it. -enum AgentHandle { - Ready(Arc), - Loading(tokio::sync::watch::Receiver), -} - -struct AgentSetupRequest { - session_id: SessionId, - goose_session: Session, - mcp_servers: Vec, - /// Pre-resolved provider name + model config (from config, no network). - /// When present the spawn skips re-deriving these from config. - resolved_provider: Option<(String, crate::model::ModelConfig)>, - /// Pre-instantiated provider reused from synchronous session initialization. - prebuilt_provider: Option>, -} - pub struct GooseAcpAgentOptions { pub provider_factory: AcpProviderFactory, pub builtins: Vec, pub data_dir: std::path::PathBuf, pub config_dir: std::path::PathBuf, - pub goose_mode: GooseMode, pub disable_session_naming: bool, pub goose_platform: GoosePlatform, pub additional_source_roots: Vec, @@ -239,26 +202,26 @@ pub struct GooseAcpAgentOptions { pub struct GooseAcpAgent { sessions: Arc>>, + agent_manager: Arc, provider_factory: AcpProviderFactory, builtins: Vec, client_fs_capabilities: OnceCell, client_terminal: OnceCell, client_mcp_host_info: OnceCell, use_login_shell_path: OnceCell, + client_cx: OnceCell>, config_dir: std::path::PathBuf, session_manager: Arc, permission_manager: Arc, - goose_mode: GooseMode, disable_session_naming: bool, provider_inventory: ProviderInventoryService, - goose_platform: GoosePlatform, additional_source_roots: Vec, } /// Shorten a session/thread id for perf log correlation. /// All `perf:` logs use `sid=<8-char-prefix>` so a single session's activity /// can be extracted with `grep 'perf:' | grep 'sid=abc12345'`. -fn sid_short(id: &str) -> String { +pub(super) fn sid_short(id: &str) -> String { id.chars().take(8).collect() } @@ -350,7 +313,20 @@ fn encode_session_list_cursor( Ok(URL_SAFE_NO_PAD.encode(bytes)) } -fn session_meta(session: &Session) -> serde_json::Map { +fn display_title(s: &Session) -> Option { + if !s.user_set_name { + if let Some(recipe) = &s.recipe { + return Some(recipe.title.clone()); + } + } + if s.name.is_empty() { + None + } else { + Some(s.name.clone()) + } +} + +pub(super) fn session_meta(session: &Session) -> serde_json::Map { let mut meta = serde_json::Map::new(); meta.insert( "messageCount".to_string(), @@ -370,6 +346,10 @@ fn session_meta(session: &Session) -> serde_json::Map "userSetName".to_string(), serde_json::Value::Bool(session.user_set_name), ); + meta.insert( + "hasRecipe".to_string(), + serde_json::Value::Bool(session.recipe.is_some()), + ); if let Some(ref pid) = session.project_id { meta.insert( @@ -392,6 +372,12 @@ fn session_meta(session: &Session) -> serde_json::Map meta } +fn meta_string(meta: Option<&Meta>, key: &str) -> Option { + meta.and_then(|m| m.get(key)) + .and_then(|v| v.as_str()) + .map(ToString::to_string) +} + fn spawn_session_name_update_notifier( cx: ConnectionTo, ) -> tokio::sync::mpsc::UnboundedSender { @@ -528,6 +514,54 @@ fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result, extension: ExtensionConfig) { + let name = extension.name().to_string(); + if let Some(index) = extensions + .iter() + .position(|existing| existing.name() == name) + { + extensions.remove(index); + } + extensions.push(extension); +} + +fn resolve_default_provider_model_config( + config: &Config, +) -> Result<(String, crate::model::ModelConfig), agent_client_protocol::Error> { + let resolved_provider = config.get_goose_provider().map_err(|error| { + agent_client_protocol::Error::internal_error() + .data(format!("Failed to resolve provider: {}", error)) + })?; + let resolved_model = config.get_goose_model().map_err(|error| { + agent_client_protocol::Error::internal_error() + .data(format!("Failed to resolve model: {}", error)) + })?; + let resolved_model_config = crate::model::ModelConfig::new(&resolved_model) + .map(|model_config| model_config.with_canonical_limits(&resolved_provider)) + .map_err(|error| { + agent_client_protocol::Error::internal_error() + .data(format!("Failed to resolve model: {}", error)) + })?; + Ok((resolved_provider, resolved_model_config)) +} + +async fn resolve_provider_default_model_config( + provider_name: &str, +) -> Result { + let entry = crate::providers::get_from_registry(provider_name) + .await + .map_err(|error| { + agent_client_protocol::Error::invalid_params() + .data(format!("Unknown provider '{}': {}", provider_name, error)) + })?; + crate::model::ModelConfig::new(&entry.metadata().default_model) + .map(|model_config| model_config.with_canonical_limits(provider_name)) + .map_err(|error| { + agent_client_protocol::Error::internal_error() + .data(format!("Failed to resolve model: {}", error)) + }) +} + fn get_requested_line(arguments: Option<&rmcp::model::JsonObject>) -> Option { arguments .and_then(|args| args.get("line")) @@ -795,50 +829,6 @@ struct PendingToolCall { fallback_title: String, } -/// Extract chains (runs of consecutive `MessageContent::ToolRequest` blocks) -/// from a single message's content. Mirrors the frontend's chain detection in -/// `MessageBubble.groupContentSections`: any non-tool block (text, thinking, -/// image, etc.) breaks the run. -/// -/// Returns one inner Vec per detected chain, holding the tool_call_ids in -/// document order. Single-tool runs are included; callers (chain -/// summarization) gate on `chain.len() >= 2`. -/// -/// Note: this is the per-message view, kept around for tests and potential -/// replay use. The live runtime path uses a streaming buffer fed by -/// [`register_chain_buffer`] so chains that span multiple `AgentEvent::Message` -/// events (e.g. Bedrock-style streaming, where one LLM message is split across -/// rows — see `f087fa63c`) are still detected. -#[allow(dead_code)] -fn extract_tool_chains( - content: &[crate::conversation::message::MessageContent], -) -> Vec> { - use crate::conversation::message::MessageContent; - let mut chains: Vec> = Vec::new(); - let mut current: Vec = Vec::new(); - - for block in content { - match block { - MessageContent::ToolRequest(tr) => current.push(tr.id.clone()), - MessageContent::ToolResponse(_) => { - // Server-side, assistant messages don't carry responses; - // responses arrive in subsequent messages. Treat as - // chain-neutral so a stray response doesn't split a chain - // if the data shape ever changes. - } - _ => { - if !current.is_empty() { - chains.push(std::mem::take(&mut current)); - } - } - } - } - if !current.is_empty() { - chains.push(current); - } - chains -} - /// If `buffer` holds a multi-tool run (≥ 2 tool requests), (re)register a /// [`ToolChain`] in `chain_membership` anchored on the **first** tool's /// message_id (the row [`SessionManager::update_tool_request_meta`] will patch @@ -849,11 +839,10 @@ fn extract_tool_chains( /// The buffer contains `(tool_call_id, message_id)` pairs in arrival order, /// fed by the prompt stream loop. Sequential tool use (Bedrock/Anthropic) /// interleaves request → response → request → response across separate -/// `AgentEvent::Message` events, so per-event `extract_tool_chains` only -/// sees length-1 chains and would miss the run. Tool responses are -/// chain-neutral (they don't split the run); only non-tool content (text, -/// thinking, image, etc.) does, matching the frontend's -/// `groupContentSections` behavior. +/// `AgentEvent::Message` events, so a per-event view would only see length-1 +/// chains and miss the run. Tool responses are chain-neutral (they don't +/// split the run); only non-tool content (text, thinking, image, etc.) does, +/// matching the frontend's `groupContentSections` behavior. fn extend_chain_membership( buffer: &[(String, String)], chain_membership: &mut HashMap>, @@ -928,117 +917,6 @@ fn builtin_to_extension_config(name: &str) -> ExtensionConfig { } } -fn build_model_state(current_model: &str, inventory: &ProviderInventoryEntry) -> SessionModelState { - let mut available_models = inventory - .models - .iter() - .map(|model| ModelInfo::new(ModelId::new(model.id.as_str()), model.name.as_str())) - .collect::>(); - if !available_models - .iter() - .any(|model| model.model_id.0.as_ref() == current_model) - { - available_models.insert( - 0, - ModelInfo::new(ModelId::new(current_model), current_model), - ); - } - SessionModelState::new(ModelId::new(current_model), available_models) -} - -struct ProviderOptionEntry { - id: String, - label: String, -} - -async fn list_provider_entries(current_provider: Option<&str>) -> Vec { - let mut providers = crate::providers::providers() - .await - .into_iter() - .map(|(metadata, _)| ProviderOptionEntry { - id: metadata.name, - label: metadata.display_name, - }) - .collect::>(); - providers.sort_by(|left, right| left.id.cmp(&right.id)); - providers.dedup_by(|left, right| left.id == right.id); - - if let Some(current_provider) = current_provider { - if current_provider != DEFAULT_PROVIDER_ID - && !providers - .iter() - .any(|provider| provider.id == current_provider) - { - providers.push(ProviderOptionEntry { - id: current_provider.to_string(), - label: current_provider.to_string(), - }); - providers.sort_by(|left, right| left.id.cmp(&right.id)); - } - } - - let mut entries = Vec::with_capacity(providers.len() + 1); - entries.push(ProviderOptionEntry { - id: DEFAULT_PROVIDER_ID.to_string(), - label: DEFAULT_PROVIDER_LABEL.to_string(), - }); - entries.extend(providers); - entries -} - -async fn build_provider_options(current_provider: Option<&str>) -> Vec { - list_provider_entries(current_provider) - .await - .into_iter() - .map(|provider| SessionConfigSelectOption::new(provider.id, provider.label)) - .collect() -} - -fn session_provider_selection(session: &Session) -> &str { - session - .provider_name - .as_deref() - .unwrap_or(DEFAULT_PROVIDER_ID) -} - -/// Resolve the provider name and model config for a session from an -/// already-loaded `Config`. -async fn resolve_provider_and_model_from_config( - config: &Config, - goose_session: &Session, -) -> Result<(String, crate::model::ModelConfig), String> { - let global_provider = config.get_goose_provider().ok(); - let provider_override = goose_session - .provider_name - .as_deref() - .filter(|p| *p != DEFAULT_PROVIDER_ID); - let provider_name = provider_override - .map(ToOwned::to_owned) - .or_else(|| global_provider.clone()) - .ok_or_else(|| "Missing provider".to_string())?; - let explicitly_switched = - provider_override.is_some() && provider_override != global_provider.as_deref(); - let model_config = match &goose_session.model_config { - Some(mc) => mc.clone(), - None if explicitly_switched => { - let entry = crate::providers::get_from_registry(&provider_name) - .await - .map_err(|e| e.to_string())?; - let default_model = &entry.metadata().default_model; - crate::model::ModelConfig::new(default_model) - .map_err(|e| e.to_string())? - .with_canonical_limits(&provider_name) - } - None => { - let model_id = config.get_goose_model().map_err(|e| e.to_string())?; - crate::model::ModelConfig::new(&model_id) - .map_err(|e| e.to_string())? - .with_canonical_limits(&provider_name) - } - }; - Ok((provider_name, model_config)) -} - fn with_preserved_session_request_params( mut model_config: crate::model::ModelConfig, current_model_config: Option<&crate::model::ModelConfig>, @@ -1067,100 +945,6 @@ fn with_preserved_session_request_params( model_config } -/// Convenience wrapper: reads config from disk, then resolves provider + model. -/// Cheap enough to call from `on_new_session` (file + registry reads, no network). -async fn resolve_provider_and_model( - config_dir: &std::path::Path, - goose_session: &Session, -) -> Result<(String, crate::model::ModelConfig), String> { - let config = - Config::new(config_dir.join(CONFIG_YAML_NAME), "goose").map_err(|e| e.to_string())?; - resolve_provider_and_model_from_config(&config, goose_session).await -} - -fn build_mode_state( - current_mode: GooseMode, -) -> Result { - let mut available = Vec::with_capacity(GooseMode::VARIANTS.len()); - for &name in GooseMode::VARIANTS { - let goose_mode: GooseMode = name.parse().map_err(|_| { - agent_client_protocol::Error::internal_error() // impossible but satisfy linters - .data(format!("Failed to parse GooseMode variant: {}", name)) - })?; - let mut mode = SessionMode::new(SessionModeId::new(name), name); - mode.description = goose_mode.get_message().map(Into::into); - available.push(mode); - } - Ok(SessionModeState::new( - SessionModeId::new(current_mode.to_string()), - available, - )) -} - -fn should_refresh_inventory_for_session_init(entry: &ProviderInventoryEntry) -> bool { - entry.configured - && entry.supports_refresh - && (entry.last_updated_at.is_none() || ProviderInventoryService::is_stale(entry)) -} - -async fn build_eager_config_from_inventory( - provider_name: &str, - current_model: &str, - inventory: &ProviderInventoryEntry, - mode_state: &SessionModeState, - goose_session: &Session, -) -> (SessionModelState, Vec) { - let ms = build_model_state(current_model, inventory); - let provider_selection = session_provider_selection(goose_session); - let provider_options = build_provider_options(Some(provider_name)).await; - let config_options = - build_config_options(mode_state, &ms, provider_selection, provider_options); - (ms, config_options) -} - -fn build_config_options( - mode_state: &SessionModeState, - model_state: &SessionModelState, - provider_selection: &str, - provider_options: Vec, -) -> Vec { - let mode_options: Vec = mode_state - .available_modes - .iter() - .map(|m| { - SessionConfigSelectOption::new(m.id.0.clone(), m.name.clone()) - .description(m.description.clone()) - }) - .collect(); - let model_options: Vec = model_state - .available_models - .iter() - .map(|m| SessionConfigSelectOption::new(m.model_id.0.clone(), m.name.clone())) - .collect(); - vec![ - SessionConfigOption::select( - "provider", - "Provider", - provider_selection.to_string(), - provider_options, - ), - SessionConfigOption::select( - "mode", - "Mode", - mode_state.current_mode_id.0.clone(), - mode_options, - ) - .category(SessionConfigOptionCategory::Mode), - SessionConfigOption::select( - "model", - "Model", - model_state.current_model_id.0.clone(), - model_options, - ) - .category(SessionConfigOptionCategory::Model), - ] -} - fn to_nonnegative_u64(value: Option) -> Option { value.and_then(|v| u64::try_from(v).ok()) } @@ -1172,12 +956,34 @@ fn build_prompt_usage(session: &Session) -> Option { Some(Usage::new(total, input, output)) } -fn build_usage_update(session: &Session, context_limit: usize) -> UsageUpdate { +pub(super) struct UsageUpdates { + pub(super) custom: GooseSessionNotification, + pub(super) standard: UsageUpdate, +} + +pub(super) fn build_usage_updates(session: &Session) -> Option { let used = session.total_tokens.unwrap_or(0).max(0) as u64; - UsageUpdate::new(used, context_limit as u64) + let ctx_limit = session.model_config.as_ref()?.context_limit() as u64; + let accumulated_input_tokens = + to_nonnegative_u64(session.accumulated_input_tokens).unwrap_or(0); + let accumulated_output_tokens = + to_nonnegative_u64(session.accumulated_output_tokens).unwrap_or(0); + Some(UsageUpdates { + custom: GooseSessionNotification { + session_id: session.id.clone(), + update: GooseSessionUpdate::UsageUpdate(SessionUsageUpdate { + used, + context_limit: ctx_limit, + accumulated_input_tokens, + accumulated_output_tokens, + accumulated_cost: session.accumulated_cost, + }), + }, + standard: UsageUpdate::new(used, ctx_limit), + }) } -fn validate_absolute_cwd(cwd: &Path) -> Result<(), agent_client_protocol::Error> { +pub(super) fn validate_absolute_cwd(cwd: &Path) -> Result<(), agent_client_protocol::Error> { if !cwd.is_absolute() { return Err( agent_client_protocol::Error::invalid_params().data("cwd must be an absolute path") @@ -1192,34 +998,6 @@ fn validate_absolute_cwd(cwd: &Path) -> Result<(), agent_client_protocol::Error> } impl GooseAcpAgent { - fn available_commands_update(working_dir: &std::path::Path) -> AvailableCommandsUpdate { - let commands = crate::slash_commands::slash_command::list_acp_commands(Some(working_dir)) - .into_iter() - .map(|entry| { - let mut command = AvailableCommand::new(entry.name, entry.description); - if let Some(input_hint) = entry.input_hint { - command = command.input(AvailableCommandInput::Unstructured( - UnstructuredCommandInput::new(input_hint), - )); - } - command - }) - .collect(); - - AvailableCommandsUpdate::new(commands) - } - - fn send_available_commands_update( - cx: &ConnectionTo, - session_id: &SessionId, - working_dir: &std::path::Path, - ) -> Result<(), agent_client_protocol::Error> { - cx.send_notification(SessionNotification::new( - session_id.clone(), - SessionUpdate::AvailableCommandsUpdate(Self::available_commands_update(working_dir)), - )) - } - pub fn permission_manager(&self) -> Arc { Arc::clone(&self.permission_manager) } @@ -1236,32 +1014,37 @@ impl GooseAcpAgent { let permission_manager = Arc::new(PermissionManager::new(options.config_dir.clone())); let provider_inventory = ProviderInventoryService::new(session_manager.storage().clone()); + let agent_config = AgentConfig::new( + Arc::clone(&session_manager), + Arc::clone(&permission_manager), + None, + Config::global().get_goose_mode().unwrap_or_default(), + options.disable_session_naming, + options.goose_platform.clone(), + ); + let agent_manager = Arc::new(AgentManager::new(agent_config, None).await?); Ok(Self { sessions: Arc::new(Mutex::new(HashMap::new())), + agent_manager, provider_factory: options.provider_factory, builtins: options.builtins, client_fs_capabilities: OnceCell::new(), client_terminal: OnceCell::new(), client_mcp_host_info: OnceCell::new(), use_login_shell_path: OnceCell::new(), + client_cx: OnceCell::new(), config_dir: options.config_dir, session_manager, permission_manager, - goose_mode: options.goose_mode, disable_session_naming: options.disable_session_naming, provider_inventory, - goose_platform: options.goose_platform, additional_source_roots: options.additional_source_roots, }) } - fn load_config(&self) -> Result { - Config::new(self.config_dir.join(CONFIG_YAML_NAME), "goose").map_err(Into::into) - } - - fn config(&self) -> Result { - self.load_config().internal_err_ctx("Failed to read config") + fn config(&self) -> Result<&'static Config, agent_client_protocol::Error> { + Ok(Config::global()) } async fn create_provider( @@ -1280,427 +1063,258 @@ impl GooseAcpAgent { .await } - async fn prepare_session_init_config( + async fn maybe_refresh_provider_inventory_with_agent( &self, - resolved: &Result<(String, crate::model::ModelConfig), String>, - mode_state: &SessionModeState, goose_session: &Session, - ) -> ( - Option, - Option>, - Option>, + agent: &Arc, ) { - let Ok((provider_name, model_config)) = resolved else { - return (None, None, None); + let Some(provider_name) = goose_session.provider_name.as_deref() else { + return; }; - let Some(mut inventory) = self .provider_inventory - .entry_for_provider(provider_name) + .find_entry_for_provider(provider_name) .await - .ok() - .flatten() else { - return (None, None, None); + return; }; - - let mut prebuilt_provider = None; - if should_refresh_inventory_for_session_init(&inventory) { - match self.load_config() { - Ok(config) => { - let ext_state = EnabledExtensionsState::extensions_or_default( - Some(&goose_session.extension_data), - &config, - ); - Config::global().invalidate_secrets_cache(); - match self - .create_provider( - provider_name, - model_config.clone(), - ext_state, - Some(goose_session.working_dir.clone()), - ) - .await - { - Ok(provider) => { - let provider_id = provider_name.clone(); - prebuilt_provider = Some(provider.clone()); - match self - .provider_inventory - .plan_refresh_jobs(std::slice::from_ref(&provider_id)) - .await - { - Ok(plan) - if plan - .started - .iter() - .any(|job| job.provider_id == provider_id) => - { - let refresh_job = plan - .started - .into_iter() - .find(|job| job.provider_id == provider_id); - if let Some(refresh_job) = refresh_job { - let mut refresh_guard = self - .provider_inventory - .refresh_guard(&refresh_job.identity); - let fetch_result: Result> = - match ensure_refresh_identity_current( - &provider_id, - &refresh_job.identity, - ) - .await - { - Ok(()) => match AssertUnwindSafe( - provider.fetch_recommended_models(), - ) - .catch_unwind() - .await - { - Ok(Ok(models)) => Ok(models), - Ok(Err(error)) => { - Err(anyhow::anyhow!(error.to_string())) - } - Err(_) => Err(anyhow::anyhow!( - "provider inventory refresh task panicked" - )), - }, - Err(error) => Err(error), - }; - match fetch_result { - Ok(models) => { - if let Err(error) = self - .provider_inventory - .store_refreshed_models_for_identity( - &refresh_job.identity, - &models, - ) - .await - { - warn!( - provider = %provider_id, - error = %error, - "failed to store refreshed provider inventory during session init" - ); - } else { - refresh_guard.complete(); - } - } - Err(error) => { - let error_message = error.to_string(); - if let Err(store_error) = self - .provider_inventory - .store_refresh_error_for_identity( - &refresh_job.identity, - error_message.clone(), - ) - .await - { - warn!( - provider = %provider_id, - error = %store_error, - "failed to store provider inventory refresh error during session init" - ); - } else { - refresh_guard.complete(); - } - warn!( - provider = %provider_id, - error = %error_message, - "provider inventory refresh failed during session init" - ); - } - } - } - } - Ok(_) => {} - Err(error) => warn!( - provider = %provider_id, - error = %error, - "failed to plan provider inventory refresh during session init" - ), - } - - if let Ok(Some(refreshed_inventory)) = self - .provider_inventory - .entry_for_provider(provider_name) - .await - { - inventory = refreshed_inventory; - } - } - Err(error) => warn!( - provider = %provider_name, - error = %error, - "failed to initialize provider during synchronous inventory refresh" - ), - } - } - Err(error) => warn!( + if !should_refresh_inventory_for_session_init(&inventory) { + return; + } + let provider = match agent.provider().await { + Ok(provider) => provider, + Err(error) => { + warn!( provider = %provider_name, + session = %goose_session.id, error = %error, - "failed to load config during synchronous inventory refresh" - ), + "agent has no provider available for inventory refresh" + ); + return; + } + }; + self.provider_inventory + .refresh_with_provider(provider_name, &provider, &mut inventory, "session init") + .await; + } + + async fn get_or_create_session_agent_with_results( + &self, + cx: &ConnectionTo, + session_id: String, + ) -> Result { + self.agent_manager + .get_or_create_agent_with_runtime_context( + session_id, + RuntimeContext { + mcp_host_info: self.client_mcp_host_info.get().cloned(), + use_login_shell_path: self.use_login_shell_path.get().copied(), + session_name_update_tx: (!self.disable_session_naming) + .then(|| spawn_session_name_update_notifier(cx.clone())), + }, + ) + .await + .internal_err_ctx("Failed to create agent") + } + + fn initial_session_extensions( + &self, + config: &Config, + mcp_servers: Vec, + ) -> Result, agent_client_protocol::Error> { + let mut extensions = Vec::new(); + for builtin in &self.builtins { + push_or_replace_extension(&mut extensions, builtin_to_extension_config(builtin)); + } + + if mcp_servers.is_empty() { + for extension in get_enabled_extensions_with_config(config) { + push_or_replace_extension(&mut extensions, extension); + } + } else { + for mcp_server in mcp_servers { + let extension = mcp_server_to_extension_config(mcp_server).map_err(|message| { + agent_client_protocol::Error::invalid_params().data(message) + })?; + push_or_replace_extension(&mut extensions, extension); } } - let (model_state, config_options) = build_eager_config_from_inventory( - provider_name, - model_config.model_name.as_str(), - &inventory, - mode_state, - goose_session, - ) - .await; - (Some(model_state), Some(config_options), prebuilt_provider) + Ok(extensions) } - fn spawn_agent_setup( + async fn apply_acp_extension_overrides( &self, cx: &ConnectionTo, - agent_tx: tokio::sync::watch::Sender, - req: AgentSetupRequest, + agent: &Arc, + session: &Session, ) { - let AgentSetupRequest { - session_id, - goose_session, - mcp_servers, - resolved_provider, - prebuilt_provider, - } = req; - - let goose_mode = goose_session.goose_mode; - let setup_session_id = goose_session.id.clone(); - let agent_session_id = SessionId::new(setup_session_id.clone()); - let sid = sid_short(session_id.0.as_ref()); - - let cx = cx.clone(); - let sessions = Arc::clone(&self.sessions); - let session_manager = Arc::clone(&self.session_manager); - let permission_manager = Arc::clone(&self.permission_manager); - let config_dir = self.config_dir.clone(); - let builtins = self.builtins.clone(); let client_fs_capabilities = self .client_fs_capabilities .get() .cloned() .unwrap_or_default(); let client_terminal = self.client_terminal.get().copied().unwrap_or(false); - let client_mcp_host_info = self.client_mcp_host_info.get().cloned(); - let use_login_shell_path = self.use_login_shell_path.get().copied().unwrap_or(false); - let provider_factory = Arc::clone(&self.provider_factory); - let disable_session_naming = self.disable_session_naming; - let goose_platform = self.goose_platform.clone(); + if !client_fs_capabilities.read_text_file + && !client_fs_capabilities.write_text_file + && !client_terminal + { + return; + } - tokio::spawn(async move { - let t_setup = std::time::Instant::now(); - debug!(target: "perf", sid = %sid, "perf: agent_setup start (background)"); - // Shared config — read once, used by both phases. - let config = match Config::new(config_dir.join(CONFIG_YAML_NAME), "goose") { - Ok(c) => c, - Err(e) => { - let msg = e.to_string(); - error!(error = %msg, "Background agent setup failed (config)"); - let _ = agent_tx.send(Some(Err(msg))); - return; - } - }; + if !agent + .extension_manager + .is_extension_enabled("developer") + .await + { + return; + } - let session_name_update_tx = - (!disable_session_naming).then(|| spawn_session_name_update_notifier(cx.clone())); - - // ── Phase 1: create agent + init provider (fast, ~55ms) ────── - let phase1: Result, String> = async { - let agent = Arc::new(Agent::with_config( - AgentConfig::new( - session_manager, - permission_manager, - None, - goose_mode, - disable_session_naming, - goose_platform, - ) - .with_mcp_host_info(client_mcp_host_info) - .with_session_name_update_tx(session_name_update_tx) - .with_use_login_shell_path(use_login_shell_path), - )); + let context = agent.extension_manager.get_context().clone(); + let dev_client = match DeveloperClient::new(context) { + Ok(dev_client) => dev_client, + Err(error) => { + warn!(error = %error, "Failed to create ACP developer client"); + return; + } + }; - // Init provider — reuse the pre-resolved name + model when - // available (already computed in on_new_session), otherwise - // fall back to reading config (e.g. load_session path). - let (provider_name, model_config) = match resolved_provider { - Some(resolved) => resolved, - None => resolve_provider_and_model_from_config(&config, &goose_session).await?, - }; - let ext_state = EnabledExtensionsState::extensions_or_default( - Some(&goose_session.extension_data), - &config, - ); - let provider = match prebuilt_provider { - Some(provider) => provider, - None => provider_factory( - provider_name.to_string(), - model_config, - ext_state, - Some(goose_session.working_dir.clone()), - ) - .await - .map_err(|e| e.to_string())?, - }; - agent - .update_provider(provider.clone(), &goose_session.id) - .await - .map_err(|e| e.to_string())?; + let client: Arc = Arc::new(AcpTools { + inner: Arc::new(dev_client), + cx: cx.clone(), + session_id: SessionId::new(session.id.clone()), + fs_read: client_fs_capabilities.read_text_file, + fs_write: client_fs_capabilities.write_text_file, + terminal: client_terminal, + }); + let info = client.get_info().cloned(); - agent - .update_goose_mode(goose_mode, &setup_session_id) - .await - .map_err(|e| e.to_string())?; + let developer_config = agent + .extension_manager + .get_extension_configs() + .await + .into_iter() + .find(|extension| extension.name() == "developer") + .unwrap_or_else(|| builtin_to_extension_config("developer")); - Ok(agent) - } + agent + .extension_manager + .add_client("developer".into(), developer_config, client, info, None) .await; + } - let agent = match phase1 { - Ok(agent) => { - // Signal ProviderReady — unblocks setProvider / update_provider - // while extensions continue loading below. - let _ = - agent_tx.send(Some(Ok(AgentSetupProgress::ProviderReady(agent.clone())))); - debug!(target: "perf", sid = %sid, ms = t_setup.elapsed().as_millis() as u64, "perf: agent_setup provider_ready (signalled)"); - agent - } - Err(e) => { - error!(error = %e, "Background agent setup failed (provider init)"); - debug!(target: "perf", sid = %sid, ms = t_setup.elapsed().as_millis() as u64, "perf: agent_setup failed (provider)"); - let _ = agent_tx.send(Some(Err(e))); - return; - } - }; + async fn prepare_acp_session_agent( + &self, + cx: &ConnectionTo, + session: &Session, + ) -> Result<(Arc, Vec), agent_client_protocol::Error> { + let agent_result = self + .get_or_create_session_agent_with_results(cx, session.id.clone()) + .await?; + let agent = agent_result.agent.clone(); + self.apply_acp_extension_overrides(cx, &agent, session) + .await; + self.maybe_refresh_provider_inventory_with_agent(session, &agent) + .await; - // ── Phase 2: load extensions (slow, may take seconds) ──────── - let phase2: Result<(), String> = async { - let mut extensions = get_enabled_extensions_with_config(&config); - extensions.extend(builtins.iter().map(|b| builtin_to_extension_config(b))); + Ok((agent, agent_result.extension_results)) + } - let acp_developer = if (client_fs_capabilities.read_text_file - || client_fs_capabilities.write_text_file - || client_terminal) - && extensions.iter().any(|e| e.name() == "developer") - { - let context = agent.extension_manager.get_context().clone(); - match DeveloperClient::new(context) { - Ok(dev_client) => { - let client: Arc = Arc::new(AcpTools { - inner: Arc::new(dev_client), - cx: cx.clone(), - session_id: session_id.clone(), - fs_read: client_fs_capabilities.read_text_file, - fs_write: client_fs_capabilities.write_text_file, - terminal: client_terminal, - }); - let dev_ext = extensions.iter().find(|e| e.name() == "developer"); - let available_tools = dev_ext - .and_then(|e| match e { - ExtensionConfig::Platform { - available_tools, .. - } => Some(available_tools.clone()), - _ => None, - }) - .unwrap_or_default(); - let def = &PLATFORM_EXTENSIONS["developer"]; - let config = ExtensionConfig::Platform { - name: def.name.into(), - description: def.description.into(), - display_name: Some(def.display_name.into()), - bundled: Some(true), - available_tools, - }; - Some((client, config)) - } - Err(e) => { - warn!(error = %e, "Failed to create developer client"); - None - } - } - } else { - None - }; + async fn prepare_session_for_activation( + &self, + mut session: Session, + cwd: std::path::PathBuf, + mcp_servers: Vec, + include_messages_on_reload: bool, + ) -> Result { + let config = Config::global(); + let mut builder = self.session_manager.update(&session.id); + let mut session_needs_update = false; + + if cwd != session.working_dir { + builder = builder.working_dir(cwd); + session_needs_update = true; + } - let skip_developer = acp_developer.is_some(); - let sid_str = Some(agent_session_id.0.to_string()); + if session.provider_name.is_none() || session.model_config.is_none() { + let (resolved_provider, resolved_model_config) = + resolve_default_provider_model_config(config)?; + builder = builder + .provider_name(resolved_provider) + .model_config(resolved_model_config); + session_needs_update = true; + } - if skip_developer { - extensions.retain(|ext| ext.name() != "developer"); - } + if !mcp_servers.is_empty() + || EnabledExtensionsState::from_extension_data(&session.extension_data).is_none() + { + let extension_data = + self.build_enabled_extensions_data(config, &session, mcp_servers)?; + builder = builder.extension_data(extension_data); + session_needs_update = true; + } - let ext_manager = &agent.extension_manager; - let working_dir = goose_session.working_dir.clone(); - let extension_futures = extensions - .into_iter() - .map(|ext| { - let ext_manager = Arc::clone(ext_manager); - let sid_inner = sid_str.clone(); - let working_dir = working_dir.clone(); - async move { - let name = ext.name().to_string(); - if let Err(e) = ext_manager - .add_extension(ext, Some(working_dir), None, sid_inner.as_deref()) - .await - { - warn!(extension = %name, error = %e, "extension load failed"); - } - } - }) - .collect::>(); - futures::future::join_all(extension_futures).await; - - if let Some((client, config)) = acp_developer { - let info = client.get_info().cloned(); - agent - .extension_manager - .add_client("developer".into(), config, client, info, None) - .await; - } + if session_needs_update { + let session_id = session.id.clone(); + builder + .apply() + .await + .internal_err_ctx("Failed to update session")?; - GooseAcpAgent::add_mcp_extensions(&agent, mcp_servers, &setup_session_id) - .await - .map_err(|e| e.to_string())?; + let _ = self.agent_manager.remove_session(&session_id).await; - Ok(()) - } - .await; + session = self + .session_manager + .get_session(&session_id, include_messages_on_reload) + .await + .internal_err_ctx("Failed to reload session")?; + } - if let Err(e) = &phase2 { - // Extension failures are non-fatal — individual failures are - // already logged as warnings. Log the top-level error but - // don't block the session: the provider is ready and the agent - // is usable. - error!(error = %e, "Background agent setup: extension phase had errors"); - } + Ok(session) + } - // Promote the handle to Ready and apply any working directory that - // was set while we were loading — regardless of phase-2 outcome, - // since the agent (with its provider) is fully usable. - { - let mut locked = sessions.lock().await; - if let Some(session) = locked.get_mut(session_id.0.as_ref()) { - if let Some(dir) = session.pending_working_dir.take() { - agent.extension_manager.update_working_dir(&dir).await; - } - session.agent = AgentHandle::Ready(agent.clone()); - } - } + fn build_enabled_extensions_data( + &self, + config: &Config, + session: &Session, + mcp_servers: Vec, + ) -> Result { + let extensions = self.initial_session_extensions(config, mcp_servers)?; + let mut extension_data = session.extension_data.clone(); + EnabledExtensionsState::new(extensions) + .to_extension_data(&mut extension_data) + .internal_err_ctx("Failed to initialize session extensions")?; + Ok(extension_data) + } - let _ = agent_tx.send(Some(Ok(AgentSetupProgress::FullyReady(agent)))); - debug!( - target: "perf", - sid = %sid, - ms = t_setup.elapsed().as_millis() as u64, - "perf: agent_setup done{}", - if phase2.is_err() { " (with extension errors)" } else { "" } - ); - }); + async fn register_acp_session( + &self, + session_id: String, + agent: Arc, + tool_requests: HashMap, + ) { + let acp_session = GooseAcpSession { + agent, + tool_requests, + chain_membership: HashMap::new(), + responded_tool_ids: HashSet::new(), + summarized_chains: HashSet::new(), + cancel_token: None, + }; + self.sessions.lock().await.insert(session_id, acp_session); + } + + async fn activate_acp_session( + &self, + cx: &ConnectionTo, + session: &Session, + tool_requests: HashMap, + ) -> Result<(Arc, Vec), agent_client_protocol::Error> { + let (agent, extension_results) = self.prepare_acp_session_agent(cx, session).await?; + self.register_acp_session(session.id.clone(), agent.clone(), tool_requests) + .await; + + Ok((agent, extension_results)) } pub async fn has_session(&self, session_id: &str) -> bool { @@ -1782,6 +1396,7 @@ impl GooseAcpAgent { session_id: &SessionId, session_id_str: &str, message_id: Option<&str>, + message_created: i64, agent: &Arc, session: &mut GooseAcpSession, cx: &ConnectionTo, @@ -1790,9 +1405,10 @@ impl GooseAcpAgent { MessageContent::Text(text) => { cx.send_notification(SessionNotification::new( session_id.clone(), - SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( - TextContent::new(text.text.clone()), - ))), + SessionUpdate::AgentMessageChunk( + ContentChunk::new(ContentBlock::Text(TextContent::new(text.text.clone()))) + .meta(message_update_meta(message_id, message_created)), + ), ))?; } MessageContent::ToolRequest(tool_request) => { @@ -1820,19 +1436,21 @@ impl GooseAcpAgent { MessageContent::Thinking(thinking) => { cx.send_notification(SessionNotification::new( session_id.clone(), - SessionUpdate::AgentThoughtChunk(ContentChunk::new(ContentBlock::Text( - TextContent::new(thinking.thinking.clone()), - ))), + SessionUpdate::AgentThoughtChunk( + ContentChunk::new(ContentBlock::Text(TextContent::new( + thinking.thinking.clone(), + ))) + .meta(message_update_meta(message_id, message_created)), + ), ))?; } - MessageContent::ActionRequired(action_required) => { - if let ActionRequiredData::ToolConfirmation { + MessageContent::ActionRequired(action_required) => match &action_required.data { + ActionRequiredData::ToolConfirmation { id, tool_name, arguments, prompt, - } = &action_required.data - { + } => { self.handle_tool_permission_request( cx, agent, @@ -1843,6 +1461,25 @@ impl GooseAcpAgent { prompt.clone(), )?; } + ActionRequiredData::Elicitation { + id, + message, + requested_schema, + } => { + send_elicitation_interaction_update( + cx, + session_id.0.as_ref(), + id.clone(), + InteractionState::Pending, + Some(message.clone()), + Some(requested_schema.clone()), + Some(interaction_update_meta(message_id, message_created)), + )?; + } + ActionRequiredData::ElicitationResponse { .. } => {} + }, + MessageContent::SystemNotification(notification) => { + send_status_message_update(cx, session_id.0.as_ref(), notification)?; } _ => {} } @@ -1879,10 +1516,7 @@ impl GooseAcpAgent { } if let Ok(tool_call) = &tool_request.tool_call { - let agent = match &session.agent { - AgentHandle::Ready(a) => a.clone(), - AgentHandle::Loading(_) => return Ok(()), - }; + let agent = session.agent.clone(); let sid = session_id.clone(); let request_id = tool_request.id.clone(); let cx = cx.clone(); @@ -2127,15 +1761,7 @@ impl GooseAcpAgent { return; } - let agent = match &session.agent { - AgentHandle::Ready(a) => a.clone(), - AgentHandle::Loading(_) => { - warn!( - "tool chain summary: agent still loading; skipping chain anchored at {first_id}", - ); - return; - } - }; + let agent = session.agent.clone(); // Snapshot (name, args_json) for each step in document order. let steps: Vec<(String, String)> = chain @@ -2377,6 +2003,111 @@ fn outcome_to_confirmation(outcome: &RequestPermissionOutcome) -> PermissionConf } } +fn prompt_error_from_message_content( + content_item: &MessageContent, +) -> Option { + match content_item { + MessageContent::SystemNotification(notification) + if notification.notification_type == SystemNotificationType::CreditsExhausted => + { + Some(credits_exhausted_prompt_error(notification)) + } + _ => None, + } +} + +fn credits_exhausted_prompt_error( + notification: &SystemNotificationContent, +) -> agent_client_protocol::Error { + let mut data = serde_json::Map::new(); + data.insert( + "reason".to_string(), + serde_json::Value::String("credits_exhausted".to_string()), + ); + + if let Some(url) = notification + .data + .as_ref() + .and_then(|data| data.get("top_up_url")) + .and_then(|url| url.as_str()) + { + data.insert( + "url".to_string(), + serde_json::Value::String(url.to_string()), + ); + } + + agent_client_protocol::Error::new(-32603, notification.msg.clone()) + .data(serde_json::Value::Object(data)) +} + +fn send_status_message_update( + cx: &ConnectionTo, + session_id: &str, + notification: &SystemNotificationContent, +) -> Result<(), agent_client_protocol::Error> { + if let Some(status) = status_message_from_system_notification(notification) { + cx.send_notification(GooseSessionNotification { + session_id: session_id.to_string(), + update: GooseSessionUpdate::StatusMessage(StatusMessageUpdate { status }), + })?; + } + Ok(()) +} + +fn status_message_from_system_notification( + notification: &SystemNotificationContent, +) -> Option { + match notification.notification_type { + SystemNotificationType::InlineMessage => Some(StatusMessage::Notice { + message: notification.msg.clone(), + }), + SystemNotificationType::ThinkingMessage => Some(StatusMessage::Progress { + message: notification.msg.clone(), + }), + SystemNotificationType::CreditsExhausted => None, + } +} + +fn send_elicitation_interaction_update( + cx: &ConnectionTo, + session_id: &str, + id: String, + state: InteractionState, + message: Option, + requested_schema: Option, + meta: Option, +) -> Result<(), agent_client_protocol::Error> { + cx.send_notification(GooseSessionNotification { + session_id: session_id.to_string(), + update: GooseSessionUpdate::InteractionUpdate(InteractionUpdate { + interaction: Interaction::Elicitation { + id, + state, + message, + requested_schema, + }, + meta, + }), + }) +} + +fn interaction_update_meta(message_id: Option<&str>, created: i64) -> serde_json::Value { + serde_json::Value::Object(message_update_meta(message_id, created)) +} + +fn message_update_meta(message_id: Option<&str>, created: i64) -> Meta { + let mut goose = serde_json::Map::new(); + goose.insert("created".to_string(), serde_json::json!(created)); + if let Some(id) = message_id { + goose.insert("messageId".to_string(), serde_json::json!(id)); + } + + let mut meta = serde_json::Map::new(); + meta.insert("goose".to_string(), serde_json::Value::Object(goose)); + meta +} + fn extract_tool_call_update_meta( tool_response: &crate::conversation::message::ToolResponse, ) -> Option { @@ -2401,32 +2132,6 @@ fn replay_message_meta(message: &Message) -> Meta { meta } -fn replay_audience_annotations(audience: &[Role]) -> Annotations { - Annotations::new().audience( - audience - .iter() - .map(|role| match role { - Role::Assistant => agent_client_protocol::schema::Role::Assistant, - Role::User => agent_client_protocol::schema::Role::User, - }) - .collect::>(), - ) -} - -fn send_replay_content_chunk( - cx: &ConnectionTo, - session_id: &SessionId, - message: &Message, - content: ContentBlock, -) -> std::result::Result<(), agent_client_protocol::Error> { - let chunk = ContentChunk::new(content).meta(replay_message_meta(message)); - let update = match message.role { - Role::User => SessionUpdate::UserMessageChunk(chunk), - Role::Assistant => SessionUpdate::AgentMessageChunk(chunk), - }; - cx.send_notification(SessionNotification::new(session_id.clone(), update)) -} - fn replay_message_goose_meta(message: &Message) -> serde_json::Map { let mut goose = serde_json::Map::new(); goose.insert("created".to_string(), serde_json::json!(message.created)); @@ -2550,503 +2255,94 @@ impl GooseAcpAgent { cx: &ConnectionTo, args: NewSessionRequest, ) -> Result { - debug!(?args, "new session request"); - let t_start = std::time::Instant::now(); - validate_absolute_cwd(&args.cwd)?; - - let requested_provider = args - .meta - .as_ref() - .and_then(|m| m.get("provider")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - let project_id = args - .meta - .as_ref() - .and_then(|m| m.get("projectId")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - // When _meta.client is set, the session is created by a known client - // (e.g. "goose" for the desktop app) and treated as a User session. - // Without it, sessions default to Acp for programmatic ACP clients. - let session_type = match args - .meta - .as_ref() - .and_then(|m| m.get("client")) - .and_then(|v| v.as_str()) - { - Some(_) => SessionType::User, - None => SessionType::Acp, - }; - - let t0 = std::time::Instant::now(); - let goose_session = self - .session_manager - .create_session( - args.cwd.clone(), - "New Chat".to_string(), - session_type, - self.goose_mode, - ) - .await - .internal_err_ctx("Failed to create session")?; - - let mut builder = self.session_manager.update(&goose_session.id); - if let Some(ref provider) = requested_provider { - builder = builder.provider_name(provider); - } - if let Some(pid) = project_id { - builder = builder.project_id(Some(pid)); - } - builder - .apply() - .await - .internal_err_ctx("Failed to update session")?; - - let goose_session = self - .session_manager - .get_session(&goose_session.id, false) - .await - .internal_err_ctx("Failed to reload session")?; - - let session_id_str = goose_session.id.clone(); - let sid = sid_short(&session_id_str); - debug!(target: "perf", sid = %sid, ms = t0.elapsed().as_millis() as u64, "perf: new_session create_session"); - - let (agent_tx, agent_rx) = tokio::sync::watch::channel::(None); - - let acp_session = GooseAcpSession { - agent: AgentHandle::Loading(agent_rx), - tool_requests: HashMap::new(), - chain_membership: HashMap::new(), - responded_tool_ids: HashSet::new(), - summarized_chains: HashSet::new(), - cancel_token: None, - pending_working_dir: None, - }; - self.sessions - .lock() - .await - .insert(session_id_str.clone(), acp_session); - - let mode_state = build_mode_state(self.goose_mode)?; - - let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await; - let initial_usage_update = resolved - .as_ref() - .ok() - .map(|(_, mc)| build_usage_update(&goose_session, mc.context_limit())); - let acp_session_id = SessionId::new(session_id_str); - let (model_state, config_options, prebuilt_provider) = self - .prepare_session_init_config(&resolved, &mode_state, &goose_session) - .await; - - let working_dir = goose_session.working_dir.clone(); - - self.spawn_agent_setup( - cx, - agent_tx, - AgentSetupRequest { - session_id: acp_session_id.clone(), - goose_session, - mcp_servers: args.mcp_servers, - resolved_provider: resolved.as_ref().ok().cloned(), - prebuilt_provider, - }, - ); - - let mut response = NewSessionResponse::new(acp_session_id.clone()).modes(mode_state); - if let Some(ms) = model_state { - response = response.models(ms); - } - if let Some(co) = config_options { - response = response.config_options(co); - } - if let Some(usage_update) = initial_usage_update { - cx.send_notification(SessionNotification::new( - acp_session_id.clone(), - SessionUpdate::UsageUpdate(usage_update), - ))?; - } - Self::send_available_commands_update(cx, &acp_session_id, &working_dir)?; - debug!( - target: "perf", - sid = %sid, - ms = t_start.elapsed().as_millis() as u64, - "perf: new_session done (agent setup continues in background)" - ); - Ok(response) - } - - /// Look up the session and return the agent if already ready, or the watch - /// receiver if still loading. Optionally sets a cancellation token on the - /// session (needed by `on_prompt`). - async fn get_agent_or_receiver( - &self, - session_id: &str, - cancel_token: Option, - ) -> Result< - Either, tokio::sync::watch::Receiver>, - agent_client_protocol::Error, - > { - let mut sessions = self.sessions.lock().await; - let session = sessions.get_mut(session_id).ok_or_else(|| { - agent_client_protocol::Error::resource_not_found(Some(session_id.to_string())) - .data(format!("Session not found: {}", session_id)) - })?; - if let Some(token) = cancel_token { - session.cancel_token = Some(token); - } - match &session.agent { - AgentHandle::Ready(agent) => Ok(Either::Left(agent.clone())), - AgentHandle::Loading(rx) => Ok(Either::Right(rx.clone())), - } - } - - /// Wait until the agent is **fully ready** (provider + all extensions). - /// Most callers (e.g. `on_prompt`, `on_get_tools`) should use this. - async fn get_session_agent( - &self, - session_id: &str, - cancel_token: Option, - ) -> Result, agent_client_protocol::Error> { - let mut rx = match self.get_agent_or_receiver(session_id, cancel_token).await? { - Either::Left(agent) => return Ok(agent), - Either::Right(rx) => rx, - }; - // Wait specifically for FullyReady (not just ProviderReady). - let guard = rx - .wait_for(|v| { - matches!( - v, - Some(Ok(AgentSetupProgress::FullyReady(_))) | Some(Err(_)) - ) - }) - .await - .map_err(|_| { - agent_client_protocol::Error::internal_error() - .data("Agent setup task was dropped".to_string()) - })?; - match guard.as_ref().unwrap() { - Ok(AgentSetupProgress::FullyReady(agent)) => Ok(agent.clone()), - Err(e) => Err(agent_client_protocol::Error::internal_error().data(e.clone())), - // wait_for predicate excludes ProviderReady - _ => unreachable!(), - } - } - - /// Wait only until the **provider** is initialized. Extensions may still - /// be loading in the background. Use this for operations that only touch - /// the provider (e.g. `update_provider`, `set_model`, `build_config_update`). - async fn get_session_agent_provider_ready( - &self, - session_id: &str, - ) -> Result, agent_client_protocol::Error> { - let mut rx = match self.get_agent_or_receiver(session_id, None).await? { - Either::Left(agent) => return Ok(agent), - Either::Right(rx) => rx, - }; - // Any signal (ProviderReady, FullyReady, or Err) unblocks us. - let guard = rx.wait_for(|v| v.is_some()).await.map_err(|_| { - agent_client_protocol::Error::internal_error() - .data("Agent setup task was dropped".to_string()) - })?; - match guard.as_ref().unwrap() { - Ok(progress) => match progress { - AgentSetupProgress::ProviderReady(agent) - | AgentSetupProgress::FullyReady(agent) => Ok(agent.clone()), - }, - Err(e) => Err(agent_client_protocol::Error::internal_error().data(e.clone())), - } - } - - async fn add_mcp_extensions( - agent: &Arc, - mcp_servers: Vec, - session_id: &str, - ) -> Result<(), agent_client_protocol::Error> { - let mut configs = Vec::with_capacity(mcp_servers.len()); - for mcp_server in mcp_servers { - let config = match mcp_server_to_extension_config(mcp_server) { - Ok(c) => c, - Err(msg) => { - return Err(agent_client_protocol::Error::invalid_params().data(msg)); - } - }; - configs.push(config); - } - - if configs.is_empty() { - return Ok(()); - } - - let results = agent - .add_extensions_bulk(configs, session_id) - .await - .internal_err()?; - for result in &results { - if !result.success { - let error_msg = result.error.as_deref().unwrap_or("unknown error"); - return Err(agent_client_protocol::Error::internal_error().data(format!( - "Failed to add MCP server '{}': {}", - result.name, error_msg - ))); - } - } - Ok(()) + self.handle_new_session(cx, args).await } - - async fn on_load_session( - &self, - cx: &ConnectionTo, - args: LoadSessionRequest, - ) -> Result { - debug!(?args, "load session request"); - validate_absolute_cwd(&args.cwd)?; - - let session_id = args.session_id.0.to_string(); - let sid = sid_short(&session_id); - let t_start = std::time::Instant::now(); - - let t0 = std::time::Instant::now(); - let goose_session = self - .session_manager - .get_session(&session_id, true) - .await - .map_err(|_| { - agent_client_protocol::Error::resource_not_found(Some(session_id.clone())) - .data(format!("Session not found: {}", session_id)) - })?; - debug!(target: "perf", sid = %sid, ms = t0.elapsed().as_millis() as u64, "perf: load_session get_session"); - let loaded_mode = goose_session.goose_mode; - - // ── REPLAY MESSAGES ── - // Stream user-visible messages back to the client so the chat view - // populates immediately, before the slow agent/provider/extension setup. - let messages = goose_session - .conversation - .as_ref() - .map(|c| c.messages().to_vec()) - .unwrap_or_default(); - debug!( - target: "perf", - sid = %sid, - messages = messages.len(), - "perf: load_session messages loaded" - ); - - let mut replay_tool_requests = - HashMap::::new(); - - for message in &messages { - if !message.metadata.user_visible { - continue; - } - - for content_item in &message.content { - match content_item { - MessageContent::Text(text) => { - let mut tc = TextContent::new(text.text.clone()); - if let Some(audience) = text.audience() { - tc = tc.annotations(replay_audience_annotations(audience)); - } - send_replay_content_chunk( - cx, - &args.session_id, - message, - ContentBlock::Text(tc), - )?; - } - MessageContent::Image(image) => { - let mut image_content = - ImageContent::new(image.data.clone(), image.mime_type.clone()); - if let Some(audience) = image.audience() { - image_content = - image_content.annotations(replay_audience_annotations(audience)); - } - send_replay_content_chunk( - cx, - &args.session_id, - message, - ContentBlock::Image(image_content), - )?; - } - MessageContent::ToolRequest(tool_request) => { - // Replay-only: emit the ToolCall notification and - // stash the request for location extraction, but - // don't require a full GooseAcpSession. - replay_tool_requests.insert(tool_request.id.clone(), tool_request.clone()); - - let pending_tool_call = pending_tool_call_from_request(tool_request); - let mut meta = pending_tool_call.identity_meta; - // If this tool request is the first of a chain whose - // summary was persisted at completion time, attach the - // chain summary to the initial ToolCall so the chain - // header is correct on first paint after reload. - if let Some(chain_summary) = tool_request.persisted_chain_summary() { - meta = with_tool_chain_summary_meta( - meta, - &chain_summary.summary, - chain_summary.count, - ); - } - let tool_call = pending_tool_call - .tool_call - .meta(merge_replay_message_meta(meta, message)); - - cx.send_notification(SessionNotification::new( - args.session_id.clone(), - SessionUpdate::ToolCall(tool_call), - ))?; - } - MessageContent::ToolResponse(tool_response) => { - // Replay-only: emit the ToolCallUpdate notification, - // using the stashed replay_tool_requests for location - // extraction. - let status = match &tool_response.tool_result { - Ok(result) if result.is_error == Some(true) => ToolCallStatus::Failed, - Ok(_) => ToolCallStatus::Completed, - Err(_) => ToolCallStatus::Failed, - }; - - let mut fields = ToolCallUpdateFields::new().status(status); - if let Some(raw_output) = - extract_tool_raw_output(&tool_response.tool_result) - { - fields = fields.raw_output(raw_output); - } - if !tool_response - .tool_result - .as_ref() - .is_ok_and(|r| r.is_acp_aware()) - { - let content = build_tool_call_content(&tool_response.tool_result); - fields = fields.content(content); - - let locations = extract_locations_from_meta(tool_response) - .unwrap_or_else(|| { - if let Some(tool_request) = - replay_tool_requests.get(&tool_response.id) - { - extract_tool_locations(tool_request, tool_response) - } else { - Vec::new() - } - }); - if !locations.is_empty() { - fields = fields.locations(locations); - } - } - - let update = - ToolCallUpdate::new(ToolCallId::new(tool_response.id.clone()), fields) - .meta(merge_replay_message_meta( - extract_tool_call_update_meta(tool_response), - message, - )); - cx.send_notification(SessionNotification::new( - args.session_id.clone(), - SessionUpdate::ToolCallUpdate(update), - ))?; - } - MessageContent::Thinking(thinking) => { - cx.send_notification(SessionNotification::new( - args.session_id.clone(), - SessionUpdate::AgentThoughtChunk( - ContentChunk::new(ContentBlock::Text(TextContent::new( - thinking.thinking.clone(), - ))) - .meta(replay_message_meta(message)), - ), - ))?; - } - _ => {} + + /// Look up the session's agent. Optionally sets a cancellation token on + /// the session (needed by `on_prompt`). + async fn get_session_agent( + &self, + session_id: &str, + cancel_token: Option, + ) -> Result, agent_client_protocol::Error> { + { + let mut sessions = self.sessions.lock().await; + if let Some(session) = sessions.get_mut(session_id) { + if let Some(token) = cancel_token { + session.cancel_token = Some(token); } + return Ok(session.agent.clone()); } } - // Update working directory. - self.session_manager - .update(&session_id) - .working_dir(args.cwd.clone()) - .apply() - .await - .internal_err_ctx("Failed to update session working directory")?; - let goose_session = self + let cx = self.client_cx.get().ok_or_else(|| { + agent_client_protocol::Error::resource_not_found(Some(session_id.to_string())) + .data(format!("Session not found: {}", session_id)) + })?; + let session = self .session_manager - .get_session(&session_id, false) - .await - .internal_err_ctx("Failed to reload session")?; - - // Register the session with a Loading handle. - let (agent_tx, agent_rx) = tokio::sync::watch::channel::(None); - - let acp_session = GooseAcpSession { - agent: AgentHandle::Loading(agent_rx), - tool_requests: replay_tool_requests, - chain_membership: HashMap::new(), - responded_tool_ids: HashSet::new(), - summarized_chains: HashSet::new(), - cancel_token: None, - pending_working_dir: None, - }; - self.sessions - .lock() + .get_session(session_id, false) .await - .insert(session_id.clone(), acp_session); - - let mode_state = build_mode_state(loaded_mode)?; - - let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await; - let initial_usage_update = resolved - .as_ref() - .ok() - .map(|(_, mc)| build_usage_update(&goose_session, mc.context_limit())) - .or_else(|| { - goose_session - .model_config - .as_ref() - .map(|mc| build_usage_update(&goose_session, mc.context_limit())) - }); - let (model_state, config_options, prebuilt_provider) = self - .prepare_session_init_config(&resolved, &mode_state, &goose_session) - .await; + .map_err(|_| { + agent_client_protocol::Error::resource_not_found(Some(session_id.to_string())) + .data(format!("Session not found: {}", session_id)) + })?; + let (agent, _) = self + .activate_acp_session(cx, &session, HashMap::new()) + .await?; - self.spawn_agent_setup( - cx, - agent_tx, - AgentSetupRequest { - session_id: args.session_id.clone(), - goose_session, - mcp_servers: args.mcp_servers, - resolved_provider: None, - prebuilt_provider, - }, - ); + if let Some(token) = cancel_token { + let mut sessions = self.sessions.lock().await; + if let Some(session) = sessions.get_mut(session_id) { + session.cancel_token = Some(token); + } + } + Ok(agent) + } - let mut response = LoadSessionResponse::new().modes(mode_state); - if let Some(ms) = model_state { - response = response.models(ms); + #[allow(dead_code)] + async fn add_mcp_extensions( + agent: &Arc, + mcp_servers: Vec, + session_id: &str, + ) -> Result<(), agent_client_protocol::Error> { + let mut configs = Vec::with_capacity(mcp_servers.len()); + for mcp_server in mcp_servers { + let config = match mcp_server_to_extension_config(mcp_server) { + Ok(c) => c, + Err(msg) => { + return Err(agent_client_protocol::Error::invalid_params().data(msg)); + } + }; + configs.push(config); } - if let Some(co) = config_options { - response = response.config_options(co); + + if configs.is_empty() { + return Ok(()); } - if let Some(usage_update) = initial_usage_update { - cx.send_notification(SessionNotification::new( - args.session_id.clone(), - SessionUpdate::UsageUpdate(usage_update), - ))?; + + let results = agent + .add_extensions_bulk(configs, session_id) + .await + .internal_err()?; + for result in &results { + if !result.success { + let error_msg = result.error.as_deref().unwrap_or("unknown error"); + return Err(agent_client_protocol::Error::internal_error().data(format!( + "Failed to add MCP server '{}': {}", + result.name, error_msg + ))); + } } - Self::send_available_commands_update(cx, &args.session_id, &args.cwd)?; - debug!( - target: "perf", - sid = %sid, - ms = t_start.elapsed().as_millis() as u64, - "perf: load_session done (agent setup continues in background)" - ); - Ok(response) + Ok(()) + } + + async fn on_load_session( + &self, + cx: &ConnectionTo, + args: LoadSessionRequest, + ) -> Result { + self.handle_load_session(cx, args).await } async fn on_prompt( @@ -3146,6 +2442,11 @@ impl GooseAcpAgent { })?; for content_item in &message.content { + if let Some(error) = prompt_error_from_message_content(content_item) { + session.cancel_token = None; + return Err(error); + } + match content_item { MessageContent::ToolRequest(tr) => { if let Some(msg_id) = stored_message_id.as_deref() { @@ -3178,6 +2479,7 @@ impl GooseAcpAgent { &args.session_id, &session_id, stored_message_id.as_deref(), + message.created, &agent, session, cx, @@ -3210,16 +2512,16 @@ impl GooseAcpAgent { .get_session(&session_id, false) .await .internal_err_ctx("Failed to load session")?; - let provider = agent - .provider() - .await - .internal_err_ctx("Failed to get provider")?; - let usage_update = - build_usage_update(&session, provider.get_model_config().context_limit()); - cx.send_notification(SessionNotification::new( - args.session_id.clone(), - SessionUpdate::UsageUpdate(usage_update), - ))?; + if let Some(updates) = build_usage_updates(&session) { + cx.send_notification(updates.custom)?; + // Standard ACP notification — emitted alongside the custom one for + // backwards compatibility. Remove once all known clients have + // migrated to `_goose/unstable/session/update`. + cx.send_notification(SessionNotification::new( + args.session_id.clone(), + SessionUpdate::UsageUpdate(updates.standard), + ))?; + } debug!( target: "perf", @@ -3263,13 +2565,52 @@ impl GooseAcpAgent { Ok(()) } + async fn on_elicitation_respond( + &self, + cx: &ConnectionTo, + req: ElicitationRespondRequest, + ) -> Result { + ActionRequiredManager::global() + .submit_response(req.elicitation_id.clone(), req.user_data.clone()) + .await + .invalid_params_err_ctx("Failed to submit elicitation response")?; + + let response_message = Message::user() + .with_generated_id() + .with_content(MessageContent::action_required_elicitation_response( + req.elicitation_id.clone(), + req.user_data, + )) + .agent_only(); + + self.session_manager + .add_message(&req.session_id, &response_message) + .await + .internal_err_ctx("Failed to persist elicitation response")?; + + send_elicitation_interaction_update( + cx, + &req.session_id, + req.elicitation_id, + InteractionState::Submitted, + None, + None, + Some(interaction_update_meta( + response_message.id.as_deref(), + response_message.created, + )), + )?; + + Ok(EmptyResponse {}) + } + async fn on_set_model( &self, session_id: &str, model_id: &str, ) -> Result { let config = self.config()?; - let agent = self.get_session_agent_provider_ready(session_id).await?; + let agent = self.get_session_agent(session_id, None).await?; let current_provider = agent .provider() .await @@ -3277,7 +2618,7 @@ impl GooseAcpAgent { let provider_name = current_provider.get_name().to_string(); let current_model_config = current_provider.get_model_config(); let extensions = - EnabledExtensionsState::for_session(&self.session_manager, session_id, &config).await; + EnabledExtensionsState::for_session(&self.session_manager, session_id, config).await; let model_config = crate::model::ModelConfig::new(model_id) .invalid_params_err_ctx("Invalid model config")? .with_canonical_limits(&provider_name); @@ -3319,7 +2660,7 @@ impl GooseAcpAgent { .get_session(&session_id.0, false) .await .internal_err()?; - let agent = self.get_session_agent_provider_ready(&session_id.0).await?; + let agent = self.get_session_agent(&session_id.0, None).await?; let provider = agent .provider() .await @@ -3362,7 +2703,7 @@ impl GooseAcpAgent { .data(format!("Invalid mode: {}", mode_id)) })?; - let agent = self.get_session_agent_provider_ready(session_id).await?; + let agent = self.get_session_agent(session_id, None).await?; agent .update_goose_mode(mode, session_id) .await @@ -3382,7 +2723,7 @@ impl GooseAcpAgent { request_params: Option>, ) -> Result<(), agent_client_protocol::Error> { let config = self.config()?; - let agent = self.get_session_agent_provider_ready(session_id).await?; + let agent = self.get_session_agent(session_id, None).await?; let current_provider = agent .provider() .await @@ -3390,8 +2731,6 @@ impl GooseAcpAgent { let current_provider_name = current_provider.get_name(); let current_model_config = current_provider.get_model_config(); let current_model = current_model_config.model_name.clone(); - let has_default_overrides = - model_name.is_some() || context_limit.is_some() || request_params.is_some(); let use_default_provider = provider_name == DEFAULT_PROVIDER_ID; let resolved_provider_name = if use_default_provider { config @@ -3424,7 +2763,7 @@ impl GooseAcpAgent { ); let extensions = - EnabledExtensionsState::for_session(&self.session_manager, session_id, &config).await; + EnabledExtensionsState::for_session(&self.session_manager, session_id, config).await; let session = self .session_manager .get_session(session_id, false) @@ -3448,32 +2787,8 @@ impl GooseAcpAgent { .update_goose_mode(mode, session_id) .await .internal_err_ctx("Failed to propagate mode")?; - let provider = agent - .provider() - .await - .internal_err_ctx("Failed to get provider")?; // provider_name is already updated on the session by the agent's update_provider call. - - if use_default_provider { - let update = self - .session_manager - .update(session_id) - .provider_name(DEFAULT_PROVIDER_ID); - if has_default_overrides { - update - .model_config(provider.get_model_config()) - .apply() - .await - .internal_err_ctx("Failed to persist default provider selection overrides")?; - } else { - update - .clear_model_config() - .apply() - .await - .internal_err_ctx("Failed to persist default provider selection")?; - } - } Ok(()) } @@ -3508,10 +2823,14 @@ impl GooseAcpAgent { .into_iter() .map(|s| { let meta = session_meta(&s); - SessionInfo::new(SessionId::new(s.id), s.working_dir) - .title(s.name) + let title = display_title(&s); + let mut info = SessionInfo::new(SessionId::new(s.id), s.working_dir) .updated_at(s.updated_at.to_rfc3339()) - .meta(meta) + .meta(meta); + if let Some(t) = title { + info = info.title(t); + } + info }) .collect(); let next_cursor = page @@ -3527,80 +2846,7 @@ impl GooseAcpAgent { cx: &ConnectionTo, args: ForkSessionRequest, ) -> Result { - validate_absolute_cwd(&args.cwd)?; - let source_session_id = &*args.session_id.0; - - let new_session = self - .session_manager - .copy_session(source_session_id, "Fork".to_string()) - .await - .internal_err()?; - let new_session_id = new_session.id.clone(); - - // Update working dir for the fork. - self.session_manager - .update(&new_session_id) - .working_dir(args.cwd.clone()) - .apply() - .await - .internal_err()?; - - let goose_session = self - .session_manager - .get_session(&new_session_id, false) - .await - .internal_err()?; - - let (agent_tx, agent_rx) = tokio::sync::watch::channel::(None); - - let acp_session = GooseAcpSession { - agent: AgentHandle::Loading(agent_rx), - tool_requests: HashMap::new(), - chain_membership: HashMap::new(), - responded_tool_ids: HashSet::new(), - summarized_chains: HashSet::new(), - cancel_token: None, - pending_working_dir: None, - }; - self.sessions - .lock() - .await - .insert(new_session_id.clone(), acp_session); - - let mode_state = build_mode_state(self.goose_mode)?; - let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await; - let (model_state, config_options, prebuilt_provider) = self - .prepare_session_init_config(&resolved, &mode_state, &goose_session) - .await; - - let acp_session_id = SessionId::new(new_session_id.clone()); - - self.spawn_agent_setup( - cx, - agent_tx, - AgentSetupRequest { - session_id: acp_session_id.clone(), - goose_session, - mcp_servers: args.mcp_servers, - resolved_provider: resolved.ok(), - prebuilt_provider, - }, - ); - - let meta = session_meta(&new_session); - - let mut response = ForkSessionResponse::new(acp_session_id.clone()) - .modes(mode_state) - .meta(meta); - - if let Some(ms) = model_state { - response = response.models(ms); - } - if let Some(co) = config_options { - response = response.config_options(co); - } - Self::send_available_commands_update(cx, &acp_session_id, &args.cwd)?; - Ok(response) + self.handle_fork_session(cx, args).await } async fn on_close_session( @@ -3614,6 +2860,10 @@ impl GooseAcpAgent { } } sessions.remove(session_id); + drop(sessions); + + let _ = self.agent_manager.remove_session(session_id).await; + info!(session_id = %session_id, "ACP session closed"); Ok(CloseSessionResponse::new()) } @@ -3671,8 +2921,7 @@ mod tests { use crate::conversation::message::{ToolRequest, ToolResponse}; use agent_client_protocol::schema::{ EnvVariable, HttpHeader, McpServer, McpServerHttp, McpServerSse, McpServerStdio, - PermissionOptionId, ResourceLink, SelectedPermissionOutcome, SessionConfigSelectOption, - SessionMode, SessionModeId, SessionModeState, + PermissionOptionId, ResourceLink, SelectedPermissionOutcome, }; use rmcp::model::{CallToolRequestParams, Content as RmcpContent}; use std::io::Write; @@ -3834,90 +3083,6 @@ print(\"hello, world\") ); } - fn tool_request_block(id: &str) -> crate::conversation::message::MessageContent { - crate::conversation::message::MessageContent::ToolRequest(ToolRequest { - id: id.to_string(), - tool_call: Ok(CallToolRequestParams::new("dummy")), - metadata: None, - tool_meta: None, - }) - } - - fn text_block(text: &str) -> crate::conversation::message::MessageContent { - crate::conversation::message::MessageContent::text(text) - } - - #[test] - fn extract_tool_chains_returns_empty_for_no_tool_blocks() { - let content = vec![text_block("hello"), text_block("world")]; - assert!(extract_tool_chains(&content).is_empty()); - } - - #[test] - fn extract_tool_chains_returns_single_chain_when_only_tools() { - let content = vec![ - tool_request_block("a"), - tool_request_block("b"), - tool_request_block("c"), - ]; - let chains = extract_tool_chains(&content); - assert_eq!( - chains, - vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]] - ); - } - - #[test] - fn extract_tool_chains_breaks_on_text_block() { - let content = vec![ - tool_request_block("a"), - tool_request_block("b"), - text_block("interlude"), - tool_request_block("c"), - tool_request_block("d"), - ]; - let chains = extract_tool_chains(&content); - assert_eq!( - chains, - vec![ - vec!["a".to_string(), "b".to_string()], - vec!["c".to_string(), "d".to_string()], - ] - ); - } - - #[test] - fn extract_tool_chains_includes_singletons() { - let content = vec![ - tool_request_block("a"), - text_block("split"), - tool_request_block("b"), - text_block("split"), - tool_request_block("c"), - ]; - let chains = extract_tool_chains(&content); - assert_eq!( - chains, - vec![ - vec!["a".to_string()], - vec!["b".to_string()], - vec!["c".to_string()], - ] - ); - } - - #[test] - fn extract_tool_chains_keeps_run_when_text_leads_or_trails() { - let content = vec![ - text_block("intro"), - tool_request_block("a"), - tool_request_block("b"), - text_block("outro"), - ]; - let chains = extract_tool_chains(&content); - assert_eq!(chains, vec![vec!["a".to_string(), "b".to_string()]]); - } - fn buf_entry(tool_id: &str, msg_id: &str) -> (String, String) { (tool_id.to_string(), msg_id.to_string()) } @@ -4148,56 +3313,6 @@ print(\"hello, world\") assert_eq!(outcome_to_confirmation(&input), expected); } - #[test_case( - vec!["model-a".into(), "model-b".into()] - => SessionModelState::new( - ModelId::new("unused"), - vec![ModelInfo::new(ModelId::new("unused"), "unused"), - ModelInfo::new(ModelId::new("model-a"), "model-a"), - ModelInfo::new(ModelId::new("model-b"), "model-b")], - ) - ; "returns current and available models" - )] - #[test_case( - vec![] - => SessionModelState::new( - ModelId::new("unused"), - vec![ModelInfo::new(ModelId::new("unused"), "unused")], - ) - ; "empty model list" - )] - fn test_build_model_state(models: Vec) -> SessionModelState { - let inventory = ProviderInventoryEntry { - provider_id: "mock".to_string(), - provider_name: "Mock".to_string(), - description: "Mock".to_string(), - default_model: "unused".to_string(), - configured: true, - provider_type: crate::providers::base::ProviderType::Builtin, - category: crate::providers::catalog::ProviderSetupCategory::Model, - config_keys: vec![], - setup_steps: vec![], - supports_refresh: true, - refreshing: false, - models: models - .into_iter() - .map(|id| crate::providers::inventory::InventoryModel { - name: id.clone(), - id, - family: None, - context_limit: None, - reasoning: None, - recommended: false, - }) - .collect(), - last_updated_at: None, - last_refresh_attempt_at: None, - last_refresh_error: None, - model_selection_hint: None, - }; - build_model_state("unused", &inventory) - } - fn json_object(pairs: Vec<(&str, serde_json::Value)>) -> rmcp::model::JsonObject { pairs.into_iter().map(|(k, v)| (k.to_string(), v)).collect() } @@ -4424,6 +3539,57 @@ print(\"hello, world\") ); } + #[test] + fn test_message_update_meta_includes_created_and_message_id() { + let meta = message_update_meta(Some("msg_live"), 1_700_000_000); + + assert_eq!( + meta.get("goose"), + Some(&serde_json::json!({ + "created": 1_700_000_000, + "messageId": "msg_live", + })), + ); + } + + #[test] + fn test_credits_exhausted_system_notification_maps_to_prompt_error() { + let content = MessageContent::SystemNotification(SystemNotificationContent { + notification_type: SystemNotificationType::CreditsExhausted, + msg: "Please add credits to your account, then resend your message to continue." + .to_string(), + data: Some(serde_json::json!({ + "top_up_url": "https://router.tetrate.ai/billing" + })), + }); + + let error = prompt_error_from_message_content(&content).expect("expected prompt error"); + let value = serde_json::to_value(error).unwrap(); + + assert_eq!( + value, + serde_json::json!({ + "code": -32603, + "message": "Please add credits to your account, then resend your message to continue.", + "data": { + "reason": "credits_exhausted", + "url": "https://router.tetrate.ai/billing" + } + }) + ); + } + + #[test] + fn test_non_credit_system_notification_does_not_map_to_prompt_error() { + let content = MessageContent::SystemNotification(SystemNotificationContent { + notification_type: SystemNotificationType::InlineMessage, + msg: "Compaction complete".to_string(), + data: None, + }); + + assert!(prompt_error_from_message_content(&content).is_none()); + } + #[test] fn test_merge_replay_message_meta_omits_message_id_when_none() { let message = Message::new(Role::Assistant, 1_700_000_000, vec![]); @@ -4533,122 +3699,27 @@ print(\"hello, world\") #[test] fn test_build_usage_update_clamps_negative_used_to_zero() { - let session = make_session_with_usage(Some(-7), Some(0), Some(0), None, None, None); - let usage = build_usage_update(&session, 258_000); + let mut session = make_session_with_usage(Some(-7), Some(0), Some(0), None, None, None); + session.model_config = Some( + crate::model::ModelConfig::new("test-model") + .unwrap() + .with_context_limit(Some(258_000)), + ); + let updates = build_usage_updates(&session).expect("usage updates should be present"); + assert_eq!(updates.custom.session_id, "session-1"); + let usage = match updates.custom.update { + GooseSessionUpdate::UsageUpdate(usage) => usage, + other => panic!("expected usage update, got {other:?}"), + }; assert_eq!(usage.used, 0); - assert_eq!(usage.size, 258_000); - } - - #[test_case( - GooseMode::Auto - => Ok(SessionModeState::new( - SessionModeId::new("auto"), - vec![ - SessionMode::new(SessionModeId::new("auto"), "auto") - .description("Automatically approve tool calls"), - SessionMode::new(SessionModeId::new("approve"), "approve") - .description("Ask before every tool call"), - SessionMode::new(SessionModeId::new("smart_approve"), "smart_approve") - .description("Ask only for sensitive tool calls"), - SessionMode::new(SessionModeId::new("chat"), "chat") - .description("Chat only, no tool calls"), - ], - )) - ; "auto mode" - )] - #[test_case( - GooseMode::Approve - => Ok(SessionModeState::new( - SessionModeId::new("approve"), - vec![ - SessionMode::new(SessionModeId::new("auto"), "auto") - .description("Automatically approve tool calls"), - SessionMode::new(SessionModeId::new("approve"), "approve") - .description("Ask before every tool call"), - SessionMode::new(SessionModeId::new("smart_approve"), "smart_approve") - .description("Ask only for sensitive tool calls"), - SessionMode::new(SessionModeId::new("chat"), "chat") - .description("Chat only, no tool calls"), - ], - )) - ; "approve mode" - )] - fn test_build_mode_state( - current_mode: GooseMode, - ) -> Result { - build_mode_state(current_mode) + assert_eq!(usage.context_limit, 258_000); + assert_eq!(updates.standard.used, 0); + assert_eq!(updates.standard.size, 258_000); } - #[test_case( - build_mode_state(GooseMode::Auto).unwrap(), - "openai", - vec![ - SessionConfigSelectOption::new("anthropic", "anthropic"), - SessionConfigSelectOption::new("openai", "openai"), - ], - SessionModelState::new( - ModelId::new("gpt-4"), - vec![ModelInfo::new(ModelId::new("gpt-4"), "gpt-4"), ModelInfo::new(ModelId::new("gpt-3.5"), "gpt-3.5")], - ) - => vec![ - SessionConfigOption::select( - "provider", "Provider", "openai", - vec![ - SessionConfigSelectOption::new("anthropic", "anthropic"), - SessionConfigSelectOption::new("openai", "openai"), - ], - ), - SessionConfigOption::select( - "mode", "Mode", "auto", - vec![ - SessionConfigSelectOption::new("auto", "auto").description("Automatically approve tool calls"), - SessionConfigSelectOption::new("approve", "approve").description("Ask before every tool call"), - SessionConfigSelectOption::new("smart_approve", "smart_approve").description("Ask only for sensitive tool calls"), - SessionConfigSelectOption::new("chat", "chat").description("Chat only, no tool calls"), - ], - ).category(SessionConfigOptionCategory::Mode), - SessionConfigOption::select( - "model", "Model", "gpt-4", - vec![ - SessionConfigSelectOption::new("gpt-4", "gpt-4"), - SessionConfigSelectOption::new("gpt-3.5", "gpt-3.5"), - ], - ).category(SessionConfigOptionCategory::Model), - ] - ; "auto mode with multiple models" - )] - #[test_case( - build_mode_state(GooseMode::Approve).unwrap(), - "openai", - vec![SessionConfigSelectOption::new("openai", "openai")], - SessionModelState::new(ModelId::new("only-model"), vec![ModelInfo::new(ModelId::new("only-model"), "only-model")]) - => vec![ - SessionConfigOption::select( - "provider", "Provider", "openai", - vec![SessionConfigSelectOption::new("openai", "openai")], - ), - SessionConfigOption::select( - "mode", "Mode", "approve", - vec![ - SessionConfigSelectOption::new("auto", "auto").description("Automatically approve tool calls"), - SessionConfigSelectOption::new("approve", "approve").description("Ask before every tool call"), - SessionConfigSelectOption::new("smart_approve", "smart_approve").description("Ask only for sensitive tool calls"), - SessionConfigSelectOption::new("chat", "chat").description("Chat only, no tool calls"), - ], - ).category(SessionConfigOptionCategory::Mode), - SessionConfigOption::select( - "model", "Model", "only-model", - vec![SessionConfigSelectOption::new("only-model", "only-model")], - ).category(SessionConfigOptionCategory::Model), - ] - ; "approve mode with single model" - )] - fn test_build_config_options( - mode_state: SessionModeState, - provider_name: &'static str, - provider_options: Vec, - model_state: SessionModelState, - ) -> Vec { - build_config_options(&mode_state, &model_state, provider_name, provider_options) + #[test] + fn test_build_usage_update_requires_model_config() { + let session = make_session_with_usage(Some(120), Some(80), Some(40), None, None, None); + assert!(build_usage_updates(&session).is_none()); } } diff --git a/crates/goose/src/acp/server/config.rs b/crates/goose/src/acp/server/config.rs index 41a9b96d491c..1657e79e583f 100644 --- a/crates/goose/src/acp/server/config.rs +++ b/crates/goose/src/acp/server/config.rs @@ -114,11 +114,11 @@ impl GooseAcpAgent { let config = self.config()?; let model = model_id.clone().unwrap_or_else(|| { - crate::config::get_provider_entry(&config, &provider_id) + crate::config::get_provider_entry(config, &provider_id) .map(|e| e.model) .unwrap_or_default() }); - crate::config::set_active_provider(&config, &provider_id, &model) + crate::config::set_active_provider(config, &provider_id, &model) .internal_err_ctx("Failed to save default provider")?; Ok(DefaultsReadResponse { diff --git a/crates/goose/src/acp/server/custom_dispatch.rs b/crates/goose/src/acp/server/custom_dispatch.rs index 81c6e67107e1..65c23aa8f6ed 100644 --- a/crates/goose/src/acp/server/custom_dispatch.rs +++ b/crates/goose/src/acp/server/custom_dispatch.rs @@ -306,6 +306,15 @@ impl GooseAcpAgent { self.on_import_session(req).await } + #[custom_method(ElicitationRespondRequest)] + async fn dispatch_elicitation_respond( + &self, + _req: ElicitationRespondRequest, + ) -> Result { + Err(agent_client_protocol::Error::invalid_params() + .data("_goose/unstable/elicitation/respond must be handled by the connection-scoped dispatcher")) + } + #[custom_method(UpdateSessionProjectRequest)] async fn dispatch_update_session_project( &self, diff --git a/crates/goose/src/acp/server/dispatch.rs b/crates/goose/src/acp/server/dispatch.rs index 7e23333e2545..b02e03511734 100644 --- a/crates/goose/src/acp/server/dispatch.rs +++ b/crates/goose/src/acp/server/dispatch.rs @@ -1,4 +1,5 @@ use super::*; +use crate::providers::inventory::ensure_refresh_identity_current; impl HandleDispatchFrom for GooseAcpHandler { fn describe_chain(&self) -> impl std::fmt::Debug { @@ -16,6 +17,12 @@ impl HandleDispatchFrom for GooseAcpHandler { // The MatchDispatchFrom chain produces an ~85KB async state machine. // Box::pin moves it to the heap so it doesn't overflow the tokio worker stack. Box::pin(async move { + // Capture the connection handle so handlers can lazily activate + // sessions that exist on disk but were never activated via + // new_session/load_session on this connection. Set-once per + // connection; the result is ignored on later requests. + let _ = agent.client_cx.set(cx.clone()); + // InitializeRequest runs inline: it sets connection-scoped state // (client fs/terminal capabilities) that later handlers read with // defaults, so a pipelined NewSessionRequest must not race ahead of it. @@ -88,6 +95,19 @@ impl HandleDispatchFrom for GooseAcpHandler { Ok(()) }) .await + .if_request({ + let agent = agent.clone(); + let cx = cx.clone(); + |req: ElicitationRespondRequest, responder: Responder| async move { + let cx_spawn = cx.clone(); + cx.spawn(async move { + responder.respond_with_result(agent.on_elicitation_respond(&cx_spawn, req).await)?; + Ok(()) + })?; + Ok(()) + } + }) + .await // set_config_option (SACP 11) and legacy set_mode/set_model; custom _goose/* in otherwise. .if_request({ let agent = agent.clone(); diff --git a/crates/goose/src/acp/server/fork_session.rs b/crates/goose/src/acp/server/fork_session.rs new file mode 100644 index 000000000000..c0e3c764b47a --- /dev/null +++ b/crates/goose/src/acp/server/fork_session.rs @@ -0,0 +1,72 @@ +use super::*; + +impl GooseAcpAgent { + #[allow(dead_code)] + pub(super) async fn handle_fork_session( + &self, + cx: &ConnectionTo, + args: ForkSessionRequest, + ) -> Result { + validate_absolute_cwd(&args.cwd)?; + let source_session_id = &*args.session_id.0; + + let source = self + .session_manager + .get_session(source_session_id, false) + .await + .internal_err()?; + let fork_name = if source.name.trim().is_empty() { + "(copy)".to_string() + } else { + format!("{} (copy)", source.name) + }; + + let new_session = self + .session_manager + .copy_session(source_session_id, fork_name) + .await + .internal_err()?; + let new_session_id = new_session.id.clone(); + + let goose_session = self + .session_manager + .get_session(&new_session_id, false) + .await + .internal_err()?; + + let goose_session = self + .prepare_session_for_activation( + goose_session, + args.cwd.clone(), + args.mcp_servers, + false, + ) + .await?; + + let (_agent, extension_results) = self + .activate_acp_session(cx, &goose_session, HashMap::new()) + .await?; + + let acp_session_id = SessionId::new(new_session_id.clone()); + let mut meta = session_meta(&new_session); + if let Ok(v) = serde_json::to_value(&extension_results) { + meta.insert("extensionResults".to_string(), v); + } + + let (mode_state, model_state, config_options) = + build_session_setup_config(&self.provider_inventory, &goose_session).await?; + + let mut response = ForkSessionResponse::new(acp_session_id.clone()) + .modes(mode_state) + .meta(meta); + + if let Some(ms) = model_state { + response = response.models(ms); + } + if let Some(co) = config_options { + response = response.config_options(co); + } + send_session_setup_notifications(cx, &goose_session)?; + Ok(response) + } +} diff --git a/crates/goose/src/acp/server/load_session.rs b/crates/goose/src/acp/server/load_session.rs new file mode 100644 index 000000000000..6a2fc8626335 --- /dev/null +++ b/crates/goose/src/acp/server/load_session.rs @@ -0,0 +1,288 @@ +use super::*; + +fn replay_audience_annotations(audience: &[Role]) -> Annotations { + Annotations::new().audience( + audience + .iter() + .map(|role| match role { + Role::Assistant => agent_client_protocol::schema::Role::Assistant, + Role::User => agent_client_protocol::schema::Role::User, + }) + .collect::>(), + ) +} + +fn send_replay_content_chunk( + cx: &ConnectionTo, + session_id: &SessionId, + message: &Message, + content: ContentBlock, +) -> std::result::Result<(), agent_client_protocol::Error> { + let chunk = ContentChunk::new(content).meta(replay_message_meta(message)); + let update = match message.role { + Role::User => SessionUpdate::UserMessageChunk(chunk), + Role::Assistant => SessionUpdate::AgentMessageChunk(chunk), + }; + cx.send_notification(SessionNotification::new(session_id.clone(), update)) +} + +fn replay_conversation_to_client( + cx: &ConnectionTo, + session: &Session, +) -> Result, agent_client_protocol::Error> +{ + let session_id = SessionId::new(session.id.clone()); + let sid = sid_short(session_id.0.as_ref()); + + let messages = session + .conversation + .as_ref() + .map(|c| c.messages().to_vec()) + .unwrap_or_default(); + debug!( + target: "perf", + sid = %sid, + messages = messages.len(), + "perf: load_session messages loaded" + ); + + let mut replay_tool_requests = + HashMap::::new(); + let submitted_elicitation_ids = collect_submitted_elicitation_ids(&messages); + + for message in &messages { + if !message.metadata.user_visible { + continue; + } + + for content_item in &message.content { + match content_item { + MessageContent::Text(text) => { + let mut tc = TextContent::new(text.text.clone()); + if let Some(audience) = text.audience() { + tc = tc.annotations(replay_audience_annotations(audience)); + } + send_replay_content_chunk(cx, &session_id, message, ContentBlock::Text(tc))?; + } + MessageContent::Image(image) => { + let mut image_content = + ImageContent::new(image.data.clone(), image.mime_type.clone()); + if let Some(audience) = image.audience() { + image_content = + image_content.annotations(replay_audience_annotations(audience)); + } + send_replay_content_chunk( + cx, + &session_id, + message, + ContentBlock::Image(image_content), + )?; + } + MessageContent::ToolRequest(tool_request) => { + replay_tool_requests.insert(tool_request.id.clone(), tool_request.clone()); + + let pending_tool_call = pending_tool_call_from_request(tool_request); + let mut meta = pending_tool_call.identity_meta; + if let Some(chain_summary) = tool_request.persisted_chain_summary() { + meta = with_tool_chain_summary_meta( + meta, + &chain_summary.summary, + chain_summary.count, + ); + } + let tool_call = pending_tool_call + .tool_call + .meta(merge_replay_message_meta(meta, message)); + + cx.send_notification(SessionNotification::new( + session_id.clone(), + SessionUpdate::ToolCall(tool_call), + ))?; + } + MessageContent::ToolResponse(tool_response) => { + let status = match &tool_response.tool_result { + Ok(result) if result.is_error == Some(true) => ToolCallStatus::Failed, + Ok(_) => ToolCallStatus::Completed, + Err(_) => ToolCallStatus::Failed, + }; + + let mut fields = ToolCallUpdateFields::new().status(status); + if let Some(raw_output) = extract_tool_raw_output(&tool_response.tool_result) { + fields = fields.raw_output(raw_output); + } + if !tool_response + .tool_result + .as_ref() + .is_ok_and(|r| r.is_acp_aware()) + { + let content = build_tool_call_content(&tool_response.tool_result); + fields = fields.content(content); + + let locations = + extract_locations_from_meta(tool_response).unwrap_or_else(|| { + if let Some(tool_request) = + replay_tool_requests.get(&tool_response.id) + { + extract_tool_locations(tool_request, tool_response) + } else { + Vec::new() + } + }); + if !locations.is_empty() { + fields = fields.locations(locations); + } + } + + let update = + ToolCallUpdate::new(ToolCallId::new(tool_response.id.clone()), fields) + .meta(merge_replay_message_meta( + extract_tool_call_update_meta(tool_response), + message, + )); + cx.send_notification(SessionNotification::new( + session_id.clone(), + SessionUpdate::ToolCallUpdate(update), + ))?; + } + MessageContent::Thinking(thinking) => { + cx.send_notification(SessionNotification::new( + session_id.clone(), + SessionUpdate::AgentThoughtChunk( + ContentChunk::new(ContentBlock::Text(TextContent::new( + thinking.thinking.clone(), + ))) + .meta(replay_message_meta(message)), + ), + ))?; + } + MessageContent::ActionRequired(action_required) => { + if let ActionRequiredData::Elicitation { + id, + message: elicitation_message, + requested_schema, + } = &action_required.data + { + if !submitted_elicitation_ids.contains(id) { + send_elicitation_interaction_update( + cx, + session_id.0.as_ref(), + id.clone(), + InteractionState::Pending, + Some(elicitation_message.clone()), + Some(requested_schema.clone()), + Some(serde_json::Value::Object(replay_message_meta(message))), + )?; + } + } + } + MessageContent::SystemNotification(_) => {} + _ => {} + } + } + } + + Ok(replay_tool_requests) +} + +fn collect_submitted_elicitation_ids(messages: &[Message]) -> HashSet { + let mut submitted_ids = HashSet::new(); + + for message in messages { + for content_item in &message.content { + if let MessageContent::ActionRequired(action_required) = content_item { + if let ActionRequiredData::ElicitationResponse { id, .. } = &action_required.data { + submitted_ids.insert(id.clone()); + } + } + } + } + + submitted_ids +} + +impl GooseAcpAgent { + pub(super) async fn handle_load_session( + &self, + cx: &ConnectionTo, + args: LoadSessionRequest, + ) -> Result { + debug!(?args, "load session request"); + validate_absolute_cwd(&args.cwd)?; + + let session_id_str = args.session_id.0.to_string(); + let sid = sid_short(&session_id_str); + let t_start = std::time::Instant::now(); + + let mut session = self + .session_manager + .get_session(&session_id_str, true) + .await + .map_err(|_| { + agent_client_protocol::Error::resource_not_found(Some(session_id_str.clone())) + .data(format!("Session not found: {}", session_id_str)) + })?; + + session = self + .prepare_session_for_activation(session, args.cwd.clone(), args.mcp_servers, true) + .await?; + + let replay_tool_requests = replay_conversation_to_client(cx, &session)?; + let (agent, extension_results) = self.prepare_acp_session_agent(cx, &session).await?; + self.register_acp_session(session_id_str.clone(), agent.clone(), replay_tool_requests) + .await; + + session = self + .session_manager + .get_session(&session_id_str, true) + .await + .internal_err_ctx("Failed to reload session")?; + + agent + .extension_manager + .update_working_dir(&session.working_dir) + .await; + + let (mode_state, model_state, config_options) = + build_session_setup_config(&self.provider_inventory, &session).await?; + + send_session_setup_notifications(cx, &session)?; + + let mut response = LoadSessionResponse::new().modes(mode_state); + if let Some(ms) = model_state { + response = response.models(ms); + } + if let Some(co) = config_options { + response = response.config_options(co); + } + + let mut meta = serde_json::Map::new(); + if let Some(recipe) = &session.recipe { + if let Ok(v) = serde_json::to_value(recipe) { + meta.insert("recipe".to_string(), v); + } + } + if let Some(values) = &session.user_recipe_values { + if let Ok(v) = serde_json::to_value(values) { + meta.insert("userRecipeValues".to_string(), v); + } + } + if let Ok(v) = serde_json::to_value(&extension_results) { + meta.insert("extensionResults".to_string(), v); + } + meta.insert( + "workingDir".to_string(), + serde_json::Value::String(session.working_dir.to_string_lossy().to_string()), + ); + if !meta.is_empty() { + response = response.meta(meta); + } + + debug!( + target: "perf", + sid = %sid, + ms = t_start.elapsed().as_millis() as u64, + "perf: load_session_refactor done" + ); + Ok(response) + } +} diff --git a/crates/goose/src/acp/server/sessions.rs b/crates/goose/src/acp/server/manage_sessions.rs similarity index 74% rename from crates/goose/src/acp/server/sessions.rs rename to crates/goose/src/acp/server/manage_sessions.rs index e46a5e75b35d..3c45f2b4793d 100644 --- a/crates/goose/src/acp/server/sessions.rs +++ b/crates/goose/src/acp/server/manage_sessions.rs @@ -20,13 +20,51 @@ impl GooseAcpAgent { .await .internal_err()?; - if let Some(session) = self.sessions.lock().await.get_mut(session_id) { - match &session.agent { - AgentHandle::Ready(agent) => { - agent.extension_manager.update_working_dir(&path).await; + if let Some(session) = self.sessions.lock().await.get(session_id) { + session + .agent + .extension_manager + .update_working_dir(&path) + .await; + } + + Ok(EmptyResponse {}) + } + + pub(super) async fn on_set_session_system_prompt( + &self, + req: SetSessionSystemPromptRequest, + ) -> Result { + let session_id = req.session_id.trim(); + if session_id.is_empty() { + return Err( + agent_client_protocol::Error::invalid_params().data("sessionId cannot be empty") + ); + } + + let agent = self.get_session_agent(session_id, None).await?; + match req.mode { + SessionSystemPromptMode::Set => { + if req.text.trim().is_empty() { + agent.clear_system_prompt_override().await; + } else { + agent.override_system_prompt(req.text).await; } - AgentHandle::Loading(_) => { - session.pending_working_dir = Some(path); + } + SessionSystemPromptMode::Append => { + let key = req + .key + .as_deref() + .map(str::trim) + .filter(|key| !key.is_empty()) + .ok_or_else(|| { + agent_client_protocol::Error::invalid_params() + .data("key cannot be empty for append mode") + })?; + if req.text.trim().is_empty() { + agent.remove_system_prompt_extra(key).await; + } else { + agent.extend_system_prompt(key.to_string(), req.text).await; } } } @@ -84,6 +122,7 @@ impl GooseAcpAgent { .await .internal_err()?; self.sessions.lock().await.remove(&req.session_id); + let _ = self.agent_manager.remove_session(&req.session_id).await; Ok(EmptyResponse {}) } @@ -156,6 +195,7 @@ impl GooseAcpAgent { .await .internal_err()?; self.sessions.lock().await.remove(&req.session_id); + let _ = self.agent_manager.remove_session(&req.session_id).await; Ok(EmptyResponse {}) } diff --git a/crates/goose/src/acp/server/new_session.rs b/crates/goose/src/acp/server/new_session.rs new file mode 100644 index 000000000000..dba27c844d49 --- /dev/null +++ b/crates/goose/src/acp/server/new_session.rs @@ -0,0 +1,108 @@ +use crate::acp::server::{meta_string, sid_short, validate_absolute_cwd, ResultExt}; +use crate::config::{Config, GooseMode}; +use crate::session::SessionType; + +use super::GooseAcpAgent; +use agent_client_protocol::schema::{NewSessionRequest, NewSessionResponse, SessionId}; +use agent_client_protocol::{Client, ConnectionTo}; +use std::collections::HashMap; +use tracing::debug; + +impl GooseAcpAgent { + #[allow(dead_code)] + pub(super) async fn handle_new_session( + &self, + cx: &ConnectionTo, + args: NewSessionRequest, + ) -> Result { + debug!(?args, "new session request"); + let t_start = std::time::Instant::now(); + validate_absolute_cwd(&args.cwd)?; + let project_id = meta_string(args.meta.as_ref(), "projectId"); + let session_type = match meta_string(args.meta.as_ref(), "client") { + Some(_) => SessionType::User, + None => SessionType::Acp, + }; + let config = Config::global(); + let (resolved_provider, resolved_model_config) = + match meta_string(args.meta.as_ref(), "provider") { + Some(provider) => { + let model_config = + super::resolve_provider_default_model_config(&provider).await?; + (provider, model_config) + } + None => super::resolve_default_provider_model_config(config)?, + }; + let current_mode: GooseMode = config.get_goose_mode().unwrap_or_default(); + let t0 = std::time::Instant::now(); + let mut goose_session = self + .session_manager + .create_session( + args.cwd.clone(), + "New Chat".to_string(), + session_type, + current_mode, + ) + .await + .internal_err_ctx("Failed to create session")?; + let mut builder = self.session_manager.update(&goose_session.id); + let extension_data = + self.build_enabled_extensions_data(config, &goose_session, args.mcp_servers)?; + builder = builder + .provider_name(resolved_provider) + .model_config(resolved_model_config) + .extension_data(extension_data); + if let Some(pid) = project_id { + builder = builder.project_id(Some(pid)); + } + builder + .apply() + .await + .internal_err_ctx("Failed to update session")?; + + goose_session = self + .session_manager + .get_session(&goose_session.id, false) + .await + .internal_err_ctx("Failed to reload session")?; + let session_id_str = goose_session.id.clone(); + let sid = sid_short(&session_id_str); + debug!(target: "perf", sid = %sid, ms = t0.elapsed().as_millis() as u64, "perf: new_session create_session"); + + let (_agent, extension_results) = self + .activate_acp_session(cx, &goose_session, HashMap::new()) + .await?; + + let goose_session = self + .session_manager + .get_session(&goose_session.id, false) + .await + .internal_err_ctx("Failed to reload session")?; + + let acp_session_id = SessionId::new(session_id_str.clone()); + + let (mode_state, model_state, config_options) = + super::build_session_setup_config(&self.provider_inventory, &goose_session).await?; + + let mut response = NewSessionResponse::new(acp_session_id.clone()).modes(mode_state); + if let Some(ms) = model_state { + response = response.models(ms); + } + if let Some(co) = config_options { + response = response.config_options(co); + } + if let Ok(extension_results) = serde_json::to_value(&extension_results) { + let mut meta = serde_json::Map::new(); + meta.insert("extensionResults".to_string(), extension_results); + response = response.meta(meta); + } + super::send_session_setup_notifications(cx, &goose_session)?; + debug!( + target: "perf", + sid = %sid, + ms = t_start.elapsed().as_millis() as u64, + "perf: new_session done" + ); + Ok(response) + } +} diff --git a/crates/goose/src/acp/server/onboarding.rs b/crates/goose/src/acp/server/onboarding.rs index d0ff12a1eabd..233f7a43e4f9 100644 --- a/crates/goose/src/acp/server/onboarding.rs +++ b/crates/goose/src/acp/server/onboarding.rs @@ -59,7 +59,7 @@ impl GooseAcpAgent { ) -> Result { let config = self.config()?; Ok(apply_onboarding_import_candidates( - &config, + config, &self.config_dir, &req, )) diff --git a/crates/goose/src/acp/server/providers.rs b/crates/goose/src/acp/server/providers.rs index 2916ccf48bb1..c24105fb098f 100644 --- a/crates/goose/src/acp/server/providers.rs +++ b/crates/goose/src/acp/server/providers.rs @@ -1,5 +1,6 @@ use super::*; use crate::config::declarative_providers; +use crate::providers::inventory::ensure_refresh_identity_current; use std::str::FromStr; fn inventory_entry_to_dto(entry: ProviderInventoryEntry) -> ProviderInventoryEntryDto { diff --git a/crates/goose/src/acp/server_factory.rs b/crates/goose/src/acp/server_factory.rs index 079287d967b0..8cd971a6bca1 100644 --- a/crates/goose/src/acp/server_factory.rs +++ b/crates/goose/src/acp/server_factory.rs @@ -23,15 +23,7 @@ impl AcpServer { } pub async fn create_agent(&self) -> Result> { - let config_path = self - .config - .config_dir - .join(crate::config::base::CONFIG_YAML_NAME); - let config = crate::config::Config::new(&config_path, "goose")?; - - let goose_mode = config - .get_goose_mode() - .unwrap_or(crate::config::GooseMode::Auto); + let config = crate::config::Config::global(); let disable_session_naming = config.get_goose_disable_session_naming().unwrap_or(false); let provider_factory: AcpProviderFactory = Arc::new( @@ -60,7 +52,6 @@ impl AcpServer { builtins: self.config.builtins.clone(), data_dir: self.config.data_dir.clone(), config_dir: self.config.config_dir.clone(), - goose_mode, disable_session_naming, goose_platform: self.config.goose_platform.clone(), additional_source_roots: self.config.additional_source_roots.clone(), diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 91f2e8748016..6ab6e62501b6 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -2522,9 +2522,14 @@ impl Agent { .await .is_ok() { - let p = crate::providers::create(&provider_name, model_config, extensions) - .await - .map_err(|e| anyhow!("Could not create provider: {}", e))?; + let p = crate::providers::create_with_working_dir( + &provider_name, + model_config, + extensions, + session.working_dir.clone(), + ) + .await + .map_err(|e| anyhow!("Could not create provider: {}", e))?; (p, false) } else { let fallback_provider_name = config @@ -2552,10 +2557,11 @@ impl Agent { .map_err(|e| anyhow!("Could not configure fallback provider: invalid model {}", e))? .with_canonical_limits(&fallback_provider_name); - let fallback_provider = crate::providers::create( + let fallback_provider = crate::providers::create_with_working_dir( &fallback_provider_name, fallback_model_config.clone(), extensions, + session.working_dir.clone(), ) .await .map_err(|e| { diff --git a/crates/goose/src/agents/execute_commands.rs b/crates/goose/src/agents/execute_commands.rs index e4e1a8c34bc2..8860c89fec16 100644 --- a/crates/goose/src/agents/execute_commands.rs +++ b/crates/goose/src/agents/execute_commands.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use anyhow::{anyhow, Result}; use crate::context_mgmt::compact_messages; -use crate::conversation::message::{Message, SystemNotificationType}; +use crate::conversation::message::Message; use crate::slash_commands::{recipe_slash_command, skill_slash_command}; use super::Agent; @@ -150,10 +150,7 @@ impl Agent { self.update_session_metrics(session_id, session.schedule_id, &usage, true) .await?; - Ok(Some(Message::assistant().with_system_notification( - SystemNotificationType::InlineMessage, - "Compaction complete", - ))) + Ok(Some(user_only_assistant_text("Compaction complete"))) } async fn handle_clear_command(&self, session_id: &str) -> Result> { @@ -172,10 +169,7 @@ impl Agent { .apply() .await?; - Ok(Some(Message::assistant().with_system_notification( - SystemNotificationType::InlineMessage, - "Conversation cleared", - ))) + Ok(Some(user_only_assistant_text("Conversation cleared"))) } async fn handle_skills_command(&self, session_id: &str) -> Result> { @@ -425,9 +419,14 @@ impl Agent { } } +fn user_only_assistant_text(text: impl Into) -> Message { + Message::assistant().with_text(text).user_only() +} + #[cfg(test)] mod tests { use super::*; + use crate::conversation::message::MessageContent; #[test] fn parse_slash_command_splits_on_literal_space() { @@ -447,4 +446,17 @@ mod tests { assert_eq!(parsed.command, "speckit.plan\nhello"); assert_eq!(parsed.params_str, ""); } + + #[test] + fn user_only_assistant_text_is_durable_text_not_system_notification() { + let message = user_only_assistant_text("Conversation cleared"); + + assert!(message.metadata.user_visible); + assert!(!message.metadata.agent_visible); + assert_eq!(message.role, rmcp::model::Role::Assistant); + assert!(matches!( + message.content.as_slice(), + [MessageContent::Text(text)] if text.text == "Conversation cleared" + )); + } } diff --git a/crates/goose/src/bin/generate_acp_schema.rs b/crates/goose/src/bin/generate_acp_schema.rs index 905dacb1a13a..90daed775662 100644 --- a/crates/goose/src/bin/generate_acp_schema.rs +++ b/crates/goose/src/bin/generate_acp_schema.rs @@ -1,3 +1,4 @@ +use goose::acp::custom_notifications::custom_notification_schemas; use goose::acp::server::GooseAcpAgent; use schemars::SchemaGenerator; use serde_json::{json, Map, Value}; @@ -9,6 +10,7 @@ use std::path::PathBuf; fn main() { let mut generator = SchemaGenerator::default(); let methods = GooseAcpAgent::custom_method_schemas(&mut generator); + let notifications = custom_notification_schemas(&mut generator); // Collect $defs from the generator (all types referenced via subschema_for). let mut defs: Map = generator @@ -19,7 +21,7 @@ fn main() { // Track which types map to which methods so we can detect shared types. let mut type_methods: HashMap> = HashMap::new(); - for m in &methods { + for m in methods.iter().chain(notifications.iter()) { let method = m.method.clone(); if let Some(name) = &m.params_type_name { type_methods @@ -90,6 +92,7 @@ fn main() { // deduplicating response variants (e.g. EmptyResponse appears once). let mut request_variants: Vec = Vec::new(); let mut response_variants: Vec = Vec::new(); + let mut notification_variants: Vec = Vec::new(); let mut seen_response_types: BTreeSet = BTreeSet::new(); for m in &methods { @@ -113,6 +116,17 @@ fn main() { } } + for n in ¬ifications { + if let Some(name) = &n.params_type_name { + let generated_name = generated_type_name(name, &unstable_type_names); + notification_variants.push(json!({ + "allOf": [{ "$ref": format!("#/$defs/{generated_name}") }], + "description": format!("Params for {}", n.method), + "title": generated_name, + })); + } + } + // Build ExtRequest — mirrors AgentRequest structure. defs.insert( "ExtRequest".into(), @@ -174,6 +188,25 @@ fn main() { }), ); + // Build ExtNotification — fire-and-forget message with no `id` and no response. + defs.insert( + "ExtNotification".into(), + json!({ + "properties": { + "method": { "type": "string" }, + "params": { + "anyOf": [ + { "anyOf": notification_variants }, + { "description": "Untyped params", "type": ["object", "null"] }, + ] + } + }, + "required": ["method"], + "type": "object", + "x-docs-ignore": true, + }), + ); + // Assemble the root schema document. let root = json!({ "$schema": "https://json-schema.org/draft/2020-12/schema", @@ -189,6 +222,11 @@ fn main() { "allOf": [{ "$ref": "#/$defs/ExtResponse" }], "description": "Extension response (agent → client)", "title": "Response", + }, + { + "allOf": [{ "$ref": "#/$defs/ExtNotification" }], + "description": "Extension notification (agent → client, fire-and-forget)", + "title": "Notification", } ], }); @@ -219,7 +257,22 @@ fn main() { }) }) .collect(); - let meta = json!({ "methods": method_entries }); + let notification_entries: Vec = notifications + .iter() + .map(|n| { + json!({ + "method": &n.method, + "paramsType": n + .params_type_name + .as_ref() + .map(|name| generated_type_name(name, &unstable_type_names)), + }) + }) + .collect(); + let meta = json!({ + "methods": method_entries, + "notifications": notification_entries, + }); let meta_str = serde_json::to_string_pretty(&meta).expect("failed to serialize meta"); let meta_path = package_path.join("acp-meta.json"); fs::write(&meta_path, format!("{meta_str}\n")).expect("failed to write meta file"); diff --git a/crates/goose/src/config/declarative_providers.rs b/crates/goose/src/config/declarative_providers.rs index be5596cb3eba..a3a8fdc10475 100644 --- a/crates/goose/src/config/declarative_providers.rs +++ b/crates/goose/src/config/declarative_providers.rs @@ -1,7 +1,8 @@ use crate::config::paths::Paths; use crate::config::Config; use crate::providers::anthropic::AnthropicProvider; -use crate::providers::base::{ModelInfo, ProviderType}; +use crate::providers::base::{ModelInfo, ProviderDef, ProviderType}; +use crate::providers::huggingface::HuggingFaceProvider; use crate::providers::inventory::declarative_inventory_identity; use crate::providers::ollama::OllamaProvider; use crate::providers::openai::OpenAiProvider; @@ -576,21 +577,48 @@ pub fn register_declarative_provider( ProviderEngine::OpenAI => { let captured = config.clone(); let identity_config = config.clone(); - registry.register_with_name::( - &config, - provider_type, - config.dynamic_models.unwrap_or(false), - move |model| { - let mut cfg = captured.clone(); - resolve_config(&mut cfg)?; - OpenAiProvider::from_custom_config(model, cfg) - }, - move || { - let mut cfg = identity_config.clone(); - resolve_config(&mut cfg)?; - declarative_inventory_identity(&cfg) - }, - ); + if HuggingFaceProvider::matches_declarative_config(&config) { + let inventory_configured_config = config.clone(); + registry + .register_with_name_and_inventory_configured::( + &config, + provider_type, + config.dynamic_models.unwrap_or(false), + move |model| { + let mut cfg = captured.clone(); + resolve_config(&mut cfg)?; + HuggingFaceProvider::from_custom_config(model, cfg) + }, + move || { + let mut cfg = identity_config.clone(); + resolve_config(&mut cfg)?; + declarative_inventory_identity(&cfg) + }, + move || { + let mut cfg = inventory_configured_config.clone(); + if resolve_config(&mut cfg).is_err() { + return false; + } + huggingface_declarative_inventory_configured(&cfg) + }, + ); + } else { + registry.register_with_name::( + &config, + provider_type, + config.dynamic_models.unwrap_or(false), + move |model| { + let mut cfg = captured.clone(); + resolve_config(&mut cfg)?; + OpenAiProvider::from_custom_config(model, cfg) + }, + move || { + let mut cfg = identity_config.clone(); + resolve_config(&mut cfg)?; + declarative_inventory_identity(&cfg) + }, + ); + } } ProviderEngine::Ollama => { let captured = config.clone(); @@ -633,10 +661,120 @@ pub fn register_declarative_provider( } } +fn huggingface_declarative_inventory_configured(config: &DeclarativeProviderConfig) -> bool { + huggingface_declarative_inventory_configured_from_sources( + config, + |key| Config::global().get_secret::(key).is_ok(), + HuggingFaceProvider::inventory_configured, + ) +} + +fn huggingface_declarative_inventory_configured_from_sources( + config: &DeclarativeProviderConfig, + provider_secret_configured: impl FnOnce(&str) -> bool, + global_huggingface_configured: impl FnOnce() -> bool, +) -> bool { + if !config.requires_auth { + return true; + } + + if !config.api_key_env.is_empty() { + return provider_secret_configured(&config.api_key_env); + } + + global_huggingface_configured() +} + #[cfg(test)] mod tests { use super::*; + fn test_huggingface_config() -> DeclarativeProviderConfig { + DeclarativeProviderConfig { + name: "custom_hf".to_string(), + engine: ProviderEngine::OpenAI, + display_name: "Custom HF".to_string(), + description: None, + api_key_env: String::new(), + base_url: "https://router.huggingface.co/v1".to_string(), + models: vec![ModelInfo { + name: "test/model".to_string(), + resolved_model: None, + context_limit: 128_000, + input_token_cost: None, + output_token_cost: None, + currency: None, + supports_cache_control: None, + reasoning: false, + }], + headers: None, + timeout_seconds: None, + supports_streaming: Some(true), + requires_auth: true, + catalog_provider_id: Some("huggingface".to_string()), + base_path: None, + env_vars: None, + dynamic_models: Some(false), + skip_canonical_filtering: false, + model_doc_link: None, + setup_steps: Vec::new(), + fast_model: None, + preserves_thinking: true, + } + } + + #[test] + fn huggingface_inventory_allows_unauthenticated_custom_provider() { + let mut config = test_huggingface_config(); + config.requires_auth = false; + + assert!(huggingface_declarative_inventory_configured_from_sources( + &config, + |_| false, + || false, + )); + } + + #[test] + fn huggingface_inventory_accepts_provider_specific_key() { + let mut config = test_huggingface_config(); + config.api_key_env = "CUSTOM_HF_TOKEN".to_string(); + + assert!(huggingface_declarative_inventory_configured_from_sources( + &config, + |key| key == "CUSTOM_HF_TOKEN", + || false, + )); + } + + #[test] + fn huggingface_inventory_does_not_fallback_when_explicit_key_is_missing() { + let mut config = test_huggingface_config(); + config.api_key_env = "CUSTOM_HF_TOKEN".to_string(); + + assert!(!huggingface_declarative_inventory_configured_from_sources( + &config, + |_| false, + || true, + )); + } + + #[test] + fn huggingface_inventory_uses_global_token_without_provider_key() { + let config = test_huggingface_config(); + + assert!(huggingface_declarative_inventory_configured_from_sources( + &config, + |_| false, + || true, + )); + assert!(!huggingface_declarative_inventory_configured_from_sources( + &config, + |_| true, + || false, + )); + } + #[test] fn test_tanzu_json_deserializes() { let json = include_str!("../providers/declarative/tanzu.json"); diff --git a/crates/goose/src/download_manager.rs b/crates/goose/src/download_manager.rs index 8463cf0ecd5a..7ad8dd0a895b 100644 --- a/crates/goose/src/download_manager.rs +++ b/crates/goose/src/download_manager.rs @@ -131,12 +131,48 @@ impl DownloadManager { .await } + pub async fn download_model_with_bearer_token( + &self, + model_id: String, + url: String, + destination: PathBuf, + bearer_token: Option, + on_complete: Option>, + ) -> Result<()> { + self.download_model_sharded_with_bearer_token( + model_id, + vec![(url, destination)], + 0, + bearer_token, + on_complete, + ) + .await + } + pub async fn download_model_sharded( &self, model_id: String, files: Vec<(String, PathBuf)>, total_size_hint: u64, on_complete: Option>, + ) -> Result<()> { + self.download_model_sharded_with_bearer_token( + model_id, + files, + total_size_hint, + None, + on_complete, + ) + .await + } + + pub async fn download_model_sharded_with_bearer_token( + &self, + model_id: String, + files: Vec<(String, PathBuf)>, + total_size_hint: u64, + bearer_token: Option, + on_complete: Option>, ) -> Result<()> { info!(model_id = %model_id, file_count = files.len(), "Starting model download"); { @@ -186,8 +222,13 @@ impl DownloadManager { let files_for_cleanup: Vec = files.iter().map(|(_, d)| d.clone()).collect(); tokio::spawn(async move { - let result = - Self::download_files_sequentially(&files, &downloads, &model_id_clone).await; + let result = Self::download_files_sequentially( + &files, + &downloads, + &model_id_clone, + bearer_token.as_deref(), + ) + .await; match result { Ok(_) => { @@ -262,6 +303,7 @@ impl DownloadManager { files: &[(String, PathBuf)], downloads: &DownloadMap, model_id: &str, + bearer_token: Option<&str>, ) -> Result<(), anyhow::Error> { let client = reqwest::Client::builder() .connect_timeout(std::time::Duration::from_secs(30)) @@ -273,8 +315,7 @@ impl DownloadManager { let mut total: u64 = 0; let mut all_resolved = true; for (url, _) in files { - let size = client - .head(url) + let size = Self::apply_bearer_token(client.head(url), bearer_token) .send() .await .ok() @@ -329,6 +370,7 @@ impl DownloadManager { &mut cumulative_bytes, start_time, bytes_at_start, + bearer_token, ) .await?; } @@ -346,6 +388,7 @@ impl DownloadManager { cumulative_bytes: &mut u64, start_time: std::time::Instant, bytes_at_start: u64, + bearer_token: Option<&str>, ) -> Result<(), anyhow::Error> { let partial_path = partial_path_for(destination); let mut retries = 0u32; @@ -357,8 +400,7 @@ impl DownloadManager { }; // Get this file's total size - let mut file_total: u64 = client - .head(url) + let mut file_total: u64 = Self::apply_bearer_token(client.head(url), bearer_token) .send() .await .ok() @@ -386,7 +428,7 @@ impl DownloadManager { anyhow::bail!("Download cancelled"); } - let mut request = client.get(url); + let mut request = Self::apply_bearer_token(client.get(url), bearer_token); if file_bytes > 0 { request = request.header("Range", format!("bytes={}-", file_bytes)); } @@ -584,6 +626,17 @@ impl DownloadManager { } } } + + fn apply_bearer_token( + request: reqwest::RequestBuilder, + bearer_token: Option<&str>, + ) -> reqwest::RequestBuilder { + if let Some(token) = bearer_token.filter(|token| !token.is_empty()) { + request.header("Authorization", format!("Bearer {}", token)) + } else { + request + } + } } static DOWNLOAD_MANAGER: once_cell::sync::Lazy = diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs index 302dd7d9571d..3c4a7414db03 100644 --- a/crates/goose/src/execution/manager.rs +++ b/crates/goose/src/execution/manager.rs @@ -1,16 +1,17 @@ -use crate::agents::{Agent, AgentConfig, GoosePlatform}; +use crate::agents::mcp_client::GooseMcpHostInfo; +use crate::agents::{Agent, AgentConfig, ExtensionLoadResult, GoosePlatform}; use crate::config::paths::Paths; use crate::config::permission::PermissionManager; -use crate::config::{Config, GooseMode}; +use crate::config::Config; use crate::scheduler::Scheduler; use crate::scheduler_trait::SchedulerTrait; -use crate::session::SessionManager; +use crate::session::{SessionManager, SessionNameUpdate}; use anyhow::Result; use lru::LruCache; use std::collections::HashMap; use std::num::NonZeroUsize; use std::sync::Arc; -use tokio::sync::{Mutex, OnceCell, RwLock}; +use tokio::sync::{mpsc, Mutex, OnceCell, RwLock}; use tokio_util::sync::CancellationToken; use tracing::{debug, info}; @@ -18,12 +19,23 @@ const DEFAULT_MAX_SESSION: usize = 100; static AGENT_MANAGER: OnceCell> = OnceCell::const_new(); +#[derive(Clone, Default)] +pub struct RuntimeContext { + pub mcp_host_info: Option, + pub use_login_shell_path: Option, + pub session_name_update_tx: Option>, +} + +pub struct AgentManagerGetResult { + pub agent: Arc, + pub agent_created: bool, + pub extension_results: Vec, +} + pub struct AgentManager { sessions: Arc>>>, - scheduler: Arc, - session_manager: Arc, + agent_config: AgentConfig, default_provider: Arc>>>, - default_mode: GooseMode, cancel_tokens: Arc>>, /// Per-session creation locks. When `get_or_create_agent` misses the /// `sessions` cache it acquires the per-session lock before doing the @@ -37,23 +49,14 @@ pub struct AgentManager { } impl AgentManager { - pub async fn new( - session_manager: Arc, - schedule_file_path: std::path::PathBuf, - max_sessions: Option, - default_mode: GooseMode, - ) -> Result { - let scheduler = Scheduler::new(schedule_file_path, session_manager.clone()).await?; - + pub async fn new(agent_config: AgentConfig, max_sessions: Option) -> Result { let capacity = NonZeroUsize::new(max_sessions.unwrap_or(DEFAULT_MAX_SESSION)) .unwrap_or_else(|| NonZeroUsize::new(100).unwrap()); let manager = Self { sessions: Arc::new(RwLock::new(LruCache::new(capacity))), - scheduler, - session_manager, + agent_config, default_provider: Arc::new(RwLock::new(None)), - default_mode, cancel_tokens: Arc::new(RwLock::new(HashMap::new())), creation_locks: Arc::new(Mutex::new(HashMap::new())), }; @@ -71,13 +74,18 @@ impl AgentManager { let default_mode = config.get_goose_mode().unwrap_or_default(); let schedule_file_path = Paths::data_dir().join("schedule.json"); let session_manager = Arc::new(SessionManager::instance()); - let manager = Self::new( + let scheduler = Scheduler::new(schedule_file_path, Arc::clone(&session_manager)) + .await + .map(|scheduler| scheduler as Arc)?; + let agent_config = AgentConfig::new( session_manager, - schedule_file_path, - Some(max_sessions), + PermissionManager::instance(), + Some(scheduler), default_mode, - ) - .await?; + config.get_goose_disable_session_naming().unwrap_or(false), + GoosePlatform::GooseDesktop, + ); + let manager = Self::new(agent_config, Some(max_sessions)).await?; Ok(Arc::new(manager)) }) .await @@ -85,12 +93,17 @@ impl AgentManager { } pub fn scheduler(&self) -> Arc { - Arc::clone(&self.scheduler) + Arc::clone( + self.agent_config + .scheduler_service + .as_ref() + .expect("AgentManager scheduler is not configured"), + ) } /// Get the shared SessionManager for session-only operations pub fn session_manager(&self) -> &SessionManager { - &self.session_manager + self.agent_config.session_manager.as_ref() } pub async fn set_default_provider(&self, provider: Arc) { @@ -99,11 +112,26 @@ impl AgentManager { } pub async fn get_or_create_agent(&self, session_id: String) -> Result> { + Ok(self + .get_or_create_agent_with_runtime_context(session_id, RuntimeContext::default()) + .await? + .agent) + } + + pub async fn get_or_create_agent_with_runtime_context( + &self, + session_id: String, + runtime_context: RuntimeContext, + ) -> Result { // Fast path: agent already cached. { let mut sessions = self.sessions.write().await; if let Some(existing) = sessions.get(&session_id) { - return Ok(Arc::clone(existing)); + return Ok(AgentManagerGetResult { + agent: Arc::clone(existing), + agent_created: false, + extension_results: Vec::new(), + }); } } @@ -128,7 +156,7 @@ impl AgentManager { // bail out via `?`, leaving a permanent `creation_locks` entry // for a session that never made it into the LRU cache and that // no one will ever call `remove_session` on. - let result = self.create_agent_locked(&session_id).await; + let result = self.create_agent_locked(&session_id, runtime_context).await; if result.is_err() { // Release BOTH the guard and our local Arc clone of the @@ -149,37 +177,49 @@ impl AgentManager { /// Slow-path body for `get_or_create_agent`. Must be called with the /// per-session creation lock held by the caller. - async fn create_agent_locked(&self, session_id: &str) -> Result> { + async fn create_agent_locked( + &self, + session_id: &str, + runtime_context: RuntimeContext, + ) -> Result { // Re-check under the creation lock: another caller may have // finished creating the agent while we were waiting. { let mut sessions = self.sessions.write().await; if let Some(existing) = sessions.get(session_id) { - return Ok(Arc::clone(existing)); + return Ok(AgentManagerGetResult { + agent: Arc::clone(existing), + agent_created: false, + extension_results: Vec::new(), + }); } } - let mut mode = self.default_mode; - let permission_manager = PermissionManager::instance(); - - if let Ok(session) = self.session_manager.get_session(session_id, false).await { + let mut mode = self.agent_config.goose_mode; + if let Ok(session) = self + .agent_config + .session_manager + .get_session(session_id, false) + .await + { mode = session.goose_mode; info!(goose_mode = %mode, session_id = %session_id, "Session loaded"); } - let config = AgentConfig::new( - Arc::clone(&self.session_manager), - permission_manager, - Some(Arc::clone(&self.scheduler)), - mode, - Config::global() - .get_goose_disable_session_naming() - .unwrap_or(false), - GoosePlatform::GooseDesktop, - ); + let mut config = self.agent_config.clone(); + config.goose_mode = mode; + config.mcp_host_info = runtime_context.mcp_host_info; + config.use_login_shell_path = runtime_context.use_login_shell_path; + config.session_name_update_tx = runtime_context.session_name_update_tx; let agent = Arc::new(Agent::with_config(config)); + let mut extension_results = Vec::new(); - if let Ok(session) = self.session_manager.get_session(session_id, false).await { + if let Ok(session) = self + .agent_config + .session_manager + .get_session(session_id, false) + .await + { if session.provider_name.is_some() { info!( "Restoring evicted session {} (provider: {:?})", @@ -193,7 +233,7 @@ impl AgentManager { ); } } - agent.load_extensions_from_session(&session).await; + extension_results = agent.load_extensions_from_session(&session).await; } if agent.provider().await.is_err() { @@ -210,7 +250,11 @@ impl AgentManager { let mut sessions = self.sessions.write().await; if let Some(existing) = sessions.get(session_id) { - return Ok(Arc::clone(existing)); + return Ok(AgentManagerGetResult { + agent: Arc::clone(existing), + agent_created: false, + extension_results: Vec::new(), + }); } // `push` returns the LRU-evicted entry when the cache is at // capacity, which `put` does not surface. We need the evicted @@ -226,7 +270,11 @@ impl AgentManager { self.prune_creation_lock(&evicted_id).await; } - Ok(agent) + Ok(AgentManagerGetResult { + agent, + agent_created: true, + extension_results, + }) } /// Drop the per-session creation lock for `session_id` if no other @@ -328,6 +376,8 @@ mod tests { use test_case::test_case; + use crate::agents::{AgentConfig, GoosePlatform}; + use crate::config::permission::PermissionManager; use crate::config::GooseMode; use crate::execution::SessionExecutionMode; use crate::session::SessionManager; @@ -336,15 +386,15 @@ mod tests { async fn create_test_manager(temp_dir: &TempDir) -> AgentManager { let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); - let schedule_path = temp_dir.path().join("schedule.json"); - AgentManager::new( + let agent_config = AgentConfig::new( session_manager, - schedule_path, - Some(100), + PermissionManager::instance(), + None, GooseMode::default(), - ) - .await - .unwrap() + false, + GoosePlatform::GooseDesktop, + ); + AgentManager::new(agent_config, Some(100)).await.unwrap() } #[test] @@ -632,15 +682,15 @@ mod tests { // even though only `max_sessions` agents remain cached. let temp_dir = TempDir::new().unwrap(); let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); - let schedule_path = temp_dir.path().join("schedule.json"); - let manager = AgentManager::new( + let agent_config = AgentConfig::new( session_manager, - schedule_path, - Some(2), + PermissionManager::instance(), + None, GooseMode::default(), - ) - .await - .unwrap(); + false, + GoosePlatform::GooseDesktop, + ); + let manager = AgentManager::new(agent_config, Some(2)).await.unwrap(); manager.get_or_create_agent("a".into()).await.unwrap(); manager.get_or_create_agent("b".into()).await.unwrap(); diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index aaee33764927..6bc52a140c3d 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -2,6 +2,7 @@ compile_error!("Features `rustls-tls` and `native-tls` are mutually exclusive"); pub mod acp; +pub use goose_sdk::custom_notifications; pub use goose_sdk::custom_requests; pub mod action_required_manager; pub mod agents; diff --git a/crates/goose/src/providers/catalog.rs b/crates/goose/src/providers/catalog.rs index 9dda84e58444..05b032ccb3ed 100644 --- a/crates/goose/src/providers/catalog.rs +++ b/crates/goose/src/providers/catalog.rs @@ -375,6 +375,23 @@ const SETUP_METADATA: &[CuratedSetupMetadata] = &[ secret_field_default: Some(API_KEY_FIELD), field_overrides: &[], }, + CuratedSetupMetadata { + provider_id: "huggingface", + category: ProviderSetupCategory::Model, + setup_method: ProviderSetupMethod::SingleApiKey, + group: ProviderSetupGroup::Default, + display_name: Some("Hugging Face"), + description: Some("Hugging Face Inference Providers"), + docs_url: Some("https://huggingface.co/docs/inference-providers"), + aliases: &["huggingface", "hf"], + native_connect_query: None, + binary_name: None, + setup_capabilities: setup_capabilities(false, false, false), + show_only_when_installed: false, + synthetic: false, + secret_field_default: Some(API_KEY_FIELD), + field_overrides: &[], + }, CuratedSetupMetadata { provider_id: "chatgpt_codex", category: ProviderSetupCategory::Model, @@ -1199,6 +1216,20 @@ mod tests { ["DATABRICKS_HOST", "DATABRICKS_TOKEN"] ); + let huggingface = entries + .iter() + .find(|entry| entry.provider_id == "huggingface") + .expect("setup catalog should include huggingface"); + assert_eq!(huggingface.setup_method, ProviderSetupMethod::SingleApiKey); + assert_eq!( + huggingface + .fields + .iter() + .map(|field| field.key.as_str()) + .collect::>(), + ["HF_TOKEN"] + ); + let atomic_chat = entries .iter() .find(|entry| entry.provider_id == "atomic_chat") diff --git a/crates/goose/src/providers/gemini_oauth.rs b/crates/goose/src/providers/gemini_oauth.rs index 2126c624f09f..b3ac6a493480 100644 --- a/crates/goose/src/providers/gemini_oauth.rs +++ b/crates/goose/src/providers/gemini_oauth.rs @@ -851,6 +851,11 @@ impl GeminiOAuthProvider { }) } + pub async fn cleanup() -> Result<()> { + TokenCache::new().clear(); + Ok(()) + } + async fn post_stream( &self, session_id: Option<&str>, diff --git a/crates/goose/src/providers/huggingface.rs b/crates/goose/src/providers/huggingface.rs new file mode 100644 index 000000000000..b36e1bea2bda --- /dev/null +++ b/crates/goose/src/providers/huggingface.rs @@ -0,0 +1,569 @@ +use super::api_client::{ApiClient, AuthMethod, AuthProvider}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, + DEFAULT_PROVIDER_TIMEOUT_SECS, +}; +use super::errors::ProviderError; +use super::huggingface_auth; +use super::inventory::{default_inventory_identity, InventoryIdentityInput}; +use super::openai_compatible::OpenAiCompatibleProvider; +use crate::config::declarative_providers::DeclarativeProviderConfig; +use crate::config::{Config, ConfigError}; +use crate::conversation::message::Message; +use crate::model::ModelConfig; +use anyhow::{anyhow, Result}; +use futures::future::BoxFuture; +use rmcp::model::Tool; + +pub const HUGGINGFACE_API_HOST: &str = "https://router.huggingface.co/v1"; +pub const HUGGINGFACE_DOC_URL: &str = "https://huggingface.co/docs/inference-providers"; +pub const HUGGINGFACE_DEFAULT_MODEL: &str = "Qwen/Qwen3-Coder-480B-A35B-Instruct"; +pub const HUGGINGFACE_KNOWN_MODELS: &[&str] = &[ + "MiniMaxAI/MiniMax-M2.1", + "MiniMaxAI/MiniMax-M2.5", + "MiniMaxAI/MiniMax-M2.7", + "Qwen/Qwen3-235B-A22B-Thinking", + "Qwen/Qwen3-Coder-480B-A35B-Instruct", + "Qwen/Qwen3-Coder-Next", + "Qwen/Qwen3-Embedding-4B", + "Qwen/Qwen3-Embedding-8B", + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "Qwen/Qwen3-Next-80B-A3B-Thinking", + "Qwen/Qwen3.5-397B-A17B", + "XiaomiMiMo/MiMo-V2-Flash", + "deepseek-ai/DeepSeek-R1", + "deepseek-ai/DeepSeek-V3.2", + "deepseek-ai/DeepSeek-V4-Pro", + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2-Thinking", + "moonshotai/Kimi-K2.5", + "moonshotai/Kimi-K2.6", + "zai-org/GLM-4.7", + "zai-org/GLM-4.7-Flash", + "zai-org/GLM-5", + "zai-org/GLM-5.1", +]; + +type QueryParams = Vec<(String, String)>; +type EndpointParts = (String, String, QueryParams); + +pub struct HuggingFaceProvider { + inner: OpenAiCompatibleProvider, + custom_models: Option>, + dynamic_models: Option, +} + +struct HuggingFaceAuthProvider; + +#[async_trait::async_trait] +impl AuthProvider for HuggingFaceAuthProvider { + async fn get_auth_header(&self) -> Result<(String, String)> { + let token = huggingface_auth::resolve_token_async() + .await? + .ok_or_else(missing_token_error)?; + Ok(("Authorization".to_string(), format!("Bearer {}", token))) + } +} + +impl HuggingFaceProvider { + pub fn matches_declarative_config(config: &DeclarativeProviderConfig) -> bool { + config.name == huggingface_auth::HUGGINGFACE_PROVIDER_NAME + || config.catalog_provider_id.as_deref() + == Some(huggingface_auth::HUGGINGFACE_PROVIDER_NAME) + } + + pub fn from_custom_config( + model: ModelConfig, + config: DeclarativeProviderConfig, + ) -> Result { + let custom_models = static_model_names(&config); + if config.dynamic_models == Some(false) && custom_models.is_none() { + return Err(anyhow!( + "Provider '{}' has dynamic_models: false but no static models listed; \ + at least one entry in `models` is required.", + config.name + )); + } + + let auth_method = custom_auth_method(&config)?; + let (host, completions_prefix, query_params) = + openai_compatible_endpoint_parts(&config.base_url, config.base_path.as_deref())?; + + let timeout_secs = config + .timeout_seconds + .unwrap_or(DEFAULT_PROVIDER_TIMEOUT_SECS); + let mut api_client = ApiClient::with_timeout( + host, + auth_method, + std::time::Duration::from_secs(timeout_secs), + )? + .with_query(query_params); + + if let Some(headers) = &config.headers { + let mut header_map = reqwest::header::HeaderMap::new(); + for (key, value) in headers { + let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?; + let header_value = reqwest::header::HeaderValue::from_str(value)?; + header_map.insert(header_name, header_value); + } + api_client = api_client.with_headers(header_map)?; + } + + let model = if let Some(ref fast_model_name) = config.fast_model { + model.with_fast(fast_model_name, &config.name)? + } else { + model + }; + + Ok(Self { + inner: OpenAiCompatibleProvider::new( + config.name.clone(), + api_client, + model, + completions_prefix, + ) + .with_supports_streaming(config.supports_streaming.unwrap_or(true)), + custom_models, + dynamic_models: config.dynamic_models, + }) + } + + pub async fn cleanup() -> Result<()> { + huggingface_auth::clear_oauth_token() + } +} + +#[async_trait::async_trait] +impl Provider for HuggingFaceProvider { + fn get_name(&self) -> &str { + self.inner.get_name() + } + + fn get_model_config(&self) -> ModelConfig { + self.inner.get_model_config() + } + + async fn fetch_supported_models(&self) -> Result, ProviderError> { + if let Some(custom_models) = &self.custom_models { + if self.dynamic_models == Some(false) { + return Ok(custom_models.clone()); + } + + match self.inner.fetch_supported_models().await { + Ok(models) => return Ok(models), + Err(e) if e.is_endpoint_not_found() => { + tracing::debug!( + "Models endpoint not implemented for Hugging Face provider '{}' ({}), using predefined list", + self.inner.get_name(), + e + ); + return Ok(custom_models.clone()); + } + Err(e) => return Err(e), + } + } + + self.inner.fetch_supported_models().await + } + + async fn stream( + &self, + model_config: &ModelConfig, + session_id: &str, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + self.inner + .stream(model_config, session_id, system, messages, tools) + .await + } +} + +impl ProviderDef for HuggingFaceProvider { + type Provider = Self; + + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + huggingface_auth::HUGGINGFACE_PROVIDER_NAME, + huggingface_auth::HUGGINGFACE_DISPLAY_NAME, + "Hugging Face Inference Providers via the Hugging Face Router", + HUGGINGFACE_DEFAULT_MODEL, + HUGGINGFACE_KNOWN_MODELS.to_vec(), + HUGGINGFACE_DOC_URL, + vec![ + ConfigKey::new( + huggingface_auth::HUGGINGFACE_TOKEN_SECRET_KEY, + true, + true, + None, + true, + ), + ConfigKey::new("HF_HOST", false, false, Some(HUGGINGFACE_API_HOST), false), + ], + ) + } + + fn from_env( + model: ModelConfig, + _extensions: Vec, + ) -> BoxFuture<'static, Result> { + Box::pin(async move { + let config = Config::global(); + let auth_method = + refreshable_huggingface_auth_method(huggingface_auth::has_configured_token)?; + let host: String = config + .get_param("HF_HOST") + .unwrap_or_else(|_| HUGGINGFACE_API_HOST.to_string()); + let api_client = ApiClient::new(host, auth_method)?; + + Ok(Self { + inner: OpenAiCompatibleProvider::new( + huggingface_auth::HUGGINGFACE_PROVIDER_NAME.to_string(), + api_client, + model, + String::new(), + ), + custom_models: None, + dynamic_models: None, + }) + }) + } + + fn inventory_identity() -> Result { + let metadata = Self::metadata(); + Ok(default_inventory_identity( + &metadata.name, + &metadata.name, + &metadata.config_keys, + Config::global(), + )) + } + + fn inventory_configured() -> bool { + huggingface_auth::has_configured_token().unwrap_or(false) + } +} + +fn missing_token_error() -> anyhow::Error { + anyhow!( + "Hugging Face token is not configured. Sign in from Settings > Auth or configure HF_TOKEN." + ) +} + +fn configured_api_key(config: &DeclarativeProviderConfig) -> Result> { + if config.api_key_env.is_empty() { + return Ok(None); + } + + match Config::global().get_secret::(&config.api_key_env) { + Ok(token) => Ok(Some(token)), + Err(ConfigError::NotFound(_)) => Ok(None), + Err(error) => Err(error.into()), + } +} + +fn static_model_names(config: &DeclarativeProviderConfig) -> Option> { + (!config.models.is_empty()).then(|| { + config + .models + .iter() + .map(|model| model.name.clone()) + .collect() + }) +} + +fn custom_auth_method(config: &DeclarativeProviderConfig) -> Result { + let configured_key = if config.requires_auth { + configured_api_key(config)? + } else { + None + }; + custom_auth_method_with_provider_token(config.requires_auth, configured_key) +} + +fn custom_auth_method_with_provider_token( + requires_auth: bool, + provider_token: Option, +) -> Result { + custom_auth_method_from_sources( + requires_auth, + provider_token, + huggingface_auth::has_configured_token, + ) +} + +fn custom_auth_method_from_sources( + requires_auth: bool, + provider_token: Option, + has_global_token: impl FnOnce() -> Result, +) -> Result { + if !requires_auth { + return Ok(AuthMethod::NoAuth); + } + + if let Some(token) = provider_token { + return Ok(AuthMethod::BearerToken(token)); + } + + refreshable_huggingface_auth_method(has_global_token) +} + +fn refreshable_huggingface_auth_method( + has_configured_token: impl FnOnce() -> Result, +) -> Result { + if !has_configured_token()? { + return Err(missing_token_error()); + } + + Ok(AuthMethod::Custom(Box::new(HuggingFaceAuthProvider))) +} + +fn openai_compatible_endpoint_parts( + base_url: &str, + base_path: Option<&str>, +) -> Result { + let url = + url::Url::parse(base_url).map_err(|e| anyhow!("Invalid base URL '{}': {}", base_url, e))?; + let mut host = if let Some(port) = url.port() { + format!( + "{}://{}:{}", + url.scheme(), + url.host_str().unwrap_or_default(), + port + ) + } else { + format!("{}://{}", url.scheme(), url.host_str().unwrap_or_default()) + }; + let query_params = url + .query_pairs() + .map(|(key, value)| (key.into_owned(), value.into_owned())) + .collect::>(); + + if let Some(path) = base_path { + return Ok((host, completions_prefix(path), query_params)); + } + + let path = url.path().trim_matches('/'); + if path.is_empty() { + return Ok((host, String::new(), query_params)); + } + + if let Some(parent) = path + .strip_suffix("/chat/completions") + .or_else(|| (path == "chat/completions").then_some("")) + { + if !parent.is_empty() { + host.push('/'); + host.push_str(parent); + } + return Ok((host, String::new(), query_params)); + } + + host.push('/'); + host.push_str(path); + Ok((host, String::new(), query_params)) +} + +fn completions_prefix(path: &str) -> String { + let path = path.trim_matches('/'); + if path.is_empty() { + return String::new(); + } + + let parent = path + .strip_suffix("/chat/completions") + .or_else(|| (path == "chat/completions").then_some("")) + .unwrap_or(path); + + if parent.is_empty() { + String::new() + } else { + format!("{}/", parent) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::providers::base::ModelInfo; + + #[test] + fn metadata_preserves_huggingface_id_and_token_key() { + let metadata = HuggingFaceProvider::metadata(); + assert_eq!(metadata.name, "huggingface"); + assert_eq!(metadata.display_name, "Hugging Face"); + assert_eq!(metadata.default_model, HUGGINGFACE_DEFAULT_MODEL); + assert!(metadata + .config_keys + .iter() + .any(|key| key.name == "HF_TOKEN" && key.secret)); + } + + #[test] + fn declarative_matching_accepts_name_or_catalog_provider_id() { + let mut config = test_config(); + assert!(!HuggingFaceProvider::matches_declarative_config(&config)); + + config.name = "huggingface".to_string(); + assert!(HuggingFaceProvider::matches_declarative_config(&config)); + + config.name = "custom_hugging_face".to_string(); + config.catalog_provider_id = Some("huggingface".to_string()); + assert!(HuggingFaceProvider::matches_declarative_config(&config)); + } + + #[test] + fn endpoint_parts_use_base_url_path_as_api_host() { + let (host, prefix, query) = + openai_compatible_endpoint_parts("https://router.huggingface.co/v1?beta=1", None) + .unwrap(); + assert_eq!(host, "https://router.huggingface.co/v1"); + assert_eq!(prefix, ""); + assert_eq!(query, vec![("beta".to_string(), "1".to_string())]); + } + + #[test] + fn endpoint_parts_strip_chat_completions_suffix() { + let (host, prefix, query) = openai_compatible_endpoint_parts( + "https://router.huggingface.co/v1/chat/completions", + None, + ) + .unwrap(); + assert_eq!(host, "https://router.huggingface.co/v1"); + assert_eq!(prefix, ""); + assert!(query.is_empty()); + } + + #[test] + fn endpoint_parts_respect_explicit_base_path() { + let (host, prefix, query) = openai_compatible_endpoint_parts( + "https://router.huggingface.co", + Some("v1/chat/completions"), + ) + .unwrap(); + assert_eq!(host, "https://router.huggingface.co"); + assert_eq!(prefix, "v1/"); + assert!(query.is_empty()); + } + + #[tokio::test] + async fn custom_provider_returns_static_models_when_dynamic_models_disabled() { + let mut config = test_config(); + config.requires_auth = false; + config.dynamic_models = Some(false); + config.models = vec![ + ModelInfo::new("static-a".to_string(), 128000), + ModelInfo::new("static-b".to_string(), 128000), + ]; + + let provider = + HuggingFaceProvider::from_custom_config(ModelConfig::new("static-a").unwrap(), config) + .unwrap(); + + assert_eq!( + provider.fetch_supported_models().await.unwrap(), + vec!["static-a".to_string(), "static-b".to_string()] + ); + } + + #[test] + fn custom_provider_requires_static_models_when_dynamic_models_disabled() { + let mut config = test_config(); + config.requires_auth = false; + config.dynamic_models = Some(false); + + let error = match HuggingFaceProvider::from_custom_config( + ModelConfig::new("model").unwrap(), + config, + ) { + Ok(_) => panic!("expected dynamic_models: false without static models to fail"), + Err(error) => error, + }; + + assert_eq!( + error.to_string(), + "Provider 'custom_provider' has dynamic_models: false but no static models listed; at least one entry in `models` is required." + ); + } + + #[test] + fn custom_auth_method_respects_no_auth_config() { + let auth_method = + custom_auth_method_with_provider_token(false, Some("provider-token".to_string())) + .unwrap(); + + assert!(matches!(auth_method, AuthMethod::NoAuth)); + } + + #[test] + fn custom_auth_method_uses_provider_token_when_auth_is_required() { + let auth_method = + custom_auth_method_with_provider_token(true, Some("provider-token".to_string())) + .unwrap(); + + match auth_method { + AuthMethod::BearerToken(token) => assert_eq!(token, "provider-token"), + other => panic!("expected bearer token auth, got {other:?}"), + } + } + + #[test] + fn custom_auth_method_uses_refresh_capable_auth_for_global_token() { + let auth_method = custom_auth_method_from_sources(true, None, || Ok(true)).unwrap(); + + assert!(matches!(auth_method, AuthMethod::Custom(_))); + } + + #[test] + fn refreshable_huggingface_auth_method_uses_refresh_capable_auth() { + let auth_method = refreshable_huggingface_auth_method(|| Ok(true)).unwrap(); + + assert!(matches!(auth_method, AuthMethod::Custom(_))); + } + + #[test] + fn refreshable_huggingface_auth_method_requires_configured_token() { + let error = refreshable_huggingface_auth_method(|| Ok(false)).unwrap_err(); + + assert_eq!( + error.to_string(), + "Hugging Face token is not configured. Sign in from Settings > Auth or configure HF_TOKEN." + ); + } + + #[test] + fn custom_auth_method_requires_global_token_when_auth_is_required() { + let error = custom_auth_method_from_sources(true, None, || Ok(false)).unwrap_err(); + + assert_eq!( + error.to_string(), + "Hugging Face token is not configured. Sign in from Settings > Auth or configure HF_TOKEN." + ); + } + + fn test_config() -> DeclarativeProviderConfig { + DeclarativeProviderConfig { + name: "custom_provider".to_string(), + engine: crate::config::declarative_providers::ProviderEngine::OpenAI, + display_name: "Custom Provider".to_string(), + description: None, + api_key_env: "CUSTOM_API_KEY".to_string(), + base_url: HUGGINGFACE_API_HOST.to_string(), + models: Vec::new(), + headers: None, + timeout_seconds: None, + supports_streaming: Some(true), + requires_auth: true, + catalog_provider_id: None, + base_path: None, + env_vars: None, + dynamic_models: None, + skip_canonical_filtering: false, + model_doc_link: None, + setup_steps: vec![], + fast_model: None, + preserves_thinking: true, + } + } +} diff --git a/crates/goose/src/providers/huggingface_auth.rs b/crates/goose/src/providers/huggingface_auth.rs new file mode 100644 index 000000000000..aacb3e2e1172 --- /dev/null +++ b/crates/goose/src/providers/huggingface_auth.rs @@ -0,0 +1,896 @@ +use crate::config::paths::Paths; +use crate::config::{Config, ConfigError}; +use anyhow::{anyhow, Result}; +use axum::{extract::Query, response::Html, routing::get, Router}; +use base64::Engine; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sha2::Digest; +use std::io; +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock}; +use tokio::sync::{oneshot, Mutex as TokioMutex}; + +pub const HUGGINGFACE_PROVIDER_NAME: &str = "huggingface"; +pub const HUGGINGFACE_DISPLAY_NAME: &str = "Hugging Face"; +pub const HUGGINGFACE_TOKEN_SECRET_KEY: &str = "HF_TOKEN"; +pub const HUGGINGFACE_OAUTH_TOKEN_NAME: &str = "OAuth token"; +pub const HUGGINGFACE_OAUTH_CACHE_PATH: &str = "huggingface/oauth/tokens.json"; + +const AUTHORIZE_URL: &str = "https://huggingface.co/oauth/authorize"; +const TOKEN_URL: &str = "https://huggingface.co/oauth/token"; +const OAUTH_SCOPES: &str = "read-repos gated-repos inference-api"; +const HUGGINGFACE_OAUTH_CLIENT_METADATA_URL: &str = + "https://goose-docs.ai/oauth/huggingface-client-metadata.json"; +// This URI must match the redirect URI in the Hugging Face CIMD metadata. +const OAUTH_HOST: [u8; 4] = [127, 0, 0, 1]; +const OAUTH_PORT: u16 = 17863; +const OAUTH_REDIRECT_PATH: &str = "/oauth/huggingface/callback"; +const OAUTH_TIMEOUT_SECS: u64 = 300; +const HTML_AUTO_CLOSE_TIMEOUT_MS: u64 = 2000; + +static HUGGINGFACE_OAUTH_MUTEX: LazyLock> = LazyLock::new(|| TokioMutex::new(())); + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HuggingFaceTokenData { + pub access_token: String, + #[serde(default)] + pub refresh_token: Option, + #[serde(default)] + pub expires_at: Option>, +} + +impl HuggingFaceTokenData { + pub fn is_expired(&self) -> bool { + self.expires_at + .is_some_and(|expires_at| expires_at <= Utc::now()) + } +} + +pub fn oauth_client_id() -> &'static str { + option_env!("GOOSE_HUGGINGFACE_OAUTH_CLIENT_ID") + .filter(|client_id| !client_id.trim().is_empty()) + .unwrap_or(HUGGINGFACE_OAUTH_CLIENT_METADATA_URL) +} + +pub fn oauth_cache_path() -> PathBuf { + Paths::in_config_dir(HUGGINGFACE_OAUTH_CACHE_PATH) +} + +pub fn load_oauth_token() -> Option { + load_oauth_token_from_path(&oauth_cache_path()) +} + +fn load_oauth_token_from_path(path: &Path) -> Option { + let contents = std::fs::read_to_string(path).ok()?; + serde_json::from_str(&contents).ok() +} + +pub fn has_oauth_token() -> bool { + load_oauth_token().is_some() +} + +pub fn usable_oauth_token() -> Option { + usable_oauth_token_from_path(&oauth_cache_path()) +} + +fn usable_oauth_token_from_path(path: &std::path::Path) -> Option { + let token = load_oauth_token_from_path(path)?; + (!token.is_expired()).then_some(token.access_token) +} + +pub fn has_usable_or_refreshable_oauth_token() -> bool { + has_usable_or_refreshable_oauth_token_from_path(&oauth_cache_path()) +} + +fn has_usable_or_refreshable_oauth_token_from_path(path: &std::path::Path) -> bool { + load_oauth_token_from_path(path).is_some_and(|token| { + !token.is_expired() + || token + .refresh_token + .as_deref() + .is_some_and(|token| !token.is_empty()) + }) +} + +pub fn has_configured_token() -> Result { + has_configured_token_from_sources(has_usable_or_refreshable_oauth_token(), hf_token_secret) +} + +fn has_configured_token_from_sources( + has_oauth_token: bool, + secret_fallback: impl FnOnce() -> Result>, +) -> Result { + if has_oauth_token { + return Ok(true); + } + + Ok(secret_fallback()?.is_some()) +} + +pub fn hf_token_secret() -> Result> { + match Config::global().get_secret::(HUGGINGFACE_TOKEN_SECRET_KEY) { + Ok(token) => Ok(Some(token)), + Err(ConfigError::NotFound(_)) => Ok(None), + Err(error) => Err(error.into()), + } +} + +pub fn resolve_token() -> Result> { + resolve_token_from_sources(None, usable_oauth_token(), hf_token_secret) +} + +pub fn resolve_token_with_provider_token(provider_token: Option) -> Result> { + resolve_token_from_sources(provider_token, usable_oauth_token(), hf_token_secret) +} + +pub async fn resolve_token_async() -> Result> { + resolve_token_async_with_provider_token(None).await +} + +pub async fn resolve_token_async_with_provider_token( + provider_token: Option, +) -> Result> { + resolve_token_async_from_sources( + provider_token, + refreshed_or_usable_oauth_token_from_path( + &oauth_cache_path(), + oauth_client_id(), + TOKEN_URL, + ), + hf_token_secret, + ) + .await +} + +async fn resolve_token_async_from_sources( + provider_token: Option, + oauth_token: impl std::future::Future>>, + secret_fallback: impl FnOnce() -> Result>, +) -> Result> { + if provider_token.is_some() { + return Ok(provider_token); + } + + match oauth_token.await { + Ok(Some(token)) => return Ok(Some(token)), + Ok(None) => {} + Err(refresh_error) => { + return match secret_fallback()? { + Some(token) => Ok(Some(token)), + None => Err(refresh_error), + }; + } + } + + secret_fallback() +} + +fn resolve_token_from_sources( + provider_token: Option, + oauth_token: Option, + secret_fallback: impl FnOnce() -> Result>, +) -> Result> { + if provider_token.is_some() { + return Ok(provider_token); + } + + if oauth_token.is_some() { + return Ok(oauth_token); + } + + secret_fallback() +} + +pub fn clear_oauth_token() -> Result<()> { + let path = oauth_cache_path(); + if path.exists() { + std::fs::remove_file(path)?; + } + Ok(()) +} + +pub async fn configure_oauth() -> Result<()> { + let token_data = perform_loopback_oauth_flow(oauth_client_id()).await?; + save_oauth_token(token_data) +} + +fn save_oauth_token(token_data: HuggingFaceTokenData) -> Result<()> { + let path = oauth_cache_path(); + save_oauth_token_to_path(&path, &token_data) +} + +fn save_oauth_token_to_path(path: &Path, token_data: &HuggingFaceTokenData) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + + let contents = serde_json::to_string(&token_data)?; + std::fs::write(path, contents)?; + restrict_token_file_permissions(path)?; + Ok(()) +} + +#[cfg(unix)] +fn restrict_token_file_permissions(path: &Path) -> Result<()> { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; + Ok(()) +} + +#[cfg(not(unix))] +fn restrict_token_file_permissions(_path: &Path) -> Result<()> { + Ok(()) +} + +struct PkceChallenge { + verifier: String, + challenge: String, +} + +fn generate_pkce() -> PkceChallenge { + let verifier = nanoid::nanoid!(64); + let digest = sha2::Sha256::digest(verifier.as_bytes()); + let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest); + PkceChallenge { + verifier, + challenge, + } +} + +fn generate_state() -> String { + nanoid::nanoid!(32) +} + +fn redirect_uri() -> String { + format!( + "http://{}.{}.{}.{}:{}{}", + OAUTH_HOST[0], OAUTH_HOST[1], OAUTH_HOST[2], OAUTH_HOST[3], OAUTH_PORT, OAUTH_REDIRECT_PATH + ) +} + +fn build_authorize_url(client_id: &str, pkce: &PkceChallenge, state: &str) -> Result { + let redirect = redirect_uri(); + let params = [ + ("response_type", "code"), + ("client_id", client_id), + ("redirect_uri", redirect.as_str()), + ("scope", OAUTH_SCOPES), + ("code_challenge", pkce.challenge.as_str()), + ("code_challenge_method", "S256"), + ("state", state), + ]; + let query = serde_urlencoded::to_string(params)?; + Ok(format!("{}?{}", AUTHORIZE_URL, query)) +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default)] + expires_in: Option, +} + +fn token_data_from_response(response: TokenResponse) -> HuggingFaceTokenData { + token_data_from_response_with_refresh_fallback(response, None) +} + +fn token_data_from_response_with_refresh_fallback( + response: TokenResponse, + refresh_token_fallback: Option, +) -> HuggingFaceTokenData { + HuggingFaceTokenData { + access_token: response.access_token, + refresh_token: response.refresh_token.or(refresh_token_fallback), + expires_at: response + .expires_in + .map(|secs| Utc::now() + chrono::Duration::seconds(secs)), + } +} + +async fn exchange_code_for_tokens( + client_id: &str, + code: &str, + pkce: &PkceChallenge, +) -> Result { + let client = reqwest::Client::new(); + let redirect = redirect_uri(); + let params = [ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", redirect.as_str()), + ("client_id", client_id), + ("code_verifier", pkce.verifier.as_str()), + ]; + + let resp = client + .post(TOKEN_URL) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(anyhow!( + "Hugging Face token exchange failed ({}): {}", + status, + text + )); + } + + Ok(resp.json().await?) +} + +async fn refresh_access_token( + client_id: &str, + refresh_token: &str, + token_url: &str, +) -> Result { + let client = reqwest::Client::new(); + let params = [ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token), + ("client_id", client_id), + ]; + + let resp = client + .post(token_url) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(anyhow!( + "Hugging Face token refresh failed ({}): {}", + status, + text + )); + } + + Ok(resp.json().await?) +} + +async fn refreshed_or_usable_oauth_token_from_path( + path: &Path, + client_id: &str, + token_url: &str, +) -> Result> { + let Some(token) = load_oauth_token_from_path(path) else { + return Ok(None); + }; + + if !token.is_expired() { + return Ok(Some(token.access_token)); + } + + let Some(refresh_token) = token.refresh_token else { + return Ok(None); + }; + + let refreshed = refresh_access_token(client_id, &refresh_token, token_url).await?; + let refreshed = + token_data_from_response_with_refresh_fallback(refreshed, Some(refresh_token.clone())); + let access_token = refreshed.access_token.clone(); + save_oauth_token_to_path(path, &refreshed)?; + Ok(Some(access_token)) +} + +const HTML_SUCCESS_TEMPLATE: &str = r#" + + + goose - Hugging Face Authorization Successful + + + + +
+

Authorization Successful

+

You can close this window and return to goose.

+
+ +"#; + +fn html_success() -> String { + HTML_SUCCESS_TEMPLATE.replace("{timeout_ms}", &HTML_AUTO_CLOSE_TIMEOUT_MS.to_string()) +} + +fn html_error(error: &str) -> String { + let safe_error = v_htmlescape::escape_fmt(error); + format!( + r#" + + + goose - Hugging Face Authorization Failed + + + +
+

Authorization Failed

+

An error occurred during authorization.

+
{}
+
+ +"#, + safe_error + ) +} + +#[derive(Deserialize)] +struct CallbackParams { + code: Option, + state: Option, + error: Option, + error_description: Option, +} + +fn oauth_callback_router( + expected_state: String, + tx: Arc>>>>, +) -> Router { + Router::new().route( + OAUTH_REDIRECT_PATH, + get(move |Query(params): Query| { + let tx = tx.clone(); + let expected = expected_state.clone(); + async move { + if let Some(error) = params.error { + let msg = params.error_description.unwrap_or(error); + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Err(anyhow!("{}", msg))); + } + return Html(html_error(&msg)); + } + + let code = match params.code { + Some(c) => c, + None => { + let msg = "Missing authorization code"; + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Err(anyhow!("{}", msg))); + } + return Html(html_error(msg)); + } + }; + + if params.state.as_deref() != Some(&expected) { + let msg = "Invalid state - potential CSRF attack"; + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Err(anyhow!("{}", msg))); + } + return Html(html_error(msg)); + } + + if let Some(sender) = tx.lock().await.take() { + let _ = sender.send(Ok(code)); + } + Html(html_success()) + } + }), + ) +} + +async fn spawn_oauth_server(app: Router) -> Result> { + let addr = SocketAddr::from((OAUTH_HOST, OAUTH_PORT)); + let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { + if e.kind() == io::ErrorKind::AddrInUse { + anyhow!( + "Hugging Face OAuth callback server failed to bind to {}: port {} is already in use", + addr, + OAUTH_PORT + ) + } else { + anyhow!( + "Hugging Face OAuth callback server failed to bind to {}: {}", + addr, + e + ) + } + })?; + Ok(tokio::spawn(async move { + let server = axum::serve(listener, app); + let _ = server.await; + })) +} + +struct ServerHandleGuard(Option>); + +impl ServerHandleGuard { + fn new(handle: tokio::task::JoinHandle<()>) -> Self { + Self(Some(handle)) + } + + fn abort(&mut self) { + if let Some(handle) = self.0.take() { + handle.abort(); + } + } +} + +impl Drop for ServerHandleGuard { + fn drop(&mut self) { + self.abort(); + } +} + +async fn wait_for_oauth_code(rx: oneshot::Receiver>) -> Result { + let code_result = + tokio::time::timeout(std::time::Duration::from_secs(OAUTH_TIMEOUT_SECS), rx).await; + code_result + .map_err(|_| anyhow!("Hugging Face OAuth flow timed out"))?? + .map_err(|e| anyhow!("Hugging Face OAuth callback error: {}", e)) +} + +async fn perform_loopback_oauth_flow(client_id: &str) -> Result { + let _guard = HUGGINGFACE_OAUTH_MUTEX.try_lock().map_err(|_| { + anyhow!("Another Hugging Face OAuth flow is already in progress; please try again later") + })?; + + let pkce = generate_pkce(); + let csrf_state = generate_state(); + let auth_url = build_authorize_url(client_id, &pkce, &csrf_state)?; + + let (tx, rx) = oneshot::channel::>(); + let tx = Arc::new(TokioMutex::new(Some(tx))); + let app = oauth_callback_router(csrf_state.clone(), tx); + let server_handle = spawn_oauth_server(app).await?; + let mut server_guard = ServerHandleGuard::new(server_handle); + + if webbrowser::open(&auth_url).is_err() { + tracing::info!( + "Please open this URL in your browser to authorize goose with Hugging Face:\n{}", + auth_url + ); + } + + let code_result = wait_for_oauth_code(rx).await; + server_guard.abort(); + let code = code_result?; + + let tokens = exchange_code_for_tokens(client_id, &code, &pkce).await?; + Ok(token_data_from_response(tokens)) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + use wiremock::matchers::{body_string_contains, method, path as request_path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn token_path(dir: &TempDir) -> PathBuf { + dir.path().join(HUGGINGFACE_OAUTH_CACHE_PATH) + } + + fn with_token_path(f: impl FnOnce(PathBuf) -> T) -> T { + let dir = TempDir::new().unwrap(); + f(token_path(&dir)) + } + + #[test] + fn pkce_challenge_is_url_safe_base64_of_sha256_of_verifier() { + let pkce = generate_pkce(); + assert_eq!(pkce.verifier.len(), 64); + assert_eq!(pkce.challenge.len(), 43); + assert!(!pkce.challenge.contains('=')); + assert!(!pkce.challenge.contains('+')); + assert!(!pkce.challenge.contains('/')); + } + + #[test] + fn authorize_url_contains_required_oauth_params() { + let pkce = PkceChallenge { + verifier: "v".repeat(64), + challenge: "challenge-fixture".to_string(), + }; + let url = build_authorize_url("client-fixture", &pkce, "state-fixture").unwrap(); + assert!(url.starts_with(AUTHORIZE_URL)); + assert!(url.contains("client_id=client-fixture")); + assert!(url.contains("code_challenge=challenge-fixture")); + assert!(url.contains("code_challenge_method=S256")); + assert!(url.contains("state=state-fixture")); + assert!(url.contains("scope=read-repos")); + assert!(url.contains("gated-repos")); + assert!(url.contains("inference-api")); + } + + #[test] + fn oauth_client_id_defaults_to_cimd_metadata_url() { + if option_env!("GOOSE_HUGGINGFACE_OAUTH_CLIENT_ID").is_none() { + assert_eq!(oauth_client_id(), HUGGINGFACE_OAUTH_CLIENT_METADATA_URL); + } + } + + #[test] + fn redirect_uri_matches_huggingface_cimd_metadata() { + assert_eq!( + redirect_uri(), + "http://127.0.0.1:17863/oauth/huggingface/callback" + ); + } + + #[test] + fn token_data_from_response_stores_expires_in_as_expires_at() { + let token_data = token_data_from_response(TokenResponse { + access_token: "token".to_string(), + refresh_token: None, + expires_in: Some(60), + }); + + let expires_at = token_data.expires_at.unwrap(); + assert!(expires_at > Utc::now()); + assert!(expires_at <= Utc::now() + chrono::Duration::seconds(60)); + } + + #[tokio::test] + async fn expired_oauth_token_refreshes_with_cached_refresh_token() { + let dir = TempDir::new().unwrap(); + let path = token_path(&dir); + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + std::fs::write( + &path, + serde_json::to_string(&HuggingFaceTokenData { + access_token: "expired".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(Utc::now() - chrono::Duration::minutes(1)), + }) + .unwrap(), + ) + .unwrap(); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(request_path("/")) + .and(body_string_contains("grant_type=refresh_token")) + .and(body_string_contains("refresh_token=refresh")) + .and(body_string_contains("client_id=client-fixture")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "refreshed", + "expires_in": 60 + }))) + .mount(&server) + .await; + + let token = + refreshed_or_usable_oauth_token_from_path(&path, "client-fixture", &server.uri()) + .await + .unwrap(); + + assert_eq!(token.as_deref(), Some("refreshed")); + let saved = load_oauth_token_from_path(&path).unwrap(); + assert_eq!(saved.access_token, "refreshed"); + assert_eq!(saved.refresh_token.as_deref(), Some("refresh")); + assert!(saved.expires_at.unwrap() > Utc::now()); + } + + #[test] + fn usable_oauth_token_skips_expired_token() { + with_token_path(|path| { + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + std::fs::write( + &path, + serde_json::to_string(&HuggingFaceTokenData { + access_token: "expired".to_string(), + refresh_token: None, + expires_at: Some(Utc::now() - chrono::Duration::minutes(1)), + }) + .unwrap(), + ) + .unwrap(); + + assert_eq!(usable_oauth_token_from_path(&path), None); + }); + } + + #[test] + fn usable_oauth_token_returns_unexpired_token() { + with_token_path(|path| { + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + std::fs::write( + &path, + serde_json::to_string(&HuggingFaceTokenData { + access_token: "valid".to_string(), + refresh_token: None, + expires_at: Some(Utc::now() + chrono::Duration::minutes(1)), + }) + .unwrap(), + ) + .unwrap(); + + assert_eq!( + usable_oauth_token_from_path(&path).as_deref(), + Some("valid") + ); + }); + } + + #[test] + fn has_usable_or_refreshable_oauth_token_accepts_unexpired_token() { + with_token_path(|path| { + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + std::fs::write( + &path, + serde_json::to_string(&HuggingFaceTokenData { + access_token: "valid".to_string(), + refresh_token: None, + expires_at: Some(Utc::now() + chrono::Duration::minutes(1)), + }) + .unwrap(), + ) + .unwrap(); + + assert!(has_usable_or_refreshable_oauth_token_from_path(&path)); + }); + } + + #[test] + fn has_usable_or_refreshable_oauth_token_accepts_expired_refreshable_token() { + with_token_path(|path| { + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + std::fs::write( + &path, + serde_json::to_string(&HuggingFaceTokenData { + access_token: "expired".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(Utc::now() - chrono::Duration::minutes(1)), + }) + .unwrap(), + ) + .unwrap(); + + assert!(has_usable_or_refreshable_oauth_token_from_path(&path)); + }); + } + + #[test] + fn has_usable_or_refreshable_oauth_token_rejects_expired_unrefreshable_token() { + with_token_path(|path| { + std::fs::create_dir_all(path.parent().unwrap()).unwrap(); + std::fs::write( + &path, + serde_json::to_string(&HuggingFaceTokenData { + access_token: "expired".to_string(), + refresh_token: None, + expires_at: Some(Utc::now() - chrono::Duration::minutes(1)), + }) + .unwrap(), + ) + .unwrap(); + + assert!(!has_usable_or_refreshable_oauth_token_from_path(&path)); + }); + } + + #[test] + fn has_configured_token_accepts_oauth_without_secret_lookup() { + let configured = has_configured_token_from_sources(true, || { + panic!("secret store should not be queried when OAuth is configured") + }) + .unwrap(); + + assert!(configured); + } + + #[test] + fn has_configured_token_accepts_secret_fallback() { + let configured = + has_configured_token_from_sources(false, || Ok(Some("hf-token".to_string()))).unwrap(); + + assert!(configured); + } + + #[test] + fn has_configured_token_rejects_missing_oauth_and_secret() { + let configured = has_configured_token_from_sources(false, || Ok(None)).unwrap(); + + assert!(!configured); + } + + #[cfg(unix)] + #[test] + fn save_oauth_token_restricts_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + with_token_path(|path| { + save_oauth_token_to_path( + &path, + &HuggingFaceTokenData { + access_token: "saved".to_string(), + refresh_token: None, + expires_at: None, + }, + ) + .unwrap(); + + let mode = std::fs::metadata(&path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600); + }); + } + + #[test] + fn resolver_prefers_provider_token_over_oauth() { + let token = resolve_token_from_sources( + Some("api-key".to_string()), + Some("oauth".to_string()), + || panic!("secret store should not be queried when provider token is usable"), + ) + .unwrap(); + + assert_eq!(token.as_deref(), Some("api-key")); + } + + #[test] + fn resolver_uses_oauth_before_secret_store() { + let token = resolve_token_from_sources(None, Some("oauth".to_string()), || { + panic!("secret store should not be queried when OAuth is usable") + }) + .unwrap(); + + assert_eq!(token.as_deref(), Some("oauth")); + } + + #[test] + fn resolver_uses_secret_store_when_no_provider_token_or_oauth_exists() { + let token = resolve_token_from_sources(None, None, || Ok(Some("secret-store".to_string()))) + .unwrap(); + + assert_eq!(token.as_deref(), Some("secret-store")); + } + + #[tokio::test] + async fn async_resolver_uses_secret_fallback_when_oauth_refresh_fails() { + let token = resolve_token_async_from_sources( + None, + async { Err(anyhow::anyhow!("refresh token revoked")) }, + || Ok(Some("secret-store".to_string())), + ) + .await + .unwrap(); + + assert_eq!(token.as_deref(), Some("secret-store")); + } + + #[tokio::test] + async fn async_resolver_reports_refresh_error_without_secret_fallback() { + let error = resolve_token_async_from_sources( + None, + async { Err(anyhow::anyhow!("refresh token revoked")) }, + || Ok(None), + ) + .await + .unwrap_err(); + + assert_eq!(error.to_string(), "refresh token revoked"); + } +} diff --git a/crates/goose/src/providers/init.rs b/crates/goose/src/providers/init.rs index 0ae29d2265ad..06d5773de0ad 100644 --- a/crates/goose/src/providers/init.rs +++ b/crates/goose/src/providers/init.rs @@ -27,6 +27,7 @@ use super::{ gemini_oauth::GeminiOAuthProvider, githubcopilot::GithubCopilotProvider, google::GoogleProvider, + huggingface::HuggingFaceProvider, kimicode::KimiCodeProvider, litellm::LiteLLMProvider, nanogpt::NanoGptProvider, @@ -76,6 +77,7 @@ async fn init_registry() -> RwLock { registry.register::(true); registry.register::(false); registry.register::(true); + registry.register::(true); registry.register::(true); registry.register::(false); registry.register::(true); @@ -111,10 +113,18 @@ async fn init_registry() -> RwLock { "chatgpt_codex", Arc::new(|| Box::pin(ChatGptCodexProvider::cleanup())), ); + registry.set_cleanup( + "gemini_oauth", + Arc::new(|| Box::pin(GeminiOAuthProvider::cleanup())), + ); registry.set_cleanup( "xai_oauth", Arc::new(|| Box::pin(XaiOAuthProvider::cleanup())), ); + registry.set_cleanup( + "huggingface", + Arc::new(|| Box::pin(HuggingFaceProvider::cleanup())), + ); if let Err(e) = load_custom_providers_into_registry(&mut registry) { tracing::warn!("Failed to load custom providers: {}", e); @@ -261,6 +271,22 @@ mod tests { assert!(!endpoint.secret, "Endpoint should not be secret"); } + #[tokio::test] + async fn test_huggingface_provider_registry_wiring() { + let huggingface = get_from_registry("huggingface") + .await + .expect("huggingface provider should be registered"); + let meta = huggingface.metadata(); + + assert_eq!(huggingface.provider_type(), ProviderType::Preferred); + assert_eq!(meta.display_name, "Hugging Face"); + assert_eq!(meta.default_model, "Qwen/Qwen3-Coder-480B-A35B-Instruct"); + assert!(meta + .config_keys + .iter() + .any(|key| key.name == "HF_TOKEN" && key.secret)); + } + #[tokio::test] async fn test_nvidia_declarative_provider_registry_wiring() { let nvidia = get_from_registry("nvidia") diff --git a/crates/goose/src/providers/inventory/mod.rs b/crates/goose/src/providers/inventory/mod.rs index f363d6963e1a..589f2ea3e611 100644 --- a/crates/goose/src/providers/inventory/mod.rs +++ b/crates/goose/src/providers/inventory/mod.rs @@ -1,4 +1,4 @@ -use super::base::{ConfigKey, ModelInfo, ProviderType}; +use super::base::{ConfigKey, ModelInfo, Provider, ProviderType}; use super::canonical::{map_provider_name, map_to_canonical_model, CanonicalModelRegistry}; use super::catalog::ProviderSetupCategory; use crate::config::declarative_providers::{DeclarativeProviderConfig, ProviderEngine}; @@ -7,10 +7,12 @@ use crate::session::session_manager::SessionStorage; use crate::utils::bytes_to_hex; use anyhow::{Context, Result}; use chrono::{DateTime, Duration, Utc}; +use futures::FutureExt; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use sqlx::{Pool, Row, Sqlite, Transaction}; use std::collections::{BTreeMap, HashMap, HashSet}; +use std::panic::AssertUnwindSafe; use std::sync::{Arc, PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard}; use tracing::warn; @@ -310,6 +312,23 @@ impl ProviderInventoryService { })) } + pub async fn find_entry_for_provider( + &self, + provider_id: &str, + ) -> Option { + match self.entry_for_provider(provider_id).await { + Ok(entry) => entry, + Err(error) => { + warn!( + provider = %provider_id, + %error, + "failed to look up provider inventory entry" + ); + None + } + } + } + pub async fn entries(&self, provider_ids: &[String]) -> Result> { let ids = self.resolve_provider_ids(provider_ids).await; let handles: Vec<_> = ids @@ -561,6 +580,106 @@ impl ProviderInventoryService { } } + pub(crate) async fn refresh_with_provider( + &self, + provider_name: &str, + provider: &Arc, + inventory: &mut ProviderInventoryEntry, + context: &str, + ) { + let provider_id = provider_name.to_string(); + match self + .plan_refresh_jobs(std::slice::from_ref(&provider_id)) + .await + { + Ok(plan) + if plan + .started + .iter() + .any(|job| job.provider_id == provider_id) => + { + let refresh_job = plan + .started + .into_iter() + .find(|job| job.provider_id == provider_id); + if let Some(refresh_job) = refresh_job { + let mut refresh_guard = self.refresh_guard(&refresh_job.identity); + let fetch_result: Result> = + match ensure_refresh_identity_current(&provider_id, &refresh_job.identity) + .await + { + Ok(()) => { + match AssertUnwindSafe(provider.fetch_recommended_models()) + .catch_unwind() + .await + { + Ok(Ok(models)) => Ok(models), + Ok(Err(error)) => Err(anyhow::anyhow!(error.to_string())), + Err(_) => Err(anyhow::anyhow!( + "provider inventory refresh task panicked" + )), + } + } + Err(error) => Err(error), + }; + match fetch_result { + Ok(models) => { + if let Err(error) = self + .store_refreshed_models_for_identity(&refresh_job.identity, &models) + .await + { + warn!( + provider = %provider_id, + context = %context, + error = %error, + "failed to store refreshed provider inventory" + ); + } else { + refresh_guard.complete(); + } + } + Err(error) => { + let error_message = error.to_string(); + if let Err(store_error) = self + .store_refresh_error_for_identity( + &refresh_job.identity, + error_message.clone(), + ) + .await + { + warn!( + provider = %provider_id, + context = %context, + error = %store_error, + "failed to store provider inventory refresh error" + ); + } else { + refresh_guard.complete(); + } + warn!( + provider = %provider_id, + context = %context, + error = %error_message, + "provider inventory refresh failed" + ); + } + } + } + } + Ok(_) => {} + Err(error) => warn!( + provider = %provider_id, + context = %context, + error = %error, + "failed to plan provider inventory refresh" + ), + } + + if let Some(refreshed_inventory) = self.find_entry_for_provider(provider_name).await { + *inventory = refreshed_inventory; + } + } + pub fn is_stale(entry: &ProviderInventoryEntry) -> bool { let Some(last_updated_at) = entry.last_updated_at else { return false; @@ -717,6 +836,19 @@ impl ProviderInventoryService { } } +pub(crate) async fn ensure_refresh_identity_current( + provider_id: &str, + planned_identity: &InventoryIdentity, +) -> Result<()> { + let current_identity = crate::providers::inventory_identity(provider_id) + .await? + .into_identity()?; + if current_identity != *planned_identity { + anyhow::bail!("provider inventory identity changed before refresh completed"); + } + Ok(()) +} + pub fn default_inventory_identity( provider_id: &str, provider_family: &str, diff --git a/crates/goose/src/providers/local_inference/hf_models.rs b/crates/goose/src/providers/local_inference/hf_models.rs index 5767193ae2ea..c9133747b016 100644 --- a/crates/goose/src/providers/local_inference/hf_models.rs +++ b/crates/goose/src/providers/local_inference/hf_models.rs @@ -1,6 +1,8 @@ use anyhow::{bail, Result}; use serde::{Deserialize, Serialize}; +use crate::providers::huggingface_auth; + use utoipa::ToSchema; const HF_API_BASE: &str = "https://huggingface.co/api/models"; @@ -245,6 +247,26 @@ fn build_download_url(repo_id: &str, filename: &str) -> String { format!("{}/{}/resolve/main/{}", HF_DOWNLOAD_BASE, repo_id, filename) } +pub fn hf_authorization_header(token: Option<&str>) -> Option { + token + .filter(|token| !token.is_empty()) + .map(|token| format!("Bearer {}", token)) +} + +fn apply_hf_auth(request: reqwest::RequestBuilder, token: Option<&str>) -> reqwest::RequestBuilder { + if let Some(header) = hf_authorization_header(token) { + request.header("Authorization", header) + } else { + request + } +} + +async fn optional_hf_token( + token: impl std::future::Future>>, +) -> Option { + token.await.ok().flatten() +} + fn parent_components(filename: &str) -> Vec<&str> { filename.rsplit_once('/').map_or(Vec::new(), |(parent, _)| { parent.split('/').filter(|part| !part.is_empty()).collect() @@ -404,13 +426,13 @@ fn group_into_variants(repo_id: &str, files: Vec) -> Vec Result> { let client = reqwest::Client::new(); + let token = optional_hf_token(huggingface_auth::resolve_token_async()).await; let url = format!( "{}?search={}&filter=gguf&sort=downloads&direction=-1&limit={}", HF_API_BASE, query, limit ); - let response = client - .get(&url) + let response = apply_hf_auth(client.get(&url), token.as_deref()) .header("User-Agent", "goose-ai-agent") .send() .await?; @@ -469,10 +491,10 @@ pub async fn search_gguf_models(query: &str, limit: usize) -> Result Result> { let client = reqwest::Client::new(); + let token = optional_hf_token(huggingface_auth::resolve_token_async()).await; let url = format!("{}/{}?blobs=true", HF_API_BASE, repo_id); - let response = client - .get(&url) + let response = apply_hf_auth(client.get(&url), token.as_deref()) .header("User-Agent", "goose-ai-agent") .send() .await?; @@ -494,10 +516,10 @@ pub async fn get_repo_gguf_variants(repo_id: &str) -> Result /// Fetch raw GGUF files (kept for resolve_model_spec). pub async fn get_repo_gguf_files(repo_id: &str) -> Result> { let client = reqwest::Client::new(); + let token = optional_hf_token(huggingface_auth::resolve_token_async()).await; let url = format!("{}/{}?blobs=true", HF_API_BASE, repo_id); - let response = client - .get(&url) + let response = apply_hf_auth(client.get(&url), token.as_deref()) .header("User-Agent", "goose-ai-agent") .send() .await?; @@ -556,9 +578,9 @@ pub async fn resolve_model_spec_full(spec: &str) -> Result<(String, ResolvedMode let (repo_id, quant) = parse_model_spec(spec)?; let client = reqwest::Client::new(); + let token = optional_hf_token(huggingface_auth::resolve_token_async()).await; let url = format!("{}/{}?blobs=true", HF_API_BASE, repo_id); - let response = client - .get(&url) + let response = apply_hf_auth(client.get(&url), token.as_deref()) .header("User-Agent", "goose-ai-agent") .send() .await?; @@ -739,6 +761,16 @@ mod tests { assert_eq!(parse_quantization("random-name.gguf"), "unknown"); } + #[test] + fn test_hf_authorization_header() { + assert_eq!( + hf_authorization_header(Some("hf_test")).as_deref(), + Some("Bearer hf_test") + ); + assert_eq!(hf_authorization_header(Some("")), None); + assert_eq!(hf_authorization_header(None), None); + } + #[test] fn test_parse_quantization_with_directory() { assert_eq!( @@ -926,6 +958,21 @@ mod tests { assert_eq!(mmproj.quantization, "BF16"); } + #[tokio::test] + async fn optional_hf_token_returns_resolved_token() { + let token = optional_hf_token(async { Ok(Some("token".to_string())) }).await; + + assert_eq!(token.as_deref(), Some("token")); + } + + #[tokio::test] + async fn optional_hf_token_ignores_resolution_errors() { + let token = + optional_hf_token(async { Err(anyhow::anyhow!("refresh token revoked")) }).await; + + assert_eq!(token, None); + } + #[test] fn test_select_best_mmproj_prefers_bf16_over_f16_tie() { let files = vec![ diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index b0fec1c9b4d5..1ca4d7938f33 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -31,6 +31,8 @@ pub mod gemini_oauth; pub mod githubcopilot; pub mod google; pub mod http_status; +pub mod huggingface; +pub mod huggingface_auth; mod init; pub mod inventory; pub mod kimicode; diff --git a/crates/goose/src/providers/openai_compatible.rs b/crates/goose/src/providers/openai_compatible.rs index d3580bdcb1ea..7ba2a5ff436e 100644 --- a/crates/goose/src/providers/openai_compatible.rs +++ b/crates/goose/src/providers/openai_compatible.rs @@ -11,13 +11,15 @@ use tokio_util::codec::{FramedRead, LinesCodec}; use tokio_util::io::StreamReader; use super::api_client::ApiClient; -use super::base::{MessageStream, Provider}; +use super::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{ImageFormat, RequestLog}; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::{create_request, response_to_streaming_message}; +use crate::providers::formats::openai::{ + create_request, get_usage, response_to_message, response_to_streaming_message, +}; use crate::providers::formats::openai_responses::responses_api_to_streaming_message; use rmcp::model::Tool; @@ -28,6 +30,7 @@ pub struct OpenAiCompatibleProvider { model: ModelConfig, /// Path prefix prepended to `chat/completions` (e.g. `"deployments/{name}/"` for Azure). completions_prefix: String, + supports_streaming: bool, } impl OpenAiCompatibleProvider { @@ -42,9 +45,15 @@ impl OpenAiCompatibleProvider { api_client, model, completions_prefix, + supports_streaming: true, } } + pub fn with_supports_streaming(mut self, supports_streaming: bool) -> Self { + self.supports_streaming = supports_streaming; + self + } + fn build_request( &self, model_config: &ModelConfig, @@ -110,7 +119,13 @@ impl Provider for OpenAiCompatibleProvider { messages: &[Message], tools: &[Tool], ) -> Result { - let payload = self.build_request(model_config, system, messages, tools, true)?; + let payload = self.build_request( + model_config, + system, + messages, + tools, + self.supports_streaming, + )?; let mut log = RequestLog::start(model_config, &payload)?; let completions_path = format!("{}chat/completions", self.completions_prefix); @@ -127,7 +142,27 @@ impl Provider for OpenAiCompatibleProvider { let _ = log.error(e); })?; - stream_openai_compat(response, log) + if self.supports_streaming { + stream_openai_compat(response, log) + } else { + let json: serde_json::Value = response.json().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse JSON: {}", e)) + })?; + + let message = response_to_message(&json).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse message: {}", e)) + })?; + + let usage_data = get_usage(json.get("usage").unwrap_or(&serde_json::Value::Null)); + let usage = ProviderUsage::new(model_config.model_name.clone(), usage_data); + + log.write( + &serde_json::to_value(&message).unwrap_or_default(), + Some(&usage.usage), + )?; + + Ok(stream_from_single_message(message, usage)) + } } } @@ -190,6 +225,7 @@ pub fn stream_responses_compat( #[cfg(test)] mod tests { use super::*; + use crate::model::ModelConfig; use serde_json::json; use test_case::test_case; @@ -262,4 +298,26 @@ mod tests { "Expected {expected_variant}, got error: {err:?}" ); } + + #[test] + fn build_request_respects_non_streaming_mode() { + let provider = OpenAiCompatibleProvider::new( + "test".to_string(), + ApiClient::new( + "http://localhost".to_string(), + super::super::api_client::AuthMethod::NoAuth, + ) + .unwrap(), + ModelConfig::new_or_fail("test-model"), + String::new(), + ) + .with_supports_streaming(false); + + let payload = provider + .build_request(&provider.model, "", &[], &[], provider.supports_streaming) + .unwrap(); + + assert_eq!(payload.get("stream"), None); + assert_eq!(payload.get("stream_options"), None); + } } diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index 1970da7fe3c8..b3da205e57b2 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -161,6 +161,53 @@ impl ProviderRegistry { P: ProviderDef + 'static, F: Fn(ModelConfig) -> Result + Send + Sync + 'static, G: Fn() -> Result + Send + Sync + 'static, + { + self.register_with_name_impl::( + config, + provider_type, + supports_inventory_refresh, + constructor, + inventory_identity, + None, + ); + } + + pub fn register_with_name_and_inventory_configured( + &mut self, + config: &DeclarativeProviderConfig, + provider_type: ProviderType, + supports_inventory_refresh: bool, + constructor: F, + inventory_identity: G, + inventory_configured: H, + ) where + P: ProviderDef + 'static, + F: Fn(ModelConfig) -> Result + Send + Sync + 'static, + G: Fn() -> Result + Send + Sync + 'static, + H: Fn() -> bool + Send + Sync + 'static, + { + self.register_with_name_impl::( + config, + provider_type, + supports_inventory_refresh, + constructor, + inventory_identity, + Some(Arc::new(inventory_configured)), + ); + } + + fn register_with_name_impl( + &mut self, + config: &DeclarativeProviderConfig, + provider_type: ProviderType, + supports_inventory_refresh: bool, + constructor: F, + inventory_identity: G, + inventory_configured: Option, + ) where + P: ProviderDef + 'static, + F: Fn(ModelConfig) -> Result + Send + Sync + 'static, + G: Fn() -> Result + Send + Sync + 'static, { let base_metadata = P::metadata(); let description = config @@ -243,6 +290,12 @@ impl ProviderRegistry { model_selection_hint: None, }; let inventory_config_keys = custom_metadata.config_keys.clone(); + let default_inventory_configured = Arc::new(move || { + super::inventory::default_inventory_configured( + &inventory_config_keys, + crate::config::Config::global(), + ) + }); self.entries.insert( config.name.clone(), @@ -256,12 +309,7 @@ impl ProviderRegistry { }) }), inventory_identity: Arc::new(inventory_identity), - inventory_configured: Arc::new(move || { - super::inventory::default_inventory_configured( - &inventory_config_keys, - crate::config::Config::global(), - ) - }), + inventory_configured: inventory_configured.unwrap_or(default_inventory_configured), cleanup: None, provider_type, supports_inventory_refresh, @@ -308,3 +356,52 @@ impl ProviderRegistry { self.entries.retain(|name, _| !name.starts_with("custom_")); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::declarative_providers::ProviderEngine; + use crate::providers::openai::OpenAiProvider; + + fn test_config() -> DeclarativeProviderConfig { + DeclarativeProviderConfig { + name: "custom_hf".to_string(), + engine: ProviderEngine::OpenAI, + display_name: "Custom HF".to_string(), + description: None, + api_key_env: String::new(), + base_url: "https://router.huggingface.co/v1".to_string(), + models: vec![ModelInfo::new("test-model", 128_000)], + headers: None, + timeout_seconds: None, + supports_streaming: Some(true), + requires_auth: true, + catalog_provider_id: Some("huggingface".to_string()), + base_path: None, + env_vars: None, + dynamic_models: None, + skip_canonical_filtering: false, + model_doc_link: None, + setup_steps: vec![], + fast_model: None, + preserves_thinking: false, + } + } + + #[test] + fn register_with_name_can_override_inventory_configured() { + let mut registry = ProviderRegistry::new(); + registry.register_with_name_and_inventory_configured::( + &test_config(), + ProviderType::Declarative, + false, + |_| unreachable!("constructor is not used by this test"), + || Ok(InventoryIdentityInput::new("custom_hf", "huggingface")), + || false, + ); + + let entry = registry.entries.get("custom_hf").unwrap(); + + assert!(!entry.inventory_configured()); + } +} diff --git a/crates/goose/src/recipe/manifest.rs b/crates/goose/src/recipe/manifest.rs new file mode 100644 index 000000000000..f9d11406a902 --- /dev/null +++ b/crates/goose/src/recipe/manifest.rs @@ -0,0 +1,140 @@ +use anyhow::{anyhow, Result}; +use std::fs; +use std::hash::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::path::{Path, PathBuf}; + +use crate::recipe::build_recipe::resolve_sub_recipe_path; +use crate::recipe::local_recipes::list_local_recipes; +use crate::recipe::Recipe; + +#[derive(Debug, Clone)] +pub struct RecipeFileManifest { + pub id: String, + pub recipe: Recipe, + pub file_path: PathBuf, + pub last_modified: String, +} + +pub fn short_id_from_path(path: &str) -> String { + let mut hasher = DefaultHasher::new(); + path.hash(&mut hasher); + let h = hasher.finish(); + format!("{:016x}", h) +} + +pub fn list_recipe_file_manifests() -> Result> { + let recipes_with_path = list_local_recipes()?; + let mut manifests = Vec::new(); + + for (file_path, mut recipe) in recipes_with_path { + let Ok(last_modified) = fs::metadata(file_path.clone()).and_then(|metadata| { + metadata + .modified() + .map(|modified| chrono::DateTime::::from(modified).to_rfc3339()) + }) else { + continue; + }; + + resolve_recipe_sub_recipe_paths(&mut recipe, &file_path); + + manifests.push(RecipeFileManifest { + id: short_id_from_path(file_path.to_string_lossy().as_ref()), + recipe, + file_path, + last_modified, + }); + } + + manifests.sort_by(|a, b| b.last_modified.cmp(&a.last_modified)); + + Ok(manifests) +} + +pub fn get_recipe_file_path_by_id(id: &str) -> Result { + list_recipe_file_manifests()? + .into_iter() + .find(|manifest| manifest.id == id) + .map(|manifest| manifest.file_path) + .ok_or_else(|| anyhow!("Recipe not found: {}", id)) +} + +pub fn load_recipe_by_id(id: &str) -> Result { + let path = get_recipe_file_path_by_id(id)?; + load_recipe_from_path(&path) +} + +pub fn load_recipe_from_path(path: &Path) -> Result { + let mut recipe = Recipe::from_file_path(path)?; + resolve_recipe_sub_recipe_paths(&mut recipe, path); + Ok(recipe) +} + +fn resolve_recipe_sub_recipe_paths(recipe: &mut Recipe, recipe_path: &Path) { + let Some(recipe_dir) = recipe_path.parent() else { + return; + }; + + let Some(ref mut sub_recipes) = recipe.sub_recipes else { + return; + }; + + for sub_recipe in sub_recipes.iter_mut() { + if let Ok(resolved) = resolve_sub_recipe_path(&sub_recipe.path, recipe_dir) { + sub_recipe.path = resolved; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn short_id_from_path_is_stable() { + assert_eq!( + short_id_from_path("/tmp/example.yaml"), + short_id_from_path("/tmp/example.yaml") + ); + assert_ne!( + short_id_from_path("/tmp/example.yaml"), + short_id_from_path("/tmp/other.yaml") + ); + } + + #[test] + fn load_recipe_from_path_resolves_sub_recipe_paths() { + let temp_dir = tempfile::tempdir().unwrap(); + let child_path = temp_dir.path().join("child.yaml"); + fs::write( + &child_path, + r#" +title: Child +description: Child recipe +instructions: Child instructions +"#, + ) + .unwrap(); + let parent_path = temp_dir.path().join("parent.yaml"); + fs::write( + &parent_path, + r#" +title: Parent +description: Parent recipe +instructions: Parent instructions +sub_recipes: + - name: child + path: child.yaml +"#, + ) + .unwrap(); + + let recipe = load_recipe_from_path(&parent_path).unwrap(); + let sub_recipes = recipe.sub_recipes.unwrap(); + + assert_eq!( + sub_recipes[0].path, + child_path.to_string_lossy().to_string() + ); + } +} diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 4dc007600560..3de1b8dda65e 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -15,6 +15,7 @@ use utoipa::ToSchema; pub mod build_recipe; pub mod local_recipes; +pub mod manifest; pub mod read_recipe_file_content; mod recipe_extension_adapter; pub mod template_recipe; diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index cb41621427e3..6b1f8bb5ff5b 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -1,8 +1,7 @@ -use ahash::AHasher; -use dashmap::DashMap; +use lru::LruCache; use rmcp::model::Tool; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::num::NonZeroUsize; +use std::sync::{Arc, Mutex}; use tiktoken_rs::CoreBPE; use tokio::sync::OnceCell; @@ -10,7 +9,7 @@ use crate::conversation::message::Message; static TOKENIZER: OnceCell> = OnceCell::const_new(); -const MAX_TOKEN_CACHE_SIZE: usize = 10_000; +const MAX_TOKEN_CACHE_SIZE: usize = 1_024; // token use for various bits of a tool calls: const FUNC_INIT: usize = 7; @@ -22,38 +21,54 @@ const FUNC_END: usize = 12; pub struct TokenCounter { tokenizer: Arc, - token_cache: Arc>, + token_cache: Mutex>, +} + +#[derive(Clone, Copy, Eq, Hash, PartialEq)] +struct TokenCacheKey { + len: usize, + hash: [u8; 32], +} + +impl TokenCacheKey { + fn from_text(text: &str) -> Self { + Self { + len: text.len(), + hash: *blake3::hash(text.as_bytes()).as_bytes(), + } + } } impl TokenCounter { pub async fn new() -> Result { let tokenizer = get_tokenizer().await?; + let cache_capacity = + NonZeroUsize::new(MAX_TOKEN_CACHE_SIZE).expect("token cache capacity must be non-zero"); Ok(Self { tokenizer, - token_cache: Arc::new(DashMap::new()), + token_cache: Mutex::new(LruCache::new(cache_capacity)), }) } pub fn count_tokens(&self, text: &str) -> usize { - let mut hasher = AHasher::default(); - text.hash(&mut hasher); - let hash = hasher.finish(); - - if let Some(count) = self.token_cache.get(&hash) { - return *count; + let cache_key = TokenCacheKey::from_text(text); + if let Some(count) = self + .token_cache + .lock() + .expect("token cache mutex poisoned") + .get(&cache_key) + .copied() + { + return count; } let tokens = self.tokenizer.encode_with_special_tokens(text); let count = tokens.len(); - if self.token_cache.len() >= MAX_TOKEN_CACHE_SIZE { - if let Some(entry) = self.token_cache.iter().next() { - let old_hash = *entry.key(); - self.token_cache.remove(&old_hash); - } - } - - self.token_cache.insert(hash, count); + self.token_cache + .lock() + .expect("token cache mutex poisoned") + .put(cache_key, count); count } @@ -173,11 +188,17 @@ impl TokenCounter { } pub fn clear_cache(&self) { - self.token_cache.clear(); + self.token_cache + .lock() + .expect("token cache mutex poisoned") + .clear(); } pub fn cache_size(&self) -> usize { - self.token_cache.len() + self.token_cache + .lock() + .expect("token cache mutex poisoned") + .len() } } @@ -260,13 +281,13 @@ mod tests { let counter = create_token_counter().await.unwrap(); let mut cached_texts = Vec::new(); - for i in 0..50 { + for i in 0..=MAX_TOKEN_CACHE_SIZE { let text = format!("Test string number {}", i); counter.count_tokens(&text); cached_texts.push(text); } - assert!(counter.cache_size() <= MAX_TOKEN_CACHE_SIZE); + assert_eq!(counter.cache_size(), MAX_TOKEN_CACHE_SIZE); let recent_text = &cached_texts[cached_texts.len() - 1]; let start_size = counter.cache_size(); diff --git a/crates/goose/tests/acp_common_tests/mod.rs b/crates/goose/tests/acp_common_tests/mod.rs index 18976ab7511e..d2af23d07497 100644 --- a/crates/goose/tests/acp_common_tests/mod.rs +++ b/crates/goose/tests/acp_common_tests/mod.rs @@ -16,18 +16,21 @@ use fs_err as fs; use goose::acp::server::AcpProviderFactory; use goose::config::base::CONFIG_YAML_NAME; use goose::config::GooseMode; -use goose::conversation::message::Message; -use goose::model::ModelConfig; -use goose::providers::base::{ - stream_from_single_message, MessageStream, Provider, ProviderUsage, Usage, -}; -use goose::providers::errors::ProviderError; use goose_test_support::{McpFixture, FAKE_CODE, TEST_IMAGE_B64, TEST_MODEL}; use sqlx::sqlite::SqlitePoolOptions; use std::sync::Arc; use std::time::Duration; const SHELL_TEST_CONTENT: &str = "test-shell-content-98765"; +const OPENAI_SESSION_NAME_RESPONSE: &str = r#"data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1766229303,"model":"gpt-5-nano","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]} + +data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1766229303,"model":"gpt-5-nano","choices":[{"index":0,"delta":{"content":"Generated Test Title"},"finish_reason":null}]} + +data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1766229303,"model":"gpt-5-nano","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + +data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1766229303,"model":"gpt-5-nano","choices":[],"usage":{"prompt_tokens":100,"completion_tokens":10,"total_tokens":110}} + +data: [DONE]"#; struct BasicSession { conn: C, @@ -58,46 +61,6 @@ async fn new_basic_session(config: TestConnectionConfig) -> Basic BasicSession { conn, session } } -struct NamingProvider { - model_config: ModelConfig, -} - -#[async_trait::async_trait] -impl Provider for NamingProvider { - fn get_name(&self) -> &str { - "naming-test" - } - - async fn stream( - &self, - _model_config: &ModelConfig, - _session_id: &str, - system: &str, - _messages: &[Message], - _tools: &[rmcp::model::Tool], - ) -> Result { - let text = if system.contains("four words or less") || system.contains("4 words or less") { - "Generated Test Title" - } else { - "2" - }; - Ok(stream_from_single_message( - Message::assistant().with_text(text), - ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()), - )) - } - - fn get_model_config(&self) -> ModelConfig { - self.model_config.clone() - } -} - -fn naming_provider_factory() -> AcpProviderFactory { - Arc::new(|_provider_name, model_config, _extensions, _working_dir| { - Box::pin(async move { Ok(Arc::new(NamingProvider { model_config }) as Arc) }) - }) -} - pub async fn run_list_sessions() { let BasicSession { conn, session } = new_basic_session::(TestConnectionConfig::default()).await; @@ -119,6 +82,7 @@ pub async fn run_list_sessions() { serde_json::Value::Number(2.into()), ); expected_meta.insert("userSetName".to_string(), serde_json::Value::Bool(false)); + expected_meta.insert("hasRecipe".to_string(), serde_json::Value::Bool(false)); assert_eq!( response, ListSessionsResponse::new(vec![SessionInfo::new( @@ -132,9 +96,21 @@ pub async fn run_list_sessions() { pub async fn run_session_name_update_notification() { let expected_session_id = C::expected_session_id(); - let openai = OpenAiFixture::new(vec![], expected_session_id.clone()).await; + let openai = OpenAiFixture::new( + vec![ + ( + r#"\nwhat should we call this conversation?""#.into(), + include_str!("../acp_test_data/openai_basic.txt"), + ), + ( + "Generate a short title for the above messages.".into(), + OPENAI_SESSION_NAME_RESPONSE, + ), + ], + expected_session_id.clone(), + ) + .await; let config = TestConnectionConfig { - provider_factory: Some(naming_provider_factory()), disable_session_naming: false, ..Default::default() }; @@ -923,6 +899,39 @@ pub async fn run_new_session_returns_initial_config() { assert!(!models.available_models.is_empty()); } +pub async fn run_new_session_uses_current_config_mode() { + let temp_dir = tempfile::tempdir().unwrap(); + let config_path = temp_dir.path().join(goose::config::base::CONFIG_YAML_NAME); + fs::write( + &config_path, + format!("GOOSE_MODEL: {TEST_MODEL}\nGOOSE_PROVIDER: openai\nGOOSE_MODE: approve\n"), + ) + .unwrap(); + + let expected_session_id = C::expected_session_id(); + let openai = OpenAiFixture::new(vec![], expected_session_id.clone()).await; + let config = TestConnectionConfig { + goose_mode: GooseMode::Approve, + data_root: temp_dir.path().to_path_buf(), + ..Default::default() + }; + + let mut conn = C::new(config, openai).await; + + let global_config_path = + goose::config::paths::Paths::config_dir().join(goose::config::base::CONFIG_YAML_NAME); + fs::write( + &global_config_path, + format!("GOOSE_MODEL: {TEST_MODEL}\nGOOSE_PROVIDER: openai\nGOOSE_MODE: auto\n"), + ) + .unwrap(); + + let SessionData { session, modes, .. } = conn.new_session().await.unwrap(); + expected_session_id.set(&session.session_id().0); + + assert_eq!(modes.unwrap().current_mode_id, SessionModeId::new("auto")); +} + pub async fn run_config_option_model_set() { run_model_set_impl::(SetModelVia::ConfigOption).await; } @@ -1328,11 +1337,11 @@ pub async fn run_prompt_model_mismatch() { // TODO: add a Responses API mock to OpenAiFixture so we can test with // responses-routed models like o4-mini here. let config = TestConnectionConfig { - current_model: "gpt-4.1".to_string(), + current_model: "gpt-4o".to_string(), ..Default::default() }; - // Server starts on gpt-4.1; client is configured with TEST_MODEL. + // Server starts on gpt-4o; client is configured with TEST_MODEL. // If session_model is seeded from the response, stream() detects the // mismatch and sends set_model(TEST_MODEL) before prompting. let BasicSession { conn: _, .. } = new_basic_session::(config).await; diff --git a/crates/goose/tests/acp_custom_requests_test.rs b/crates/goose/tests/acp_custom_requests_test.rs index 24e237e49606..659a820eeb98 100644 --- a/crates/goose/tests/acp_custom_requests_test.rs +++ b/crates/goose/tests/acp_custom_requests_test.rs @@ -12,11 +12,29 @@ use goose::model::ModelConfig; use goose::providers::base::{MessageStream, Provider}; use goose::providers::errors::ProviderError; use goose_test_support::{EnforceSessionId, IgnoreSessionId}; +use serial_test::serial; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, LazyLock, Mutex}; use common_tests::fixtures::OpenAiFixture; +const DEFAULT_ACP_TEST_CONFIG: &str = "GOOSE_MODEL: gpt-4o\nGOOSE_PROVIDER: openai\n"; + +static ACP_CONFIG_ROOT: LazyLock = + LazyLock::new(|| tempfile::tempdir().unwrap()); + +fn write_acp_global_config(contents: &str) -> PathBuf { + std::env::set_var("GOOSE_PATH_ROOT", ACP_CONFIG_ROOT.path()); + let config_dir = goose::config::paths::Paths::config_dir(); + std::fs::create_dir_all(&config_dir).unwrap(); + std::fs::write( + config_dir.join(goose::config::base::CONFIG_YAML_NAME), + contents, + ) + .unwrap(); + config_dir +} + struct MockProvider { name: String, model_config: ModelConfig, @@ -75,7 +93,9 @@ fn mock_provider_factory() -> AcpProviderFactory { } #[test] +#[serial] fn test_custom_get_tools() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async move { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let mut conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; @@ -98,7 +118,9 @@ fn test_custom_get_tools() { } #[test] +#[serial] fn test_custom_get_extensions() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async move { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; @@ -124,109 +146,9 @@ fn test_custom_get_extensions() { } #[test] -fn test_new_session_passes_cwd_to_provider_factory() { - run_test(async move { - let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let cwd = tempfile::tempdir().unwrap(); - let expected_cwd = cwd.path().to_path_buf(); - let captured_cwds = Arc::new(Mutex::new(Vec::>::new())); - let factory_cwds = Arc::clone(&captured_cwds); - let provider_factory: AcpProviderFactory = Arc::new( - move |provider_name, model_config, _extensions, working_dir| { - factory_cwds.lock().unwrap().push(working_dir); - Box::pin(async move { - Ok(Arc::new(MockProvider { - name: provider_name, - model_config, - recommended_models: Vec::new(), - supported_models: Vec::new(), - }) as Arc) - }) - }, - ); - - let mut conn = AcpServerConnection::new( - TestConnectionConfig { - cwd: Some(cwd), - provider_factory: Some(provider_factory), - ..Default::default() - }, - openai, - ) - .await; - - conn.new_session().await.unwrap(); - - let captured_cwd = tokio::time::timeout(std::time::Duration::from_secs(1), async { - loop { - if let Some(cwd) = captured_cwds.lock().unwrap().first().cloned() { - break cwd; - } - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - }) - .await - .expect("provider factory was not called"); - - assert_eq!(captured_cwd, Some(expected_cwd)); - }); -} - -#[test] -fn test_load_session_passes_load_cwd_to_provider_factory() { - run_test(async move { - let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let initial_cwd = tempfile::tempdir().unwrap(); - let captured_cwds = Arc::new(Mutex::new(Vec::>::new())); - let factory_cwds = Arc::clone(&captured_cwds); - let provider_factory: AcpProviderFactory = Arc::new( - move |provider_name, model_config, _extensions, working_dir| { - factory_cwds.lock().unwrap().push(working_dir); - Box::pin(async move { - Ok(Arc::new(MockProvider { - name: provider_name, - model_config, - recommended_models: Vec::new(), - supported_models: Vec::new(), - }) as Arc) - }) - }, - ); - - let mut conn = AcpServerConnection::new( - TestConnectionConfig { - cwd: Some(initial_cwd), - provider_factory: Some(provider_factory), - ..Default::default() - }, - openai, - ) - .await; - - let SessionData { session, .. } = conn.new_session().await.unwrap(); - let session_id = session.session_id().0.to_string(); - let SessionData { - session: loaded, .. - } = conn.load_session(&session_id, vec![]).await.unwrap(); - let expected_cwd = loaded.work_dir(); - - let captured_cwd = tokio::time::timeout(std::time::Duration::from_secs(1), async { - loop { - if let Some(cwd) = captured_cwds.lock().unwrap().get(1).cloned() { - break cwd; - } - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - }) - .await - .expect("provider factory was not called for load session"); - - assert_eq!(captured_cwd, Some(expected_cwd)); - }); -} - -#[test] +#[serial] fn test_custom_list_builtin_skill_sources() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async move { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; @@ -260,7 +182,9 @@ fn test_custom_list_builtin_skill_sources() { } #[test] +#[serial] fn test_custom_provider_inventory_includes_metadata() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; @@ -291,19 +215,16 @@ fn test_custom_provider_inventory_includes_metadata() { } #[test] +#[serial] fn test_custom_preferences_read_save_remove() { - run_test(async { - let data_root = tempfile::tempdir().unwrap(); - std::fs::write( - data_root - .path() - .join(goose::config::base::CONFIG_YAML_NAME), - "GOOSE_MODEL: gpt-4o\nGOOSE_PROVIDER: openai\nGOOSE_AUTO_COMPACT_THRESHOLD: 0.7\nVOICE_AUTO_SUBMIT_PHRASES: send it\n", - ) - .unwrap(); + let config_dir = write_acp_global_config( + "GOOSE_MODEL: gpt-4o\nGOOSE_PROVIDER: openai\nGOOSE_AUTO_COMPACT_THRESHOLD: 0.7\nVOICE_AUTO_SUBMIT_PHRASES: send it\n", + ); + + run_test(async move { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let config = TestConnectionConfig { - data_root: data_root.path().to_path_buf(), + data_root: config_dir, ..Default::default() }; let conn = AcpServerConnection::new(config, openai).await; @@ -373,7 +294,9 @@ fn test_custom_preferences_read_save_remove() { } #[test] +#[serial] fn test_custom_preferences_save_rejects_invalid_values() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; @@ -433,17 +356,16 @@ fn test_custom_preferences_save_rejects_invalid_values() { } #[test] +#[serial] fn test_custom_defaults_read() { - run_test(async { - let data_root = tempfile::tempdir().unwrap(); - std::fs::write( - data_root.path().join(goose::config::base::CONFIG_YAML_NAME), - "GOOSE_MODEL: claude-3-5-haiku-latest\nGOOSE_PROVIDER: anthropic\n", - ) - .unwrap(); + let config_dir = write_acp_global_config( + "GOOSE_MODEL: claude-3-5-haiku-latest\nGOOSE_PROVIDER: anthropic\n", + ); + + run_test(async move { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let config = TestConnectionConfig { - data_root: data_root.path().to_path_buf(), + data_root: config_dir, ..Default::default() }; let conn = AcpServerConnection::new(config, openai).await; @@ -466,21 +388,15 @@ fn test_custom_defaults_read() { } #[test] +#[serial] fn test_custom_dictation_secret_save_delete() { - let root = tempfile::tempdir().unwrap(); - let root_path = root.path().to_string_lossy().to_string(); let _env = env_lock::lock_env([ - ("GOOSE_PATH_ROOT", Some(root_path.as_str())), ("GOOSE_DISABLE_KEYRING", Some("1")), ("GROQ_API_KEY", None::<&str>), ]); - let config_dir = goose::config::paths::Paths::config_dir(); - std::fs::create_dir_all(&config_dir).unwrap(); - std::fs::write( - config_dir.join(goose::config::base::CONFIG_YAML_NAME), + let config_dir = write_acp_global_config( "GOOSE_MODEL: gpt-4o\nGOOSE_PROVIDER: openai\nGOOSE_DISABLE_KEYRING: true\n", - ) - .unwrap(); + ); run_test(async move { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; @@ -570,7 +486,9 @@ fn test_custom_dictation_secret_save_delete() { } #[test] +#[serial] fn test_raw_config_and_secret_methods_are_removed() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; @@ -590,7 +508,9 @@ fn test_raw_config_and_secret_methods_are_removed() { } #[test] +#[serial] fn test_provider_switching_updates_session_state() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let config = TestConnectionConfig { @@ -618,7 +538,9 @@ fn test_provider_switching_updates_session_state() { } #[test] +#[serial] fn test_custom_unknown_method() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; @@ -629,6 +551,7 @@ fn test_custom_unknown_method() { } #[test] +#[serial] fn test_developer_fs_requests_use_acp_session_id() { run_test(async { let seen_session_id = Arc::new(Mutex::new(None::)); @@ -648,9 +571,14 @@ fn test_developer_fs_requests_use_acp_session_id() { Arc::new(IgnoreSessionId), ) .await; + let config_dir = write_acp_global_config(&format!( + "GOOSE_MODEL: gpt-4.1\nGOOSE_PROVIDER: openai\nOPENAI_HOST: {}\n", + openai.uri() + )); let config = TestConnectionConfig { // gpt-5-nano routes to the Responses API; use a Chat Completions // model so the canned SSE fixtures are parsed correctly. + data_root: config_dir, current_model: "gpt-4.1".to_string(), read_text_file: Some(Arc::new(move |req| { *seen_session_id_clone.lock().unwrap() = Some(req.session_id.0.to_string()); @@ -683,7 +611,9 @@ fn test_developer_fs_requests_use_acp_session_id() { } #[test] +#[serial] fn test_custom_provider_supported_models_lists_raw_provider_models() { + write_acp_global_config(DEFAULT_ACP_TEST_CONFIG); run_test(async move { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; let provider_factory: AcpProviderFactory = diff --git a/crates/goose/tests/acp_fixtures/mod.rs b/crates/goose/tests/acp_fixtures/mod.rs index 6ba0b05e1ad7..55370be5dbd0 100644 --- a/crates/goose/tests/acp_fixtures/mod.rs +++ b/crates/goose/tests/acp_fixtures/mod.rs @@ -23,13 +23,31 @@ use goose::session_context::SESSION_ID_HEADER; use goose_test_support::{ExpectedSessionId, TEST_MODEL}; use std::collections::VecDeque; use std::future::Future; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock, Mutex}; use tokio::task::JoinHandle; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; +static ACP_TEST_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); +static ACP_CONFIG_ROOT: LazyLock = + LazyLock::new(|| tempfile::tempdir().unwrap()); + +fn write_global_test_config(config_path: &Path, openai_base_url: &str) { + let contents = fs::read_to_string(config_path).unwrap(); + let mut config: serde_yaml::Mapping = serde_yaml::from_str(&contents).unwrap(); + config.insert( + serde_yaml::Value::String("OPENAI_HOST".to_string()), + serde_yaml::Value::String(openai_base_url.to_string()), + ); + + let global_config_dir = Paths::config_dir(); + fs::create_dir_all(&global_config_dir).unwrap(); + let global_config_path = global_config_dir.join(goose::config::base::CONFIG_YAML_NAME); + fs::write(&global_config_path, serde_yaml::to_string(&config).unwrap()).unwrap(); +} + pub struct OpenAiFixture { _server: MockServer, base_url: String, @@ -167,10 +185,14 @@ pub async fn spawn_acp_server_in_process( if !config_path.exists() { fs::write( &config_path, - format!("GOOSE_MODEL: {current_model}\nGOOSE_PROVIDER: openai\n"), + format!( + "GOOSE_MODEL: {current_model}\nGOOSE_PROVIDER: openai\nGOOSE_MODE: {}\n", + goose_mode + ), ) .unwrap(); } + write_global_test_config(&config_path, openai_base_url); let provider_factory = provider_factory.unwrap_or_else(|| { let base_url = openai_base_url.to_string(); Arc::new( @@ -195,7 +217,6 @@ pub async fn spawn_acp_server_in_process( builtins: builtins.to_vec(), data_dir: data_root.to_path_buf(), config_dir: data_root.to_path_buf(), - goose_mode, disable_session_naming, goose_platform: GoosePlatform::GooseCli, additional_source_roots: Vec::new(), @@ -585,6 +606,10 @@ pub fn run_test(fut: F) where F: Future + Send + 'static, { + let _guard = ACP_TEST_LOCK.lock().unwrap_or_else(|err| err.into_inner()); + if std::env::var_os("GOOSE_PATH_ROOT").is_none() { + std::env::set_var("GOOSE_PATH_ROOT", ACP_CONFIG_ROOT.path()); + } register_builtin_extensions(goose_mcp::BUILTIN_EXTENSIONS.clone()); let handle = std::thread::Builder::new() diff --git a/crates/goose/tests/acp_server_test.rs b/crates/goose/tests/acp_server_test.rs index 6faca600bf8f..22cbfa32d76d 100644 --- a/crates/goose/tests/acp_server_test.rs +++ b/crates/goose/tests/acp_server_test.rs @@ -14,9 +14,10 @@ use common_tests::{ run_load_mode, run_load_model, run_load_session_error, run_load_session_mcp, run_load_session_replays_image_attachment, run_mode_set, run_model_list, run_model_set, run_model_set_error_session_not_found, run_new_session_returns_initial_config, - run_permission_persistence, run_prompt_basic, run_prompt_error, run_prompt_image, - run_prompt_image_attachment, run_prompt_mcp, run_prompt_model_mismatch, run_prompt_skill, - run_session_name_update_notification, run_shell_terminal_false, run_shell_terminal_true, + run_new_session_uses_current_config_mode, run_permission_persistence, run_prompt_basic, + run_prompt_error, run_prompt_image, run_prompt_image_attachment, run_prompt_mcp, + run_prompt_model_mismatch, run_prompt_skill, run_session_name_update_notification, + run_shell_terminal_false, run_shell_terminal_true, }; use goose::config::GooseMode; use goose::conversation::message::Message; @@ -240,6 +241,11 @@ fn test_new_session_returns_initial_config() { run_test(async { run_new_session_returns_initial_config::().await }); } +#[test] +fn test_new_session_uses_current_config_mode() { + run_test(async { run_new_session_uses_current_config_mode::().await }); +} + #[test] fn test_model_set() { run_test(async { run_model_set::().await }); diff --git a/documentation/static/oauth/huggingface-client-metadata.json b/documentation/static/oauth/huggingface-client-metadata.json new file mode 100644 index 000000000000..355b17847a95 --- /dev/null +++ b/documentation/static/oauth/huggingface-client-metadata.json @@ -0,0 +1,11 @@ +{ + "client_id": "https://goose-docs.ai/oauth/huggingface-client-metadata.json", + "client_name": "goose", + "redirect_uris": [ + "http://127.0.0.1:17863/oauth/huggingface/callback" + ], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + "code_challenge_methods_supported": ["S256"] +} diff --git a/scripts/pre-release.sh b/scripts/pre-release.sh index eed42d5464c9..d6d8f1bd6a20 100755 --- a/scripts/pre-release.sh +++ b/scripts/pre-release.sh @@ -14,7 +14,10 @@ else SEARCH="chore(release): release version" fi -PR=$(gh pr list --repo "$REPO" --search "$SEARCH in:title" --state all --limit 1 --json number,title) +# Wrap the phrase in quotes so GitHub treats it literally. Without quotes the +# parentheses in "release):" are interpreted as search operators and the query +# matches nothing. +PR=$(gh pr list --repo "$REPO" --search "\"$SEARCH\" in:title" --state all --limit 1 --json number,title) PR_NUMBER=$(echo "$PR" | jq -r '.[0].number // empty') if [[ -z "$PR_NUMBER" ]]; then diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index c1db15171d01..3bd3b51992e8 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -1312,6 +1312,66 @@ } } }, + "/config/provider-secrets": { + "get": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "list_provider_secrets", + "responses": { + "200": { + "description": "Provider secrets retrieved successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProviderSecretsResponse" + } + } + } + }, + "500": { + "description": "Internal server error" + } + } + } + }, + "/config/provider-secrets/{id}": { + "delete": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "delete_provider_secret", + "parameters": [ + { + "name": "id", + "in": "path", + "description": "Provider secret identifier", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Provider secret deleted successfully", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + }, + "400": { + "description": "Invalid provider secret identifier" + }, + "500": { + "description": "Internal server error" + } + } + } + }, "/config/providers": { "get": { "tags": [ @@ -7154,6 +7214,91 @@ } } }, + "ProviderSecret": { + "type": "object", + "required": [ + "id", + "provider", + "provider_display_name", + "name", + "storage", + "status", + "configured", + "has_secret", + "can_delete", + "can_configure" + ], + "properties": { + "can_configure": { + "type": "boolean" + }, + "can_delete": { + "type": "boolean" + }, + "configure_provider": { + "type": "string", + "nullable": true + }, + "configured": { + "type": "boolean" + }, + "expires_at": { + "type": "string", + "format": "date-time", + "nullable": true + }, + "has_secret": { + "type": "boolean" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "provider_display_name": { + "type": "string" + }, + "status": { + "$ref": "#/components/schemas/ProviderSecretStatus" + }, + "storage": { + "$ref": "#/components/schemas/ProviderSecretStorage" + } + } + }, + "ProviderSecretStatus": { + "type": "string", + "enum": [ + "valid", + "expired", + "unknown" + ] + }, + "ProviderSecretStorage": { + "type": "string", + "enum": [ + "secret_store", + "provider_cache" + ] + }, + "ProviderSecretsResponse": { + "type": "object", + "required": [ + "secrets" + ], + "properties": { + "secrets": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ProviderSecret" + } + } + } + }, "ProviderTemplate": { "type": "object", "required": [ diff --git a/ui/desktop/src/api/index.ts b/ui/desktop/src/api/index.ts index 871f52ded6bf..da395910541d 100644 --- a/ui/desktop/src/api/index.ts +++ b/ui/desktop/src/api/index.ts @@ -1,4 +1,4 @@ // This file is auto-generated by @hey-api/openapi-ts -export { addExtension, agentAddExtension, agentRemoveExtension, callTool, cancelDownload, cancelLocalModelDownload, checkProvider, cleanupProviderCache, configureProviderOauth, confirmToolAction, createCustomProvider, createRecipe, createSchedule, decodeRecipe, deleteLocalModel, deleteModel, deleteRecipe, deleteSchedule, deleteSession, diagnostics, downloadHfModel, downloadModel, encodeRecipe, exportApp, exportSession, forkSession, getCanonicalModelInfo, getCustomProvider, getDictationConfig, getDownloadProgress, getExtensions, getFeatures, getLocalModelDownloadProgress, getModelSettings, getPrompt, getPrompts, getProviderCatalog, getProviderCatalogTemplate, getProviderModelInfo, getProviderModels, getRepoFiles, getSession, getSessionExtensions, getSessionInsights, getSlashCommands, getTools, getTunnelStatus, importApp, importSession, importSessionNostr, inspectRunningJob, killRunningJob, listApps, listBuiltinChatTemplates, listLocalModels, listModels, listRecipes, listSchedules, listSessions, mcpUiProxy, type Options, parseRecipe, pauseSchedule, providers, readAllConfig, readConfig, readResource, recipeToYaml, removeConfig, removeCustomProvider, removeExtension, reply, resetPrompt, restartAgent, resumeAgent, runNowHandler, savePrompt, saveRecipe, scanRecipe, scheduleRecipe, searchHfModels, searchSessions, sendTelemetryEvent, sessionCancel, sessionEvents, sessionReply, sessionsHandler, setConfigProvider, setRecipeSlashCommand, shareSessionNostr, startAgent, startNanogptSetup, startOpenrouterSetup, startTetrateSetup, startTunnel, status, stopAgent, stopTunnel, syncFeaturedModels, systemInfo, transcribeDictation, unpauseSchedule, updateAgentProvider, updateCustomProvider, updateFromSession, updateModelSettings, updateSchedule, updateSession, updateSessionName, updateSessionUserRecipeValues, updateWorkingDir, upsertConfig, upsertPermissions, validateConfig } from './sdk.gen'; -export type { ActionRequired, ActionRequiredData, AddExtensionData, AddExtensionErrors, AddExtensionRequest, AddExtensionResponse, AddExtensionResponses, AgentAddExtensionData, AgentAddExtensionErrors, AgentAddExtensionResponse, AgentAddExtensionResponses, AgentRemoveExtensionData, AgentRemoveExtensionErrors, AgentRemoveExtensionResponse, AgentRemoveExtensionResponses, Annotations, Author, AuthorRequest, CallToolData, CallToolError, CallToolErrors, CallToolRequest, CallToolResponse, CallToolResponse2, CallToolResponses, CancelDownloadData, CancelDownloadErrors, CancelDownloadResponses, CancelLocalModelDownloadData, CancelLocalModelDownloadErrors, CancelLocalModelDownloadResponses, CancelRequest, ChatRequest, ChatTemplate, CheckProviderData, CheckProviderRequest, CleanupProviderCacheData, CleanupProviderCacheErrors, CleanupProviderCacheResponse, CleanupProviderCacheResponses, ClientOptions, CommandType, ConfigKey, ConfigKeyQuery, ConfigResponse, ConfigureProviderOauthData, ConfigureProviderOauthErrors, ConfigureProviderOauthResponses, ConfirmToolActionData, ConfirmToolActionErrors, ConfirmToolActionRequest, ConfirmToolActionResponses, Content, ContentBlock, Conversation, CreateCustomProviderData, CreateCustomProviderErrors, CreateCustomProviderResponse, CreateCustomProviderResponse2, CreateCustomProviderResponses, CreateRecipeData, CreateRecipeErrors, CreateRecipeRequest, CreateRecipeResponse, CreateRecipeResponse2, CreateRecipeResponses, CreateScheduleData, CreateScheduleErrors, CreateScheduleRequest, CreateScheduleResponse, CreateScheduleResponses, CspMetadata, DeclarativeProviderConfig, DecodeRecipeData, DecodeRecipeErrors, DecodeRecipeRequest, DecodeRecipeResponse, DecodeRecipeResponse2, DecodeRecipeResponses, DeleteLocalModelData, DeleteLocalModelErrors, DeleteLocalModelResponses, DeleteModelData, DeleteModelErrors, DeleteModelResponses, DeleteRecipeData, DeleteRecipeErrors, DeleteRecipeRequest, DeleteRecipeResponse, DeleteRecipeResponses, DeleteScheduleData, DeleteScheduleErrors, DeleteScheduleResponse, DeleteScheduleResponses, DeleteSessionData, DeleteSessionErrors, DeleteSessionResponses, DiagnosticsData, DiagnosticsErrors, DiagnosticsResponse, DiagnosticsResponses, DictationProvider, DictationProviderStatus, DownloadHfModelData, DownloadHfModelErrors, DownloadHfModelResponse, DownloadHfModelResponses, DownloadModelData, DownloadModelErrors, DownloadModelRequest, DownloadModelResponses, DownloadProgress, DownloadStatus, EmbeddedResource, EncodeRecipeData, EncodeRecipeErrors, EncodeRecipeRequest, EncodeRecipeResponse, EncodeRecipeResponse2, EncodeRecipeResponses, Envs, EnvVarConfig, ErrorResponse, ExportAppData, ExportAppError, ExportAppErrors, ExportAppResponse, ExportAppResponses, ExportSessionData, ExportSessionErrors, ExportSessionResponse, ExportSessionResponses, ExtensionConfig, ExtensionData, ExtensionEntry, ExtensionLoadResult, ExtensionQuery, ExtensionResponse, FeaturesResponse, ForkRequest, ForkResponse, ForkSessionData, ForkSessionErrors, ForkSessionResponse, ForkSessionResponses, FrontendToolRequest, GetCanonicalModelInfoData, GetCanonicalModelInfoResponse, GetCanonicalModelInfoResponses, GetCustomProviderData, GetCustomProviderErrors, GetCustomProviderResponse, GetCustomProviderResponses, GetDictationConfigData, GetDictationConfigResponse, GetDictationConfigResponses, GetDownloadProgressData, GetDownloadProgressErrors, GetDownloadProgressResponse, GetDownloadProgressResponses, GetExtensionsData, GetExtensionsErrors, GetExtensionsResponse, GetExtensionsResponses, GetFeaturesData, GetFeaturesResponse, GetFeaturesResponses, GetLocalModelDownloadProgressData, GetLocalModelDownloadProgressErrors, GetLocalModelDownloadProgressResponse, GetLocalModelDownloadProgressResponses, GetModelSettingsData, GetModelSettingsErrors, GetModelSettingsResponse, GetModelSettingsResponses, GetPromptData, GetPromptErrors, GetPromptResponse, GetPromptResponses, GetPromptsData, GetPromptsResponse, GetPromptsResponses, GetProviderCatalogData, GetProviderCatalogErrors, GetProviderCatalogResponse, GetProviderCatalogResponses, GetProviderCatalogTemplateData, GetProviderCatalogTemplateErrors, GetProviderCatalogTemplateResponse, GetProviderCatalogTemplateResponses, GetProviderModelInfoData, GetProviderModelInfoErrors, GetProviderModelInfoResponse, GetProviderModelInfoResponses, GetProviderModelsData, GetProviderModelsErrors, GetProviderModelsResponse, GetProviderModelsResponses, GetRepoFilesData, GetRepoFilesResponse, GetRepoFilesResponses, GetSessionData, GetSessionErrors, GetSessionExtensionsData, GetSessionExtensionsErrors, GetSessionExtensionsResponse, GetSessionExtensionsResponses, GetSessionInsightsData, GetSessionInsightsErrors, GetSessionInsightsResponse, GetSessionInsightsResponses, GetSessionResponse, GetSessionResponses, GetSlashCommandsData, GetSlashCommandsResponse, GetSlashCommandsResponses, GetToolsData, GetToolsErrors, GetToolsQuery, GetToolsResponse, GetToolsResponses, GetTunnelStatusData, GetTunnelStatusResponse, GetTunnelStatusResponses, GooseApp, GooseMode, HfGgufFile, HfModelInfo, HfQuantVariant, Icon, IconTheme, ImageContent, ImportAppData, ImportAppError, ImportAppErrors, ImportAppRequest, ImportAppResponse, ImportAppResponse2, ImportAppResponses, ImportSessionData, ImportSessionErrors, ImportSessionNostrData, ImportSessionNostrErrors, ImportSessionNostrRequest, ImportSessionNostrResponse, ImportSessionNostrResponses, ImportSessionRequest, ImportSessionResponse, ImportSessionResponses, InferenceMetadata, InspectJobResponse, InspectRunningJobData, InspectRunningJobErrors, InspectRunningJobResponse, InspectRunningJobResponses, JsonObject, KillJobResponse, KillRunningJobData, KillRunningJobResponses, ListAppsData, ListAppsError, ListAppsErrors, ListAppsRequest, ListAppsResponse, ListAppsResponse2, ListAppsResponses, ListBuiltinChatTemplatesData, ListBuiltinChatTemplatesResponse, ListBuiltinChatTemplatesResponses, ListLocalModelsData, ListLocalModelsResponse, ListLocalModelsResponses, ListModelsData, ListModelsResponse, ListModelsResponses, ListRecipeResponse, ListRecipesData, ListRecipesErrors, ListRecipesResponse, ListRecipesResponses, ListSchedulesData, ListSchedulesErrors, ListSchedulesResponse, ListSchedulesResponse2, ListSchedulesResponses, ListSessionsData, ListSessionsErrors, ListSessionsResponse, ListSessionsResponses, LoadedProvider, LocalModelResponse, McpAppResource, McpUiProxyData, McpUiProxyErrors, McpUiProxyResponses, Message, MessageContent, MessageEvent, MessageMetadata, ModelCapabilities, ModelConfig, ModelDownloadStatus, ModelInfo, ModelInfoData, ModelInfoQuery, ModelInfoResponse, ModelSettings, ModelTemplate, ParseRecipeData, ParseRecipeError, ParseRecipeErrors, ParseRecipeRequest, ParseRecipeResponse, ParseRecipeResponse2, ParseRecipeResponses, PauseScheduleData, PauseScheduleErrors, PauseScheduleResponse, PauseScheduleResponses, Permission, PermissionLevel, PermissionsMetadata, PrincipalType, PromptContentResponse, PromptsListResponse, ProviderCatalogEntry, ProviderDetails, ProviderEngine, ProviderMetadata, ProviderModelInfoQuery, ProvidersData, ProvidersResponse, ProvidersResponse2, ProvidersResponses, ProviderTemplate, ProviderType, RawAudioContent, RawEmbeddedResource, RawImageContent, RawResource, RawTextContent, ReadAllConfigData, ReadAllConfigResponse, ReadAllConfigResponses, ReadConfigData, ReadConfigErrors, ReadConfigResponses, ReadResourceData, ReadResourceErrors, ReadResourceRequest, ReadResourceResponse, ReadResourceResponse2, ReadResourceResponses, Recipe, RecipeManifest, RecipeParameter, RecipeParameterInputType, RecipeParameterRequirement, RecipeToYamlData, RecipeToYamlError, RecipeToYamlErrors, RecipeToYamlRequest, RecipeToYamlResponse, RecipeToYamlResponse2, RecipeToYamlResponses, RedactedThinkingContent, RemoveConfigData, RemoveConfigErrors, RemoveConfigResponse, RemoveConfigResponses, RemoveCustomProviderData, RemoveCustomProviderErrors, RemoveCustomProviderResponse, RemoveCustomProviderResponses, RemoveExtensionData, RemoveExtensionErrors, RemoveExtensionRequest, RemoveExtensionResponse, RemoveExtensionResponses, ReplyData, ReplyErrors, ReplyResponse, ReplyResponses, RepoVariantsResponse, ResetPromptData, ResetPromptErrors, ResetPromptResponse, ResetPromptResponses, ResourceContents, ResourceMetadata, Response, RestartAgentData, RestartAgentErrors, RestartAgentRequest, RestartAgentResponse, RestartAgentResponse2, RestartAgentResponses, ResumeAgentData, ResumeAgentErrors, ResumeAgentRequest, ResumeAgentResponse, ResumeAgentResponse2, ResumeAgentResponses, RetryConfig, Role, RunNowHandlerData, RunNowHandlerErrors, RunNowHandlerResponse, RunNowHandlerResponses, RunNowResponse, SamplingConfig, SavePromptData, SavePromptErrors, SavePromptRequest, SavePromptResponse, SavePromptResponses, SaveRecipeData, SaveRecipeError, SaveRecipeErrors, SaveRecipeRequest, SaveRecipeResponse, SaveRecipeResponse2, SaveRecipeResponses, ScanRecipeData, ScanRecipeRequest, ScanRecipeResponse, ScanRecipeResponse2, ScanRecipeResponses, ScheduledJob, ScheduleRecipeData, ScheduleRecipeErrors, ScheduleRecipeRequest, ScheduleRecipeResponses, SearchHfModelsData, SearchHfModelsErrors, SearchHfModelsResponse, SearchHfModelsResponses, SearchSessionsData, SearchSessionsErrors, SearchSessionsResponse, SearchSessionsResponses, SendTelemetryEventData, SendTelemetryEventResponses, Session, SessionCancelData, SessionCancelResponses, SessionDisplayInfo, SessionEventsData, SessionEventsErrors, SessionEventsResponse, SessionEventsResponses, SessionExtensionsResponse, SessionInsights, SessionListResponse, SessionReplyData, SessionReplyErrors, SessionReplyRequest, SessionReplyResponse, SessionReplyResponse2, SessionReplyResponses, SessionsHandlerData, SessionsHandlerErrors, SessionsHandlerResponse, SessionsHandlerResponses, SessionsQuery, SessionType, SetConfigProviderData, SetProviderRequest, SetRecipeSlashCommandData, SetRecipeSlashCommandErrors, SetRecipeSlashCommandResponses, SetSlashCommandRequest, Settings, SetupResponse, ShareSessionNostrData, ShareSessionNostrErrors, ShareSessionNostrRequest, ShareSessionNostrResponse, ShareSessionNostrResponse2, ShareSessionNostrResponses, SlashCommand, SlashCommandsResponse, StartAgentData, StartAgentError, StartAgentErrors, StartAgentRequest, StartAgentResponse, StartAgentResponses, StartNanogptSetupData, StartNanogptSetupResponse, StartNanogptSetupResponses, StartOpenrouterSetupData, StartOpenrouterSetupResponse, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponse, StartTetrateSetupResponses, StartTunnelData, StartTunnelError, StartTunnelErrors, StartTunnelResponse, StartTunnelResponses, StatusData, StatusResponse, StatusResponses, StopAgentData, StopAgentErrors, StopAgentRequest, StopAgentResponse, StopAgentResponses, StopTunnelData, StopTunnelError, StopTunnelErrors, StopTunnelResponses, SubRecipe, SuccessCheck, SyncFeaturedModelsData, SyncFeaturedModelsResponses, SystemInfo, SystemInfoData, SystemInfoResponse, SystemInfoResponses, SystemNotificationContent, SystemNotificationType, TaskSupport, TelemetryEventRequest, Template, TextContent, ThinkingContent, ThinkingEffort, TokenState, Tool, ToolAnnotations, ToolCallingMode, ToolConfirmationRequest, ToolExecution, ToolInfo, ToolPermission, ToolRequest, ToolResponse, TranscribeDictationData, TranscribeDictationErrors, TranscribeDictationResponse, TranscribeDictationResponses, TranscribeRequest, TranscribeResponse, TunnelInfo, TunnelState, UiMetadata, UnpauseScheduleData, UnpauseScheduleErrors, UnpauseScheduleResponse, UnpauseScheduleResponses, UpdateAgentProviderData, UpdateAgentProviderErrors, UpdateAgentProviderResponses, UpdateCustomProviderData, UpdateCustomProviderErrors, UpdateCustomProviderRequest, UpdateCustomProviderResponse, UpdateCustomProviderResponses, UpdateFromSessionData, UpdateFromSessionErrors, UpdateFromSessionRequest, UpdateFromSessionResponses, UpdateModelSettingsData, UpdateModelSettingsErrors, UpdateModelSettingsResponse, UpdateModelSettingsResponses, UpdateProviderRequest, UpdateScheduleData, UpdateScheduleErrors, UpdateScheduleRequest, UpdateScheduleResponse, UpdateScheduleResponses, UpdateSessionData, UpdateSessionErrors, UpdateSessionNameData, UpdateSessionNameErrors, UpdateSessionNameRequest, UpdateSessionNameResponses, UpdateSessionRequest, UpdateSessionResponses, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesError, UpdateSessionUserRecipeValuesErrors, UpdateSessionUserRecipeValuesRequest, UpdateSessionUserRecipeValuesResponse, UpdateSessionUserRecipeValuesResponse2, UpdateSessionUserRecipeValuesResponses, UpdateWorkingDirData, UpdateWorkingDirErrors, UpdateWorkingDirRequest, UpdateWorkingDirResponses, UpsertConfigData, UpsertConfigErrors, UpsertConfigQuery, UpsertConfigResponse, UpsertConfigResponses, UpsertPermissionsData, UpsertPermissionsErrors, UpsertPermissionsQuery, UpsertPermissionsResponse, UpsertPermissionsResponses, ValidateConfigData, ValidateConfigErrors, ValidateConfigResponse, ValidateConfigResponses, WhisperModelResponse, WindowProps } from './types.gen'; +export { addExtension, agentAddExtension, agentRemoveExtension, callTool, cancelDownload, cancelLocalModelDownload, checkProvider, cleanupProviderCache, configureProviderOauth, confirmToolAction, createCustomProvider, createRecipe, createSchedule, decodeRecipe, deleteLocalModel, deleteModel, deleteProviderSecret, deleteRecipe, deleteSchedule, deleteSession, diagnostics, downloadHfModel, downloadModel, encodeRecipe, exportApp, exportSession, forkSession, getCanonicalModelInfo, getCustomProvider, getDictationConfig, getDownloadProgress, getExtensions, getFeatures, getLocalModelDownloadProgress, getModelSettings, getPrompt, getPrompts, getProviderCatalog, getProviderCatalogTemplate, getProviderModelInfo, getProviderModels, getRepoFiles, getSession, getSessionExtensions, getSessionInsights, getSlashCommands, getTools, getTunnelStatus, importApp, importSession, importSessionNostr, inspectRunningJob, killRunningJob, listApps, listBuiltinChatTemplates, listLocalModels, listModels, listProviderSecrets, listRecipes, listSchedules, listSessions, mcpUiProxy, type Options, parseRecipe, pauseSchedule, providers, readAllConfig, readConfig, readResource, recipeToYaml, removeConfig, removeCustomProvider, removeExtension, reply, resetPrompt, restartAgent, resumeAgent, runNowHandler, savePrompt, saveRecipe, scanRecipe, scheduleRecipe, searchHfModels, searchSessions, sendTelemetryEvent, sessionCancel, sessionEvents, sessionReply, sessionsHandler, setConfigProvider, setRecipeSlashCommand, shareSessionNostr, startAgent, startNanogptSetup, startOpenrouterSetup, startTetrateSetup, startTunnel, status, stopAgent, stopTunnel, syncFeaturedModels, systemInfo, transcribeDictation, unpauseSchedule, updateAgentProvider, updateCustomProvider, updateFromSession, updateModelSettings, updateSchedule, updateSession, updateSessionName, updateSessionUserRecipeValues, updateWorkingDir, upsertConfig, upsertPermissions, validateConfig } from './sdk.gen'; +export type { ActionRequired, ActionRequiredData, AddExtensionData, AddExtensionErrors, AddExtensionRequest, AddExtensionResponse, AddExtensionResponses, AgentAddExtensionData, AgentAddExtensionErrors, AgentAddExtensionResponse, AgentAddExtensionResponses, AgentRemoveExtensionData, AgentRemoveExtensionErrors, AgentRemoveExtensionResponse, AgentRemoveExtensionResponses, Annotations, Author, AuthorRequest, CallToolData, CallToolError, CallToolErrors, CallToolRequest, CallToolResponse, CallToolResponse2, CallToolResponses, CancelDownloadData, CancelDownloadErrors, CancelDownloadResponses, CancelLocalModelDownloadData, CancelLocalModelDownloadErrors, CancelLocalModelDownloadResponses, CancelRequest, ChatRequest, ChatTemplate, CheckProviderData, CheckProviderRequest, CleanupProviderCacheData, CleanupProviderCacheErrors, CleanupProviderCacheResponse, CleanupProviderCacheResponses, ClientOptions, CommandType, ConfigKey, ConfigKeyQuery, ConfigResponse, ConfigureProviderOauthData, ConfigureProviderOauthErrors, ConfigureProviderOauthResponses, ConfirmToolActionData, ConfirmToolActionErrors, ConfirmToolActionRequest, ConfirmToolActionResponses, Content, ContentBlock, Conversation, CreateCustomProviderData, CreateCustomProviderErrors, CreateCustomProviderResponse, CreateCustomProviderResponse2, CreateCustomProviderResponses, CreateRecipeData, CreateRecipeErrors, CreateRecipeRequest, CreateRecipeResponse, CreateRecipeResponse2, CreateRecipeResponses, CreateScheduleData, CreateScheduleErrors, CreateScheduleRequest, CreateScheduleResponse, CreateScheduleResponses, CspMetadata, DeclarativeProviderConfig, DecodeRecipeData, DecodeRecipeErrors, DecodeRecipeRequest, DecodeRecipeResponse, DecodeRecipeResponse2, DecodeRecipeResponses, DeleteLocalModelData, DeleteLocalModelErrors, DeleteLocalModelResponses, DeleteModelData, DeleteModelErrors, DeleteModelResponses, DeleteProviderSecretData, DeleteProviderSecretErrors, DeleteProviderSecretResponse, DeleteProviderSecretResponses, DeleteRecipeData, DeleteRecipeErrors, DeleteRecipeRequest, DeleteRecipeResponse, DeleteRecipeResponses, DeleteScheduleData, DeleteScheduleErrors, DeleteScheduleResponse, DeleteScheduleResponses, DeleteSessionData, DeleteSessionErrors, DeleteSessionResponses, DiagnosticsData, DiagnosticsErrors, DiagnosticsResponse, DiagnosticsResponses, DictationProvider, DictationProviderStatus, DownloadHfModelData, DownloadHfModelErrors, DownloadHfModelResponse, DownloadHfModelResponses, DownloadModelData, DownloadModelErrors, DownloadModelRequest, DownloadModelResponses, DownloadProgress, DownloadStatus, EmbeddedResource, EncodeRecipeData, EncodeRecipeErrors, EncodeRecipeRequest, EncodeRecipeResponse, EncodeRecipeResponse2, EncodeRecipeResponses, Envs, EnvVarConfig, ErrorResponse, ExportAppData, ExportAppError, ExportAppErrors, ExportAppResponse, ExportAppResponses, ExportSessionData, ExportSessionErrors, ExportSessionResponse, ExportSessionResponses, ExtensionConfig, ExtensionData, ExtensionEntry, ExtensionLoadResult, ExtensionQuery, ExtensionResponse, FeaturesResponse, ForkRequest, ForkResponse, ForkSessionData, ForkSessionErrors, ForkSessionResponse, ForkSessionResponses, FrontendToolRequest, GetCanonicalModelInfoData, GetCanonicalModelInfoResponse, GetCanonicalModelInfoResponses, GetCustomProviderData, GetCustomProviderErrors, GetCustomProviderResponse, GetCustomProviderResponses, GetDictationConfigData, GetDictationConfigResponse, GetDictationConfigResponses, GetDownloadProgressData, GetDownloadProgressErrors, GetDownloadProgressResponse, GetDownloadProgressResponses, GetExtensionsData, GetExtensionsErrors, GetExtensionsResponse, GetExtensionsResponses, GetFeaturesData, GetFeaturesResponse, GetFeaturesResponses, GetLocalModelDownloadProgressData, GetLocalModelDownloadProgressErrors, GetLocalModelDownloadProgressResponse, GetLocalModelDownloadProgressResponses, GetModelSettingsData, GetModelSettingsErrors, GetModelSettingsResponse, GetModelSettingsResponses, GetPromptData, GetPromptErrors, GetPromptResponse, GetPromptResponses, GetPromptsData, GetPromptsResponse, GetPromptsResponses, GetProviderCatalogData, GetProviderCatalogErrors, GetProviderCatalogResponse, GetProviderCatalogResponses, GetProviderCatalogTemplateData, GetProviderCatalogTemplateErrors, GetProviderCatalogTemplateResponse, GetProviderCatalogTemplateResponses, GetProviderModelInfoData, GetProviderModelInfoErrors, GetProviderModelInfoResponse, GetProviderModelInfoResponses, GetProviderModelsData, GetProviderModelsErrors, GetProviderModelsResponse, GetProviderModelsResponses, GetRepoFilesData, GetRepoFilesResponse, GetRepoFilesResponses, GetSessionData, GetSessionErrors, GetSessionExtensionsData, GetSessionExtensionsErrors, GetSessionExtensionsResponse, GetSessionExtensionsResponses, GetSessionInsightsData, GetSessionInsightsErrors, GetSessionInsightsResponse, GetSessionInsightsResponses, GetSessionResponse, GetSessionResponses, GetSlashCommandsData, GetSlashCommandsResponse, GetSlashCommandsResponses, GetToolsData, GetToolsErrors, GetToolsQuery, GetToolsResponse, GetToolsResponses, GetTunnelStatusData, GetTunnelStatusResponse, GetTunnelStatusResponses, GooseApp, GooseMode, HfGgufFile, HfModelInfo, HfQuantVariant, Icon, IconTheme, ImageContent, ImportAppData, ImportAppError, ImportAppErrors, ImportAppRequest, ImportAppResponse, ImportAppResponse2, ImportAppResponses, ImportSessionData, ImportSessionErrors, ImportSessionNostrData, ImportSessionNostrErrors, ImportSessionNostrRequest, ImportSessionNostrResponse, ImportSessionNostrResponses, ImportSessionRequest, ImportSessionResponse, ImportSessionResponses, InferenceMetadata, InspectJobResponse, InspectRunningJobData, InspectRunningJobErrors, InspectRunningJobResponse, InspectRunningJobResponses, JsonObject, KillJobResponse, KillRunningJobData, KillRunningJobResponses, ListAppsData, ListAppsError, ListAppsErrors, ListAppsRequest, ListAppsResponse, ListAppsResponse2, ListAppsResponses, ListBuiltinChatTemplatesData, ListBuiltinChatTemplatesResponse, ListBuiltinChatTemplatesResponses, ListLocalModelsData, ListLocalModelsResponse, ListLocalModelsResponses, ListModelsData, ListModelsResponse, ListModelsResponses, ListProviderSecretsData, ListProviderSecretsErrors, ListProviderSecretsResponse, ListProviderSecretsResponses, ListRecipeResponse, ListRecipesData, ListRecipesErrors, ListRecipesResponse, ListRecipesResponses, ListSchedulesData, ListSchedulesErrors, ListSchedulesResponse, ListSchedulesResponse2, ListSchedulesResponses, ListSessionsData, ListSessionsErrors, ListSessionsResponse, ListSessionsResponses, LoadedProvider, LocalModelResponse, McpAppResource, McpUiProxyData, McpUiProxyErrors, McpUiProxyResponses, Message, MessageContent, MessageEvent, MessageMetadata, ModelCapabilities, ModelConfig, ModelDownloadStatus, ModelInfo, ModelInfoData, ModelInfoQuery, ModelInfoResponse, ModelSettings, ModelTemplate, ParseRecipeData, ParseRecipeError, ParseRecipeErrors, ParseRecipeRequest, ParseRecipeResponse, ParseRecipeResponse2, ParseRecipeResponses, PauseScheduleData, PauseScheduleErrors, PauseScheduleResponse, PauseScheduleResponses, Permission, PermissionLevel, PermissionsMetadata, PrincipalType, PromptContentResponse, PromptsListResponse, ProviderCatalogEntry, ProviderDetails, ProviderEngine, ProviderMetadata, ProviderModelInfoQuery, ProvidersData, ProviderSecret, ProviderSecretsResponse, ProviderSecretStatus, ProviderSecretStorage, ProvidersResponse, ProvidersResponse2, ProvidersResponses, ProviderTemplate, ProviderType, RawAudioContent, RawEmbeddedResource, RawImageContent, RawResource, RawTextContent, ReadAllConfigData, ReadAllConfigResponse, ReadAllConfigResponses, ReadConfigData, ReadConfigErrors, ReadConfigResponses, ReadResourceData, ReadResourceErrors, ReadResourceRequest, ReadResourceResponse, ReadResourceResponse2, ReadResourceResponses, Recipe, RecipeManifest, RecipeParameter, RecipeParameterInputType, RecipeParameterRequirement, RecipeToYamlData, RecipeToYamlError, RecipeToYamlErrors, RecipeToYamlRequest, RecipeToYamlResponse, RecipeToYamlResponse2, RecipeToYamlResponses, RedactedThinkingContent, RemoveConfigData, RemoveConfigErrors, RemoveConfigResponse, RemoveConfigResponses, RemoveCustomProviderData, RemoveCustomProviderErrors, RemoveCustomProviderResponse, RemoveCustomProviderResponses, RemoveExtensionData, RemoveExtensionErrors, RemoveExtensionRequest, RemoveExtensionResponse, RemoveExtensionResponses, ReplyData, ReplyErrors, ReplyResponse, ReplyResponses, RepoVariantsResponse, ResetPromptData, ResetPromptErrors, ResetPromptResponse, ResetPromptResponses, ResourceContents, ResourceMetadata, Response, RestartAgentData, RestartAgentErrors, RestartAgentRequest, RestartAgentResponse, RestartAgentResponse2, RestartAgentResponses, ResumeAgentData, ResumeAgentErrors, ResumeAgentRequest, ResumeAgentResponse, ResumeAgentResponse2, ResumeAgentResponses, RetryConfig, Role, RunNowHandlerData, RunNowHandlerErrors, RunNowHandlerResponse, RunNowHandlerResponses, RunNowResponse, SamplingConfig, SavePromptData, SavePromptErrors, SavePromptRequest, SavePromptResponse, SavePromptResponses, SaveRecipeData, SaveRecipeError, SaveRecipeErrors, SaveRecipeRequest, SaveRecipeResponse, SaveRecipeResponse2, SaveRecipeResponses, ScanRecipeData, ScanRecipeRequest, ScanRecipeResponse, ScanRecipeResponse2, ScanRecipeResponses, ScheduledJob, ScheduleRecipeData, ScheduleRecipeErrors, ScheduleRecipeRequest, ScheduleRecipeResponses, SearchHfModelsData, SearchHfModelsErrors, SearchHfModelsResponse, SearchHfModelsResponses, SearchSessionsData, SearchSessionsErrors, SearchSessionsResponse, SearchSessionsResponses, SendTelemetryEventData, SendTelemetryEventResponses, Session, SessionCancelData, SessionCancelResponses, SessionDisplayInfo, SessionEventsData, SessionEventsErrors, SessionEventsResponse, SessionEventsResponses, SessionExtensionsResponse, SessionInsights, SessionListResponse, SessionReplyData, SessionReplyErrors, SessionReplyRequest, SessionReplyResponse, SessionReplyResponse2, SessionReplyResponses, SessionsHandlerData, SessionsHandlerErrors, SessionsHandlerResponse, SessionsHandlerResponses, SessionsQuery, SessionType, SetConfigProviderData, SetProviderRequest, SetRecipeSlashCommandData, SetRecipeSlashCommandErrors, SetRecipeSlashCommandResponses, SetSlashCommandRequest, Settings, SetupResponse, ShareSessionNostrData, ShareSessionNostrErrors, ShareSessionNostrRequest, ShareSessionNostrResponse, ShareSessionNostrResponse2, ShareSessionNostrResponses, SlashCommand, SlashCommandsResponse, StartAgentData, StartAgentError, StartAgentErrors, StartAgentRequest, StartAgentResponse, StartAgentResponses, StartNanogptSetupData, StartNanogptSetupResponse, StartNanogptSetupResponses, StartOpenrouterSetupData, StartOpenrouterSetupResponse, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponse, StartTetrateSetupResponses, StartTunnelData, StartTunnelError, StartTunnelErrors, StartTunnelResponse, StartTunnelResponses, StatusData, StatusResponse, StatusResponses, StopAgentData, StopAgentErrors, StopAgentRequest, StopAgentResponse, StopAgentResponses, StopTunnelData, StopTunnelError, StopTunnelErrors, StopTunnelResponses, SubRecipe, SuccessCheck, SyncFeaturedModelsData, SyncFeaturedModelsResponses, SystemInfo, SystemInfoData, SystemInfoResponse, SystemInfoResponses, SystemNotificationContent, SystemNotificationType, TaskSupport, TelemetryEventRequest, Template, TextContent, ThinkingContent, ThinkingEffort, TokenState, Tool, ToolAnnotations, ToolCallingMode, ToolConfirmationRequest, ToolExecution, ToolInfo, ToolPermission, ToolRequest, ToolResponse, TranscribeDictationData, TranscribeDictationErrors, TranscribeDictationResponse, TranscribeDictationResponses, TranscribeRequest, TranscribeResponse, TunnelInfo, TunnelState, UiMetadata, UnpauseScheduleData, UnpauseScheduleErrors, UnpauseScheduleResponse, UnpauseScheduleResponses, UpdateAgentProviderData, UpdateAgentProviderErrors, UpdateAgentProviderResponses, UpdateCustomProviderData, UpdateCustomProviderErrors, UpdateCustomProviderRequest, UpdateCustomProviderResponse, UpdateCustomProviderResponses, UpdateFromSessionData, UpdateFromSessionErrors, UpdateFromSessionRequest, UpdateFromSessionResponses, UpdateModelSettingsData, UpdateModelSettingsErrors, UpdateModelSettingsResponse, UpdateModelSettingsResponses, UpdateProviderRequest, UpdateScheduleData, UpdateScheduleErrors, UpdateScheduleRequest, UpdateScheduleResponse, UpdateScheduleResponses, UpdateSessionData, UpdateSessionErrors, UpdateSessionNameData, UpdateSessionNameErrors, UpdateSessionNameRequest, UpdateSessionNameResponses, UpdateSessionRequest, UpdateSessionResponses, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesError, UpdateSessionUserRecipeValuesErrors, UpdateSessionUserRecipeValuesRequest, UpdateSessionUserRecipeValuesResponse, UpdateSessionUserRecipeValuesResponse2, UpdateSessionUserRecipeValuesResponses, UpdateWorkingDirData, UpdateWorkingDirErrors, UpdateWorkingDirRequest, UpdateWorkingDirResponses, UpsertConfigData, UpsertConfigErrors, UpsertConfigQuery, UpsertConfigResponse, UpsertConfigResponses, UpsertPermissionsData, UpsertPermissionsErrors, UpsertPermissionsQuery, UpsertPermissionsResponse, UpsertPermissionsResponses, ValidateConfigData, ValidateConfigErrors, ValidateConfigResponse, ValidateConfigResponses, WhisperModelResponse, WindowProps } from './types.gen'; diff --git a/ui/desktop/src/api/sdk.gen.ts b/ui/desktop/src/api/sdk.gen.ts index 081dfb57fcd5..91833f88b9f2 100644 --- a/ui/desktop/src/api/sdk.gen.ts +++ b/ui/desktop/src/api/sdk.gen.ts @@ -2,7 +2,7 @@ import type { Client, Options as Options2, TDataShape } from './client'; import { client } from './client.gen'; -import type { AddExtensionData, AddExtensionErrors, AddExtensionResponses, AgentAddExtensionData, AgentAddExtensionErrors, AgentAddExtensionResponses, AgentRemoveExtensionData, AgentRemoveExtensionErrors, AgentRemoveExtensionResponses, CallToolData, CallToolErrors, CallToolResponses, CancelDownloadData, CancelDownloadErrors, CancelDownloadResponses, CancelLocalModelDownloadData, CancelLocalModelDownloadErrors, CancelLocalModelDownloadResponses, CheckProviderData, CleanupProviderCacheData, CleanupProviderCacheErrors, CleanupProviderCacheResponses, ConfigureProviderOauthData, ConfigureProviderOauthErrors, ConfigureProviderOauthResponses, ConfirmToolActionData, ConfirmToolActionErrors, ConfirmToolActionResponses, CreateCustomProviderData, CreateCustomProviderErrors, CreateCustomProviderResponses, CreateRecipeData, CreateRecipeErrors, CreateRecipeResponses, CreateScheduleData, CreateScheduleErrors, CreateScheduleResponses, DecodeRecipeData, DecodeRecipeErrors, DecodeRecipeResponses, DeleteLocalModelData, DeleteLocalModelErrors, DeleteLocalModelResponses, DeleteModelData, DeleteModelErrors, DeleteModelResponses, DeleteRecipeData, DeleteRecipeErrors, DeleteRecipeResponses, DeleteScheduleData, DeleteScheduleErrors, DeleteScheduleResponses, DeleteSessionData, DeleteSessionErrors, DeleteSessionResponses, DiagnosticsData, DiagnosticsErrors, DiagnosticsResponses, DownloadHfModelData, DownloadHfModelErrors, DownloadHfModelResponses, DownloadModelData, DownloadModelErrors, DownloadModelResponses, EncodeRecipeData, EncodeRecipeErrors, EncodeRecipeResponses, ExportAppData, ExportAppErrors, ExportAppResponses, ExportSessionData, ExportSessionErrors, ExportSessionResponses, ForkSessionData, ForkSessionErrors, ForkSessionResponses, GetCanonicalModelInfoData, GetCanonicalModelInfoResponses, GetCustomProviderData, GetCustomProviderErrors, GetCustomProviderResponses, GetDictationConfigData, GetDictationConfigResponses, GetDownloadProgressData, GetDownloadProgressErrors, GetDownloadProgressResponses, GetExtensionsData, GetExtensionsErrors, GetExtensionsResponses, GetFeaturesData, GetFeaturesResponses, GetLocalModelDownloadProgressData, GetLocalModelDownloadProgressErrors, GetLocalModelDownloadProgressResponses, GetModelSettingsData, GetModelSettingsErrors, GetModelSettingsResponses, GetPromptData, GetPromptErrors, GetPromptResponses, GetPromptsData, GetPromptsResponses, GetProviderCatalogData, GetProviderCatalogErrors, GetProviderCatalogResponses, GetProviderCatalogTemplateData, GetProviderCatalogTemplateErrors, GetProviderCatalogTemplateResponses, GetProviderModelInfoData, GetProviderModelInfoErrors, GetProviderModelInfoResponses, GetProviderModelsData, GetProviderModelsErrors, GetProviderModelsResponses, GetRepoFilesData, GetRepoFilesResponses, GetSessionData, GetSessionErrors, GetSessionExtensionsData, GetSessionExtensionsErrors, GetSessionExtensionsResponses, GetSessionInsightsData, GetSessionInsightsErrors, GetSessionInsightsResponses, GetSessionResponses, GetSlashCommandsData, GetSlashCommandsResponses, GetToolsData, GetToolsErrors, GetToolsResponses, GetTunnelStatusData, GetTunnelStatusResponses, ImportAppData, ImportAppErrors, ImportAppResponses, ImportSessionData, ImportSessionErrors, ImportSessionNostrData, ImportSessionNostrErrors, ImportSessionNostrResponses, ImportSessionResponses, InspectRunningJobData, InspectRunningJobErrors, InspectRunningJobResponses, KillRunningJobData, KillRunningJobResponses, ListAppsData, ListAppsErrors, ListAppsResponses, ListBuiltinChatTemplatesData, ListBuiltinChatTemplatesResponses, ListLocalModelsData, ListLocalModelsResponses, ListModelsData, ListModelsResponses, ListRecipesData, ListRecipesErrors, ListRecipesResponses, ListSchedulesData, ListSchedulesErrors, ListSchedulesResponses, ListSessionsData, ListSessionsErrors, ListSessionsResponses, McpUiProxyData, McpUiProxyErrors, McpUiProxyResponses, ParseRecipeData, ParseRecipeErrors, ParseRecipeResponses, PauseScheduleData, PauseScheduleErrors, PauseScheduleResponses, ProvidersData, ProvidersResponses, ReadAllConfigData, ReadAllConfigResponses, ReadConfigData, ReadConfigErrors, ReadConfigResponses, ReadResourceData, ReadResourceErrors, ReadResourceResponses, RecipeToYamlData, RecipeToYamlErrors, RecipeToYamlResponses, RemoveConfigData, RemoveConfigErrors, RemoveConfigResponses, RemoveCustomProviderData, RemoveCustomProviderErrors, RemoveCustomProviderResponses, RemoveExtensionData, RemoveExtensionErrors, RemoveExtensionResponses, ReplyData, ReplyErrors, ReplyResponses, ResetPromptData, ResetPromptErrors, ResetPromptResponses, RestartAgentData, RestartAgentErrors, RestartAgentResponses, ResumeAgentData, ResumeAgentErrors, ResumeAgentResponses, RunNowHandlerData, RunNowHandlerErrors, RunNowHandlerResponses, SavePromptData, SavePromptErrors, SavePromptResponses, SaveRecipeData, SaveRecipeErrors, SaveRecipeResponses, ScanRecipeData, ScanRecipeResponses, ScheduleRecipeData, ScheduleRecipeErrors, ScheduleRecipeResponses, SearchHfModelsData, SearchHfModelsErrors, SearchHfModelsResponses, SearchSessionsData, SearchSessionsErrors, SearchSessionsResponses, SendTelemetryEventData, SendTelemetryEventResponses, SessionCancelData, SessionCancelResponses, SessionEventsData, SessionEventsErrors, SessionEventsResponses, SessionReplyData, SessionReplyErrors, SessionReplyResponses, SessionsHandlerData, SessionsHandlerErrors, SessionsHandlerResponses, SetConfigProviderData, SetRecipeSlashCommandData, SetRecipeSlashCommandErrors, SetRecipeSlashCommandResponses, ShareSessionNostrData, ShareSessionNostrErrors, ShareSessionNostrResponses, StartAgentData, StartAgentErrors, StartAgentResponses, StartNanogptSetupData, StartNanogptSetupResponses, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, StartTunnelData, StartTunnelErrors, StartTunnelResponses, StatusData, StatusResponses, StopAgentData, StopAgentErrors, StopAgentResponses, StopTunnelData, StopTunnelErrors, StopTunnelResponses, SyncFeaturedModelsData, SyncFeaturedModelsResponses, SystemInfoData, SystemInfoResponses, TranscribeDictationData, TranscribeDictationErrors, TranscribeDictationResponses, UnpauseScheduleData, UnpauseScheduleErrors, UnpauseScheduleResponses, UpdateAgentProviderData, UpdateAgentProviderErrors, UpdateAgentProviderResponses, UpdateCustomProviderData, UpdateCustomProviderErrors, UpdateCustomProviderResponses, UpdateFromSessionData, UpdateFromSessionErrors, UpdateFromSessionResponses, UpdateModelSettingsData, UpdateModelSettingsErrors, UpdateModelSettingsResponses, UpdateScheduleData, UpdateScheduleErrors, UpdateScheduleResponses, UpdateSessionData, UpdateSessionErrors, UpdateSessionNameData, UpdateSessionNameErrors, UpdateSessionNameResponses, UpdateSessionResponses, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesErrors, UpdateSessionUserRecipeValuesResponses, UpdateWorkingDirData, UpdateWorkingDirErrors, UpdateWorkingDirResponses, UpsertConfigData, UpsertConfigErrors, UpsertConfigResponses, UpsertPermissionsData, UpsertPermissionsErrors, UpsertPermissionsResponses, ValidateConfigData, ValidateConfigErrors, ValidateConfigResponses } from './types.gen'; +import type { AddExtensionData, AddExtensionErrors, AddExtensionResponses, AgentAddExtensionData, AgentAddExtensionErrors, AgentAddExtensionResponses, AgentRemoveExtensionData, AgentRemoveExtensionErrors, AgentRemoveExtensionResponses, CallToolData, CallToolErrors, CallToolResponses, CancelDownloadData, CancelDownloadErrors, CancelDownloadResponses, CancelLocalModelDownloadData, CancelLocalModelDownloadErrors, CancelLocalModelDownloadResponses, CheckProviderData, CleanupProviderCacheData, CleanupProviderCacheErrors, CleanupProviderCacheResponses, ConfigureProviderOauthData, ConfigureProviderOauthErrors, ConfigureProviderOauthResponses, ConfirmToolActionData, ConfirmToolActionErrors, ConfirmToolActionResponses, CreateCustomProviderData, CreateCustomProviderErrors, CreateCustomProviderResponses, CreateRecipeData, CreateRecipeErrors, CreateRecipeResponses, CreateScheduleData, CreateScheduleErrors, CreateScheduleResponses, DecodeRecipeData, DecodeRecipeErrors, DecodeRecipeResponses, DeleteLocalModelData, DeleteLocalModelErrors, DeleteLocalModelResponses, DeleteModelData, DeleteModelErrors, DeleteModelResponses, DeleteProviderSecretData, DeleteProviderSecretErrors, DeleteProviderSecretResponses, DeleteRecipeData, DeleteRecipeErrors, DeleteRecipeResponses, DeleteScheduleData, DeleteScheduleErrors, DeleteScheduleResponses, DeleteSessionData, DeleteSessionErrors, DeleteSessionResponses, DiagnosticsData, DiagnosticsErrors, DiagnosticsResponses, DownloadHfModelData, DownloadHfModelErrors, DownloadHfModelResponses, DownloadModelData, DownloadModelErrors, DownloadModelResponses, EncodeRecipeData, EncodeRecipeErrors, EncodeRecipeResponses, ExportAppData, ExportAppErrors, ExportAppResponses, ExportSessionData, ExportSessionErrors, ExportSessionResponses, ForkSessionData, ForkSessionErrors, ForkSessionResponses, GetCanonicalModelInfoData, GetCanonicalModelInfoResponses, GetCustomProviderData, GetCustomProviderErrors, GetCustomProviderResponses, GetDictationConfigData, GetDictationConfigResponses, GetDownloadProgressData, GetDownloadProgressErrors, GetDownloadProgressResponses, GetExtensionsData, GetExtensionsErrors, GetExtensionsResponses, GetFeaturesData, GetFeaturesResponses, GetLocalModelDownloadProgressData, GetLocalModelDownloadProgressErrors, GetLocalModelDownloadProgressResponses, GetModelSettingsData, GetModelSettingsErrors, GetModelSettingsResponses, GetPromptData, GetPromptErrors, GetPromptResponses, GetPromptsData, GetPromptsResponses, GetProviderCatalogData, GetProviderCatalogErrors, GetProviderCatalogResponses, GetProviderCatalogTemplateData, GetProviderCatalogTemplateErrors, GetProviderCatalogTemplateResponses, GetProviderModelInfoData, GetProviderModelInfoErrors, GetProviderModelInfoResponses, GetProviderModelsData, GetProviderModelsErrors, GetProviderModelsResponses, GetRepoFilesData, GetRepoFilesResponses, GetSessionData, GetSessionErrors, GetSessionExtensionsData, GetSessionExtensionsErrors, GetSessionExtensionsResponses, GetSessionInsightsData, GetSessionInsightsErrors, GetSessionInsightsResponses, GetSessionResponses, GetSlashCommandsData, GetSlashCommandsResponses, GetToolsData, GetToolsErrors, GetToolsResponses, GetTunnelStatusData, GetTunnelStatusResponses, ImportAppData, ImportAppErrors, ImportAppResponses, ImportSessionData, ImportSessionErrors, ImportSessionNostrData, ImportSessionNostrErrors, ImportSessionNostrResponses, ImportSessionResponses, InspectRunningJobData, InspectRunningJobErrors, InspectRunningJobResponses, KillRunningJobData, KillRunningJobResponses, ListAppsData, ListAppsErrors, ListAppsResponses, ListBuiltinChatTemplatesData, ListBuiltinChatTemplatesResponses, ListLocalModelsData, ListLocalModelsResponses, ListModelsData, ListModelsResponses, ListProviderSecretsData, ListProviderSecretsErrors, ListProviderSecretsResponses, ListRecipesData, ListRecipesErrors, ListRecipesResponses, ListSchedulesData, ListSchedulesErrors, ListSchedulesResponses, ListSessionsData, ListSessionsErrors, ListSessionsResponses, McpUiProxyData, McpUiProxyErrors, McpUiProxyResponses, ParseRecipeData, ParseRecipeErrors, ParseRecipeResponses, PauseScheduleData, PauseScheduleErrors, PauseScheduleResponses, ProvidersData, ProvidersResponses, ReadAllConfigData, ReadAllConfigResponses, ReadConfigData, ReadConfigErrors, ReadConfigResponses, ReadResourceData, ReadResourceErrors, ReadResourceResponses, RecipeToYamlData, RecipeToYamlErrors, RecipeToYamlResponses, RemoveConfigData, RemoveConfigErrors, RemoveConfigResponses, RemoveCustomProviderData, RemoveCustomProviderErrors, RemoveCustomProviderResponses, RemoveExtensionData, RemoveExtensionErrors, RemoveExtensionResponses, ReplyData, ReplyErrors, ReplyResponses, ResetPromptData, ResetPromptErrors, ResetPromptResponses, RestartAgentData, RestartAgentErrors, RestartAgentResponses, ResumeAgentData, ResumeAgentErrors, ResumeAgentResponses, RunNowHandlerData, RunNowHandlerErrors, RunNowHandlerResponses, SavePromptData, SavePromptErrors, SavePromptResponses, SaveRecipeData, SaveRecipeErrors, SaveRecipeResponses, ScanRecipeData, ScanRecipeResponses, ScheduleRecipeData, ScheduleRecipeErrors, ScheduleRecipeResponses, SearchHfModelsData, SearchHfModelsErrors, SearchHfModelsResponses, SearchSessionsData, SearchSessionsErrors, SearchSessionsResponses, SendTelemetryEventData, SendTelemetryEventResponses, SessionCancelData, SessionCancelResponses, SessionEventsData, SessionEventsErrors, SessionEventsResponses, SessionReplyData, SessionReplyErrors, SessionReplyResponses, SessionsHandlerData, SessionsHandlerErrors, SessionsHandlerResponses, SetConfigProviderData, SetRecipeSlashCommandData, SetRecipeSlashCommandErrors, SetRecipeSlashCommandResponses, ShareSessionNostrData, ShareSessionNostrErrors, ShareSessionNostrResponses, StartAgentData, StartAgentErrors, StartAgentResponses, StartNanogptSetupData, StartNanogptSetupResponses, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, StartTunnelData, StartTunnelErrors, StartTunnelResponses, StatusData, StatusResponses, StopAgentData, StopAgentErrors, StopAgentResponses, StopTunnelData, StopTunnelErrors, StopTunnelResponses, SyncFeaturedModelsData, SyncFeaturedModelsResponses, SystemInfoData, SystemInfoResponses, TranscribeDictationData, TranscribeDictationErrors, TranscribeDictationResponses, UnpauseScheduleData, UnpauseScheduleErrors, UnpauseScheduleResponses, UpdateAgentProviderData, UpdateAgentProviderErrors, UpdateAgentProviderResponses, UpdateCustomProviderData, UpdateCustomProviderErrors, UpdateCustomProviderResponses, UpdateFromSessionData, UpdateFromSessionErrors, UpdateFromSessionResponses, UpdateModelSettingsData, UpdateModelSettingsErrors, UpdateModelSettingsResponses, UpdateScheduleData, UpdateScheduleErrors, UpdateScheduleResponses, UpdateSessionData, UpdateSessionErrors, UpdateSessionNameData, UpdateSessionNameErrors, UpdateSessionNameResponses, UpdateSessionResponses, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesErrors, UpdateSessionUserRecipeValuesResponses, UpdateWorkingDirData, UpdateWorkingDirErrors, UpdateWorkingDirResponses, UpsertConfigData, UpsertConfigErrors, UpsertConfigResponses, UpsertPermissionsData, UpsertPermissionsErrors, UpsertPermissionsResponses, ValidateConfigData, ValidateConfigErrors, ValidateConfigResponses } from './types.gen'; export type Options = Options2 & { /** @@ -233,6 +233,10 @@ export const getProviderCatalog = (options export const getProviderCatalogTemplate = (options: Options) => (options.client ?? client).get({ url: '/config/provider-catalog/{id}', ...options }); +export const listProviderSecrets = (options?: Options) => (options?.client ?? client).get({ url: '/config/provider-secrets', ...options }); + +export const deleteProviderSecret = (options: Options) => (options.client ?? client).delete({ url: '/config/provider-secrets/{id}', ...options }); + export const providers = (options?: Options) => (options?.client ?? client).get({ url: '/config/providers', ...options }); export const cleanupProviderCache = (options: Options) => (options.client ?? client).post({ url: '/config/providers/{name}/cleanup', ...options }); diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 09e243ee2bd2..398813947fd9 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -1014,6 +1014,29 @@ export type ProviderModelInfoQuery = { model: string; }; +export type ProviderSecret = { + can_configure: boolean; + can_delete: boolean; + configure_provider?: string | null; + configured: boolean; + expires_at?: string | null; + has_secret: boolean; + id: string; + name: string; + provider: string; + provider_display_name: string; + status: ProviderSecretStatus; + storage: ProviderSecretStorage; +}; + +export type ProviderSecretStatus = 'valid' | 'expired' | 'unknown'; + +export type ProviderSecretStorage = 'secret_store' | 'provider_cache'; + +export type ProviderSecretsResponse = { + secrets: Array; +}; + export type ProviderTemplate = { api_url: string; doc_url: string; @@ -2548,6 +2571,61 @@ export type GetProviderCatalogTemplateResponses = { export type GetProviderCatalogTemplateResponse = GetProviderCatalogTemplateResponses[keyof GetProviderCatalogTemplateResponses]; +export type ListProviderSecretsData = { + body?: never; + path?: never; + query?: never; + url: '/config/provider-secrets'; +}; + +export type ListProviderSecretsErrors = { + /** + * Internal server error + */ + 500: unknown; +}; + +export type ListProviderSecretsResponses = { + /** + * Provider secrets retrieved successfully + */ + 200: ProviderSecretsResponse; +}; + +export type ListProviderSecretsResponse = ListProviderSecretsResponses[keyof ListProviderSecretsResponses]; + +export type DeleteProviderSecretData = { + body?: never; + path: { + /** + * Provider secret identifier + */ + id: string; + }; + query?: never; + url: '/config/provider-secrets/{id}'; +}; + +export type DeleteProviderSecretErrors = { + /** + * Invalid provider secret identifier + */ + 400: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type DeleteProviderSecretResponses = { + /** + * Provider secret deleted successfully + */ + 200: string; +}; + +export type DeleteProviderSecretResponse = DeleteProviderSecretResponses[keyof DeleteProviderSecretResponses]; + export type ProvidersData = { body?: never; path?: never; diff --git a/ui/desktop/src/components/bottom_menu/BottomMenuExtensionSelection.tsx b/ui/desktop/src/components/bottom_menu/BottomMenuExtensionSelection.tsx index 295ad9b20af6..7ac91a9de0ae 100644 --- a/ui/desktop/src/components/bottom_menu/BottomMenuExtensionSelection.tsx +++ b/ui/desktop/src/components/bottom_menu/BottomMenuExtensionSelection.tsx @@ -1,4 +1,3 @@ -import { AppEvents } from '../../constants/events'; import { useCallback, useEffect, useMemo, useState, useRef } from 'react'; import { Puzzle } from 'lucide-react'; import { DropdownMenu, DropdownMenuContent, DropdownMenuTrigger } from '../ui/dropdown-menu'; @@ -8,8 +7,7 @@ import { FixedExtensionEntry, useConfig } from '../ConfigContext'; import { toastService } from '../../toasts'; import { formatExtensionName } from '../settings/extensions/subcomponents/ExtensionList'; import { nameToKey } from '../settings/extensions/utils'; -import { ExtensionConfig } from '../../api'; -import { getSessionExtensions } from '../../acp/extensions'; +import { ExtensionConfig, getSessionExtensions } from '../../api'; import { addToAgent, removeFromAgent } from '../settings/extensions/agent-api'; import { setExtensionOverride, @@ -17,6 +15,7 @@ import { getExtensionOverrides, } from '../../store/extensionOverrides'; import { defineMessages, useIntl } from '../../i18n'; +import { AppEvents } from '../../constants/events'; const i18n = defineMessages({ manageExtensions: { @@ -69,6 +68,8 @@ interface BottomMenuExtensionSelectionProps { sessionId: string | null; } +type GetSessionExtensionsSignal = Parameters[0]['signal']; + export const BottomMenuExtensionSelection = ({ sessionId }: BottomMenuExtensionSelectionProps) => { const intl = useIntl(); const [searchQuery, setSearchQuery] = useState(''); @@ -78,28 +79,25 @@ export const BottomMenuExtensionSelection = ({ sessionId }: BottomMenuExtensionS const [isTransitioning, setIsTransitioning] = useState(false); const [pendingSort, setPendingSort] = useState(false); const [togglingExtension, setTogglingExtension] = useState(null); - const [refreshTrigger, setRefreshTrigger] = useState(0); const [isSessionExtensionsLoaded, setIsSessionExtensionsLoaded] = useState(false); const sortTimeoutRef = useRef | null>(null); + const latestSessionIdRef = useRef(sessionId); const { extensionsList: allExtensions } = useConfig(); const isHubView = !sessionId; useEffect(() => { + latestSessionIdRef.current = sessionId; setIsSessionExtensionsLoaded(false); setSessionExtensions([]); - }, [sessionId]); + setPendingSort(false); + setIsTransitioning(false); + setTogglingExtension(null); - useEffect(() => { - const handleExtensionsLoaded = () => { - setRefreshTrigger((prev) => prev + 1); - }; - - window.addEventListener(AppEvents.SESSION_EXTENSIONS_LOADED, handleExtensionsLoaded); - - return () => { - window.removeEventListener(AppEvents.SESSION_EXTENSIONS_LOADED, handleExtensionsLoaded); - }; - }, []); + if (sortTimeoutRef.current) { + clearTimeout(sortTimeoutRef.current); + sortTimeoutRef.current = null; + } + }, [sessionId]); useEffect(() => { return () => { @@ -109,28 +107,71 @@ export const BottomMenuExtensionSelection = ({ sessionId }: BottomMenuExtensionS }; }, []); + const loadSessionExtensions = useCallback( + async (targetSessionId: string, signal?: GetSessionExtensionsSignal) => { + const response = await getSessionExtensions({ + path: { session_id: targetSessionId }, + signal, + throwOnError: true, + }); + + if (signal?.aborted || latestSessionIdRef.current !== targetSessionId) { + return; + } + + setSessionExtensions(response.data?.extensions ?? []); + setIsSessionExtensionsLoaded(true); + }, + [] + ); + useEffect(() => { - if (refreshTrigger === 0 && !isOpen) { + if (!sessionId) { + setIsSessionExtensionsLoaded(true); return; } - const fetchExtensions = async () => { - if (!sessionId) { + let controller: AbortController | null = null; + + const loadExtensionsForCurrentSession = (event: Event) => { + const targetSessionId = (event as CustomEvent<{ sessionId?: string }>).detail?.sessionId; + + if (targetSessionId !== sessionId) { return; } - try { - const extensions = await getSessionExtensions(sessionId); - setSessionExtensions(extensions); - setIsSessionExtensionsLoaded(true); - } catch (error) { + controller?.abort(); + const currentController = new AbortController(); + controller = currentController; + + loadSessionExtensions(targetSessionId, currentController.signal).catch((error) => { + if (currentController.signal.aborted || latestSessionIdRef.current !== targetSessionId) { + return; + } + console.error('Failed to fetch session extensions:', error); setIsSessionExtensionsLoaded(true); - } + }); }; - fetchExtensions(); - }, [sessionId, isOpen, refreshTrigger]); + window.addEventListener(AppEvents.SESSION_EXTENSIONS_LOADED, loadExtensionsForCurrentSession); + + return () => { + controller?.abort(); + window.removeEventListener( + AppEvents.SESSION_EXTENSIONS_LOADED, + loadExtensionsForCurrentSession + ); + }; + }, [sessionId, loadSessionExtensions]); + + const finishSessionTransition = useCallback((targetSessionId: string) => { + if (latestSessionIdRef.current === targetSessionId) { + setPendingSort(false); + setIsTransitioning(false); + setTogglingExtension(null); + } + }, []); const handleToggle = useCallback( async (extensionConfig: FixedExtensionEntry) => { @@ -156,6 +197,7 @@ export const BottomMenuExtensionSelection = ({ sessionId }: BottomMenuExtensionS setPendingSort(false); setIsTransitioning(false); setTogglingExtension(null); + sortTimeoutRef.current = null; }, 800); toastService.success({ @@ -192,12 +234,17 @@ export const BottomMenuExtensionSelection = ({ sessionId }: BottomMenuExtensionS clearTimeout(sortTimeoutRef.current); } - sortTimeoutRef.current = setTimeout(async () => { - const extensions = await getSessionExtensions(sessionId); - setSessionExtensions(extensions); - setPendingSort(false); - setIsTransitioning(false); - setTogglingExtension(null); + sortTimeoutRef.current = setTimeout(() => { + loadSessionExtensions(sessionId) + .catch((error) => { + if (latestSessionIdRef.current === sessionId) { + console.error('Failed to fetch session extensions:', error); + } + }) + .finally(() => { + finishSessionTransition(sessionId); + sortTimeoutRef.current = null; + }); }, 800); } catch { setIsTransitioning(false); @@ -205,7 +252,7 @@ export const BottomMenuExtensionSelection = ({ sessionId }: BottomMenuExtensionS setTogglingExtension(null); } }, - [sessionId, isHubView, togglingExtension, intl] + [sessionId, isHubView, togglingExtension, intl, loadSessionExtensions, finishSessionTransition] ); // Merge all available extensions with session-specific or hub override state diff --git a/ui/desktop/src/components/settings/SettingsView.tsx b/ui/desktop/src/components/settings/SettingsView.tsx index 5b47b12b4440..ae6c82b35797 100644 --- a/ui/desktop/src/components/settings/SettingsView.tsx +++ b/ui/desktop/src/components/settings/SettingsView.tsx @@ -2,17 +2,36 @@ import { ScrollArea } from '../ui/scroll-area'; import { Tabs, TabsContent, TabsList, TabsTrigger } from '../ui/tabs'; import { View, ViewOptions } from '../../utils/navigationUtils'; import ModelsSection from './models/ModelsSection'; +import SessionSharingSection from './sessions/SessionSharingSection'; +import ExternalBackendSection from './app/ExternalBackendSection'; import AppSettingsSection from './app/AppSettingsSection'; import ConfigSettings from './config/ConfigSettings'; import PromptsSettingsSection from './PromptsSettingsSection'; import { ExtensionConfig } from '../../api'; import { MainPanelLayout } from '../Layout/MainPanelLayout'; -import { Bot, Monitor, MessageSquare, FileText, Keyboard } from 'lucide-react'; +import { + Bot, + Share2, + Monitor, + MessageSquare, + FileText, + Keyboard, + HardDrive, + Network, + KeyRound, +} from 'lucide-react'; import { useState, useEffect, useRef } from 'react'; +import TunnelSection from './tunnel/TunnelSection'; +import GatewaySettingsSection from './gateways/GatewaySettingsSection'; +import { getTunnelStatus } from '../../api/sdk.gen'; import ChatSettingsSection from './chat/ChatSettingsSection'; import KeyboardShortcutsSection from './keyboard/KeyboardShortcutsSection'; +import AuthSettingsSection from './auth/AuthSettingsSection'; +import LocalInferenceSection from './localInference/LocalInferenceSection'; +import MeshSection from './mesh/MeshSection'; import { CONFIGURATION_ENABLED } from '../../updates'; import { trackSettingsTabViewed } from '../../utils/analytics'; +import { useFeatures } from '../../contexts/FeaturesContext'; import { defineMessages, useIntl } from '../../i18n'; const i18n = defineMessages({ @@ -24,10 +43,18 @@ const i18n = defineMessages({ id: 'settingsView.tabModels', defaultMessage: 'Models', }, + tabLocalInference: { + id: 'settingsView.tabLocalInference', + defaultMessage: 'Local Inference', + }, tabChat: { id: 'settingsView.tabChat', defaultMessage: 'Chat', }, + tabSession: { + id: 'settingsView.tabSession', + defaultMessage: 'Session', + }, tabPrompts: { id: 'settingsView.tabPrompts', defaultMessage: 'Prompts', @@ -36,6 +63,10 @@ const i18n = defineMessages({ id: 'settingsView.tabKeyboard', defaultMessage: 'Keyboard', }, + tabAuth: { + id: 'settingsView.tabAuth', + defaultMessage: 'Auth', + }, tabApp: { id: 'settingsView.tabApp', defaultMessage: 'App', @@ -59,7 +90,9 @@ export default function SettingsView({ viewOptions: SettingsViewOptions; }) { const [activeTab, setActiveTab] = useState('models'); + const [tunnelDisabled, setTunnelDisabled] = useState(false); const hasTrackedInitialTab = useRef(false); + const { localInference } = useFeatures(); const intl = useIntl(); const handleTabChange = (tab: string) => { @@ -75,20 +108,39 @@ export default function SettingsView({ update: 'app', models: 'models', modes: 'chat', + sharing: 'sharing', styles: 'chat', tools: 'chat', app: 'app', chat: 'chat', prompts: 'prompts', keyboard: 'keyboard', + auth: 'auth', + gateway: 'sharing', + 'local-inference': 'local-inference', + mesh: 'mesh', }; const targetTab = sectionToTab[viewOptions.section]; - if (targetTab) { + if ( + targetTab && + (targetTab !== 'local-inference' || localInference) && + (targetTab !== 'mesh' || !tunnelDisabled) + ) { setActiveTab(targetTab); } } - }, [viewOptions.section]); + }, [viewOptions.section, localInference, tunnelDisabled]); + + // Reset active tab if local-inference or mesh becomes unavailable + useEffect(() => { + if (!localInference && activeTab === 'local-inference') { + setActiveTab('models'); + } + if (tunnelDisabled && activeTab === 'mesh') { + setActiveTab('models'); + } + }, [localInference, tunnelDisabled, activeTab]); useEffect(() => { if (!hasTrackedInitialTab.current) { @@ -97,6 +149,16 @@ export default function SettingsView({ } }, [activeTab]); + useEffect(() => { + getTunnelStatus() + .then(({ data }) => { + setTunnelDisabled(data?.state === 'disabled'); + }) + .catch(() => { + setTunnelDisabled(false); + }); + }, []); + useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { if (event.key === 'Escape' && !event.defaultPrevented) { @@ -139,10 +201,38 @@ export default function SettingsView({ {intl.formatMessage(i18n.tabModels)} + {localInference && ( + + + {intl.formatMessage(i18n.tabLocalInference)} + + )} + {!tunnelDisabled && ( + + + Mesh + + )} {intl.formatMessage(i18n.tabChat)} + + + {intl.formatMessage(i18n.tabSession)} + {intl.formatMessage(i18n.tabKeyboard)} + + + {intl.formatMessage(i18n.tabAuth)} + {intl.formatMessage(i18n.tabApp)} @@ -174,6 +268,24 @@ export default function SettingsView({ + {localInference && ( + + + + )} + + {!tunnelDisabled && ( + + + + )} + + +
+ + + {!tunnelDisabled && ( +
+ + +
+ )} +
+
+ + + + + { + const actual = await vi.importActual('../../../api'); + return { + ...actual, + configureProviderOauth: vi.fn(), + listProviderSecrets: vi.fn(), + deleteProviderSecret: vi.fn(), + }; +}); + +vi.mock('../../ModelAndProviderContext', () => ({ + useModelAndProvider: () => ({ + currentProvider: 'openai', + }), +})); + +vi.mock('react-toastify', () => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + }, +})); + +const mockedListProviderSecrets = vi.mocked(listProviderSecrets); +const mockedDeleteProviderSecret = vi.mocked(deleteProviderSecret); +const mockedConfigureProviderOauth = vi.mocked(configureProviderOauth); +const mockedToast = vi.mocked(toast); + +const renderWithIntl = (ui: React.ReactElement, options?: RenderOptions) => + render(ui, { wrapper: IntlTestWrapper, ...options }); + +const providerSecret: ProviderSecret = { + id: 'secret_store:openai:OPENAI_API_KEY', + provider: 'openai', + provider_display_name: 'OpenAI', + name: 'OPENAI_API_KEY', + storage: 'secret_store', + expires_at: null, + status: 'unknown', + configured: true, + has_secret: true, + can_delete: true, + can_configure: false, + configure_provider: null, +}; + +const apiResult = (data: T) => ({ + data, + request: {} as never, + response: {} as never, +}); + +describe('AuthSettingsSection', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockedListProviderSecrets.mockResolvedValue(apiResult({ secrets: [] })); + mockedDeleteProviderSecret.mockResolvedValue(apiResult('ok')); + mockedConfigureProviderOauth.mockResolvedValue(apiResult('ok')); + }); + + it('renders an empty state when no credentials are stored', async () => { + renderWithIntl(); + + expect(screen.getByText('Loading credentials...')).toBeInTheDocument(); + expect(await screen.findByText('No locally stored provider credentials were found.')).toBeInTheDocument(); + }); + + it('renders provider credentials with storage and expiry status', async () => { + mockedListProviderSecrets.mockResolvedValue( + apiResult({ + secrets: [ + { + ...providerSecret, + expires_at: '2027-01-01T12:00:00Z', + status: 'valid', + }, + ], + }) + ); + + renderWithIntl(); + + expect(await screen.findByText('OpenAI')).toBeInTheDocument(); + expect(screen.getByText('OPENAI_API_KEY')).toBeInTheDocument(); + expect(screen.getByText('Secret store')).toBeInTheDocument(); + expect(screen.getByText(/Expires/)).toBeInTheDocument(); + }); + + it('does not render an expiry badge when expiry is unknown', async () => { + mockedListProviderSecrets.mockResolvedValue(apiResult({ secrets: [providerSecret] })); + + renderWithIntl(); + + expect(await screen.findByText('OpenAI')).toBeInTheDocument(); + expect(screen.getByText('Secret store')).toBeInTheDocument(); + expect(screen.queryByText('Expiry unknown')).not.toBeInTheDocument(); + expect(screen.queryByText(/Expires/)).not.toBeInTheDocument(); + }); + + it('deletes a credential after confirmation and refreshes the list', async () => { + const user = userEvent.setup(); + mockedListProviderSecrets + .mockResolvedValueOnce(apiResult({ secrets: [providerSecret] })) + .mockResolvedValueOnce(apiResult({ secrets: [] })); + + renderWithIntl(); + + expect(await screen.findByText('OpenAI')).toBeInTheDocument(); + + await user.click(screen.getByRole('button', { name: 'Delete credential' })); + + expect(screen.getByText('Delete the OPENAI_API_KEY credential for OpenAI?')).toBeInTheDocument(); + expect( + screen.getByText( + 'This is the active provider. New requests may fail until you configure another credential.' + ) + ).toBeInTheDocument(); + + await user.click(screen.getByRole('button', { name: 'Delete' })); + + await waitFor(() => { + expect(mockedDeleteProviderSecret).toHaveBeenCalledWith({ + path: { id: 'secret_store:openai:OPENAI_API_KEY' }, + throwOnError: true, + }); + }); + await waitFor(() => { + expect(mockedToast.success).toHaveBeenCalledWith('Credential deleted'); + }); + expect(await screen.findByText('No locally stored provider credentials were found.')).toBeInTheDocument(); + }); + + it('configures the permanent Hugging Face credential row', async () => { + const user = userEvent.setup(); + const huggingFaceSecret: ProviderSecret = { + id: 'provider_cache:huggingface', + provider: 'huggingface', + provider_display_name: 'Hugging Face', + name: 'OAuth token', + storage: 'provider_cache', + expires_at: null, + status: 'unknown', + configured: false, + has_secret: false, + can_delete: false, + can_configure: true, + configure_provider: 'huggingface', + }; + + mockedListProviderSecrets + .mockResolvedValueOnce(apiResult({ secrets: [huggingFaceSecret] })) + .mockResolvedValueOnce( + apiResult({ + secrets: [ + { + ...huggingFaceSecret, + configured: true, + has_secret: true, + can_delete: true, + }, + ], + }) + ); + + renderWithIntl(); + + expect(await screen.findByText('Hugging Face')).toBeInTheDocument(); + expect(screen.queryByRole('button', { name: 'Delete credential' })).not.toBeInTheDocument(); + await user.click(screen.getByRole('button', { name: 'Sign in' })); + + await waitFor(() => { + expect(mockedConfigureProviderOauth).toHaveBeenCalledWith({ + path: { name: 'huggingface' }, + throwOnError: true, + }); + }); + await waitFor(() => { + expect(mockedToast.success).toHaveBeenCalledWith('Credential configured'); + }); + }); +}); diff --git a/ui/desktop/src/components/settings/auth/AuthSettingsSection.tsx b/ui/desktop/src/components/settings/auth/AuthSettingsSection.tsx new file mode 100644 index 000000000000..35b321fd25e5 --- /dev/null +++ b/ui/desktop/src/components/settings/auth/AuthSettingsSection.tsx @@ -0,0 +1,317 @@ +import { useCallback, useEffect, useState } from 'react'; +import { KeyRound, Loader2, LogIn, RefreshCw, Trash2 } from 'lucide-react'; +import { toast } from 'react-toastify'; +import { + configureProviderOauth, + deleteProviderSecret, + listProviderSecrets, + ProviderSecret, +} from '../../../api'; +import { errorMessage } from '../../../utils/conversionUtils'; +import { useModelAndProvider } from '../../ModelAndProviderContext'; +import { Button } from '../../ui/button'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '../../ui/card'; +import { ConfirmationModal } from '../../ui/ConfirmationModal'; +import { defineMessages, useIntl } from '../../../i18n'; + +const i18n = defineMessages({ + title: { + id: 'authSettings.title', + defaultMessage: 'Provider Credentials', + }, + description: { + id: 'authSettings.description', + defaultMessage: 'Manage provider credentials stored locally by goose.', + }, + loading: { + id: 'authSettings.loading', + defaultMessage: 'Loading credentials...', + }, + empty: { + id: 'authSettings.empty', + defaultMessage: 'No locally stored provider credentials were found.', + }, + failedToLoad: { + id: 'authSettings.failedToLoad', + defaultMessage: 'Failed to load provider credentials', + }, + deleteTitle: { + id: 'authSettings.deleteTitle', + defaultMessage: 'Delete credential', + }, + deleteMessage: { + id: 'authSettings.deleteMessage', + defaultMessage: 'Delete the {name} credential for {provider}?', + }, + activeProviderWarning: { + id: 'authSettings.activeProviderWarning', + defaultMessage: 'This is the active provider. New requests may fail until you configure another credential.', + }, + delete: { + id: 'authSettings.delete', + defaultMessage: 'Delete', + }, + cancel: { + id: 'authSettings.cancel', + defaultMessage: 'Cancel', + }, + deleted: { + id: 'authSettings.deleted', + defaultMessage: 'Credential deleted', + }, + failedToDelete: { + id: 'authSettings.failedToDelete', + defaultMessage: 'Failed to delete credential: {error}', + }, + storageSecretStore: { + id: 'authSettings.storageSecretStore', + defaultMessage: 'Secret store', + }, + storageProviderCache: { + id: 'authSettings.storageProviderCache', + defaultMessage: 'Provider cache', + }, + expiresAt: { + id: 'authSettings.expiresAt', + defaultMessage: 'Expires {date}', + }, + deleteCredential: { + id: 'authSettings.deleteCredential', + defaultMessage: 'Delete credential', + }, + signIn: { + id: 'authSettings.signIn', + defaultMessage: 'Sign in', + }, + reauthorize: { + id: 'authSettings.reauthorize', + defaultMessage: 'Reauthorize', + }, + signedIn: { + id: 'authSettings.signedIn', + defaultMessage: 'Credential configured', + }, + failedToConfigure: { + id: 'authSettings.failedToConfigure', + defaultMessage: 'Failed to configure credential: {error}', + }, +}); + +function storageLabel(secret: ProviderSecret, intl: ReturnType) { + if (secret.storage === 'provider_cache') { + return intl.formatMessage(i18n.storageProviderCache); + } + return intl.formatMessage(i18n.storageSecretStore); +} + +function expiryLabel(secret: ProviderSecret, intl: ReturnType) { + if (!secret.expires_at) { + return null; + } + return intl.formatMessage(i18n.expiresAt, { + date: intl.formatDate(new Date(secret.expires_at), { + dateStyle: 'medium', + timeStyle: 'short', + }), + }); +} + +function expiryClass(secret: ProviderSecret) { + if (secret.status === 'expired') { + return 'border-red-500/30 bg-red-500/10 text-red-700 dark:text-red-300'; + } + return 'border-green-500/30 bg-green-500/10 text-green-700 dark:text-green-300'; +} + +export default function AuthSettingsSection() { + const intl = useIntl(); + const { currentProvider } = useModelAndProvider(); + const [secrets, setSecrets] = useState([]); + const [loading, setLoading] = useState(true); + const [deletingId, setDeletingId] = useState(null); + const [configuringId, setConfiguringId] = useState(null); + const [secretToDelete, setSecretToDelete] = useState(null); + + const loadSecrets = useCallback(async () => { + setLoading(true); + try { + const response = await listProviderSecrets({ throwOnError: true }); + setSecrets(response.data?.secrets ?? []); + } catch { + toast.error(intl.formatMessage(i18n.failedToLoad)); + setSecrets([]); + } finally { + setLoading(false); + } + }, [intl]); + + useEffect(() => { + loadSecrets(); + }, [loadSecrets]); + + const confirmDelete = async () => { + if (!secretToDelete) { + return; + } + + setDeletingId(secretToDelete.id); + try { + await deleteProviderSecret({ + path: { id: secretToDelete.id }, + throwOnError: true, + }); + toast.success(intl.formatMessage(i18n.deleted)); + setSecretToDelete(null); + await loadSecrets(); + } catch (error) { + toast.error( + intl.formatMessage(i18n.failedToDelete, { + error: errorMessage(error, 'Unknown error'), + }) + ); + } finally { + setDeletingId(null); + } + }; + + const configureSecret = async (secret: ProviderSecret) => { + if (!secret.configure_provider) { + return; + } + + setConfiguringId(secret.id); + try { + await configureProviderOauth({ + path: { name: secret.configure_provider }, + throwOnError: true, + }); + toast.success(intl.formatMessage(i18n.signedIn)); + await loadSecrets(); + } catch (error) { + toast.error( + intl.formatMessage(i18n.failedToConfigure, { + error: errorMessage(error, 'Unknown error'), + }) + ); + } finally { + setConfiguringId(null); + } + }; + + const isActiveProvider = secretToDelete?.provider === currentProvider; + + return ( +
+ + + + + {intl.formatMessage(i18n.title)} + + {intl.formatMessage(i18n.description)} + + + {loading ? ( +
+ + {intl.formatMessage(i18n.loading)} +
+ ) : secrets.length === 0 ? ( +
{intl.formatMessage(i18n.empty)}
+ ) : ( +
+ {secrets.map((secret) => ( +
+
+
+

+ {secret.provider_display_name} +

+ + {storageLabel(secret, intl)} + + {expiryLabel(secret, intl) && ( + + {expiryLabel(secret, intl)} + + )} +
+

+ {secret.name} +

+
+
+ {secret.can_configure && secret.configure_provider && ( + + )} + {secret.can_delete && ( + + )} +
+
+ ))} +
+ )} +
+
+ + setSecretToDelete(null)} + confirmLabel={intl.formatMessage(i18n.delete)} + cancelLabel={intl.formatMessage(i18n.cancel)} + confirmVariant="destructive" + isSubmitting={!!deletingId} + /> +
+ ); +} diff --git a/ui/desktop/src/components/settings/auth/HuggingFaceSignInPrompt.tsx b/ui/desktop/src/components/settings/auth/HuggingFaceSignInPrompt.tsx new file mode 100644 index 000000000000..a4b6c408ed1a --- /dev/null +++ b/ui/desktop/src/components/settings/auth/HuggingFaceSignInPrompt.tsx @@ -0,0 +1,116 @@ +import { useCallback, useEffect, useState } from 'react'; +import { Loader2, LogIn } from 'lucide-react'; +import { toast } from 'react-toastify'; +import { configureProviderOauth, listProviderSecrets } from '../../../api'; +import { errorMessage } from '../../../utils/conversionUtils'; +import { defineMessages, useIntl } from '../../../i18n'; +import { Button } from '../../ui/button'; + +const HUGGINGFACE_PROVIDER = 'huggingface'; +const HUGGINGFACE_OAUTH_SECRET_ID = 'provider_cache:huggingface'; + +const i18n = defineMessages({ + title: { + id: 'huggingFaceSignInPrompt.title', + defaultMessage: 'Hugging Face', + }, + signIn: { + id: 'huggingFaceSignInPrompt.signIn', + defaultMessage: 'Sign in', + }, + signingIn: { + id: 'huggingFaceSignInPrompt.signingIn', + defaultMessage: 'Signing in...', + }, + signedIn: { + id: 'huggingFaceSignInPrompt.signedIn', + defaultMessage: 'Hugging Face signed in', + }, + failedToConfigure: { + id: 'huggingFaceSignInPrompt.failedToConfigure', + defaultMessage: 'Failed to sign in to Hugging Face: {error}', + }, +}); + +interface HuggingFaceSignInPromptProps { + description: string; + className?: string; + onSignedIn?: () => void; +} + +export default function HuggingFaceSignInPrompt({ + description, + className, + onSignedIn, +}: HuggingFaceSignInPromptProps) { + const intl = useIntl(); + const [loading, setLoading] = useState(true); + const [loggedIn, setLoggedIn] = useState(false); + const [signingIn, setSigningIn] = useState(false); + + const loadStatus = useCallback(async () => { + setLoading(true); + try { + const response = await listProviderSecrets({ throwOnError: true }); + const huggingFaceSecret = response.data?.secrets.find( + (secret) => secret.id === HUGGINGFACE_OAUTH_SECRET_ID + ); + setLoggedIn(Boolean(huggingFaceSecret?.has_secret && huggingFaceSecret.status !== 'expired')); + } catch { + setLoggedIn(false); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + loadStatus(); + }, [loadStatus]); + + const signIn = async () => { + setSigningIn(true); + try { + await configureProviderOauth({ + path: { name: HUGGINGFACE_PROVIDER }, + throwOnError: true, + }); + toast.success(intl.formatMessage(i18n.signedIn)); + setLoggedIn(true); + onSignedIn?.(); + } catch (error) { + toast.error( + intl.formatMessage(i18n.failedToConfigure, { + error: errorMessage(error, 'Unknown error'), + }) + ); + await loadStatus(); + } finally { + setSigningIn(false); + } + }; + + if (loading || loggedIn) { + return null; + } + + return ( +
+
+

{intl.formatMessage(i18n.title)}

+

{description}

+
+ +
+ ); +} diff --git a/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx b/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx index ea4a8ee58f3b..f8c98999c6b0 100644 --- a/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx +++ b/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx @@ -17,6 +17,7 @@ import { import { HuggingFaceModelSearch } from './HuggingFaceModelSearch'; import { ModelSettingsPanel } from './ModelSettingsPanel'; import { Dialog, DialogContent, DialogHeader, DialogTitle } from '../../ui/dialog'; +import HuggingFaceSignInPrompt from '../auth/HuggingFaceSignInPrompt'; const i18n = defineMessages({ title: { @@ -96,6 +97,11 @@ const i18n = defineMessages({ id: 'localInferenceSettings.visionEncoderNotDownloaded', defaultMessage: 'Vision encoder not downloaded', }, + huggingFaceSignInNote: { + id: 'localInferenceSettings.huggingFaceSignInNote', + defaultMessage: + 'Sign in to increase rate limits when searching and downloading models, and to access private or gated Hugging Face repositories.', + }, }); const VisionBadge = ({ @@ -328,6 +334,8 @@ export const LocalInferenceSettings = () => {

+ + {/* Active Downloads */} {downloads.size > 0 && (
diff --git a/ui/desktop/src/components/settings/providers/modal/ProviderConfigurationModal.tsx b/ui/desktop/src/components/settings/providers/modal/ProviderConfigurationModal.tsx index 45ddf6fb85dd..28e64bebdbbd 100644 --- a/ui/desktop/src/components/settings/providers/modal/ProviderConfigurationModal.tsx +++ b/ui/desktop/src/components/settings/providers/modal/ProviderConfigurationModal.tsx @@ -26,6 +26,7 @@ import { import { Button } from '../../../../components/ui/button'; import { errorMessage } from '../../../../utils/conversionUtils'; import { defineMessages, useIntl } from '../../../../i18n'; +import HuggingFaceSignInPrompt from '../../auth/HuggingFaceSignInPrompt'; const i18n = defineMessages({ deleteConfigHeader: { @@ -114,6 +115,11 @@ const i18n = defineMessages({ id: 'providerConfigurationModal.close', defaultMessage: 'Close', }, + huggingFaceOAuthDescription: { + id: 'providerConfigurationModal.huggingFaceOAuthDescription', + defaultMessage: + 'Sign in to use Hugging Face Inference Providers without manually entering an API token.', + }, }); /** Render a setup step string, turning `backtick` spans into and newlines into
. */ @@ -176,6 +182,7 @@ export default function ProviderConfigurationModal({ const hasOAuth = provider.metadata.config_keys.some((key) => key.oauth_flow); const hasConfig = configKeys.length > 0; const hasDeviceCodeFlow = provider.metadata.config_keys.some((key) => key.device_code_flow); + const isHuggingFaceProvider = provider.name === 'huggingface'; const isConfigured = provider.is_configured; const headerText = showDeleteConfirmation @@ -422,6 +429,20 @@ export default function ProviderConfigurationModal({ /> )} + {isHuggingFaceProvider && !hasOAuth && ( + { + if (onConfigured) { + onConfigured(provider); + } else { + onClose(); + } + }} + /> + )} + {isExternalSetup && (

diff --git a/ui/desktop/src/hooks/useChatStream.ts b/ui/desktop/src/hooks/useChatStream.ts index 1d6470252ec5..f67390a76606 100644 --- a/ui/desktop/src/hooks/useChatStream.ts +++ b/ui/desktop/src/hooks/useChatStream.ts @@ -748,7 +748,9 @@ export function useChatStream({ }, }, }); - window.dispatchEvent(new CustomEvent(AppEvents.SESSION_EXTENSIONS_LOADED)); + window.dispatchEvent( + new CustomEvent(AppEvents.SESSION_EXTENSIONS_LOADED, { detail: { sessionId } }) + ); onSessionLoaded?.(); return; } @@ -776,7 +778,9 @@ export function useChatStream({ const extensionResults = resumeData?.extension_results; showExtensionLoadResults(extensionResults); - window.dispatchEvent(new CustomEvent(AppEvents.SESSION_EXTENSIONS_LOADED)); + window.dispatchEvent( + new CustomEvent(AppEvents.SESSION_EXTENSIONS_LOADED, { detail: { sessionId } }) + ); const pendingRequestId = pendingReattachRequestIdRef.current; const reattachedToActiveRequest = activeRequestIdRef.current !== null; diff --git a/ui/desktop/src/i18n/messages/en.json b/ui/desktop/src/i18n/messages/en.json index 76f990a9707a..aa3a5a206d52 100644 --- a/ui/desktop/src/i18n/messages/en.json +++ b/ui/desktop/src/i18n/messages/en.json @@ -44,6 +44,66 @@ "appsView.title": { "defaultMessage": "Apps" }, + "authSettings.activeProviderWarning": { + "defaultMessage": "This is the active provider. New requests may fail until you configure another credential." + }, + "authSettings.cancel": { + "defaultMessage": "Cancel" + }, + "authSettings.delete": { + "defaultMessage": "Delete" + }, + "authSettings.deleteCredential": { + "defaultMessage": "Delete credential" + }, + "authSettings.deleteMessage": { + "defaultMessage": "Delete the {name} credential for {provider}?" + }, + "authSettings.deleteTitle": { + "defaultMessage": "Delete credential" + }, + "authSettings.deleted": { + "defaultMessage": "Credential deleted" + }, + "authSettings.description": { + "defaultMessage": "Manage provider credentials stored locally by goose." + }, + "authSettings.empty": { + "defaultMessage": "No locally stored provider credentials were found." + }, + "authSettings.expiresAt": { + "defaultMessage": "Expires {date}" + }, + "authSettings.failedToConfigure": { + "defaultMessage": "Failed to configure credential: {error}" + }, + "authSettings.failedToDelete": { + "defaultMessage": "Failed to delete credential: {error}" + }, + "authSettings.failedToLoad": { + "defaultMessage": "Failed to load provider credentials" + }, + "authSettings.loading": { + "defaultMessage": "Loading credentials..." + }, + "authSettings.reauthorize": { + "defaultMessage": "Reauthorize" + }, + "authSettings.signIn": { + "defaultMessage": "Sign in" + }, + "authSettings.signedIn": { + "defaultMessage": "Credential configured" + }, + "authSettings.storageProviderCache": { + "defaultMessage": "Provider cache" + }, + "authSettings.storageSecretStore": { + "defaultMessage": "Secret store" + }, + "authSettings.title": { + "defaultMessage": "Provider Credentials" + }, "backButton.back": { "defaultMessage": "Back" }, @@ -1460,6 +1520,21 @@ "huggingFaceModelSearch.tooLarge": { "defaultMessage": "May not fit in memory ({size} model, {available} available)" }, + "huggingFaceSignInPrompt.failedToConfigure": { + "defaultMessage": "Failed to sign in to Hugging Face: {error}" + }, + "huggingFaceSignInPrompt.signIn": { + "defaultMessage": "Sign in" + }, + "huggingFaceSignInPrompt.signedIn": { + "defaultMessage": "Hugging Face signed in" + }, + "huggingFaceSignInPrompt.signingIn": { + "defaultMessage": "Signing in..." + }, + "huggingFaceSignInPrompt.title": { + "defaultMessage": "Hugging Face" + }, "imagePreview.altText": { "defaultMessage": "ApeMind Agent image" }, @@ -1799,6 +1874,9 @@ "localInferenceSettings.featuredModels": { "defaultMessage": "Featured Models" }, + "localInferenceSettings.huggingFaceSignInNote": { + "defaultMessage": "Sign in to increase rate limits when searching and downloading models, and to access private or gated Hugging Face repositories." + }, "localInferenceSettings.modelSettings": { "defaultMessage": "Model Settings" }, @@ -2795,6 +2873,9 @@ "providerConfigurationModal.goBack": { "defaultMessage": "Go Back" }, + "providerConfigurationModal.huggingFaceOAuthDescription": { + "defaultMessage": "Sign in to use Hugging Face Inference Providers without manually entering an API token." + }, "providerConfigurationModal.oauthLoginFailed": { "defaultMessage": "OAuth login failed: {error}" }, @@ -4046,6 +4127,9 @@ "settingsView.tabApp": { "defaultMessage": "App" }, + "settingsView.tabAuth": { + "defaultMessage": "Auth" + }, "settingsView.tabChat": { "defaultMessage": "Chat" }, diff --git a/ui/sdk/generate-schema.ts b/ui/sdk/generate-schema.ts index 1d7c4858dfc6..66d7ed673c95 100644 --- a/ui/sdk/generate-schema.ts +++ b/ui/sdk/generate-schema.ts @@ -71,7 +71,10 @@ async function postProcessTypes() { await fs.writeFile(tsPath, src); } -async function postProcessIndex(meta: { methods: unknown[] }) { +async function postProcessIndex(meta: { + methods: unknown[]; + notifications?: unknown[]; +}) { const indexPath = resolve(OUTPUT_DIR, "index.ts"); let src = await fs.readFile(indexPath, "utf8"); @@ -88,6 +91,10 @@ async function postProcessIndex(meta: { methods: unknown[] }) { export const GOOSE_EXT_METHODS = ${JSON.stringify(meta.methods, null, 2)} as const; export type GooseExtMethod = (typeof GOOSE_EXT_METHODS)[number]; + +export const GOOSE_EXT_NOTIFICATIONS = ${JSON.stringify(meta.notifications ?? [], null, 2)} as const; + +export type GooseExtNotification = (typeof GOOSE_EXT_NOTIFICATIONS)[number]; `, { parser: "typescript" }, ); @@ -126,6 +133,32 @@ interface MethodMeta { responseType: string | null; } +interface NotificationMeta { + method: string; + paramsType: string | null; +} + +function methodToHandlerName(method: string): string { + let methodParts = method.split(/[/_]/).filter((part) => part.length > 0); + let prefix = ""; + if (methodParts[0] == "goose" && methodParts[1] == "unstable") { + methodParts.shift(); + methodParts.shift(); + prefix = "unstable_"; + } else if (methodParts[0] == "goose") { + methodParts.shift(); + } + const body = methodParts + .map((part) => + part.replace(/[^a-zA-Z0-9]+(.)/g, (_, chr: string) => chr.toUpperCase()), + ) + .map((part, i) => + i === 0 ? part : part.charAt(0).toUpperCase() + part.slice(1), + ) + .join(""); + return `${prefix}${body}`; +} + function methodToCamelCase(method: string): string { let methodParts = method.split(/[/_]/).filter((part) => part.length > 0); @@ -150,9 +183,13 @@ function methodToCamelCase(method: string): string { return `${prefix}${suffix}`; } -async function generateClient(meta: { methods: MethodMeta[] }) { +async function generateClient(meta: { + methods: MethodMeta[]; + notifications?: NotificationMeta[]; +}) { const typeImports = new Set(); const zodImports = new Set(); + const upstreamTypeImports = new Set(["Client"]); const methodDefs: string[] = []; @@ -200,6 +237,67 @@ async function generateClient(meta: { methods: MethodMeta[] }) { }`); } + const handlerFields: string[] = []; + const dispatchCases: string[] = []; + const handlerKeys: string[] = []; + + for (const n of meta.notifications ?? []) { + const handlerName = methodToHandlerName(n.method); + handlerKeys.push(handlerName); + if (!n.paramsType) { + handlerFields.push( + ` ${handlerName}?: (params: Record) => Promise;`, + ); + dispatchCases.push( + ` case "${n.method}": { + await ${handlerName}?.(params); + return; + }`, + ); + continue; + } + typeImports.add(n.paramsType); + const zodName = `z${n.paramsType}`; + zodImports.add(zodName); + handlerFields.push( + ` ${handlerName}?: (notification: ${n.paramsType}) => Promise;`, + ); + dispatchCases.push( + ` case "${n.method}": { + const parsed = ${zodName}.parse(params) as ${n.paramsType}; + await ${handlerName}?.(parsed); + return; + }`, + ); + } + + const handlerDestructure = + handlerKeys.length > 0 + ? `const { ${handlerKeys.join(", ")}, ...rest } = callbacks;` + : `const rest = callbacks;`; + const handlersInterface = `export interface GooseExtNotifications { +${handlerFields.join("\n")} +}`; + + const dispatcherFn = `export function installGooseExtNotificationDispatcher( + callbacks: GooseClientCallbacks, +): Client { + ${handlerDestructure} + const userExtNotification = rest.extNotification; + return { + ...rest, + extNotification: async (method, params) => { + switch (method) { +${dispatchCases.join("\n")} + default: + await userExtNotification?.(method, params); + return; + } + }, + }; +}`; + + const upstreamImportLine = `import type { ${[...upstreamTypeImports].sort().join(", ")} } from "@agentclientprotocol/sdk";`; const typeImportLine = typeImports.size ? `import type { ${[...typeImports].sort().join(", ")} } from "./types.gen.js";` : ""; @@ -213,6 +311,7 @@ export interface ExtMethodProvider { extMethod(method: string, params: Record): Promise>; } +${upstreamImportLine} ${typeImportLine} ${zodImportLine} @@ -220,6 +319,12 @@ export class GooseExtClient { constructor(private conn: ExtMethodProvider) {} ${methodDefs.join("\n")} } + +${handlersInterface} + +export type GooseClientCallbacks = Client & GooseExtNotifications; + +${dispatcherFn} `; src = await prettier.format(src, { parser: "typescript" }); diff --git a/ui/sdk/src/generated/client.gen.ts b/ui/sdk/src/generated/client.gen.ts index 22beacb47533..0d332565a382 100644 --- a/ui/sdk/src/generated/client.gen.ts +++ b/ui/sdk/src/generated/client.gen.ts @@ -7,6 +7,7 @@ export interface ExtMethodProvider { ): Promise>; } +import type { Client } from "@agentclientprotocol/sdk"; import type { AddConfigExtensionRequest_unstable, AddExtensionRequest_unstable, @@ -40,6 +41,7 @@ import type { DictationSecretSaveRequest_unstable, DictationTranscribeRequest_unstable, DictationTranscribeResponse_unstable, + ElicitationRespondRequest_unstable, ExportSessionRequest_unstable, ExportSessionResponse_unstable, ExportSourceRequest_unstable, @@ -50,6 +52,7 @@ import type { GetSessionExtensionsResponse_unstable, GetToolsRequest_unstable, GetToolsResponse_unstable, + GooseSessionNotification_unstable, GooseToolCallRequest_unstable, GooseToolCallResponse_unstable, ImportSessionRequest_unstable, @@ -115,6 +118,7 @@ import { zGetExtensionsResponse_unstable, zGetSessionExtensionsResponse_unstable, zGetToolsResponse_unstable, + zGooseSessionNotification_unstable, zGooseToolCallResponse_unstable, zImportSessionResponse_unstable, zImportSourcesResponse_unstable, @@ -527,6 +531,12 @@ export class GooseExtClient { ) as ImportSessionResponse_unstable; } + async elicitationRespond_unstable( + params: ElicitationRespondRequest_unstable, + ): Promise { + await this.conn.extMethod("_goose/unstable/elicitation/respond", params); + } + async sessionProjectUpdate_unstable( params: UpdateSessionProjectRequest_unstable, ): Promise { @@ -716,3 +726,35 @@ export class GooseExtClient { ); } } + +export interface GooseExtNotifications { + unstable_sessionUpdate?: ( + notification: GooseSessionNotification_unstable, + ) => Promise; +} + +export type GooseClientCallbacks = Client & GooseExtNotifications; + +export function installGooseExtNotificationDispatcher( + callbacks: GooseClientCallbacks, +): Client { + const { unstable_sessionUpdate, ...rest } = callbacks; + const userExtNotification = rest.extNotification; + return { + ...rest, + extNotification: async (method, params) => { + switch (method) { + case "_goose/unstable/session/update": { + const parsed = zGooseSessionNotification_unstable.parse( + params, + ) as GooseSessionNotification_unstable; + await unstable_sessionUpdate?.(parsed); + return; + } + default: + await userExtNotification?.(method, params); + return; + } + }, + }; +} diff --git a/ui/sdk/src/generated/index.ts b/ui/sdk/src/generated/index.ts index 85a63bc04964..6e3488320a5e 100644 --- a/ui/sdk/src/generated/index.ts +++ b/ui/sdk/src/generated/index.ts @@ -1,6 +1,6 @@ // This file is auto-generated by @hey-api/openapi-ts -export type { AddConfigExtensionRequest_unstable, AddExtensionRequest_unstable, ArchiveSessionRequest_unstable, CreateSourceRequest_unstable, CreateSourceResponse_unstable, CustomProviderConfigDto, CustomProviderCreateRequest_unstable, CustomProviderCreateResponse_unstable, CustomProviderDeleteRequest_unstable, CustomProviderDeleteResponse_unstable, CustomProviderReadRequest_unstable, CustomProviderReadResponse_unstable, CustomProviderUpdateRequest_unstable, CustomProviderUpdateResponse_unstable, DefaultsReadRequest_unstable, DefaultsReadResponse_unstable, DefaultsSaveRequest_unstable, DeleteSessionRequest, DeleteSourceRequest_unstable, DictationConfigRequest_unstable, DictationConfigResponse_unstable, DictationDownloadProgress, DictationLocalModelStatus, DictationModelCancelRequest_unstable, DictationModelDeleteRequest_unstable, DictationModelDownloadProgressRequest_unstable, DictationModelDownloadProgressResponse_unstable, DictationModelDownloadRequest_unstable, DictationModelOption, DictationModelSelectRequest_unstable, DictationModelsListRequest_unstable, DictationModelsListResponse_unstable, DictationProviderStatusEntry, DictationSecretDeleteRequest_unstable, DictationSecretSaveRequest_unstable, DictationTranscribeRequest_unstable, DictationTranscribeResponse_unstable, EmptyResponse, ExportSessionRequest_unstable, ExportSessionResponse_unstable, ExportSourceRequest_unstable, ExportSourceResponse_unstable, ExtRequest, ExtResponse, GetExtensionsRequest_unstable, GetExtensionsResponse_unstable, GetSessionExtensionsRequest_unstable, GetSessionExtensionsResponse_unstable, GetToolsRequest_unstable, GetToolsResponse_unstable, GooseToolCallRequest_unstable, GooseToolCallResponse_unstable, ImportSessionRequest_unstable, ImportSessionResponse_unstable, ImportSourcesRequest_unstable, ImportSourcesResponse_unstable, ListProvidersRequest_unstable, ListProvidersResponse_unstable, ListSourcesRequest_unstable, ListSourcesResponse_unstable, OnboardingImportApplyRequest_unstable, OnboardingImportApplyResponse_unstable, OnboardingImportCandidate, OnboardingImportCounts, OnboardingImportScanRequest_unstable, OnboardingImportScanResponse_unstable, OnboardingImportSourceKind, PreferenceKey, PreferencesReadRequest_unstable, PreferencesReadResponse_unstable, PreferencesRemoveRequest_unstable, PreferencesSaveRequest_unstable, PreferenceValue, ProviderCatalogListRequest_unstable, ProviderCatalogListResponse_unstable, ProviderCatalogTemplateRequest_unstable, ProviderCatalogTemplateResponse_unstable, ProviderConfigAuthenticateRequest_unstable, ProviderConfigChangeResponse_unstable, ProviderConfigDeleteRequest_unstable, ProviderConfigFieldUpdate, ProviderConfigFieldValueDto, ProviderConfigKey, ProviderConfigReadRequest_unstable, ProviderConfigReadResponse_unstable, ProviderConfigSaveRequest_unstable, ProviderConfigStatusDto, ProviderConfigStatusRequest_unstable, ProviderConfigStatusResponse_unstable, ProviderInventoryEntryDto, ProviderInventoryModelDto, ProviderSetupCatalogEntryDto, ProviderSetupCatalogListRequest_unstable, ProviderSetupCatalogListResponse_unstable, ProviderSetupCategoryDto, ProviderSetupFieldDto, ProviderSetupGroupDto, ProviderSetupMethodDto, ProviderSupportedModelsListRequest_unstable, ProviderSupportedModelsListResponse_unstable, ProviderTemplateCapabilitiesDto, ProviderTemplateCatalogEntryDto, ProviderTemplateDto, ProviderTemplateModelDto, ReadResourceRequest_unstable, ReadResourceResponse_unstable, RefreshProviderInventoryRequest_unstable, RefreshProviderInventoryResponse_unstable, RefreshProviderInventorySkipDto, RefreshProviderInventorySkipReasonDto, RemoveConfigExtensionRequest_unstable, RemoveExtensionRequest_unstable, RenameSessionRequest_unstable, SessionSystemPromptMode, SetSessionSystemPromptRequest_unstable, SourceEntry, SourceScope, SourceType, ToggleConfigExtensionRequest_unstable, UnarchiveSessionRequest_unstable, UpdateSessionProjectRequest_unstable, UpdateSourceRequest_unstable, UpdateSourceResponse_unstable, UpdateWorkingDirRequest_unstable } from './types.gen.js'; +export type { AddConfigExtensionRequest_unstable, AddExtensionRequest_unstable, ArchiveSessionRequest_unstable, CreateSourceRequest_unstable, CreateSourceResponse_unstable, CustomProviderConfigDto, CustomProviderCreateRequest_unstable, CustomProviderCreateResponse_unstable, CustomProviderDeleteRequest_unstable, CustomProviderDeleteResponse_unstable, CustomProviderReadRequest_unstable, CustomProviderReadResponse_unstable, CustomProviderUpdateRequest_unstable, CustomProviderUpdateResponse_unstable, DefaultsReadRequest_unstable, DefaultsReadResponse_unstable, DefaultsSaveRequest_unstable, DeleteSessionRequest, DeleteSourceRequest_unstable, DictationConfigRequest_unstable, DictationConfigResponse_unstable, DictationDownloadProgress, DictationLocalModelStatus, DictationModelCancelRequest_unstable, DictationModelDeleteRequest_unstable, DictationModelDownloadProgressRequest_unstable, DictationModelDownloadProgressResponse_unstable, DictationModelDownloadRequest_unstable, DictationModelOption, DictationModelSelectRequest_unstable, DictationModelsListRequest_unstable, DictationModelsListResponse_unstable, DictationProviderStatusEntry, DictationSecretDeleteRequest_unstable, DictationSecretSaveRequest_unstable, DictationTranscribeRequest_unstable, DictationTranscribeResponse_unstable, ElicitationRespondRequest_unstable, EmptyResponse, ExportSessionRequest_unstable, ExportSessionResponse_unstable, ExportSourceRequest_unstable, ExportSourceResponse_unstable, ExtNotification, ExtRequest, ExtResponse, GetExtensionsRequest_unstable, GetExtensionsResponse_unstable, GetSessionExtensionsRequest_unstable, GetSessionExtensionsResponse_unstable, GetToolsRequest_unstable, GetToolsResponse_unstable, GooseSessionNotification_unstable, GooseSessionUpdate, GooseToolCallRequest_unstable, GooseToolCallResponse_unstable, ImportSessionRequest_unstable, ImportSessionResponse_unstable, ImportSourcesRequest_unstable, ImportSourcesResponse_unstable, Interaction, InteractionState, InteractionUpdate, ListProvidersRequest_unstable, ListProvidersResponse_unstable, ListSourcesRequest_unstable, ListSourcesResponse_unstable, OnboardingImportApplyRequest_unstable, OnboardingImportApplyResponse_unstable, OnboardingImportCandidate, OnboardingImportCounts, OnboardingImportScanRequest_unstable, OnboardingImportScanResponse_unstable, OnboardingImportSourceKind, PreferenceKey, PreferencesReadRequest_unstable, PreferencesReadResponse_unstable, PreferencesRemoveRequest_unstable, PreferencesSaveRequest_unstable, PreferenceValue, ProviderCatalogListRequest_unstable, ProviderCatalogListResponse_unstable, ProviderCatalogTemplateRequest_unstable, ProviderCatalogTemplateResponse_unstable, ProviderConfigAuthenticateRequest_unstable, ProviderConfigChangeResponse_unstable, ProviderConfigDeleteRequest_unstable, ProviderConfigFieldUpdate, ProviderConfigFieldValueDto, ProviderConfigKey, ProviderConfigReadRequest_unstable, ProviderConfigReadResponse_unstable, ProviderConfigSaveRequest_unstable, ProviderConfigStatusDto, ProviderConfigStatusRequest_unstable, ProviderConfigStatusResponse_unstable, ProviderInventoryEntryDto, ProviderInventoryModelDto, ProviderSetupCatalogEntryDto, ProviderSetupCatalogListRequest_unstable, ProviderSetupCatalogListResponse_unstable, ProviderSetupCategoryDto, ProviderSetupFieldDto, ProviderSetupGroupDto, ProviderSetupMethodDto, ProviderSupportedModelsListRequest_unstable, ProviderSupportedModelsListResponse_unstable, ProviderTemplateCapabilitiesDto, ProviderTemplateCatalogEntryDto, ProviderTemplateDto, ProviderTemplateModelDto, ReadResourceRequest_unstable, ReadResourceResponse_unstable, RefreshProviderInventoryRequest_unstable, RefreshProviderInventoryResponse_unstable, RefreshProviderInventorySkipDto, RefreshProviderInventorySkipReasonDto, RemoveConfigExtensionRequest_unstable, RemoveExtensionRequest_unstable, RenameSessionRequest_unstable, SessionSystemPromptMode, SessionUsageUpdate, SetSessionSystemPromptRequest_unstable, SourceEntry, SourceScope, SourceType, StatusMessage, StatusMessageUpdate, ToggleConfigExtensionRequest_unstable, UnarchiveSessionRequest_unstable, UpdateSessionProjectRequest_unstable, UpdateSourceRequest_unstable, UpdateSourceResponse_unstable, UpdateWorkingDirRequest_unstable } from './types.gen.js'; export const GOOSE_EXT_METHODS = [ { @@ -188,6 +188,11 @@ export const GOOSE_EXT_METHODS = [ requestType: "ImportSessionRequest_unstable", responseType: "ImportSessionResponse_unstable", }, + { + method: "_goose/unstable/elicitation/respond", + requestType: "ElicitationRespondRequest_unstable", + responseType: "EmptyResponse", + }, { method: "_goose/unstable/session/project/update", requestType: "UpdateSessionProjectRequest_unstable", @@ -291,3 +296,12 @@ export const GOOSE_EXT_METHODS = [ ] as const; export type GooseExtMethod = (typeof GOOSE_EXT_METHODS)[number]; + +export const GOOSE_EXT_NOTIFICATIONS = [ + { + method: "_goose/unstable/session/update", + paramsType: "GooseSessionNotification_unstable", + }, +] as const; + +export type GooseExtNotification = (typeof GOOSE_EXT_NOTIFICATIONS)[number]; diff --git a/ui/sdk/src/generated/types.gen.ts b/ui/sdk/src/generated/types.gen.ts index 410f74ada5be..cfec115f598e 100644 --- a/ui/sdk/src/generated/types.gen.ts +++ b/ui/sdk/src/generated/types.gen.ts @@ -743,6 +743,15 @@ export type ImportSessionResponse_unstable = { messageCount: number; }; +/** + * Submit a response for a pending MCP elicitation in an active session. + */ +export type ElicitationRespondRequest_unstable = { + sessionId: string; + elicitationId: string; + userData?: unknown; +}; + /** * Update the project association for a session. */ @@ -1087,10 +1096,77 @@ export type DictationModelSelectRequest_unstable = { modelId: string; }; +/** + * Goose-custom session update notification — a parallel to ACP's + * `session/update` carrying goose-specific update variants. + */ +export type GooseSessionNotification_unstable = { + sessionId: string; + update: GooseSessionUpdate; +}; + +/** + * Discriminated union of goose-specific session update payloads. + * Variant tag matches ACP's convention (`sessionUpdate: ""`). + * + * `discriminator.mapping` is what makes TS codegen (`@hey-api/openapi-ts`) + * emit the correct snake_case tag value even when this enum has a single + * variant. Add a mapping entry per variant. + */ +export type GooseSessionUpdate = ({ + sessionUpdate: 'usage_update'; +} & SessionUsageUpdate) | ({ + sessionUpdate: 'status_message'; +} & StatusMessageUpdate) | ({ + sessionUpdate: 'interaction_update'; +} & InteractionUpdate); + +/** + * Streaming context-window usage update for a session. + */ +export type SessionUsageUpdate = { + used: number; + contextLimit: number; + accumulatedInputTokens: number; + accumulatedOutputTokens: number; + accumulatedCost?: number | null; +}; + +export type StatusMessage = { + message: string; + type: 'notice'; +} | { + message: string; + type: 'progress'; +}; + +/** + * Live UI/session status. This is not conversation transcript content, and + * should not be persisted or replayed as history. + */ +export type StatusMessageUpdate = { + status: StatusMessage; +}; + +export type Interaction = { + id: string; + state: InteractionState; + message?: string | null; + requestedSchema?: unknown; + type: 'elicitation'; +}; + +export type InteractionState = 'pending' | 'submitted'; + +export type InteractionUpdate = { + interaction: Interaction; + _meta?: unknown; +}; + export type ExtRequest = { id: string; method: string; - params?: AddExtensionRequest_unstable | RemoveExtensionRequest_unstable | GetToolsRequest_unstable | GooseToolCallRequest_unstable | ReadResourceRequest_unstable | UpdateWorkingDirRequest_unstable | SetSessionSystemPromptRequest_unstable | DeleteSessionRequest | GetExtensionsRequest_unstable | AddConfigExtensionRequest_unstable | RemoveConfigExtensionRequest_unstable | ToggleConfigExtensionRequest_unstable | GetSessionExtensionsRequest_unstable | ListProvidersRequest_unstable | ProviderSupportedModelsListRequest_unstable | ProviderCatalogListRequest_unstable | ProviderSetupCatalogListRequest_unstable | ProviderCatalogTemplateRequest_unstable | CustomProviderCreateRequest_unstable | CustomProviderReadRequest_unstable | CustomProviderUpdateRequest_unstable | CustomProviderDeleteRequest_unstable | RefreshProviderInventoryRequest_unstable | ProviderConfigReadRequest_unstable | ProviderConfigStatusRequest_unstable | ProviderConfigSaveRequest_unstable | ProviderConfigDeleteRequest_unstable | ProviderConfigAuthenticateRequest_unstable | PreferencesReadRequest_unstable | PreferencesSaveRequest_unstable | PreferencesRemoveRequest_unstable | DefaultsReadRequest_unstable | DefaultsSaveRequest_unstable | OnboardingImportScanRequest_unstable | OnboardingImportApplyRequest_unstable | ExportSessionRequest_unstable | ImportSessionRequest_unstable | UpdateSessionProjectRequest_unstable | RenameSessionRequest_unstable | ArchiveSessionRequest_unstable | UnarchiveSessionRequest_unstable | CreateSourceRequest_unstable | ListSourcesRequest_unstable | UpdateSourceRequest_unstable | DeleteSourceRequest_unstable | ExportSourceRequest_unstable | ImportSourcesRequest_unstable | DictationTranscribeRequest_unstable | DictationConfigRequest_unstable | DictationSecretSaveRequest_unstable | DictationSecretDeleteRequest_unstable | DictationModelsListRequest_unstable | DictationModelDownloadRequest_unstable | DictationModelDownloadProgressRequest_unstable | DictationModelCancelRequest_unstable | DictationModelDeleteRequest_unstable | DictationModelSelectRequest_unstable | { + params?: AddExtensionRequest_unstable | RemoveExtensionRequest_unstable | GetToolsRequest_unstable | GooseToolCallRequest_unstable | ReadResourceRequest_unstable | UpdateWorkingDirRequest_unstable | SetSessionSystemPromptRequest_unstable | DeleteSessionRequest | GetExtensionsRequest_unstable | AddConfigExtensionRequest_unstable | RemoveConfigExtensionRequest_unstable | ToggleConfigExtensionRequest_unstable | GetSessionExtensionsRequest_unstable | ListProvidersRequest_unstable | ProviderSupportedModelsListRequest_unstable | ProviderCatalogListRequest_unstable | ProviderSetupCatalogListRequest_unstable | ProviderCatalogTemplateRequest_unstable | CustomProviderCreateRequest_unstable | CustomProviderReadRequest_unstable | CustomProviderUpdateRequest_unstable | CustomProviderDeleteRequest_unstable | RefreshProviderInventoryRequest_unstable | ProviderConfigReadRequest_unstable | ProviderConfigStatusRequest_unstable | ProviderConfigSaveRequest_unstable | ProviderConfigDeleteRequest_unstable | ProviderConfigAuthenticateRequest_unstable | PreferencesReadRequest_unstable | PreferencesSaveRequest_unstable | PreferencesRemoveRequest_unstable | DefaultsReadRequest_unstable | DefaultsSaveRequest_unstable | OnboardingImportScanRequest_unstable | OnboardingImportApplyRequest_unstable | ExportSessionRequest_unstable | ImportSessionRequest_unstable | ElicitationRespondRequest_unstable | UpdateSessionProjectRequest_unstable | RenameSessionRequest_unstable | ArchiveSessionRequest_unstable | UnarchiveSessionRequest_unstable | CreateSourceRequest_unstable | ListSourcesRequest_unstable | UpdateSourceRequest_unstable | DeleteSourceRequest_unstable | ExportSourceRequest_unstable | ImportSourcesRequest_unstable | DictationTranscribeRequest_unstable | DictationConfigRequest_unstable | DictationSecretSaveRequest_unstable | DictationSecretDeleteRequest_unstable | DictationModelsListRequest_unstable | DictationModelDownloadRequest_unstable | DictationModelDownloadProgressRequest_unstable | DictationModelCancelRequest_unstable | DictationModelDeleteRequest_unstable | DictationModelSelectRequest_unstable | { [key: string]: unknown; } | null; }; @@ -1106,3 +1182,10 @@ export type ExtResponse = { }; id: string; }; + +export type ExtNotification = { + method: string; + params?: GooseSessionNotification_unstable | { + [key: string]: unknown; + } | null; +}; diff --git a/ui/sdk/src/generated/zod.gen.ts b/ui/sdk/src/generated/zod.gen.ts index efebe29bc0d9..0597c7f716d0 100644 --- a/ui/sdk/src/generated/zod.gen.ts +++ b/ui/sdk/src/generated/zod.gen.ts @@ -760,6 +760,15 @@ export const zImportSessionResponse_unstable = z.object({ messageCount: z.number().int().gte(0) }); +/** + * Submit a response for a pending MCP elicitation in an active session. + */ +export const zElicitationRespondRequest_unstable = z.object({ + sessionId: z.string(), + elicitationId: z.string(), + userData: z.unknown().optional().default(null) +}); + /** * Update the project association for a session. */ @@ -1088,6 +1097,86 @@ export const zDictationModelSelectRequest_unstable = z.object({ modelId: z.string() }); +/** + * Streaming context-window usage update for a session. + */ +export const zSessionUsageUpdate = z.object({ + used: z.number().int().gte(0), + contextLimit: z.number().int().gte(0), + accumulatedInputTokens: z.number().int().gte(0), + accumulatedOutputTokens: z.number().int().gte(0), + accumulatedCost: z.union([ + z.number(), + z.null() + ]).optional() +}); + +export const zStatusMessage = z.union([ + z.object({ + message: z.string(), + type: z.literal('notice') + }), + z.object({ + message: z.string(), + type: z.literal('progress') + }) +]); + +/** + * Live UI/session status. This is not conversation transcript content, and + * should not be persisted or replayed as history. + */ +export const zStatusMessageUpdate = z.object({ + status: zStatusMessage +}); + +export const zInteractionState = z.enum(['pending', 'submitted']); + +export const zInteraction = z.object({ + id: z.string(), + state: zInteractionState, + message: z.union([ + z.string(), + z.null() + ]).optional(), + requestedSchema: z.unknown().optional(), + type: z.literal('elicitation') +}); + +export const zInteractionUpdate = z.object({ + interaction: zInteraction, + _meta: z.unknown().optional() +}); + +/** + * Discriminated union of goose-specific session update payloads. + * Variant tag matches ACP's convention (`sessionUpdate: ""`). + * + * `discriminator.mapping` is what makes TS codegen (`@hey-api/openapi-ts`) + * emit the correct snake_case tag value even when this enum has a single + * variant. Add a mapping entry per variant. + */ +export const zGooseSessionUpdate = z.union([ + z.object({ + sessionUpdate: z.literal('usage_update') + }).and(zSessionUsageUpdate), + z.object({ + sessionUpdate: z.literal('status_message') + }).and(zStatusMessageUpdate), + z.object({ + sessionUpdate: z.literal('interaction_update') + }).and(zInteractionUpdate) +]); + +/** + * Goose-custom session update notification — a parallel to ACP's + * `session/update` carrying goose-specific update variants. + */ +export const zGooseSessionNotification_unstable = z.object({ + sessionId: z.string(), + update: zGooseSessionUpdate +}); + export const zExtRequest = z.object({ id: z.string(), method: z.string(), @@ -1130,6 +1219,7 @@ export const zExtRequest = z.object({ zOnboardingImportApplyRequest_unstable, zExportSessionRequest_unstable, zImportSessionRequest_unstable, + zElicitationRespondRequest_unstable, zUpdateSessionProjectRequest_unstable, zRenameSessionRequest_unstable, zArchiveSessionRequest_unstable, @@ -1210,3 +1300,14 @@ export const zExtResponse = z.union([ id: z.string() }) ]); + +export const zExtNotification = z.object({ + method: z.string(), + params: z.union([ + zGooseSessionNotification_unstable, + z.union([ + z.record(z.unknown()), + z.null() + ]) + ]).optional() +}); diff --git a/ui/sdk/src/goose-client.ts b/ui/sdk/src/goose-client.ts index c697dfccad0b..f969f7a570be 100644 --- a/ui/sdk/src/goose-client.ts +++ b/ui/sdk/src/goose-client.ts @@ -1,6 +1,5 @@ import { ClientSideConnection, - type Client, type Stream, type InitializeRequest, type InitializeResponse, @@ -26,19 +25,28 @@ import { type SetSessionModelRequest, type SetSessionModelResponse, } from "@agentclientprotocol/sdk"; -import { GooseExtClient } from "./generated/client.gen.js"; +import { + GooseExtClient, + installGooseExtNotificationDispatcher, + type GooseClientCallbacks, +} from "./generated/client.gen.js"; import { createHttpStream } from "./http-stream.js"; export class GooseClient { private conn: ClientSideConnection; private ext: GooseExtClient; - constructor(toClient: () => Client, streamOrUrl: Stream | string) { + constructor( + toClient: () => GooseClientCallbacks, + streamOrUrl: Stream | string, + ) { const stream = typeof streamOrUrl === "string" ? createHttpStream(streamOrUrl) : streamOrUrl; - this.conn = new ClientSideConnection(toClient, stream); + const toAcpClient = () => + installGooseExtNotificationDispatcher(toClient()); + this.conn = new ClientSideConnection(toAcpClient, stream); this.ext = new GooseExtClient(this.conn); } diff --git a/ui/sdk/src/index.ts b/ui/sdk/src/index.ts index 5e587ed92e89..aa4cbe960f54 100644 --- a/ui/sdk/src/index.ts +++ b/ui/sdk/src/index.ts @@ -1,5 +1,9 @@ export * from "./generated/types.gen.js"; export * from "./generated/zod.gen.js"; +export { + type GooseClientCallbacks, + type GooseExtNotifications, +} from "./generated/client.gen.js"; export { GooseClient } from "./goose-client.js"; export { createHttpStream } from "./http-stream.js"; export * from "./mcp-apps.js";