diff --git a/.github/workflows/qa.yml b/.github/workflows/qa.yml new file mode 100644 index 0000000..bb5d2de --- /dev/null +++ b/.github/workflows/qa.yml @@ -0,0 +1,33 @@ +name: QA Instructions + +on: + pull_request: + types: [opened, synchronize] + +permissions: + pull-requests: write + models: read + +jobs: + qa-instructions: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version-file: ".node-version" + cache: "npm" + + - name: Install dependencies + run: npm ci + + - name: Build + run: npm run build + + - name: Generate QA Instructions + uses: ./ + with: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index 25dee65..b7dd5d1 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,17 @@ [![CI](https://github.com/slifty/qa-instructions-action/actions/workflows/ci.yml/badge.svg)](https://github.com/slifty/qa-instructions-action/actions/workflows/ci.yml) -A GitHub Action that automatically generates QA testing instructions for pull requests using Claude. On each PR push, it gathers context about the changes and posts (or updates) a comment with structured testing instructions. +A GitHub Action that automatically generates QA testing instructions for pull requests using AI. On each PR push, it gathers context about the changes and posts (or updates) a comment with structured testing instructions. + +Supports two AI providers: + +- **GitHub Models** (default) — uses the GitHub Models inference API with your existing `GITHUB_TOKEN`. No API keys or subscriptions required. +- **Anthropic** — uses the Anthropic API with a Claude model. Requires an API key. ## Usage +### GitHub Models (default) + ```yaml name: QA Instructions on: @@ -14,6 +21,7 @@ on: permissions: pull-requests: write + models: read jobs: qa-instructions: @@ -22,23 +30,59 @@ jobs: - uses: slifty/qa-instructions-action@v1 with: github-token: ${{ secrets.GITHUB_TOKEN }} - anthropic-api-key: ${{ secrets.ANTHROPIC_API_KEY }} ``` **Requirements:** +- `permissions: models: read` is required for GitHub Models API access - `permissions: pull-requests: write` is required for posting PR comments -- `ANTHROPIC_API_KEY` must be stored as a repository secret - The `synchronize` event type triggers on each push, updating the existing comment +### Anthropic + +```yaml +name: QA Instructions +on: + pull_request: + types: [opened, synchronize] + +permissions: + pull-requests: write + +jobs: + qa-instructions: + runs-on: ubuntu-latest + steps: + - uses: slifty/qa-instructions-action@v1 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + provider: anthropic + anthropic-api-key: ${{ secrets.ANTHROPIC_API_KEY }} +``` + +**Requirements:** + +- `ANTHROPIC_API_KEY` must be stored as a repository secret +- `permissions: pull-requests: write` is required for posting PR comments + ## Inputs -| Input | Description | Required | Default | -| ------------------- | ------------------------------------------- | -------- | ---------------------------- | -| `github-token` | GitHub token for API access | Yes | `${{ github.token }}` | -| `anthropic-api-key` | Anthropic API key for Claude | Yes | — | -| `prompt` | Optional custom instructions for the prompt | No | `""` | -| `model` | Claude model to use | No | `claude-sonnet-4-5-20250929` | +| Input | Description | Required | Default | +| ------------------- | ------------------------------------------------------------ | -------- | --------------------- | +| `github-token` | GitHub token for API access and GitHub Models authentication | Yes | `${{ github.token }}` | +| `provider` | AI provider: `"github-models"` or `"anthropic"` | No | `"github-models"` | +| `anthropic-api-key` | Anthropic API key (required when provider is `"anthropic"`) | No | `""` | +| `prompt` | Optional custom instructions appended to the prompt | No | `""` | +| `model` | AI model to use (defaults to a provider-appropriate model) | No | `""` | + +### Default models + +| Provider | Default model | +| --------------- | ---------------------------- | +| `github-models` | `openai/gpt-4o` | +| `anthropic` | `claude-sonnet-4-5-20250929` | + +You can override the model with any model supported by the chosen provider. ## Outputs @@ -50,7 +94,7 @@ jobs: 1. Gathers PR context: metadata, diff, changed file contents, repository file tree, and commit history 2. Builds a structured prompt with tiered truncation to fit within model context limits -3. Sends the prompt to Claude, which generates QA instructions covering: +3. Sends the prompt to the configured AI provider, which generates QA instructions covering: - Summary of changes - Test environment setup - Specific test scenarios with steps and expected results diff --git a/action.yml b/action.yml index 12b0543..ba77be3 100644 --- a/action.yml +++ b/action.yml @@ -1,21 +1,26 @@ name: "QA Instructions Action" -description: "A GitHub Action that generates QA testing instructions for pull requests using Claude" +description: "A GitHub Action that generates QA testing instructions for pull requests using AI" inputs: github-token: - description: "GitHub token for API access" + description: "GitHub token for API access (also used for GitHub Models authentication)" required: true default: ${{ github.token }} + provider: + description: 'AI provider to use: "github-models" or "anthropic"' + required: false + default: "github-models" anthropic-api-key: - description: "Anthropic API key for Claude" - required: true + description: 'Anthropic API key for Claude (required when provider is "anthropic")' + required: false + default: "" prompt: description: "Optional custom instructions appended to the prompt" required: false default: "" model: - description: "Claude model to use" + description: "AI model to use (defaults to provider-appropriate model if not set)" required: false - default: "claude-sonnet-4-5-20250929" + default: "" outputs: instructions: description: "The generated QA instructions" diff --git a/src/claude.test.ts b/src/claude.test.ts index 1a87005..59fc065 100644 --- a/src/claude.test.ts +++ b/src/claude.test.ts @@ -14,31 +14,35 @@ vi.mock("@anthropic-ai/sdk", () => { }; }); -import { generateQAInstructions } from "./claude.js"; -import { DEFAULT_MODEL } from "./constants.js"; +import { createAnthropicProvider } from "./claude.js"; +import { DEFAULT_ANTHROPIC_MODEL } from "./constants.js"; -describe("generateQAInstructions", () => { +describe("createAnthropicProvider", () => { beforeEach(() => { vi.clearAllMocks(); }); - it("sends context to Claude and returns text response", async () => { + it("sends prompts to Claude and returns text response", async () => { mockCreate.mockResolvedValue({ content: [{ type: "text", text: "## QA Instructions\n\nTest this." }], }); - const result = await generateQAInstructions( + const provider = createAnthropicProvider( "test-api-key", - DEFAULT_MODEL, - "PR context here", + DEFAULT_ANTHROPIC_MODEL, + ); + const result = await provider.generateQAInstructions( + "system prompt", + "user prompt", ); expect(result).toBe("## QA Instructions\n\nTest this."); expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ - model: DEFAULT_MODEL, + model: DEFAULT_ANTHROPIC_MODEL, max_tokens: 4096, - messages: [{ role: "user", content: "PR context here" }], + system: "system prompt", + messages: [{ role: "user", content: "user prompt" }], }), ); }); @@ -48,8 +52,10 @@ describe("generateQAInstructions", () => { content: [], }); + const provider = createAnthropicProvider("key", "model"); + await expect( - generateQAInstructions("key", "model", "context"), + provider.generateQAInstructions("system", "user"), ).rejects.toThrow("No text content in Claude response"); }); }); diff --git a/src/claude.ts b/src/claude.ts index a3cb705..2ec3ab5 100644 --- a/src/claude.ts +++ b/src/claude.ts @@ -1,42 +1,30 @@ import Anthropic from "@anthropic-ai/sdk"; +import type { AiProvider } from "./types.js"; -const SYSTEM_PROMPT = `You are an expert QA engineer reviewing a pull request. Your job is to generate clear, actionable QA testing instructions that a human tester can follow. - -Scale your response to the complexity of the changes. A small documentation fix needs just a sentence or two. A large feature needs thorough coverage. Be concise — omit sections that add no value for the specific PR. - -Analyze the provided PR context and produce testing instructions using whichever of these sections are relevant: - -- **Summary** — What the PR changes and why (1-3 sentences). -- **Test Environment Setup** — Prerequisites or setup steps, if any beyond the standard dev environment. Omit if none. -- **Test Scenarios** — Numbered test cases with steps and expected results. Focus on the most important paths; don't enumerate the obvious. -- **Regression Risks** — Areas that might break as a side effect. Omit if the changes are well-isolated. -- **Things to Watch For** — Edge cases or concerns spotted in the code. Omit if nothing stands out. - -Be specific and practical. Reference actual file names, function names, and UI elements from the PR when possible.`; - -export async function generateQAInstructions( +export function createAnthropicProvider( apiKey: string, model: string, - promptContext: string, -): Promise { +): AiProvider { const client = new Anthropic({ apiKey }); - const response = await client.messages.create({ - model, - max_tokens: 4096, - system: SYSTEM_PROMPT, - messages: [ - { - role: "user", - content: promptContext, - }, - ], - }); - - const textBlock = response.content.find((block) => block.type === "text"); - if (!textBlock || textBlock.type !== "text") { - throw new Error("No text content in Claude response"); - } - - return textBlock.text; + return { + async generateQAInstructions( + systemPrompt: string, + userPrompt: string, + ): Promise { + const response = await client.messages.create({ + model, + max_tokens: 4096, + system: systemPrompt, + messages: [{ role: "user", content: userPrompt }], + }); + + const textBlock = response.content.find((block) => block.type === "text"); + if (!textBlock || textBlock.type !== "text") { + throw new Error("No text content in Claude response"); + } + + return textBlock.text; + }, + }; } diff --git a/src/constants.ts b/src/constants.ts index 4c4a2ec..810dd64 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -1,9 +1,52 @@ export const COMMENT_MARKER = ""; -export const DEFAULT_MODEL = "claude-sonnet-4-5-20250929"; +export const VALID_PROVIDERS = ["github-models", "anthropic"] as const; +export type Provider = (typeof VALID_PROVIDERS)[number]; -export const MAX_DIFF_CHARS = 80_000; -export const MAX_CHANGED_FILES_CHARS = 60_000; -export const MAX_FILE_CHARS = 10_000; -export const MAX_FILE_TREE_CHARS = 20_000; -export const MAX_TOTAL_CHARS = 180_000; +export const DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-5-20250929"; +export const DEFAULT_GITHUB_MODELS_MODEL = "openai/gpt-4o"; + +export const GITHUB_MODELS_BASE_URL = + "https://models.github.ai/inference/chat/completions"; + +export interface ContextLimits { + maxDiffChars: number; + maxChangedFilesChars: number; + maxFileChars: number; + maxFileTreeChars: number; + maxTotalChars: number; +} + +export const ANTHROPIC_CONTEXT_LIMITS: ContextLimits = { + maxDiffChars: 80_000, + maxChangedFilesChars: 60_000, + maxFileChars: 10_000, + maxFileTreeChars: 20_000, + maxTotalChars: 180_000, +}; + +// GitHub Models free tier: 8k input tokens for gpt-4o (~32k chars). +// JSON encoding inflates newlines (\n → \\n), so budget ~20k chars +// for the user prompt after reserving room for the system prompt and +// JSON/HTTP overhead. +export const GITHUB_MODELS_CONTEXT_LIMITS: ContextLimits = { + maxDiffChars: 8_000, + maxChangedFilesChars: 6_000, + maxFileChars: 3_000, + maxFileTreeChars: 3_000, + maxTotalChars: 20_000, +}; + +export const SYSTEM_PROMPT = `You are an expert QA engineer reviewing a pull request. Your job is to generate clear, actionable QA testing instructions that a human tester can follow. + +Scale your response to the complexity of the changes. A small documentation fix needs just a sentence or two. A large feature needs thorough coverage. Be concise — omit sections that add no value for the specific PR. + +Analyze the provided PR context and produce testing instructions using whichever of these sections are relevant: + +- **Summary** — What the PR changes and why (1-3 sentences). +- **Test Environment Setup** — Prerequisites or setup steps, if any beyond the standard dev environment. Omit if none. +- **Test Scenarios** — Numbered test cases with steps and expected results. Focus on the most important paths; don't enumerate the obvious. +- **Regression Risks** — Areas that might break as a side effect. Omit if the changes are well-isolated. +- **Things to Watch For** — Edge cases or concerns spotted in the code. Omit if nothing stands out. + +Be specific and practical. Reference actual file names, function names, and UI elements from the PR when possible.`; diff --git a/src/context-builder.test.ts b/src/context-builder.test.ts index cf719f2..d6770e6 100644 --- a/src/context-builder.test.ts +++ b/src/context-builder.test.ts @@ -1,7 +1,9 @@ import { describe, it, expect } from "vitest"; import { buildPromptContext } from "./context-builder.js"; import type { PrData } from "./types.js"; -import { MAX_DIFF_CHARS, MAX_FILE_CHARS } from "./constants.js"; +import { ANTHROPIC_CONTEXT_LIMITS } from "./constants.js"; + +const limits = ANTHROPIC_CONTEXT_LIMITS; function makePrData(overrides: Partial = {}): PrData { return { @@ -16,7 +18,7 @@ function makePrData(overrides: Partial = {}): PrData { describe("buildPromptContext", () => { it("includes PR title and description", () => { - const result = buildPromptContext(makePrData(), ""); + const result = buildPromptContext(makePrData(), "", limits); expect(result).toContain("**Title:** Test PR"); expect(result).toContain("Test body"); }); @@ -25,7 +27,7 @@ describe("buildPromptContext", () => { const data = makePrData({ metadata: { title: "Test", body: "", headSha: "abc" }, }); - const result = buildPromptContext(data, ""); + const result = buildPromptContext(data, "", limits); expect(result).not.toContain("**Description:**"); }); @@ -36,7 +38,7 @@ describe("buildPromptContext", () => { { sha: "1234567abcdef", message: "Fix bug" }, ], }); - const result = buildPromptContext(data, ""); + const result = buildPromptContext(data, "", limits); expect(result).toContain("## Commits"); expect(result).toContain("- abcdef1 Initial commit"); expect(result).toContain("- 1234567 Fix bug"); @@ -44,7 +46,7 @@ describe("buildPromptContext", () => { it("includes diff content", () => { const data = makePrData({ diff: "+added line\n-removed line" }); - const result = buildPromptContext(data, ""); + const result = buildPromptContext(data, "", limits); expect(result).toContain("## Diff"); expect(result).toContain("+added line\n-removed line"); }); @@ -52,9 +54,13 @@ describe("buildPromptContext", () => { it("truncates long diffs at line boundaries", () => { const lines = Array.from({ length: 10000 }, (_, i) => `+line ${i}`); const longDiff = lines.join("\n"); - expect(longDiff.length).toBeGreaterThan(MAX_DIFF_CHARS); + expect(longDiff.length).toBeGreaterThan(limits.maxDiffChars); - const result = buildPromptContext(makePrData({ diff: longDiff }), ""); + const result = buildPromptContext( + makePrData({ diff: longDiff }), + "", + limits, + ); expect(result).toContain("[Content truncated]"); }); @@ -65,7 +71,7 @@ describe("buildPromptContext", () => { { filename: "big.ts", content: "x".repeat(500) }, ], }); - const result = buildPromptContext(data, ""); + const result = buildPromptContext(data, "", limits); expect(result).toContain("## Changed File Contents"); // Big file should appear before small file const bigIdx = result.indexOf("big.ts"); @@ -73,7 +79,7 @@ describe("buildPromptContext", () => { expect(bigIdx).toBeLessThan(smallIdx); }); - it("truncates individual files exceeding MAX_FILE_CHARS", () => { + it("truncates individual files exceeding maxFileChars", () => { const data = makePrData({ changedFiles: [ { @@ -84,21 +90,22 @@ describe("buildPromptContext", () => { }, ], }); - // The content is larger than MAX_FILE_CHARS - expect(data.changedFiles[0].content.length).toBeGreaterThan(MAX_FILE_CHARS); + expect(data.changedFiles[0].content.length).toBeGreaterThan( + limits.maxFileChars, + ); - const result = buildPromptContext(data, ""); + const result = buildPromptContext(data, "", limits); expect(result).toContain("[Content truncated]"); }); - it("stops adding files when total exceeds MAX_CHANGED_FILES_CHARS", () => { + it("stops adding files when total exceeds maxChangedFilesChars", () => { const files = Array.from({ length: 20 }, (_, i) => ({ filename: `file${i}.ts`, - content: "x".repeat(MAX_FILE_CHARS - 100), + content: "x".repeat(limits.maxFileChars - 100), })); const data = makePrData({ changedFiles: files }); - const result = buildPromptContext(data, ""); + const result = buildPromptContext(data, "", limits); // Not all 20 files should be included const includedCount = (result.match(/### file\d+\.ts/g) || []).length; @@ -110,7 +117,7 @@ describe("buildPromptContext", () => { const data = makePrData({ fileTree: ["src/index.ts", "src/utils.ts", "package.json"], }); - const result = buildPromptContext(data, ""); + const result = buildPromptContext(data, "", limits); expect(result).toContain("## Repository File Tree"); expect(result).toContain("src/index.ts"); }); @@ -119,26 +126,25 @@ describe("buildPromptContext", () => { const result = buildPromptContext( makePrData(), "Focus on accessibility testing", + limits, ); expect(result).toContain("## Additional Instructions"); expect(result).toContain("Focus on accessibility testing"); }); it("omits additional instructions when custom prompt is empty", () => { - const result = buildPromptContext(makePrData(), ""); + const result = buildPromptContext(makePrData(), "", limits); expect(result).not.toContain("## Additional Instructions"); }); it("respects overall character cap", () => { - // Create data that would exceed MAX_TOTAL_CHARS + // Create data that would exceed maxTotalChars const data = makePrData({ diff: "x".repeat(80_000), changedFiles: [{ filename: "big.ts", content: "y".repeat(60_000) }], fileTree: Array.from({ length: 5000 }, (_, i) => `path/${i}.ts`), }); - const result = buildPromptContext(data, ""); - // The result includes truncation markers, so it will be slightly over - // but the core content should be capped + const result = buildPromptContext(data, "", limits); expect(result).toContain("[Content truncated]"); }); }); diff --git a/src/context-builder.ts b/src/context-builder.ts index 3a00040..a982195 100644 --- a/src/context-builder.ts +++ b/src/context-builder.ts @@ -1,11 +1,5 @@ import type { PrData } from "./types.js"; -import { - MAX_DIFF_CHARS, - MAX_CHANGED_FILES_CHARS, - MAX_FILE_CHARS, - MAX_FILE_TREE_CHARS, - MAX_TOTAL_CHARS, -} from "./constants.js"; +import type { ContextLimits } from "./constants.js"; function truncateAtLineBreak(text: string, maxChars: number): string { if (text.length <= maxChars) return text; @@ -15,7 +9,11 @@ function truncateAtLineBreak(text: string, maxChars: number): string { return truncated.slice(0, cutPoint) + "\n\n[Content truncated]"; } -export function buildPromptContext(data: PrData, customPrompt: string): string { +export function buildPromptContext( + data: PrData, + customPrompt: string, + limits: ContextLimits, +): string { const sections: string[] = []; // Pr title + description + commits (always included in full) @@ -31,13 +29,13 @@ export function buildPromptContext(data: PrData, customPrompt: string): string { sections.push(`## Commits\n\n${commitLines}`); } - // Diff (up to MAX_DIFF_CHARS) + // Diff (up to maxDiffChars) if (data.diff) { - const truncatedDiff = truncateAtLineBreak(data.diff, MAX_DIFF_CHARS); + const truncatedDiff = truncateAtLineBreak(data.diff, limits.maxDiffChars); sections.push(`## Diff\n\n\`\`\`diff\n${truncatedDiff}\n\`\`\``); } - // Changed file contents (up to MAX_CHANGED_FILES_CHARS total, MAX_FILE_CHARS per file) + // Changed file contents (up to maxChangedFilesChars total, maxFileChars per file) // Sort by content length descending (most-changed files first) if (data.changedFiles.length > 0) { const sorted = [...data.changedFiles].sort( @@ -47,10 +45,10 @@ export function buildPromptContext(data: PrData, customPrompt: string): string { let totalFileChars = 0; for (const file of sorted) { - if (totalFileChars >= MAX_CHANGED_FILES_CHARS) break; + if (totalFileChars >= limits.maxChangedFilesChars) break; - const remaining = MAX_CHANGED_FILES_CHARS - totalFileChars; - const perFileLimit = Math.min(MAX_FILE_CHARS, remaining); + const remaining = limits.maxChangedFilesChars - totalFileChars; + const perFileLimit = Math.min(limits.maxFileChars, remaining); const content = truncateAtLineBreak(file.content, perFileLimit); fileBlocks.push(`### ${file.filename}\n\n\`\`\`\n${content}\n\`\`\``); @@ -60,10 +58,13 @@ export function buildPromptContext(data: PrData, customPrompt: string): string { sections.push(`## Changed File Contents\n\n${fileBlocks.join("\n\n")}`); } - // File tree (lowest priority, up to MAX_FILE_TREE_CHARS) + // File tree (lowest priority, up to maxFileTreeChars) if (data.fileTree.length > 0) { const treeText = data.fileTree.join("\n"); - const truncatedTree = truncateAtLineBreak(treeText, MAX_FILE_TREE_CHARS); + const truncatedTree = truncateAtLineBreak( + treeText, + limits.maxFileTreeChars, + ); sections.push( `## Repository File Tree\n\n\`\`\`\n${truncatedTree}\n\`\`\``, ); @@ -77,8 +78,8 @@ export function buildPromptContext(data: PrData, customPrompt: string): string { let result = sections.join("\n\n"); // Overall cap - if (result.length > MAX_TOTAL_CHARS) { - result = truncateAtLineBreak(result, MAX_TOTAL_CHARS); + if (result.length > limits.maxTotalChars) { + result = truncateAtLineBreak(result, limits.maxTotalChars); } return result; diff --git a/src/github-models.test.ts b/src/github-models.test.ts new file mode 100644 index 0000000..93883f1 --- /dev/null +++ b/src/github-models.test.ts @@ -0,0 +1,88 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +vi.mock("@actions/core"); + +import { createGitHubModelsProvider } from "./github-models.js"; +import { GITHUB_MODELS_BASE_URL } from "./constants.js"; + +const mockFetch = vi.fn(); +vi.stubGlobal("fetch", mockFetch); + +describe("createGitHubModelsProvider", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("sends prompts to GitHub Models and returns content", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + choices: [{ message: { content: "## QA Instructions\n\nTest this." } }], + }), + }); + + const provider = createGitHubModelsProvider("test-token", "openai/gpt-4o"); + const result = await provider.generateQAInstructions( + "system prompt", + "user prompt", + ); + + expect(result).toBe("## QA Instructions\n\nTest this."); + expect(mockFetch).toHaveBeenCalledWith(GITHUB_MODELS_BASE_URL, { + method: "POST", + headers: { + Authorization: "Bearer test-token", + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: "openai/gpt-4o", + messages: [ + { role: "system", content: "system prompt" }, + { role: "user", content: "user prompt" }, + ], + }), + }); + }); + + it("throws when response is not ok", async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 403, + statusText: "Forbidden", + }); + + const provider = createGitHubModelsProvider("bad-token", "openai/gpt-4o"); + + await expect( + provider.generateQAInstructions("system", "user"), + ).rejects.toThrow("GitHub Models API request failed: 403 Forbidden"); + }); + + it("throws when response has no content", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + choices: [{ message: { content: null } }], + }), + }); + + const provider = createGitHubModelsProvider("test-token", "openai/gpt-4o"); + + await expect( + provider.generateQAInstructions("system", "user"), + ).rejects.toThrow("No content in GitHub Models response"); + }); + + it("throws when response has no choices", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ choices: [] }), + }); + + const provider = createGitHubModelsProvider("test-token", "openai/gpt-4o"); + + await expect( + provider.generateQAInstructions("system", "user"), + ).rejects.toThrow("No content in GitHub Models response"); + }); +}); diff --git a/src/github-models.ts b/src/github-models.ts new file mode 100644 index 0000000..e10e815 --- /dev/null +++ b/src/github-models.ts @@ -0,0 +1,58 @@ +import * as core from "@actions/core"; +import { GITHUB_MODELS_BASE_URL } from "./constants.js"; +import type { AiProvider } from "./types.js"; + +interface GitHubModelsChoice { + message?: { + content: string | null; + }; +} + +interface GitHubModelsResponse { + choices: GitHubModelsChoice[]; +} + +export function createGitHubModelsProvider( + token: string, + model: string, +): AiProvider { + return { + async generateQAInstructions( + systemPrompt: string, + userPrompt: string, + ): Promise { + const body = JSON.stringify({ + model, + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userPrompt }, + ], + }); + + core.debug(`GitHub Models request body size: ${body.length} bytes`); + + const response = await fetch(GITHUB_MODELS_BASE_URL, { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + "Content-Type": "application/json", + }, + body, + }); + + if (!response.ok) { + throw new Error( + `GitHub Models API request failed: ${response.status} ${response.statusText}`, + ); + } + + const data = (await response.json()) as GitHubModelsResponse; + const content = data.choices[0]?.message?.content; + if (!content) { + throw new Error("No content in GitHub Models response"); + } + + return content; + }, + }; +} diff --git a/src/index.test.ts b/src/index.test.ts index aa227a4..bc0d605 100644 --- a/src/index.test.ts +++ b/src/index.test.ts @@ -13,14 +13,16 @@ vi.mock("@actions/github", () => ({ })); vi.mock("./github.js"); vi.mock("./context-builder.js"); -vi.mock("./claude.js"); +vi.mock("./provider-factory.js"); vi.mock("./comment.js"); -import { DEFAULT_MODEL } from "./constants.js"; import * as ghModule from "./github.js"; import * as contextBuilder from "./context-builder.js"; -import * as claude from "./claude.js"; +import * as providerFactory from "./provider-factory.js"; import * as comment from "./comment.js"; +import { GITHUB_MODELS_CONTEXT_LIMITS } from "./constants.js"; + +const mockGenerateQAInstructions = vi.fn(); describe("run", () => { beforeEach(() => { @@ -29,7 +31,8 @@ describe("run", () => { vi.mocked(core.getInput).mockImplementation((name: string) => { const inputs: Record = { "github-token": "fake-token", - "anthropic-api-key": "fake-api-key", + "anthropic-api-key": "", + provider: "github-models", prompt: "", model: "", }; @@ -40,6 +43,14 @@ describe("run", () => { "mock-octokit" as unknown as ReturnType, ); + mockGenerateQAInstructions.mockResolvedValue("QA instructions"); + vi.mocked(providerFactory.createProvider).mockReturnValue({ + generateQAInstructions: mockGenerateQAInstructions, + }); + vi.mocked(providerFactory.getContextLimits).mockReturnValue( + GITHUB_MODELS_CONTEXT_LIMITS, + ); + vi.mocked(ghModule.getPrMetadata).mockResolvedValue({ title: "Test PR", body: "Test body", @@ -54,9 +65,6 @@ describe("run", () => { vi.mocked(contextBuilder.buildPromptContext).mockReturnValue( "prompt context", ); - vi.mocked(claude.generateQAInstructions).mockResolvedValue( - "QA instructions", - ); vi.mocked(comment.postOrUpdateComment).mockResolvedValue(); }); @@ -64,6 +72,16 @@ describe("run", () => { await run(); expect(github.getOctokit).toHaveBeenCalledWith("fake-token"); + expect(providerFactory.createProvider).toHaveBeenCalledWith({ + provider: "github-models", + model: "", + anthropicApiKey: "", + githubToken: "fake-token", + octokit: "mock-octokit", + }); + expect(providerFactory.getContextLimits).toHaveBeenCalledWith( + "github-models", + ); expect(ghModule.getPrMetadata).toHaveBeenCalledWith( "mock-octokit", "test-owner", @@ -94,10 +112,10 @@ describe("run", () => { commits: [{ sha: "abc123", message: "commit" }], }, "", + GITHUB_MODELS_CONTEXT_LIMITS, ); - expect(claude.generateQAInstructions).toHaveBeenCalledWith( - "fake-api-key", - DEFAULT_MODEL, + expect(mockGenerateQAInstructions).toHaveBeenCalledWith( + expect.any(String), "prompt context", ); expect(comment.postOrUpdateComment).toHaveBeenCalledWith( @@ -113,11 +131,12 @@ describe("run", () => { ); }); - it("uses custom model when provided", async () => { + it("passes anthropic provider and API key when configured", async () => { vi.mocked(core.getInput).mockImplementation((name: string) => { const inputs: Record = { "github-token": "fake-token", "anthropic-api-key": "fake-api-key", + provider: "anthropic", prompt: "", model: "claude-opus-4-20250514", }; @@ -126,15 +145,35 @@ describe("run", () => { await run(); - expect(claude.generateQAInstructions).toHaveBeenCalledWith( - "fake-api-key", - "claude-opus-4-20250514", - "prompt context", + expect(providerFactory.createProvider).toHaveBeenCalledWith({ + provider: "anthropic", + model: "claude-opus-4-20250514", + anthropicApiKey: "fake-api-key", + githubToken: "fake-token", + octokit: "mock-octokit", + }); + }); + + it("defaults provider to github-models when input is empty", async () => { + vi.mocked(core.getInput).mockImplementation((name: string) => { + const inputs: Record = { + "github-token": "fake-token", + "anthropic-api-key": "", + provider: "", + prompt: "", + model: "", + }; + return inputs[name] ?? ""; + }); + + await run(); + + expect(providerFactory.createProvider).toHaveBeenCalledWith( + expect.objectContaining({ provider: "github-models" }), ); }); it("fails when not a pull_request event", async () => { - // Override the context to have no pull_request const contextAny = github.context as unknown as Record; const originalPayload = contextAny.payload; contextAny.payload = {}; @@ -145,10 +184,29 @@ describe("run", () => { expect.stringContaining("pull_request event"), ); - // Restore contextAny.payload = originalPayload; }); + it("fails for invalid provider", async () => { + vi.mocked(core.getInput).mockImplementation((name: string) => { + const inputs: Record = { + "github-token": "fake-token", + "anthropic-api-key": "", + provider: "bad", + prompt: "", + model: "", + }; + return inputs[name] ?? ""; + }); + + await run(); + + expect(core.setFailed).toHaveBeenCalledWith( + expect.stringContaining('Invalid provider "bad"'), + ); + expect(providerFactory.createProvider).not.toHaveBeenCalled(); + }); + it("calls setFailed when an error occurs", async () => { vi.mocked(ghModule.getPrMetadata).mockRejectedValue(new Error("API error")); diff --git a/src/index.ts b/src/index.ts index 6e95000..2f4ad51 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,7 @@ import * as core from "@actions/core"; import * as github from "@actions/github"; -import { DEFAULT_MODEL } from "./constants.js"; +import { SYSTEM_PROMPT, VALID_PROVIDERS } from "./constants.js"; +import type { Provider } from "./constants.js"; import { getPrMetadata, getPrDiff, @@ -9,7 +10,7 @@ import { getPrCommits, } from "./github.js"; import { buildPromptContext } from "./context-builder.js"; -import { generateQAInstructions } from "./claude.js"; +import { createProvider, getContextLimits } from "./provider-factory.js"; import { postOrUpdateComment } from "./comment.js"; export async function run(): Promise { @@ -17,11 +18,10 @@ export async function run(): Promise { core.info("QA Instructions Action is running!"); const token = core.getInput("github-token", { required: true }); - const anthropicApiKey = core.getInput("anthropic-api-key", { - required: true, - }); + const provider = core.getInput("provider") || "github-models"; + const anthropicApiKey = core.getInput("anthropic-api-key"); const customPrompt = core.getInput("prompt"); - const model = core.getInput("model") || DEFAULT_MODEL; + const model = core.getInput("model"); const pullRequest = github.context.payload.pull_request; if (!pullRequest) { @@ -36,6 +36,23 @@ export async function run(): Promise { const octokit = github.getOctokit(token); + if (!VALID_PROVIDERS.includes(provider as Provider)) { + core.setFailed( + `Invalid provider "${provider}". Must be one of: ${VALID_PROVIDERS.join(", ")}`, + ); + return; + } + const validatedProvider = provider as Provider; + + const aiProvider = createProvider({ + provider: validatedProvider, + model, + anthropicApiKey, + githubToken: token, + octokit, + }); + const contextLimits = getContextLimits(validatedProvider); + // Fetch Pr metadata first (need headSha for file tree) core.info("Fetching Pr metadata..."); const metadata = await getPrMetadata(octokit, owner, repo, pullNumber); @@ -54,13 +71,13 @@ export async function run(): Promise { const promptContext = buildPromptContext( { metadata, diff, changedFiles, fileTree, commits }, customPrompt, + contextLimits, ); - // Generate QA instructions via Claude + // Generate QA instructions via AI provider core.info("Generating QA instructions..."); - const instructions = await generateQAInstructions( - anthropicApiKey, - model, + const instructions = await aiProvider.generateQAInstructions( + SYSTEM_PROMPT, promptContext, ); diff --git a/src/provider-factory.test.ts b/src/provider-factory.test.ts new file mode 100644 index 0000000..ddfc928 --- /dev/null +++ b/src/provider-factory.test.ts @@ -0,0 +1,136 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +vi.mock("./claude.js", () => ({ + createAnthropicProvider: vi.fn(() => ({ + generateQAInstructions: vi.fn(), + })), +})); + +vi.mock("./github-models.js", () => ({ + createGitHubModelsProvider: vi.fn(() => ({ + generateQAInstructions: vi.fn(), + })), +})); + +import { + resolveModel, + createProvider, + getContextLimits, +} from "./provider-factory.js"; +import { + DEFAULT_ANTHROPIC_MODEL, + DEFAULT_GITHUB_MODELS_MODEL, + ANTHROPIC_CONTEXT_LIMITS, + GITHUB_MODELS_CONTEXT_LIMITS, +} from "./constants.js"; +import { createAnthropicProvider } from "./claude.js"; +import { createGitHubModelsProvider } from "./github-models.js"; +import type { Octokit } from "./types.js"; + +const mockOctokit = {} as unknown as Octokit; + +describe("resolveModel", () => { + it("returns the given model when non-empty", () => { + expect(resolveModel("github-models", "custom-model")).toBe("custom-model"); + expect(resolveModel("anthropic", "custom-model")).toBe("custom-model"); + }); + + it("returns default Anthropic model when model is empty", () => { + expect(resolveModel("anthropic", "")).toBe(DEFAULT_ANTHROPIC_MODEL); + }); + + it("returns default GitHub Models model when model is empty", () => { + expect(resolveModel("github-models", "")).toBe(DEFAULT_GITHUB_MODELS_MODEL); + }); +}); + +describe("getContextLimits", () => { + it("returns Anthropic limits for anthropic provider", () => { + expect(getContextLimits("anthropic")).toBe(ANTHROPIC_CONTEXT_LIMITS); + }); + + it("returns GitHub Models limits for github-models provider", () => { + expect(getContextLimits("github-models")).toBe( + GITHUB_MODELS_CONTEXT_LIMITS, + ); + }); +}); + +describe("createProvider", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("creates a GitHub Models provider with default model", () => { + createProvider({ + provider: "github-models", + model: "", + anthropicApiKey: "", + githubToken: "test-token", + octokit: mockOctokit, + }); + + expect(createGitHubModelsProvider).toHaveBeenCalledWith( + "test-token", + DEFAULT_GITHUB_MODELS_MODEL, + ); + }); + + it("creates an Anthropic provider with default model", () => { + createProvider({ + provider: "anthropic", + model: "", + anthropicApiKey: "test-key", + githubToken: "test-token", + octokit: mockOctokit, + }); + + expect(createAnthropicProvider).toHaveBeenCalledWith( + "test-key", + DEFAULT_ANTHROPIC_MODEL, + ); + }); + + it("uses custom model when provided", () => { + createProvider({ + provider: "anthropic", + model: "claude-opus-4-20250514", + anthropicApiKey: "test-key", + githubToken: "test-token", + octokit: mockOctokit, + }); + + expect(createAnthropicProvider).toHaveBeenCalledWith( + "test-key", + "claude-opus-4-20250514", + ); + }); + + it("throws for invalid provider", () => { + expect(() => + createProvider({ + provider: "invalid", + model: "", + anthropicApiKey: "", + githubToken: "test-token", + octokit: mockOctokit, + }), + ).toThrow( + 'Invalid provider "invalid". Must be one of: github-models, anthropic', + ); + }); + + it("throws when anthropic provider is used without API key", () => { + expect(() => + createProvider({ + provider: "anthropic", + model: "", + anthropicApiKey: "", + githubToken: "test-token", + octokit: mockOctokit, + }), + ).toThrow( + 'The "anthropic-api-key" input is required when provider is "anthropic"', + ); + }); +}); diff --git a/src/provider-factory.ts b/src/provider-factory.ts new file mode 100644 index 0000000..f7a83c8 --- /dev/null +++ b/src/provider-factory.ts @@ -0,0 +1,58 @@ +import { + VALID_PROVIDERS, + DEFAULT_ANTHROPIC_MODEL, + DEFAULT_GITHUB_MODELS_MODEL, + ANTHROPIC_CONTEXT_LIMITS, + GITHUB_MODELS_CONTEXT_LIMITS, +} from "./constants.js"; +import type { Provider, ContextLimits } from "./constants.js"; +import type { AiProvider, Octokit } from "./types.js"; +import { createAnthropicProvider } from "./claude.js"; +import { createGitHubModelsProvider } from "./github-models.js"; + +export interface ProviderConfig { + provider: string; + model: string; + anthropicApiKey: string; + githubToken: string; + octokit: Octokit; +} + +export function getContextLimits(provider: Provider): ContextLimits { + return provider === "anthropic" + ? ANTHROPIC_CONTEXT_LIMITS + : GITHUB_MODELS_CONTEXT_LIMITS; +} + +export function resolveModel(provider: Provider, model: string): string { + if (model) { + return model; + } + return provider === "anthropic" + ? DEFAULT_ANTHROPIC_MODEL + : DEFAULT_GITHUB_MODELS_MODEL; +} + +export function createProvider(config: ProviderConfig): AiProvider { + const { provider, model, anthropicApiKey, githubToken } = config; + + if (!VALID_PROVIDERS.includes(provider as Provider)) { + throw new Error( + `Invalid provider "${provider}". Must be one of: ${VALID_PROVIDERS.join(", ")}`, + ); + } + + const resolvedProvider = provider as Provider; + const resolvedModel = resolveModel(resolvedProvider, model); + + if (resolvedProvider === "anthropic") { + if (!anthropicApiKey) { + throw new Error( + 'The "anthropic-api-key" input is required when provider is "anthropic"', + ); + } + return createAnthropicProvider(anthropicApiKey, resolvedModel); + } + + return createGitHubModelsProvider(githubToken, resolvedModel); +} diff --git a/src/types.ts b/src/types.ts index 93e117d..7ddfd5c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2,6 +2,13 @@ import * as github from "@actions/github"; export type Octokit = ReturnType; +export interface AiProvider { + generateQAInstructions( + systemPrompt: string, + userPrompt: string, + ): Promise; +} + export interface PrMetadata { title: string; body: string;