2222# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
2323# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
2424# of the code.
25+ import os
2526import torch
2627import warnings
2728
2829from MinkowskiCommon import convert_to_int_list , StrideType
29- from MinkowskiEngineBackend ._C import CoordinateMapKey
30+ from MinkowskiEngineBackend ._C import (
31+ CoordinateMapKey ,
32+ CoordinateMapType ,
33+ GPUMemoryAllocatorType ,
34+ MinkowskiAlgorithm ,
35+ )
3036from MinkowskiTensor import (
3137 SparseTensorQuantizationMode ,
38+ SparseTensorOperationMode ,
3239 Tensor ,
40+ sparse_tensor_operation_mode ,
41+ global_coordinate_manager ,
42+ set_global_coordinate_manager ,
3343)
44+ from MinkowskiCoordinateManager import CoordinateManager
3445from sparse_matrix_functions import MinkowskiSPMMFunction
3546
3647
@@ -107,6 +118,152 @@ class SparseTensor(Tensor):
107118
108119 """
109120
121+ def __init__ (
122+ self ,
123+ features : torch .Tensor ,
124+ coordinates : torch .Tensor = None ,
125+ # optional coordinate related arguments
126+ tensor_stride : StrideType = 1 ,
127+ coordinate_map_key : CoordinateMapKey = None ,
128+ coordinate_manager : CoordinateManager = None ,
129+ quantization_mode : SparseTensorQuantizationMode = SparseTensorQuantizationMode .RANDOM_SUBSAMPLE ,
130+ # optional manager related arguments
131+ allocator_type : GPUMemoryAllocatorType = None ,
132+ minkowski_algorithm : MinkowskiAlgorithm = None ,
133+ requires_grad = None ,
134+ device = None ,
135+ ):
136+ r"""
137+
138+ Args:
139+ :attr:`features` (:attr:`torch.FloatTensor`,
140+ :attr:`torch.DoubleTensor`, :attr:`torch.cuda.FloatTensor`, or
141+ :attr:`torch.cuda.DoubleTensor`): The features of a sparse
142+ tensor.
143+
144+ :attr:`coordinates` (:attr:`torch.IntTensor`): The coordinates
145+ associated to the features. If not provided, :attr:`coordinate_map_key`
146+ must be provided.
147+
148+ :attr:`coordinate_map_key`
149+ (:attr:`MinkowskiEngine.CoordinateMapKey`): When the coordinates
150+ are already cached in the MinkowskiEngine, we could reuse the same
151+ coordinate map by simply providing the coordinate map key. In most
152+ case, this process is done automatically. When you provide a
153+ `coordinate_map_key`, `coordinates` will be be ignored.
154+
155+ :attr:`coordinate_manager`
156+ (:attr:`MinkowskiEngine.CoordinateManager`): The MinkowskiEngine
157+ manages all coordinate maps using the `_C.CoordinateMapManager`. If
158+ not provided, the MinkowskiEngine will create a new computation
159+ graph. In most cases, this process is handled automatically and you
160+ do not need to use this.
161+
162+ :attr:`quantization_mode`
163+ (:attr:`MinkowskiEngine.SparseTensorQuantizationMode`): Defines how
164+ continuous coordinates will be quantized to define a sparse tensor.
165+ Please refer to :attr:`SparseTensorQuantizationMode` for details.
166+
167+ :attr:`requires_grad` (:attr:`bool`): Set the requires_grad flag.
168+
169+ :attr:`tensor_stride` (:attr:`int`, :attr:`list`,
170+ :attr:`numpy.array`, or :attr:`tensor.Tensor`): The tensor stride
171+ of the current sparse tensor. By default, it is 1.
172+
173+ """
174+ # Type checks
175+ assert isinstance (features , torch .Tensor ), "Features must be a torch.Tensor"
176+ assert (
177+ features .ndim == 2
178+ ), f"The feature should be a matrix, The input feature is an order-{ features .ndim } tensor."
179+ assert isinstance (quantization_mode , SparseTensorQuantizationMode )
180+ self .quantization_mode = quantization_mode
181+
182+ if coordinates is not None :
183+ assert isinstance (coordinates , torch .Tensor )
184+ if coordinate_map_key is not None :
185+ assert isinstance (coordinate_map_key , CoordinateMapKey )
186+ if coordinate_manager is not None :
187+ assert isinstance (coordinate_manager , CoordinateManager )
188+ if coordinates is None and (
189+ coordinate_map_key is None or coordinate_manager is None
190+ ):
191+ raise ValueError (
192+ "Either coordinates or (coordinate_map_key, coordinate_manager) pair must be provided."
193+ )
194+
195+ Tensor .__init__ (self )
196+
197+ # To device
198+ if device is not None :
199+ features = features .to (device )
200+ if coordinates is not None :
201+ # assertion check for the map key done later
202+ coordinates = coordinates .to (device )
203+
204+ self ._D = (
205+ coordinates .size (1 ) - 1 if coordinates is not None else coordinate_manager .D
206+ )
207+ ##########################
208+ # Setup CoordsManager
209+ ##########################
210+ if coordinate_manager is None :
211+ # If set to share the coords man, use the global coords man
212+ if (
213+ sparse_tensor_operation_mode ()
214+ == SparseTensorOperationMode .SHARE_COORDINATE_MANAGER
215+ ):
216+ coordinate_manager = global_coordinate_manager ()
217+ if coordinate_manager is None :
218+ coordinate_manager = CoordinateManager (
219+ D = self ._D ,
220+ coordinate_map_type = CoordinateMapType .CUDA
221+ if coordinates .is_cuda
222+ else CoordinateMapType .CPU ,
223+ allocator_type = allocator_type ,
224+ minkowski_algorithm = minkowski_algorithm ,
225+ )
226+ set_global_coordinate_manager (coordinate_manager )
227+ else :
228+ coordinate_manager = CoordinateManager (
229+ D = coordinates .size (1 ) - 1 ,
230+ coordinate_map_type = CoordinateMapType .CUDA
231+ if coordinates .is_cuda
232+ else CoordinateMapType .CPU ,
233+ allocator_type = allocator_type ,
234+ minkowski_algorithm = minkowski_algorithm ,
235+ )
236+ self ._manager = coordinate_manager
237+
238+ ##########################
239+ # Initialize coords
240+ ##########################
241+ if coordinates is not None :
242+ assert (
243+ features .shape [0 ] == coordinates .shape [0 ]
244+ ), "The number of rows in features and coordinates must match."
245+
246+ assert (
247+ features .is_cuda == coordinates .is_cuda
248+ ), "Features and coordinates must have the same backend."
249+
250+ coordinate_map_key = CoordinateMapKey (
251+ convert_to_int_list (tensor_stride , self ._D ), ""
252+ )
253+ coordinates , features , coordinate_map_key = self .initialize_coordinates (
254+ coordinates , features , coordinate_map_key
255+ )
256+ else : # coordinate_map_key is not None:
257+ assert coordinate_map_key .is_key_set (), "The coordinate key must be valid."
258+
259+ if requires_grad is not None :
260+ features .requires_grad_ (requires_grad )
261+
262+ self ._F = features
263+ self ._C = coordinates
264+ self .coordinate_map_key = coordinate_map_key
265+ self ._batch_rows = None
266+
110267 def initialize_coordinates (self , coordinates , features , coordinate_map_key ):
111268 if not isinstance (coordinates , (torch .IntTensor , torch .cuda .IntTensor )):
112269 warnings .warn (
@@ -117,7 +274,7 @@ def initialize_coordinates(self, coordinates, features, coordinate_map_key):
117274 coordinates = torch .floor (coordinates ).int ()
118275
119276 (
120- self . coordinate_map_key ,
277+ coordinate_map_key ,
121278 (unique_index , self .inverse_mapping ),
122279 ) = self ._manager .insert_and_map (coordinates , * coordinate_map_key .get_key ())
123280 self .unique_index = unique_index .long ()
@@ -396,19 +553,15 @@ def slice(self, X, slicing_mode=0):
396553 if isinstance (X , TensorField ):
397554 return TensorField (
398555 self .F [X .inverse_mapping ],
399- coordinate_map_key = X .coordinate_map_key ,
400556 coordinate_field_map_key = X .coordinate_field_map_key ,
401557 coordinate_manager = X .coordinate_manager ,
402- inverse_mapping = X .inverse_mapping ,
403558 quantization_mode = X .quantization_mode ,
404559 )
405560 else :
406561 return TensorField (
407562 self .F [X .inverse_mapping ],
408563 coordinates = self .C [X .inverse_mapping ],
409- coordinate_map_key = X .coordinate_map_key ,
410564 coordinate_manager = X .coordinate_manager ,
411- inverse_mapping = X .inverse_mapping ,
412565 quantization_mode = X .quantization_mode ,
413566 )
414567
@@ -495,10 +648,30 @@ def features_at_coordinates(self, query_coordinates: torch.Tensor):
495648 self ._F ,
496649 query_coordinates ,
497650 self .coordinate_map_key ,
498- None ,
499651 self .coordinate_manager ,
500652 )[0 ]
501653
654+ def __repr__ (self ):
655+ return (
656+ self .__class__ .__name__
657+ + "("
658+ + os .linesep
659+ + " coordinates="
660+ + str (self .C )
661+ + os .linesep
662+ + " features="
663+ + str (self .F )
664+ + os .linesep
665+ + " coordinate_map_key="
666+ + str (self .coordinate_map_key )
667+ + os .linesep
668+ + " coordinate_manager="
669+ + str (self ._manager )
670+ + " spatial dimension="
671+ + str (self ._D )
672+ + ")"
673+ )
674+
502675 __slots__ = (
503676 "_C" ,
504677 "_F" ,
0 commit comments