Skip to content

Commit 9e1322f

Browse files
committed
sparse tensor and tfield refact
1 parent d5c6fe8 commit 9e1322f

13 files changed

Lines changed: 414 additions & 322 deletions

MinkowskiEngine/MinkowskiNonlinearity.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ def forward(self, input):
4343
if isinstance(input, TensorField):
4444
return TensorField(
4545
output,
46-
coordinate_map_key=input.coordinate_map_key,
4746
coordinate_field_map_key=input.coordinate_field_map_key,
4847
coordinate_manager=input.coordinate_manager,
49-
inverse_mapping=input.inverse_mapping,
5048
quantization_mode=input.quantization_mode,
5149
)
5250
else:
@@ -120,10 +118,8 @@ def forward(self, input: Union[SparseTensor, TensorField]):
120118
if isinstance(input, TensorField):
121119
return TensorField(
122120
out_F,
123-
coordinate_map_key=input.coordinate_map_key,
124121
coordinate_field_map_key=input.coordinate_field_map_key,
125122
coordinate_manager=input.coordinate_manager,
126-
inverse_mapping=input.inverse_mapping,
127123
quantization_mode=input.quantization_mode,
128124
)
129125
else:

MinkowskiEngine/MinkowskiNormalization.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,8 @@ def forward(self, input):
7676
if isinstance(input, TensorField):
7777
return TensorField(
7878
output,
79-
coordinate_map_key=input.coordinate_map_key,
8079
coordinate_field_map_key=input.coordinate_field_map_key,
8180
coordinate_manager=input.coordinate_manager,
82-
inverse_mapping=input.inverse_mapping,
8381
quantization_mode=input.quantization_mode,
8482
)
8583
else:

MinkowskiEngine/MinkowskiOps.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,8 @@ def forward(self, input: Union[SparseTensor, TensorField]):
4646
if isinstance(input, TensorField):
4747
return TensorField(
4848
output,
49-
coordinate_map_key=input.coordinate_map_key,
5049
coordinate_field_map_key=input.coordinate_field_map_key,
5150
coordinate_manager=input.coordinate_manager,
52-
inverse_mapping=input.inverse_mapping,
5351
quantization_mode=input.quantization_mode,
5452
)
5553
else:

MinkowskiEngine/MinkowskiSparseTensor.py

Lines changed: 180 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,26 @@
2222
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
2323
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
2424
# of the code.
25+
import os
2526
import torch
2627
import warnings
2728

2829
from MinkowskiCommon import convert_to_int_list, StrideType
29-
from MinkowskiEngineBackend._C import CoordinateMapKey
30+
from MinkowskiEngineBackend._C import (
31+
CoordinateMapKey,
32+
CoordinateMapType,
33+
GPUMemoryAllocatorType,
34+
MinkowskiAlgorithm,
35+
)
3036
from MinkowskiTensor import (
3137
SparseTensorQuantizationMode,
38+
SparseTensorOperationMode,
3239
Tensor,
40+
sparse_tensor_operation_mode,
41+
global_coordinate_manager,
42+
set_global_coordinate_manager,
3343
)
44+
from MinkowskiCoordinateManager import CoordinateManager
3445
from sparse_matrix_functions import MinkowskiSPMMFunction
3546

3647

