11import torch
2+ from pathlib import Path
23from .datatypes import to_torch_dtype
34from .devices import torch_device_map
45
@@ -12,9 +13,12 @@ class TensorInitializer:
1213 RANDINT = "randint"
1314 MANUAL = "manual"
1415 BINARY = "binary"
16+ FROM_FILE = "from_file"
1517
1618 @staticmethod
17- def create_tensor (shape , dtype , device , mode = RANDOM , strides = None , set_tensor = None ):
19+ def create_tensor (
20+ shape , dtype , device , mode = RANDOM , strides = None , set_tensor = None , file_path = None
21+ ):
1822 """
1923 Create a torch tensor with specified initialization mode
2024
@@ -25,6 +29,7 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
2529 mode: Initialization mode
2630 strides: Optional strides for strided tensors
2731 set_tensor: Pre-existing tensor for manual/binary mode
32+ file_path: Path to file for FROM_FILE mode
2833
2934 Returns:
3035 torch.Tensor: Initialized tensor
@@ -36,9 +41,6 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
3641 # Handle strided tensors - calculate required storage size
3742 if strides is not None :
3843 # Calculate the required storage size for strided tensor
39- # The storage size needed is: max(offset + 1) for all elements
40- # where offset = sum(index[i] * stride[i] for i in range(len(shape)))
41- # The maximum offset occurs at the last element: sum((shape[i]-1) * strides[i])
4244 storage_size = 0
4345 for i in range (len (shape )):
4446 if shape [i ] > 0 :
@@ -72,6 +74,10 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
7274 elif mode == TensorInitializer .BINARY :
7375 assert set_tensor is not None , "Binary mode requires set_tensor"
7476 base_tensor = set_tensor .to (torch_dtype ).to (torch_device_str )
77+ elif mode == TensorInitializer .FROM_FILE :
78+ base_tensor = TensorInitializer ._load_from_file (
79+ file_path , storage_size , torch_dtype , torch_device_str
80+ )
7581 else :
7682 raise ValueError (f"Unsupported initialization mode: { mode } " )
7783
@@ -101,11 +107,176 @@ def create_tensor(shape, dtype, device, mode=RANDOM, strides=None, set_tensor=No
101107 assert set_tensor is not None , "Binary mode requires set_tensor"
102108 assert shape == list (set_tensor .shape ), "Shape mismatch in binary mode"
103109 tensor = set_tensor .to (torch_dtype ).to (torch_device_str )
110+ elif mode == TensorInitializer .FROM_FILE :
111+ tensor = TensorInitializer ._load_from_file (
112+ file_path , shape , torch_dtype , torch_device_str
113+ )
104114 else :
105115 raise ValueError (f"Unsupported initialization mode: { mode } " )
106116
107117 return tensor
108118
119+ @staticmethod
120+ def _load_from_file (file_path , shape_or_size , torch_dtype , torch_device_str ):
121+ """
122+ Load tensor data from file using PyTorch's native methods
123+
124+ Args:
125+ file_path: Path to the file
126+ shape_or_size: Tensor shape for contiguous or size for strided
127+ torch_dtype: Target torch dtype
128+ torch_device_str: Target device string
129+
130+ Returns:
131+ torch.Tensor: Tensor with data loaded from file
132+ """
133+ if file_path is None :
134+ raise ValueError ("FROM_FILE mode requires file_path" )
135+
136+ file_path = Path (file_path )
137+ if not file_path .exists ():
138+ raise FileNotFoundError (f"File not found: { file_path } " )
139+
140+ # Determine file type and load accordingly
141+ file_extension = file_path .suffix .lower ()
142+
143+ if file_extension in [".pt" , ".pth" ]:
144+ # PyTorch native format
145+ tensor = torch .load (file_path , map_location = torch_device_str )
146+
147+ elif file_extension in [".bin" , ".dat" , ".raw" ]:
148+ # Raw binary format - we need to know the expected shape
149+ tensor = TensorInitializer ._load_binary_file (
150+ file_path , shape_or_size , torch_dtype , torch_device_str
151+ )
152+
153+ elif file_extension in [".npy" ]:
154+ # NumPy format - fallback to numpy if needed
155+ try :
156+ import numpy as np
157+
158+ numpy_array = np .load (file_path )
159+ tensor = (
160+ torch .from_numpy (numpy_array ).to (torch_dtype ).to (torch_device_str )
161+ )
162+ except ImportError :
163+ raise ImportError ("NumPy is required to load .npy files" )
164+
165+ else :
166+ # Try to load as PyTorch format first, then fallback to binary
167+ try :
168+ tensor = torch .load (file_path , map_location = torch_device_str )
169+ except :
170+ # Fallback to binary loading
171+ tensor = TensorInitializer ._load_binary_file (
172+ file_path , shape_or_size , torch_dtype , torch_device_str
173+ )
174+
175+ # Ensure correct dtype and device
176+ tensor = tensor .to (torch_dtype ).to (torch_device_str )
177+
178+ # Validate shape/size
179+ if isinstance (shape_or_size , (list , tuple )):
180+ # Contiguous tensor - check shape
181+ if list (tensor .shape ) != list (shape_or_size ):
182+ raise ValueError (
183+ f"Tensor shape mismatch: expected { shape_or_size } , got { tensor .shape } "
184+ )
185+ else :
186+ # Strided tensor - check total size
187+ if tensor .numel () != shape_or_size :
188+ raise ValueError (
189+ f"Tensor size mismatch: expected { shape_or_size } elements, got { tensor .numel ()} "
190+ )
191+
192+ return tensor
193+
194+ @staticmethod
195+ def _load_binary_file (file_path , shape_or_size , torch_dtype , torch_device_str ):
196+ """
197+ Load tensor from raw binary file
198+
199+ Args:
200+ file_path: Path to binary file
201+ shape_or_size: Expected shape or size
202+ torch_dtype: Target dtype
203+ torch_device_str: Target device
204+
205+ Returns:
206+ torch.Tensor: Loaded tensor
207+ """
208+ # Read binary data
209+ with open (file_path , "rb" ) as f :
210+ binary_data = f .read ()
211+
212+ # Create tensor from buffer
213+ if isinstance (shape_or_size , (list , tuple )):
214+ # Contiguous tensor with known shape
215+ tensor = torch .frombuffer (binary_data , dtype = torch_dtype ).reshape (
216+ shape_or_size
217+ )
218+ else :
219+ # Strided tensor - just 1D buffer
220+ tensor = torch .frombuffer (binary_data , dtype = torch_dtype )
221+
222+ return tensor .to (torch_device_str )
223+
224+ @staticmethod
225+ def save_to_file (tensor , file_path , format = "auto" ):
226+ """
227+ Save tensor data to file using PyTorch's native methods
228+
229+ Args:
230+ tensor: torch.Tensor to save
231+ file_path: Path to save the file
232+ format: File format ('auto', 'torch', 'binary', 'numpy')
233+ """
234+ file_path = Path (file_path )
235+
236+ if format == "auto" :
237+ # Determine format from file extension
238+ file_extension = file_path .suffix .lower ()
239+ if file_extension in [".pt" , ".pth" ]:
240+ format = "torch"
241+ elif file_extension in [".npy" ]:
242+ format = "numpy"
243+ else :
244+ format = "binary"
245+
246+ if format == "torch" :
247+ # PyTorch native format (preserves metadata)
248+ torch .save (tensor , file_path )
249+
250+ elif format == "binary" :
251+ # Raw binary format
252+ with open (file_path , "wb" ) as f :
253+ f .write (tensor .cpu ().numpy ().tobytes ())
254+
255+ elif format == "numpy" :
256+ # NumPy format
257+ try :
258+ import numpy as np
259+
260+ np .save (file_path , tensor .cpu ().numpy ())
261+ except ImportError :
262+ raise ImportError ("NumPy is required to save .npy files" )
263+
264+ else :
265+ raise ValueError (f"Unsupported format: { format } " )
266+
267+ print (
268+ f"Tensor saved to { file_path } (shape: { tensor .shape } , dtype: { tensor .dtype } , format: { format } )"
269+ )
270+
271+ @staticmethod
272+ def list_supported_formats ():
273+ """Return list of supported file formats"""
274+ return {
275+ "torch" : [".pt" , ".pth" ], # PyTorch native format
276+ "binary" : [".bin" , ".dat" , ".raw" ], # Raw binary
277+ "numpy" : [".npy" ], # NumPy format
278+ }
279+
109280
110281class TensorSpec :
111282 """Tensor specification supporting various input types and per-tensor dtype"""
@@ -120,6 +291,8 @@ def __init__(
120291 is_contiguous = True ,
121292 init_mode = TensorInitializer .RANDOM , # Default to random initialization
122293 custom_tensor = None , # For manual/binary mode
294+ file_path = None , # For FROM_FILE mode
295+ file_format = None , # Optional file format hint
123296 ):
124297 self .shape = shape
125298 self .dtype = dtype
@@ -129,6 +302,8 @@ def __init__(
129302 self .is_contiguous = is_contiguous
130303 self .init_mode = init_mode
131304 self .custom_tensor = custom_tensor
305+ self .file_path = file_path
306+ self .file_format = file_format
132307
133308 @classmethod
134309 def from_tensor (
@@ -139,6 +314,7 @@ def from_tensor(
139314 is_contiguous = True ,
140315 init_mode = TensorInitializer .RANDOM ,
141316 custom_tensor = None ,
317+ file_path = None ,
142318 ):
143319 return cls (
144320 shape = shape ,
@@ -148,6 +324,7 @@ def from_tensor(
148324 is_contiguous = is_contiguous ,
149325 init_mode = init_mode ,
150326 custom_tensor = custom_tensor ,
327+ file_path = file_path ,
151328 )
152329
153330 @classmethod
@@ -162,6 +339,7 @@ def from_strided_tensor(
162339 dtype = None ,
163340 init_mode = TensorInitializer .RANDOM ,
164341 custom_tensor = None ,
342+ file_path = None ,
165343 ):
166344 return cls (
167345 shape = shape ,
@@ -171,6 +349,42 @@ def from_strided_tensor(
171349 is_contiguous = False ,
172350 init_mode = init_mode ,
173351 custom_tensor = custom_tensor ,
352+ file_path = file_path ,
353+ )
354+
355+ @classmethod
356+ def from_file (
357+ cls ,
358+ file_path ,
359+ shape ,
360+ dtype = None ,
361+ strides = None ,
362+ is_contiguous = True ,
363+ file_format = None ,
364+ ):
365+ """
366+ Create TensorSpec that loads data from file
367+
368+ Args:
369+ file_path: Path to file
370+ shape: Tensor shape
371+ dtype: infinicore dtype (inferred from file if None)
372+ strides: Optional strides for strided tensors
373+ is_contiguous: Whether tensor is contiguous
374+ file_format: Optional file format hint
375+
376+ Returns:
377+ TensorSpec: Configured for file loading
378+ """
379+ return cls (
380+ shape = shape ,
381+ dtype = dtype ,
382+ strides = strides ,
383+ is_scalar = False ,
384+ is_contiguous = is_contiguous ,
385+ init_mode = TensorInitializer .FROM_FILE ,
386+ file_path = file_path ,
387+ file_format = file_format ,
174388 )
175389
176390 def create_torch_tensor (self , device , dtype_config , tensor_index = 0 ):
@@ -198,4 +412,5 @@ def create_torch_tensor(self, device, dtype_config, tensor_index=0):
198412 mode = self .init_mode ,
199413 strides = self .strides ,
200414 set_tensor = self .custom_tensor ,
415+ file_path = self .file_path ,
201416 )
0 commit comments