Skip to content

Commit c596612

Browse files
train.py fix attempt #8
1 parent 6dcd244 commit c596612

1 file changed

Lines changed: 34 additions & 5 deletions

File tree

train.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,27 @@
55
"""
66

77
import os
8+
import sys
89

910
# ВАЖНО: Установка переменных окружения ДО импорта MLflow
10-
# Используем текущую рабочую директорию (работает и локально, и в CI/CD)
11-
WORK_DIR = os.getcwd()
11+
# Используем абсолютный путь к текущей рабочей директории
12+
WORK_DIR = os.path.abspath(os.getcwd())
1213
MLRUNS_PATH = os.path.join(WORK_DIR, 'mlruns')
13-
os.environ['MLFLOW_TRACKING_URI'] = f'file:{MLRUNS_PATH}'
14+
15+
# Создаем директорию mlruns заранее, чтобы MLflow не пытался создавать её в другом месте
16+
os.makedirs(MLRUNS_PATH, exist_ok=True)
17+
18+
# Устанавливаем переменные окружения
19+
os.environ['MLFLOW_TRACKING_URI'] = f'file://{MLRUNS_PATH}'
1420
os.environ['MLFLOW_ARTIFACT_ROOT'] = MLRUNS_PATH
21+
os.environ['MLFLOW_REGISTRY_URI'] = f'file://{MLRUNS_PATH}'
22+
23+
# В CI/CD окружении дополнительно переопределяем HOME
24+
if os.getenv('CI') == 'true' or os.getenv('GITHUB_ACTIONS') == 'true':
25+
os.environ['HOME'] = WORK_DIR
26+
print(f"CI/CD обнаружен, HOME установлен на: {WORK_DIR}")
27+
28+
print(f"MLflow будет использовать: {MLRUNS_PATH}")
1529

1630
import numpy as np
1731
import pandas as pd
@@ -53,10 +67,15 @@ def load_and_prepare_data():
5367
def train_model(X_train, X_test, y_train, y_test, n_estimators=100, max_depth=5, random_state=42):
5468
"""Обучение модели Random Forest с логированием в MLflow"""
5569

70+
# Явно устанавливаем tracking URI перед каждым экспериментом
71+
MLRUNS_PATH = os.path.join(os.path.abspath(os.getcwd()), 'mlruns')
72+
mlflow.set_tracking_uri(f'file://{MLRUNS_PATH}')
73+
5674
# Установка имени эксперимента
57-
# (tracking URI уже настроен через переменные окружения в начале файла)
5875
mlflow.set_experiment("iris_classification")
5976

77+
print(f"MLflow tracking URI: {mlflow.get_tracking_uri()}")
78+
6079
with mlflow.start_run():
6180
print("\nНачало обучения модели...")
6281

@@ -116,7 +135,17 @@ def train_model(X_train, X_test, y_train, y_test, n_estimators=100, max_depth=5,
116135
try:
117136
mlflow.sklearn.log_model(model, "model")
118137
mlflow.log_artifact(model_path)
119-
print("\nМодель успешно залогирована в MLflow")
138+
139+
# Получаем информацию о текущем эксперименте
140+
run_id = mlflow.active_run().info.run_id
141+
experiment_id = mlflow.active_run().info.experiment_id
142+
mlflow_uri = mlflow.get_tracking_uri()
143+
144+
print(f"\nМодель успешно залогирована в MLflow")
145+
print(f"Tracking URI: {mlflow_uri}")
146+
print(f"Experiment ID: {experiment_id}")
147+
print(f"Run ID: {run_id}")
148+
print(f"MLFlow артефакты сохранены в: {MLRUNS_PATH}/{experiment_id}/{run_id}/artifacts/")
120149
except Exception as e:
121150
print(f"\nПредупреждение: не удалось залогировать модель в MLflow: {e}")
122151
print("Модель сохранена локально, продолжаем работу...")

0 commit comments

Comments
 (0)