Skip to content

Commit e5b1bfd

Browse files
committed
Allow selection of PEPS type in unit cell random() function
1 parent 84e3477 commit e5b1bfd

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

varipeps/contractions/apply.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from typing import Sequence, List, Tuple, Dict, Union, Optional
1515

1616

17-
@partial(
18-
jax.jit, static_argnames=("name", "disable_identity_check")
19-
)
17+
@partial(jax.jit, static_argnames=("name", "disable_identity_check"))
2018
def apply_contraction(
2119
name: str,
2220
peps_tensors: Sequence[jnp.ndarray],

varipeps/peps/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from . import tensor
22
from . import unitcell
33
from .tensor import (
4+
PEPS_Type,
45
PEPS_Tensor,
56
PEPS_Tensor_Structure_Factor,
67
PEPS_Tensor_Split_Transfer,

varipeps/peps/tensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import collections
88
from dataclasses import dataclass
9+
from enum import Enum, IntEnum, auto, unique
910

1011
import numpy as np
1112
import jax
@@ -26,6 +27,15 @@
2627
)
2728

2829

30+
@unique
31+
class PEPS_Type(IntEnum):
32+
SQUARE = auto() #: Square-lattice based iPEPS state with full transfer tensors
33+
SQUARE_SPLIT = (
34+
auto()
35+
) #: Square-lattice based iPEPS state with split transfer tensors
36+
TRIANGULAR = auto() #: Triangular-lattice based iPEPS state
37+
38+
2939
@dataclass
3040
@register_pytree_node_class
3141
class PEPS_Tensor:
@@ -1172,6 +1182,10 @@ def load_from_group(cls: Type[T_PEPS_Tensor], grp: h5py.Group) -> T_PEPS_Tensor:
11721182
def is_split_transfer(self: T_PEPS_Tensor) -> bool:
11731183
return False
11741184

1185+
@property
1186+
def peps_type(self) -> PEPS_Type:
1187+
return PEPS_Type.SQUARE
1188+
11751189
def convert_to_split_transfer(
11761190
self: T_PEPS_Tensor, interlayer_chi: Optional[int] = None
11771191
) -> T_PEPS_Tensor_Split_Transfer:
@@ -2916,6 +2930,10 @@ def load_from_group(
29162930
def is_split_transfer(self: T_PEPS_Tensor_Split_Transfer) -> bool:
29172931
return True
29182932

2933+
@property
2934+
def peps_type(self) -> PEPS_Type:
2935+
return PEPS_Type.SQUARE_SPLIT
2936+
29192937
def convert_to_split_transfer(
29202938
self: T_PEPS_Tensor_Split_Transfer,
29212939
) -> T_PEPS_Tensor_Split_Transfer:
@@ -4043,6 +4061,10 @@ def is_split_transfer(self) -> bool:
40434061
def is_triangular_peps(self) -> bool:
40444062
return True
40454063

4064+
@property
4065+
def peps_type(self) -> PEPS_Type:
4066+
return PEPS_Type.TRIANGULAR
4067+
40464068
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
40474069
data = (
40484070
self.tensor,

varipeps/peps/unitcell.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import jax.numpy as jnp
1818
from jax.tree_util import register_pytree_node_class
1919

20-
from .tensor import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Tensor_Triangular
20+
from .tensor import (
21+
PEPS_Tensor,
22+
PEPS_Tensor_Split_Transfer,
23+
PEPS_Tensor_Triangular,
24+
PEPS_Type,
25+
)
2126
import varipeps
2227
from varipeps.utils.random import PEPS_Random_Number_Generator
2328
from varipeps.utils.periodic_indices import calculate_periodic_indices
@@ -273,6 +278,7 @@ def random(
273278
chi: Union[int, Sequence[int]],
274279
dtype: Type[jnp.number],
275280
max_chi: Optional[int] = None,
281+
peps_type: PEPS_Type = PEPS_Type.SQUARE,
276282
*,
277283
seed: Optional[int] = None,
278284
destroy_random_state: bool = True,
@@ -320,7 +326,10 @@ def random(
320326
)
321327

322328
if isinstance(D, int):
323-
D = [(D, D, D, D) for _ in range(tensors_i.size)]
329+
if peps_type is PEPS_Type.TRIANGULAR:
330+
D = [(D, D, D, D, D, D) for _ in range(tensors_i.size)]
331+
else:
332+
D = [(D, D, D, D) for _ in range(tensors_i.size)]
324333

325334
if (
326335
not all(isinstance(j, int) for i in D for j in i)
@@ -344,19 +353,28 @@ def random(
344353

345354
peps_tensors = []
346355

356+
if peps_type is PEPS_Type.TRIANGULAR:
357+
peps_tensor_class = PEPS_Tensor_Triangular
358+
else:
359+
peps_tensor_class = PEPS_Tensor
360+
347361
for i in tensors_i:
348362
if i > 0:
349363
seed = None
350364

351365
peps_tensors.append(
352-
PEPS_Tensor.random(
366+
peps_tensor_class.random(
353367
d=d[i], D=D[i], chi=chi[i], dtype=dtype, seed=seed, max_chi=max_chi
354368
)
355369
)
356370

357371
data = cls.Unit_Cell_Data(peps_tensors=peps_tensors, structure=structure)
358372

359-
return cls(data=data)
373+
result = cls(data=data)
374+
375+
if peps_type is PEPS_Type.SQUARE_SPLIT:
376+
return result.convert_to_split_transfer()
377+
return result
360378

361379
def get_size(self) -> Tuple[int, int]:
362380
"""

0 commit comments

Comments
 (0)