Skip to content

Commit e492662

Browse files
committed
support mindspore and tensorflow neural solvers
1 parent 577e88d commit e492662

6 files changed

Lines changed: 1402 additions & 43 deletions

File tree

pygmtools/mindspore_backend.py

Lines changed: 325 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mindspore.ops import stop_gradient
66
import math
77

8+
import pygmtools
89
import inspect
910
import functools
1011
_max_signature = inspect.signature(mindspore.ops.max)
@@ -161,7 +162,8 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo
161162
-float('inf'), dtype=log_s.dtype
162163
)), axis=1)
163164
for b in range(batch_size):
164-
log_s[b, int(ori_nrows[b]):int(nrows[b]), :int(ncols[b])] = -100
165+
if int(nrows[b]) > int(ori_nrows[b]):
166+
log_s[b, int(ori_nrows[b]):int(nrows[b]), :int(ncols[b])] = -100
165167

166168
# assign the unmatch weights
167169
if unmatchrows is not None and unmatchcols is not None:
@@ -238,7 +240,8 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo
238240
if dummy_shape[1] > 0:
239241
ret_log_s = ret_log_s[:, :-dummy_shape[1]]
240242
for b in range(batch_size):
241-
ret_log_s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -float('inf')
243+
if int(nrows[b]) > int(ori_nrows[b]):
244+
ret_log_s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -float('inf')
242245

