Skip to content

Commit 4328da3

Browse files
:> GFN configs
:> search updates :> analysis updates
1 parent 5314c5c commit 4328da3

8 files changed

Lines changed: 210 additions & 128 deletions

File tree

configs/crystal_searches/base.yaml

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
11
device: cuda
2-
mol_path: D:\crystal_datasets\nehzor\NEHZOR0_std_conf.pt
3-
dataset_path: D:\crystal_datasets\nehzor\NEHZOR_structures_std_conf.pt #D:\crystal_datasets\xuldud\xuldud_sg61_zp1.pt
4-
target_path: D:\crystal_datasets\nehzor\NEHZOR0_std_conf.pt #D:\crystal_datasets\mipcas\MIPCAS.pt
5-
umbrella_path: D:\crystal_datasets\opt_outputs\nehzor_test_umbrella.pt # null
6-
target_identifier: NEHZOR #MIPCAS
2+
mol_path: D:\crystal_datasets\nacjaf\nacjaf_nikos.pt
3+
dataset_path: null #D:\crystal_datasets\nacjaf\nacjaf_nikos.pt #D:\crystal_datasets\xuldud\xuldud_sg61_zp1.pt
4+
target_path: D:\crystal_datasets\nacjaf\nacjaf_nikos.pt
5+
umbrella_path: null #D:\crystal_datasets\opt_outputs\nehzor_test_umbrella.pt # null
6+
target_identifier: Target_XXII_10581_0242_0005_P1211_02_01_07_1 #MIPCAS
77
out_dir: D:\crystal_datasets\opt_outputs
8-
run_name: NEHZOR_rdf
8+
run_name: NACJAF_chirality
99
save_trajs: false
1010

1111
uma_predictor_path: D:\crystal_datasets\esen_s.pt
1212

1313
init_sample_method: random # 'random' 'reasonable' 'data'
14-
init_reduced: true # initialize in our reduced frame
14+
init_reduced: false # initialize in our reduced frame
1515

1616
# if method is 'data' optimize provided pre-built crystals
1717
# else, from random initial conditions
1818
mol_seed: 0
19-
opt_seed: 7
19+
opt_seed: 0
2020
sampling_mode: all # 'all' or 'random' or 'ordered'
2121
mols_to_sample: 2
22-
num_samples: 10000 # per mol, per space group, per Zp
22+
num_samples: 1000 # per mol, per space group, per Zp
2323

2424
sgs_to_search: [ 14 ]
2525
zp_to_search: [ 1 ]
2626

2727
batch_size: 50
2828
grow_batch_size: false
2929

30-
init_target_cp: wide # a float, null (use target value), 'wide' (0.0-0.7) or 'std' target packing coefficient for initial structure sampling
30+
init_target_cp: 0.6828 # a float, null (use target value), 'wide' (0.0-0.7) or 'std' target packing coefficient for initial structure sampling
3131

3232
# Optimization can have multiple stages by duplicating the opt block below and adjusting params
3333
# the search algo will run consecutive n consecutive optimizations for n the number of list elements
@@ -53,18 +53,19 @@ init_target_cp: wide # a float, null (use target value), 'wide' (0.0-0.7) or 's
5353
# umbrella_epsilon: Optional[float] = None, # repulsion term for umbrella sampling
5454

5555
opt:
56-
- optim_target: 'uma' # lj qlj elj silu ellipsoid classification_score rdf_score rdf_dist latent_dist
57-
enforce_reduced: true
58-
compression_factor: 1.0
59-
cutoff: 10 # can be as low as 6 for SiLU, 10 otherwise
60-
init_lr: 0.1
61-
convergence_eps: 0.001
56+
- optim_target: 'rdf_dist' # lj qlj elj silu ellipsoid classification_score rdf_score rdf_dist latent_dist
57+
enforce_reduced: false
58+
compression_factor: 0.0
59+
cutoff: 4 # can be as low as 6 for SiLU, 10 otherwise
60+
init_lr: 0.001
61+
convergence_eps: 0.000001
6262
optimizer_func: 'rprop' # NOTE rprop is by far the fastest and most reliable
63-
anneal_lr: true
63+
anneal_lr: false
6464
grad_norm_clip: 0.1
6565
show_tqdm: true
66-
max_num_steps: 250
67-
target_packing_coeff: null
66+
max_num_steps: 40000
67+
rdf_warmup: 100
68+
target_packing_coeff: 0.6828
6869
umbrella: false
6970
umbrella_sigma: 0.25
7071
umbrella_epsilon: 40.0

