1111from torch_scatter import scatter
1212from tqdm import tqdm
1313
14-
1514from mxtaltools .analysis .crystal_rdf import compute_rdf_distance
1615from mxtaltools .common .geometry_utils import enforce_crystal_system
1716from mxtaltools .common .utils import is_cuda_oom
1817from mxtaltools .dataset_utils .utils import collate_data_list
1918from mxtaltools .models .utils import enforce_1d_bound , softmax_and_score
2019
2120
22-
23-
2421def dict2namespace (data_dict : dict ):
2522 """
2623 Recursively converts a dictionary and its internal dictionaries into an
@@ -131,16 +128,16 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
131128 # lennard jones need 10 angstroms to nicely converge
132129 cutoff = 10
133130
134- energy_computes = ['lj' ]
131+ energy_computes = ['lj' , 'elj' ]
135132 min_num_steps = 50
136133 num_samples = init_crystal_batch .num_graphs
137134
138135 if optim_target .lower () == 'silu' :
139136 energy_computes .append ('silu' )
140137 elif optim_target .lower () == 'qlj' :
141138 energy_computes .append ('qlj' )
142- elif optim_target .lower () == 'elj' :
143- energy_computes .append ('elj' )
139+ # elif optim_target.lower() == 'elj': # always do this
140+ # energy_computes.append('elj')
144141 elif optim_target .lower () == 'ellipsoid' :
145142 energy_computes .append ('ellipsoid' )
146143 elif optim_target .lower () == 'reduce' :
@@ -161,7 +158,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
161158 if target_latent is not None :
162159 target_latent = target_latent .to (init_sample .device )
163160
164- if False : # target_rdf is not None: # assumes we already have the box
161+ if False : # target_rdf is not None: # assumes we already have the box
165162 fixed_dims = [0 , 1 , 2 , 3 , 4 , 5 ]
166163 else :
167164 fixed_dims = None
@@ -280,8 +277,8 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
280277
281278 scheduler1 .step () # shrink
282279 s_ind += 1
283- if s_ind % 50 == 0 :
284- pbar .update (50 )
280+ if s_ind % 10 == 0 :
281+ pbar .update (10 )
285282 if s_ind >= min (max_num_steps , max (50 , min_num_steps )):
286283 converged = check_convergence (params_record , s_ind , convergence_eps ,
287284 optimizer , init_lr )
@@ -580,25 +577,32 @@ def compute_loss(cluster_batch, crystal_batch, outputs, config, opt_step):
580577 elif config .optim_target .lower () == 'rdf_dist' :
581578 n_channels = config .target_rdf .shape [- 2 ] # 120
582579 if config .rdf_warmup is not None :
583- # channel_warmup = config.rdf_warmup
584- # channel_onsets = torch.linspace(0, channel_warmup, n_channels) # evenly spaced turn-on times
585- # channel_weights = torch.sigmoid((opt_step - channel_onsets) / (channel_warmup / n_channels * 0.5))
586-
587- n_waves = 3
588- base_periods = torch .tensor ([1.0 , 1.6 , 2.5 ]) * config .rdf_warmup
589- channel_idx = torch .arange (n_channels , dtype = torch .float32 )
590-
591- modulation = torch .zeros (n_channels )
592- for i in range (n_waves ):
593- modulation += torch .sin (
594- 2 * torch .pi * opt_step / base_periods [i ] + 2 * torch .pi * channel_idx / n_channels * (i + 1 ))
595- modulation = modulation / n_waves
596- channel_weights = 0.5 + 0.5 * modulation
580+ channel_warmup = config .rdf_warmup
581+ channel_onsets = torch .linspace (0 , channel_warmup , n_channels ) # evenly spaced turn-on times
582+ channel_weights = torch .sigmoid ((opt_step - channel_onsets ) / (channel_warmup / n_channels * 0.5 ))
583+
584+ # n_waves = 3
585+ # base_periods = torch.tensor([1.0, 1.6, 2.5]) * config.rdf_warmup
586+ # channel_idx = torch.arange(n_channels, dtype=torch.float32)
587+ #
588+ # modulation = torch.zeros(n_channels)
589+ # for i in range(n_waves):
590+ # modulation += torch.sin(
591+ # 2 * torch.pi * opt_step / base_periods[i] + 2 * torch.pi * channel_idx / n_channels * (i + 1))
592+ # modulation = modulation / n_waves
593+ # channel_weights = 0.5 + 0.5 * modulation
597594 else :
598595 channel_weights = torch .ones (n_channels )
599- loss = compute_rdf_distance (outputs ['rdf' ][0 ], config .target_rdf ,
600- torch .linspace (0 , config .cutoff , config .target_rdf .shape [- 1 ]),
601- channel_weights = channel_weights )
596+ rdf_loss = compute_rdf_distance (outputs ['rdf' ][0 ], config .target_rdf ,
597+ torch .linspace (0 , config .cutoff , config .target_rdf .shape [- 1 ]),
598+ channel_weights = channel_weights )
599+
600+ en_cut = - 307 # set equal or higher to the target energy
601+ beta = 5
602+ lj_en = outputs ['elj' ]
603+ lj_loss = F .softplus (beta * (lj_en - en_cut )) / beta
604+
605+ loss = rdf_loss + lj_loss / 100
602606
603607 elif config .optim_target .lower () == 'latent_dist' :
604608 loss = (config .target_latent - crystal_batch .latent_params ()).norm (dim = - 1 )
@@ -637,7 +641,7 @@ def compute_auxiliary_loss(cluster_batch, loss, outputs, config):
637641 record = config .umbrella_record .to (cluster_batch .device )
638642 dists = torch .cdist (latents , record )
639643 penalty = torch .exp (- dists ** 2 / (2 * config .umbrella_sigma ** 2 )).sum (dim = 1 ).clip (max = 10 )
640- loss = loss + config .umbrella_epsilon * penalty
644+ loss = loss + config .umbrella_epsilon * penalty
641645
642646 return loss
643647
0 commit comments