Skip to content

Commit d7a4974

Browse files
committed
prevent possibility of missing input slots
1 parent 99a2c61 commit d7a4974

4 files changed

Lines changed: 98 additions & 1 deletion

File tree

population.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(self, config):
137137
self.shared_attr_vocab,
138138
graph_encoder.pin_role_embedding,
139139
min_pin_nodes=decoder_min_pin_nodes,
140+
required_input_count=len(configured_inputs),
140141
)
141142
icnn_hidden_dims = getattr(config, "latent_icnn_hidden_dims", (64, 32))
142143
if isinstance(icnn_hidden_dims, str):
@@ -1569,6 +1570,10 @@ def _assign_pin_role(idx: int, role: str) -> bool:
15691570
return False
15701571
attrs = dict(normalized_attrs[idx] or {})
15711572
current = _normalize_pin_role(attrs.get("pin_role"))
1573+
if attrs.get("_pin_role_locked"):
1574+
if current == role:
1575+
return False
1576+
return False
15721577
if current == role:
15731578
return False
15741579
attrs["pin_role"] = role
@@ -1580,6 +1585,10 @@ def _assign_pin_slot(idx: int, slot: int | None) -> bool:
15801585
return False
15811586
attrs = dict(normalized_attrs[idx] or {})
15821587
current = _node_pin_slot(idx)
1588+
if attrs.get("_pin_slot_locked"):
1589+
if current == slot:
1590+
return False
1591+
return False
15831592
if slot is None:
15841593
if "pin_slot_index" in attrs:
15851594
attrs.pop("pin_slot_index", None)

