-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathextract.jl
More file actions
48 lines (47 loc) · 1.82 KB
/
extract.jl
File metadata and controls
48 lines (47 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Here extract_local_tensor and insert_local_tensor
# are essentially inverse operations, adapted for different kinds of
# algorithms and networks.
#
# In the simplest case, exact_local_tensor contracts together a few
# tensors of the network and returns the result, while
# insert_local_tensors takes that tensor and factorizes it back
# apart and puts it back into the network.
#
function default_extracter(state, projected_operator, region, ortho; internal_kwargs)
state = orthogonalize(state, ortho)
if isa(region, AbstractEdge)
other_vertex = only(setdiff(support(region), [ortho]))
left_inds = uniqueinds(state[ortho], state[other_vertex])
#ToDo: replace with call to factorize
U, S, V = svd(
state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
)
state[ortho] = U
local_tensor = S * V
else
local_tensor = prod(state[v] for v in region)
end
projected_operator = position(projected_operator, state, region)
return state, projected_operator, local_tensor
end
function extract_and_truncate(state,projected_operator, region, ortho; maxdim=nothing, cutoff=nothing,internal_kwargs)
svd_kwargs= (;
(isnothing(maxdim) ? (;) : (;maxdim))...,
(isnothing(cutoff) ? (;) : (;cutoff))...
)
state = orthogonalize(state, ortho)
if isa(region, AbstractEdge)
other_vertex = only(setdiff(support(region), [ortho]))
left_inds = uniqueinds(state[ortho], state[other_vertex])
#ToDo: replace with call to factorize
U, S, V = svd(
state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region), svd_kwargs...
)
state[ortho] = U
local_tensor = S * V
else
local_tensor = prod(state[v] for v in region)
end
projected_operator = position(projected_operator, state, region)
return state, projected_operator, local_tensor
end