mxtaltools/analysis/crystal_rdf.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def compute_rdf_distmat_parallel(rdf_record, rr, num_cpus, chunk_size=250):
475475
return rdf_dists
476476

477477

478-
def compute_rdf_distance(rdf1, rdf2, rr, n_parallel_rdf2: int = None, return_numpy: bool = False):
478+
def compute_rdf_distance(rdf1, rdf2, rr, n_parallel_rdf2: int = None, return_numpy: bool = False, channel_weights: torch.tensor=None):
479479
"""
480480
Compute a distance metric between two radial distribution functions including sub_rdfs where sub_rdfs are e.g., particular interatomic RDFS within a certain sample (elementwise or atomwise modes).
481481
@@ -511,7 +511,7 @@ def compute_rdf_distance(rdf1, rdf2, rr, n_parallel_rdf2: int = None, return_num
511511
else:
512512
torch_rdf1_f = torch_rdf1
513513

514-
bin_range = (torch_range[-1] - torch_range[0])
514+
#bin_range = (torch_range[-1] - torch_range[0])
515515
bin_width = torch_range[1] - torch_range[0]
516516

517517
# RDF should measure how much mass needs to move, by how far
@@ -525,7 +525,12 @@ def compute_rdf_distance(rdf1, rdf2, rr, n_parallel_rdf2: int = None, return_num
525525
# take the raw average over nonzero element pairs
526526
eps = 1e-12
527527
active = (torch_rdf1_f.sum(dim=-1) > eps) | (torch_rdf2.sum(dim=-1) > eps)
528-
distance = (emd * active).sum(dim=-1) / active.sum(dim=-1).clamp_min(1) # ignore unused channels
528+
529+
if channel_weights is not None:
530+
w = channel_weights.to(emd.device)[None, :] # [1, n_channels]
531+
distance = (emd * active * w).sum(dim=-1) / (active * w).sum(dim=-1).clamp_min(1e-10)
532+
else:
533+
distance = (emd * active).sum(dim=-1) / active.sum(dim=-1).clamp_min(1)
529534

530535
# distance = emd.mean(-1)
531536
#

mxtaltools/constants/space_group_info.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,22 @@
179179
SYM_OPS = {1: [array([[1., 0., 0., 0.],
180180
[0., 1., 0., 0.],
181181
[0., 0., 1., 0.],
182-
[0., 0., 0., 1.]])], 2: [array([[1., 0., 0., 0.],
183-
[0., 1., 0., 0.],
184-
[0., 0., 1., 0.],
185-
[0., 0., 0., 1.]]), array([[-1., 0., 0., 0.],
186-
[0., -1., 0., 0.],
187-
[0., 0., -1., 0.],
188-
[0., 0., 0., 1.]])],
182+
[0., 0., 0., 1.]])],
183+
2: [array([[1., 0., 0., 0.],
184+
[0., 1., 0., 0.],
185+
[0., 0., 1., 0.],
186+
[0., 0., 0., 1.]]), array([[-1., 0., 0., 0.],
187+
[0., -1., 0., 0.],
188+
[0., 0., -1., 0.],
189+
[0., 0., 0., 1.]])],
189190
3: [array([[1., 0., 0., 0.],
190191
[0., 1., 0., 0.],
191192
[0., 0., 1., 0.],
192193
[0., 0., 0., 1.]]), array([[-1., 0., 0., 0.],
193194
[0., 1., 0., 0.],
194195
[0., 0., -1., 0.],
195-
[0., 0., 0., 1.]])], 4: [array([[1., 0., 0., 0.],
196+
[0., 0., 0., 1.]])],
197+
4: [array([[1., 0., 0., 0.],
196198
[0., 1., 0., 0.],
197199
[0., 0., 1., 0.],
198200
[0., 0., 0., 1.]]),
@@ -213,7 +215,8 @@
213215
array([[-1., 0., 0., 0.5],
214216
[0., 1., 0., 0.5],
215217
[0., 0., -1., 0.],
216-
[0., 0., 0., 1.]])], 6: [array([[1., 0., 0., 0.],
218+
[0., 0., 0., 1.]])],
219+
6: [array([[1., 0., 0., 0.],
217220
[0., 1., 0., 0.],
218221
[0., 0., 1., 0.],
219222
[0., 0., 0., 1.]]), array([[1., 0., 0., 0.],

mxtaltools/crystal_search/crystal_opt_utils.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
121121
umbrella_sigma: Optional[float] = None, # bandwidth term for umbrella sampling
122122
umbrella_epsilon: Optional[float] = None, # repulsion term for umbrella sampling
123123
umbrella_record: Optional[list] = None,
124+
rdf_warmup: Optional[torch.tensor] = 500,
124125
):
125126
"""
126127
do a local optimization via gradient descent on some score function
@@ -160,7 +161,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
160161
if target_latent is not None:
161162
target_latent = target_latent.to(init_sample.device)
162163

