Skip to content

Commit 88d72e1

Browse files
authored
Adding Additional Functionality to Neural Quantum States (NQS) for Quantum Chemistry Repo (#49)
* Added implementations for transformer ansatz with minimal phase function and surrogate local energy Hamiltonian class. * Addressed requested changes for merging into Tangelo-Examples.
1 parent fdc8ab5 commit 88d72e1

12 files changed

Lines changed: 443 additions & 8 deletions

File tree

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
from .base import Base
7+
8+
class AltTransformer(Base):
9+
def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, embedding_dim: int=16, nhead: int=2, dim_feedforward: int=64, num_layers: int=1, temperature: float=1.0, device: str=None, **kwargs):
10+
'''
11+
A Transformer-based autoregressive NQS Ansatz using the phase strategy from Bennewitz et al, where a single linear layer operators on the concated list of transformer hidden states in lieu of a seperate phase network.
12+
Parent class args:
13+
num_sites: number of qubits in the ansatz system
14+
num_spin_up: total occupancy number of spin-up spin-orbitals
15+
num_spin_down: total occupancy number of spin-down spin-orbitals
16+
device: Device (CPU or Cuda) to store model
17+
Child class specific args:
18+
embedding_dim: dimension of transformer hidden states
19+
nhead: number of attention heads
20+
dim_feedforward: dimension of transformer feedforward layer
21+
num_layers: number of transformer blocks
22+
temperature: modulus network softmax temperature parameter
23+
device: device to store model on
24+
'''
25+
super(AltTransformer, self).__init__('AltTransformer', num_sites, num_spin_up, num_spin_down, device)
26+
27+
# construct model
28+
self.num_in, self.num_out = num_sites, num_sites*2
29+
self.temperature = temperature
30+
# Sample function samples spatial orbitals in reverse order, but spin-up orbitals are always sampled first. self.input_order calculates this order for sampling.
31+
self.input_order = np.stack([np.arange(self.num_sites-2,-1,-2), np.arange(self.num_sites-1,-1,-2)],1).reshape(-1) # [4,5,2,3,0,1]
32+
self.input_order = torch.Tensor(self.input_order).int().to(self.device)
33+
# Calculate spatial orbital sampling order
34+
self.shell_order = torch.arange(self.num_sites//2-1, -1, -1) # [2,1,0]
35+
36+
transformer_layer = nn.TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward=dim_feedforward, dropout=0.0, batch_first=True)
37+
self.transformer = nn.TransformerEncoder(transformer_layer, num_layers)
38+
self.fc = nn.Linear(embedding_dim, 4)
39+
self.tok_emb = nn.Embedding(5, embedding_dim)
40+
self.pos_emb = nn.Embedding(len(self.shell_order), embedding_dim)
41+
self.apply(self._init_weights)
42+
self.softmax = nn.Softmax(dim=-1)
43+
self.log_softmax = nn.LogSoftmax(dim=-1)
44+
45+
self.net_phase = nn.Linear(in_features=embedding_dim*len(self.shell_order), out_features=4, bias=True)
46+
47+
self.mask = torch.zeros((len(self.shell_order), len(self.shell_order))).to(self.device)
48+
for i in range(len(self.mask)):
49+
for j in range(len(self.mask)):
50+
if i < j:
51+
self.mask[i][j] = float('-inf')
52+
53+
def _init_weights(self, module: nn.Module):
54+
'''
55+
Performs weight initialization for each module in ansatz, dependent on module type
56+
Args:
57+
module: module to be initialized
58+
'''
59+
if isinstance(module, nn.Linear):
60+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
61+
if module.bias is not None:
62+
torch.nn.init.zeros_(module.bias)
63+
elif isinstance(module, nn.Embedding):
64+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
65+
elif isinstance(module, nn.LayerNorm):
66+
torch.nn.init.zeros_(module.bias)
67+
torch.nn.init.ones_(module.weight)
68+
69+
def forward(self, x: torch.Tensor, sample_shell: int = 0) -> torch.Tensor:
70+
'''
71+
Forward function for Transformer ansatz (used for both sampling and training)
72+
Args:
73+
x: qubit spin configuration
74+
sample_shell: shell index provided during sampling to avoid extraneous forward passes
75+
Returns:
76+
prob_cond/log_psi: either conditional probabilities of logarithms of statevector entries, depending on if sampling
77+
'''
78+
# x: [bs, num_sites]
79+
shells = self.state2shell(x)[:, self.shell_order]
80+
input = 4*torch.ones(shells.shape, dtype=torch.int64).to(self.device)
81+
input[:,1:] = shells[:,:-1]
82+
pos = self.shell_order.to(self.device)
83+
84+
input = self.tok_emb(input) + self.pos_emb(pos)
85+
# new x is of shape (batch_size, sequence_length, d_model)
86+
87+
if self.mask.device != self.device:
88+
self.mask = self.mask.to(self.device)
89+
output = self.transformer(input[:,:(len(self.shell_order) - sample_shell + 1)], mask=self.mask[:(len(self.shell_order) - sample_shell + 1),:(len(self.shell_order) - sample_shell + 1)], is_causal=True)
90+
91+
if not self.sampling:
92+
phase_input = output.reshape(output.shape[0], -1)
93+
output = self.fc(output)
94+
95+
if output.shape[1] < len(self.shell_order):
96+
new_output = torch.zeros(output.shape[0], len(self.shell_order), output.shape[2]).to(self.device)
97+
new_output[:,:output.shape[1],:] = output
98+
output = new_output[:, self.shell_order]
99+
else:
100+
output = output[:, self.shell_order]
101+
102+
if (self.num_spin_up + self.num_spin_down) >= 0:
103+
logits_cls = self.apply_constraint(x, output)
104+
logits_cls /= self.temperature
105+
106+
if self.sampling:
107+
prob_cond = self.softmax(logits_cls)
108+
return prob_cond
109+
else:
110+
log_psi_cond = 0.5 * self.log_softmax(logits_cls)
111+
idx = self.state2shell(x)
112+
log_psi_real = log_psi_cond.gather(-1, idx.unsqueeze(-1)).sum(-1).sum(-1)
113+
log_psi_imag = self.net_phase(phase_input).gather(-1, idx[:, -1].unsqueeze(-1)).squeeze()
114+
if log_psi_real.shape[0] == 1:
115+
log_psi_imag = log_psi_imag.reshape(log_psi_real.shape)
116+
log_psi = torch.stack((log_psi_real, log_psi_imag), dim=-1)
117+
return log_psi
118+
119+
def apply_constraint(self, inp: torch.Tensor, log_psi_cond: torch.Tensor) -> torch.Tensor:
120+
'''
121+
Applies constraints that enforce particle number and spin on ansatz network
122+
Args:
123+
inp: input spin configurations
124+
log_psi_cond: unconstrained ansatz outputs
125+
Returns:
126+
log_psi_cond: ansatz outputs with constraint applied
127+
'''
128+
# convert [|-1,-1>, |1,-1>, |-1,1>, |1,1>] to [0, 1, 2, 3]
129+
device = inp.device
130+
N = inp.shape[-1] // 2
131+
inp_up = inp[:, self.input_order][:, ::2].clone()
132+
inp_down = inp[:, self.input_order][:, 1::2].clone()
133+
inp_cumsum_up = torch.cat((torch.zeros((inp_up.shape[0],1)).to(device), ((1 + inp_up)/2).cumsum(-1)[:, :-1]), axis=-1)
134+
inp_cumsum_down = torch.cat((torch.zeros((inp_down.shape[0],1)).to(device), ((1 + inp_down)/2).cumsum(-1)[:, :-1]), axis=-1)
135+
upper_bound_up = self.num_spin_up
136+
lower_bound_up = (self.num_spin_up - (N - torch.arange(1, N+1)))
137+
condition1_up = (inp_cumsum_up < lower_bound_up.to(device)).float()
138+
condition2_up = (inp_cumsum_up >= upper_bound_up).float()
139+
upper_bound_down = self.num_spin_down
140+
lower_bound_down = (self.num_spin_down - (N - torch.arange(1, N+1)))
141+
condition1_down = (inp_cumsum_down < lower_bound_down.to(device)).float()
142+
condition2_down = (inp_cumsum_down >= upper_bound_down).float()
143+
idx = torch.sort(self.shell_order)[1]
144+
# first entry must be down
145+
log_psi_cond[:,:,0].masked_fill_(condition1_up[:,idx]==1, float('-inf'))
146+
log_psi_cond[:,:,2].masked_fill_(condition1_up[:,idx]==1, float('-inf'))
147+
# second entry must be down
148+
log_psi_cond[:,:,0].masked_fill_(condition1_down[:,idx]==1, float('-inf'))
149+
log_psi_cond[:,:,1].masked_fill_(condition1_down[:,idx]==1, float('-inf'))
150+
# first entry must be up
151+
log_psi_cond[:,:,1].masked_fill_(condition2_up[:,idx]==1, float('-inf'))
152+
log_psi_cond[:,:,3].masked_fill_(condition2_up[:,idx]==1, float('-inf'))
153+
# second entry must be up
154+
log_psi_cond[:,:,2].masked_fill_(condition2_down[:,idx]==1, float('-inf'))
155+
log_psi_cond[:,:,3].masked_fill_(condition2_down[:,idx]==1, float('-inf'))
156+
return log_psi_cond
157+
158+
@torch.no_grad()
159+
def sample(self, bs: int, num_samples: int) -> [torch.Tensor, torch.Tensor]:
160+
'''
161+
Generates a set of samples from the ansatz state vector distribution
162+
Inputs:
163+
bs: total number of unique samples desired
164+
num_samples: total number of non-unique samples desired
165+
Returns:
166+
uniq_samples: unique spin sample set
167+
uniq_counts: tensor of count values (summing to num_samples) corresponding with uniq_samples
168+
'''
169+
self.eval()
170+
self.sampling = True
171+
sample_multinomial = True
172+
# random initialize a configuration of values +- 1
173+
uniq_samples = (torch.randn(1, self.num_sites).to(self.device) > 0.0).float() * 2 - 1
174+
uniq_count = torch.tensor([num_samples]).to(self.device)
175+
for i in self.shell_order:
176+
prob = self.forward(uniq_samples, i)[:, i] # num_uniq, 4
177+
num_uniq = uniq_samples.shape[0]
178+
uniq_samples = uniq_samples.repeat(4,1) # 4*num_uniq, num_sites
179+
# convert [|-1,-1>, |1,-1>, |-1,1>, |1,1>] to [0, 1, 2, 3]
180+
uniq_samples[:num_uniq, 2*i] = -1
181+
uniq_samples[:num_uniq, 2*i+1] = -1
182+
uniq_samples[num_uniq:2*num_uniq, 2*i] = 1
183+
uniq_samples[num_uniq:2*num_uniq, 2*i+1] = -1
184+
uniq_samples[2*num_uniq:3*num_uniq, 2*i] = -1
185+
uniq_samples[2*num_uniq:3*num_uniq, 2*i+1] = 1
186+
uniq_samples[3*num_uniq:4*num_uniq, 2*i] = 1
187+
uniq_samples[3*num_uniq:4*num_uniq, 2*i+1] = 1
188+
if sample_multinomial:
189+
uniq_count = torch.tensor(self.multinomial_arr(uniq_count.long().data.cpu().numpy(), prob.data.cpu().numpy())).T.flatten().to(prob.device)
190+
else:
191+
uniq_count = (uniq_count.unsqueeze(-1)*prob).T.flatten().round()
192+
keep_idx = uniq_count > 0
193+
uniq_samples = uniq_samples[keep_idx]
194+
uniq_count = uniq_count[keep_idx]
195+
uniq_samples = uniq_samples[uniq_count.sort()[1][-2*bs:]]
196+
uniq_count = uniq_count[uniq_count.sort()[1][-2*bs:]]
197+
uniq_samples = uniq_samples[uniq_count.sort()[1][-bs:]]
198+
uniq_count = uniq_count[uniq_count.sort()[1][-bs:]]
199+
self.sampling = False
200+
self.train()
201+
return [uniq_samples, uniq_count]

