Skip to content

Commit 291a7a7

Browse files
committed
fix(identity): stop collapsing distinct workspace accountIds
1 parent a98e609 commit 291a7a7

File tree

6 files changed

+435
-83
lines changed

6 files changed

+435
-83
lines changed

index.ts

Lines changed: 149 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -594,17 +594,87 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
594594
};
595595
};
596596

597+
const normalizeStoredAccountId = (
598+
account: { accountId?: string } | undefined,
599+
): string | undefined => {
600+
const accountId = account?.accountId?.trim();
601+
return accountId && accountId.length > 0 ? accountId : undefined;
602+
};
603+
604+
const hasDistinctNonEmptyAccountIds = (
605+
left: { accountId?: string } | undefined,
606+
right: { accountId?: string } | undefined,
607+
): boolean => {
608+
const leftId = normalizeStoredAccountId(left);
609+
const rightId = normalizeStoredAccountId(right);
610+
return !!leftId && !!rightId && leftId !== rightId;
611+
};
612+
613+
const canCollapseWithCandidateAccountId = (
614+
existing: { accountId?: string } | undefined,
615+
candidateAccountId: string | undefined,
616+
): boolean => {
617+
const existingAccountId = normalizeStoredAccountId(existing);
618+
const normalizedCandidate = candidateAccountId?.trim() || undefined;
619+
if (!existingAccountId || !normalizedCandidate) {
620+
return true;
621+
}
622+
return existingAccountId === normalizedCandidate;
623+
};
624+
597625

598626
type IdentityIndexes = {
599627
byOrganizationId: Map<string, number>;
600628
byAccountIdNoOrg: Map<string, number>;
601-
byRefreshTokenNoOrg: Map<string, number>;
629+
byRefreshTokenNoOrg: Map<string, number[]>;
602630
byEmailNoOrg: Map<string, number>;
603631
byAccountIdOrgScoped: Map<string, number[]>;
604632
byRefreshTokenOrgScoped: Map<string, number[]>;
605633
byRefreshTokenGlobal: Map<string, number[]>;
606634
};
607635

636+
const pickNewestFromIndices = (indices: number[]): number | undefined => {
637+
if (indices.length === 0) return undefined;
638+
const first = indices[0];
639+
if (typeof first !== "number") return undefined;
640+
let newestIndex = first;
641+
for (let i = 1; i < indices.length; i += 1) {
642+
const candidate = indices[i];
643+
if (typeof candidate !== "number") continue;
644+
newestIndex = pickNewestAccountIndex(newestIndex, candidate);
645+
}
646+
return newestIndex;
647+
};
648+
649+
const resolveNoOrgRefreshMatch = (
650+
indexes: IdentityIndexes,
651+
refreshToken: string,
652+
candidateAccountId: string | undefined,
653+
): number | undefined => {
654+
const candidateId = candidateAccountId?.trim() || undefined;
655+
const matches = indexes.byRefreshTokenNoOrg.get(refreshToken);
656+
if (!matches || matches.length === 0) return undefined;
657+
658+
const withNoAccountId = matches.filter((index) => {
659+
const existing = accounts[index];
660+
return !normalizeStoredAccountId(existing);
661+
});
662+
663+
if (!candidateId) {
664+
return pickNewestFromIndices(withNoAccountId);
665+
}
666+
667+
const exactMatches = matches.filter((index) => {
668+
const existing = accounts[index];
669+
return normalizeStoredAccountId(existing) === candidateId;
670+
});
671+
if (exactMatches.length > 0) {
672+
return pickNewestFromIndices(exactMatches);
673+
}
674+
675+
return pickNewestFromIndices(withNoAccountId);
676+
};
677+
608678
const resolveUniqueOrgScopedMatch = (
609679
indexes: IdentityIndexes,
610680
accountId: string | undefined,
@@ -625,7 +695,7 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
625695
const buildIdentityIndexes = (): IdentityIndexes => {
626696
const byOrganizationId = new Map<string, number>();
627697
const byAccountIdNoOrg = new Map<string, number>();
628-
const byRefreshTokenNoOrg = new Map<string, number>();
698+
const byRefreshTokenNoOrg = new Map<string, number[]>();
629699
const byEmailNoOrg = new Map<string, number>();
630700
const byAccountIdOrgScoped = new Map<string, number[]>();
631701
const byRefreshTokenOrgScoped = new Map<string, number[]>();
@@ -661,7 +731,7 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
661731
byAccountIdNoOrg.set(accountId, i);
662732
}
663733
if (refreshToken) {
664-
byRefreshTokenNoOrg.set(refreshToken, i);
734+
pushIndex(byRefreshTokenNoOrg, refreshToken, i);
665735
}
666736
if (email) {
667737
byEmailNoOrg.set(email, i);
@@ -704,17 +774,21 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
704774
}
705775
}
706776

