Skip to content

Commit 7009df2

Browse files
authored
Merge pull request #14 from cristofima/dev - refactor: enhance security and improve error handling in training and frontend components
- Replace wildcard CORS origins (*) with specific allowed origins based on Amplify domain and environment - Improve error handling by replacing bare exception catches with specific exception types and logging - Add guards against empty datasets and NaN values in training preprocessing
2 parents 1e95bc5 + 48aad75 commit 7009df2

10 files changed

Lines changed: 132 additions & 51 deletions

File tree

.github/SETUP_CICD.md

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,41 @@ cat > github-actions-permissions.json <<EOF
127127
{
128128
"Sid": "APIGatewayManagement",
129129
"Effect": "Allow",
130-
"Action": "apigateway:*",
131-
"Resource": "*"
130+
"Action": [
131+
"apigateway:GET",
132+
"apigateway:POST",
133+
"apigateway:PUT",
134+
"apigateway:PATCH",
135+
"apigateway:DELETE",
136+
"apigateway:UpdateRestApiPolicy"
137+
],
138+
"Resource": [
139+
"arn:aws:apigateway:*::/restapis",
140+
"arn:aws:apigateway:*::/restapis/*"
141+
]
132142
},
133143
{
134144
"Sid": "BatchManagement",
135145
"Effect": "Allow",
136-
"Action": "batch:*",
146+
"Action": [
147+
"batch:CreateComputeEnvironment",
148+
"batch:UpdateComputeEnvironment",
149+
"batch:DeleteComputeEnvironment",
150+
"batch:DescribeComputeEnvironments",
151+
"batch:CreateJobQueue",
152+
"batch:UpdateJobQueue",
153+
"batch:DeleteJobQueue",
154+
"batch:DescribeJobQueues",
155+
"batch:RegisterJobDefinition",
156+
"batch:DeregisterJobDefinition",
157+
"batch:DescribeJobDefinitions",
158+
"batch:SubmitJob",
159+
"batch:DescribeJobs",
160+
"batch:ListJobs",
161+
"batch:TerminateJob",
162+
"batch:TagResource",
163+
"batch:UntagResource"
164+
],
137165
"Resource": "*"
138166
},
139167
{

backend/training/eda.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pandas as pd
22
import numpy as np
3-
from typing import Dict, List, Tuple, Any
3+
from typing import List, Tuple
44
import re
55

66

@@ -54,6 +54,10 @@ def __init__(self, df: pd.DataFrame, target_column: str):
5454

5555
def _detect_problem_type(self) -> str:
5656
"""Detect if classification or regression"""
57+
# Guard against empty target
58+
if len(self.target) == 0:
59+
return 'classification' # Default fallback
60+
5761
if pd.api.types.is_numeric_dtype(self.target):
5862
unique_ratio = self.target.nunique() / len(self.target)
5963
if unique_ratio < 0.05 or self.target.nunique() < 20:
@@ -108,9 +112,10 @@ def _analyze_columns(self):
108112

109113
if self.problem_type == 'classification':
110114
class_counts = self.target.value_counts()
111-
imbalance_ratio = class_counts.max() / class_counts.min()
112-
if imbalance_ratio > 3:
113-
self.warnings.append(f"Class imbalance detected (ratio: {imbalance_ratio:.1f}:1)")
115+
if len(class_counts) > 0 and class_counts.min() > 0:
116+
imbalance_ratio = class_counts.max() / class_counts.min()
117+
if imbalance_ratio > 3:
118+
self.warnings.append(f"Class imbalance detected (ratio: {imbalance_ratio:.1f}:1)")
114119

115120
def _get_css(self) -> str:
116121
"""Return CSS styles"""

backend/training/model_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def get_feature_importance(model: AutoML, feature_names) -> Dict[str, float]:
170170
print(f" {i+1}. {feature}: {importance:.4f}")
171171

172172
return feature_importance
173-
except Exception:
174-
pass
173+
except (AttributeError, TypeError) as e:
174+
print(f"Could not extract feature importances from model: {e}")
175175

176176
# Fallback: Create equal importance for all features
177177
print("\nCould not extract feature importances, using equal weights")

backend/training/preprocessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ def detect_useless_columns(self, df: pd.DataFrame) -> List[str]:
184184

185185
def detect_problem_type(self, y: pd.Series) -> str:
186186
"""Detect if problem is classification or regression"""
187+
# Guard against empty target
188+
if len(y) == 0:
189+
return 'classification' # Default fallback
190+
187191
# Check if target is numeric
188192
if pd.api.types.is_numeric_dtype(y):
189193
# If numeric, check unique values ratio

backend/training/training_report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Dict, Any
2-
from datetime import datetime
2+
from datetime import datetime, timezone
33

44

55
def generate_training_report(
@@ -382,7 +382,7 @@ def _generate_config_info(self) -> str:
382382

383383
def generate(self) -> str:
384384
"""Generate complete HTML report"""
385-
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC")
385+
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
386386

387387
html = f"""
388388
<!DOCTYPE html>

frontend/app/results/[jobId]/page.tsx

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@ export default function ResultsPage() {
3030
setTimeout(() => setCopiedPython(false), 2000);
3131
};
3232

33+
// Generate Docker commands for model prediction (extracted to avoid duplication)
34+
const getDockerCommands = (jobId: string) => {
35+
const modelFile = `model_${jobId.slice(0, 8)}.pkl`;
36+
return `# Build prediction container (one time)
37+
docker build -f scripts/Dockerfile.predict -t automl-predict .
38+
39+
# Show model info and required features
40+
docker run --rm -v \${PWD}:/data automl-predict /data/${modelFile} --info
41+
42+
# Generate sample input JSON (auto-detects features)
43+
docker run --rm -v \${PWD}:/data automl-predict /data/${modelFile} -g /data/sample_input.json
44+
45+
# Edit sample_input.json with your values, then predict
46+
docker run --rm -v \${PWD}:/data automl-predict /data/${modelFile} --json /data/sample_input.json
47+
48+
# Batch predictions from CSV
49+
docker run --rm -v \${PWD}:/data automl-predict /data/${modelFile} -i /data/test.csv -o /data/predictions.csv`;
50+
};
51+
3352
useEffect(() => {
3453
const fetchResults = async () => {
3554
try {
@@ -276,39 +295,10 @@ export default function ResultsPage() {
276295
</div>
277296
<div className="relative">
278297
<pre className="bg-gray-900 text-gray-100 rounded-lg p-4 overflow-x-auto text-sm font-mono">
279-
<code>{`# Build prediction container (one time)
280-
docker build -f scripts/Dockerfile.predict -t automl-predict .
281-
282-
# Show model info and required features
283-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl --info
284-
285-
# Generate sample input JSON (auto-detects features)
286-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl -g /data/sample_input.json
287-
288-
# Edit sample_input.json with your values, then predict
289-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl --json /data/sample_input.json
290-
291-
# Batch predictions from CSV
292-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl -i /data/test.csv -o /data/predictions.csv`}</code>
298+
<code>{getDockerCommands(job.job_id)}</code>
293299
</pre>
294300
<button
295-
onClick={() => {
296-
const code = `# Build prediction container (one time)
297-
docker build -f scripts/Dockerfile.predict -t automl-predict .
298-
299-
# Show model info and required features
300-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl --info
301-
302-
# Generate sample input JSON (auto-detects features)
303-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl -g /data/sample_input.json
304-
305-
# Edit sample_input.json with your values, then predict
306-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl --json /data/sample_input.json
307-
308-
# Batch predictions from CSV
309-
docker run --rm -v \${PWD}:/data automl-predict /data/model_${job.job_id.slice(0, 8)}.pkl -i /data/test.csv -o /data/predictions.csv`;
310-
handleCopyDocker(code);
311-
}}
301+
onClick={() => handleCopyDocker(getDockerCommands(job.job_id))}
312302
className={`absolute top-2 right-2 px-3 py-1 text-xs rounded transition-all cursor-pointer ${
313303
copiedDocker
314304
? 'bg-green-600 text-white'

frontend/next.config.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ const nextConfig: NextConfig = {
66
unoptimized: true,
77
},
88

9-
// Trailing slashes for better compatibility
9+
// Trailing slashes ensure consistent URL handling across:
10+
// - AWS Amplify SSR deployments (prevents 404 on refresh)
11+
// - Static file serving and client-side navigation
1012
trailingSlash: true,
1113
};
1214

infrastructure/terraform/s3.tf

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
# =============================================================================
2+
# CORS Origins - Computed once and validated
3+
# =============================================================================
4+
# The cors_origins local calculates allowed origins based on:
5+
# 1. Manual override via var.cors_allowed_origins (highest priority)
6+
# 2. Amplify domain (if enabled via github_repository + github_token)
7+
# 3. localhost:3000 (only in dev environment)
8+
#
9+
# IMPORTANT: In production, either enable Amplify OR set cors_allowed_origins
10+
# =============================================================================
11+
locals {
12+
cors_origins = length(var.cors_allowed_origins) > 0 ? var.cors_allowed_origins : concat(
13+
local.amplify_enabled ? ["https://${aws_amplify_app.frontend[0].default_domain}"] : [],
14+
var.environment == "dev" ? ["http://localhost:3000"] : []
15+
)
16+
}
17+
118
# S3 Bucket for Datasets
219
resource "aws_s3_bucket" "datasets" {
320
bucket = "${local.name_prefix}-datasets-${local.account_id}"
@@ -35,9 +52,16 @@ resource "aws_s3_bucket_cors_configuration" "datasets" {
3552
cors_rule {
3653
allowed_headers = ["*"]
3754
allowed_methods = ["PUT", "GET"]
38-
allowed_origins = ["*"]
55+
allowed_origins = local.cors_origins
3956
max_age_seconds = 3600
4057
}
58+
59+
lifecycle {
60+
precondition {
61+
condition = length(local.cors_origins) > 0
62+
error_message = "CORS allowed_origins cannot be empty. Either enable Amplify (set github_repository and github_token), use dev environment, or set cors_allowed_origins manually."
63+
}
64+
}
4165
}
4266

4367
# S3 Bucket for Models
@@ -77,10 +101,17 @@ resource "aws_s3_bucket_cors_configuration" "models" {
77101
cors_rule {
78102
allowed_headers = ["*"]
79103
allowed_methods = ["GET"]
80-
allowed_origins = ["*"]
104+
allowed_origins = local.cors_origins
81105
expose_headers = ["Content-Disposition"]
82106
max_age_seconds = 3600
83107
}
108+
109+
lifecycle {
110+
precondition {
111+
condition = length(local.cors_origins) > 0
112+
error_message = "CORS allowed_origins cannot be empty. Either enable Amplify (set github_repository and github_token), use dev environment, or set cors_allowed_origins manually."
113+
}
114+
}
84115
}
85116

86117
# S3 Bucket for Reports
@@ -120,8 +151,15 @@ resource "aws_s3_bucket_cors_configuration" "reports" {
120151
cors_rule {
121152
allowed_headers = ["*"]
122153
allowed_methods = ["GET"]
123-
allowed_origins = ["*"]
154+
allowed_origins = local.cors_origins
124155
expose_headers = ["Content-Disposition"]
125156
max_age_seconds = 3600
126157
}
158+
159+
lifecycle {
160+
precondition {
161+
condition = length(local.cors_origins) > 0
162+
error_message = "CORS allowed_origins cannot be empty. Either enable Amplify (set github_repository and github_token), use dev environment, or set cors_allowed_origins manually."
163+
}
164+
}
127165
}

infrastructure/terraform/variables.tf

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,11 @@ variable "github_token" {
124124
sensitive = true
125125
default = ""
126126
}
127+
128+
variable "cors_allowed_origins" {
129+
description = "List of allowed origins for S3 CORS configuration. Use specific domains for security."
130+
type = list(string)
131+
default = []
132+
# When empty, defaults to Amplify domain + localhost for dev
133+
# For production, specify exact frontend URLs
134+
}

scripts/predict.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,11 @@ def prepare_input(data: pd.DataFrame, preprocessor) -> pd.DataFrame:
176176
numeric_cols = df.select_dtypes(include=[np.number]).columns
177177
for col in numeric_cols:
178178
if df[col].isnull().any():
179-
df[col].fillna(df[col].median(), inplace=True)
179+
median_val = df[col].median()
180+
# Fallback to 0 if median is NaN (empty column or all NaN)
181+
if pd.isna(median_val):
182+
median_val = 0
183+
df[col].fillna(median_val, inplace=True)
180184

181185
categorical_cols = df.select_dtypes(include=['object']).columns
182186
for col in categorical_cols:
@@ -243,8 +247,9 @@ def predict_single(model_package: dict, input_data: dict) -> dict:
243247

244248
result['probabilities'] = {str(label): float(p) for label, p in zip(class_labels, probas)}
245249
result['confidence'] = float(max(probas))
246-
except Exception:
247-
pass
250+
except (AttributeError, ValueError, IndexError) as e:
251+
# Log warning but continue - probabilities are optional
252+
print(f"⚠️ Could not compute class probabilities: {e}")
248253

249254
return result
250255

@@ -283,8 +288,9 @@ def predict_batch(model_package: dict, input_path: str, output_path: str) -> Non
283288
try:
284289
probas = model.predict_proba(X)
285290
df['confidence'] = probas.max(axis=1)
286-
except Exception:
287-
pass
291+
except (AttributeError, ValueError) as e:
292+
# Log warning but continue - confidence scores are optional
293+
print(f"⚠️ Could not compute confidence scores: {e}")
288294

289295
# Save results
290296
df.to_csv(output_path, index=False)

0 commit comments

Comments
 (0)