@@ -11,7 +11,7 @@ class PCABlock(ImageProcessingBlock):
1111 Block that applies a precomputed PCA to the image
1212 """
1313
14- def __init__ (self , fp , models_dir , device ):
14+ def __init__ (self , fp , models_dir , n_features = - 1 , device = 'cuda' ):
1515 full_fp = os .path .join (models_dir , fp )
1616 pca = torch .load (full_fp , weights_only = False )
1717
@@ -28,6 +28,10 @@ def __init__(self, fp, models_dir, device):
2828 "V" : pca ["V" ].to (device ),
2929 }
3030
31+ self .n_features = n_features
32+
33+ assert self .n_features <= self .pca ["V" ].shape [- 1 ]
34+
3135 def run (self , image , intrinsics , image_orig ):
3236 _pmean = self .pca ["mean" ].view (1 , 1 , - 1 )
3337 _pv = self .pca ["V" ].unsqueeze (0 )
@@ -40,11 +44,15 @@ def run(self, image, intrinsics, image_orig):
4044 image .shape [0 ], _pv .shape [- 1 ], image .shape [2 ], image .shape [3 ]
4145 )
4246
47+ if self .n_features >= 0 :
48+ img_out = img_out [:, :self .n_features ]
49+
4350 return img_out , intrinsics
4451
4552 @property
4653 def output_feature_keys (self ):
54+ N = self .pca ["V" ].shape [- 1 ] if self .n_features == - 1 else self .n_features
4755 return FeatureKeyList (
48- label = [f"{ self .base_label } _{ i } " for i in range (self . pca [ "V" ]. shape [ - 1 ] )],
49- metainfo = ["vfm" for i in range (self . pca [ "V" ]. shape [ - 1 ] )]
56+ label = [f"{ self .base_label } _{ i } " for i in range (N )],
57+ metainfo = ["vfm" for i in range (N )]
5058 )
0 commit comments