examples/neural_quantum_states/src/models/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ class Base(nn.Module):
99
'''
1010
Base template for all autoregressive NQS ansatze.
1111
Args:
12+
name: name of specific model type
1213
num_sites: qubit number
1314
num_spin_up: number of spin up electrons
1415
num_spin_down: number of spin down electrons
1516
device: Device (CPU or Cuda) to store model
1617
**kwargs: nonspecific kwargs
1718
'''
18-
def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, device: str, **kwargs):
19+
def __init__(self, name: str, num_sites: int, num_spin_up: int, num_spin_down: int, device: str, **kwargs):
1920
super().__init__()
21+
self.name = name
2022
self.num_sites = num_sites
2123
self.num_spin_up = num_spin_up
2224
self.num_spin_down = num_spin_down
@@ -154,6 +156,9 @@ def get_model(model_name: str, device: str, print_model_info: bool, **kwargs) ->
154156
elif model_name == 'retnet':
155157
from .retnet import NNQSRetNet
156158
model = NNQSRetNet(**kwargs)
159+
elif model_name == 'alt_transformer':
160+
from .alt_transformer import AltTransformer
161+
model = AltTransformer(**kwargs)
157162
else:
158163
raise ValueError(f"Unknown model_name: {model_name}")
159164
if print_model_info:

