Skip to content

Commit 1aa3a65

Browse files
committed
added ml_training_poc
1 parent d66632f commit 1aa3a65

File tree

2 files changed

+286
-0
lines changed

2 files changed

+286
-0
lines changed

singlestoredb/ml/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ml_train import register_train_command

singlestoredb/ml/ml_train.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import os
2+
import logging
3+
from typing import List, Dict, Any, Optional
4+
from datetime import datetime
5+
6+
import requests
7+
from IPython.core.magic import Magics, magics_class, line_magic
8+
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
9+
10+
class SingleStoreAPIError(Exception):
11+
"""Raised when the SingleStore Management API returns an error."""
12+
def __init__(self, status_code: int, message: str, response: requests.Response):
13+
super().__init__(f"SingleStore API Error {status_code}: {message}")
14+
self.status_code = status_code
15+
self.response = response
16+
17+
18+
class SingleStoreJobsClient:
19+
"""Thin wrapper around the SingleStore Management API for notebook Jobs."""
20+
21+
def __init__(
22+
self,
23+
jwt_token: Optional[str] = None,
24+
base_url: str = "https://api.singlestore.com/v1",
25+
):
26+
token = jwt_token or "09f75c43e5ed6ceb1cdf34ea35ced436599450967e3772a34f774f4f36a49945"
27+
# token = jwt_token or os.environ.get("SINGLESTOREDB_USER_TOKEN")
28+
# token = jwt_token or os.environ.get("SINGLESTOREDB_APP_TOKEN")
29+
30+
31+
if not token:
32+
raise ValueError("Set your JWT in SINGLESTORE_JWT or pass it in.")
33+
self.base_url = base_url.rstrip("/")
34+
self.session = requests.Session()
35+
self.session.headers.update({
36+
"Authorization": f"Bearer {token}",
37+
"Content-Type": "application/json",
38+
"Accept": "application/json",
39+
})
40+
41+
def list_jobs(self) -> List[Dict[str, Any]]:
42+
"""Fetch all jobs; filter client‑side by name."""
43+
url = f"{self.base_url}/jobs"
44+
resp = self.session.get(url)
45+
if not resp.ok:
46+
raise SingleStoreAPIError(resp.status_code, resp.text, resp)
47+
return resp.json() # assume a JSON array
48+
49+
def create_job(
50+
self,
51+
name: str,
52+
description: str,
53+
notebook_path: str,
54+
runtime_name: str,
55+
parameters: List[Dict[str, Any]],
56+
schedule_interval_minutes: int,
57+
schedule_mode: str,
58+
schedule_start_at: datetime,
59+
target_config: Dict[str, Any],
60+
create_snapshot: bool = True,
61+
) -> Dict[str, Any]:
62+
payload = {
63+
"name": name,
64+
"description": description,
65+
"executionConfig": {
66+
"notebookPath": notebook_path,
67+
"runtimeName": runtime_name,
68+
"createSnapshot": create_snapshot,
69+
},
70+
"parameters": parameters,
71+
"schedule": {
72+
"executionIntervalInMinutes": schedule_interval_minutes,
73+
"mode": schedule_mode,
74+
"startAt": schedule_start_at.replace(microsecond=0).isoformat() + "Z",
75+
},
76+
"targetConfig": target_config,
77+
}
78+
url = f"{self.base_url}/jobs"
79+
resp = self.session.post(url, json=payload)
80+
if not resp.ok:
81+
raise SingleStoreAPIError(resp.status_code, resp.text, resp)
82+
return resp.json()
83+
84+
def get_job(self, job_id: str) -> Dict[str, Any]:
85+
url = f"{self.base_url}/jobs/{job_id}"
86+
resp = self.session.get(url)
87+
print(resp)
88+
if not resp.ok:
89+
raise SingleStoreAPIError(resp.status_code, resp.text, resp)
90+
return resp.json()
91+
92+
import os
93+
import logging
94+
from typing import List, Dict, Any, Optional
95+
from datetime import datetime
96+
97+
import requests
98+
from IPython.core.magic import Magics, magics_class, line_magic
99+
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
100+
101+
102+
class SingleStoreAPIError(Exception):
103+
"""Raised when the SingleStore Management API returns an error."""
104+
def __init__(self, status_code: int, message: str, response: requests.Response):
105+
super().__init__(f"SingleStore API Error {status_code}: {message}")
106+
self.status_code = status_code
107+
self.response = response
108+
109+
110+
class SingleStoreJobsClient:
111+
"""Thin wrapper around the SingleStore Management API for notebook Jobs."""
112+
113+
def __init__(
114+
self,
115+
jwt_token: Optional[str] = None,
116+
base_url: str = "https://api.singlestore.com/v1",
117+
):
118+
token = jwt_token or "09f75c43e5ed6ceb1cdf34ea35ced436599450967e3772a34f774f4f36a49945"
119+
# token = jwt_token or os.environ.get("SINGLESTOREDB_USER_TOKEN")
120+
# token = jwt_token or os.environ.get("SINGLESTOREDB_APP_TOKEN")
121+
122+
123+
if not token:
124+
raise ValueError("Set your JWT in SINGLESTORE_JWT or pass it in.")
125+
self.base_url = base_url.rstrip("/")
126+
self.session = requests.Session()
127+
self.session.headers.update({
128+
"Authorization": f"Bearer {token}",
129+
"Content-Type": "application/json",
130+
"Accept": "application/json",
131+
})
132+
133+
def list_jobs(self) -> List[Dict[str, Any]]:
134+
"""Fetch all jobs; filter client‑side by name."""
135+
url = f"{self.base_url}/jobs"
136+
resp = self.session.get(url)
137+
if not resp.ok:
138+
raise SingleStoreAPIError(resp.status_code, resp.text, resp)
139+
return resp.json() # assume a JSON array
140+
141+
def create_job(
142+
self,
143+
name: str,
144+
description: str,
145+
notebook_path: str,
146+
runtime_name: str,
147+
parameters: List[Dict[str, Any]],
148+
schedule_interval_minutes: int,
149+
schedule_mode: str,
150+
schedule_start_at: datetime,
151+
target_config: Dict[str, Any],
152+
create_snapshot: bool = True,
153+
) -> Dict[str, Any]:
154+
payload = {
155+
"name": name,
156+
"description": description,
157+
"executionConfig": {
158+
"notebookPath": notebook_path,
159+
"runtimeName": runtime_name,
160+
"createSnapshot": create_snapshot,
161+
},
162+
"parameters": parameters,
163+
"schedule": {
164+
"executionIntervalInMinutes": schedule_interval_minutes,
165+
"mode": schedule_mode,
166+
"startAt": schedule_start_at.replace(microsecond=0).isoformat() + "Z",
167+
},
168+
"targetConfig": target_config,
169+
}
170+
url = f"{self.base_url}/jobs"
171+
resp = self.session.post(url, json=payload)
172+
if not resp.ok:
173+
raise SingleStoreAPIError(resp.status_code, resp.text, resp)
174+
return resp.json()
175+
176+
def get_job(self, job_id: str) -> Dict[str, Any]:
177+
url = f"{self.base_url}/jobs/{job_id}"
178+
resp = self.session.get(url)
179+
print(resp)
180+
if not resp.ok:
181+
raise SingleStoreAPIError(resp.status_code, resp.text, resp)
182+
return resp.json()
183+
184+
185+
@magics_class
186+
class SSMLFunctionMagics(Magics):
187+
"""Line magic that schedules (or skips) a classification‑train job via REST."""
188+
189+
def __init__(self, shell):
190+
super().__init__(shell)
191+
self.logger = logging.getLogger(__name__)
192+
if not self.logger.handlers:
193+
h = logging.StreamHandler()
194+
h.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
195+
self.logger.addHandler(h)
196+
self.logger.setLevel(logging.INFO)
197+
198+
@magic_arguments()
199+
@argument('--job_name', type=str, required=True, help="Job name")
200+
@argument('--workspace', type=str, required=True, help="Workspace identifier")
201+
@argument('--db', type=str, required=True, help="Database/schema name")
202+
@argument('--input_table', type=str, required=True, help="Source table")
203+
@argument('--target_column', type=str, required=True, help="Target column")
204+
@argument('--model', type=str, default='auto', help="Model to train")
205+
@argument('--evaluation_criteria', type=str, required=True, help="Metric (roc_auc, etc.)")
206+
@argument('--selected_features', type=str, nargs='+', help="List of features")
207+
@line_magic
208+
def SS_ML_FUNCTION_CLASSIFICATION_TRAIN(self, line: str):
209+
args = parse_argstring(self.SS_ML_FUNCTION_CLASSIFICATION_TRAIN, line)
210+
job_name = args.job_name
211+
job_descr = f"Train classification on {args.input_table} target={args.target_column} job_name={job_name}"
212+
213+
# 2) Instantiate client & check for existing job
214+
client = SingleStoreJobsClient()
215+
# try:
216+
# job = client.get_job(job_name)
217+
# print(job)
218+
# if job.get("name") == job_name:
219+
# existing = True
220+
# except SingleStoreAPIError as e:
221+
# self.logger.error(f"Failed to list jobs: {e}")
222+
# return
223+
224+
# if existing:
225+
# job = existing[0]
226+
# self.logger.info(
227+
# f"Job '{job_name}' already exists (ID={job.get('id')}); skipping creation."
228+
# )
229+
# return
230+
231+
# 3) Build parameters payload
232+
params: List[Dict[str, Any]] = []
233+
for key in ('job_name','workspace','db','input_table','target_column','model','evaluation_criteria'):
234+
params.append({
235+
"name": key,
236+
"type": "string",
237+
"value": getattr(args, key),
238+
})
239+
if args.selected_features:
240+
params.append({
241+
"name": "selected_features",
242+
"type": "string",
243+
"value": ",".join(args.selected_features),
244+
})
245+
246+
# 4) Schedule + targetConfig defaults (from ENV)
247+
runtime = os.environ.get("SINGLESTORE_RUNTIME", "notebooks-cpu-small")
248+
target_id = os.environ.get("SINGLESTOREDB_WORKSPACE")
249+
if not target_id:
250+
raise ValueError("Please set SINGLESTORE_TARGET_ID in your env.")
251+
target_cfg = {
252+
"databaseName": args.db,
253+
"resumeTarget": True,
254+
"targetType": "Workspace",
255+
"targetID": target_id,
256+
}
257+
258+
# 5) Create the job
259+
try:
260+
created = client.create_job(
261+
name=job_name,
262+
description=job_descr,
263+
notebook_path="ml_Classification_pipeline.ipynb",
264+
runtime_name=runtime,
265+
parameters=params,
266+
schedule_interval_minutes=None,
267+
schedule_mode="Once",
268+
schedule_start_at=datetime.utcnow(),
269+
target_config=target_cfg,
270+
create_snapshot=True
271+
)
272+
except SingleStoreAPIError as e:
273+
self.logger.error(f"Failed to create job '{job_name}': {e}")
274+
return
275+
276+
self.logger.info(
277+
f"Created job '{created.get('name')}' (ID={created.get('jobID')})"
278+
)
279+
280+
281+
# register so the magic is available immediately:
282+
# get_ipython().register_magics(SSMLFunctionMagics)
283+
def register_train_command():
284+
"""Register the SSMLFunctionMagics with the IPython shell."""
285+
get_ipython().register_magics(SSMLFunctionMagics)

0 commit comments

Comments
 (0)