|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | +from .base import Base |
| 7 | + |
| 8 | +class AltTransformer(Base): |
| 9 | + def __init__(self, num_sites: int, num_spin_up: int, num_spin_down: int, embedding_dim: int=16, nhead: int=2, dim_feedforward: int=64, num_layers: int=1, temperature: float=1.0, device: str=None, **kwargs): |
| 10 | + ''' |
| 11 | + A Transformer-based autoregressive NQS Ansatz using the phase strategy from Bennewitz et al, where a single linear layer operators on the concated list of transformer hidden states in lieu of a seperate phase network. |
| 12 | + Parent class args: |
| 13 | + num_sites: number of qubits in the ansatz system |
| 14 | + num_spin_up: total occupancy number of spin-up spin-orbitals |
| 15 | + num_spin_down: total occupancy number of spin-down spin-orbitals |
| 16 | + device: Device (CPU or Cuda) to store model |
| 17 | + Child class specific args: |
| 18 | + embedding_dim: dimension of transformer hidden states |
| 19 | + nhead: number of attention heads |
| 20 | + dim_feedforward: dimension of transformer feedforward layer |
| 21 | + num_layers: number of transformer blocks |
| 22 | + temperature: modulus network softmax temperature parameter |
| 23 | + device: device to store model on |
| 24 | + ''' |
| 25 | + super(AltTransformer, self).__init__('AltTransformer', num_sites, num_spin_up, num_spin_down, device) |
| 26 | + |
| 27 | + # construct model |
| 28 | + self.num_in, self.num_out = num_sites, num_sites*2 |
| 29 | + self.temperature = temperature |
| 30 | + # Sample function samples spatial orbitals in reverse order, but spin-up orbitals are always sampled first. self.input_order calculates this order for sampling. |
| 31 | + self.input_order = np.stack([np.arange(self.num_sites-2,-1,-2), np.arange(self.num_sites-1,-1,-2)],1).reshape(-1) # [4,5,2,3,0,1] |
| 32 | + self.input_order = torch.Tensor(self.input_order).int().to(self.device) |
| 33 | + # Calculate spatial orbital sampling order |
| 34 | + self.shell_order = torch.arange(self.num_sites//2-1, -1, -1) # [2,1,0] |
| 35 | + |
| 36 | + transformer_layer = nn.TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward=dim_feedforward, dropout=0.0, batch_first=True) |
| 37 | + self.transformer = nn.TransformerEncoder(transformer_layer, num_layers) |
| 38 | + self.fc = nn.Linear(embedding_dim, 4) |
| 39 | + self.tok_emb = nn.Embedding(5, embedding_dim) |
| 40 | + self.pos_emb = nn.Embedding(len(self.shell_order), embedding_dim) |
| 41 | + self.apply(self._init_weights) |
| 42 | + self.softmax = nn.Softmax(dim=-1) |
| 43 | + self.log_softmax = nn.LogSoftmax(dim=-1) |
| 44 | + |
| 45 | + self.net_phase = nn.Linear(in_features=embedding_dim*len(self.shell_order), out_features=4, bias=True) |
| 46 | + |
| 47 | + self.mask = torch.zeros((len(self.shell_order), len(self.shell_order))).to(self.device) |
| 48 | + for i in range(len(self.mask)): |
| 49 | + for j in range(len(self.mask)): |
| 50 | + if i < j: |
| 51 | + self.mask[i][j] = float('-inf') |
| 52 | + |
| 53 | + def _init_weights(self, module: nn.Module): |
| 54 | + ''' |
| 55 | + Performs weight initialization for each module in ansatz, dependent on module type |
| 56 | + Args: |
| 57 | + module: module to be initialized |
| 58 | + ''' |
| 59 | + if isinstance(module, nn.Linear): |
| 60 | + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| 61 | + if module.bias is not None: |
| 62 | + torch.nn.init.zeros_(module.bias) |
| 63 | + elif isinstance(module, nn.Embedding): |
| 64 | + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| 65 | + elif isinstance(module, nn.LayerNorm): |
| 66 | + torch.nn.init.zeros_(module.bias) |
| 67 | + torch.nn.init.ones_(module.weight) |
| 68 | + |
| 69 | + def forward(self, x: torch.Tensor, sample_shell: int = 0) -> torch.Tensor: |
| 70 | + ''' |
| 71 | + Forward function for Transformer ansatz (used for both sampling and training) |
| 72 | + Args: |
| 73 | + x: qubit spin configuration |
| 74 | + sample_shell: shell index provided during sampling to avoid extraneous forward passes |
| 75 | + Returns: |
| 76 | + prob_cond/log_psi: either conditional probabilities of logarithms of statevector entries, depending on if sampling |
| 77 | + ''' |
| 78 | + # x: [bs, num_sites] |
| 79 | + shells = self.state2shell(x)[:, self.shell_order] |
| 80 | + input = 4*torch.ones(shells.shape, dtype=torch.int64).to(self.device) |
| 81 | + input[:,1:] = shells[:,:-1] |
| 82 | + pos = self.shell_order.to(self.device) |
| 83 | + |
| 84 | + input = self.tok_emb(input) + self.pos_emb(pos) |
| 85 | + # new x is of shape (batch_size, sequence_length, d_model) |
| 86 | + |
| 87 | + if self.mask.device != self.device: |
| 88 | + self.mask = self.mask.to(self.device) |
| 89 | + output = self.transformer(input[:,:(len(self.shell_order) - sample_shell + 1)], mask=self.mask[:(len(self.shell_order) - sample_shell + 1),:(len(self.shell_order) - sample_shell + 1)], is_causal=True) |
| 90 | + |
| 91 | + if not self.sampling: |
| 92 | + phase_input = output.reshape(output.shape[0], -1) |
| 93 | + output = self.fc(output) |
| 94 | + |
| 95 | + if output.shape[1] < len(self.shell_order): |
| 96 | + new_output = torch.zeros(output.shape[0], len(self.shell_order), output.shape[2]).to(self.device) |
| 97 | + new_output[:,:output.shape[1],:] = output |
| 98 | + output = new_output[:, self.shell_order] |
| 99 | + else: |
| 100 | + output = output[:, self.shell_order] |
| 101 | + |
| 102 | + if (self.num_spin_up + self.num_spin_down) >= 0: |
| 103 | + logits_cls = self.apply_constraint(x, output) |
| 104 | + logits_cls /= self.temperature |
| 105 | + |
| 106 | + if self.sampling: |
| 107 | + prob_cond = self.softmax(logits_cls) |
| 108 | + return prob_cond |
| 109 | + else: |
| 110 | + log_psi_cond = 0.5 * self.log_softmax(logits_cls) |
| 111 | + idx = self.state2shell(x) |
| 112 | + log_psi_real = log_psi_cond.gather(-1, idx.unsqueeze(-1)).sum(-1).sum(-1) |
| 113 | + log_psi_imag = self.net_phase(phase_input).gather(-1, idx[:, -1].unsqueeze(-1)).squeeze() |
| 114 | + if log_psi_real.shape[0] == 1: |
| 115 | + log_psi_imag = log_psi_imag.reshape(log_psi_real.shape) |
| 116 | + log_psi = torch.stack((log_psi_real, log_psi_imag), dim=-1) |
| 117 | + return log_psi |
| 118 | + |
| 119 | + def apply_constraint(self, inp: torch.Tensor, log_psi_cond: torch.Tensor) -> torch.Tensor: |
| 120 | + ''' |
| 121 | + Applies constraints that enforce particle number and spin on ansatz network |
| 122 | + Args: |
| 123 | + inp: input spin configurations |
| 124 | + log_psi_cond: unconstrained ansatz outputs |
| 125 | + Returns: |
| 126 | + log_psi_cond: ansatz outputs with constraint applied |
| 127 | + ''' |
| 128 | + # convert [|-1,-1>, |1,-1>, |-1,1>, |1,1>] to [0, 1, 2, 3] |
| 129 | + device = inp.device |
| 130 | + N = inp.shape[-1] // 2 |
| 131 | + inp_up = inp[:, self.input_order][:, ::2].clone() |
| 132 | + inp_down = inp[:, self.input_order][:, 1::2].clone() |
| 133 | + inp_cumsum_up = torch.cat((torch.zeros((inp_up.shape[0],1)).to(device), ((1 + inp_up)/2).cumsum(-1)[:, :-1]), axis=-1) |
| 134 | + inp_cumsum_down = torch.cat((torch.zeros((inp_down.shape[0],1)).to(device), ((1 + inp_down)/2).cumsum(-1)[:, :-1]), axis=-1) |
| 135 | + upper_bound_up = self.num_spin_up |
| 136 | + lower_bound_up = (self.num_spin_up - (N - torch.arange(1, N+1))) |
| 137 | + condition1_up = (inp_cumsum_up < lower_bound_up.to(device)).float() |
| 138 | + condition2_up = (inp_cumsum_up >= upper_bound_up).float() |
| 139 | + upper_bound_down = self.num_spin_down |
| 140 | + lower_bound_down = (self.num_spin_down - (N - torch.arange(1, N+1))) |
| 141 | + condition1_down = (inp_cumsum_down < lower_bound_down.to(device)).float() |
| 142 | + condition2_down = (inp_cumsum_down >= upper_bound_down).float() |
| 143 | + idx = torch.sort(self.shell_order)[1] |
| 144 | + # first entry must be down |
| 145 | + log_psi_cond[:,:,0].masked_fill_(condition1_up[:,idx]==1, float('-inf')) |
| 146 | + log_psi_cond[:,:,2].masked_fill_(condition1_up[:,idx]==1, float('-inf')) |
| 147 | + # second entry must be down |
| 148 | + log_psi_cond[:,:,0].masked_fill_(condition1_down[:,idx]==1, float('-inf')) |
| 149 | + log_psi_cond[:,:,1].masked_fill_(condition1_down[:,idx]==1, float('-inf')) |
| 150 | + # first entry must be up |
| 151 | + log_psi_cond[:,:,1].masked_fill_(condition2_up[:,idx]==1, float('-inf')) |
| 152 | + log_psi_cond[:,:,3].masked_fill_(condition2_up[:,idx]==1, float('-inf')) |
| 153 | + # second entry must be up |
| 154 | + log_psi_cond[:,:,2].masked_fill_(condition2_down[:,idx]==1, float('-inf')) |
| 155 | + log_psi_cond[:,:,3].masked_fill_(condition2_down[:,idx]==1, float('-inf')) |
| 156 | + return log_psi_cond |
| 157 | + |
| 158 | + @torch.no_grad() |
| 159 | + def sample(self, bs: int, num_samples: int) -> [torch.Tensor, torch.Tensor]: |
| 160 | + ''' |
| 161 | + Generates a set of samples from the ansatz state vector distribution |
| 162 | + Inputs: |
| 163 | + bs: total number of unique samples desired |
| 164 | + num_samples: total number of non-unique samples desired |
| 165 | + Returns: |
| 166 | + uniq_samples: unique spin sample set |
| 167 | + uniq_counts: tensor of count values (summing to num_samples) corresponding with uniq_samples |
| 168 | + ''' |
| 169 | + self.eval() |
| 170 | + self.sampling = True |
| 171 | + sample_multinomial = True |
| 172 | + # random initialize a configuration of values +- 1 |
| 173 | + uniq_samples = (torch.randn(1, self.num_sites).to(self.device) > 0.0).float() * 2 - 1 |
| 174 | + uniq_count = torch.tensor([num_samples]).to(self.device) |
| 175 | + for i in self.shell_order: |
| 176 | + prob = self.forward(uniq_samples, i)[:, i] # num_uniq, 4 |
| 177 | + num_uniq = uniq_samples.shape[0] |
| 178 | + uniq_samples = uniq_samples.repeat(4,1) # 4*num_uniq, num_sites |
| 179 | + # convert [|-1,-1>, |1,-1>, |-1,1>, |1,1>] to [0, 1, 2, 3] |
| 180 | + uniq_samples[:num_uniq, 2*i] = -1 |
| 181 | + uniq_samples[:num_uniq, 2*i+1] = -1 |
| 182 | + uniq_samples[num_uniq:2*num_uniq, 2*i] = 1 |
| 183 | + uniq_samples[num_uniq:2*num_uniq, 2*i+1] = -1 |
| 184 | + uniq_samples[2*num_uniq:3*num_uniq, 2*i] = -1 |
| 185 | + uniq_samples[2*num_uniq:3*num_uniq, 2*i+1] = 1 |
| 186 | + uniq_samples[3*num_uniq:4*num_uniq, 2*i] = 1 |
| 187 | + uniq_samples[3*num_uniq:4*num_uniq, 2*i+1] = 1 |
| 188 | + if sample_multinomial: |
| 189 | + uniq_count = torch.tensor(self.multinomial_arr(uniq_count.long().data.cpu().numpy(), prob.data.cpu().numpy())).T.flatten().to(prob.device) |
| 190 | + else: |
| 191 | + uniq_count = (uniq_count.unsqueeze(-1)*prob).T.flatten().round() |
| 192 | + keep_idx = uniq_count > 0 |
| 193 | + uniq_samples = uniq_samples[keep_idx] |
| 194 | + uniq_count = uniq_count[keep_idx] |
| 195 | + uniq_samples = uniq_samples[uniq_count.sort()[1][-2*bs:]] |
| 196 | + uniq_count = uniq_count[uniq_count.sort()[1][-2*bs:]] |
| 197 | + uniq_samples = uniq_samples[uniq_count.sort()[1][-bs:]] |
| 198 | + uniq_count = uniq_count[uniq_count.sort()[1][-bs:]] |
| 199 | + self.sampling = False |
| 200 | + self.train() |
| 201 | + return [uniq_samples, uniq_count] |
0 commit comments