163-
if target_rdf is not None:
164+
if False: #target_rdf is not None: # assumes we already have the box
164165
fixed_dims = [0, 1, 2, 3, 4, 5]
165166
else:
166167
fixed_dims = None
@@ -172,6 +173,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
172173
'score_model': score_model,
173174
'optim_target': optim_target,
174175
'target_latent': target_latent,
176+
'rdf_warmup': rdf_warmup,
175177
})
176178

177179
aux_config = dict2namespace({
@@ -270,7 +272,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
270272
else:
271273
loss_and_backprop(cluster_batch, crystal_batch, grad_norm_clip,
272274
optimizer, outputs, param_module, records,
273-
loss_config, aux_config)
275+
loss_config, aux_config, s_ind)
274276

275277
if s_ind % 10 == 0:
276278
gc.collect()
@@ -344,7 +346,10 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
344346
345347
"""
346348
if optim_target == 'rdf_dist':
347-
if torch.amin(records['loss']).log() < -2.5:
349+
timesteps = torch.arange(s_ind).repeat(init_crystal_batch.num_graphs, 1).T
350+
traj_fig(timesteps, torch.log(records['loss']), names=['time', 'loss'])
351+
352+
if torch.amin(records['loss'][-1]).log() < -2.5:
348353
print("Found the crystal!")
349354
good_ind = torch.argmin(records['loss'][-1]).item()
350355
sample = crystal_batch.batch_to_list()[good_ind]
@@ -403,8 +408,8 @@ def update_record(crystal_batch, outputs, params_record, records, s_ind):
403408

404409

405410
def loss_and_backprop(cluster_batch, crystal_batch, grad_norm_clip, optimizer, outputs, param_module, records,
406-
loss_config, aux_config):
407-
loss = compute_loss(cluster_batch, crystal_batch, outputs, loss_config)
411+
loss_config, aux_config, opt_step):
412+
loss = compute_loss(cluster_batch, crystal_batch, outputs, loss_config, opt_step)
408413
loss = compute_auxiliary_loss(cluster_batch, loss, outputs, aux_config)
409414

410415
records['loss'].append(loss.detach().cpu())
@@ -541,7 +546,7 @@ def ema_trajectory(traj: torch.Tensor, alpha: float = 0.1) -> torch.Tensor:
541546
return numer / denom
542547

543548

544-
def compute_loss(cluster_batch, crystal_batch, outputs, config):
549+
def compute_loss(cluster_batch, crystal_batch, outputs, config, opt_step):
545550
if config.optim_target.lower() == 'lj': # todo obviate this with analysis keys
546551
loss = outputs['lj']
547552

@@ -573,8 +578,27 @@ def compute_loss(cluster_batch, crystal_batch, outputs, config):
573578
loss = outputs['uma']
574579

575580
elif config.optim_target.lower() == 'rdf_dist':
581+
n_channels = config.target_rdf.shape[-2] # 120
582+
if config.rdf_warmup is not None:
583+
# channel_warmup = config.rdf_warmup
584+
# channel_onsets = torch.linspace(0, channel_warmup, n_channels) # evenly spaced turn-on times
585+
# channel_weights = torch.sigmoid((opt_step - channel_onsets) / (channel_warmup / n_channels * 0.5))
586+
587+
n_waves = 3
588+
base_periods = torch.tensor([1.0, 1.6, 2.5]) * config.rdf_warmup
589+
channel_idx = torch.arange(n_channels, dtype=torch.float32)
590+
591+
modulation = torch.zeros(n_channels)
592+
for i in range(n_waves):
593+
modulation += torch.sin(
594+
2 * torch.pi * opt_step / base_periods[i] + 2 * torch.pi * channel_idx / n_channels * (i + 1))
595+
modulation = modulation / n_waves
596+
channel_weights = 0.5 + 0.5 * modulation
597+
else:
598+
channel_weights = torch.ones(n_channels)
576599
loss = compute_rdf_distance(outputs['rdf'][0], config.target_rdf,
577-
torch.linspace(0, config.cutoff, config.target_rdf.shape[-1]))
600+
torch.linspace(0, config.cutoff, config.target_rdf.shape[-1]),
601+
channel_weights=channel_weights)
578602

579603
elif config.optim_target.lower() == 'latent_dist':
580604
loss = (config.target_latent - crystal_batch.latent_params()).norm(dim=-1)

mxtaltools/crystal_search/utils.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mxtaltools.mlip_interfaces.uma_utils import init_uma_crystal_predictor
1313

1414

15-
def save_umbrella_record(record, new_latents, path, sigma = 0.2, epsilon = 10):
15+
def save_umbrella_record(record, new_latents, path, sigma=0.2, epsilon=10):
1616
if len(record) > 0 and len(new_latents) > 0:
1717
dists = torch.cdist(new_latents, record)
1818
repulsion = epsilon * torch.exp(-dists ** 2 / (2 * sigma ** 2)).sum(dim=1).clip(max=10)
@@ -22,6 +22,7 @@ def save_umbrella_record(record, new_latents, path, sigma = 0.2, epsilon = 10):
2222
torch.save(record, path)
2323
return record
2424

25+
2526
def rdf_clustering(packing_coeff, rdf, rdf_cutoff, rr, samples, vdw, num_cpus=None):
2627
"""cluster samples according to rdf distances"""
2728
# rdf_dists = compute_rdf_distmat(rdf, rr)
@@ -209,13 +210,14 @@ def get_initial_state(config, crystal_batch, device, batch_idx, target):
209210
crystal_batch.sample_random_crystal_parameters(
210211
target_packing_coeff=target_cp,
211212
seed=config.opt_seed + int(batch_idx * 10000))
212-
standard_cell = target.compute_standard_cell()
213-
crystal_batch.cell_lengths = torch.tensor(standard_cell[0, :3], dtype=torch.float32,
214-
device=crystal_batch.device).repeat(crystal_batch.num_graphs, 1)
215-
crystal_batch.cell_angles = torch.tensor(standard_cell[0, 3:], dtype=torch.float32,
216-
device=crystal_batch.device).repeat(crystal_batch.num_graphs,
217-
1) * torch.pi / 2 / 90
218-
crystal_batch.box_analysis()
213+
assert False, "Below is deprecated and probably unnecessary for now"
214+
# standard_cell = target.compute_standard_cell()
215+
# crystal_batch.cell_lengths = torch.tensor(standard_cell[0, :3], dtype=torch.float32,
216+
# device=crystal_batch.device).repeat(crystal_batch.num_graphs, 1)
217+
# crystal_batch.cell_angles = torch.tensor(standard_cell[0, 3:], dtype=torch.float32,
218+
# device=crystal_batch.device).repeat(crystal_batch.num_graphs,
219+
# 1) * torch.pi / 2 / 90
220+
# crystal_batch.box_analysis()
219221

220222
else:
221223
assert False
@@ -228,7 +230,10 @@ def init_samples_to_optim(config, target=None):
228230
"""
229231
if config.init_sample_method == 'data':
230232
samples_to_optim = torch.load(config.dataset_path, weights_only=False)
231-
index_block = torch.arange(config.mol_seed * config.num_samples, (config.mol_seed + 1) * config.num_samples)
233+
if not isinstance(samples_to_optim, list):
234+
samples_to_optim = [samples_to_optim]
235+
index_block = [0 for _ in range(
236+
config.num_samples)] # torch.arange(config.mol_seed * config.num_samples, (config.mol_seed + 1) * config.num_samples)
232237
samples_to_optim = [samples_to_optim[ind] for ind in index_block]
233238
return samples_to_optim
234239
else:
@@ -305,7 +310,9 @@ def parse_opt_config(opt_config, config, device, target):
305310
opt_config['elementwise'] = False
306311
opt_config['atomwise'] = True
307312
tbatch = collate_data_list([target])
308-
out = tbatch.analyze(['rdf'], cutoff=10, rdf_cutoff=10, elementwise=False, atomwise=True, bins=100)
313+
out = tbatch.analyze(['rdf'], cutoff=opt_config['cutoff'],
314+
rdf_cutoff=opt_config['cutoff'],
315+
elementwise=False, atomwise=True, bins=100)
309316
opt_config['target_rdf'] = out['rdf'][0]
310317
if opt_config['optim_target'] in ['latent_dist']:
311318
tbatch = collate_data_list([target])

0 commit comments

Comments
 (0)