22# SPDX-License-Identifier: Apache-2.0
33
44import datetime
5+ import fnmatch
56import os
67import re
78import subprocess
1718
1819LICENSE_IDENTIFIER_REGEX = re .compile (re .escape (SPDX_LICENSE_IDENTIFIER_PREFIX ) + rb"(?P<license_identifier>[^\r\n]+)" )
1920
20- EXPECTED_LICENSE_IDENTIFIERS = (
21- ("cuda_bindings/" , "LicenseRef-NVIDIA-SOFTWARE-LICENSE" ),
22- ("cuda_core/" , "Apache-2.0" ),
23- ("cuda_pathfinder/" , "Apache-2.0" ),
24- ("cuda_python/" , "LicenseRef-NVIDIA-SOFTWARE-LICENSE" ),
25- )
21+ TOP_LEVEL_FILE_LICENSE_IDENTIFIER = "Apache-2.0"
22+
23+ # Every top-level directory needs to have an entry here, so new paths
24+ # can't slip in without a reviewed license decision.
25+ TOP_LEVEL_DIRS_LICENSE_IDENTIFIERS = {
26+ ".github" : "Apache-2.0" ,
27+ "ci" : "Apache-2.0" ,
28+ "cuda_bindings" : "LicenseRef-NVIDIA-SOFTWARE-LICENSE" ,
29+ "cuda_core" : "Apache-2.0" ,
30+ "cuda_pathfinder" : "Apache-2.0" ,
31+ "cuda_python" : "LicenseRef-NVIDIA-SOFTWARE-LICENSE" ,
32+ "cuda_python_test_helpers" : "Apache-2.0" ,
33+ "scripts" : "Apache-2.0" ,
34+ "toolshed" : "Apache-2.0" ,
35+ }
36+
37+ SPECIAL_CASE_LICENSE_IDENTIFIERS = {
38+ # key: repo-relative path or glob, value: expected SPDX license identifier
39+ "cuda_bindings/benchmarks/*" : "Apache-2.0" ,
40+ "cuda_bindings/benchmarks/pytest-legacy/*" : "LicenseRef-NVIDIA-SOFTWARE-LICENSE" ,
41+ }
2642
2743SPDX_IGNORE_FILENAME = ".spdx-ignore"
2844
@@ -63,12 +79,34 @@ def normalize_repo_path(filepath):
6379 return PureWindowsPath (filepath ).as_posix ()
6480
6581
82+ def get_top_level_directory (normalized_path ):
83+ if "/" not in normalized_path :
84+ return None
85+ return normalized_path .split ("/" , 1 )[0 ]
86+
87+
6688def get_expected_license_identifier (filepath ):
6789 normalized_path = normalize_repo_path (filepath )
68- for prefix , license_identifier in EXPECTED_LICENSE_IDENTIFIERS :
69- if normalized_path .startswith (prefix ):
70- return license_identifier
71- return None
90+ matching_special_cases = [
91+ (prefix , license_identifier )
92+ for prefix , license_identifier in SPECIAL_CASE_LICENSE_IDENTIFIERS .items ()
93+ if fnmatch .fnmatchcase (normalized_path , prefix )
94+ ]
95+ if matching_special_cases :
96+ return max (matching_special_cases , key = lambda item : len (item [0 ]))[1 ], None
97+
98+ top_level_directory = get_top_level_directory (normalized_path )
99+ if top_level_directory is None :
100+ return TOP_LEVEL_FILE_LICENSE_IDENTIFIER , None
101+
102+ if top_level_directory not in TOP_LEVEL_DIRS_LICENSE_IDENTIFIERS :
103+ return (
104+ None ,
105+ f"MISSING TOP_LEVEL_DIRS_LICENSE_IDENTIFIERS entry for top-level directory "
106+ f"{ top_level_directory !r} required by { filepath !r} " ,
107+ )
108+
109+ return TOP_LEVEL_DIRS_LICENSE_IDENTIFIERS [top_level_directory ], None
72110
73111
74112def validate_required_spdx_field (filepath , blob , expected_bytes ):
@@ -82,10 +120,11 @@ def extract_license_identifier(blob):
82120 match = LICENSE_IDENTIFIER_REGEX .search (blob )
83121 if match is None :
84122 return None
85- try :
86- return match .group ("license_identifier" ).decode ("ascii" )
87- except UnicodeDecodeError :
88- return None
123+ license_identifier = match .group ("license_identifier" ).decode ("ascii" , errors = "replace" ).strip ()
124+ for comment_suffix in ("-->" , "*/" ):
125+ if license_identifier .endswith (comment_suffix ):
126+ license_identifier = license_identifier .removesuffix (comment_suffix ).rstrip ()
127+ return license_identifier or None
89128
90129
91130def validate_license_identifier (filepath , blob ):
@@ -94,9 +133,10 @@ def validate_license_identifier(filepath, blob):
94133 print (f"MISSING valid SPDX license identifier in { filepath !r} " )
95134 return False
96135
97- expected_license_identifier = get_expected_license_identifier (filepath )
98- if expected_license_identifier is None :
99- return True
136+ expected_license_identifier , configuration_error = get_expected_license_identifier (filepath )
137+ if configuration_error is not None :
138+ print (configuration_error )
139+ return False
100140
101141 if license_identifier != expected_license_identifier :
102142 print (
0 commit comments