707-
const byRefreshToken = identityIndexes.byRefreshTokenNoOrg.get(result.refresh);
777+
const byRefreshToken = resolveNoOrgRefreshMatch(
778+
identityIndexes,
779+
result.refresh,
780+
normalizedAccountId,
781+
);
708782
if (byRefreshToken !== undefined) {
709783
return byRefreshToken;
710784
}
711785

712-
if (accountEmail) {
713-
const byEmail = identityIndexes.byEmailNoOrg.get(accountEmail);
714-
if (byEmail !== undefined) {
715-
return byEmail;
716-
}
786+
if (accountEmail && !normalizedAccountId) {
787+
const byEmail = identityIndexes.byEmailNoOrg.get(accountEmail);
788+
if (byEmail !== undefined) {
789+
return byEmail;
717790
}
791+
}
718792

719793
const orgScoped = resolveUniqueOrgScopedMatch(
720794
identityIndexes,
@@ -723,10 +797,20 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
723797
);
724798
if (orgScoped !== undefined) return orgScoped;
725799

726-
// Absolute last resort: only collapse when refresh token maps to a
727-
// single account. Avoids merging distinct org-scoped variants.
728-
return asUniqueIndex(identityIndexes.byRefreshTokenGlobal.get(result.refresh));
729-
})();
800+
// Absolute last resort: only collapse when refresh token maps to a
801+
// single compatible account. Avoids merging distinct workspace variants.
802+
const globalRefreshMatch = asUniqueIndex(
803+
identityIndexes.byRefreshTokenGlobal.get(result.refresh),
804+
);
805+
if (globalRefreshMatch === undefined) {
806+
return undefined;
807+
}
808+
const existing = accounts[globalRefreshMatch];
809+
if (!canCollapseWithCandidateAccountId(existing, normalizedAccountId)) {
810+
return undefined;
811+
}
812+
return globalRefreshMatch;
813+
})();
730814

731815
if (existingIndex === undefined) {
732816
accounts.push({
@@ -784,7 +868,12 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
784868
const indicesToRemove = new Set<number>();
785869
const refreshMap = new Map<
786870
string,
787-
{ byOrg: Map<string, number>; preferredOrgIndex?: number; fallbackIndex?: number }
871+
{
872+
byOrg: Map<string, number>;
873+
preferredOrgIndex?: number;
874+
fallbackNoAccountIdIndex?: number;
875+
fallbackByAccountId: Map<string, number>;
876+
}
788877
>();
789878

790879
const pickPreferredOrgIndex = (
@@ -798,25 +887,37 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
798887
const collapseFallbackIntoPreferredOrg = (entry: {
799888
byOrg: Map<string, number>;
800889
preferredOrgIndex?: number;
801-
fallbackIndex?: number;
890+
fallbackNoAccountIdIndex?: number;
891+
fallbackByAccountId: Map<string, number>;
802892
}): void => {
803-
if (entry.preferredOrgIndex === undefined || entry.fallbackIndex === undefined) {
893+
if (entry.preferredOrgIndex === undefined) {
804894
return;
805895
}
806896

807897
const preferredOrgIndex = entry.preferredOrgIndex;
808-
const fallbackIndex = entry.fallbackIndex;
809-
if (preferredOrgIndex === fallbackIndex) {
810-
entry.fallbackIndex = undefined;
811-
return;
812-
}
813-
814-
const target = accounts[preferredOrgIndex];
815-
const source = accounts[fallbackIndex];
816-
if (target && source) {
898+
const collapseFallbackIndex = (fallbackIndex: number): boolean => {
899+
if (preferredOrgIndex === fallbackIndex) return true;
900+
const target = accounts[preferredOrgIndex];
901+
const source = accounts[fallbackIndex];
902+
if (!target || !source) return true;
903+
if (hasDistinctNonEmptyAccountIds(target, source)) {
904+
return false;
905+
}
817906
mergeAccountRecords(preferredOrgIndex, fallbackIndex);
818907
indicesToRemove.add(fallbackIndex);
819-
entry.fallbackIndex = undefined;
908+
return true;
909+
};
910+
911+
if (typeof entry.fallbackNoAccountIdIndex === "number") {
912+
if (collapseFallbackIndex(entry.fallbackNoAccountIdIndex)) {
913+
entry.fallbackNoAccountIdIndex = undefined;
914+
}
915+
}
916+
917+
for (const [accountId, fallbackIndex] of entry.fallbackByAccountId) {
918+
if (collapseFallbackIndex(fallbackIndex)) {
919+
entry.fallbackByAccountId.delete(accountId);
920+
}
820921
}
821922
};
822923

@@ -831,7 +932,8 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
831932
entry = {
832933
byOrg: new Map<string, number>(),
833934
preferredOrgIndex: undefined,
834-
fallbackIndex: undefined,
935+
fallbackNoAccountIdIndex: undefined,
936+
fallbackByAccountId: new Map<string, number>(),
835937
};
836938
refreshMap.set(refreshToken, entry);
837939
}
@@ -854,17 +956,34 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
854956
continue;
855957
}
856958

857-
const existingFallback = entry.fallbackIndex;
959+
const fallbackAccountId = normalizeStoredAccountId(account);
960+
if (fallbackAccountId) {
961+
const existingFallback = entry.fallbackByAccountId.get(fallbackAccountId);
962+
if (typeof existingFallback === "number") {
963+
const newestIndex = pickNewestAccountIndex(existingFallback, i);
964+
const obsoleteIndex = newestIndex === existingFallback ? i : existingFallback;
965+
mergeAccountRecords(newestIndex, obsoleteIndex);
966+
indicesToRemove.add(obsoleteIndex);
967+
entry.fallbackByAccountId.set(fallbackAccountId, newestIndex);
968+
collapseFallbackIntoPreferredOrg(entry);
969+
continue;
970+
}
971+
entry.fallbackByAccountId.set(fallbackAccountId, i);
972+
collapseFallbackIntoPreferredOrg(entry);
973+
continue;
974+
}
975+
976+
const existingFallback = entry.fallbackNoAccountIdIndex;
858977
if (typeof existingFallback === "number") {
859978
const newestIndex = pickNewestAccountIndex(existingFallback, i);
860979
const obsoleteIndex = newestIndex === existingFallback ? i : existingFallback;
861980
mergeAccountRecords(newestIndex, obsoleteIndex);
862981
indicesToRemove.add(obsoleteIndex);
863-
entry.fallbackIndex = newestIndex;
982+
entry.fallbackNoAccountIdIndex = newestIndex;
864983
collapseFallbackIntoPreferredOrg(entry);
865984
continue;
866985
}
867-
entry.fallbackIndex = i;
986+
entry.fallbackNoAccountIdIndex = i;
868987
collapseFallbackIntoPreferredOrg(entry);
869988
}
870989

lib/auth/token-utils.ts

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,37 @@ function extractOrganizationIdsByIndex(value: unknown): Array<string | undefined
226226
});
227227
}
228228

