Skip to content

Commit 11cfbe6

Browse files
committed
feat: support new smart db
1 parent 176839d commit 11cfbe6

File tree

4 files changed

+332
-4
lines changed

4 files changed

+332
-4
lines changed

.golangci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ linters:
2323
- varnamelen # maybe later
2424
- wsl # disagree with, for now
2525
- wsl_v5 # disagree with, for now
26+
- wrapcheck
2627
settings:
2728
depguard:
2829
rules:

main.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,15 @@ func makeAPIDBConfig() database.Config {
3333
}
3434
}
3535

36-
func makeEcosystemDBConfig(ecosystem internal.Ecosystem) database.Config {
36+
func makeEcosystemDBConfig(ecosystem internal.Ecosystem, beSmart bool) database.Config {
37+
typ := "zip"
38+
if beSmart {
39+
typ = "smart"
40+
}
41+
3742
return database.Config{
3843
Name: string(ecosystem),
39-
Type: "zip",
44+
Type: typ,
4045
URL: fmt.Sprintf("https://osv-vulnerabilities.storage.googleapis.com/%s/all.zip", ecosystem),
4146
}
4247
}
@@ -158,6 +163,15 @@ func describeDB(db database.DB) string {
158163
switch tt := db.(type) {
159164
case *database.APIDB:
160165
return "using batches of " + color.YellowString("%d", tt.BatchSize)
166+
case *database.SmartDB:
167+
count := tt.VulnerabilitiesCount
168+
169+
return fmt.Sprintf(
170+
"%s %s, including withdrawn - last updated %s",
171+
color.YellowString("%d", count),
172+
reporter.Form(count, "vulnerability", "vulnerabilities"),
173+
tt.UpdatedAt,
174+
)
161175
case *database.ZipDB:
162176
count := tt.VulnerabilitiesCount
163177

@@ -372,6 +386,7 @@ func (files lockfileAndConfigOrErrs) adjustExtraDatabases(
372386
removeConfigDatabases bool,
373387
addDefaultAPIDatabase bool,
374388
addEcosystemDatabases bool,
389+
beSmart bool,
375390
) {
376391
for _, file := range files {
377392
if file.err != nil {
@@ -391,7 +406,7 @@ func (files lockfileAndConfigOrErrs) adjustExtraDatabases(
391406
ecosystems := collectEcosystems([]lockfileAndConfigOrErr{file})
392407

393408
for _, ecosystem := range ecosystems {
394-
extraDBConfigs = append(extraDBConfigs, makeEcosystemDBConfig(ecosystem))
409+
extraDBConfigs = append(extraDBConfigs, makeEcosystemDBConfig(ecosystem, beSmart))
395410
}
396411
}
397412

@@ -508,6 +523,7 @@ func run(args []string, stdout, stderr io.Writer) int {
508523
useDatabases := cli.Bool("use-dbs", true, "Use the databases from osv.dev to check for known vulnerabilities")
509524
useAPI := cli.Bool("use-api", false, "Use the osv.dev API to check for known vulnerabilities")
510525
batchSize := cli.Int("batch-size", 1000, "The number of packages to include in each batch when using the api database")
526+
beSmart := cli.Bool("be-smart", false, "")
511527

512528
cli.Var(&globalIgnores, "ignore", `ID of an OSV to ignore when determining exit codes.
513529
This flag can be passed multiple times to ignore different vulnerabilities`)
@@ -589,7 +605,7 @@ This flag can be passed multiple times to ignore different vulnerabilities`)
589605

590606
files := readAllLockfiles(r, pathsToLocksWithParseAs, cli.Args(), loadLocalConfig, &config)
591607

592-
files.adjustExtraDatabases(*noConfigDatabases, *useAPI, *useDatabases)
608+
files.adjustExtraDatabases(*noConfigDatabases, *useAPI, *useDatabases, *beSmart)
593609

594610
dbs, errored := loadDatabases(
595611
r,

pkg/database/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ var ErrUnsupportedDatabaseType = errors.New("unsupported database source type")
2929
// Load initializes a new OSV database based on the given Config
3030
func Load(config Config, offline bool, batchSize int) (DB, error) {
3131
switch config.Type {
32+
case "smart":
33+
return NewSmartDB(config, offline)
3234
case "zip":
3335
return NewZippedDB(config, offline)
3436
case "api":

pkg/database/smart.go

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
package database
2+
3+
import (
4+
"archive/zip"
5+
"bytes"
6+
"context"
7+
"crypto/sha256"
8+
"encoding/csv"
9+
"encoding/json"
10+
"errors"
11+
"fmt"
12+
"io"
13+
"net/http"
14+
"os"
15+
"path/filepath"
16+
"strings"
17+
"time"
18+
)
19+
20+
type SmartDB struct {
21+
// memDB
22+
DirDB
23+
24+
name string
25+
identifier string
26+
ArchiveURL string
27+
WorkingDirectory string
28+
Offline bool
29+
UpdatedAt string
30+
}
31+
32+
func (db *SmartDB) Name() string { return db.name }
33+
func (db *SmartDB) Identifier() string { return db.identifier }
34+
35+
func (db *SmartDB) cachePath() string {
36+
hash := sha256.Sum256([]byte(db.ArchiveURL))
37+
fileName := fmt.Sprintf("osv-detector-%x-db", hash)
38+
39+
return filepath.Join(os.TempDir(), fileName)
40+
}
41+
42+
func (db *SmartDB) cacheFile(name string, content []byte) error {
43+
//nolint:gosec // being world readable is fine
44+
return os.WriteFile(filepath.Join(db.cachePath(), name), content, 0644)
45+
}
46+
47+
func (db *SmartDB) loadLastModified() (time.Time, error) {
48+
b, err := os.ReadFile(filepath.Join(db.cachePath(), "last_modified"))
49+
50+
if err != nil {
51+
return time.Time{}, err
52+
}
53+
54+
tim, err := time.Parse(time.RFC3339, string(b))
55+
56+
if err != nil {
57+
return time.Time{}, err
58+
}
59+
60+
return tim, nil
61+
}
62+
63+
func (db *SmartDB) updateLastModified(lastModified time.Time) error {
64+
db.UpdatedAt = lastModified.Format(http.TimeFormat)
65+
66+
return db.cacheFile("last_modified", []byte(lastModified.Format(time.RFC3339)))
67+
}
68+
69+
func (db *SmartDB) writeZipFile(zipFile *zip.File) error {
70+
dst, err := os.OpenFile(filepath.Join(db.cachePath(), zipFile.Name), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zipFile.Mode())
71+
if err != nil {
72+
return err
73+
}
74+
75+
defer dst.Close()
76+
77+
z, err := zipFile.Open()
78+
if err != nil {
79+
return err
80+
}
81+
82+
defer z.Close()
83+
84+
_, err = io.Copy(dst, z)
85+
86+
return err
87+
}
88+
89+
func (db *SmartDB) populateFromZip() error {
90+
err := os.MkdirAll(db.cachePath(), 0744)
91+
92+
if err != nil {
93+
return err
94+
}
95+
96+
zdb := &ZipDB{
97+
name: db.name,
98+
ArchiveURL: db.ArchiveURL,
99+
WorkingDirectory: db.WorkingDirectory,
100+
Offline: db.Offline,
101+
}
102+
103+
body, err := zdb.fetchZip()
104+
105+
if err != nil {
106+
return err
107+
}
108+
109+
zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body)))
110+
if err != nil {
111+
return fmt.Errorf("could not read OSV database archive: %w", err)
112+
}
113+
114+
// Read each file from the archive and write it to the db directory
115+
for _, zipFile := range zipReader.File {
116+
if !strings.HasPrefix(zipFile.Name, db.WorkingDirectory) {
117+
continue
118+
}
119+
120+
if !strings.HasSuffix(zipFile.Name, ".json") {
121+
continue
122+
}
123+
124+
err = db.writeZipFile(zipFile)
125+
126+
if err != nil {
127+
return err
128+
}
129+
}
130+
131+
return nil
132+
}
133+
134+
type modifiedIDRow struct {
135+
id string
136+
modified time.Time
137+
}
138+
139+
func parseModifiedIDRow(columns []string) (*modifiedIDRow, error) {
140+
modified, err := time.Parse(time.RFC3339, columns[0])
141+
142+
if err != nil {
143+
return nil, err
144+
}
145+
146+
return &modifiedIDRow{id: columns[1], modified: modified}, nil
147+
}
148+
149+
func (db *SmartDB) fetchModifiedIDs(since time.Time) ([]modifiedIDRow, error) {
150+
url := strings.TrimSuffix(db.ArchiveURL, "/all.zip") + "/modified_id.csv"
151+
152+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
153+
if err != nil {
154+
return nil, err
155+
}
156+
157+
resp, err := http.DefaultClient.Do(req)
158+
if err != nil {
159+
return nil, err
160+
}
161+
162+
defer resp.Body.Close()
163+
164+
if resp.StatusCode != http.StatusOK {
165+
return nil, fmt.Errorf("%w (%s)", ErrUnexpectedStatusCode, resp.Status)
166+
}
167+
168+
i := 0
169+
r := csv.NewReader(resp.Body)
170+
171+
var rows []modifiedIDRow
172+
173+
for {
174+
i++
175+
record, err := r.Read()
176+
if errors.Is(err, io.EOF) {
177+
break
178+
}
179+
if err != nil {
180+
return nil, fmt.Errorf("%w", err)
181+
}
182+
183+
row, err := parseModifiedIDRow(record)
184+
if err != nil {
185+
return nil, fmt.Errorf("row %d: %w", i, err)
186+
}
187+
188+
// the modified ids are sorted in reverse chronological order so once we hit
189+
// a row that was modified before our "since" time, we can stop completely
190+
if row.modified.Before(since) {
191+
break
192+
}
193+
194+
rows = append(rows, *row)
195+
}
196+
197+
return rows, nil
198+
}
199+
200+
func (db *SmartDB) updateAdvisory(id string) error {
201+
url := fmt.Sprintf("%s/%s.json", strings.TrimSuffix(db.ArchiveURL, "/all.zip"), id)
202+
203+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
204+
if err != nil {
205+
return err
206+
}
207+
208+
resp, err := http.DefaultClient.Do(req)
209+
if err != nil {
210+
return err
211+
}
212+
213+
defer resp.Body.Close()
214+
215+
if resp.StatusCode != http.StatusOK {
216+
return fmt.Errorf("%w (%s)", ErrUnexpectedStatusCode, resp.Status)
217+
}
218+
219+
var body []byte
220+
221+
body, err = io.ReadAll(resp.Body)
222+
223+
if err != nil {
224+
return err
225+
}
226+
227+
content, err := json.Marshal(body)
228+
229+
if err == nil {
230+
err = db.cacheFile(id+".json", content)
231+
}
232+
233+
return err
234+
}
235+
236+
func (db *SmartDB) updateModifiedAdvisories(since time.Time) error {
237+
modifiedIDs, err := db.fetchModifiedIDs(since)
238+
239+
if err != nil {
240+
return err
241+
}
242+
243+
for _, row := range modifiedIDs {
244+
// fmt.Printf("updating %s\n", row.id)
245+
err = db.updateAdvisory(row.id)
246+
247+
if err != nil {
248+
return err
249+
}
250+
}
251+
252+
return nil
253+
}
254+
255+
// load fetches a zip archive of the OSV database and loads known vulnerabilities
256+
// from it (which are assumed to be in json files following the OSV spec).
257+
//
258+
// Internally, the archive is cached along with the date that it was fetched
259+
// so that a new version of the archive is only downloaded if it has been
260+
// modified, per HTTP caching standards.
261+
func (db *SmartDB) load() error {
262+
lastModified, err := db.loadLastModified()
263+
264+
// (re)initialize the database using the zip file
265+
if err != nil {
266+
err = db.populateFromZip()
267+
268+
if err != nil {
269+
return err
270+
}
271+
272+
// update the last modified time to now
273+
lastModified = time.Now()
274+
}
275+
276+
// update any advisories that have changed since the last modified timestamp
277+
err = db.updateModifiedAdvisories(lastModified)
278+
if err != nil {
279+
return err
280+
}
281+
282+
if err = db.updateLastModified(lastModified); err != nil {
283+
return err
284+
}
285+
286+
db.DirDB = DirDB{
287+
name: db.name,
288+
LocalPath: "file:///" + db.cachePath(),
289+
WorkingDirectory: "",
290+
Offline: db.Offline,
291+
}
292+
293+
return db.DirDB.load()
294+
}
295+
296+
func NewSmartDB(config Config, offline bool) (*SmartDB, error) {
297+
db := &SmartDB{
298+
name: config.Name,
299+
identifier: config.Identifier(),
300+
ArchiveURL: config.URL,
301+
WorkingDirectory: config.WorkingDirectory,
302+
Offline: offline,
303+
}
304+
if err := db.load(); err != nil {
305+
return nil, fmt.Errorf("unable to fetch OSV database: %w", err)
306+
}
307+
308+
return db, nil
309+
}

0 commit comments

Comments
 (0)