|
5 | 5 | from mindspore.ops import stop_gradient |
6 | 6 | import math |
7 | 7 |
|
| 8 | +import pygmtools |
8 | 9 | import inspect |
9 | 10 | import functools |
10 | 11 | _max_signature = inspect.signature(mindspore.ops.max) |
@@ -161,7 +162,8 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo |
161 | 162 | -float('inf'), dtype=log_s.dtype |
162 | 163 | )), axis=1) |
163 | 164 | for b in range(batch_size): |
164 | | - log_s[b, int(ori_nrows[b]):int(nrows[b]), :int(ncols[b])] = -100 |
| 165 | + if int(nrows[b]) > int(ori_nrows[b]): |
| 166 | + log_s[b, int(ori_nrows[b]):int(nrows[b]), :int(ncols[b])] = -100 |
165 | 167 |
|
166 | 168 | # assign the unmatch weights |
167 | 169 | if unmatchrows is not None and unmatchcols is not None: |
@@ -238,7 +240,8 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo |
238 | 240 | if dummy_shape[1] > 0: |
239 | 241 | ret_log_s = ret_log_s[:, :-dummy_shape[1]] |
240 | 242 | for b in range(batch_size): |
241 | | - ret_log_s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -float('inf') |
| 243 | + if int(nrows[b]) > int(ori_nrows[b]): |
| 244 | + ret_log_s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -float('inf') |
242 | 245 |
|
243 | 246 | if transposed_batch.any(): |
244 | 247 | s_t = ret_log_s.swapaxes(1, 2) |
@@ -602,3 +605,323 @@ def _mm(input1, input2): |
602 | 605 | mindspore implementation of _mm |
603 | 606 | """ |
604 | 607 | return mindspore.ops.matmul(input1, input2) |
| 608 | + |
| 609 | + |
| 610 | +############################################ |
| 611 | +# Neural Network Solvers # |
| 612 | +############################################ |
| 613 | + |
| 614 | +from pygmtools.mindspore_modules import * |
| 615 | + |
| 616 | + |
| 617 | +class PCA_GM_Net(nn.Cell): |
| 618 | + """ |
| 619 | + MindSpore implementation of PCA-GM and IPCA-GM network. |
| 620 | + """ |
| 621 | + |
| 622 | + def __init__(self, in_channel, hidden_channel, out_channel, num_layers, cross_iter_num=-1): |
| 623 | + super().__init__() |
| 624 | + self.gnn_layer = num_layers |
| 625 | + self.gnn_layer_list = nn.CellList() |
| 626 | + self.affinity_list = nn.CellList() |
| 627 | + self.cross_graph_list = nn.CellList() |
| 628 | + for i in range(self.gnn_layer): |
| 629 | + if i == 0: |
| 630 | + gnn_layer = Siamese_Gconv(in_channel, hidden_channel) |
| 631 | + elif 0 < i < self.gnn_layer - 1: |
| 632 | + gnn_layer = Siamese_Gconv(hidden_channel, hidden_channel) |
| 633 | + else: |
| 634 | + gnn_layer = Siamese_Gconv(hidden_channel, out_channel) |
| 635 | + self.gnn_layer_list.append(gnn_layer) |
| 636 | + |
| 637 | + if i == self.gnn_layer - 1: |
| 638 | + self.affinity_list.append(WeightedInnerProdAffinity(out_channel)) |
| 639 | + elif i == self.gnn_layer - 2 and cross_iter_num <= 0: |
| 640 | + self.affinity_list.append(WeightedInnerProdAffinity(hidden_channel)) |
| 641 | + else: |
| 642 | + self.affinity_list.append(Identity()) |
| 643 | + |
| 644 | + if i == self.gnn_layer - 2: |
| 645 | + self.cross_graph_list.append(nn.Dense(hidden_channel * 2, hidden_channel)) |
| 646 | + else: |
| 647 | + self.cross_graph_list.append(Identity()) |
| 648 | + |
| 649 | + def construct(self, feat1, feat2, A1, A2, n1, n2, cross_iter_num, sk_max_iter, sk_tau): |
| 650 | + _sinkhorn_func = functools.partial( |
| 651 | + sinkhorn, dummy_row=False, max_iter=sk_max_iter, tau=sk_tau, batched_operation=False |
| 652 | + ) |
| 653 | + emb1, emb2 = feat1, feat2 |
| 654 | + if cross_iter_num <= 0: |
| 655 | + for i in range(self.gnn_layer): |
| 656 | + emb1, emb2 = self.gnn_layer_list[i]([A1, emb1], [A2, emb2]) |
| 657 | + if i == self.gnn_layer - 2: |
| 658 | + s = self.affinity_list[i](emb1, emb2) |
| 659 | + s = _sinkhorn_func(s, n1, n2) |
| 660 | + cross_graph = self.cross_graph_list[i] |
| 661 | + new_emb1 = cross_graph(mindspore.ops.concat((emb1, mindspore.ops.BatchMatMul()(s, emb2)), axis=-1)) |
| 662 | + new_emb2 = cross_graph( |
| 663 | + mindspore.ops.concat((emb2, mindspore.ops.BatchMatMul()(s.swapaxes(1, 2), emb1)), axis=-1) |
| 664 | + ) |
| 665 | + emb1, emb2 = new_emb1, new_emb2 |
| 666 | + |
| 667 | + s = self.affinity_list[self.gnn_layer - 1](emb1, emb2) |
| 668 | + s = _sinkhorn_func(s, n1, n2) |
| 669 | + else: |
| 670 | + for i in range(self.gnn_layer - 1): |
| 671 | + emb1, emb2 = self.gnn_layer_list[i]([A1, emb1], [A2, emb2]) |
| 672 | + |
| 673 | + emb1_0, emb2_0 = emb1, emb2 |
| 674 | + s = mindspore.ops.zeros((emb1.shape[0], emb1.shape[1], emb2.shape[1]), emb1.dtype) |
| 675 | + for _ in range(cross_iter_num): |
| 676 | + i = self.gnn_layer - 2 |
| 677 | + cross_graph = self.cross_graph_list[i] |
| 678 | + emb1 = cross_graph(mindspore.ops.concat((emb1_0, mindspore.ops.BatchMatMul()(s, emb2_0)), axis=-1)) |
| 679 | + emb2 = cross_graph( |
| 680 | + mindspore.ops.concat((emb2_0, mindspore.ops.BatchMatMul()(s.swapaxes(1, 2), emb1_0)), axis=-1) |
| 681 | + ) |
| 682 | + i = self.gnn_layer - 1 |
| 683 | + emb1, emb2 = self.gnn_layer_list[i]([A1, emb1], [A2, emb2]) |
| 684 | + s = self.affinity_list[i](emb1, emb2) |
| 685 | + s = _sinkhorn_func(s, n1, n2) |
| 686 | + return s |
| 687 | + |
| 688 | + |
| 689 | +class CIE_Net(nn.Cell): |
| 690 | + """ |
| 691 | + MindSpore implementation of CIE graph matching network. |
| 692 | + """ |
| 693 | + |
| 694 | + def __init__(self, in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers): |
| 695 | + super().__init__() |
| 696 | + self.gnn_layer = num_layers |
| 697 | + self.gnn_layer_list = nn.CellList() |
| 698 | + self.affinity_list = nn.CellList() |
| 699 | + self.cross_graph_list = nn.CellList() |
| 700 | + for i in range(self.gnn_layer): |
| 701 | + if i == 0: |
| 702 | + gnn_layer = Siamese_ChannelIndependentConv(in_node_channel, hidden_channel, in_edge_channel) |
| 703 | + elif 0 < i < self.gnn_layer - 1: |
| 704 | + gnn_layer = Siamese_ChannelIndependentConv(hidden_channel, hidden_channel, hidden_channel) |
| 705 | + else: |
| 706 | + gnn_layer = Siamese_ChannelIndependentConv(hidden_channel, out_channel, hidden_channel) |
| 707 | + self.gnn_layer_list.append(gnn_layer) |
| 708 | + |
| 709 | + if i == self.gnn_layer - 1: |
| 710 | + self.affinity_list.append(WeightedInnerProdAffinity(out_channel)) |
| 711 | + elif i == self.gnn_layer - 2: |
| 712 | + self.affinity_list.append(WeightedInnerProdAffinity(hidden_channel)) |
| 713 | + else: |
| 714 | + self.affinity_list.append(Identity()) |
| 715 | + |
| 716 | + if i == self.gnn_layer - 2: |
| 717 | + self.cross_graph_list.append(nn.Dense(hidden_channel * 2, hidden_channel)) |
| 718 | + else: |
| 719 | + self.cross_graph_list.append(Identity()) |
| 720 | + |
| 721 | + def construct(self, feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2, sk_max_iter, sk_tau): |
| 722 | + _sinkhorn_func = functools.partial( |
| 723 | + sinkhorn, dummy_row=False, max_iter=sk_max_iter, tau=sk_tau, batched_operation=False |
| 724 | + ) |
| 725 | + emb1, emb2 = feat_node1, feat_node2 |
| 726 | + emb_edge1, emb_edge2 = feat_edge1, feat_edge2 |
| 727 | + for i in range(self.gnn_layer): |
| 728 | + emb1, emb2, emb_edge1, emb_edge2 = self.gnn_layer_list[i]([A1, emb1, emb_edge1], [A2, emb2, emb_edge2]) |
| 729 | + if i == self.gnn_layer - 2: |
| 730 | + s = self.affinity_list[i](emb1, emb2) |
| 731 | + s = _sinkhorn_func(s, n1, n2) |
| 732 | + cross_graph = self.cross_graph_list[i] |
| 733 | + new_emb1 = cross_graph(mindspore.ops.concat((emb1, mindspore.ops.BatchMatMul()(s, emb2)), axis=-1)) |
| 734 | + new_emb2 = cross_graph( |
| 735 | + mindspore.ops.concat((emb2, mindspore.ops.BatchMatMul()(s.swapaxes(1, 2), emb1)), axis=-1) |
| 736 | + ) |
| 737 | + emb1, emb2 = new_emb1, new_emb2 |
| 738 | + |
| 739 | + s = self.affinity_list[self.gnn_layer - 1](emb1, emb2) |
| 740 | + s = _sinkhorn_func(s, n1, n2) |
| 741 | + return s |
| 742 | + |
| 743 | + |
| 744 | +class NGM_Net(nn.Cell): |
| 745 | + """ |
| 746 | + MindSpore implementation of NGM network. |
| 747 | + """ |
| 748 | + |
| 749 | + def __init__(self, gnn_channels, sk_emb): |
| 750 | + super().__init__() |
| 751 | + self.gnn_layer = len(gnn_channels) |
| 752 | + self.gnn_layer_list = nn.CellList() |
| 753 | + for i in range(self.gnn_layer): |
| 754 | + if i == 0: |
| 755 | + gnn_layer = NGMConvLayer(1, 1, gnn_channels[i] + sk_emb, gnn_channels[i], sk_channel=sk_emb) |
| 756 | + else: |
| 757 | + gnn_layer = NGMConvLayer( |
| 758 | + gnn_channels[i - 1] + sk_emb, gnn_channels[i - 1], |
| 759 | + gnn_channels[i] + sk_emb, gnn_channels[i], sk_channel=sk_emb |
| 760 | + ) |
| 761 | + self.gnn_layer_list.append(gnn_layer) |
| 762 | + self.classifier = nn.Dense(gnn_channels[-1] + sk_emb, 1) |
| 763 | + |
| 764 | + def construct(self, K, n1, n2, n1max, n2max, v0, sk_max_iter, sk_tau): |
| 765 | + _sinkhorn_func = functools.partial( |
| 766 | + sinkhorn, dummy_row=False, max_iter=sk_max_iter, tau=sk_tau, batched_operation=False |
| 767 | + ) |
| 768 | + emb = v0 |
| 769 | + A = (K != 0).astype(K.dtype) |
| 770 | + emb_K = mindspore.ops.expand_dims(K, axis=-1) |
| 771 | + for i in range(self.gnn_layer): |
| 772 | + emb_K, emb = self.gnn_layer_list[i](A, emb_K, emb, n1, n2, sk_func=_sinkhorn_func) |
| 773 | + v = self.classifier(emb) |
| 774 | + s = v.reshape((v.shape[0], int(n2max), -1)).swapaxes(1, 2) |
| 775 | + return _sinkhorn_func(s, n1, n2, dummy_row=True) |
| 776 | + |
| 777 | + |
| 778 | +pca_gm_pretrain_path = { |
| 779 | + 'voc': ('pca_gm_voc_mindspore.ckpt', |
| 780 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc_mindspore.ckpt'], |
| 781 | + 'e49379424ffea6759526bd9d436a0dcf'), |
| 782 | + 'willow': ('pca_gm_willow_mindspore.ckpt', |
| 783 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_willow_mindspore.ckpt'], |
| 784 | + '3ed733a9b04e2a83142b08db2b9952cf'), |
| 785 | + 'voc-all': ('pca_gm_voc-all_mindspore.ckpt', |
| 786 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/pca_gm_voc-all_mindspore.ckpt'], |
| 787 | + '28d40ccdc8bc743d2eca459758d887a2'), |
| 788 | +} |
| 789 | + |
| 790 | +ipca_gm_pretrain_path = { |
| 791 | + 'voc': ('ipca_gm_voc_mindspore.ckpt', |
| 792 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_voc_mindspore.ckpt'], |
| 793 | + 'c9f888eefbc22684317f5deedf175da7'), |
| 794 | + 'willow': ('ipca_gm_willow_mindspore.ckpt', |
| 795 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/ipca_gm_willow_mindspore.ckpt'], |
| 796 | + '8d25c7bc7d7350467e07e53d1d004d63'), |
| 797 | +} |
| 798 | + |
| 799 | +cie_pretrain_path = { |
| 800 | + 'voc': ('cie_voc_mindspore.ckpt', |
| 801 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_voc_mindspore.ckpt'], |
| 802 | + 'be80f98e26af89a68421286f60d544f7'), |
| 803 | + 'willow': ('cie_willow_mindspore.ckpt', |
| 804 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/cie_willow_mindspore.ckpt'], |
| 805 | + 'b295e0bfa7367e9a2830b0ec25a99220'), |
| 806 | +} |
| 807 | + |
| 808 | +ngm_pretrain_path = { |
| 809 | + 'voc': ('ngm_voc_mindspore.ckpt', |
| 810 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_voc_mindspore.ckpt'], |
| 811 | + 'afa3d94ac9685dba82629e9ef79b19cf'), |
| 812 | + 'willow': ('ngm_willow_mindspore.ckpt', |
| 813 | + ['https://huggingface.co/heatingma/pygmtools/resolve/main/ngm_willow_mindspore.ckpt'], |
| 814 | + '877aad75a62ad6cddbd45a0f4ece3790'), |
| 815 | +} |
| 816 | + |
| 817 | + |
| 818 | +def _save_model(model, path): |
| 819 | + mindspore.save_checkpoint(model, path) |
| 820 | + |
| 821 | + |
| 822 | +def _load_model(model, path, strict=True): |
| 823 | + param_dict = mindspore.load_checkpoint(path) |
| 824 | + try: |
| 825 | + param_not_load, ckpt_not_load = mindspore.load_param_into_net(model, param_dict, strict_load=strict) |
| 826 | + except TypeError: |
| 827 | + param_not_load, ckpt_not_load = mindspore.load_param_into_net(model, param_dict) |
| 828 | + if len(ckpt_not_load) > 0: |
| 829 | + print('Warning: Unexpected key(s) in state_dict: {}. '.format( |
| 830 | + ', '.join('"{}"'.format(k) for k in ckpt_not_load))) |
| 831 | + if len(param_not_load) > 0: |
| 832 | + print('Warning: Missing key(s) in state_dict: {}. '.format( |
| 833 | + ', '.join('"{}"'.format(k) for k in param_not_load))) |
| 834 | + |
| 835 | + |
| 836 | +def _get_pretrain_file(pretrain, pretrain_path): |
| 837 | + if pretrain in pretrain_path: |
| 838 | + filename, url, md5 = pretrain_path[pretrain] |
| 839 | + return pygmtools.utils.download(filename, url, md5) |
| 840 | + raise ValueError(f'Unknown pretrain tag. Available tags: {pretrain_path.keys()}') |
| 841 | + |
| 842 | + |
| 843 | +def pca_gm(feat1, feat2, A1, A2, n1, n2, |
| 844 | + in_channel, hidden_channel, out_channel, num_layers, sk_max_iter, sk_tau, |
| 845 | + network, pretrain): |
| 846 | + """ |
| 847 | + MindSpore implementation of PCA-GM. |
| 848 | + """ |
| 849 | + forward_pass = feat1 is not None |
| 850 | + if network is None: |
| 851 | + network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers) |
| 852 | + if pretrain: |
| 853 | + _load_model(network, _get_pretrain_file(pretrain, pca_gm_pretrain_path)) |
| 854 | + if forward_pass: |
| 855 | + batch_size = feat1.shape[0] |
| 856 | + if n1 is None: |
| 857 | + n1 = mindspore.Tensor([feat1.shape[1]] * batch_size, dtype=mindspore.int32) |
| 858 | + if n2 is None: |
| 859 | + n2 = mindspore.Tensor([feat2.shape[1]] * batch_size, dtype=mindspore.int32) |
| 860 | + result = network(feat1, feat2, A1, A2, n1, n2, -1, sk_max_iter, sk_tau) |
| 861 | + else: |
| 862 | + result = None |
| 863 | + return result, network |
| 864 | + |
| 865 | + |
| 866 | +def ipca_gm(feat1, feat2, A1, A2, n1, n2, |
| 867 | + in_channel, hidden_channel, out_channel, num_layers, cross_iter, sk_max_iter, sk_tau, |
| 868 | + network, pretrain): |
| 869 | + """ |
| 870 | + MindSpore implementation of IPCA-GM. |
| 871 | + """ |
| 872 | + forward_pass = feat1 is not None |
| 873 | + if network is None: |
| 874 | + network = PCA_GM_Net(in_channel, hidden_channel, out_channel, num_layers, cross_iter) |
| 875 | + if pretrain: |
| 876 | + _load_model(network, _get_pretrain_file(pretrain, ipca_gm_pretrain_path)) |
| 877 | + if forward_pass: |
| 878 | + batch_size = feat1.shape[0] |
| 879 | + if n1 is None: |
| 880 | + n1 = mindspore.Tensor([feat1.shape[1]] * batch_size, dtype=mindspore.int32) |
| 881 | + if n2 is None: |
| 882 | + n2 = mindspore.Tensor([feat2.shape[1]] * batch_size, dtype=mindspore.int32) |
| 883 | + result = network(feat1, feat2, A1, A2, n1, n2, cross_iter, sk_max_iter, sk_tau) |
| 884 | + else: |
| 885 | + result = None |
| 886 | + return result, network |
| 887 | + |
| 888 | + |
| 889 | +def cie(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2, |
| 890 | + in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers, sk_max_iter, sk_tau, |
| 891 | + network, pretrain): |
| 892 | + """ |
| 893 | + MindSpore implementation of CIE. |
| 894 | + """ |
| 895 | + forward_pass = feat_node1 is not None |
| 896 | + if network is None: |
| 897 | + network = CIE_Net(in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers) |
| 898 | + if pretrain: |
| 899 | + _load_model(network, _get_pretrain_file(pretrain, cie_pretrain_path)) |
| 900 | + if forward_pass: |
| 901 | + batch_size = feat_node1.shape[0] |
| 902 | + if n1 is None: |
| 903 | + n1 = mindspore.Tensor([feat_node1.shape[1]] * batch_size, dtype=mindspore.int32) |
| 904 | + if n2 is None: |
| 905 | + n2 = mindspore.Tensor([feat_node1.shape[1]] * batch_size, dtype=mindspore.int32) |
| 906 | + result = network(feat_node1, feat_node2, A1, A2, feat_edge1, feat_edge2, n1, n2, sk_max_iter, sk_tau) |
| 907 | + else: |
| 908 | + result = None |
| 909 | + return result, network |
| 910 | + |
| 911 | + |
| 912 | +def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau, network, return_network, pretrain): |
| 913 | + """ |
| 914 | + MindSpore implementation of NGM. |
| 915 | + """ |
| 916 | + forward_pass = K is not None |
| 917 | + if network is None: |
| 918 | + network = NGM_Net(gnn_channels, sk_emb) |
| 919 | + if pretrain: |
| 920 | + _load_model(network, _get_pretrain_file(pretrain, ngm_pretrain_path)) |
| 921 | + if forward_pass: |
| 922 | + batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0) |
| 923 | + v0 = v0 / mindspore.ops.mean(v0) |
| 924 | + result = network(K, n1, n2, n1max, n2max, v0, sk_max_iter, sk_tau) |
| 925 | + else: |
| 926 | + result = None |
| 927 | + return result, network |
0 commit comments