examples/neural_quantum_states/src/models/made.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@
88
class MADE(Base):
99
'''
1010
Class implements MADE-based NQS ansatz
11+
Parent class args:
12+
num_sites: number of qubits in the ansatz system
13+
num_spin_up: total occupancy number of spin-up spin-orbitals
14+
num_spin_down: total occupancy number of spin-down spin-orbitals
15+
device: Device (CPU or Cuda) to store model
1116
Child class specific args:
1217
made_width: width of modulus and phase network hidden layers
1318
made_depth: number of hidden layers in modulus and phase networks
1419
temperature: Temperature variable for modulus softmax
1520
**kwargs: nonspecific kwargs
1621
'''
1722
def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, made_width: int=64, made_depth: int=2, temperature: float=1.0, device: str=None, **kwargs):
18-
super(MADE, self).__init__(num_sites, num_spin_up, num_spin_down, device)
23+
super(MADE, self).__init__('MADE', num_sites, num_spin_up, num_spin_down, device)
1924
self.temperature = temperature
2025
# construct model
2126
self.net = []

examples/neural_quantum_states/src/models/retnet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ class NNQSRetNet(Base):
1010
def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, made_width: int=64, made_depth: int=2, embedding_dim: int=16, nhead: int=2, dim_feedforward: int=64, num_layers: int=1, temperature: float=1.0, device: str=None, **kwargs):
1111
'''
1212
Retentive network (RetNet) NQS ansatz
13+
Parent class args:
14+
num_sites: number of qubits in the ansatz system
15+
num_spin_up: total occupancy number of spin-up spin-orbitals
16+
num_spin_down: total occupancy number of spin-down spin-orbitals
17+
device: Device (CPU or Cuda) to store model
1318
Child class specific args:
1419
made_width: width of phase network hidden layers
1520
made_depth: number of phase network hidden layers
@@ -20,7 +25,7 @@ def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, made_wi
2025
temperature: RetNet softmax temperature parameter
2126
device: device on which the model is stored
2227
'''
23-
super(NNQSRetNet, self).__init__(num_sites, num_spin_up, num_spin_down, device)
28+
super(NNQSRetNet, self).__init__('RetNet', num_sites, num_spin_up, num_spin_down, device)
2429

