diff --git a/backend/app/prompts.py b/backend/app/prompts.py index 696a1349..7ed91f31 100644 --- a/backend/app/prompts.py +++ b/backend/app/prompts.py @@ -187,11 +187,20 @@ ``` EXTREMELY Important notes on syntax!!! (PAY ATTENTION TO THIS): -- Make sure to add colour to the diagram!!! This is extremely critical. +- Make sure to add colour to the diagram!!! This is extremely critical. Not Following these rules will result in a syntax error! - In Mermaid.js syntax, we cannot include special characters for nodes without being inside quotes! For example: `EX[/api/process (Backend)]:::api` and `API -->|calls Process()| Backend` are two examples of syntax errors. They should be `EX["/api/process (Backend)"]:::api` and `API -->|"calls Process()"| Backend` respectively. Notice the quotes. This is extremely important. Make sure to include quotes for any string that contains special characters. - In Mermaid.js syntax, you cannot apply a class style directly within a subgraph declaration. For example: `subgraph "Frontend Layer":::frontend` is a syntax error. However, you can apply them to nodes within the subgraph. For example: `Example["Example Node"]:::frontend` is valid, and `class Example1,Example2 frontend` is valid. -- In Mermaid.js syntax, there cannot be spaces in the relationship label names. For example: `A -->| "example relationship" | B` is a syntax error. It should be `A -->|"example relationship"| B` +- In Mermaid.js syntax, connections should be following the format `A -->|"relationship"| B` without spaces around the relationship label. For example: `A -->|"relationship"| B` is valid, and `A -->| "relationship" | B` is a syntax error. +- In Mermaid.js syntax, there cannot be spaces in the relationship label names. For example: `A -->| "example relationship" | B` is a syntax error. It should be `A -->|"example relationship"| B`. - In Mermaid.js syntax, you cannot give subgraphs an alias like nodes. For example: `subgraph A "Layer A"` is a syntax error. It should be `subgraph "Layer A"` +- In Mermaid.js syntax, you cannot use "direction TD", replace "direction TD" with "direction TB" everwhere neeeded. Very critical information , remember it. + -- Example `subgraph "Layer A" direction TD` is a syntax error. It should be `subgraph "Layer A" direction TB` +- In Mermaid.js syntax, you cannot use special characters in node names and no examples inside the nodes. + -- Example `A[("Example Node", ())] and A[("Example Node"), ()]` is a syntax error. It should be `A["Example Node"]:::example` +- In Mermaid.js syntax, you cannot use special characters in comments and no examples inside the comments. + -- Example `%% This is an example comment with special characters: @#$%^&*()[]{};:'",.<>?` is a syntax error. It should be `%% This is an example comment with special characters: @#$%^&*()[]{};:'",.<>?` +- In Mermaid.js syntax, you cannot add comments after any code line, For example: AI_ModelProviders_Group -->|"Returns Diagram Code"| BE_App %% Simplified return path is a syntax error. It should be AI_ModelProviders_Group -->|"Returns Diagram Code"| BE_App without any comments after it. It is very important, remember it! +- In Mermaid.js syntax, if you encounter the keyword "end" in the code, make sure not to add any comments after it. For example: `end %% This is an example comment` is a syntax error. It should be `end` without any comments after it! """ # ^^^ note: ive generated a few diagrams now and claude still writes incorrect mermaid code sometimes. in the future, refer to those generated diagrams and add important instructions to the prompt above to avoid those mistakes. examples are best. diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py index 51befcc3..f27e250f 100644 --- a/backend/app/routers/generate.py +++ b/backend/app/routers/generate.py @@ -1,3 +1,4 @@ +from app.services.gemini_service import GeminiService from fastapi import APIRouter, Request, HTTPException from fastapi.responses import StreamingResponse from dotenv import load_dotenv @@ -25,39 +26,88 @@ # Initialize services # claude_service = ClaudeService() +# gemini_service = GeminiService() o4_service = OpenAIo4Service() - # cache github data to avoid double API calls from cost and generate @lru_cache(maxsize=100) -def get_cached_github_data(username: str, repo: str, github_pat: str | None = None): +def get_cached_github_data(username: str, repo: str, github_pat: str | None = None, branch: str = ""): # Create a new service instance for each call with the appropriate PAT current_github_service = GitHubService(pat=github_pat) - default_branch = current_github_service.get_default_branch(username, repo) - if not default_branch: - default_branch = "main" # fallback value + defaultBranch = current_github_service.get_default_branch(username, repo) + if not defaultBranch: + defaultBranch = "main" # fallback value - file_tree = current_github_service.get_github_file_paths_as_list(username, repo) + file_tree = current_github_service.get_github_file_paths_as_list(username, repo, branch) readme = current_github_service.get_github_readme(username, repo) - return {"default_branch": default_branch, "file_tree": file_tree, "readme": readme} + return {"defaultBranch": defaultBranch, "file_tree": file_tree, "readme": readme} +@lru_cache(maxsize=100) +def get_github_repo_branches(username: str, repo: str, github_pat: str | None = None): + """Get all branches of a GitHub repository. + """ + # Create a new service instance for each call with the appropriate PAT + current_github_service = GitHubService(pat=github_pat) + branches = current_github_service.get_github_repo_branches(username, repo) + if not branches: + raise HTTPException(status_code=404, detail="No branches found in repository") + + # Get the default branch as well to return it + defaultBranch = get_cached_github_data(username, repo, github_pat)["defaultBranch"] + if not defaultBranch: + defaultBranch = "main" + return { + "branches": branches["branches"], + "defaultBranch": defaultBranch, + } + class ApiRequest(BaseModel): username: str repo: str instructions: str = "" api_key: str | None = None github_pat: str | None = None + branch: str = "" + page: int = 1 + pageSize: int = 100 +@router.post("/branches") +async def get_repo_branches(request: Request, body: ApiRequest): + try: + # Validate input + if not body.username or not body.repo: + raise HTTPException(status_code=400, detail="Username and repo are required") + + # Create a new service instance with the appropriate PAT + current_github_service = GitHubService(pat=body.github_pat) + + # Get branches with pagination + branches_data = current_github_service.get_github_repo_branches( + body.username, body.repo, body.page, body.pageSize + ) + + # Also get the default branch for compatibility + defaultBranch = current_github_service.get_default_branch(body.username, body.repo) + + # Add defaultBranch to the response for compatibility + branches_data["defaultBranch"] = defaultBranch + + return branches_data + + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @router.post("/cost") # @limiter.limit("5/minute") # TEMP: disable rate limit for growth?? async def get_generation_cost(request: Request, body: ApiRequest): try: # Get file tree and README content - github_data = get_cached_github_data(body.username, body.repo, body.github_pat) + github_data = get_cached_github_data(body.username, body.repo, body.github_pat, body.branch) file_tree = github_data["file_tree"] readme = github_data["readme"] @@ -136,9 +186,9 @@ async def event_generator(): try: # Get cached github data github_data = get_cached_github_data( - body.username, body.repo, body.github_pat + body.username, body.repo, body.github_pat, body.branch ) - default_branch = github_data["default_branch"] + defaultBranch = github_data["defaultBranch"] file_tree = github_data["file_tree"] readme = github_data["readme"] @@ -243,7 +293,7 @@ async def event_generator(): return processed_diagram = process_click_events( - mermaid_code, body.username, body.repo, default_branch + mermaid_code, body.username, body.repo, defaultBranch ) # Send final result diff --git a/backend/app/services/gemini_service.py b/backend/app/services/gemini_service.py new file mode 100644 index 00000000..0d1bc3fc --- /dev/null +++ b/backend/app/services/gemini_service.py @@ -0,0 +1,112 @@ +from dotenv import load_dotenv +from app.utils.format_message import format_user_message +import os +import aiohttp +import json +from typing import AsyncGenerator + +load_dotenv() + +class GeminiService: + def __init__(self): + self.api_key = os.getenv("GEMINI_API_KEY") + self.base_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent" + + def call_gemini_api( + self, + system_prompt: str, + data: dict, + api_key: str | None = None, + ) -> str: + """ + Makes an API call to Gemini and returns the response. + Args: + system_prompt (str): The instruction/system prompt + data (dict): Dictionary of variables to format into the user message + api_key (str | None): Optional custom API key + Returns: + str: Gemini's response text + """ + user_message = format_user_message(data) + key = api_key or self.api_key + if not key: + raise ValueError("Gemini API key is missing. Please set GEMINI_API_KEY in your environment or provide api_key.") + headers = { + "Content-Type": "application/json", + } + params = {"key": str(key)} + payload = { + "contents": [ + {"role": "user", "parts": [{"text": f"{system_prompt}\n{user_message}"}]} + ] + } + try: + import requests + response = requests.post(self.base_url, headers=headers, params=params, json=payload) + response.raise_for_status() + result = response.json() + return result["candidates"][0]["content"]["parts"][0]["text"] + except Exception as e: + print(f"Error in Gemini API call: {str(e)}") + raise + + async def call_gemini_api_stream( + self, + system_prompt: str, + data: dict, + api_key: str | None = None, + ) -> AsyncGenerator[str, None]: + """ + Makes a streaming API call to Gemini and yields the responses. + Args: + system_prompt (str): The instruction/system prompt + data (dict): Dictionary of variables to format into the user message + api_key (str | None): Optional custom API key + Yields: + str: Chunks of Gemini's response text + """ + user_message = format_user_message(data) + key = api_key or self.api_key + if not key: + raise ValueError("Gemini API key is missing. Please set GEMINI_API_KEY in your environment or provide api_key.") + headers = { + "Content-Type": "application/json", + } + params = {"key": str(key)} + payload = { + "contents": [ + {"role": "user", "parts": [{"text": f"{system_prompt}\n{user_message}"}]} + ] + } + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.base_url, headers=headers, params=params, json=payload) as response: + if response.status != 200: + error_text = await response.text() + print(f"Error response: {error_text}") + raise ValueError(f"Gemini API returned status code {response.status}: {error_text}") + response_text = await response.text() + try: + data = json.loads(response_text) + text = data["candidates"][0]["content"]["parts"][0]["text"] + if text: + yield text + except Exception as e: + print(f"Error parsing Gemini response: {e}") + except aiohttp.ClientError as e: + print(f"Connection error: {str(e)}") + raise ValueError(f"Failed to connect to Gemini API: {str(e)}") + except Exception as e: + print(f"Unexpected error in streaming API call: {str(e)}") + raise + + def count_tokens(self, prompt: str) -> int: + """ + Counts the number of tokens in a prompt. + Args: + prompt (str): The prompt to count tokens for + Returns: + int: Estimated number of input tokens + """ + # Gemini does not have a public tokenizer, so we approximate by whitespace splitting + return len(prompt.split()) diff --git a/backend/app/services/github_service.py b/backend/app/services/github_service.py index 33a4d429..be0c327d 100644 --- a/backend/app/services/github_service.py +++ b/backend/app/services/github_service.py @@ -98,6 +98,64 @@ def _check_repository_exists(self, username, repo): f"Failed to check repository: {response.status_code}, {response.json()}" ) + def get_github_repo_branches(self, username, repo, page=1, pageSize=100): + """ + Get branches of a GitHub repository with pagination. + + Args: + username (str): The GitHub username or organization name + repo (str): The repository name + page (int): Page number to fetch (default: 1) + pageSize (int): Number of branches per page (max: 100, default: 100) + + Returns: + dict: A dictionary containing branch names, pagination info, and default branch. + """ + self._check_repository_exists(username, repo) + + # Ensure pageSize doesn't exceed GitHub's limit + pageSize = min(pageSize, 100) + + api_url = f"https://api.github.com/repos/{username}/{repo}/branches" + params = { + "page": page, + "pageSize": pageSize + } + + response = requests.get(api_url, headers=self._get_headers(), params=params) + + if response.status_code == 200: + branches_data = response.json() + branches = [branch["name"] for branch in branches_data] + + # Parse pagination info from headers + link_header = response.headers.get('Link', '') + has_next = 'rel="next"' in link_header + + # Get total count if available (not always provided by GitHub) + total_count = None + if 'Link' in response.headers: + # Try to extract total from last page link if available + import re + last_match = re.search(r'page=(\d+)>; rel="last"', link_header) + if last_match: + last_page = int(last_match.group(1)) + # Estimate total (this is approximate) + total_count = (last_page - 1) * pageSize + len(branches_data) + + return { + "branches": branches, + "pagination": { + "current_page": page, + "has_next": has_next, + "total_count": total_count + } + } + + raise Exception( + f"Failed to fetch branches: {response.status_code}, {response.json()}" + ) + def get_default_branch(self, username, repo): """Get the default branch of the repository.""" api_url = f"https://api.github.com/repos/{username}/{repo}" @@ -107,7 +165,7 @@ def get_default_branch(self, username, repo): return response.json().get("default_branch") return None - def get_github_file_paths_as_list(self, username, repo): + def get_github_file_paths_as_list(self, username, repo, branch): """ Fetches the file tree of an open-source GitHub repository, excluding static files and generated code. @@ -160,8 +218,11 @@ def should_include_file(path): return not any(pattern in path.lower() for pattern in excluded_patterns) - # Try to get the default branch first - branch = self.get_default_branch(username, repo) + #if the branch is empty, try to get the default branch + if not branch: + branch = self.get_default_branch(username, repo) + + # Finding the file tree for the specified branch if branch: api_url = f"https://api.github.com/repos/{ username}/{repo}/git/trees/{branch}?recursive=1" diff --git a/src/app/[username]/[repo]/page.tsx b/src/app/[username]/[repo]/page.tsx index 064efa8c..6f559265 100644 --- a/src/app/[username]/[repo]/page.tsx +++ b/src/app/[username]/[repo]/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useParams } from "next/navigation"; +import { useParams, useSearchParams } from "next/navigation"; import MainCard from "~/components/main-card"; import Loading from "~/components/loading"; import MermaidChart from "~/components/mermaid-diagram"; @@ -13,6 +13,7 @@ import { useStarReminder } from "~/hooks/useStarReminder"; export default function Repo() { const [zoomingEnabled, setZoomingEnabled] = useState(false); const params = useParams<{ username: string; repo: string }>(); + const branch = useSearchParams().get("branch") ?? ""; // Use the star reminder hook useStarReminder(); @@ -31,8 +32,8 @@ export default function Repo() { handleCloseApiKeyDialog, handleOpenApiKeyDialog, handleExportImage, - state, - } = useDiagram(params.username.toLowerCase(), params.repo.toLowerCase()); + state + } = useDiagram(params.username.toLowerCase(), params.repo.toLowerCase(), branch); return (
@@ -41,6 +42,7 @@ export default function Repo() { isHome={false} username={params.username.toLowerCase()} repo={params.repo.toLowerCase()} + branch={branch} showCustomization={!loading && !error} onModify={handleModify} onRegenerate={handleRegenerate} diff --git a/src/app/_actions/cache.ts b/src/app/_actions/cache.ts index 28d4720d..37ed480b 100644 --- a/src/app/_actions/cache.ts +++ b/src/app/_actions/cache.ts @@ -42,6 +42,7 @@ export async function getCachedExplanation(username: string, repo: string) { export async function cacheDiagramAndExplanation( username: string, repo: string, + branch: string, diagram: string, explanation: string, usedOwnKey = false, @@ -55,6 +56,7 @@ export async function cacheDiagramAndExplanation( diagram, explanation, usedOwnKey, + branch, }) .onConflictDoUpdate({ target: [diagramCache.username, diagramCache.repo], @@ -63,6 +65,7 @@ export async function cacheDiagramAndExplanation( explanation, usedOwnKey, updatedAt: new Date(), + branch, }, }); } catch (error) { diff --git a/src/app/page.tsx b/src/app/page.tsx index 7ab825a0..16b4bcac 100644 --- a/src/app/page.tsx +++ b/src/app/page.tsx @@ -3,23 +3,27 @@ import Hero from "~/components/hero"; export default function HomePage() { return ( -
+
-
-

+

+

Turn any GitHub repository into an interactive diagram for visualization.

-

+

This is useful for quickly visualizing projects.

-

+

You can also replace 'hub' with 'diagram' in any Github URL

+

+ You can also add '?branch=branch_name' to the URL to + visualize a specific branch. +

-
+
diff --git a/src/components/main-card.tsx b/src/components/main-card.tsx index d3fbbac9..3c85f016 100644 --- a/src/components/main-card.tsx +++ b/src/components/main-card.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback } from "react"; import { useRouter } from "next/navigation"; import { Card } from "~/components/ui/card"; import { Input } from "~/components/ui/input"; @@ -12,11 +12,14 @@ import { exampleRepos } from "~/lib/exampleRepos"; import { ExportDropdown } from "./export-dropdown"; import { ChevronUp, ChevronDown } from "lucide-react"; import { Switch } from "~/components/ui/switch"; +import { Dropdown } from "./ui/dropdown"; +import { getRepoBranches } from "~/lib/fetch-backend"; interface MainCardProps { isHome?: boolean; username?: string; repo?: string; + branch?: string; showCustomization?: boolean; onModify?: (instructions: string) => void; onRegenerate?: (instructions: string) => void; @@ -32,6 +35,7 @@ export default function MainCard({ isHome = true, username, repo, + branch, showCustomization, onModify, onRegenerate, @@ -43,28 +47,151 @@ export default function MainCard({ loading, }: MainCardProps) { const [repoUrl, setRepoUrl] = useState(""); + const [debouncedRepoUrl, setDebouncedRepoUrl] = useState(""); const [error, setError] = useState(""); + const [selectedBranch, setSelectedBranch] = useState(""); + const [branches, setBranches] = useState([]); + const [loadingBranches, setLoadingBranches] = useState(false); + const [loadingMoreBranches, setLoadingMoreBranches] = useState(false); + const pageSize = 50; + const [pagination, setPagination] = useState({ + currentPage: 1, + hasNext: false, + }); const [activeDropdown, setActiveDropdown] = useState< "customize" | "export" | null >(null); const router = useRouter(); + const fetchBranches = useCallback( + async (page: number) => { + if (!debouncedRepoUrl) { + setBranches([]); + setSelectedBranch(""); + setError(""); + setLoadingBranches(false); + return; + } + + const { sanitizedUsername, sanitizedRepo } = verifyRepoUrl(debouncedRepoUrl) ?? {}; + + if (!sanitizedUsername || !sanitizedRepo) { + setError("Invalid repository URL format"); + setLoadingBranches(false); + return; + } + + if (page === 1) { + setLoadingBranches(true); + setBranches([]); + setSelectedBranch(""); + setError(""); + setPagination({ + currentPage: 1, + hasNext: false, + }); + } else { + setLoadingMoreBranches(true); + } + + try { + const githubPat = + localStorage.getItem("github_pat") ?? process.env.GITHUB_PAT; + const branchList = await getRepoBranches( + sanitizedUsername, + sanitizedRepo, + githubPat, + page, + pageSize, + ); + + if (branchList.error) { + setError(branchList.error); + if (page === 1) { + setBranches([]); + setSelectedBranch(""); + } + return; + } + + if (page == 1) { + const branches = branchList.branches ?? []; + const defaultBranch = branchList.defaultBranch; + if (defaultBranch && !branches.includes(defaultBranch)) { + setBranches([defaultBranch, ...branches]); + } else { + setBranches(branches); + } + } else { + setBranches((prev) => { + const newBranches = branchList.branches ?? []; + return Array.from(new Set([...prev, ...newBranches])); + }); + } + + if (branchList.pagination) { + setPagination({ + currentPage: branchList.pagination.current_page, + hasNext: branchList.pagination.has_next, + }); + } + + setSelectedBranch(branch ?? branchList.defaultBranch ?? ""); + setError(""); + } catch (error) { + setError(error as string); + if (page === 1) { + setBranches([]); + setSelectedBranch(""); + } + } finally { + if (page === 1) { + setLoadingBranches(false); + } else { + setLoadingMoreBranches(false); + } + } + }, + [debouncedRepoUrl, branch], + ); + + useEffect(() => { + void fetchBranches(1); + }, [fetchBranches]); + + const loadMoreBranches = useCallback(async () => { + if (!debouncedRepoUrl || !pagination.hasNext || loadingMoreBranches) return; + await fetchBranches(pagination.currentPage + 1); + }, [ + debouncedRepoUrl, + pagination.hasNext, + pagination.currentPage, + loadingMoreBranches, + fetchBranches, + ]); + useEffect(() => { if (username && repo) { setRepoUrl(`https://github.com/${username}/${repo}`); } }, [username, repo]); + // Debouncing + useEffect(() => { + const handler = setTimeout(() => { + setDebouncedRepoUrl(repoUrl); + }, 600); + return () => clearTimeout(handler); + }, [repoUrl]); + useEffect(() => { if (loading) { setActiveDropdown(null); } }, [loading]); - const handleSubmit = (e: React.FormEvent) => { - e.preventDefault(); - setError(""); - + //verify the repoUrl format and extract username and repo + const verifyRepoUrl = (repoUrl: string) => { const githubUrlPattern = /^https?:\/\/github\.com\/([a-zA-Z0-9-_]+)\/([a-zA-Z0-9-_\.]+)\/?$/; const match = githubUrlPattern.exec(repoUrl.trim()); @@ -81,7 +208,25 @@ export default function MainCard({ } const sanitizedUsername = encodeURIComponent(username); const sanitizedRepo = encodeURIComponent(repo); - router.push(`/${sanitizedUsername}/${sanitizedRepo}`); + + return { sanitizedUsername, sanitizedRepo }; + } + + const handleSubmit = (e: React.FormEvent) => { + if(loadingBranches) { + setError("Please wait for branches to load"); + return; + } + e.preventDefault(); + setError(""); + + const { sanitizedUsername, sanitizedRepo } = verifyRepoUrl(repoUrl) ?? {}; + if (!sanitizedUsername || !sanitizedRepo) { + return; // Error will be set in verifyRepoUrl + } + + const branchQuery = `?branch=${encodeURIComponent(selectedBranch)}`; + router.push(`/${sanitizedUsername}/${sanitizedRepo}${branchQuery}`); }; const handleExampleClick = (repoPath: string, e: React.MouseEvent) => { @@ -96,7 +241,7 @@ export default function MainCard({ return (
-
+
setRepoUrl(e.target.value)} required /> +
+
+ + + + + + + {branches.length === 0 ? ( + No branches found. + ) : ( + + {branches.map((branch) => ( + { + onSelectBranch(branch); + setOpen(false); + }} + > + {branch} + + + ))} + + )} + {hasMoreBranches && ( +
+ {loadingMoreBranches + ? "Loading more branches..." + : "End of list"} +
+ )} +
+
+
+ + ); +}; diff --git a/src/components/ui/popover.tsx b/src/components/ui/popover.tsx new file mode 100644 index 00000000..e7bd907f --- /dev/null +++ b/src/components/ui/popover.tsx @@ -0,0 +1,31 @@ +"use client" + +import * as React from "react" +import * as PopoverPrimitive from "@radix-ui/react-popover" + +import { cn } from "~/lib/utils" + +const Popover = PopoverPrimitive.Root + +const PopoverTrigger = PopoverPrimitive.Trigger + +const PopoverContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, align = "start", sideOffset = 4, ...props }, ref) => ( + + + +)) +PopoverContent.displayName = PopoverPrimitive.Content.displayName + +export { Popover, PopoverTrigger, PopoverContent } diff --git a/src/hooks/useDiagram.ts b/src/hooks/useDiagram.ts index a763f0b9..75e0826b 100644 --- a/src/hooks/useDiagram.ts +++ b/src/hooks/useDiagram.ts @@ -39,7 +39,7 @@ interface StreamResponse { error?: string; } -export function useDiagram(username: string, repo: string) { +export function useDiagram(username: string, repo: string, branch: string) { const [diagram, setDiagram] = useState(""); const [error, setError] = useState(""); const [loading, setLoading] = useState(true); @@ -73,6 +73,7 @@ export function useDiagram(username: string, repo: string) { body: JSON.stringify({ username, repo, + branch, instructions, api_key: localStorage.getItem("openai_key") ?? undefined, github_pat: githubPat, @@ -226,7 +227,7 @@ export function useDiagram(username: string, repo: string) { setLoading(false); } }, - [username, repo, hasUsedFreeGeneration], + [username, repo, branch, hasUsedFreeGeneration], ); useEffect(() => { @@ -236,6 +237,7 @@ export function useDiagram(username: string, repo: string) { void cacheDiagramAndExplanation( username, repo, + branch, state.diagram, state.explanation ?? "No explanation provided", hasApiKey, @@ -247,7 +249,7 @@ export function useDiagram(username: string, repo: string) { } else if (state.status === "error") { setLoading(false); } - }, [state.status, state.diagram, username, repo, state.explanation]); + }, [state.status, state.diagram, username, repo, state.explanation, branch]); const getDiagram = useCallback(async () => { setLoading(true); @@ -257,7 +259,7 @@ export function useDiagram(username: string, repo: string) { try { // Check cache first - always allow access to cached diagrams const cached = await getCachedDiagram(username, repo); - const github_pat = localStorage.getItem("github_pat"); + const github_pat = localStorage.getItem("github_pat") ?? undefined; if (cached) { setDiagram(cached); @@ -281,8 +283,9 @@ export function useDiagram(username: string, repo: string) { const costEstimate = await getCostOfGeneration( username, repo, + branch, "", - github_pat ?? undefined, + github_pat, ); if (costEstimate.error) { @@ -308,7 +311,7 @@ export function useDiagram(username: string, repo: string) { } finally { setLoading(false); } - }, [username, repo, generateDiagram]); + }, [username, repo, branch, generateDiagram]); useEffect(() => { void getDiagram(); @@ -364,7 +367,7 @@ export function useDiagram(username: string, repo: string) { // return; // } - const costEstimate = await getCostOfGeneration(username, repo, ""); + const costEstimate = await getCostOfGeneration(username, repo, "", branch); if (costEstimate.error) { console.error("Cost estimation failed:", costEstimate.error); diff --git a/src/lib/fetch-backend.ts b/src/lib/fetch-backend.ts index e0eda932..c899f595 100644 --- a/src/lib/fetch-backend.ts +++ b/src/lib/fetch-backend.ts @@ -22,9 +22,59 @@ interface CostApiResponse { cost?: string; } +interface BranchesApiResponse { + branches?: string[]; + defaultBranch?: string; + pagination?: { + current_page: number; + has_next: boolean; + total_count?: number; + }; + error?: string; +} + +export async function getRepoBranches( + username: string, + repo: string, + github_pat?: string, + page = 1, + pageSize = 100, +): Promise { + try { + const baseUrl = + process.env.NEXT_PUBLIC_API_DEV_URL ?? "https://api.gitdiagram.com"; + const url = new URL(`${baseUrl}/generate/branches`); + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + username, + repo, + github_pat, + page, + pageSize + }), + }); + + if (response.status === 429) { + return { error: "Rate limit exceeded. Please try again later." }; + } + + const data = (await response.json()) as BranchesApiResponse; + return data; + } catch (error) { + console.error("Error fetching branches:", error); + return { error: "Failed to fetch branches. Please try again later." }; + } +} + export async function generateAndCacheDiagram( username: string, repo: string, + branch:string, github_pat?: string, instructions?: string, api_key?: string, @@ -45,6 +95,7 @@ export async function generateAndCacheDiagram( instructions: instructions ?? "", api_key: api_key, github_pat: github_pat, + branch: branch ?? "", }), }); @@ -62,6 +113,7 @@ export async function generateAndCacheDiagram( await cacheDiagramAndExplanation( username, repo, + branch, data.diagram!, data.explanation!, ); @@ -76,6 +128,7 @@ export async function modifyAndCacheDiagram( username: string, repo: string, instructions: string, + branch: string ): Promise { try { // First get the current diagram from cache @@ -101,6 +154,7 @@ export async function modifyAndCacheDiagram( instructions: instructions, current_diagram: currentDiagram, explanation: explanation, + branch: branch ?? "", }), }); @@ -118,6 +172,7 @@ export async function modifyAndCacheDiagram( await cacheDiagramAndExplanation( username, repo, + branch, data.diagram!, explanation, ); @@ -131,6 +186,7 @@ export async function modifyAndCacheDiagram( export async function getCostOfGeneration( username: string, repo: string, + branch: string, instructions: string, github_pat?: string, ): Promise { @@ -149,6 +205,7 @@ export async function getCostOfGeneration( repo, github_pat: github_pat, instructions: instructions ?? "", + branch:branch ?? "", }), }); diff --git a/src/server/db/schema.ts b/src/server/db/schema.ts index 1d2d9439..068ab130 100644 --- a/src/server/db/schema.ts +++ b/src/server/db/schema.ts @@ -21,10 +21,10 @@ export const createTable = pgTableCreator((name) => `gitdiagram_${name}`); export const diagramCache = createTable( "diagram_cache", { - username: varchar("username", { length: 256 }).notNull(), + username: varchar("username", { length: 512 }).notNull(), repo: varchar("repo", { length: 256 }).notNull(), - diagram: varchar("diagram", { length: 10000 }).notNull(), // Adjust length as needed - explanation: varchar("explanation", { length: 10000 }) + diagram: varchar("diagram", { length: 100000 }).notNull(), // Adjust length as needed + explanation: varchar("explanation", { length: 50000 }) .notNull() .default("No explanation provided"), // Default explanation to avoid data loss of existing rows createdAt: timestamp("created_at", { withTimezone: true }) @@ -34,8 +34,9 @@ export const diagramCache = createTable( () => new Date(), ), usedOwnKey: boolean("used_own_key").default(false), + branch: varchar("branch", { length: 512 }).notNull() }, (table) => ({ - pk: primaryKey({ columns: [table.username, table.repo] }), + pk: primaryKey({ columns: [table.username, table.repo, table.branch] }), }), );