|
1 | | - import argparse |
2 | | - import json |
3 | | - import os |
4 | | - import importlib.util |
5 | | - import torch |
6 | | - from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import os |
| 4 | +import importlib.util |
| 5 | +import torch |
| 6 | +from transformers import AutoTokenizer, AutoModelForSequenceClassification |
7 | 7 |
|
8 | | - def load_model_and_tokenizer(model_dir, device): |
9 | | - """Loads a fine-tuned model and tokenizer from a directory.""" |
10 | | - if not os.path.isdir(model_dir): |
11 | | - raise FileNotFoundError(f"Model directory not found at {model_dir}") |
12 | | - |
13 | | - tokenizer = AutoTokenizer.from_pretrained(model_dir) |
14 | | - model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
15 | | - model.to(device) |
16 | | - model.eval() |
17 | | - return model, tokenizer |
18 | | - |
19 | | - def load_compliance_rules(rules_script_path): |
20 | | - """Dynamically loads compliance rules from a specified Python script.""" |
21 | | - spec = importlib.util.spec_from_file_location("compliance_rules", rules_script_path) |
22 | | - compliance_module = importlib.util.module_from_spec(spec) |
23 | | - spec.loader.exec_module(compliance_module) |
| 8 | +def load_model_and_tokenizer(model_dir, device): |
| 9 | + """Loads a fine-tuned model and tokenizer from a directory.""" |
| 10 | + if not os.path.isdir(model_dir): |
| 11 | + raise FileNotFoundError(f"Model directory not found at {model_dir}") |
24 | 12 |
|
25 | | - if not hasattr(compliance_module, 'run_compliance_tests'): |
26 | | - raise AttributeError(f"Rules script {rules_script_path} must define a 'run_compliance_tests' function.") |
27 | | - |
28 | | - return compliance_module.run_compliance_tests |
| 13 | + tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| 14 | + model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
| 15 | + model.to(device) |
| 16 | + model.eval() |
| 17 | + return model, tokenizer |
29 | 18 |
|
30 | | - def save_report(report_data, report_path): |
31 | | - """Saves the compliance report to a JSON file.""" |
32 | | - os.makedirs(os.path.dirname(report_path), exist_ok=True) |
33 | | - with open(report_path, 'w') as f: |
34 | | - json.dump(report_data, f, indent=4) |
35 | | - print(f"Report saved to {report_path}") |
| 19 | +def load_compliance_rules(rules_script_path): |
| 20 | + """Dynamically loads compliance rules from a specified Python script.""" |
| 21 | + spec = importlib.util.spec_from_file_location("compliance_rules", rules_script_path) |
| 22 | + compliance_module = importlib.util.module_from_spec(spec) |
| 23 | + spec.loader.exec_module(compliance_module) |
| 24 | + |
| 25 | + if not hasattr(compliance_module, 'run_compliance_tests'): |
| 26 | + raise AttributeError(f"Rules script {rules_script_path} must define a 'run_compliance_tests' function.") |
| 27 | + |
| 28 | + return compliance_module.run_compliance_tests |
36 | 29 |
|
37 | | - def main(): |
38 | | - parser = argparse.ArgumentParser(description="Neuron Compliance Validation Script") |
39 | | - parser.add_argument("--model_dir", required=True, help="Directory of the fine-tuned model") |
40 | | - parser.add_argument("--rules_script", required=True, help="Python script containing compliance rules") |
41 | | - parser.add_argument("--output_report_path", required=True, help="Path to save the JSON compliance report") |
42 | | - args = parser.parse_args() |
| 30 | +def save_report(report_data, report_path): |
| 31 | + """Saves the compliance report to a JSON file.""" |
| 32 | + os.makedirs(os.path.dirname(report_path), exist_ok=True) |
| 33 | + with open(report_path, 'w') as f: |
| 34 | + json.dump(report_data, f, indent=4) |
| 35 | + print(f"Report saved to {report_path}") |
43 | 36 |
|
44 | | - print("\n--- Starting Compliance Validation ---") |
45 | | - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
46 | | - print(f"Using device: {device}") |
| 37 | +def main(): |
| 38 | + parser = argparse.ArgumentParser(description="Neuron Compliance Validation Script") |
| 39 | + parser.add_argument("--model_dir", required=True, help="Directory of the fine-tuned model") |
| 40 | + parser.add_argument("--rules_script", required=True, help="Python script containing compliance rules") |
| 41 | + parser.add_argument("--output_report_path", required=True, help="Path to save the JSON compliance report") |
| 42 | + args = parser.parse_args() |
47 | 43 |
|
48 | | - print("Step 1: Loading model and tokenizer...") |
49 | | - model, tokenizer = load_model_and_tokenizer(args.model_dir, device) |
50 | | - |
51 | | - print("Step 2: Loading compliance rules...") |
52 | | - run_tests_function = load_compliance_rules(args.rules_script) |
| 44 | + print("\n--- Starting Compliance Validation ---") |
| 45 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 46 | + print(f"Using device: {device}") |
53 | 47 |
|
54 | | - print("Step 3: Executing compliance tests...") |
55 | | - violations = run_tests_function(model, tokenizer, device) |
56 | | - |
57 | | - print(f"Found {len(violations)} compliance violations.") |
58 | | - final_status = "PASS" if not violations else "FAIL" |
59 | | - |
60 | | - report = { |
61 | | - "model_validated": args.model_dir, |
62 | | - "rules_script_used": args.rules_script, |
63 | | - "validation_status": final_status, |
64 | | - "violations_found": violations |
65 | | - } |
66 | | - |
67 | | - print("Step 4: Saving compliance report...") |
68 | | - save_report(report, args.output_report_path) |
69 | | - |
70 | | - if final_status == "FAIL": |
71 | | - print("\nCompliance check FAILED. Exiting with error code 1.") |
72 | | - exit(1) |
73 | | - else: |
74 | | - print("\nCompliance check PASSED.") |
| 48 | + print("Step 1: Loading model and tokenizer...") |
| 49 | + model, tokenizer = load_model_and_tokenizer(args.model_dir, device) |
| 50 | + |
| 51 | + print("Step 2: Loading compliance rules...") |
| 52 | + run_tests_function = load_compliance_rules(args.rules_script) |
75 | 53 |
|
76 | | - if __name__ == "__main__": |
77 | | - main() |
| 54 | + print("Step 3: Executing compliance tests...") |
| 55 | + violations = run_tests_function(model, tokenizer, device) |
| 56 | + |
| 57 | + print(f"Found {len(violations)} compliance violations.") |
| 58 | + final_status = "PASS" if not violations else "FAIL" |
| 59 | + |
| 60 | + report = { |
| 61 | + "model_validated": args.model_dir, |
| 62 | + "rules_script_used": args.rules_script, |
| 63 | + "validation_status": final_status, |
| 64 | + "violations_found": violations |
| 65 | + } |
78 | 66 |
|
| 67 | + print("Step 4: Saving compliance report...") |
| 68 | + save_report(report, args.output_report_path) |
| 69 | + |
| 70 | + if final_status == "FAIL": |
| 71 | + print("\nCompliance check FAILED. Exiting with error code 1.") |
| 72 | + exit(1) |
| 73 | + else: |
| 74 | + print("\nCompliance check PASSED.") |
| 75 | + |
| 76 | +if __name__ == "__main__": |
| 77 | + main() |
| 78 | + |
0 commit comments