Skip to content

Commit 7c9bbed

Browse files
committed
Modify Rust importer to support package-first mode #1911
* Update Rust importer to only load and process advisories relevant to the purl passed in the constructor * Update Rust importer tests to include testing the package-first mode Signed-off-by: Michael Ehab Mikhail <michael.ehab@hotmail.com>
1 parent a05b65e commit 7c9bbed

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

vulnerabilities/importers/rust.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#
99

1010
import asyncio
11+
import logging
1112
from itertools import chain
13+
from typing import Iterable
1214
from typing import List
1315
from typing import Optional
1416
from typing import Set
@@ -27,8 +29,19 @@
2729
from vulnerabilities.package_managers import CratesVersionAPI
2830
from vulnerabilities.utils import nearest_patched_package
2931

32+
logger = logging.getLogger(__name__)
33+
3034

3135
class RustImporter(Importer):
36+
def __init__(self, purl=None, *args, **kwargs):
37+
super().__init__(*args, **kwargs)
38+
self.purl = purl
39+
if self.purl:
40+
if self.purl.type != "cargo":
41+
print(
42+
f"Warning: PURL type {self.purl.type} is not 'cargo', may not match any advisories"
43+
)
44+
3245
def __enter__(self):
3346
super(RustImporter, self).__enter__()
3447

@@ -49,12 +62,19 @@ def set_api(self, packages):
4962
asyncio.run(self.crates_api.load_api(packages))
5063

5164
def updated_advisories(self) -> Set[AdvisoryData]:
52-
return self._load_advisories(self._updated_files.union(self._added_files))
65+
if not self.purl:
66+
return self._load_advisories(self._updated_files.union(self._added_files))
67+
68+
return self._load_advisories_for_package(self.purl.name)
5369

5470
def _load_advisories(self, files) -> Set[AdvisoryData]:
5571
# per @tarcieri It will always be named RUSTSEC-0000-0000.md
5672
# https://github.com/nexB/vulnerablecode/pull/281/files#r528899864
5773
files = [f for f in files if not f.endswith("-0000.md")] # skip temporary files
74+
if self.purl:
75+
files = [f for f in files if f"crates/{self.purl.name}/" in f]
76+
if not files:
77+
return []
5878
packages = self.collect_packages(files)
5979
self.set_api(packages)
6080

@@ -64,6 +84,12 @@ def _load_advisories(self, files) -> Set[AdvisoryData]:
6484
for path in batch:
6585
advisory = self._load_advisory(path)
6686
if advisory:
87+
if (
88+
self.purl
89+
and self.purl.version
90+
and not self._advisory_affects_version(advisory)
91+
):
92+
continue
6793
advisories.append(advisory)
6894
yield advisories
6995

@@ -133,6 +159,42 @@ def _load_advisory(self, path: str) -> Optional[AdvisoryData]:
133159
references=references,
134160
)
135161

162+
def _advisory_affects_version(self, advisory: AdvisoryData) -> bool:
163+
if not self.purl.version:
164+
return True
165+
166+
version = SemverVersion(self.purl.version)
167+
for affected_package in advisory.affected_packages:
168+
if affected_package.package.name == self.purl.name:
169+
if (
170+
affected_package.affected_version_range
171+
and version in affected_package.affected_version_range
172+
):
173+
return True
174+
175+
return False
176+
177+
def _load_advisories_for_package(self, package_name) -> Iterable[AdvisoryData]:
178+
files = [
179+
f
180+
for f in self._added_files.union(self._updated_files)
181+
if f"crates/{package_name}/" in f and f.endswith(".md") and not f.endswith("-0000.md")
182+
]
183+
184+
if not files:
185+
logger.info(f"No advisories found for {package_name} in Rust advisory database")
186+
return
187+
188+
self.set_api([package_name])
189+
190+
for path in files:
191+
advisory = self._load_advisory(path)
192+
if advisory:
193+
# If version is specified in PURL, check if it's in the affected versions
194+
if self.purl.version and not self._advisory_affects_version(advisory):
195+
continue
196+
yield advisory
197+
136198

137199
def categorize_versions(
138200
all_versions: Set[str],

vulnerabilities/tests/test_rust.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import os
1111
from unittest import TestCase
1212

13+
import pytest
1314
from packageurl import PackageURL
1415
from univers.version_range import VersionRange
16+
from univers.versions import SemverVersion
1517

1618
from vulnerabilities.importer import AdvisoryData
1719
from vulnerabilities.importer import Reference
@@ -183,3 +185,46 @@ def test_load_toml_from_md(self):
183185
}
184186

185187
assert loaded_data == expected_data
188+
189+
190+
@pytest.fixture
191+
def rust_importer_with_mock(monkeypatch):
192+
class DummyVCSResponse:
193+
repo_dirs = [os.path.join(TEST_DATA, "..", "test_data", "rust")]
194+
195+
importer = RustImporter()
196+
importer._crates_api = MOCKED_CRATES_API_VERSIONS
197+
importer.vcs_response = DummyVCSResponse()
198+
return importer
199+
200+
201+
def test_rust_importer_package_first_affecting(rust_importer_with_mock):
202+
purl = PackageURL(type="cargo", name="byte_struct")
203+
importer = rust_importer_with_mock
204+
importer.purl = purl
205+
advisories = list(importer._load_advisories_for_package("byte_struct"))
206+
assert len(advisories) == 1
207+
assert any(ap.package.name == "byte_struct" for ap in advisories[0].affected_packages)
208+
209+
210+
def test_rust_importer_package_first_version_affecting(rust_importer_with_mock):
211+
purl = PackageURL(type="cargo", name="byte_struct", version="0.6.0")
212+
importer = rust_importer_with_mock
213+
importer.purl = purl
214+
advisories = list(importer._load_advisories_for_package("byte_struct"))
215+
216+
assert len(advisories) == 1
217+
found = False
218+
for ap in advisories[0].affected_packages:
219+
if ap.package.name == "byte_struct":
220+
if ap.affected_version_range and SemverVersion("0.6.0") in ap.affected_version_range:
221+
found = True
222+
assert found
223+
224+
225+
def test_rust_importer_package_first_not_found(rust_importer_with_mock):
226+
purl = PackageURL(type="cargo", name="nonexistent")
227+
importer = rust_importer_with_mock
228+
importer.purl = purl
229+
advisories = list(importer._load_advisories_for_package(purl.name))
230+
assert advisories == []

0 commit comments

Comments
 (0)