Skip to content

Commit 636fe90

Browse files
:> various analyses
:> var boosted training change and configs > search updates
1 parent 4328da3 commit 636fe90

9 files changed

Lines changed: 385 additions & 73 deletions

File tree

configs/crystal_searches/base.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ mol_seed: 0
1919
opt_seed: 0
2020
sampling_mode: all # 'all' or 'random' or 'ordered'
2121
mols_to_sample: 2
22-
num_samples: 1000 # per mol, per space group, per Zp
22+
num_samples: 10000 # per mol, per space group, per Zp
2323

2424
sgs_to_search: [ 14 ]
2525
zp_to_search: [ 1 ]
2626

27-
batch_size: 50
27+
batch_size: 2000
2828
grow_batch_size: false
2929

3030
init_target_cp: 0.6828 # a float, null (use target value), 'wide' (0.0-0.7) or 'std' target packing coefficient for initial structure sampling
@@ -56,15 +56,15 @@ opt:
5656
- optim_target: 'rdf_dist' # lj qlj elj silu ellipsoid classification_score rdf_score rdf_dist latent_dist
5757
enforce_reduced: false
5858
compression_factor: 0.0
59-
cutoff: 4 # can be as low as 6 for SiLU, 10 otherwise
59+
cutoff: 10 # can be as low as 6 for SiLU, 10 otherwise
6060
init_lr: 0.001
61-
convergence_eps: 0.000001
61+
convergence_eps: 0.0001
6262
optimizer_func: 'rprop' # NOTE rprop is by far the fastest and most reliable
6363
anneal_lr: false
6464
grad_norm_clip: 0.1
6565
show_tqdm: true
66-
max_num_steps: 40000
67-
rdf_warmup: 100
66+
max_num_steps: 500
67+
rdf_warmup: null
6868
target_packing_coeff: 0.6828
6969
umbrella: false
7070
umbrella_sigma: 0.25

examples/crystal_search_reporting.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ def compack_fig(matches, rmsds, write_fig):
195195
fig.write_image(r'C:\Users\mikem\OneDrive\NYU\CSD\papers\mxt_code\compack_fig.png', width=900, height=900)
196196

197197

198-
def batch_compack(best_sample_inds, optimized_samples, reference_cluster_batch): # todo refactor into analysis code
198+
def batch_compack(best_sample_inds, optimized_samples, reference_cluster_batch, ref_ind: int = 0): # todo refactor into analysis code
199199
# generate the crystals in ccdc format
200200
best_crystals_batch = collate_data_list([optimized_samples[ind] for ind in best_sample_inds])
201-
best_cluster_batch = best_crystals_batch.mol2cluster().to('cpu')
201+
best_cluster_batch = best_crystals_batch.mol2cluster(cutoff=10).to('cpu')
202202
_ = cluster_batch_to_ccdc_crystals(best_cluster_batch, np.arange(best_cluster_batch.num_graphs))
203-
mol = ase_mol_from_crystaldata(reference_cluster_batch, index=0, mode='unit cell')
204-
#mol.info['spacegroup'] = Spacegroup(int(best_cluster_batch.sg_ind[0]), setting=1)
205-
mol.write('DAFMUV.cif')
203+
mol = ase_mol_from_crystaldata(reference_cluster_batch, index=ref_ind, mode='unit cell')
204+
mol.info['spacegroup'] = Spacegroup(int(reference_cluster_batch.sg_ind[ref_ind]), setting=1)
205+
mol.write('compack_placeholder.cif')
206206

207207
print(f"Running COMPACK on {len(best_sample_inds)} crystals")
208208
pool = mp.Pool(8)
@@ -225,7 +225,7 @@ def batch_compack(best_sample_inds, optimized_samples, reference_cluster_batch):
225225

226226

227227
def single_compack_run(ind):
228-
ref_crystal = CrystalReader('DAFMUV.cif')[0]
228+
ref_crystal = CrystalReader('compack_placeholder.cif')[0]
229229
sample_crystal = CrystalReader(f'temp_{ind}.cif')[0]
230230
similarity_engine = PackingSimilarity()
231231
similarity_engine.settings.distance_tolerance = 0.4

