Skip to content

Commit 466d884

Browse files
committed
Integrate OpenAI support to extract the Package URL (PURL), as well as the affected and fixed versions.
Signed-off-by: ziad hany <ziadhany2016@gmail.com>
1 parent d929a80 commit 466d884

1 file changed

Lines changed: 115 additions & 156 deletions

File tree

Lines changed: 115 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,116 @@
1-
import json
2-
import re
3-
from pathlib import Path
41
from typing import Iterable
2+
from typing import List
53

6-
import chromadb
74
from django.db.models import QuerySet
8-
from langchain_chroma import Chroma
9-
from langchain_ollama import OllamaLLM
10-
5+
from pydantic import BaseModel
6+
from pydantic_ai import Agent
7+
from pydantic_ai.models.openai import OpenAIModel
8+
from pydantic_ai.providers.openai import OpenAIProvider
119
from univers.version_range import RANGE_CLASS_BY_SCHEMES
1210

13-
from vulnerabilities.importer import AffectedPackage, AdvisoryData
14-
from vulnerabilities.improver import Inference, MAX_CONFIDENCE, Improver
11+
from vulnerabilities.importer import AdvisoryData
12+
from vulnerabilities.importer import AffectedPackage
13+
from vulnerabilities.improver import MAX_CONFIDENCE
14+
from vulnerabilities.improver import Improver
15+
from vulnerabilities.improver import Inference
1516
from vulnerabilities.improvers.default import get_exact_purls
1617
from vulnerabilities.models import Advisory
1718
from vulnerablecode.settings import env
18-
from langchain.prompts import PromptTemplate
1919
from packageurl import PackageURL
20-
from langchain_huggingface import HuggingFaceEmbeddings
21-
from langchain_community.document_loaders import UnstructuredMarkdownLoader
22-
from tqdm import tqdm
20+
from pydantic.functional_validators import field_validator
21+
22+
class Purl(BaseModel):
23+
string: str
24+
25+
@field_validator('string')
26+
def check_valid_purl(cls, v: str) -> str:
27+
try:
28+
PackageURL.from_string(v)
29+
except Exception as e:
30+
raise ValueError(f"Invalid PURL '{v}': {e}")
31+
return v
32+
33+
class Versions(BaseModel):
34+
affected_versions: List[str]
35+
fixed_versions: List[str]
36+
37+
38+
prompt_purl_extraction = f"""
39+
You are a highly specialized Vulnerability Analysis Assistant. Your task is to analyze the provided vulnerability summary or package name and extract a single valid Package URL (PURL) that conforms to the official PURL specification:
40+
41+
**Component Definitions (Required by PURL Specification):**
42+
- **scheme**: Constant value `pkg`
43+
- **type**: Package type or protocol (e.g., maven, npm, nuget, gem, pypi, rpm, etc.) — must be a known valid type
44+
- **namespace**: A name prefix such as a Maven groupId, Docker image owner, or GitHub user/org (optional and type-specific)
45+
- **name**: Package name (required)
46+
- **version**: Version of the package (optional)
47+
- **qualifiers**: Extra data like OS, arch, etc. (optional and type-specific)
48+
- **subpath**: Subpath within the package (optional)
49+
50+
**Examples of Valid PURLs:**
51+
- pkg:maven/org.apache.apr/apr-util@1.3.5
52+
- pkg:github/apache/apr-util@1.3.5
53+
- pkg:rpm/redhat/apr-util@1.3.5
54+
- pkg:deb/debian/apr-util@1.3.5
55+
56+
**Output Instructions:**
57+
- Identify the most appropriate and valid PURL type for the package if possible.
58+
- If a valid and complete PURL can be constructed, return only:
59+
`{{ "string": "pkg:type/namespace/name@version?qualifiers#subpath" }}`
60+
- If no valid PURL can be constructed or the type is unknown, return:
61+
`{{}}`
62+
- Do not include any other output (no explanation, formatting, or markdown).
63+
"""
64+
65+
prompt_version_extraction = f"""
66+
You are a highly specialized Vulnerability Analysis Assistant. Your task is to analyze the following vulnerability summary and accurately extract the affected and fixed versions of the software.
67+
68+
Instructions:
69+
- Affected Version: Use one of the following formats:
70+
- >= <version>, <= <version>, > <version>, < <version>
71+
- A specific range like <version1> - <version2>
72+
- Fixed Version: Use one of the following formats:
73+
- >= <version>, <= <version>, > <version>, < <version>
74+
- "Not Fixed" if no fixed version is mentioned.
75+
- Ensure accuracy by considering different ways affected and fixed versions might be described in the summary.
76+
- Extract only version-related details without adding any extra information.
77+
78+
Output Format:
79+
```json
80+
{{
81+
"affected_versions": ["<version_condition>", "<version_condition>"],
82+
"fixed_versions": ["<version_condition>", "<version_condition>"]
83+
}}
84+
```
85+
Example:
86+
{{
87+
"affected_versions": [">=1.2.3", "<2.0.0"],
88+
"fixed_versions": ["2.0.0"]
89+
}}
90+
91+
Return only the JSON object without any additional text.
92+
"""
2393

