-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexperiment_utils.py
More file actions
113 lines (88 loc) · 3.87 KB
/
experiment_utils.py
File metadata and controls
113 lines (88 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import pandas as pd
import torch
import yaml
import copy
from typing import Any, Dict, List, Optional
import logging
from data.data_object import DataObject
from model.model_object import ModelObject
logger = logging.getLogger(__name__)
def load_yaml(path: str) -> Dict[str, Any]:
"""Load a YAML file and return its contents as a dictionary."""
with open(path, 'r') as f:
return yaml.safe_load(f)
def deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively merge `overrides` into `base`.
Values in `overrides` take precedence. Returns a new dict.
Example:
base = {"a": 1, "b": {"c": 2, "d": 3}}
overrides = {"b": {"c": 99}, "e": 5}
result = {"a": 1, "b": {"c": 99, "d": 3}, "e": 5}
"""
result = copy.deepcopy(base)
for key, value in overrides.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = copy.deepcopy(value)
return result
def resolve_layer_config(base_config_path: str, overrides: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Load a layer's base YAML config file and apply any experiment-level overrides.
Args:
base_config_path: Path to the layer's own config YAML.
overrides: Dictionary of keys to override from the experiment config.
Returns:
The merged configuration dictionary.
"""
base_config = load_yaml(base_config_path)
if overrides:
merged = deep_merge(base_config, overrides)
logger.info(f"Applied {len(overrides)} override(s) to {base_config_path}")
return merged
return base_config
def reconstruct_encoding_constraints(instance: torch.Tensor, cat_features_indices: List[List[int]]) -> torch.Tensor:
"""
For a given instance, ensure that the one-hot encoded categorical features
are valid.
Args:
instance: The input instance as a torch tensor.
cat_features_indices: A list of lists indicating the indices of the one-hot encoded categorical features.
Returns:
The instance with reconstructed encoding constraints.
"""
x_reconstructed = instance.clone()
for features in cat_features_indices:
if len(features) > 1:
max_indices_in_group = torch.argmax(x_reconstructed[:, features], dim=1)
x_reconstructed[:, features] = 0
row_indices = torch.arange(x_reconstructed.size(0))
absolute_feature_indices = torch.tensor(features, device=x_reconstructed.device)[max_indices_in_group]
x_reconstructed[row_indices, absolute_feature_indices] = 1.0
elif len(features) == 1:
# For binary categorical features that are one-hot encoded into a single column, we can simply round the value to 0 or 1
x_reconstructed[:, features[0]] = torch.round(x_reconstructed[:, features[0]])
return x_reconstructed
def setup_logging(name: str):
level = getattr(logging, name.upper(), logging.INFO)
logging.basicConfig(
level=level,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
def select_factuals(model: ModelObject, X_test: pd.DataFrame, config) -> pd.DataFrame:
num_factuals = config.get("num_factuals", 5)
factual_selection = config.get("factual_selection", "negative_class")
df = X_test.copy()
if factual_selection == "negative_class":
df["y"] = model.predict(X_test)
df = df[df["y"] == 0]
df = df.drop(columns=["y"]).sample(n=num_factuals, random_state=42)
elif factual_selection == "all":
df["y"] = model.predict(X_test)
df = df[df["y"] == 0]
df = df.drop(columns=["y"])
else:
raise ValueError(f"Unknown factual selection method {factual_selection}")
return df