229+
function extractPrimaryAuthClaimOrganizationId(
230+
payload: JWTPayload | Record<string, unknown> | null,
231+
): string | undefined {
232+
if (!payload || !isRecord(payload)) return undefined;
233+
const auth = payload[JWT_CLAIM_PATH];
234+
if (!isRecord(auth)) return undefined;
235+
const organizations = normalizeCandidateArray(auth.organizations);
236+
if (organizations.length === 0) return undefined;
237+
const firstOrganization = organizations[0];
238+
if (!isRecord(firstOrganization)) return undefined;
239+
return toStringValue(firstOrganization.id);
240+
}
241+
229242
function extractCanonicalOrganizationIds(
230243
payload: JWTPayload | Record<string, unknown> | null,
231244
): Array<string | undefined> {
232245
if (!payload || !isRecord(payload)) return [];
233246
const auth = payload[JWT_CLAIM_PATH];
234247
if (!isRecord(auth)) return [];
235-
return extractOrganizationIdsByIndex(auth.organizations);
248+
249+
const organizationIds = extractOrganizationIdsByIndex(auth.organizations);
250+
if (organizationIds.length === 0) return organizationIds;
251+
252+
// Authoritative source: idToken['https://api.openai.com/auth'].organizations[0].id
253+
// Only fall back to broader field extraction when this exact path is missing.
254+
const primaryOrganizationId = extractPrimaryAuthClaimOrganizationId(payload);
255+
if (primaryOrganizationId) {
256+
organizationIds[0] = primaryOrganizationId;
257+
}
258+
259+
return organizationIds;
236260
}
237261

238262
function resolveOrganizationOverridesForKey(
@@ -443,12 +467,12 @@ export function getAccountIdCandidates(
443467

444468
if (idToken) {
445469
const decoded = decodeJWT(idToken);
446-
const canonicalOrganizationIds = extractCanonicalOrganizationIds(decoded);
470+
const primaryOrganizationId = extractPrimaryAuthClaimOrganizationId(decoded);
447471
const idAccountId = extractAccountIdFromPayload(decoded);
448472
if (idAccountId && idAccountId !== accessId) {
449473
candidates.push({
450474
accountId: idAccountId,
451-
organizationId: canonicalOrganizationIds[0],
475+
organizationId: primaryOrganizationId,
452476
label: formatTokenCandidateLabel("ID token account", idAccountId),
453477
source: "id_token",
454478
});

0 commit comments

Comments
 (0)