mxtaltools/common/geometry_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,8 +1663,8 @@ def compute_latent_distance(latents1: torch.Tensor,
16631663

16641664
rot_dists = []
16651665
for zp in range(z_prime): # this should be replaced with a proper vector distance
1666-
rmat1 = rotvec2rotmat(lat_sph_rotvec1[...,3 * zp:3 * zp + 3], 'spherical')
1667-
rmat2 = rotvec2rotmat(lat_sph_rotvec2[...,3 * zp:3 * zp + 3], 'spherical')
1666+
rmat1 = rotvec2rotmat(lat_sph_rotvec1[..., 3 * zp:3 * zp + 3], 'spherical')
1667+
rmat2 = rotvec2rotmat(lat_sph_rotvec2[..., 3 * zp:3 * zp + 3], 'spherical')
16681668

16691669
R_delta = rmat1 @ rmat2.transpose(-1, -2)
16701670

@@ -1676,10 +1676,11 @@ def compute_latent_distance(latents1: torch.Tensor,
16761676
rot_dists = torch.stack(rot_dists).sum(0)
16771677

16781678
"overall distance metric"
1679-
#dists = 0.5 * box_dist + 0.25 * (positions_dist / z_prime / 2.5) + 0.25 * (rot_dists / z_prime / 2)
1680-
#scales = [2 * sqrt(6), 2*sqrt(3), torch.pi] # maximum variation per dist
1681-
scales = [1, 0.836, 0.293]# [0.0127, 0.0152, 0.0433] # empirical std over CSD samples
1682-
dists = scales[0] * 0.5 * box_dist + scales[1] * 0.25 * (positions_dist / z_prime) + scales[2] * 0.25 * (rot_dists / z_prime)
1679+
# dists = 0.5 * box_dist + 0.25 * (positions_dist / z_prime / 2.5) + 0.25 * (rot_dists / z_prime / 2)
1680+
# scales = [2 * sqrt(6), 2*sqrt(3), torch.pi] # maximum variation per dist
1681+
scales = [1, 0.836, 0.293] # [0.0127, 0.0152, 0.0433] # empirical std over CSD samples
1682+
dists = scales[0] * 0.5 * box_dist + scales[1] * 0.25 * (positions_dist / z_prime) + scales[2] * 0.25 * (
1683+
rot_dists / z_prime)
16831684

16841685
return dists
16851686

mxtaltools/crystal_search/crystal_opt_utils.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,13 @@
1111
from torch_scatter import scatter
1212
from tqdm import tqdm
1313

14-
1514
from mxtaltools.analysis.crystal_rdf import compute_rdf_distance
1615
from mxtaltools.common.geometry_utils import enforce_crystal_system
1716
from mxtaltools.common.utils import is_cuda_oom
1817
from mxtaltools.dataset_utils.utils import collate_data_list
1918
from mxtaltools.models.utils import enforce_1d_bound, softmax_and_score
2019

2120

22-
23-
2421
def dict2namespace(data_dict: dict):
2522
"""
2623
Recursively converts a dictionary and its internal dictionaries into an
@@ -131,16 +128,16 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
131128
# lennard jones need 10 angstroms to nicely converge
132129
cutoff = 10
133130

134-
energy_computes = ['lj']
131+
energy_computes = ['lj', 'elj']
135132
min_num_steps = 50
136133
num_samples = init_crystal_batch.num_graphs
137134

138135
if optim_target.lower() == 'silu':
139136
energy_computes.append('silu')
140137
elif optim_target.lower() == 'qlj':
141138
energy_computes.append('qlj')
142-
elif optim_target.lower() == 'elj':
143-
energy_computes.append('elj')
139+
# elif optim_target.lower() == 'elj': # always do this
140+
# energy_computes.append('elj')
144141
elif optim_target.lower() == 'ellipsoid':
145142
energy_computes.append('ellipsoid')
146143
elif optim_target.lower() == 'reduce':
@@ -161,7 +158,7 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
161158
if target_latent is not None:
162159
target_latent = target_latent.to(init_sample.device)
163160

164-
if False: #target_rdf is not None: # assumes we already have the box
161+
if False: # target_rdf is not None: # assumes we already have the box
165162
fixed_dims = [0, 1, 2, 3, 4, 5]
166163
else:
167164
fixed_dims = None
@@ -280,8 +277,8 @@ def gradient_descent_optimization( # todo consolidate kwargs somewhere
280277

281278
scheduler1.step() # shrink
282279
s_ind += 1
283-
if s_ind % 50 == 0:
284-
pbar.update(50)
280+
if s_ind % 10 == 0:
281+
pbar.update(10)
285282
if s_ind >= min(max_num_steps, max(50, min_num_steps)):
286283
converged = check_convergence(params_record, s_ind, convergence_eps,
287284
optimizer, init_lr)
@@ -580,25 +577,32 @@ def compute_loss(cluster_batch, crystal_batch, outputs, config, opt_step):
580577
elif config.optim_target.lower() == 'rdf_dist':
581578
n_channels = config.target_rdf.shape[-2] # 120
582579
if config.rdf_warmup is not None:
583-
# channel_warmup = config.rdf_warmup
584-
# channel_onsets = torch.linspace(0, channel_warmup, n_channels) # evenly spaced turn-on times
585-
# channel_weights = torch.sigmoid((opt_step - channel_onsets) / (channel_warmup / n_channels * 0.5))
586-
587-
n_waves = 3
588-
base_periods = torch.tensor([1.0, 1.6, 2.5]) * config.rdf_warmup
589-
channel_idx = torch.arange(n_channels, dtype=torch.float32)
590-
591-
modulation = torch.zeros(n_channels)
592-
for i in range(n_waves):
593-
modulation += torch.sin(
594-
2 * torch.pi * opt_step / base_periods[i] + 2 * torch.pi * channel_idx / n_channels * (i + 1))
595-
modulation = modulation / n_waves
596-
channel_weights = 0.5 + 0.5 * modulation
580+
channel_warmup = config.rdf_warmup
581+
channel_onsets = torch.linspace(0, channel_warmup, n_channels) # evenly spaced turn-on times
582+
channel_weights = torch.sigmoid((opt_step - channel_onsets) / (channel_warmup / n_channels * 0.5))
583+
584+
# n_waves = 3
585+
# base_periods = torch.tensor([1.0, 1.6, 2.5]) * config.rdf_warmup
586+
# channel_idx = torch.arange(n_channels, dtype=torch.float32)
587+
#
588+
# modulation = torch.zeros(n_channels)
589+
# for i in range(n_waves):
590+
# modulation += torch.sin(
591+
# 2 * torch.pi * opt_step / base_periods[i] + 2 * torch.pi * channel_idx / n_channels * (i + 1))
592+
# modulation = modulation / n_waves
593+
# channel_weights = 0.5 + 0.5 * modulation
597594
else:
598595
channel_weights = torch.ones(n_channels)
599-
loss = compute_rdf_distance(outputs['rdf'][0], config.target_rdf,
600-
torch.linspace(0, config.cutoff, config.target_rdf.shape[-1]),
601-
channel_weights=channel_weights)
596+
rdf_loss = compute_rdf_distance(outputs['rdf'][0], config.target_rdf,
597+
torch.linspace(0, config.cutoff, config.target_rdf.shape[-1]),
598+
channel_weights=channel_weights)
599+
600+
en_cut = -307 # set equal or higher to the target energy
601+
beta = 5
602+
lj_en = outputs['elj']
603+
lj_loss = F.softplus(beta * (lj_en - en_cut)) / beta
604+
605+
loss = rdf_loss + lj_loss / 100
602606

603607
elif config.optim_target.lower() == 'latent_dist':
604608
loss = (config.target_latent - crystal_batch.latent_params()).norm(dim=-1)
@@ -637,7 +641,7 @@ def compute_auxiliary_loss(cluster_batch, loss, outputs, config):
637641
record = config.umbrella_record.to(cluster_batch.device)
638642
dists = torch.cdist(latents, record)
639643
penalty = torch.exp(-dists ** 2 / (2 * config.umbrella_sigma ** 2)).sum(dim=1).clip(max=10)
640-
loss = loss + config.umbrella_epsilon * penalty
644+
loss = loss + config.umbrella_epsilon * penalty
641645

642646
return loss
643647

mxtaltools/crystal_search/run_search.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,7 @@
1515
recover_opt_state, process_target, save_umbrella_record
1616
from mxtaltools.dataset_utils.utils import collate_data_list
1717

18-
if __name__ == '__main__':
19-
args = parse_args() # call config with "python run_search.py --config /path/to/config.yaml
20-
source_dir = Path(__file__).resolve().parent.parent.parent
21-
if args.config is None:
22-
config_path = source_dir / 'configs' / 'crystal_searches' / 'base.yaml'
23-
else:
24-
config_path = Path(args.config)
25-
26-
config = dict2namespace(load_yaml(config_path))
27-
18+
def crystal_search(config):
2819
device = config.device
2920
umbrella_path = config.umbrella_path
3021

@@ -102,14 +93,6 @@
10293
umbrella_record = torch.load(umbrella_path, weights_only=False)
10394
save_umbrella_record(umbrella_record, new_latents, umbrella_path, opt_config['umbrella_sigma'], opt_config['umbrella_epsilon'])
10495

105-
# cursor = 0
106-
# bsz = config.batch_size
107-
# while cursor < len(opt_outs):
108-
# batch = collate_data_list(opt_outs[cursor:cursor + bsz])
109-
# en = batch.elj
110-
# print([en.quantile(ii) for ii in torch.linspace(0, 1, 10)])
111-
# cursor += bsz
112-
11396
cursor += config.batch_size
11497
prev_best_samples = None
11598
pbar.update(min(config.batch_size, num_samples - cursor)) # safe final update
@@ -143,6 +126,9 @@
143126
else:
144127
raise e
145128

129+
130+
return opt_outs
131+
146132
print(f"Sampling complete! Optimized a total of {len(opt_outs)} crystal samples.")
147133

148134
# batch = collate_data_list(opt_outs)
@@ -151,6 +137,19 @@
151137
# batch.plot_batch_density_funnel(split_by_sg=True)
152138

153139
aa = 1
140+
141+
if __name__ == '__main__':
142+
args = parse_args() # call config with "python run_search.py --config /path/to/config.yaml
143+
source_dir = Path(__file__).resolve().parent.parent.parent
144+
if args.config is None:
145+
config_path = source_dir / 'configs' / 'crystal_searches' / 'base.yaml'
146+
else:
147+
config_path = Path(args.config)
148+
149+
config = dict2namespace(load_yaml(config_path))
150+
151+
crystal_search(config)
152+
154153
"""
155154
156155

mxtaltools/crystal_search/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def coarse_crystal_filter(lj_record, lj_cutoff, packing_coeff_record, packing_cu
162162

163163
def get_initial_state(config, crystal_batch, device, batch_idx, target):
164164
# sample initial parameters
165-
if config.init_sample_method == 'data':
165+
if config.init_sample_method == 'data' or config.init_sample_method == 'in_config':
166166
return crystal_batch
167167

168168
if config.init_target_cp == 'std':
@@ -236,6 +236,9 @@ def init_samples_to_optim(config, target=None):
236236
config.num_samples)] # torch.arange(config.mol_seed * config.num_samples, (config.mol_seed + 1) * config.num_samples)
237237
samples_to_optim = [samples_to_optim[ind] for ind in index_block]
238238
return samples_to_optim
239+
elif config.init_sample_method == 'in_config':
240+
samples_to_optim = config.samples_to_optim
241+
return samples_to_optim
239242
else:
240243
if target is None:
241244
mol_list = torch.load(config.mol_path, weights_only=False)

mxtaltools/dataset_utils/data_class_methods/crystal_ops.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -963,25 +963,36 @@ def destandardize_aunit_orientation(self, std_aunit_orientation):
963963
orientation_means = torch.tensor([[0, 0, torch.pi / 2]], dtype=torch.float32, device=self.device)
964964
return std_aunit_orientation * orientation_stds + orientation_means
965965

966-
def _build_feature_labels(self):
966+
def _build_feature_labels(self, space):
967967
lattice_features = ['a', 'b', 'c',
968968
r'$\alpha$', r'$\beta$', r'$\gamma$']
969969
if self.max_z_prime == 1:
970970
lattice_features.extend([
971971
f'u', f'v', f'w',
972972
])
973-
lattice_features.extend([
974-
f'x', f'y', f'z'
975-
])
973+
if space == 'latent':
974+
lattice_features.extend([
975+
f'θ', f'φ', f'r'
976+
])
977+
else:
978+
lattice_features.extend([
979+
f'x', f'y', f'z'
980+
])
981+
976982
else:
977983
for zp in range(self.max_z_prime):
978984
lattice_features.extend([
979985
f'aunit{zp} u', f'aunit{zp} v', f'aunit{zp} w',
980986
])
981987
for zp in range(self.max_z_prime):
982-
lattice_features.extend([
983-
f'x{zp}', f'y{zp}', f'z{zp}'
984-
])
988+
if space == 'latent':
989+
lattice_features.extend([
990+
f'θ{zp}', f'φ{zp}', f'r{zp}'
991+
])
992+
else:
993+
lattice_features.extend([
994+
f'x{zp}', f'y{zp}', f'z{zp}'
995+
])
985996
return lattice_features
986997

987998
def _set_cell_ranges(self, space, samples):
@@ -1225,7 +1236,7 @@ def plot_batch_cell_params(self, space='real',
12251236
print("Cell statistics only works for a batch of crystal data objects")
12261237
return None
12271238

1228-
lattice_features = self._build_feature_labels()
1239+
lattice_features = self._build_feature_labels(space=space)
12291240
samples = self._get_samples(space)
12301241
num_dists, dist_names, dists = self._collect_sample_dists(samples, ref_dist, quantiles, split_by_sg,
12311242
split_by_zp, aux_dists, override_energy)
@@ -1270,6 +1281,10 @@ def plot_batch_cell_params(self, space='real',
12701281
# showgrid=False, zeroline=False, ticks='outside',
12711282
# tickwidth=1, mirror=True
12721283
# )
1284+
fig.update_yaxes(
1285+
showgrid=False, zeroline=False, showticklabels=False, ticks='',
1286+
mirror=True
1287+
)
12731288
if len(dists) > 1:
12741289
fig.update_traces(opacity=0.5)
12751290
if show:
@@ -1290,7 +1305,7 @@ def plot_batch_staircase(self, space='real',
12901305
ref_dist=None,
12911306
):
12921307

1293-
labels = self._build_feature_labels()
1308+
labels = self._build_feature_labels(space=space)
12941309
samples = self._get_samples(space)
12951310
if torch.is_tensor(samples):
12961311
samples = samples.detach().cpu().numpy()

0 commit comments

Comments
 (0)