Skip to content

Commit 50df425

Browse files
committed
Update the pipeline to use bulk
Signed-off-by: ziad hany <ziadhany2016@gmail.com>
1 parent b08c66c commit 50df425

File tree

2 files changed

+91
-45
lines changed

2 files changed

+91
-45
lines changed

vulnerabilities/pipelines/v2_improvers/reference_collect_commits.py

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@
88
#
99

1010
from aboutcode.pipeline import LoopProgress
11+
from django.db.models import Prefetch
1112
from packageurl.contrib.purl2url import purl2url
1213
from packageurl.contrib.url2purl import url2purl
1314

1415
from aboutcode.federated import get_core_purl
16+
from vulnerabilities.models import AdvisoryReference
1517
from vulnerabilities.models import AdvisoryV2
18+
from vulnerabilities.models import ImpactedPackage
1619
from vulnerabilities.models import PackageCommitPatch
17-
from vulnerabilities.pipelines import VulnerableCodeBaseImporterPipelineV2
18-
from vulnerabilities.pipes.advisory import VCS_URLS_SUPPORTED_TYPES
20+
from vulnerabilities.models import Patch
21+
from vulnerabilities.pipelines import VulnerableCodePipeline
1922
from vulnerabilities.utils import is_commit
2023

2124

22-
class CollectReferencesFixCommitsPipeline(VulnerableCodeBaseImporterPipelineV2):
25+
class CollectReferencesFixCommitsPipeline(VulnerableCodePipeline):
2326
"""
2427
Improver pipeline to scout References/Patch and create PackageCommitPatch entries.
2528
"""
@@ -30,45 +33,98 @@ class CollectReferencesFixCommitsPipeline(VulnerableCodeBaseImporterPipelineV2):
3033
def steps(cls):
3134
return (cls.collect_and_store_fix_commits,)
3235

