Skip to content

Commit 32c8bde

Browse files
committed
train: remove policy target-mask leakage and harden runtime settings
1 parent 112ea85 commit 32c8bde

3 files changed

Lines changed: 12 additions & 20 deletions

File tree

src/model/system.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,9 @@ def _common_step(
101101
batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
102102
) -> dict[str, torch.Tensor]:
103103
boards, target_pis, target_vs = batch
104-
legal_mask = (target_pis > 0).to(dtype=boards.dtype)
105-
has_legal_support = torch.sum(legal_mask, dim=1, keepdim=True) > 0
106-
if not bool(torch.all(has_legal_support).item()):
107-
# Defensive fallback: keep logits finite even if a malformed target row is all zeros.
108-
legal_mask = torch.where(has_legal_support, legal_mask, torch.ones_like(legal_mask))
109-
110-
pi_logits, v_pred = self(boards, action_mask=legal_mask)
104+
# Training must not see a target-derived action mask, otherwise policy loss can
105+
# become artificially easy (label leakage) when targets are sparse/one-hot.
106+
pi_logits, v_pred = self(boards)
111107

112108
loss_v = functional.mse_loss(v_pred.view(-1), target_vs.view(-1))
113109
log_probs = functional.log_softmax(pi_logits, dim=1)

tests/test_training_step_numerics.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_forward_passes_action_mask_to_inner_model(self) -> None:
102102
self.assertIsInstance(action_mask_obj, torch.Tensor)
103103
self.assertTrue(torch.equal(action_mask_obj, mask))
104104

105-
def test_common_step_uses_policy_support_as_action_mask(self) -> None:
105+
def test_common_step_does_not_pass_target_derived_action_mask(self) -> None:
106106
system = AtaxxZero(
107107
learning_rate=1e-3,
108108
d_model=64,
@@ -121,15 +121,11 @@ def test_common_step_uses_policy_support_as_action_mask(self) -> None:
121121

122122
with patch.object(system.model, "forward", return_value=(pi_logits, v_pred)) as forward_spy:
123123
_ = system._common_step((boards, target_pis, target_vs))
124-
_, kwargs = forward_spy.call_args
124+
args, kwargs = forward_spy.call_args
125125

126-
self.assertIn("action_mask", kwargs)
127-
action_mask = kwargs["action_mask"]
128-
self.assertIsInstance(action_mask, torch.Tensor)
129-
self.assertEqual(action_mask.shape, target_pis.shape)
130-
self.assertEqual(float(action_mask[0, 7].item()), 1.0)
131-
self.assertEqual(float(torch.sum(action_mask[0]).item()), 1.0)
132-
self.assertTrue(torch.all(action_mask[1] == 1.0).item())
126+
self.assertEqual(len(args), 1)
127+
self.assertIsInstance(args[0], torch.Tensor)
128+
self.assertEqual(kwargs.get("action_mask"), None)
133129

134130
def test_common_step_applies_value_loss_coefficient(self) -> None:
135131
system = AtaxxZero(

train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
src = root / "src"
1717
if str(src) not in sys.path:
1818
sys.path.insert(0, str(src))
19-
2019
if TYPE_CHECKING:
2120
from data.replay_buffer import ReplayBuffer
2221
from model.system import AtaxxZero
23-
2422
from training.bootstrap import generate_imitation_data # noqa: E402
2523
from training.callbacks import OptimizerStateTransfer # noqa: E402
2624
from training.checkpointing import ( # noqa: E402
@@ -73,7 +71,7 @@ def _build_train_loader(buffer: ReplayBuffer, device: str) -> DataLoader[object]
7371
batch_size=cfg_int("batch_size"),
7472
shuffle=True,
7573
num_workers=cfg_int("num_workers"),
76-
persistent_workers=True,
74+
persistent_workers=cfg_bool("persistent_workers"),
7775
pin_memory=(device == "cuda"),
7876
prefetch_factor=2,
7977
)
@@ -99,7 +97,7 @@ def _build_val_loader(buffer: ReplayBuffer, device: str) -> DataLoader[object] |
9997
batch_size=cfg_int("batch_size"),
10098
shuffle=False,
10199
num_workers=cfg_int("num_workers"),
102-
persistent_workers=True,
100+
persistent_workers=cfg_bool("persistent_workers"),
103101
pin_memory=(device == "cuda"),
104102
prefetch_factor=2,
105103
)
@@ -494,6 +492,8 @@ def main() -> None:
494492
fail_on_error=cfg_bool("fail_on_hf_upload_error"),
495493
)
496494
except Exception as exc:
495+
if cfg_bool("fail_on_hf_upload_error"):
496+
raise
497497
log(f"HF upload wait failed: {exc}")
498498
hf_upload_executor.shutdown(wait=False, cancel_futures=True)
499499
if __name__ == "__main__":

0 commit comments

Comments
 (0)