2530
# construct model
2631
self.num_in, self.num_out = num_sites, num_sites*2

examples/neural_quantum_states/src/models/transformer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ class NNQSTransformer(Base):
99
def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, made_width: int=64, made_depth: int=2, embedding_dim: int=16, nhead: int=2, dim_feedforward: int=64, num_layers: int=1, temperature: float=1.0, device: str=None, **kwargs):
1010
'''
1111
A Transformer-based autoregressive NQS Ansatz
12+
Parent class args:
13+
num_sites: number of qubits in the ansatz system
14+
num_spin_up: total occupancy number of spin-up spin-orbitals
15+
num_spin_down: total occupancy number of spin-down spin-orbitals
16+
device: Device (CPU or Cuda) to store model
1217
Child class specific args:
1318
made_width: width of phase network hidden layers
1419
made_depth: number of phase network hidden layers
@@ -19,7 +24,7 @@ def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, made_wi
1924
temperature: modulus network softmax temperature parameter
2025
device: device to store model on
2126
'''
22-
super(NNQSTransformer, self).__init__(num_sites, num_spin_up, num_spin_down, device)
27+
super(NNQSTransformer, self).__init__('NNQSTransformer', num_sites, num_spin_up, num_spin_down, device)
2328

2429
# construct model
2530
self.num_in, self.num_out = num_sites, num_sites*2

