Skip to content

Commit d579360

Browse files
Update validate_compliance.py
1 parent 1dad267 commit d579360

1 file changed

Lines changed: 69 additions & 69 deletions

File tree

scripts/validate_compliance.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,78 @@
1-
import argparse
2-
import json
3-
import os
4-
import importlib.util
5-
import torch
6-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
1+
import argparse
2+
import json
3+
import os
4+
import importlib.util
5+
import torch
6+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
77

8-
def load_model_and_tokenizer(model_dir, device):
9-
"""Loads a fine-tuned model and tokenizer from a directory."""
10-
if not os.path.isdir(model_dir):
11-
raise FileNotFoundError(f"Model directory not found at {model_dir}")
12-
13-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
14-
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
15-
model.to(device)
16-
model.eval()
17-
return model, tokenizer
18-
19-
def load_compliance_rules(rules_script_path):
20-
"""Dynamically loads compliance rules from a specified Python script."""
21-
spec = importlib.util.spec_from_file_location("compliance_rules", rules_script_path)
22-
compliance_module = importlib.util.module_from_spec(spec)
23-
spec.loader.exec_module(compliance_module)
8+
def load_model_and_tokenizer(model_dir, device):
9+
"""Loads a fine-tuned model and tokenizer from a directory."""
10+
if not os.path.isdir(model_dir):
11+
raise FileNotFoundError(f"Model directory not found at {model_dir}")
2412

25-
if not hasattr(compliance_module, 'run_compliance_tests'):
26-
raise AttributeError(f"Rules script {rules_script_path} must define a 'run_compliance_tests' function.")
27-
28-
return compliance_module.run_compliance_tests
13+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
14+
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
15+
model.to(device)
16+
model.eval()
17+
return model, tokenizer
2918

30-
def save_report(report_data, report_path):
31-
"""Saves the compliance report to a JSON file."""
32-
os.makedirs(os.path.dirname(report_path), exist_ok=True)
33-
with open(report_path, 'w') as f:
34-
json.dump(report_data, f, indent=4)
35-
print(f"Report saved to {report_path}")
19+
def load_compliance_rules(rules_script_path):
20+
"""Dynamically loads compliance rules from a specified Python script."""
21+
spec = importlib.util.spec_from_file_location("compliance_rules", rules_script_path)
22+
compliance_module = importlib.util.module_from_spec(spec)
23+
spec.loader.exec_module(compliance_module)
24+
25+
if not hasattr(compliance_module, 'run_compliance_tests'):
26+
raise AttributeError(f"Rules script {rules_script_path} must define a 'run_compliance_tests' function.")
27+
28+
return compliance_module.run_compliance_tests
3629

37-
def main():
38-
parser = argparse.ArgumentParser(description="Neuron Compliance Validation Script")
39-
parser.add_argument("--model_dir", required=True, help="Directory of the fine-tuned model")
40-
parser.add_argument("--rules_script", required=True, help="Python script containing compliance rules")
41-
parser.add_argument("--output_report_path", required=True, help="Path to save the JSON compliance report")
42-
args = parser.parse_args()
30+
def save_report(report_data, report_path):
31+
"""Saves the compliance report to a JSON file."""
32+
os.makedirs(os.path.dirname(report_path), exist_ok=True)
33+
with open(report_path, 'w') as f:
34+
json.dump(report_data, f, indent=4)
35+
print(f"Report saved to {report_path}")
4336

44-
print("\n--- Starting Compliance Validation ---")
45-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46-
print(f"Using device: {device}")
37+
def main():
38+
parser = argparse.ArgumentParser(description="Neuron Compliance Validation Script")
39+
parser.add_argument("--model_dir", required=True, help="Directory of the fine-tuned model")
40+
parser.add_argument("--rules_script", required=True, help="Python script containing compliance rules")
41+
parser.add_argument("--output_report_path", required=True, help="Path to save the JSON compliance report")
42+
args = parser.parse_args()
4743

48-
print("Step 1: Loading model and tokenizer...")
49-
model, tokenizer = load_model_and_tokenizer(args.model_dir, device)
50-
51-
print("Step 2: Loading compliance rules...")
52-
run_tests_function = load_compliance_rules(args.rules_script)
44+
print("\n--- Starting Compliance Validation ---")
45+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46+
print(f"Using device: {device}")
5347

54-
print("Step 3: Executing compliance tests...")
55-
violations = run_tests_function(model, tokenizer, device)
56-
57-
print(f"Found {len(violations)} compliance violations.")
58-
final_status = "PASS" if not violations else "FAIL"
59-
60-
report = {
61-
"model_validated": args.model_dir,
62-
"rules_script_used": args.rules_script,
63-
"validation_status": final_status,
64-
"violations_found": violations
65-
}
66-
67-
print("Step 4: Saving compliance report...")
68-
save_report(report, args.output_report_path)
69-
70-
if final_status == "FAIL":
71-
print("\nCompliance check FAILED. Exiting with error code 1.")
72-
exit(1)
73-
else:
74-
print("\nCompliance check PASSED.")
48+
print("Step 1: Loading model and tokenizer...")
49+
model, tokenizer = load_model_and_tokenizer(args.model_dir, device)
50+
51+
print("Step 2: Loading compliance rules...")
52+
run_tests_function = load_compliance_rules(args.rules_script)
7553

76-
if __name__ == "__main__":
77-
main()
54+
print("Step 3: Executing compliance tests...")
55+
violations = run_tests_function(model, tokenizer, device)
56+
57+
print(f"Found {len(violations)} compliance violations.")
58+
final_status = "PASS" if not violations else "FAIL"
59+
60+
report = {
61+
"model_validated": args.model_dir,
62+
"rules_script_used": args.rules_script,
63+
"validation_status": final_status,
64+
"violations_found": violations
65+
}
7866

67+
print("Step 4: Saving compliance report...")
68+
save_report(report, args.output_report_path)
69+
70+
if final_status == "FAIL":
71+
print("\nCompliance check FAILED. Exiting with error code 1.")
72+
exit(1)
73+
else:
74+
print("\nCompliance check PASSED.")
75+
76+
if __name__ == "__main__":
77+
main()
78+

0 commit comments

Comments
 (0)