Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ export type MiddlewareRef = z.infer<typeof MiddlewareRefSchema>;
*/
export const ModelReferenceSchema = z.object({
name: z.string(),
configSchema: z.any().optional(),
info: z.any().optional(),
version: z.string().optional(),
config: z.any().optional(),
});

Expand Down
5 changes: 0 additions & 5 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -888,11 +888,6 @@
"name": {
"type": "string"
},
"configSchema": {},
"info": {},
"version": {
"type": "string"
},
"config": {}
},
"required": [
Expand Down
7 changes: 2 additions & 5 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,8 @@ const (
)

type ModelReference struct {
Config any `json:"config,omitempty"`
ConfigSchema any `json:"configSchema,omitempty"`
Info any `json:"info,omitempty"`
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
Config any `json:"config,omitempty"`
Name string `json:"name,omitempty"`
}

// A ModelRequest is a request to generate completions from a model.
Expand Down
163 changes: 141 additions & 22 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ import {
shouldInjectFormatInstructions,
} from './generate/action.js';
import { GenerateResponseChunk } from './generate/chunk.js';
import {
GenerateMiddleware,
generateMiddleware,
GenerateMiddlewareDef,
resolveMiddleware,
} from './generate/middleware.js';
import { GenerateResponse } from './generate/response.js';
import { Message } from './message.js';
import {
Expand All @@ -51,6 +57,7 @@ import {
type GenerateRequest,
type GenerationCommonConfigSchema,
type MessageData,
type MiddlewareRef,
type ModelArgument,
type ModelMiddlewareArgument,
type Part,
Expand Down Expand Up @@ -172,7 +179,7 @@ export interface GenerateOptions<
*/
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
/** Middleware to be used with this model call. */
use?: ModelMiddlewareArgument[];
use?: (ModelMiddlewareArgument | GenerateMiddleware | MiddlewareRef)[];
/** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */
context?: ActionContext;
/** Abort signal for the generate request. */
Expand Down Expand Up @@ -362,6 +369,108 @@ function messagesFromOptions(options: GenerateOptions): MessageData[] {
/** A GenerationBlockedError is thrown when a generation is blocked. */
export class GenerationBlockedError extends GenerationResponseError {}

/**
* Normalizes a mix of middleware representations into an array of standardized `MiddlewareRef`s.
* Any raw functional middleware or unregistered middleware objects are dynamically registered
* into the provided registry.
*
* @param registry The registry to use for looking up or dynamically registering middleware.
* @param middlewareList An array of middleware functions, instances, or references.
* @returns A promise resolving to an array of normalized `MiddlewareRef` objects.
*/
export async function normalizeMiddleware(
registry: Registry,
middlewareList?: (
| ModelMiddlewareArgument
| GenerateMiddleware
| MiddlewareRef
)[]
): Promise<MiddlewareRef[]> {
if (!middlewareList || middlewareList.length === 0) {
return [];
}

const refs: MiddlewareRef[] = [];

for (let i = 0; i < middlewareList.length; i++) {
const middleware = middlewareList[i];

if (
typeof middleware === 'function' &&
(middleware as any).instantiate &&
(middleware as any).plugin
) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Middleware ${(middleware as any).name || 'function'} must be called with () when used in 'use' array.`,
});
}

if (typeof middleware === 'function') {
const name = `dynamic-middleware-${i}-${Math.random().toString(36).slice(2)}`;

const wrappedDef = generateMiddleware(
{ name, metadata: { dynamic: true } },
() => ({
model: async (req, ctx, next) => {
if (middleware.length === 3) {
return (middleware as any)(
req,
ctx,
async (modifiedReq: any, opts: any) =>
next(modifiedReq || req, opts || ctx)
);
} else {
return (middleware as any)(req, async (modifiedReq: any) =>
next(modifiedReq || req, ctx)
);
}
},
})
);
registry.registerValue('middleware', name, wrappedDef);
refs.push({ name });
continue;
}

if (
typeof middleware === 'object' &&
middleware !== null &&
'instantiate' in middleware &&
typeof middleware.instantiate === 'function'
) {
const def = middleware as GenerateMiddleware;
const registered = await registry.lookupValue<GenerateMiddleware>(
'middleware',
def.name
);
if (!registered) {
registry.registerValue('middleware', def.name, def);
}
refs.push({ name: def.name });
continue;
}

if (
typeof middleware === 'object' &&
middleware !== null &&
'name' in middleware
) {
const ref = middleware as MiddlewareRef & { __def?: GenerateMiddleware };
const registered = await registry.lookupValue<GenerateMiddleware>(
'middleware',
ref.name
);
if (!registered && ref.__def) {
registry.registerValue('middleware', ref.name, ref.__def);
}
refs.push({ name: ref.name, config: ref.config });
}
}

return refs;
}

/**
* Generate calls a generative model based on the provided prompt and configuration. If
* `history` is provided, the generation will include a conversation history in its
Expand All @@ -387,8 +496,16 @@ export async function generate<
};
const resolvedFormat = await resolveFormat(registry, resolvedOptions.output);

registry = maybeRegisterDynamicTools(registry, resolvedOptions);
registry = maybeRegisterDynamicResources(registry, resolvedOptions);
registry = Registry.withParent(registry);

maybeRegisterDynamicTools(registry, resolvedOptions);
maybeRegisterDynamicResources(registry, resolvedOptions);

const middlewareRefs = await normalizeMiddleware(
registry,
resolvedOptions.use
);
resolvedOptions.use = middlewareRefs; // Cast back because `use` can be generic

const params = await toGenerateActionOptions(registry, resolvedOptions);

Expand All @@ -400,10 +517,14 @@ export async function generate<
const streamingCallback = stripNoop(
resolvedOptions.onChunk ?? resolvedOptions.streamingCallback
) as StreamingCallback<GenerateResponseChunkData>;

const resolvedMiddleware = await resolveMiddleware(registry, middlewareRefs);
maybeRegisterDynamicMiddlewareTools(registry, resolvedMiddleware);

const response = await runWithContext(resolvedOptions.context, () =>
generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
middleware: resolvedMiddleware,
abortSignal: resolvedOptions.abortSignal,
streamingCallback,
})
Expand Down Expand Up @@ -451,46 +572,43 @@ export async function generateOperation<
return operation;
}

export function maybeRegisterDynamicMiddlewareTools(
registry: Registry,
middlewares?: GenerateMiddlewareDef[]
) {
middlewares?.forEach((mw) => {
mw.tools?.forEach((t) => {
if (isDynamicTool(t)) {
registry.registerAction('tool', t as Action);
}
});
});
}

function maybeRegisterDynamicTools<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(registry: Registry, options: GenerateOptions<O, CustomOptions>): Registry {
let hasDynamicTools = false;
>(registry: Registry, options: GenerateOptions<O, CustomOptions>) {
options?.tools?.forEach((t) => {
if (isDynamicTool(t)) {
if (!hasDynamicTools) {
hasDynamicTools = true;
// Create a temporary registry with dynamic tools for the duration of this
// generate request.
registry = Registry.withParent(registry);
}
if (isMultipartTool(t)) {
registry.registerAction('tool.v2', t);
} else {
registry.registerAction('tool', t as Action);
}
}
});
return registry;
}

function maybeRegisterDynamicResources<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(registry: Registry, options: GenerateOptions<O, CustomOptions>): Registry {
let hasDynamicResources = false;
>(registry: Registry, options: GenerateOptions<O, CustomOptions>) {
options?.resources?.forEach((r) => {
if (isDynamicResourceAction(r)) {
if (!hasDynamicResources) {
hasDynamicResources = true;
// Create a temporary registry with dynamic tools for the duration of this
// generate request.
registry = Registry.withParent(registry);
}
registry.registerAction('resource', r);
}
});
return registry;
}

export async function toGenerateActionOptions<
Expand Down Expand Up @@ -547,6 +665,7 @@ export async function toGenerateActionOptions<
returnToolRequests: options.returnToolRequests,
maxTurns: options.maxTurns,
stepName: options.stepName,
use: options.use as MiddlewareRef[] | undefined,
};
// if config is empty and it was not explicitly passed in, we delete it, don't want {}
if (Object.keys(params.config).length === 0 && !options.config) {
Expand Down
Loading
Loading