|
| 1 | +import sys |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +from scipy.stats import ks_2samp |
| 6 | + |
| 7 | +from network_security.constant.training_pipeline import SCHEMA_FILE_PATH |
| 8 | +from network_security.entity.artifact_entity import ( |
| 9 | + DataIngestionArtifact, |
| 10 | + DataValidationArtifact, |
| 11 | +) |
| 12 | +from network_security.entity.config_entity import DataValidationConfig |
| 13 | +from network_security.exception.exception import NetworkSecurityException |
| 14 | +from network_security.logging.logger import logging |
| 15 | +from network_security.utils.main_utils.utils import read_yaml_file, write_yaml_file |
| 16 | + |
| 17 | + |
| 18 | +class DataValidation: |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + data_ingestion_artifact: DataIngestionArtifact, |
| 22 | + data_validation_config: DataValidationConfig, |
| 23 | + ) -> None: |
| 24 | + try: |
| 25 | + self.data_ingestion_artifact = data_ingestion_artifact |
| 26 | + self.data_validation_config = data_validation_config |
| 27 | + self._schema_config = read_yaml_file(SCHEMA_FILE_PATH) |
| 28 | + self._numerical_columns = self._schema_config.get("numerical_columns", []) |
| 29 | + except Exception as e: |
| 30 | + raise NetworkSecurityException(e, sys) |
| 31 | + |
| 32 | + @staticmethod |
| 33 | + def read_data(file_path: str) -> pd.DataFrame: |
| 34 | + try: |
| 35 | + return pd.read_csv(file_path) |
| 36 | + except Exception as e: |
| 37 | + raise NetworkSecurityException(e, sys) |
| 38 | + |
| 39 | + def validate_number_of_columns(self, dataframe: pd.DataFrame) -> bool: |
| 40 | + try: |
| 41 | + number_of_columns = len(self._schema_config["columns"]) |
| 42 | + logging.info(f"Required number of columns:{number_of_columns}") |
| 43 | + logging.info(f"Data frame has columns:{len(dataframe.columns)}") |
| 44 | + return len(dataframe.columns) == number_of_columns |
| 45 | + except Exception as e: |
| 46 | + raise NetworkSecurityException(e, sys) |
| 47 | + |
| 48 | + def validate_numerical_columns_exist(self, dataframe: pd.DataFrame) -> bool: |
| 49 | + """ |
| 50 | + Validates whether all required numerical columns exist in the given DataFrame. |
| 51 | +
|
| 52 | + Returns: |
| 53 | + bool: True if all required numerical columns exist and are numeric, False otherwise. |
| 54 | +
|
| 55 | + """ |
| 56 | + try: |
| 57 | + required_numerical_columns = self._numerical_columns |
| 58 | + missing_columns = [] |
| 59 | + non_numeric_columns = [] |
| 60 | + |
| 61 | + for column in required_numerical_columns: |
| 62 | + if column not in dataframe.columns: |
| 63 | + missing_columns.append(column) |
| 64 | + elif not pd.api.types.is_numeric_dtype(dataframe[column]): |
| 65 | + non_numeric_columns.append(column) |
| 66 | + |
| 67 | + if missing_columns: |
| 68 | + logging.info(f"Missing numerical columns: {missing_columns}") |
| 69 | + if non_numeric_columns: |
| 70 | + logging.info(f"Columns not of numeric type: {non_numeric_columns}") |
| 71 | + |
| 72 | + return len(missing_columns) == 0 and len(non_numeric_columns) == 0 |
| 73 | + |
| 74 | + except Exception as e: |
| 75 | + raise NetworkSecurityException(e, sys) |
| 76 | + |
| 77 | + |
| 78 | + def detect_dataset_drift(self, base_df: pd.DataFrame, current_df: pd.DataFrame, threshold: float = 0.05) -> bool: |
| 79 | + try: |
| 80 | + report = {} |
| 81 | + for column in base_df.columns: |
| 82 | + d1 = base_df[column] |
| 83 | + d2 = current_df[column] |
| 84 | + is_same_dist = ks_2samp(d1, d2) |
| 85 | + is_found = not threshold <= is_same_dist.pvalue |
| 86 | + report.update( |
| 87 | + { |
| 88 | + column: { |
| 89 | + "p_value": float(is_same_dist.pvalue), |
| 90 | + "drift_status": is_found, |
| 91 | + }, |
| 92 | + }, |
| 93 | + ) |
| 94 | + drift_report_file_path = self.data_validation_config.drift_report_file_path |
| 95 | + |
| 96 | + dir_path = Path(drift_report_file_path).parent |
| 97 | + dir_path.mkdir(parents=True, exist_ok=True) |
| 98 | + write_yaml_file(file_path=drift_report_file_path, content=report) |
| 99 | + write_yaml_file(file_path=drift_report_file_path, content=report) |
| 100 | + |
| 101 | + except Exception as e: |
| 102 | + raise NetworkSecurityException(e, sys) |
| 103 | + |
| 104 | + def initiate_data_validation(self) -> DataValidationArtifact: |
| 105 | + try: |
| 106 | + train_file_path = self.data_ingestion_artifact.trained_file_path |
| 107 | + test_file_path = self.data_ingestion_artifact.test_file_path |
| 108 | + |
| 109 | + ## Read the data from train and test |
| 110 | + train_dataframe = DataValidation.read_data(train_file_path) |
| 111 | + test_dataframe = DataValidation.read_data(test_file_path) |
| 112 | + |
| 113 | + ## Validate number of columns |
| 114 | + status = self.validate_number_of_columns(dataframe=train_dataframe) |
| 115 | + if not status: |
| 116 | + logging.info("Train dataframe does not contain all columns.\n") |
| 117 | + |
| 118 | + status = self.validate_number_of_columns(dataframe=test_dataframe) |
| 119 | + if not status: |
| 120 | + logging.info("Test dataframe does not contain all columns.\n") |
| 121 | + |
| 122 | + # Validate numerical columns |
| 123 | + status = self.validate_numerical_columns_exist(train_dataframe) |
| 124 | + if not status: |
| 125 | + logging.info("Train dataframe is missing required numerical columns or types.\n") |
| 126 | + |
| 127 | + status = self.validate_numerical_columns_exist(test_dataframe) |
| 128 | + if not status: |
| 129 | + logging.info("Test dataframe is missing required numerical columns or types.\n") |
| 130 | + |
| 131 | + ## Check data drift |
| 132 | + status = self.detect_dataset_drift( |
| 133 | + base_df=train_dataframe, current_df=test_dataframe) |
| 134 | + dir_path = Path(self.data_validation_config.valid_train_file_path).parent |
| 135 | + dir_path.mkdir(parents=True, exist_ok=True) |
| 136 | + |
| 137 | + train_dataframe.to_csv( |
| 138 | + self.data_validation_config.valid_train_file_path, |
| 139 | + index=False, |
| 140 | + header=True, |
| 141 | + ) |
| 142 | + |
| 143 | + test_dataframe.to_csv( |
| 144 | + self.data_validation_config.valid_test_file_path, |
| 145 | + index=False, |
| 146 | + header=True, |
| 147 | + ) |
| 148 | + |
| 149 | + data_validation_artifact = DataValidationArtifact( |
| 150 | + validation_status=status, |
| 151 | + valid_train_file_path=self.data_ingestion_artifact.trained_file_path, |
| 152 | + valid_test_file_path=self.data_ingestion_artifact.test_file_path, |
| 153 | + invalid_train_file_path=None, |
| 154 | + invalid_test_file_path=None, |
| 155 | + drift_report_file_path=self.data_validation_config.drift_report_file_path, |
| 156 | + ) |
| 157 | + return data_validation_artifact |
| 158 | + except Exception as e: |
| 159 | + raise NetworkSecurityException(e, sys) |
0 commit comments