2494
class AISummaryImprover(Improver):
2595
"""
2696
A pipeline for improving vulnerability version extraction using AI.
2797
This pipeline analyzes vulnerability summaries and extracts affected and fixed versions.
2898
"""
2999

30-
llm = OllamaLLM(
31-
model=env.str("OLLAMA_MODEL_NAME"),
32-
base_url=env.str("OLLAMA_BASE_URL")
33-
)
34-
35-
# Initialize embeddings
36-
embeddings = HuggingFaceEmbeddings(
37-
model_name="sentence-transformers/all-MiniLM-L6-v2",
38-
model_kwargs={"device": "cpu"},
39-
encode_kwargs={"normalize_embeddings": True},
40-
)
41-
42-
# Initialize ChromaDB Client (do this once)
43-
chroma_client = chromadb.PersistentClient(path="purl_index")
44-
45-
# Create the vector store using LangChain's Chroma integration
46-
vector_db = Chroma(
47-
client=chroma_client,
48-
collection_name="purl_embeddings",
49-
embedding_function=embeddings,
50-
)
51-
52-
# Check if collection exists and contains documents
53-
existing_docs = vector_db.get()
54-
if existing_docs and existing_docs.get("documents"):
55-
print(f"✅ ChromaDB collection loaded successfully! {len(existing_docs['documents'])} documents found.")
56-
else:
57-
print(f"⚠️ Collection not found or empty. Initializing ChromaDB.")
58-
59-
# Load documents
60-
markdown_path = "/agent/purl_db/PURL.rst"
61-
loader = UnstructuredMarkdownLoader(markdown_path)
62-
docs = loader.load() # This returns a list of Documents
63-
64-
if not docs:
65-
print("❌ No documents loaded. Please check the file path and format.")
66-
else:
67-
print(f"✅ Loaded {len(docs)} documents.")
68-
collection = chroma_client.get_or_create_collection(name="purl_embeddings")
69-
70-
# Index each document by its file name
71-
for i, doc in enumerate(tqdm(docs, desc="Indexing documents")):
72-
file = doc.metadata.get("source", "unknown")
73-
file_name = Path(file).stem
74-
collection.add(
75-
ids=[file_name],
76-
documents=[doc.page_content],
77-
metadatas=[{"file_name": file_name}],
78-
)
100+
openai_model = OpenAIModel('gpt-4o-mini', provider=OpenAIProvider(api_key=env.str("OPENAI_API_KEY")))
101+
102+
# ollama_model = OpenAIModel(
103+
# model_name=env.str("OLLAMA_MODEL_NAME"), provider=OpenAIProvider(openai_client=env.str("OLLAMA_BASE_URL"))
104+
# )
105+
106+
purl_agent = Agent(openai_model,
107+
system_prompt=prompt_purl_extraction,
108+
output_type=Purl)
109+
110+
versions_agent = Agent(openai_model,
111+
system_prompt=prompt_version_extraction,
112+
output_type=Versions)
79113

80-
print("✅ Documents indexed in ChromaDB.")
81114

82115
@property
83116
def interesting_advisories(self) -> QuerySet:
@@ -86,27 +119,24 @@ def interesting_advisories(self) -> QuerySet:
86119
)
87120

88121
def get_inferences(self, advisory_data: AdvisoryData) -> Iterable[Inference]:
89-
"""
90-
"""
91122
if not advisory_data:
92123
return []
93124

94125
if advisory_data.summary:
95126
purl = self.handler_purl(advisory_data.summary)
96127

97-
if not purl:
98-
return
99-
100-
affected_version_range, fixed_version = self.handler_version_ranges(summary=advisory_data.summary,
101-
supported_ecosystem=purl.type)
128+
affected_version_range, fixed_version = self.handler_version_ranges(
129+
summary=advisory_data.summary,
130+
supported_ecosystem=purl.type
131+
)
102132

103133
affected_package = AffectedPackage(
104-
package=purl,
134+
package=PackageURL(type=purl.type, namespace=purl.namespace, name=purl.name),
105135
affected_version_range=affected_version_range,
106136
fixed_version=fixed_version,
107137
)
108-
affected_purls, fixed_purls = get_exact_purls(affected_package)
109138

