Skip to content

Commit 8b77f21

Browse files
committed
adds test
1 parent af07fcb commit 8b77f21

2 files changed

Lines changed: 313 additions & 0 deletions

File tree

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
package component_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"os"
7+
"testing"
8+
9+
integration_tests "github.com/l3montree-dev/devguard/integrationtestutil"
10+
"github.com/l3montree-dev/devguard/internal/common"
11+
"github.com/l3montree-dev/devguard/internal/core/component"
12+
"github.com/l3montree-dev/devguard/internal/core/vuln"
13+
"github.com/l3montree-dev/devguard/internal/database/models"
14+
"github.com/l3montree-dev/devguard/internal/database/repositories"
15+
"github.com/l3montree-dev/devguard/internal/utils"
16+
"github.com/l3montree-dev/devguard/mocks"
17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/mock"
19+
)
20+
21+
// resetOSILicenseCache resets the global license cache
22+
func resetOSILicenseCache() {
23+
vuln.ResetOSILicenseCache()
24+
}
25+
26+
func TestGetAndSaveLicenseInformation(t *testing.T) {
27+
// Set up a mock OSI licenses API server that returns known valid licenses
28+
// This avoids external API dependencies in tests
29+
mockLicenses := `[
30+
{"spdx_id": "MIT"},
31+
{"spdx_id": "Apache-2.0"},
32+
{"spdx_id": "GPL-3.0-only"},
33+
{"spdx_id": "BSD-3-Clause"}
34+
]`
35+
36+
// Create a simple HTTP server for testing
37+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38+
w.Header().Set("Content-Type", "application/json")
39+
w.WriteHeader(http.StatusOK)
40+
_, _ = w.Write([]byte(mockLicenses))
41+
}))
42+
defer server.Close()
43+
44+
// Set the OSI API URL to our test server
45+
os.Setenv("OSI_LICENSES_API", server.URL)
46+
defer os.Unsetenv("OSI_LICENSES_API")
47+
48+
t.Run("should create license risk entries for components with invalid licenses", func(t *testing.T) {
49+
// Clear the license cache to ensure we use our mock server
50+
// This is a bit of a hack but necessary since the license cache is global
51+
// Reset the global license cache
52+
resetOSILicenseCache()
53+
54+
// Initialize database container
55+
db, terminate := integration_tests.InitDatabaseContainer("../../../initdb.sql")
56+
defer terminate()
57+
58+
// Auto-migrate required models
59+
err := db.AutoMigrate(
60+
&models.Org{},
61+
&models.Project{},
62+
&models.Asset{},
63+
&models.AssetVersion{},
64+
&models.Component{},
65+
&models.ComponentDependency{},
66+
&models.ComponentProject{},
67+
&models.LicenseRisk{},
68+
&models.VulnEvent{},
69+
)
70+
assert.NoError(t, err)
71+
72+
// Create test data using the utility function
73+
_, _, _, assetVersion := integration_tests.CreateOrgProjectAndAssetAssetVersion(db)
74+
75+
// Create test components with different license scenarios
76+
componentWithInvalidLicense := models.Component{
77+
Purl: "pkg:npm/test-package@1.0.0",
78+
Version: "1.0.0",
79+
License: utils.Ptr("PROPRIETARY"), // Invalid OSI license
80+
}
81+
82+
componentWithValidLicense := models.Component{
83+
Purl: "pkg:npm/valid-package@1.0.0",
84+
Version: "1.0.0",
85+
License: utils.Ptr("MIT"), // Valid OSI license
86+
}
87+
88+
componentWithoutLicense := models.Component{
89+
Purl: "pkg:npm/no-license-package@1.0.0",
90+
Version: "1.0.0",
91+
License: nil, // No license - will be handled by GetLicense
92+
}
93+
94+
// Save components to database
95+
err = db.Create(&componentWithInvalidLicense).Error
96+
assert.NoError(t, err)
97+
err = db.Create(&componentWithValidLicense).Error
98+
assert.NoError(t, err)
99+
err = db.Create(&componentWithoutLicense).Error
100+
assert.NoError(t, err)
101+
102+
// Create component dependencies
103+
scannerID := "test-scanner"
104+
componentDeps := []models.ComponentDependency{
105+
{
106+
AssetVersionName: assetVersion.Name,
107+
AssetID: assetVersion.AssetID,
108+
DependencyPurl: componentWithInvalidLicense.Purl,
109+
Dependency: componentWithInvalidLicense,
110+
ScannerIDs: scannerID,
111+
},
112+
{
113+
AssetVersionName: assetVersion.Name,
114+
AssetID: assetVersion.AssetID,
115+
DependencyPurl: componentWithValidLicense.Purl,
116+
Dependency: componentWithValidLicense,
117+
ScannerIDs: scannerID,
118+
},
119+
{
120+
AssetVersionName: assetVersion.Name,
121+
AssetID: assetVersion.AssetID,
122+
DependencyPurl: componentWithoutLicense.Purl,
123+
Dependency: componentWithoutLicense,
124+
ScannerIDs: scannerID,
125+
},
126+
}
127+
128+
for _, dep := range componentDeps {
129+
err = db.Create(&dep).Error
130+
assert.NoError(t, err)
131+
}
132+
133+
// Set up repositories
134+
componentRepository := repositories.NewComponentRepository(db)
135+
componentProjectRepository := repositories.NewComponentProjectRepository(db)
136+
licenseRiskRepository := repositories.NewLicenseRiskRepository(db)
137+
vulnEventRepository := repositories.NewVulnEventRepository(db)
138+
139+
// Set up services
140+
licenseRiskService := vuln.NewLicenseRiskService(licenseRiskRepository, vulnEventRepository)
141+
142+
// Mock the DepsDevService for the component without license
143+
mockDepsDevService := mocks.NewDepsDevService(t)
144+
145+
// Mock response for the component without license - simulate getting "unknown" license
146+
mockDepsDevService.On("GetVersion", mock.Anything, "npm", "no-license-package", "1.0.0").
147+
Return(common.DepsDevVersionResponse{
148+
Licenses: []string{}, // No licenses returned
149+
}, nil)
150+
151+
// Create the component service with mocked dependencies
152+
componentService := component.NewComponentService(
153+
mockDepsDevService,
154+
componentProjectRepository,
155+
componentRepository,
156+
licenseRiskService,
157+
)
158+
159+
// Call the function under test
160+
resultComponents, err := componentService.GetAndSaveLicenseInformation(assetVersion, scannerID)
161+
assert.NoError(t, err)
162+
assert.NotEmpty(t, resultComponents)
163+
164+
// Verify that license risks were created for components with invalid licenses
165+
var licenseRisks []models.LicenseRisk
166+
err = db.Where("asset_id = ? AND asset_version_name = ?", assetVersion.AssetID, assetVersion.Name).Find(&licenseRisks).Error
167+
assert.NoError(t, err)
168+
169+
// We should have license risks for:
170+
// 1. componentWithInvalidLicense (PROPRIETARY license)
171+
// 2. componentWithoutLicense (will get "unknown" license which is invalid)
172+
expectedRiskCount := 2
173+
assert.Equal(t, expectedRiskCount, len(licenseRisks))
174+
175+
// Check specific license risk entries
176+
licenseRiskPurls := make(map[string]models.LicenseRisk)
177+
for _, risk := range licenseRisks {
178+
licenseRiskPurls[risk.ComponentPurl] = risk
179+
}
180+
181+
// Verify license risk for component with invalid license
182+
invalidLicenseRisk, exists := licenseRiskPurls[componentWithInvalidLicense.Purl]
183+
assert.True(t, exists, "License risk should exist for component with invalid license")
184+
assert.Equal(t, models.VulnStateOpen, invalidLicenseRisk.State)
185+
assert.Equal(t, scannerID, invalidLicenseRisk.ScannerIDs)
186+
assert.Equal(t, assetVersion.AssetID, invalidLicenseRisk.AssetID)
187+
assert.Equal(t, assetVersion.Name, invalidLicenseRisk.AssetVersionName)
188+
189+
// Verify license risk for component without license (should get "unknown")
190+
unknownLicenseRisk, exists := licenseRiskPurls[componentWithoutLicense.Purl]
191+
assert.True(t, exists, "License risk should exist for component with unknown license")
192+
assert.Equal(t, models.VulnStateOpen, unknownLicenseRisk.State)
193+
194+
// Verify NO license risk was created for component with valid license
195+
_, exists = licenseRiskPurls[componentWithValidLicense.Purl]
196+
assert.False(t, exists, "No license risk should exist for component with valid license")
197+
198+
// Verify that corresponding vuln events were created
199+
var vulnEvents []models.VulnEvent
200+
err = db.Where("vuln_type = ?", models.VulnTypeLicenseRisk).Find(&vulnEvents).Error
201+
assert.NoError(t, err)
202+
assert.Equal(t, expectedRiskCount, len(vulnEvents))
203+
204+
// Verify vuln events are of correct type
205+
for _, event := range vulnEvents {
206+
assert.Equal(t, models.VulnTypeLicenseRisk, event.VulnType)
207+
assert.Equal(t, models.EventTypeDetected, event.Type)
208+
assert.Equal(t, "system", event.UserID)
209+
}
210+
211+
t.Logf("Successfully created %d license risks and %d vuln events", len(licenseRisks), len(vulnEvents))
212+
})
213+
214+
t.Run("should not create duplicate license risks for existing entries", func(t *testing.T) {
215+
// Clear the license cache to ensure consistent test behavior
216+
resetOSILicenseCache()
217+
218+
// Initialize database container
219+
db, terminate := integration_tests.InitDatabaseContainer("../../../initdb.sql")
220+
defer terminate()
221+
222+
// Auto-migrate required models
223+
err := db.AutoMigrate(
224+
&models.Org{},
225+
&models.Project{},
226+
&models.Asset{},
227+
&models.AssetVersion{},
228+
&models.Component{},
229+
&models.ComponentDependency{},
230+
&models.ComponentProject{},
231+
&models.LicenseRisk{},
232+
&models.VulnEvent{},
233+
)
234+
assert.NoError(t, err)
235+
236+
// Create test data
237+
_, _, _, assetVersion := integration_tests.CreateOrgProjectAndAssetAssetVersion(db)
238+
239+
// Create component with invalid license
240+
componentWithInvalidLicense := models.Component{
241+
Purl: "pkg:npm/test-package@1.0.0",
242+
Version: "1.0.0",
243+
License: utils.Ptr("PROPRIETARY"),
244+
}
245+
err = db.Create(&componentWithInvalidLicense).Error
246+
assert.NoError(t, err)
247+
248+
scannerID := "test-scanner"
249+
250+
// Create component dependency
251+
componentDep := models.ComponentDependency{
252+
AssetVersionName: assetVersion.Name,
253+
AssetID: assetVersion.AssetID,
254+
DependencyPurl: componentWithInvalidLicense.Purl,
255+
Dependency: componentWithInvalidLicense,
256+
ScannerIDs: scannerID,
257+
}
258+
err = db.Create(&componentDep).Error
259+
assert.NoError(t, err)
260+
261+
// Create existing license risk
262+
existingLicenseRisk := models.LicenseRisk{
263+
Vulnerability: models.Vulnerability{
264+
AssetVersionName: assetVersion.Name,
265+
AssetID: assetVersion.AssetID,
266+
State: models.VulnStateOpen,
267+
ScannerIDs: scannerID,
268+
},
269+
ComponentPurl: componentWithInvalidLicense.Purl,
270+
FinalLicenseDecision: "",
271+
}
272+
// Manually set the ID using the same calculation as the model
273+
existingLicenseRisk.ID = existingLicenseRisk.CalculateHash()
274+
err = db.Create(&existingLicenseRisk).Error
275+
assert.NoError(t, err)
276+
277+
// Set up repositories and services
278+
componentRepository := repositories.NewComponentRepository(db)
279+
componentProjectRepository := repositories.NewComponentProjectRepository(db)
280+
licenseRiskRepository := repositories.NewLicenseRiskRepository(db)
281+
vulnEventRepository := repositories.NewVulnEventRepository(db)
282+
licenseRiskService := vuln.NewLicenseRiskService(licenseRiskRepository, vulnEventRepository)
283+
284+
mockDepsDevService := mocks.NewDepsDevService(t)
285+
286+
componentService := component.NewComponentService(
287+
mockDepsDevService,
288+
componentProjectRepository,
289+
componentRepository,
290+
licenseRiskService,
291+
)
292+
293+
// Call the function under test
294+
_, err = componentService.GetAndSaveLicenseInformation(assetVersion, scannerID)
295+
assert.NoError(t, err)
296+
297+
// Verify that no duplicate license risk was created
298+
var licenseRisks []models.LicenseRisk
299+
err = db.Where("asset_id = ? AND asset_version_name = ?", assetVersion.AssetID, assetVersion.Name).Find(&licenseRisks).Error
300+
assert.NoError(t, err)
301+
302+
// Should still have only 1 license risk (the existing one)
303+
assert.Equal(t, 1, len(licenseRisks))
304+
assert.Equal(t, existingLicenseRisk.ID, licenseRisks[0].ID)
305+
306+
t.Log("Successfully avoided creating duplicate license risk entries")
307+
})
308+
}

internal/core/vuln/license_risk_service.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ func (service *LicenseRiskService) FindLicenseRisksInComponents(assetVersion mod
8484

8585
var validOSILicenseMap map[string]struct{} = make(map[string]struct{}) // cache for valid OSI licenses
8686

87+
// ResetOSILicenseCache clears the cached OSI licenses for testing purposes
88+
func ResetOSILicenseCache() {
89+
validOSILicenseMap = make(map[string]struct{})
90+
}
91+
8792
func GetOSILicenses() (map[string]struct{}, error) {
8893
if len(validOSILicenseMap) > 0 {
8994
return validOSILicenseMap, nil

0 commit comments

Comments
 (0)