Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/cml.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: mlops-linear-regression
on: [push]
jobs:
run:
runs-on: [ubuntu-latest]
container: docker://dvcorg/cml-py3:latest
steps:
- uses: actions/checkout@v2
- name: 'Train my linear model'
env:
repo_token: ${{ secrets.GITHUB_TOKEN }}
run: |
# Your ML workflow goes here
python -m pip install --upgrade pip
pip install -r requirements.txt
python tensorflow_linear_regression_model.py

echo "## Model Metrics" > report.md
cat metrics.txt >> report.md

echo "\n## Model Performance" >> report.md
echo "Model performance metrics are on the plot below." >> report.md

cml-publish model_results.png --md >> report.md

cml-send-comment report.md
88 changes: 0 additions & 88 deletions main.py

This file was deleted.

6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mlflow
cloudpickle==2.2.0
scikit-learn==1.0.2
tensorflow
numpy
matplotlib
102 changes: 102 additions & 0 deletions tensorflow_linear_regression_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Import modules and packages
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


# Functions and procedures
def plot_predictions(train_data, train_labels, test_data, test_labels, predictions):
"""
Plots training data, test data and compares predictions.
"""
plt.figure(figsize=(6, 5))
# Plot training data in blue
plt.scatter(train_data, train_labels, c="b", label="Training data")
# Plot test data in green
plt.scatter(test_data, test_labels, c="g", label="Testing data")
# Plot the predictions in red (predictions were made on the test data)
plt.scatter(test_data, predictions, c="r", label="Predictions")
# Show the legend
plt.legend(shadow='True')
# Set grids
plt.grid(which='major', c='#cccccc', linestyle='--', alpha=0.5)
# Some text
plt.title('Model Results', family='Arial', fontsize=14)
plt.xlabel('X axis values', family='Arial', fontsize=11)
plt.ylabel('Y axis values', family='Arial', fontsize=11)
# Show
plt.savefig('model_results.png', dpi=120)



def mae(y_test, y_pred):
"""
Calculuates mean absolute error between y_test and y_preds.
"""
return tf.metrics.mean_absolute_error(y_test, y_pred)


def mse(y_test, y_pred):
"""
Calculates mean squared error between y_test and y_preds.
"""
return tf.metrics.mean_squared_error(y_test, y_pred)


# Check Tensorflow version
print(tf.__version__)


# Create features
X = np.arange(-100, 100, 4)

# Create labels
y = np.arange(-90, 110, 4)


# Split data into train and test sets
X_train = X[:40] # first 40 examples (80% of data)
y_train = y[:40]

X_test = X[40:] # last 10 examples (20% of data)
y_test = y[40:]


# Take a single example of X
input_shape = X[0].shape

# Take a single example of y
output_shape = y[0].shape


# Set random seed
tf.random.set_seed(42)

# Create a model using the Sequential API
model = tf.keras.Sequential([
tf.keras.layers.Dense(1),
tf.keras.layers.Dense(1)
])

# Compile the model
model.compile(loss = tf.keras.losses.mae,
optimizer = tf.keras.optimizers.SGD(),
metrics = ['mae'])

# Fit the model
model.fit(X_train, y_train, epochs=100)


# Make and plot predictions for model_1
y_preds = model.predict(X_test)
plot_predictions(train_data=X_train, train_labels=y_train, test_data=X_test, test_labels=y_test, predictions=y_preds)


# Calculate model_1 metrics
mae_1 = np.round(float(mae(y_test, y_preds.squeeze()).numpy()), 2)
mse_1 = np.round(float(mse(y_test, y_preds.squeeze()).numpy()), 2)
print(f'\nMean Absolute Error = {mae_1}, Mean Squared Error = {mse_1}.')

# Write metrics to file
with open('metrics.txt', 'w') as outfile:
outfile.write(f'\nMean Absolute Error = {mae_1}, Mean Squared Error = {mse_1}.')