33-
def get_vcs_commit(self, url):
34-
"""Extracts and VCS URL and commit hash from URL.
36+
def get_vcs_data(self, url):
37+
"""Extracts a VCS URL and commit hash from URL.
3538
>> get_vcs_commit('https://github.com/aboutcode-org/vulnerablecode/commit/98e516011d6e096e25247b82fc5f196bbeecff10')
36-
('https://github.com/aboutcode-org/vulnerablecode', '98e516011d6e096e25247b82fc5f196bbeecff10')
39+
("pkg:github/aboutcode-org/vulnerablecode", 'https://github.com/aboutcode-org/vulnerablecode', '98e516011d6e096e25247b82fc5f196bbeecff10')
3740
>> get_vcs_commit('https://github.com/aboutcode-org/vulnerablecode/pull/1974')
3841
None
3942
"""
40-
purl = url2purl(url)
41-
if not purl or purl.type not in VCS_URLS_SUPPORTED_TYPES:
42-
return None
43+
try:
44+
purl = url2purl(url)
45+
if not purl:
46+
return
4347

44-
version = getattr(purl, "version", None)
45-
if not version or not is_commit(version):
46-
return None
47-
48-
vcs_url = purl2url(get_core_purl(purl).to_string())
49-
return (vcs_url, version) if vcs_url else None
48+
version = purl.version
49+
if not version or not is_commit(version):
50+
return
51+
base_purl = get_core_purl(purl)
52+
vcs_url = purl2url(base_purl.to_string())
53+
if base_purl and vcs_url and version:
54+
return base_purl, vcs_url, version
55+
except Exception as e:
56+
self.log(f"Invalid URL: url:{url} error:{e}")
5057

5158
def collect_and_store_fix_commits(self):
52-
impacted_packages_advisories = (
53-
AdvisoryV2.objects.filter(impacted_packages__isnull=False)
54-
.prefetch_related("references", "patches", "impacted_packages")
55-
.distinct()
59+
advisories = AdvisoryV2.objects.only("id").prefetch_related(
60+
Prefetch("references", queryset=AdvisoryReference.objects.only("url")),
61+
Prefetch("patches", queryset=Patch.objects.only("patch_url")),
5662
)
5763

58-
progress = LoopProgress(
59-
total_iterations=impacted_packages_advisories.count(), logger=self.log
60-
)
61-
for adv in progress.iter(impacted_packages_advisories.paginated(per_page=500)):
64+
progress = LoopProgress(total_iterations=advisories.count(), logger=self.log)
65+
66+
commit_batch = []
67+
updated_pkg_patch_commit_count = 0
68+
batch_size = 1000
69+
for adv in progress.iter(advisories.paginated(per_page=batch_size)):
6270
urls = {r.url for r in adv.references.all()} | {p.patch_url for p in adv.patches.all()}
63-
impacted_packages = list(adv.impacted_packages.all())
6471

6572
for url in urls:
66-
vcs_data = self.get_vcs_commit(url)
73+
vcs_data = self.get_vcs_data(url)
6774
if not vcs_data:
6875
continue
76+
base_purl, vcs_url, commit_hash = vcs_data
77+
commit_batch.append((str(base_purl), vcs_url, commit_hash, adv.id))
78+
79+
if len(commit_batch) >= batch_size:
80+
updated_pkg_patch_commit_count += self.bulk_commit_batch_update(commit_batch)
81+
commit_batch.clear()
82+
83+
if commit_batch:
84+
updated_pkg_patch_commit_count += self.bulk_commit_batch_update(commit_batch)
85+
commit_batch.clear()
86+
87+
self.log(f"Successfully processed pkg patch commit {updated_pkg_patch_commit_count:,d}")
88+
89+
def bulk_commit_batch_update(self, vcs_data_table):
90+
impact_data = {(row[0], row[3]) for row in vcs_data_table} # base_purl, adv_id
91+
commit_data = {(row[1], row[2]) for row in vcs_data_table} # vcs_url, commit_hash
92+
93+
adv_ids = {aid for _, aid in impact_data}
94+
existing_impacts = ImpactedPackage.objects.filter(advisory_id__in=adv_ids)
95+
existing_impact_pairs = {(ip.base_purl, ip.advisory_id) for ip in existing_impacts}
96+
97+
new_impacts = impact_data - existing_impact_pairs
98+
if new_impacts:
99+
ImpactedPackage.objects.bulk_create(
100+
[ImpactedPackage(base_purl=bp, advisory_id=aid) for bp, aid in new_impacts]
101+
)
102+
103+
PackageCommitPatch.objects.bulk_create(
104+
[
105+
PackageCommitPatch(vcs_url=vcs_url, commit_hash=commit_hash)
106+
for vcs_url, commit_hash in commit_data
107+
],
108+
ignore_conflicts=True,
109+
)
110+
111+
adv_ids = {adv_id for _, adv_id in impact_data}
112+
fetched_impacts = {
113+
(impacted_pkg.base_purl, impacted_pkg.advisory_id): impacted_pkg
114+
for impacted_pkg in ImpactedPackage.objects.filter(advisory_id__in=adv_ids)
115+
}
116+
117+
commit_hashes = {commit_hash for _, commit_hash in commit_data}
118+
fetched_commits = {
119+
(pkg_commit_patch.vcs_url, pkg_commit_patch.commit_hash): pkg_commit_patch
120+
for pkg_commit_patch in PackageCommitPatch.objects.filter(commit_hash__in=commit_hashes)
121+
}
122+
123+
for base_purl, vcs_url, commit_hash, adv_id in vcs_data_table:
124+
impacted_package = fetched_impacts.get((base_purl, adv_id))
125+
package_commit_obj = fetched_commits.get((vcs_url, commit_hash))
126+
127+
if impacted_package and package_commit_obj:
128+
package_commit_obj.fixed_in_impacts.add(impacted_package)
69129

70-
vcs_url, commit_hash = vcs_data
71-
package_commit_obj, _ = PackageCommitPatch.objects.get_or_create(
72-
vcs_url=vcs_url, commit_hash=commit_hash
73-
)
74-
package_commit_obj.fixed_in_impacts.add(*impacted_packages)
130+
return len(vcs_data_table)

vulnerabilities/tests/pipelines/v2_improvers/test_reference_collect_commits_v2.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,25 @@ def test_collect_fix_commits_pipeline_creates_entry():
3030
unique_content_id="11111",
3131
date_collected=datetime.now(),
3232
)
33-
package = PackageV2.objects.create(
34-
type="foo",
35-
name="testpkg",
36-
version="1.0",
37-
)
33+
3834
reference = AdvisoryReference.objects.create(
3935
url="https://github.com/test/testpkg/commit/6bd301819f8f69331a55ae2336c8b111fc933f3d"
4036
)
41-
impact = ImpactedPackage.objects.create(advisory=advisory)
42-
impact.affecting_packages.add(package)
4337
advisory.references.add(reference)
4438

4539
pipeline = CollectReferencesFixCommitsPipeline()
4640
pipeline.collect_and_store_fix_commits()
4741

4842
package_commit_patch = PackageCommitPatch.objects.all()
43+
impacted_packages = advisory.impacted_packages.all()
4944

5045
assert package_commit_patch.count() == 1
46+
assert impacted_packages.count() == 1
47+
5148
fix = package_commit_patch.first()
5249
assert fix.commit_hash == "6bd301819f8f69331a55ae2336c8b111fc933f3d"
5350
assert fix.vcs_url == "https://github.com/test/testpkg"
54-
assert impact.fixed_by_package_commit_patches.count() == 1
51+
assert impacted_packages.first().fixed_by_package_commit_patches.count() == 1
5552

5653

5754
@pytest.mark.django_db
@@ -64,13 +61,6 @@ def test_collect_fix_commits_pipeline_skips_non_commit_urls():
6461
unique_content_id="11111",
6562
date_collected=datetime.now(),
6663
)
67-
package = PackageV2.objects.create(
68-
type="pypi",
69-
name="otherpkg",
70-
version="2.0",
71-
)
72-
impact = ImpactedPackage.objects.create(advisory=advisory)
73-
impact.affecting_packages.add(package)
7464

7565
reference = AdvisoryReference.objects.create(
7666
url="https://github.com/test/testpkg/issues/12"

0 commit comments

Comments
 (0)