-
Notifications
You must be signed in to change notification settings - Fork 7k
Expand file tree
/
Copy pathtabular_train.py
More file actions
124 lines (99 loc) · 4.09 KB
/
tabular_train.py
File metadata and controls
124 lines (99 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import os
import shutil
from pprint import pprint
import yaml
from autogluon.tabular import TabularDataset, TabularPredictor
def get_input_path(path):
file = os.listdir(path)[0]
if len(os.listdir(path)) > 1:
print(f"WARN: more than one file is found in {channel} directory")
print(f"Using {file}")
filename = f"{path}/{file}"
return filename
def get_env_if_present(name):
result = None
if name in os.environ:
result = os.environ[name]
return result
if __name__ == "__main__":
# Disable Autotune
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
# ------------------------------------------------------------ Args parsing
print("Starting AG")
parser = argparse.ArgumentParser()
# Data, model, and output directories
parser.add_argument(
"--output-data-dir", type=str, default=get_env_if_present("SM_OUTPUT_DATA_DIR")
)
parser.add_argument(
"--model-dir", type=str, default=get_env_if_present("SM_MODEL_DIR")
)
parser.add_argument("--n_gpus", type=str, default=get_env_if_present("SM_NUM_GPUS"))
parser.add_argument(
"--train_dir", type=str, default=get_env_if_present("SM_CHANNEL_TRAIN")
)
parser.add_argument(
"--test_dir",
type=str,
required=False,
default=get_env_if_present("SM_CHANNEL_TEST"),
)
parser.add_argument(
"--ag_config", type=str, default=get_env_if_present("SM_CHANNEL_CONFIG")
)
parser.add_argument(
"--serving_script", type=str, default=get_env_if_present("SM_CHANNEL_SERVING")
)
args, _ = parser.parse_known_args()
print(f"Args: {args}")
# See SageMaker-specific environment variables: https://sagemaker.readthedocs.io/en/stable/overview.html#prepare-a-training-script
os.makedirs(args.output_data_dir, mode=0o777, exist_ok=True)
config_file = get_input_path(args.ag_config)
with open(config_file) as f:
config = yaml.safe_load(f) # AutoGluon-specific config
if args.n_gpus:
config["num_gpus"] = int(args.n_gpus)
print("Running training job with the config:")
pprint(config)
# ---------------------------------------------------------------- Training
train_file = get_input_path(args.train_dir)
train_data = TabularDataset(train_file)
save_path = os.path.normpath(args.model_dir)
ag_predictor_args = config["ag_predictor_args"]
ag_predictor_args["path"] = save_path
ag_fit_args = config["ag_fit_args"]
predictor = TabularPredictor(**ag_predictor_args).fit(train_data, **ag_fit_args)
# --------------------------------------------------------------- Inference
if args.test_dir:
test_file = get_input_path(args.test_dir)
test_data = TabularDataset(test_file)
# Predictions
y_pred_proba = predictor.predict_proba(test_data)
if config.get("output_prediction_format", "csv") == "parquet":
y_pred_proba.to_parquet(f"{args.output_data_dir}/predictions.parquet")
else:
y_pred_proba.to_csv(f"{args.output_data_dir}/predictions.csv")
# Leaderboard
if config.get("leaderboard", False):
lb = predictor.leaderboard(test_data, silent=False)
lb.to_csv(f"{args.output_data_dir}/leaderboard.csv")
# Feature importance
if config.get("feature_importance", False):
feature_importance = predictor.feature_importance(test_data)
feature_importance.to_csv(f"{args.output_data_dir}/feature_importance.csv")
else:
if config.get("leaderboard", False):
lb = predictor.leaderboard(silent=False)
lb.to_csv(f"{args.output_data_dir}/leaderboard.csv")
if args.serving_script:
print("Saving serving script")
serving_script_saving_path = os.path.join(save_path, "code")
os.mkdir(serving_script_saving_path)
serving_script_path = get_input_path(args.serving_script)
shutil.move(
serving_script_path,
os.path.join(
serving_script_saving_path, os.path.basename(serving_script_path)
),
)