243246
if transposed_batch.any():
244247
s_t = ret_log_s.swapaxes(1, 2)
@@ -602,3 +605,323 @@ def _mm(input1, input2):
602605
mindspore implementation of _mm
603606
"""
604607
return mindspore.ops.matmul(input1, input2)
608+
609+
610+
############################################
611+
# Neural Network Solvers #
612+
############################################
613+
614+
from pygmtools.mindspore_modules import *
615+
616+
617+
class PCA_GM_Net(nn.Cell):
618+
"""
619+
MindSpore implementation of PCA-GM and IPCA-GM network.
620+
"""
621+
622+
def __init__(self, in_channel, hidden_channel, out_channel, num_layers, cross_iter_num=-1):
623+
super().__init__()
624+
self.gnn_layer = num_layers
625+
self.gnn_layer_list = nn.CellList()
626+
self.affinity_list = nn.CellList()
627+
self.cross_graph_list = nn.CellList()
628+
for i in range(self.gnn_layer):
629+
if i == 0:
630+
gnn_layer = Siamese_Gconv(in_channel, hidden_channel)
631+
elif 0 < i < self.gnn_layer - 1:
632+
gnn_layer = Siamese_Gconv(hidden_channel, hidden_channel)
633+
else:
634+
gnn_layer = Siamese_Gconv(hidden_channel, out_channel)
635+
self.gnn_layer_list.append(gnn_layer)
636+
637+
if i == self.gnn_layer - 1:
638+
self.affinity_list.append(WeightedInnerProdAffinity(out_channel))
639+
elif i == self.gnn_layer - 2 and cross_iter_num <= 0:
640+
self.affinity_list.append(WeightedInnerProdAffinity(hidden_channel))
641+
else:
642+
self.affinity_list.append(Identity())
643+
644+
if i == self.gnn_layer - 2:
645+
self.cross_graph_list.append(nn.Dense(hidden_channel * 2, hidden_channel))
646+
else:
647+
self.cross_graph_list.append(Identity())
648+
649+
def construct(self, feat1, feat2, A1, A2, n1, n2, cross_iter_num, sk_max_iter, sk_tau):
650+
_sinkhorn_func = functools.partial(
651+
sinkhorn, dummy_row=False, max_iter=sk_max_iter, tau=sk_tau, batched_operation=False
652+
)
653+
emb1, emb2 = feat1, feat2
654+
if cross_iter_num <= 0:
655+
for i in range(self.gnn_layer):
656+
emb1, emb2 = self.gnn_layer_list[i]([A1, emb1], [A2, emb2])
657+
if i == self.gnn_layer - 2:
658+
s = self.affinity_list[i](emb1, emb2)
659+
s = _sinkhorn_func(s, n1, n2)
660+
cross_graph = self.cross_graph_list[i]
661+
new_emb1 = cross_graph(mindspore.ops.concat((emb1, mindspore.ops.BatchMatMul()(s, emb2)), axis=-1))
662+
new_emb2 = cross_graph(
663+
mindspore.ops.concat((emb2, mindspore.ops.BatchMatMul()(s.swapaxes(1, 2), emb1)), axis=-1)
664+
)
665+
emb1, emb2 = new_emb1, new_emb2
666+
667+
s = self.affinity_list[self.gnn_layer - 1](emb1, emb2)
668+
s = _sinkhorn_func(s, n1, n2)
669+
else:
670+
for i in range(self.gnn_layer - 1):
671+
emb1, emb2 = self.gnn_layer_list[i]([A1, emb1], [A2, emb2])
672+
673+
emb1_0, emb2_0 = emb1, emb2
674+
s = mindspore.ops.zeros((emb1.shape[0], emb1.shape[1], emb2.shape[1]), emb1.dtype)
675+
for _ in range(cross_iter_num):
676+
i = self.gnn_layer - 2
677+
cross_graph = self.cross_graph_list[i]
678+
emb1 = cross_graph(mindspore.ops.concat((emb1_0, mindspore.ops.BatchMatMul()(s, emb2_0)), axis=-1))
679+
emb2 = cross_graph(
680+
mindspore.ops.concat((emb2_0, mindspore.ops.BatchMatMul()(s.swapaxes(1, 2), emb1_0)), axis=-1)
681+
)
682+
i = self.gnn_layer - 1
683+
emb1, emb2 = self.gnn_layer_list[i]([A1, emb1], [A2, emb2])
684+
s = self.affinity_list[i](emb1, emb2)
685+
s = _sinkhorn_func(s, n1, n2)
686+
return s
687+
688+
689+
class CIE_Net(nn.Cell):
690+
"""
691+
MindSpore implementation of CIE graph matching network.
692+
"""
693+
694+
def __init__(self, in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers):
695+
super().__init__()
696+
self.gnn_layer = num_layers
697+
self.gnn_layer_list = nn.CellList()
698+
self.affinity_list = nn.CellList()
699+
self.cross_graph_list = nn.CellList()
700+
for i in range(self.gnn_layer):
701+
if i == 0:
702+
gnn_layer = Siamese_ChannelIndependentConv(in_node_channel, hidden_channel, in_edge_channel)
703+
elif 0 < i < self.gnn_layer - 1:
704+
gnn_layer = Siamese_ChannelIndependentConv(hidden_channel, hidden_channel, hidden_channel)
705+
else:
706+
gnn_layer = Siamese_ChannelIndependentConv(hidden_channel, out_channel, hidden_channel)
707+
self.gnn_layer_list.append(gnn_layer)
708+
709+
if i == self.gnn_layer - 1:
710+
self.affinity_list.append(WeightedInnerProdAffinity(out_channel))
711+
elif i == self.gnn_layer - 2:
712+
self.affinity_list.append(WeightedInnerProdAffinity(hidden_channel))
713+
else:
714+
self.affinity_list.append(Identity())
715+
716+
if i == self.gnn_layer - 2:
717+
self.cross_graph_list.append(nn.Dense(hidden_channel * 2, hidden_channel))
718+
else:
719+
self.cross_graph_list.append(Identity())
720+
721+
def construct(self, feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2, sk_max_iter, sk_tau):
722+
_sinkhorn_func = functools.partial(
723+
sinkhorn, dummy_row=False, max_iter=sk_max_iter, tau=sk_tau, batched_operation=False
724+
)
725+
emb1, emb2 = feat_node1, feat_node2
726+
emb_edge1, emb_edge2 = feat_edge1, feat_edge2
727+
for i in range(self.gnn_layer):
728+
emb1, emb2, emb_edge1, emb_edge2 = self.gnn_layer_list[i]([A1, emb1, emb_edge1], [A2, emb2, emb_edge2])
729+
if i == self.gnn_layer - 2:
730+
s = self.affinity_list[i](emb1, emb2)
731+
s = _sinkhorn_func(s, n1, n2)
732+
cross_graph = self.cross_graph_list[i]
733+
new_emb1 = cross_graph(mindspore.ops.concat((emb1, mindspore.ops.BatchMatMul()(s, emb2)), axis=-1))
734+
new_emb2 = cross_graph(
735+
mindspore.ops.concat((emb2, mindspore.ops.BatchMatMul()(s.swapaxes(1, 2), emb1)), axis=-1)
736+
)
737+
emb1, emb2 = new_emb1, new_emb2
738+
739+
s = self.affinity_list[self.gnn_layer - 1](emb1, emb2)
740+
s = _sinkhorn_func(s, n1, n2)
741+
return s
742+
743+
744+
class NGM_Net(nn.Cell):
745+
"""
746+
MindSpore implementation of NGM network.
747+
"""
748+
749+
def __init__(self, gnn_channels, sk_emb):
750+
super().__init__()
751+
self.gnn_layer = len(gnn_channels)
752+
self.gnn_layer_list = nn.CellList()
753+
for i in range(self.gnn_layer):
754+
if i == 0:
755+
gnn_layer = NGMConvLayer(1, 1, gnn_channels[i] + sk_emb, gnn_channels[i], sk_channel=sk_emb)
756+
else:
757+
gnn_layer = NGMConvLayer(
758+
gnn_channels[i - 1] + sk_emb, gnn_channels[i - 1],
759+
gnn_channels[i] + sk_emb, gnn_channels[i], sk_channel=sk_emb
760+
)
761+
self.gnn_layer_list.append(gnn_layer)
762+
self.classifier = nn.Dense(gnn_channels[-1] + sk_emb, 1)
763+
764+
def construct(self, K, n1, n2, n1max, n2max, v0, sk_max_iter, sk_tau):
765+
_sinkhorn_func = functools.partial(
766+
sinkhorn, dummy_row=False, max_iter=sk_max_iter, tau=sk_tau, batched_operation=False
767+
)
768+
emb = v0
769+
A = (K != 0).astype(K.dtype)
770+
emb_K = mindspore.ops.expand_dims(K, axis=-1)
771+
for i in range(self.gnn_layer):
772+
emb_K, emb = self.gnn_layer_list[i](A, emb_K, emb, n1, n2, sk_func=_sinkhorn_func)
773+
v = self.classifier(emb)
774+
s = v.reshape((v.shape[0], int(n2max), -1)).swapaxes(1, 2)
775+
return _sinkhorn_func(s, n1, n2, dummy_row=True)
776+
777+
778+
pca_gm_pretrain_path = {
779+
'voc': ('pca_gm_voc_mindspore.ckpt',
780+
['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc_mindspore.ckpt'],
781+
'e49379424ffea6759526bd9d436a0dcf'),
782+
'willow': ('pca_gm_willow_mindspore.ckpt',
783+
['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_willow_mindspore.ckpt'],
784+
'3ed733a9b04e2a83142b08db2b9952cf'),
785+
'voc-all': ('pca_gm_voc-all_mindspore.ckpt',
786+
['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc-all_mindspore.ckpt'],
787+
'28d40ccdc8bc743d2eca459758d887a2'),
788+
}
789+
790+
ipca_gm_pretrain_path = {
791+
'voc': ('ipca_gm_voc_mindspore.ckpt',
792+
['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_voc_mindspore.ckpt'],
793+
'c9f888eefbc22684317f5deedf175da7'),
794+
'willow': ('ipca_gm_willow_mindspore.ckpt',
795+
['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_willow_mindspore.ckpt'],
796+
'8d25c7bc7d7350467e07e53d1d004d63'),
797+
}
798+
799+
cie_pretrain_path = {
800+
'voc': ('cie_voc_mindspore.ckpt',
801+
['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_voc_mindspore.ckpt'],
802+
'be80f98e26af89a68421286f60d544f7'),
803+
'willow': ('cie_willow_mindspore.ckpt',
804+
['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_willow_mindspore.ckpt'],
805+
'b295e0bfa7367e9a2830b0ec25a99220'),
806+
}
807+
808+
ngm_pretrain_path = {
809+
'voc': ('ngm_voc_mindspore.ckpt',
810+
['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_voc_mindspore.ckpt'],
811+
'afa3d94ac9685dba82629e9ef79b19cf'),
812+
'willow': ('ngm_willow_mindspore.ckpt',
813+
['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_willow_mindspore.ckpt'],
814+
'877aad75a62ad6cddbd45a0f4ece3790'),
815+
}
816+
817+
818+
def _save_model(model, path):
819+
mindspore.save_checkpoint(model, path)
820+
821+
822+
def _load_model(model, path, strict=True):
823+
param_dict = mindspore.load_checkpoint(path)
824+
try:
825+
param_not_load, ckpt_not_load = mindspore.load_param_into_net(model, param_dict, strict_load=strict)
826+
except TypeError:
827+
param_not_load, ckpt_not_load = mindspore.load_param_into_net(model, param_dict)
828+
if len(ckpt_not_load) > 0:
829+
print('Warning: Unexpected key(s) in state_dict: {}. '.format(
830+
', '.join('"{}"'.format(k) for k in ckpt_not_load)))
831+
if len(param_not_load) > 0:
832+
print('Warning: Missing key(s) in state_dict: {}. '.format(
833+
', '.join('"{}"'.format(k) for k in param_not_load)))
834+
835+
836+
def _get_pretrain_file(pretrain, pretrain_path):
837+
if pretrain in pretrain_path:
838+
filename, url, md5 = pretrain_path[pretrain]
839+
return pygmtools.utils.download(filename, url, md5)
840+
raise ValueError(f'Unknown pretrain tag. Available tags: {pretrain_path.keys()}')
841+
842+
843+
def pca_gm(feat1, feat2, A1, A2, n1, n2,
844+
in_channel, hidden_channel, out_channel, num_layers, sk_max_iter, sk_tau,
845+
network, pretrain):
846+
"""
847+
MindSpore implementation of PCA-GM.
848+
"""
849+
forward_pass = feat1 is not None
850+
if network is None:
851+
network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers)
852+
if pretrain:
853+
_load_model(network, _get_pretrain_file(pretrain, pca_gm_pretrain_path))
854+
if forward_pass:
855+
batch_size = feat1.shape[0]
856+
if n1 is None:
857+
n1 = mindspore.Tensor([feat1.shape[1]] * batch_size, dtype=mindspore.int32)
858+
if n2 is None:
859+
n2 = mindspore.Tensor([feat2.shape[1]] * batch_size, dtype=mindspore.int32)
860+
result = network(feat1, feat2, A1, A2, n1, n2, -1, sk_max_iter, sk_tau)
861+
else:
862+
result = None
863+
return result, network
864+
865+
866+
def ipca_gm(feat1, feat2, A1, A2, n1, n2,
867+
in_channel, hidden_channel, out_channel, num_layers, cross_iter, sk_max_iter, sk_tau,
868+
network, pretrain):
869+
"""
870+
MindSpore implementation of IPCA-GM.
871+
"""
872+
forward_pass = feat1 is not None
873+
if network is None:
874+
network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers, cross_iter)
875+
if pretrain:
876+
_load_model(network, _get_pretrain_file(pretrain, ipca_gm_pretrain_path))
877+
if forward_pass:
878+
batch_size = feat1.shape[0]
879+
if n1 is None:
880+
n1 = mindspore.Tensor([feat1.shape[1]] * batch_size, dtype=mindspore.int32)
881+
if n2 is None:
882+
n2 = mindspore.Tensor([feat2.shape[1]] * batch_size, dtype=mindspore.int32)
883+
result = network(feat1, feat2, A1, A2, n1, n2, cross_iter, sk_max_iter, sk_tau)
884+
else:
885+
result = None
886+
return result, network
887+
888+
889+
def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2,
890+
in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers, sk_max_iter, sk_tau,
891+
network, pretrain):
892+
"""
893+
MindSpore implementation of CIE.
894+
"""
895+
forward_pass = feat_node1 is not None
896+
if network is None:
897+
network = CIE_Net(in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers)
898+
if pretrain:
899+
_load_model(network, _get_pretrain_file(pretrain, cie_pretrain_path))
900+
if forward_pass:
901+
batch_size = feat_node1.shape[0]
902+
if n1 is None:
903+
n1 = mindspore.Tensor([feat_node1.shape[1]] * batch_size, dtype=mindspore.int32)
904+
if n2 is None:
905+
n2 = mindspore.Tensor([feat_node1.shape[1]] * batch_size, dtype=mindspore.int32)
906+
result = network(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2, sk_max_iter, sk_tau)
907+
else:
908+
result = None
909+
return result, network
910+
911+
912+
def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau, network, return_network, pretrain):
913+
"""
914+
MindSpore implementation of NGM.
915+
"""
916+
forward_pass = K is not None
917+
if network is None:
918+
network = NGM_Net(gnn_channels, sk_emb)
919+
if pretrain:
920+
_load_model(network, _get_pretrain_file(pretrain, ngm_pretrain_path))
921+
if forward_pass:
922+
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
923+
v0 = v0 / mindspore.ops.mean(v0)
924+
result = network(K, n1, n2, n1max, n2max, v0, sk_max_iter, sk_tau)
925+
else:
926+
result = None
927+
return result, network

0 commit comments

Comments
 (0)