Skip to content

Commit 0ca567f

Browse files
committed
Move example model training to tasks
1 parent f4d68dd commit 0ca567f

2 files changed

Lines changed: 70 additions & 9 deletions

File tree

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1-
from django.conf import settings
21
import djclick as click
3-
from mlflow import MlflowClient
4-
from sklearn.datasets import load_diabetes
5-
from sklearn.ensemble import RandomForestRegressor
6-
from sklearn.model_selection import train_test_split
2+
import mlflow
73

4+
from bats_ai.core.tasks import example_train
85

9-
@click.command()
10-
def command():
11-
click.echo("Running Mlflow experiment")
126

13-
client = MlflowClient(tracking_uri=settings.MLFLOW_ENDPOINT)
7+
@click.command()
8+
@click.option('--experiment-name', type=click.STRING, required=False, default='Default')
9+
def command(experiment_name):
10+
click.echo('Finding experiment')
11+
experiment = mlflow.get_experiment_by_name(experiment_name)
12+
if experiment:
13+
click.echo(f'Creating a log for experiment {experiment_name}')
14+
example_train.delay(experiment_name)
15+
# train_body(experiment_name)
16+
else:
17+
click.echo(
18+
f'Could not find experiment {experiment_name}.'
19+
' Use the create experiment command to create a new experiement.'
20+
)

bats_ai/core/tasks.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from PIL import Image
55
from celery import shared_task
66
import cv2
7+
from django.conf import settings
78
from django.core.files import File
89
import numpy as np
910
import scipy
@@ -243,3 +244,56 @@ def predict(compressed_spectrogram_id: int):
243244
recording_annotation.species.set(species)
244245
recording_annotation.save()
245246
return label, score, confs
247+
248+
249+
def train_body(experiment_name: str):
250+
import mlflow
251+
from mlflow.models import infer_signature
252+
from sklearn import datasets
253+
from sklearn.linear_model import LogisticRegression
254+
from sklearn.metrics import accuracy_score
255+
from sklearn.model_selection import train_test_split
256+
257+
X, y = datasets.load_iris(return_X_y=True)
258+
259+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
260+
261+
params = {
262+
'solver': 'lbfgs',
263+
'max_iter': 1000,
264+
'multi_class': 'auto',
265+
'random_state': 8888,
266+
}
267+
268+
lr = LogisticRegression(**params)
269+
lr.fit(X_train, y_train)
270+
271+
y_pred = lr.predict(X_test)
272+
273+
accuracy = accuracy_score(y_test, y_pred)
274+
275+
mlflow.set_tracking_uri(settings.MLFLOW_ENDPOINT)
276+
mlflow.set_experiment(experiment_name)
277+
278+
print(mlflow.get_tracking_uri())
279+
print(mlflow.get_artifact_uri())
280+
281+
mlflow.end_run()
282+
with mlflow.start_run():
283+
mlflow.log_params(params)
284+
mlflow.log_metric('accuracy', accuracy)
285+
mlflow.set_tag('Training Info', 'Basic LR model for iris data')
286+
287+
signature = infer_signature(X_train, lr.predict(X_train))
288+
_ = mlflow.sklearn.log_model(
289+
sk_model=lr,
290+
artifact_path='iris_model',
291+
signature=signature,
292+
input_example=X_train,
293+
registered_model_name='tracking-quickstart',
294+
)
295+
296+
297+
@shared_task
298+
def example_train(experiment_name: str):
299+
train_body(experiment_name)

0 commit comments

Comments
 (0)