examples/neural_quantum_states/src/objective/adaptive_shadows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class AdaptiveShadows(Hamiltonian):
1919
flip_bs: Number of unique bit flip patterns processed at a time on each GPU
2020
'''
2121
def __init__(self, hamiltonian_string: str, num_sites: int, sample_count: int, total_unique_samples: int, reset_prob: float, flip_bs: int, **kwargs):
22-
super(AdaptiveShadows, self).__init__(hamiltonian_string, num_sites)
22+
super(AdaptiveShadows, self).__init__('adaptive_shadows', hamiltonian_string, num_sites)
2323
# product of identity operators by default, encoded as 0
2424
self.coefficients = torch.stack((self.coefficients.real, self.coefficients.imag), dim=-1)
2525
self.coefficients_square = norm_square(self.coefficients)

examples/neural_quantum_states/src/objective/automatic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, hamiltonian_string: str, num_sites: int, flip_bs: int, **kwar
1414
num_sites: qubit number of system
1515
flip_bs: largest batch size of model input tensors that each GPU is expected to handle at once
1616
'''
17-
super(Automatic, self).__init__(hamiltonian_string, num_sites)
17+
super(Automatic, self).__init__('automatic', hamiltonian_string, num_sites)
1818
# product of identity operators by default, encoded as 0
1919
self.coefficients = torch.stack((self.coefficients.real, self.coefficients.imag), dim=-1)
2020
self.flip_bs = flip_bs

examples/neural_quantum_states/src/objective/hamiltonian.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from src.complex import scalar_mult, real, imag
55

66
class Hamiltonian(nn.Module):
7-
def __init__(self, hamiltonian_string, num_sites):
7+
def __init__(self, choice, hamiltonian_string, num_sites):
88
super().__init__()
9+
self.name = choice
910
self.operators, self.coefficients = self.parse_hamiltonian_string(hamiltonian_string, num_sites)
1011
self.num_terms, self.input_dim = self.operators.shape
1112
print("Number of terms is {}.".format(self.num_terms))
@@ -83,6 +84,9 @@ def get_hamiltonian(hamiltonian_choice: str, hamiltonian_data: dict) -> nn.Modul
8384
elif hamiltonian_choice in ['exact']:
8485
from .automatic import Automatic
8586
return Automatic(**hamiltonian_data)
87+
elif hamiltonian_choice in ['surrogate']:
88+
from .surrogate import Surrogate
89+
return Surrogate(**hamiltonian_data)
8690
else:
8791
raise Exception('Hamiltonian choice not recognized!')
8892

examples/neural_quantum_states/src/objective/naive_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self, hamiltonian_string: str, num_sites: int, sample_count: int, t
1010
'''
1111
A variation of the Automatic class that stochastically estimates the input Hamiltonian with Pauli strings sampled from the distribution proportional to the absolute values of the scalar coefficients (simple to construct because the Pauli string coefficients are real for Hamiltonians). This estimated Hamiltonian can be used to create local energy estimates during NQS training for (ideally) lower computational cost.
1212
'''
13-
super(NaiveSampler, self).__init__(hamiltonian_string, num_sites)
13+
super(NaiveSampler, self).__init__('naive_sampler', hamiltonian_string, num_sites)
1414
self.flip_bs = flip_bs
1515
# product of identity operators by default, encoded as 0
1616
self.coefficients = torch.stack((self.coefficients.real, self.coefficients.imag), dim=-1)

0 commit comments

Comments
 (0)