-
Notifications
You must be signed in to change notification settings - Fork 215
Expand file tree
/
Copy pathsmile.py
More file actions
239 lines (197 loc) · 9.46 KB
/
smile.py
File metadata and controls
239 lines (197 loc) · 9.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#!/usr/bin/env python3
"""
MSAL Feature Test Runner
Interprets testcase file(s) to create and execute test cases using MSAL.
Initially created by the following prompt:
Write a python implementation that can read content from feature.yml, create variables whose names are defined in the "arrange" mapping's keys, and the variables' value are derived from the "arrange" mapping's value; interpret those value as if they are python snippet using MSAL library.
"""
import os
import sys
import logging
from contextlib import contextmanager
from typing import Dict, Any, List, Optional
import yaml
import msal
import requests
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class SmileTestRunner:
def __init__(self, testcase_url: str):
self.testcase_url = testcase_url
self.test_spec = None
self.variables = {}
def load_feature(self) -> Dict[str, Any]:
"""Load and validate the feature file."""
try:
with requests.get(self.testcase_url) as response:
response.raise_for_status()
self.test_spec = yaml.safe_load(response.text)
# Basic validation
if not isinstance(self.test_spec, dict):
raise ValueError("Feature file must contain a valid YAML dictionary")
if self.test_spec.get('type') != 'MSAL Test':
raise ValueError("Feature file must have type 'MSAL Test'")
return self.test_spec
except Exception as e:
logger.error(f"Error loading feature file: {str(e)}")
sys.exit(1)
@contextmanager
def setup_environment(self):
"""Set up the environment variables specified in the feature file."""
original_env = os.environ.copy()
try:
# Set environment variables
if 'env' in self.test_spec and isinstance(self.test_spec['env'], dict):
for key, value in self.test_spec['env'].items():
os.environ[key] = str(value)
logger.debug(f"Set environment variable {key}={value}")
yield
finally:
# Restore original environment
os.environ.clear()
os.environ.update(original_env)
def arrange(self):
"""Create variables based on the arrange section."""
arrange_spec = self.test_spec.get('arrange', {})
if not isinstance(arrange_spec, dict):
raise ValueError("Arrange section must be a dictionary")
for var_name, value_spec in arrange_spec.items():
logger.debug(f"Creating variable '{var_name}' with {value_spec}")
self.variables[var_name] = self._create_instance(value_spec)
def _create_instance(self, spec: Dict[str, Any]) -> Any:
"""Create an instance based on the specification."""
if not isinstance(spec, dict) or len(spec) != 1:
raise ValueError(f"Invalid specification format: {spec}")
class_name, params = next(iter(spec.items()))
# Handle different MSAL classes
if class_name == "ManagedIdentityClient":
return msal.ManagedIdentityClient(http_client=requests.Session(), **params)
elif class_name == "PublicClientApplication":
return self._create_public_client_app(params)
elif class_name == "ConfidentialClientApplication":
return self._create_confidential_client_app(params)
else:
raise ValueError(f"Unsupported class: {class_name}")
def _create_public_client_app(self, params: Dict[str, Any]) -> Any:
"""Create a PublicClientApplication instance."""
if not params or 'client_id' not in params:
raise ValueError("PublicClientApplication requires client_id")
client_id = params.get('client_id')
authority = params.get('authority')
logger.debug(f"Creating PublicClientApplication with client_id: {client_id}, authority: {authority}")
kwargs = {'client_id': client_id}
if authority:
kwargs['authority'] = authority
return msal.PublicClientApplication(**kwargs)
def _create_confidential_client_app(self, params: Dict[str, Any]) -> Any:
"""Create a ConfidentialClientApplication instance."""
if not params or 'client_id' not in params or 'client_credential' not in params:
raise ValueError("ConfidentialClientApplication requires client_id and client_credential")
kwargs = {
"client_id": params.get('client_id'),
"client_credential": params.get('client_credential'),
"authority": params.get('authority'),
"oidc_authority": params.get('oidc_authority'),
}
logger.debug(f"Creating ConfidentialClientApplication with {kwargs}")
return msal.ConfidentialClientApplication(**kwargs)
def execute_steps(self) -> bool:
"""Execute the test steps, returns whether all steps passed."""
steps = self.test_spec.get('steps', [])
passed = 0
for i, step in enumerate(steps):
logger.debug(f"Executing step {i+1}/{len(steps)}")
if 'act' in step:
result = self._execute_action(step['act'])
if 'assert' in step:
if self._validate_assertions(result, step['assert']):
passed += 1
logger.info(f"{passed} of {len(steps)} step(s) passed")
return passed == len(steps)
def _execute_action(self, act_spec: Dict[str, Any]) -> Any:
"""Execute an action based on the specification."""
if not isinstance(act_spec, dict) or len(act_spec) != 1:
raise ValueError(f"Invalid action specification: {act_spec}")
action_str, params = next(iter(act_spec.items()))
# Parse the action string (e.g., "app1.AcquireToken")
parts = action_str.split('.')
if len(parts) != 2:
raise ValueError(f"Invalid action format: {action_str}")
var_name = parts[0]
method_name = { # Map the method names in yml to actual method names
"AcquireTokenForManagedIdentity": "acquire_token_for_client", # For ManagedIdentityClient
"AcquireTokenForClient": "acquire_token_for_client", # For ConfidentialClientApplication
}.get(parts[1])
if method_name is None:
raise ValueError(f"Unsupported method: {parts[1]}")
if var_name not in self.variables:
raise ValueError(f"Variable '{var_name}' not found")
instance = self.variables[var_name]
if not hasattr(instance, method_name):
raise ValueError(f"Method '{method_name}' not found on {var_name}")
method = getattr(instance, method_name)
# Convert parameters to kwargs
kwargs = params if params else {}
# Execute the method with parameters
logger.info(f"Calling {var_name}.{method_name} with {kwargs}")
return method(**kwargs)
def _validate_assertions(self, result: Any, assertions: Dict[str, Any]) -> bool:
"""Validate the assertions against the result."""
logger.info(f"Validating assertions: {assertions}")
for key, expected_value in assertions.items():
if key not in result:
logger.error(f"Assertion failed: '{key}' not found in result {result}")
return False # Failed
actual_value = result[key]
if actual_value != expected_value:
logger.error(f"Assertion failed: expected {key}='{expected_value}', got '{actual_value}'")
return False # Failed
else:
logger.debug(f"Assertion passed: {key}='{actual_value}'")
return True # Passed
def run(self) -> bool:
"""Run the entire test, returns whether it passed."""
self.load_feature()
with self.setup_environment():
self.arrange()
result = self.execute_steps()
if result:
logger.info(f"Test case {self.testcase_url} passed")
else:
logger.error(f"Test case {self.testcase_url} failed")
return result
def run_testcases(testcases_url: str) -> bool:
try:
response = requests.get(testcases_url)
response.raise_for_status()
passed = 0
testcases = response.json().get("testcases", [])
for url in testcases:
try:
if SmileTestRunner(url).run():
passed += 1
except Exception as e:
logger.error(f"Test case {url} failed: {e}")
(logger.info if passed == len(testcases) else logger.error)(
f"Passed {passed} of {len(testcases)} test cases"
)
return passed == len(testcases)
except requests.RequestException as e:
logger.error(f"Failed to fetch test cases from {testcases_url}: {e}")
raise
def main():
import argparse
parser = argparse.ArgumentParser(description="MSAL Feature Test Runner")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--testcase", help="URL for a single test case")
group.add_argument("--batch", help="URL for a batch of test cases in JSON format")
args = parser.parse_args()
if args.testcase:
logger.setLevel(logging.DEBUG)
success = SmileTestRunner(args.testcase).run()
elif args.batch:
logger.setLevel(logging.INFO)
success = run_testcases(args.batch)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()