@@ -121,6 +121,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
121121 umbrella_sigma : Optional [float ] = None , # bandwidth term for umbrella sampling
122122 umbrella_epsilon : Optional [float ] = None , # repulsion term for umbrella sampling
123123 umbrella_record : Optional [list ] = None ,
124+ rdf_warmup : Optional [torch .tensor ] = 500 ,
124125):
125126 """
126127 do a local optimization via gradient descent on some score function
@@ -160,7 +161,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
160161 if target_latent is not None :
161162 target_latent = target_latent .to (init_sample .device )
162163
163- if target_rdf is not None :
164+ if False : # target_rdf is not None: # assumes we already have the box
164165 fixed_dims = [0 , 1 , 2 , 3 , 4 , 5 ]
165166 else :
166167 fixed_dims = None
@@ -172,6 +173,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
172173 'score_model' : score_model ,
173174 'optim_target' : optim_target ,
174175 'target_latent' : target_latent ,
176+ 'rdf_warmup' : rdf_warmup ,
175177 })
176178
177179 aux_config = dict2namespace ({
@@ -270,7 +272,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
270272 else :
271273 loss_and_backprop (cluster_batch , crystal_batch , grad_norm_clip ,
272274 optimizer , outputs , param_module , records ,
273- loss_config , aux_config )
275+ loss_config , aux_config , s_ind )
274276
275277 if s_ind % 10 == 0 :
276278 gc .collect ()
@@ -344,7 +346,10 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
344346
345347 """
346348 if optim_target == 'rdf_dist' :
347- if torch .amin (records ['loss' ]).log () < - 2.5 :
349+ timesteps = torch .arange (s_ind ).repeat (init_crystal_batch .num_graphs , 1 ).T
350+ traj_fig (timesteps , torch .log (records ['loss' ]), names = ['time' , 'loss' ])
351+
352+ if torch .amin (records ['loss' ][- 1 ]).log () < - 2.5 :
348353 print ("Found the crystal!" )
349354 good_ind = torch .argmin (records ['loss' ][- 1 ]).item ()
350355 sample = crystal_batch .batch_to_list ()[good_ind ]
@@ -403,8 +408,8 @@ def update_record(crystal_batch, outputs, params_record, records, s_ind):
403408
404409
405410def loss_and_backprop (cluster_batch , crystal_batch , grad_norm_clip , optimizer , outputs , param_module , records ,
406- loss_config , aux_config ):
407- loss = compute_loss (cluster_batch , crystal_batch , outputs , loss_config )
411+ loss_config , aux_config , opt_step ):
412+ loss = compute_loss (cluster_batch , crystal_batch , outputs , loss_config , opt_step )
408413 loss = compute_auxiliary_loss (cluster_batch , loss , outputs , aux_config )
409414
410415 records ['loss' ].append (loss .detach ().cpu ())
@@ -541,7 +546,7 @@ def ema_trajectory(traj: torch.Tensor, alpha: float = 0.1) -> torch.Tensor:
541546 return numer / denom
542547
543548
544- def compute_loss (cluster_batch , crystal_batch , outputs , config ):
549+ def compute_loss (cluster_batch , crystal_batch , outputs , config , opt_step ):
545550 if config .optim_target .lower () == 'lj' : # todo obviate this with analysis keys
546551 loss = outputs ['lj' ]
547552
@@ -573,8 +578,27 @@ def compute_loss(cluster_batch, crystal_batch, outputs, config):
573578 loss = outputs ['uma' ]
574579
575580 elif config .optim_target .lower () == 'rdf_dist' :
581+ n_channels = config .target_rdf .shape [- 2 ] # 120
582+ if config .rdf_warmup is not None :
583+ # channel_warmup = config.rdf_warmup
584+ # channel_onsets = torch.linspace(0, channel_warmup, n_channels) # evenly spaced turn-on times
585+ # channel_weights = torch.sigmoid((opt_step - channel_onsets) / (channel_warmup / n_channels * 0.5))
586+
587+ n_waves = 3
588+ base_periods = torch .tensor ([1.0 , 1.6 , 2.5 ]) * config .rdf_warmup
589+ channel_idx = torch .arange (n_channels , dtype = torch .float32 )
590+
591+ modulation = torch .zeros (n_channels )
592+ for i in range (n_waves ):
593+ modulation += torch .sin (
594+ 2 * torch .pi * opt_step / base_periods [i ] + 2 * torch .pi * channel_idx / n_channels * (i + 1 ))
595+ modulation = modulation / n_waves
596+ channel_weights = 0.5 + 0.5 * modulation
597+ else :
598+ channel_weights = torch .ones (n_channels )
576599 loss = compute_rdf_distance (outputs ['rdf' ][0 ], config .target_rdf ,
577- torch .linspace (0 , config .cutoff , config .target_rdf .shape [- 1 ]))
600+ torch .linspace (0 , config .cutoff , config .target_rdf .shape [- 1 ]),
601+ channel_weights = channel_weights )
578602
579603 elif config .optim_target .lower () == 'latent_dist' :
580604 loss = (config .target_latent - crystal_batch .latent_params ()).norm (dim = - 1 )
0 commit comments