|
58 | 58 | ) |
59 | 59 | from physicsnemo.models.vfgn.graph_network_modules import VFGNLearnedSimulator |
60 | 60 |
|
61 | | -physical_devices = tf.config.list_physical_devices("GPU") |
62 | | -try: |
63 | | - for device_ in physical_devices: |
64 | | - tf.config.experimental.set_memory_growth(device_, True) |
65 | | -except: |
66 | | - # Invalid device or cannot modify virtual devices once initialized. |
67 | | - pass |
68 | | - |
69 | 61 |
|
70 | 62 | def Train(rank_zero_logger, dist, cfg: DictConfig): |
71 | 63 | """ |
@@ -131,18 +123,13 @@ def Train(rank_zero_logger, dist, cfg: DictConfig): |
131 | 123 | writer = SummaryWriter(log_dir=cfg.data_options.ckpt_path_vfgn) |
132 | 124 |
|
133 | 125 | optimizer = None |
| 126 | + scaler = None |
134 | 127 | # todo : check device |
135 | 128 | device = "cpu" |
136 | 129 | step = 0 |
137 | 130 | running_loss = 0.0 |
138 | 131 | best_loss = 1000.0 |
139 | 132 |
|
140 | | - # Native PyTorch automatic mixed precision (AMP) replaces NVIDIA Apex. |
141 | | - # GradScaler is a no-op when enabled=False, so it is safe to construct |
142 | | - # unconditionally and only activate when fp16 training is requested. |
143 | | - use_amp = cfg.general.fp16 |
144 | | - scaler = torch.amp.GradScaler("cuda", enabled=use_amp) |
145 | | - |
146 | 133 | rank_zero_logger.info("Training started...") |
147 | 134 |
|
148 | 135 | for features, targets in tqdm(dataset): |
@@ -184,11 +171,10 @@ def Train(rank_zero_logger, dist, cfg: DictConfig): |
184 | 171 |
|
185 | 172 | sampled_noise *= noise_mask |
186 | 173 |
|
187 | | - amp_active = ( |
188 | | - use_amp and isinstance(device, torch.device) and device.type == "cuda" |
189 | | - ) |
| 174 | + amp_enabled = cfg.general.fp16 and scaler is not None |
190 | 175 | with torch.autocast( |
191 | | - device_type="cuda", dtype=torch.float16, enabled=amp_active |
| 176 | + device_type=device.type if isinstance(device, torch.device) else "cpu", |
| 177 | + enabled=amp_enabled, |
192 | 178 | ): |
193 | 179 | pred_target = model( |
194 | 180 | next_positions=targets.to(device), |
@@ -219,8 +205,8 @@ def Train(rank_zero_logger, dist, cfg: DictConfig): |
219 | 205 | model.setMessagePassingDevices(message_passing_devices) |
220 | 206 | model = model.to(device) |
221 | 207 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) |
222 | | - # Mixed precision is handled via the torch.amp GradScaler / autocast |
223 | | - # constructed above; no extra optimizer wrapping is required. |
| 208 | + if cfg.general.fp16: |
| 209 | + scaler = torch.amp.GradScaler(device.type) |
224 | 210 |
|
225 | 211 | scheduler = torch.optim.lr_scheduler.ExponentialLR( |
226 | 212 | optimizer, gamma=0.1, verbose=True |
@@ -398,7 +384,7 @@ def Train(rank_zero_logger, dist, cfg: DictConfig): |
398 | 384 | rank_zero_logger.info(f"loss: {loss}") |
399 | 385 | # back propogation |
400 | 386 | optimizer.zero_grad() |
401 | | - if use_amp: |
| 387 | + if cfg.general.fp16: |
402 | 388 | scaler.scale(loss).backward() |
403 | 389 | scaler.step(optimizer) |
404 | 390 | scaler.update() |
|
0 commit comments