Skip to content

Commit b02b749

Browse files
wooway777voltjia
authored andcommitted
issue/497 - support tensor from/to files
1 parent 1b40ea0 commit b02b749

1 file changed

Lines changed: 219 additions & 4 deletions

File tree

test/infinicore/framework/tensor.py

Lines changed: 219 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from pathlib import Path
23
from .datatypes import to_torch_dtype
34
from .devices import torch_device_map
45

@@ -12,9 +13,12 @@ class TensorInitializer:
1213
RANDINT = "randint"
1314
MANUAL = "manual"
1415
BINARY = "binary"
16+
FROM_FILE = "from_file"
1517

1618
@staticmethod
17-
def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=None):
19+
def create_tensor(
20+
shape, dtype, device, mode=RANDOM, strides=None, set_tensor=None, file_path=None
21+
):
1822
"""
1923
Create a torch tensor with specified initialization mode
2024
@@ -25,6 +29,7 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
2529
mode: Initialization mode
2630
strides: Optional strides for strided tensors
2731
set_tensor: Pre-existing tensor for manual/binary mode
32+
file_path: Path to file for FROM_FILE mode
2833
2934
Returns:
3035
torch.Tensor: Initialized tensor
@@ -36,9 +41,6 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
3641
# Handle strided tensors - calculate required storage size
3742
if strides is not None:
3843
# Calculate the required storage size for strided tensor
39-
# The storage size needed is: max(offset + 1) for all elements
40-
# where offset = sum(index[i] * stride[i] for i in range(len(shape)))
41-
# The maximum offset occurs at the last element: sum((shape[i]-1) * strides[i])
4244
storage_size = 0
4345
for i in range(len(shape)):
4446
if shape[i] > 0:
@@ -72,6 +74,10 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
7274
elif mode == TensorInitializer.BINARY:
7375
assert set_tensor is not None, "Binary mode requires set_tensor"
7476
base_tensor = set_tensor.to(torch_dtype).to(torch_device_str)
77+
elif mode == TensorInitializer.FROM_FILE:
78+
base_tensor = TensorInitializer._load_from_file(
79+
file_path, storage_size, torch_dtype, torch_device_str
80+
)
7581
else:
7682
raise ValueError(f"Unsupported initialization mode: {mode}")
7783

@@ -101,11 +107,176 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
101107
assert set_tensor is not None, "Binary mode requires set_tensor"
102108
assert shape == list(set_tensor.shape), "Shape mismatch in binary mode"
103109
tensor = set_tensor.to(torch_dtype).to(torch_device_str)
110+
elif mode == TensorInitializer.FROM_FILE:
111+
tensor = TensorInitializer._load_from_file(
112+
file_path, shape, torch_dtype, torch_device_str
113+
)
104114
else:
105115
raise ValueError(f"Unsupported initialization mode: {mode}")
106116

107117
return tensor
108118

119+
@staticmethod
120+
def _load_from_file(file_path, shape_or_size, torch_dtype, torch_device_str):
121+
"""
122+
Load tensor data from file using PyTorch's native methods
123+
124+
Args:
125+
file_path: Path to the file
126+
shape_or_size: Tensor shape for contiguous or size for strided
127+
torch_dtype: Target torch dtype
128+
torch_device_str: Target device string
129+
130+
Returns:
131+
torch.Tensor: Tensor with data loaded from file
132+
"""
133+
if file_path is None:
134+
raise ValueError("FROM_FILE mode requires file_path")
135+
136+
file_path = Path(file_path)
137+
if not file_path.exists():
138+
raise FileNotFoundError(f"File not found: {file_path}")
139+
140+
# Determine file type and load accordingly
141+
file_extension = file_path.suffix.lower()
142+
143+
if file_extension in [".pt", ".pth"]:
144+
# PyTorch native format
145+
tensor = torch.load(file_path, map_location=torch_device_str)
146+
147+
elif file_extension in [".bin", ".dat", ".raw"]:
148+
# Raw binary format - we need to know the expected shape
149+
tensor = TensorInitializer._load_binary_file(
150+
file_path, shape_or_size, torch_dtype, torch_device_str
151+
)
152+
153+
elif file_extension in [".npy"]:
154+
# NumPy format - fallback to numpy if needed
155+
try:
156+
import numpy as np
157+
158+
numpy_array = np.load(file_path)
159+
tensor = (
160+
torch.from_numpy(numpy_array).to(torch_dtype).to(torch_device_str)
161+
)
162+
except ImportError:
163+
raise ImportError("NumPy is required to load .npy files")
164+
165+
else:
166+
# Try to load as PyTorch format first, then fallback to binary
167+
try:
168+
tensor = torch.load(file_path, map_location=torch_device_str)
169+
except:
170+
# Fallback to binary loading
171+
tensor = TensorInitializer._load_binary_file(
172+
file_path, shape_or_size, torch_dtype, torch_device_str
173+
)
174+
175+
# Ensure correct dtype and device
176+
tensor = tensor.to(torch_dtype).to(torch_device_str)
177+
178+
# Validate shape/size
179+
if isinstance(shape_or_size, (list, tuple)):
180+
# Contiguous tensor - check shape
181+
if list(tensor.shape) != list(shape_or_size):
182+
raise ValueError(
183+
f"Tensor shape mismatch: expected {shape_or_size}, got {tensor.shape}"
184+
)
185+
else:
186+
# Strided tensor - check total size
187+
if tensor.numel() != shape_or_size:
188+
raise ValueError(
189+
f"Tensor size mismatch: expected {shape_or_size} elements, got {tensor.numel()}"
190+
)
191+
192+
return tensor
193+
194+
@staticmethod
195+
def _load_binary_file(file_path, shape_or_size, torch_dtype, torch_device_str):
196+
"""
197+
Load tensor from raw binary file
198+
199+
Args:
200+
file_path: Path to binary file
201+
shape_or_size: Expected shape or size
202+
torch_dtype: Target dtype
203+
torch_device_str: Target device
204+
205+
Returns:
206+
torch.Tensor: Loaded tensor
207+
"""
208+
# Read binary data
209+
with open(file_path, "rb") as f:
210+
binary_data = f.read()
211+
212+
# Create tensor from buffer
213+
if isinstance(shape_or_size, (list, tuple)):
214+
# Contiguous tensor with known shape
215+
tensor = torch.frombuffer(binary_data, dtype=torch_dtype).reshape(
216+
shape_or_size
217+
)
218+
else:
219+
# Strided tensor - just 1D buffer
220+
tensor = torch.frombuffer(binary_data, dtype=torch_dtype)
221+
222+
return tensor.to(torch_device_str)
223+
224+
@staticmethod
225+
def save_to_file(tensor, file_path, format="auto"):
226+
"""
227+
Save tensor data to file using PyTorch's native methods
228+
229+
Args:
230+
tensor: torch.Tensor to save
231+
file_path: Path to save the file
232+
format: File format ('auto', 'torch', 'binary', 'numpy')
233+
"""
234+
file_path = Path(file_path)
235+
236+
if format == "auto":
237+
# Determine format from file extension
238+
file_extension = file_path.suffix.lower()
239+
if file_extension in [".pt", ".pth"]:
240+
format = "torch"
241+
elif file_extension in [".npy"]:
242+
format = "numpy"
243+
else:
244+
format = "binary"
245+
246+
if format == "torch":
247+
# PyTorch native format (preserves metadata)
248+
torch.save(tensor, file_path)
249+
250+
elif format == "binary":
251+
# Raw binary format
252+
with open(file_path, "wb") as f:
253+
f.write(tensor.cpu().numpy().tobytes())
254+
255+
elif format == "numpy":
256+
# NumPy format
257+
try:
258+
import numpy as np
259+
260+
np.save(file_path, tensor.cpu().numpy())
261+
except ImportError:
262+
raise ImportError("NumPy is required to save .npy files")
263+
264+
else:
265+
raise ValueError(f"Unsupported format: {format}")
266+
267+
print(
268+
f"Tensor saved to {file_path} (shape: {tensor.shape}, dtype: {tensor.dtype}, format: {format})"
269+
)
270+
271+
@staticmethod
272+
def list_supported_formats():
273+
"""Return list of supported file formats"""
274+
return {
275+
"torch": [".pt", ".pth"], # PyTorch native format
276+
"binary": [".bin", ".dat", ".raw"], # Raw binary
277+
"numpy": [".npy"], # NumPy format
278+
}
279+
109280

110281
class TensorSpec:
111282
"""Tensor specification supporting various input types and per-tensor dtype"""
@@ -120,6 +291,8 @@ def __init__(
120291
is_contiguous=True,
121292
init_mode=TensorInitializer.RANDOM, # Default to random initialization
122293
custom_tensor=None, # For manual/binary mode
294+
file_path=None, # For FROM_FILE mode
295+
file_format=None, # Optional file format hint
123296
):
124297
self.shape = shape
125298
self.dtype = dtype
@@ -129,6 +302,8 @@ def __init__(
129302
self.is_contiguous = is_contiguous
130303
self.init_mode = init_mode
131304
self.custom_tensor = custom_tensor
305+
self.file_path = file_path
306+
self.file_format = file_format
132307

133308
@classmethod
134309
def from_tensor(
@@ -139,6 +314,7 @@ def from_tensor(
139314
is_contiguous=True,
140315
init_mode=TensorInitializer.RANDOM,
141316
custom_tensor=None,
317+
file_path=None,
142318
):
143319
return cls(
144320
shape=shape,
@@ -148,6 +324,7 @@ def from_tensor(
148324
is_contiguous=is_contiguous,
149325
init_mode=init_mode,
150326
custom_tensor=custom_tensor,
327+
file_path=file_path,
151328
)
152329

153330
@classmethod
@@ -162,6 +339,7 @@ def from_strided_tensor(
162339
dtype=None,
163340
init_mode=TensorInitializer.RANDOM,
164341
custom_tensor=None,
342+
file_path=None,
165343
):
166344
return cls(
167345
shape=shape,
@@ -171,6 +349,42 @@ def from_strided_tensor(
171349
is_contiguous=False,
172350
init_mode=init_mode,
173351
custom_tensor=custom_tensor,
352+
file_path=file_path,
353+
)
354+
355+
@classmethod
356+
def from_file(
357+
cls,
358+
file_path,
359+
shape,
360+
dtype=None,
361+
strides=None,
362+
is_contiguous=True,
363+
file_format=None,
364+
):
365+
"""
366+
Create TensorSpec that loads data from file
367+
368+
Args:
369+
file_path: Path to file
370+
shape: Tensor shape
371+
dtype: infinicore dtype (inferred from file if None)
372+
strides: Optional strides for strided tensors
373+
is_contiguous: Whether tensor is contiguous
374+
file_format: Optional file format hint
375+
376+
Returns:
377+
TensorSpec: Configured for file loading
378+
"""
379+
return cls(
380+
shape=shape,
381+
dtype=dtype,
382+
strides=strides,
383+
is_scalar=False,
384+
is_contiguous=is_contiguous,
385+
init_mode=TensorInitializer.FROM_FILE,
386+
file_path=file_path,
387+
file_format=file_format,
174388
)
175389

176390
def create_torch_tensor(self, device, dtype_config, tensor_index=0):
@@ -198,4 +412,5 @@ def create_torch_tensor(self, device, dtype_config, tensor_index=0):
198412
mode=self.init_mode,
199413
strides=self.strides,
200414
set_tensor=self.custom_tensor,
415+
file_path=self.file_path,
201416
)

0 commit comments

Comments
 (0)