-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.py
More file actions
166 lines (129 loc) · 5.17 KB
/
client.py
File metadata and controls
166 lines (129 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""HTTP client SDK for RL-Testing-Env OpenEnv environment.
This module provides a Python client for programmatic interaction with the
RL-Testing-Env server. Use this when building custom agents or integrating
the environment into your training pipeline.
Example Usage:
from client import AutoTestEnvClient
from models import TestAction
# Connect to the environment server
client = AutoTestEnvClient("http://localhost:7860")
# Start a new episode
result = client.reset(task_id="unit_test_writer", seed=42)
print(result.observation.code_under_test)
# Submit test code
action = TestAction(
action_type="submit_tests",
test_code="def test_basic(): assert True"
)
result = client.step(action)
print(f"Reward: {result.reward}")
"""
import json
import urllib.request
import urllib.error
from typing import Optional
from models import TestAction, TestObservation, StepResult, EpisodeState
class AutoTestEnvClient:
"""Client for interacting with the RL-Testing-Env server.
This client provides a Python interface to the HTTP API exposed by
the RL-Testing-Env FastAPI server.
Attributes:
base_url: The base URL of the RL-Testing-Env server.
timeout: Default timeout for HTTP requests in seconds.
"""
def __init__(self, base_url: str = "http://localhost:7860", timeout: int = 60):
"""Initialize the client.
Args:
base_url: Base URL of the AutoTest-Env server.
timeout: Default timeout for HTTP requests in seconds.
"""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
def _post(self, endpoint: str, data: dict, timeout: Optional[int] = None) -> dict:
"""Make a POST request to the server.
Args:
endpoint: API endpoint (e.g., "/reset", "/step").
data: Dictionary to send as JSON body.
timeout: Optional timeout override.
Returns:
Parsed JSON response as dictionary.
Raises:
RuntimeError: If the request fails.
"""
url = f"{self.base_url}{endpoint}"
json_data = json.dumps(data).encode("utf-8")
req = urllib.request.Request(
url,
data=json_data,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=timeout or self.timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
raise RuntimeError(f"HTTP {e.code}: {e.reason}")
except urllib.error.URLError as e:
raise RuntimeError(f"URL Error: {e.reason}")
def _get(self, endpoint: str, timeout: Optional[int] = None) -> dict:
"""Make a GET request to the server.
Args:
endpoint: API endpoint (e.g., "/state", "/health").
timeout: Optional timeout override.
Returns:
Parsed JSON response as dictionary.
Raises:
RuntimeError: If the request fails.
"""
url = f"{self.base_url}{endpoint}"
try:
with urllib.request.urlopen(url, timeout=timeout or self.timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
raise RuntimeError(f"HTTP {e.code}: {e.reason}")
except urllib.error.URLError as e:
raise RuntimeError(f"URL Error: {e.reason}")
def reset(self, task_id: str = "unit_test_writer", seed: int = 42) -> StepResult:
"""Reset the environment and start a new episode.
Args:
task_id: The task to run. One of:
- "unit_test_writer" (easy)
- "coverage_audit" (medium)
- "regression_audit" (hard)
seed: Random seed for deterministic task generation.
Returns:
StepResult containing initial observation.
"""
data = self._post("/reset", {"task_id": task_id, "seed": seed})
return StepResult(**data)
def step(self, action: TestAction) -> StepResult:
"""Execute one step in the environment.
Args:
action: The TestAction to execute.
Returns:
StepResult containing new observation, reward, and done flag.
"""
data = self._post("/step", action.model_dump())
return StepResult(**data)
def state(self) -> EpisodeState:
"""Get the current episode state.
Returns:
EpisodeState with current progress information.
"""
data = self._get("/state")
return EpisodeState(**data)
def health(self) -> dict:
"""Check server health.
Returns:
Health status dictionary with status, env name, and version.
"""
return self._get("/health")
# Convenience function for quick client creation
def connect(base_url: str = "http://localhost:7860") -> AutoTestEnvClient:
"""Create and return an AutoTestEnvClient.
Args:
base_url: Base URL of the AutoTest-Env server.
Returns:
Configured AutoTestEnvClient instance.
"""
return AutoTestEnvClient(base_url)