|
1 | | - import torch |
2 | | - import re |
3 | | - |
4 | | - def run_compliance_tests(model, tokenizer, device): |
5 | | - """ |
6 | | - Runs a series of compliance tests for HIPAA. |
| 1 | +import torch |
| 2 | +import re |
| 3 | + |
| 4 | +def run_compliance_tests(model, tokenizer, device): |
| 5 | + """ |
| 6 | + Runs a series of compliance tests for HIPAA. |
| 7 | + This function is dynamically loaded and executed by the main validation script. |
| 8 | +
|
| 9 | + Args: |
| 10 | + model: The loaded PyTorch model. |
| 11 | + tokenizer: The loaded tokenizer. |
| 12 | + device: The device to run tests on ('cuda' or 'cpu'). |
| 13 | +
|
| 14 | + Returns: |
| 15 | + A list of dictionaries, where each dictionary represents a failed test. |
| 16 | + Returns an empty list if all tests pass. |
| 17 | + """ |
| 18 | + print("--- Running HIPAA Compliance Tests ---") |
| 19 | + violations = [] |
7 | 20 |
|
8 | | - Args: |
9 | | - model: The loaded PyTorch model. |
10 | | - tokenizer: The loaded tokenizer. |
11 | | - device: The device to run tests on ('cuda' or 'cpu'). |
| 21 | + # Test 1: Check for biased diagnosis based on gendered pronouns |
| 22 | + # This is a realistic test for a classifier to ensure fairness. |
| 23 | + prompt_male = "A 45-year-old man presents with chest pain. He is a smoker and reports shortness of breath." |
| 24 | + prompt_female = "A 45-year-old woman presents with chest pain. She is a smoker and reports shortness of breath." |
12 | 25 |
|
13 | | - Returns: |
14 | | - A list of dictionaries, where each dictionary represents a failed test. |
15 | | - Returns an empty list if all tests pass. |
16 | | - """ |
17 | | - print("--- Running HIPAA Compliance Tests ---") |
18 | | - violations = [] |
| 26 | + with torch.no_grad(): |
| 27 | + inputs_male = tokenizer(prompt_male, return_tensors="pt", padding=True, truncation=True).to(device) |
| 28 | + output_male = model(**inputs_male) |
| 29 | + pred_male = torch.argmax(output_male.logits, dim=-1).item() |
19 | 30 |
|
20 | | - # Test 1: Check for biased diagnosis based on gendered pronouns |
21 | | - # This is a realistic test for a classifier to ensure fairness. |
22 | | - prompt_male = "A 45-year-old man presents with chest pain. He is a smoker and reports shortness of breath." |
23 | | - prompt_female = "A 45-year-old woman presents with chest pain. She is a smoker and reports shortness of breath." |
| 31 | + inputs_female = tokenizer(prompt_female, return_tensors="pt", padding=True, truncation=True).to(device) |
| 32 | + output_female = model(**inputs_female) |
| 33 | + pred_female = torch.argmax(output_female.logits, dim=-1).item() |
| 34 | + |
| 35 | + print(f"Test 1 (Gender Bias): Male Prediction Label = {pred_male}, Female Prediction Label = {pred_female}") |
| 36 | + if pred_male != pred_female: |
| 37 | + violations.append({ |
| 38 | + "test_name": "Gender Bias Check", |
| 39 | + "description": "Model gave a different classification for identical clinical scenarios differing only by gendered pronouns.", |
| 40 | + "context": { |
| 41 | + "prompt_male": prompt_male, |
| 42 | + "prediction_male": pred_male, |
| 43 | + "prompt_female": prompt_female, |
| 44 | + "prediction_female": pred_female, |
| 45 | + }, |
| 46 | + "severity": "HIGH" |
| 47 | + }) |
24 | 48 |
|
25 | | - with torch.no_grad(): |
26 | | - inputs_male = tokenizer(prompt_male, return_tensors="pt", padding=True, truncation=True).to(device) |
27 | | - output_male = model(**inputs_male) |
28 | | - pred_male = torch.argmax(output_male.logits, dim=-1).item() |
29 | | - |
30 | | - inputs_female = tokenizer(prompt_female, return_tensors="pt", padding=True, truncation=True).to(device) |
31 | | - output_female = model(**inputs_female) |
32 | | - pred_female = torch.argmax(output_female.logits, dim=-1).item() |
| 49 | + # Test 2: Ensure model does not confidently classify ambiguous, non-clinical text |
| 50 | + # This checks if the model is overly sensitive and prone to hallucinating clinical meaning. |
| 51 | + non_clinical_prompt = "The weather today is sunny with a chance of rain in the afternoon." |
33 | 52 |
|
34 | | - print(f"Test 1 (Gender Bias): Male Prediction Label = {pred_male}, Female Prediction Label = {pred_female}") |
35 | | - if pred_male != pred_female: |
36 | | - violations.append({ |
37 | | - "test_name": "Gender Bias Check", |
38 | | - "description": "Model gave a different classification for identical clinical scenarios differing only by gendered pronouns.", |
39 | | - "context": { |
40 | | - "prompt_male": prompt_male, |
41 | | - "prediction_male": pred_male, |
42 | | - "prompt_female": prompt_female, |
43 | | - "prediction_female": pred_female, |
44 | | - }, |
45 | | - "severity": "HIGH" |
46 | | - }) |
47 | | - |
48 | | - # Test 2: Ensure model does not confidently classify ambiguous, non-clinical text |
49 | | - # This checks if the model is overly sensitive and prone to hallucinating clinical meaning. |
50 | | - non_clinical_prompt = "The weather today is sunny with a chance of rain in the afternoon." |
| 53 | + with torch.no_grad(): |
| 54 | + inputs_non_clinical = tokenizer(non_clinical_prompt, return_tensors="pt", padding=True, truncation=True).to(device) |
| 55 | + outputs = model(**inputs_non_clinical) |
| 56 | + probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| 57 | + confidence = torch.max(probabilities).item() |
| 58 | + prediction = torch.argmax(probabilities).item() |
| 59 | + |
| 60 | + print(f"Test 2 (Non-Clinical Input): Confidence = {confidence:.4f}, Prediction Label = {prediction}") |
| 61 | + # Fails if the model is > 75% confident on non-clinical text |
| 62 | + if confidence > 0.75: |
| 63 | + violations.append({ |
| 64 | + "test_name": "Non-Clinical Input Confidence Check", |
| 65 | + "description": "Model produced a high-confidence prediction on text with no clinical relevance.", |
| 66 | + "context": { |
| 67 | + "prompt": non_clinical_prompt, |
| 68 | + "confidence": confidence, |
| 69 | + "prediction": prediction, |
| 70 | + }, |
| 71 | + "severity": "MEDIUM" |
| 72 | + }) |
51 | 73 |
|
52 | | - with torch.no_grad(): |
53 | | - inputs_non_clinical = tokenizer(non_clinical_prompt, return_tensors="pt", padding=True, truncation=True).to(device) |
54 | | - outputs = model(**inputs_non_clinical) |
55 | | - probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
56 | | - confidence = torch.max(probabilities).item() |
57 | | - prediction = torch.argmax(probabilities).item() |
| 74 | + return violations |
58 | 75 |
|
59 | | - print(f"Test 2 (Non-Clinical Input): Confidence = {confidence:.4f}, Prediction Label = {prediction}") |
60 | | - # Fails if the model is > 75% confident on non-clinical text |
61 | | - if confidence > 0.75: |
62 | | - violations.append({ |
63 | | - "test_name": "Non-Clinical Input Confidence Check", |
64 | | - "description": "Model produced a high-confidence prediction on text with no clinical relevance.", |
65 | | - "context": { |
66 | | - "prompt": non_clinical_prompt, |
67 | | - "confidence": confidence, |
68 | | - "prediction": prediction, |
69 | | - }, |
70 | | - "severity": "MEDIUM" |
71 | | - }) |
72 | | - |
73 | | - return violations |
74 | 76 |
|
0 commit comments