|
4 | 4 | from PIL import Image |
5 | 5 | from celery import shared_task |
6 | 6 | import cv2 |
| 7 | +from django.conf import settings |
7 | 8 | from django.core.files import File |
8 | 9 | import numpy as np |
9 | 10 | import scipy |
@@ -243,3 +244,56 @@ def predict(compressed_spectrogram_id: int): |
243 | 244 | recording_annotation.species.set(species) |
244 | 245 | recording_annotation.save() |
245 | 246 | return label, score, confs |
| 247 | + |
| 248 | + |
| 249 | +def train_body(experiment_name: str): |
| 250 | + import mlflow |
| 251 | + from mlflow.models import infer_signature |
| 252 | + from sklearn import datasets |
| 253 | + from sklearn.linear_model import LogisticRegression |
| 254 | + from sklearn.metrics import accuracy_score |
| 255 | + from sklearn.model_selection import train_test_split |
| 256 | + |
| 257 | + X, y = datasets.load_iris(return_X_y=True) |
| 258 | + |
| 259 | + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
| 260 | + |
| 261 | + params = { |
| 262 | + 'solver': 'lbfgs', |
| 263 | + 'max_iter': 1000, |
| 264 | + 'multi_class': 'auto', |
| 265 | + 'random_state': 8888, |
| 266 | + } |
| 267 | + |
| 268 | + lr = LogisticRegression(**params) |
| 269 | + lr.fit(X_train, y_train) |
| 270 | + |
| 271 | + y_pred = lr.predict(X_test) |
| 272 | + |
| 273 | + accuracy = accuracy_score(y_test, y_pred) |
| 274 | + |
| 275 | + mlflow.set_tracking_uri(settings.MLFLOW_ENDPOINT) |
| 276 | + mlflow.set_experiment(experiment_name) |
| 277 | + |
| 278 | + print(mlflow.get_tracking_uri()) |
| 279 | + print(mlflow.get_artifact_uri()) |
| 280 | + |
| 281 | + mlflow.end_run() |
| 282 | + with mlflow.start_run(): |
| 283 | + mlflow.log_params(params) |
| 284 | + mlflow.log_metric('accuracy', accuracy) |
| 285 | + mlflow.set_tag('Training Info', 'Basic LR model for iris data') |
| 286 | + |
| 287 | + signature = infer_signature(X_train, lr.predict(X_train)) |
| 288 | + _ = mlflow.sklearn.log_model( |
| 289 | + sk_model=lr, |
| 290 | + artifact_path='iris_model', |
| 291 | + signature=signature, |
| 292 | + input_example=X_train, |
| 293 | + registered_model_name='tracking-quickstart', |
| 294 | + ) |
| 295 | + |
| 296 | + |
| 297 | +@shared_task |
| 298 | +def example_train(experiment_name: str): |
| 299 | + train_body(experiment_name) |
0 commit comments