@@ -29,13 +29,14 @@ def get_low_dim_basis(inf_matrix: InfluenceMatrix, compression: str = 'wavelet')
2929 :type compression: str
3030 :return: a list that contains the dimension reduction basis in the format of array(float)
3131 """
32+ low_dim_basis = {}
3233 num_of_beams = len (inf_matrix .beamlets_dict )
33- low_dim_basis = list ()
34+ num_of_beamlets = inf_matrix . beamlets_dict [ num_of_beams - 1 ][ 'end_beamlet' ] + 1
3435 beam_id = [inf_matrix .beamlets_dict [i ]['beam_id' ] for i in range (num_of_beams )]
3536 beamlets = inf_matrix .get_bev_2d_grid (beam_id = beam_id )
3637 index_position = list ()
37- num_of_beamlets = inf_matrix .beamlets_dict [num_of_beams - 1 ]['end_beamlet' ] + 1
3838 for ind in range (num_of_beams ):
39+ low_dim_basis [beam_id [ind ]] = []
3940 for i in range (inf_matrix .beamlets_dict [ind ]['start_beamlet' ],
4041 inf_matrix .beamlets_dict [ind ]['end_beamlet' ] + 1 ):
4142 index_position .append ((np .where (beamlets [ind ] == i )[0 ][0 ], np .where (beamlets [ind ] == i )[1 ][0 ]))
@@ -64,9 +65,11 @@ def get_low_dim_basis(inf_matrix: InfluenceMatrix, compression: str = 'wavelet')
6465 inf_matrix .beamlets_dict [b ]['end_beamlet' ] + 1 ):
6566 approximation [ind ] = approximation_coeffs [index_position [ind ]]
6667 horizontal [ind ] = horizontal_coeffs [index_position [ind ]]
67- low_dim_basis .append (np .stack (( approximation , horizontal )))
68+ low_dim_basis [ beam_id [ b ]] .append (np .transpose ( np . stack ([ approximation , horizontal ] )))
6869 beamlet_2d_grid [row ][col ] = 0
69- low_dim_basis = np .transpose (np .concatenate (low_dim_basis , axis = 0 ))
70- u , s , vh = scipy .sparse .linalg .svds (low_dim_basis , k = min (low_dim_basis .shape [0 ], low_dim_basis .shape [1 ]) - 1 )
71- ind = np .where (s > 0.0001 )
72- return u [:, ind [0 ]]
70+ for b in beam_id :
71+ low_dim_basis [b ] = np .concatenate (low_dim_basis [b ], axis = 1 )
72+ u , s , vh = scipy .sparse .linalg .svds (low_dim_basis [b ], k = min (low_dim_basis [b ].shape [0 ], low_dim_basis [b ].shape [1 ]) - 1 )
73+ ind = np .where (s > 0.0001 )
74+ low_dim_basis [b ] = u [:, ind [0 ]]
75+ return np .concatenate ([low_dim_basis [b ] for b in beam_id ], axis = 1 )
0 commit comments