@@ -107,6 +118,152 @@ class SparseTensor(Tensor):
107118
108119
"""
109120

121+
def __init__(
122+
self,
123+
features: torch.Tensor,
124+
coordinates: torch.Tensor = None,
125+
# optional coordinate related arguments
126+
tensor_stride: StrideType = 1,
127+
coordinate_map_key: CoordinateMapKey = None,
128+
coordinate_manager: CoordinateManager = None,
129+
quantization_mode: SparseTensorQuantizationMode = SparseTensorQuantizationMode.RANDOM_SUBSAMPLE,
130+
# optional manager related arguments
131+
allocator_type: GPUMemoryAllocatorType = None,
132+
minkowski_algorithm: MinkowskiAlgorithm = None,
133+
requires_grad=None,
134+
device=None,
135+
):
136+
r"""
137+
138+
Args:
139+
:attr:`features` (:attr:`torch.FloatTensor`,
140+
:attr:`torch.DoubleTensor`, :attr:`torch.cuda.FloatTensor`, or
141+
:attr:`torch.cuda.DoubleTensor`): The features of a sparse
142+
tensor.
143+
144+
:attr:`coordinates` (:attr:`torch.IntTensor`): The coordinates
145+
associated to the features. If not provided, :attr:`coordinate_map_key`
146+
must be provided.
147+
148+
:attr:`coordinate_map_key`
149+
(:attr:`MinkowskiEngine.CoordinateMapKey`): When the coordinates
150+
are already cached in the MinkowskiEngine, we could reuse the same
151+
coordinate map by simply providing the coordinate map key. In most
152+
case, this process is done automatically. When you provide a
153+
`coordinate_map_key`, `coordinates` will be be ignored.
154+
155+
:attr:`coordinate_manager`
156+
(:attr:`MinkowskiEngine.CoordinateManager`): The MinkowskiEngine
157+
manages all coordinate maps using the `_C.CoordinateMapManager`. If
158+
not provided, the MinkowskiEngine will create a new computation
159+
graph. In most cases, this process is handled automatically and you
160+
do not need to use this.
161+
162+
:attr:`quantization_mode`
163+
(:attr:`MinkowskiEngine.SparseTensorQuantizationMode`): Defines how
164+
continuous coordinates will be quantized to define a sparse tensor.
165+
Please refer to :attr:`SparseTensorQuantizationMode` for details.
166+
167+
:attr:`requires_grad` (:attr:`bool`): Set the requires_grad flag.
168+
169+
:attr:`tensor_stride` (:attr:`int`, :attr:`list`,
170+
:attr:`numpy.array`, or :attr:`tensor.Tensor`): The tensor stride
171+
of the current sparse tensor. By default, it is 1.
172+
173+
"""
174+
# Type checks
175+
assert isinstance(features, torch.Tensor), "Features must be a torch.Tensor"
176+
assert (
177+
features.ndim == 2
178+
), f"The feature should be a matrix, The input feature is an order-{features.ndim} tensor."
179+
assert isinstance(quantization_mode, SparseTensorQuantizationMode)
180+
self.quantization_mode = quantization_mode
181+
182+
if coordinates is not None:
183+
assert isinstance(coordinates, torch.Tensor)
184+
if coordinate_map_key is not None:
185+
assert isinstance(coordinate_map_key, CoordinateMapKey)
186+
if coordinate_manager is not None:
187+
assert isinstance(coordinate_manager, CoordinateManager)
188+
if coordinates is None and (
189+
coordinate_map_key is None or coordinate_manager is None
190+
):
191+
raise ValueError(
192+
"Either coordinates or (coordinate_map_key, coordinate_manager) pair must be provided."
193+
)
194+
195+
Tensor.__init__(self)
196+
197+
# To device
198+
if device is not None:
199+
features = features.to(device)
200+
if coordinates is not None:
201+
# assertion check for the map key done later
202+
coordinates = coordinates.to(device)
203+
204+
self._D = (
205+
coordinates.size(1) - 1 if coordinates is not None else coordinate_manager.D
206+
)
207+
##########################
208+
# Setup CoordsManager
209+
##########################
210+
if coordinate_manager is None:
211+
# If set to share the coords man, use the global coords man
212+
if (
213+
sparse_tensor_operation_mode()
214+
== SparseTensorOperationMode.SHARE_COORDINATE_MANAGER
215+
):
216+
coordinate_manager = global_coordinate_manager()
217+
if coordinate_manager is None:
218+
coordinate_manager = CoordinateManager(
219+
D=self._D,
220+
coordinate_map_type=CoordinateMapType.CUDA
221+
if coordinates.is_cuda
222+
else CoordinateMapType.CPU,
223+
allocator_type=allocator_type,
224+
minkowski_algorithm=minkowski_algorithm,
225+
)
226+
set_global_coordinate_manager(coordinate_manager)
227+
else:
228+
coordinate_manager = CoordinateManager(
229+
D=coordinates.size(1) - 1,
230+
coordinate_map_type=CoordinateMapType.CUDA
231+
if coordinates.is_cuda
232+
else CoordinateMapType.CPU,
233+
allocator_type=allocator_type,
234+
minkowski_algorithm=minkowski_algorithm,
235+
)
236+
self._manager = coordinate_manager
237+
238+
##########################
239+
# Initialize coords
240+
##########################
241+
if coordinates is not None:
242+
assert (
243+
features.shape[0] == coordinates.shape[0]
244+
), "The number of rows in features and coordinates must match."
245+
246+
assert (
247+
features.is_cuda == coordinates.is_cuda
248+
), "Features and coordinates must have the same backend."
249+
250+
coordinate_map_key = CoordinateMapKey(
251+
convert_to_int_list(tensor_stride, self._D), ""
252+
)
253+
coordinates, features, coordinate_map_key = self.initialize_coordinates(
254+
coordinates, features, coordinate_map_key
255+
)
256+
else: # coordinate_map_key is not None:
257+
assert coordinate_map_key.is_key_set(), "The coordinate key must be valid."
258+
259+
if requires_grad is not None:
260+
features.requires_grad_(requires_grad)
261+
262+
self._F = features
263+
self._C = coordinates
264+
self.coordinate_map_key = coordinate_map_key
265+
self._batch_rows = None
266+
110267
def initialize_coordinates(self, coordinates, features, coordinate_map_key):
111268
if not isinstance(coordinates, (torch.IntTensor, torch.cuda.IntTensor)):
112269
warnings.warn(
@@ -117,7 +274,7 @@ def initialize_coordinates(self, coordinates, features, coordinate_map_key):
117274
coordinates = torch.floor(coordinates).int()
118275

119276
(
120-
self.coordinate_map_key,
277+
coordinate_map_key,
121278
(unique_index, self.inverse_mapping),
122279
) = self._manager.insert_and_map(coordinates, *coordinate_map_key.get_key())
123280
self.unique_index = unique_index.long()
@@ -396,19 +553,15 @@ def slice(self, X, slicing_mode=0):
396553
if isinstance(X, TensorField):
397554
return TensorField(
398555
self.F[X.inverse_mapping],
399-
coordinate_map_key=X.coordinate_map_key,
400556
coordinate_field_map_key=X.coordinate_field_map_key,
401557
coordinate_manager=X.coordinate_manager,
402-
inverse_mapping=X.inverse_mapping,
403558
quantization_mode=X.quantization_mode,
404559
)
405560
else:
406561
return TensorField(
407562
self.F[X.inverse_mapping],
408563
coordinates=self.C[X.inverse_mapping],
409-
coordinate_map_key=X.coordinate_map_key,
410564
coordinate_manager=X.coordinate_manager,
411-
inverse_mapping=X.inverse_mapping,
412565
quantization_mode=X.quantization_mode,
413566
)
414567

@@ -495,10 +648,30 @@ def features_at_coordinates(self, query_coordinates: torch.Tensor):
495648
self._F,
496649
query_coordinates,
497650
self.coordinate_map_key,
498-
None,
499651
self.coordinate_manager,
500652
)[0]
501653

654+
def __repr__(self):
655+
return (
656+
self.__class__.__name__
657+
+ "("
658+
+ os.linesep
659+
+ " coordinates="
660+
+ str(self.C)
661+
+ os.linesep
662+
+ " features="
663+
+ str(self.F)
664+
+ os.linesep
665+
+ " coordinate_map_key="
666+
+ str(self.coordinate_map_key)
667+
+ os.linesep
668+
+ " coordinate_manager="
669+
+ str(self._manager)
670+
+ " spatial dimension="
671+
+ str(self._D)
672+
+ ")"
673+
)
674+
502675
__slots__ = (
503676
"_C",
504677
"_F",

0 commit comments

Comments
 (0)