139+
affected_purls, fixed_purls = get_exact_purls(affected_package)
110140
for fixed_purl in fixed_purls:
111141
yield Inference(
112142
aliases=advisory_data.aliases,
@@ -120,50 +150,14 @@ def get_inferences(self, advisory_data: AdvisoryData) -> Iterable[Inference]:
120150

121151

122152
def handler_version_ranges(self, summary, supported_ecosystem):
123-
"""
124-
"""
125-
version_extraction_prompt = PromptTemplate(
126-
input_variables=["summary"],
127-
template="""
128-
You are a highly specialized Vulnerability Analysis Assistant. Your task is to analyze the following vulnerability summary and accurately extract the affected and fixed versions of the software.
129-
130-
**Vulnerability Summary:**
131-
{summary}
132-
133-
Output Format:
134-
```json
135-
{{
136-
"affected_versions": ["<version_condition>", "<version_condition>"],
137-
"fixed_versions": ["<version_condition>", "<version_condition>"]
138-
}}
139-
```
140-
141-
Instructions:
142-
- Affected Version: Use one of the following formats:
143-
- >= <version>, <= <version>, > <version>, < <version>
144-
- A specific range like <version1> - <version2>
145-
- Fixed Version: Use one of the following formats:
146-
- >= <version>, <= <version>, > <version>, < <version>
147-
- "Not Fixed" if no fixed version is mentioned.
148-
- Ensure accuracy by considering different ways affected and fixed versions might be described in the summary.
149-
- Extract only version-related details without adding any extra information.
150-
151-
Return only the JSON object without any additional text.
152-
""",
153-
)
154-
155-
version_extraction_prompt = version_extraction_prompt.format(summary=summary)
156-
json_text = self.get_llm_result(prompt=version_extraction_prompt)
157-
158-
try:
159-
match = re.search(r'```json\n(.*?)\n```', json_text, re.DOTALL).group(1)
160-
json_data = json.loads(match)
161-
except json.JSONDecodeError as e:
162-
print("Invalid JSON:", e)
163-
json_data = {}
153+
"""Extract affected and fixed version ranges from a vulnerability summary."""
154+
result = self.versions_agent.run_sync(user_prompt=f"""
155+
**Vulnerability Summary:**
156+
{summary}
157+
""")
164158

165-
affected_version_ranges = json_data.get("affected_versions", [])
166-
fixed_version_ranges = json_data.get("fixed_versions", [])
159+
affected_version_ranges = result.output.affected_versions
160+
fixed_version_ranges = result.output.fixed_versions
167161

168162
affected_version_objs = [RANGE_CLASS_BY_SCHEMES[supported_ecosystem].from_string(f"vers:{supported_ecosystem}/" + affected_version_range) for affected_version_range in affected_version_ranges]
169163
fixed_version_objs = [RANGE_CLASS_BY_SCHEMES[supported_ecosystem].from_string(f"vers:{supported_ecosystem}/" + fixed_version_version_range) for fixed_version_version_range in fixed_version_ranges]
@@ -172,46 +166,11 @@ def handler_version_ranges(self, summary, supported_ecosystem):
172166

173167
def handler_purl(self, summary):
174168
"""
169+
Analyze the vulnerability summary and extract a valid Package URL (PURL).
170+
Returns the extracted PURL string or None if not found.
175171
"""
176-
purl_extraction_prompt = PromptTemplate(
177-
input_variables=["summary"],
178-
template="""
179-
You are a highly specialized Vulnerability Analysis Assistant. Your task is to analyze the provided vulnerability summary, and extract a single valid Package URL (PURL) that strictly conforms to the following specification:
180-
181-
**Vulnerability Summary:**
172+
result = self.purl_agent.run_sync(user_prompt=f"""
173+
**Vulnerability Summary:**
182174
{summary}
183-
184-
**Component Definitions:**
185-
- **scheme:** Must be the constant value `pkg` (required).
186-
- **type:** The package type or protocol (e.g., maven, npm, nuget, gem, pypi, etc.) (required).
187-
- **namespace:** A name prefix such as a Maven groupId, Docker image owner, or GitHub user/organization (optional and type-specific).
188-
- **name:** The package name (required).
189-
- **version:** The version of the package (optional).
190-
- **qualifiers:** Extra qualifying data such as an OS, architecture, distro, etc. (optional and type-specific).
191-
- **subpath:** A subpath within the package, relative to the package root (optional).
192-
193-
**Important Requirements:**
194-
- The components must form a hierarchy from the most significant (left) to the least significant (right).
195-
- The PURL must NOT contain a URL authority (i.e., no username, password, host, or port).
196-
- If a namespace segment resembles a host, its interpretation is specific to the package type.
197-
198-
**Output Instructions:**
199-
- If a valid PURL is extracted, return **only** the PURL (and nothing else).
200-
- If no valid PURL is found, return nothing.
201-
Provide the answer strictly based on the above context.
202-
""",
203-
)
204-
# single_doc_content = self.vector_db.get()["documents"][0]
205-
purl_extraction_prompt = purl_extraction_prompt.format(summary=summary)
206-
#context=single_doc_content)
207-
llm_response = self.get_llm_result(prompt=purl_extraction_prompt)
208-
purl_response = re.search(r'pkg:[a-zA-Z0-9._-]+(?:/[a-zA-Z0-9._-]+)+', llm_response).group(0)
209-
return PackageURL.from_string(purl_response)
210-
211-
def get_llm_result(self, prompt):
212-
"""
213-
"""
214-
response = self.llm.invoke(prompt)
215-
cleaned_result = re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL).strip()
216-
print(cleaned_result)
217-
return cleaned_result
175+
""")
176+
return PackageURL.from_string(result.output.string)

0 commit comments

Comments
 (0)