search_space_compression.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,7 @@ def __init__(
845845
max_edges_per_node: int = 256,
846846
max_edges_per_graph: int = 4096,
847847
min_pin_nodes: int | None = None,
848+
required_input_count: int | None = None,
848849
):
849850
super().__init__()
850851
self.shared_attr_vocab = shared_attr_vocab
@@ -885,6 +886,13 @@ def __init__(
885886
self.min_pin_nodes = max(0, int(min_pin_nodes))
886887
except (TypeError, ValueError):
887888
self.min_pin_nodes = 0
889+
if required_input_count is None:
890+
self.required_input_count = 0
891+
else:
892+
try:
893+
self.required_input_count = max(0, int(required_input_count))
894+
except (TypeError, ValueError):
895+
self.required_input_count = 0
888896
# Encourage attribute decoder to terminate by progressively biasing the EOS logit.
889897
self.attr_eos_bias_base = 0.0
890898
self.attr_eos_bias_slope = 0.1
@@ -1022,6 +1030,8 @@ def make_name_input(token_idx: int) -> torch.Tensor:
10221030
node_loop_timer = time.perf_counter()
10231031
node_loop_iters = 0
10241032
teacher_force_nodes = self.training and (target_node_count is not None)
1033+
forced_pin_roles: List[str | None] = []
1034+
forced_pin_slots: List[int | None] = []
10251035
with torch.autograd.profiler.record_function("GraphDecoder.node_loop"):
10261036
while True:
10271037
logger.debug(f"Decoding node {t}")
@@ -1072,6 +1082,13 @@ def make_name_input(token_idx: int) -> torch.Tensor:
10721082
node_types.append(self.type_head(new_node).argmax(dim=-1).cpu().tolist())
10731083

10741084
node_index = t
1085+
forced_role = None
1086+
forced_slot = None
1087+
if self.required_input_count and node_index < self.required_input_count:
1088+
forced_role = PIN_ROLE_INPUT
1089+
forced_slot = node_index
1090+
forced_pin_roles.append(forced_role)
1091+
forced_pin_slots.append(forced_slot)
10751092
hidden_edge = hidden_node
10761093
edge_in = torch.zeros(1, 1, device=device)
10771094
node_edge_budget = 0
@@ -1163,6 +1180,12 @@ def make_name_input(token_idx: int) -> torch.Tensor:
11631180
pin_role_vec = None
11641181
decoded_role = None
11651182
target_pin_role = None
1183+
forced_role = None
1184+
forced_slot = None
1185+
if node_idx < len(forced_pin_roles):
1186+
forced_role = forced_pin_roles[node_idx]
1187+
if node_idx < len(forced_pin_slots):
1188+
forced_slot = forced_pin_slots[node_idx]
11661189
if self.pin_role_head is not None:
11671190
pin_role_vec = self.pin_role_head(embedding)
11681191
decoded_role = self._infer_pin_role(pin_role_vec, pin_role_reference)
@@ -1176,8 +1199,14 @@ def make_name_input(token_idx: int) -> torch.Tensor:
11761199
pin_role_teacher_tokens += 1
11771200
if decoded_role is None:
11781201
decoded_role = target_pin_role
1202+
locked_role = False
1203+
if forced_role is not None:
1204+
decoded_role = forced_role
1205+
locked_role = True
11791206
if decoded_role is not None:
11801207
attrs["pin_role"] = decoded_role
1208+
if locked_role:
1209+
attrs["_pin_role_locked"] = True
11811210
slot_scalar = self.pin_slot_head(embedding).squeeze()
11821211
slot_scalar = _sanitize_pin_slot_scalar(slot_scalar)
11831212
slot_target = None
@@ -1189,13 +1218,19 @@ def make_name_input(token_idx: int) -> torch.Tensor:
11891218
pin_slot_teacher_loss = pin_slot_teacher_loss + slot_loss
11901219
pin_slot_teacher_tokens += 1
11911220
slot_attr = None
1192-
if slot_scalar is not None:
1221+
locked_slot = False
1222+
if forced_slot is not None:
1223+
slot_attr = forced_slot
1224+
locked_slot = True
1225+
elif slot_scalar is not None:
11931226
slot_value = float(slot_scalar.detach().item())
11941227
slot_attr = max(0, int(round(abs(slot_value)) - 1))
11951228
elif slot_target is not None:
11961229
slot_attr = max(0, int(round(abs(float(slot_target))) - 1))
11971230
if slot_attr is not None:
11981231
attrs["pin_slot_index"] = slot_attr
1232+
if locked_slot:
1233+
attrs["_pin_slot_locked"] = True
11991234
name_hidden = embedding.unsqueeze(0).unsqueeze(0)
12001235
val_hidden = None
12011236
t = 0

tests/test_population.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,33 @@ def reaches_outputs():
744744
assert idx in backward
745745

746746

747+
def test_repair_preserves_seeded_input_pins():
748+
config = make_config()
749+
pop = GuidedPopulation(config)
750+
required_inputs = len(config.genome_config.input_keys)
751+
if required_inputs < 2:
752+
pytest.skip("insufficient configured inputs for this test")
753+
node_attrs = []
754+
for slot in range(required_inputs):
755+
node_attrs.append(
756+
{
757+
"pin_role": "input",
758+
"pin_slot_index": slot,
759+
"_pin_role_locked": True,
760+
"_pin_slot_locked": True,
761+
}
762+
)
763+
node_attrs.append({})
764+
graph = _make_empty_graph_dict(len(node_attrs), node_attrs)
765+
766+
pop._repair_graph_dict(graph)
767+
768+
attrs = graph["node_attributes"]
769+
for slot in range(required_inputs):
770+
assert attrs[slot]["pin_role"] == "input"
771+
assert attrs[slot]["pin_slot_index"] == slot
772+
773+
747774
def test_repair_preserves_predicted_edges_for_visualization():
748775
config = make_config()
749776
required_inputs = len(config.genome_config.input_keys)

tests/test_search_space_compression.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,29 @@ def test_graph_decoder_enforces_min_pin_nodes():
231231
graphs = decoder(latent)
232232
graph = graphs[0]
233233
assert len(graph["node_attributes"]) >= 3
234+
235+
236+
def test_graph_decoder_seeds_input_pin_roles():
237+
vocab = SharedAttributeVocab([], embedding_dim=4)
238+
decoder = GraphDecoder(
239+
num_node_types=3,
240+
latent_dim=8,
241+
shared_attr_vocab=vocab,
242+
hidden_dim=4,
243+
min_pin_nodes=4,
244+
required_input_count=2,
245+
)
246+
decoder.eval()
247+
248+
with torch.no_grad():
249+
decoder.stop_head.weight.zero_()
250+
decoder.stop_head.bias.fill_(20.0)
251+
252+
latent = torch.zeros(1, decoder.latent_dim)
253+
graphs = decoder(latent)
254+
graph = graphs[0]
255+
attrs = graph["node_attributes"]
256+
assert attrs[0]["pin_role"] == "input"
257+
assert attrs[0]["pin_slot_index"] == 0
258+
assert attrs[1]["pin_role"] == "input"
259+
assert attrs[1]["pin_slot_index"] == 1

0 commit comments

Comments
 (0)