@@ -845,6 +845,7 @@ def __init__(
845845 max_edges_per_node : int = 256 ,
846846 max_edges_per_graph : int = 4096 ,
847847 min_pin_nodes : int | None = None ,
848+ required_input_count : int | None = None ,
848849 ):
849850 super ().__init__ ()
850851 self .shared_attr_vocab = shared_attr_vocab
@@ -885,6 +886,13 @@ def __init__(
885886 self .min_pin_nodes = max (0 , int (min_pin_nodes ))
886887 except (TypeError , ValueError ):
887888 self .min_pin_nodes = 0
889+ if required_input_count is None :
890+ self .required_input_count = 0
891+ else :
892+ try :
893+ self .required_input_count = max (0 , int (required_input_count ))
894+ except (TypeError , ValueError ):
895+ self .required_input_count = 0
888896 # Encourage attribute decoder to terminate by progressively biasing the EOS logit.
889897 self .attr_eos_bias_base = 0.0
890898 self .attr_eos_bias_slope = 0.1
@@ -1022,6 +1030,8 @@ def make_name_input(token_idx: int) -> torch.Tensor:
10221030 node_loop_timer = time .perf_counter ()
10231031 node_loop_iters = 0
10241032 teacher_force_nodes = self .training and (target_node_count is not None )
1033+ forced_pin_roles : List [str | None ] = []
1034+ forced_pin_slots : List [int | None ] = []
10251035 with torch .autograd .profiler .record_function ("GraphDecoder.node_loop" ):
10261036 while True :
10271037 logger .debug (f"Decoding node { t } " )
@@ -1072,6 +1082,13 @@ def make_name_input(token_idx: int) -> torch.Tensor:
10721082 node_types .append (self .type_head (new_node ).argmax (dim = - 1 ).cpu ().tolist ())
10731083
10741084 node_index = t
1085+ forced_role = None
1086+ forced_slot = None
1087+ if self .required_input_count and node_index < self .required_input_count :
1088+ forced_role = PIN_ROLE_INPUT
1089+ forced_slot = node_index
1090+ forced_pin_roles .append (forced_role )
1091+ forced_pin_slots .append (forced_slot )
10751092 hidden_edge = hidden_node
10761093 edge_in = torch .zeros (1 , 1 , device = device )
10771094 node_edge_budget = 0
@@ -1163,6 +1180,12 @@ def make_name_input(token_idx: int) -> torch.Tensor:
11631180 pin_role_vec = None
11641181 decoded_role = None
11651182 target_pin_role = None
1183+ forced_role = None
1184+ forced_slot = None
1185+ if node_idx < len (forced_pin_roles ):
1186+ forced_role = forced_pin_roles [node_idx ]
1187+ if node_idx < len (forced_pin_slots ):
1188+ forced_slot = forced_pin_slots [node_idx ]
11661189 if self .pin_role_head is not None :
11671190 pin_role_vec = self .pin_role_head (embedding )
11681191 decoded_role = self ._infer_pin_role (pin_role_vec , pin_role_reference )
@@ -1176,8 +1199,14 @@ def make_name_input(token_idx: int) -> torch.Tensor:
11761199 pin_role_teacher_tokens += 1
11771200 if decoded_role is None :
11781201 decoded_role = target_pin_role
1202+ locked_role = False
1203+ if forced_role is not None :
1204+ decoded_role = forced_role
1205+ locked_role = True
11791206 if decoded_role is not None :
11801207 attrs ["pin_role" ] = decoded_role
1208+ if locked_role :
1209+ attrs ["_pin_role_locked" ] = True
11811210 slot_scalar = self .pin_slot_head (embedding ).squeeze ()
11821211 slot_scalar = _sanitize_pin_slot_scalar (slot_scalar )
11831212 slot_target = None
@@ -1189,13 +1218,19 @@ def make_name_input(token_idx: int) -> torch.Tensor:
11891218 pin_slot_teacher_loss = pin_slot_teacher_loss + slot_loss
11901219 pin_slot_teacher_tokens += 1
11911220 slot_attr = None
1192- if slot_scalar is not None :
1221+ locked_slot = False
1222+ if forced_slot is not None :
1223+ slot_attr = forced_slot
1224+ locked_slot = True
1225+ elif slot_scalar is not None :
11931226 slot_value = float (slot_scalar .detach ().item ())
11941227 slot_attr = max (0 , int (round (abs (slot_value )) - 1 ))
11951228 elif slot_target is not None :
11961229 slot_attr = max (0 , int (round (abs (float (slot_target ))) - 1 ))
11971230 if slot_attr is not None :
11981231 attrs ["pin_slot_index" ] = slot_attr
1232+ if locked_slot :
1233+ attrs ["_pin_slot_locked" ] = True
11991234 name_hidden = embedding .unsqueeze (0 ).unsqueeze (0 )
12001235 val_hidden = None
12011236 t = 0
0 commit comments