Skip to content

Commit 47a981d

Browse files
committed
issue/497 - support tensor init modes
1 parent ebe544d commit 47a981d

4 files changed

Lines changed: 271 additions & 124 deletions

File tree

test/infinicore/framework/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from .base import TensorSpec, TestConfig, TestRunner, TestCase, BaseOperatorTest
1+
# [file name]: __init__.py
2+
# [file content begin]
3+
from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest
4+
from .tensor import TensorSpec, TensorInitializer
25
from .utils import (
36
compare_results,
47
create_test_comparator,
@@ -17,12 +20,11 @@
1720

1821
__all__ = [
1922
"TensorSpec",
23+
"TensorInitializer",
2024
"TestConfig",
2125
"TestRunner",
2226
"TestCase",
2327
"BaseOperatorTest",
24-
"ParameterMapping",
25-
"create_test_cases",
2628
"compare_results",
2729
"create_test_comparator",
2830
"convert_infinicore_to_torch",

test/infinicore/framework/base.py

Lines changed: 52 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .datatypes import to_torch_dtype, to_infinicore_dtype
88
from .devices import InfiniDeviceNames, torch_device_map
9+
from .tensor import TensorSpec, TensorInitializer
910
from .utils import (
1011
create_test_comparator,
1112
infinicore_tensor_from_torch,
@@ -15,50 +16,6 @@
1516
)
1617

1718

18-
class TensorSpec:
19-
"""Tensor specification supporting various input types and per-tensor dtype"""
20-
21-
def __init__(
22-
self,
23-
shape=None,
24-
dtype=None,
25-
strides=None,
26-
value=None,
27-
is_scalar=False,
28-
is_contiguous=True,
29-
):
30-
self.shape = shape
31-
self.dtype = dtype
32-
self.strides = strides
33-
self.value = value
34-
self.is_scalar = is_scalar
35-
self.is_contiguous = is_contiguous
36-
37-
@classmethod
38-
def from_tensor(cls, shape, dtype=None, strides=None, is_contiguous=True):
39-
return cls(
40-
shape=shape,
41-
dtype=dtype,
42-
strides=strides,
43-
is_scalar=False,
44-
is_contiguous=is_contiguous,
45-
)
46-
47-
@classmethod
48-
def from_scalar(cls, value, dtype=None):
49-
return cls(value=value, dtype=dtype, is_scalar=True)
50-
51-
@classmethod
52-
def from_strided_tensor(cls, shape, strides, dtype=None):
53-
return cls(
54-
shape=shape,
55-
dtype=dtype,
56-
strides=strides,
57-
is_scalar=False,
58-
is_contiguous=False,
59-
)
60-
61-
6219
class TestCase:
6320
"""Test case"""
6421

@@ -101,17 +58,27 @@ def __str__(self):
10158
input_strs.append(f"scalar({inp.value}{dtype_str})")
10259
elif hasattr(inp, "shape"):
10360
dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
61+
init_str = (
62+
f", init={inp.init_mode}"
63+
if inp.init_mode != TensorInitializer.RANDOM
64+
else ""
65+
)
10466
if hasattr(inp, "is_contiguous") and not inp.is_contiguous:
105-
input_strs.append(f"strided_tensor{inp.shape}{dtype_str}")
67+
input_strs.append(f"strided_tensor{inp.shape}{dtype_str}{init_str}")
10668
else:
107-
input_strs.append(f"tensor{inp.shape}{dtype_str}")
69+
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
10870
else:
10971
input_strs.append(str(inp))
11072

11173
base_str = f"TestCase(mode={mode_str}, inputs=[{', '.join(input_strs)}]"
11274
if self.output:
11375
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
114-
base_str += f", output=tensor{self.output.shape}{dtype_str}"
76+
init_str = (
77+
f", init={self.output.init_mode}"
78+
if self.output.init_mode != TensorInitializer.RANDOM
79+
else ""
80+
)
81+
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
11582
if self.kwargs:
11683
base_str += f", kwargs={self.kwargs}"
11784
if self.description:
@@ -252,17 +219,14 @@ def infinicore_operator(self, *inputs, out=None, **kwargs):
252219
"""Unified Infinicore operator function"""
253220
pass
254221

255-
def create_strided_tensor(self, shape, strides, dtype, device_str):
222+
def create_strided_tensor(
223+
self, shape, strides, dtype, device, init_mode=TensorInitializer.RANDOM
224+
):
256225
"""Create a non-contiguous tensor with specific strides"""
257-
total_size = 1
258-
for i in range(len(shape)):
259-
total_size += (shape[i] - 1) * abs(strides[i])
260-
261-
base_tensor = torch.rand(total_size, dtype=dtype, device=device_str)
262-
strided_tensor = torch.as_strided(base_tensor, shape, strides)
263-
return strided_tensor
226+
spec = TensorSpec.from_strided_tensor(shape, strides, dtype, init_mode)
227+
return spec.create_torch_tensor(device, dtype)
264228

265-
def prepare_inputs(self, test_case, device_str, dtype_config):
229+
def prepare_inputs(self, test_case, device, dtype_config):
266230
"""Prepare input data"""
267231
inputs = []
268232

@@ -271,49 +235,26 @@ def prepare_inputs(self, test_case, device_str, dtype_config):
271235
if input_spec.is_scalar:
272236
inputs.append(input_spec.value)
273237
else:
274-
shape = input_spec.shape
275-
276-
if input_spec.dtype is not None:
277-
tensor_dtype = to_torch_dtype(input_spec.dtype)
278-
elif (
279-
isinstance(dtype_config, dict) and f"input_{i}" in dtype_config
280-
):
281-
tensor_dtype = to_torch_dtype(dtype_config[f"input_{i}"])
282-
elif isinstance(dtype_config, (list, tuple)) and i < len(
283-
dtype_config
284-
):
285-
tensor_dtype = to_torch_dtype(dtype_config[i])
286-
else:
287-
tensor_dtype = to_torch_dtype(dtype_config)
288-
289-
if input_spec.is_contiguous or input_spec.strides is None:
290-
tensor = torch.rand(
291-
shape, dtype=tensor_dtype, device=device_str
292-
)
293-
else:
294-
tensor = self.create_strided_tensor(
295-
shape, input_spec.strides, tensor_dtype, device_str
296-
)
297-
238+
tensor = input_spec.create_torch_tensor(device, dtype_config, i)
298239
inputs.append(tensor)
299240
else:
300241
inputs.append(input_spec)
301242

302243
return inputs, test_case.kwargs
303244

304245
def get_output_dtype(self, test_case, dtype_config, torch_result=None):
305-
"""Determine output dtype"""
246+
"""Determine output dtype - returns infinicore dtype, not torch dtype"""
306247
if test_case.output and test_case.output.dtype is not None:
307-
return to_torch_dtype(test_case.output.dtype)
248+
return test_case.output.dtype
308249
elif isinstance(dtype_config, dict) and "output" in dtype_config:
309-
return to_torch_dtype(dtype_config["output"])
250+
return dtype_config["output"]
310251
elif torch_result is not None:
311-
return torch_result.dtype
252+
return to_infinicore_dtype(torch_result.dtype)
312253
else:
313254
if isinstance(dtype_config, (list, tuple)):
314-
return to_torch_dtype(dtype_config[0])
255+
return dtype_config[0]
315256
else:
316-
return to_torch_dtype(dtype_config)
257+
return dtype_config
317258

318259
def run_test(self, device, test_case, dtype_config, config):
319260
"""Unified test execution flow"""
@@ -350,7 +291,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
350291
"""Run a single test with specified operation mode"""
351292
device_str = torch_device_map[device]
352293

353-
inputs, kwargs = self.prepare_inputs(test_case, device_str, dtype_config)
294+
inputs, kwargs = self.prepare_inputs(test_case, device, dtype_config)
354295

355296
infini_inputs = []
356297
for inp in inputs:
@@ -378,8 +319,9 @@ def infini_op():
378319

379320
infini_result = infini_op()
380321

381-
comparison_dtype = to_infinicore_dtype(
382-
self.get_output_dtype(test_case, dtype_config, torch_result)
322+
# Get comparison dtype (infinicore dtype)
323+
comparison_dtype = self.get_output_dtype(
324+
test_case, dtype_config, torch_result
383325
)
384326

385327
compare_fn = create_test_comparator(
@@ -408,26 +350,40 @@ def infini_op():
408350
if not test_case.output:
409351
raise ValueError("IN_PLACE test requires output specification")
410352

353+
# Get output dtype and create output tensor
411354
output_dtype = self.get_output_dtype(test_case, dtype_config)
412355
output_shape = test_case.output.shape
413356

357+
# Use TensorSpec to create output tensor with specified initialization mode
414358
if test_case.output.is_contiguous or test_case.output.strides is None:
415-
torch_output = torch.zeros(
416-
output_shape, dtype=output_dtype, device=device_str
359+
output_spec = TensorSpec.from_tensor(
360+
output_shape, output_dtype, init_mode=test_case.output.init_mode
417361
)
418362
else:
419-
torch_output = self.create_strided_tensor(
420-
output_shape, test_case.output.strides, output_dtype, device_str
363+
output_spec = TensorSpec.from_strided_tensor(
364+
output_shape,
365+
test_case.output.strides,
366+
output_dtype,
367+
init_mode=test_case.output.init_mode,
421368
)
369+
370+
torch_output = output_spec.create_torch_tensor(device, output_dtype)
371+
372+
# For non-contiguous tensors, we need to ensure zeros initialization
373+
if (
374+
not test_case.output.is_contiguous
375+
and test_case.output.strides is not None
376+
):
422377
torch_output.zero_()
423378

424379
def torch_op_inplace():
425380
self.torch_operator(*inputs, out=torch_output, **kwargs)
426381

427382
torch_op_inplace()
428383

384+
# Create infinicore output tensor
429385
torch_dummy = torch.zeros(
430-
output_shape, dtype=output_dtype, device=device_str
386+
output_shape, dtype=to_torch_dtype(output_dtype), device=device_str
431387
)
432388
if (
433389
not test_case.output.is_contiguous
@@ -441,8 +397,8 @@ def infini_op_inplace():
441397

442398
infini_op_inplace()
443399

444-
comparison_dtype = to_infinicore_dtype(
445-
self.get_output_dtype(test_case, dtype_config, torch_output)
400+
comparison_dtype = self.get_output_dtype(
401+
test_case, dtype_config, torch_output
446402
)
447403
compare_fn = create_test_comparator(
448404
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"

test/infinicore/framework/templates.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,31 @@
1818
4. get_dtype_combinations() -> Optional[List[Dict]]
1919
- Define mixed dtype configurations for multi-dtype tests
2020
- Return None for single-dtype tests
21-
- Available mixed dtype definition methods:
22-
23-
METHOD 1: Explicit dictionary per combination (Recommended)
24-
Format: [{"input_0": dtype1, "input_1": dtype2, "output": dtype3}]
25-
Example:
26-
[{"input_0": infinicore.float16, "input_1": infinicore.float32, "output": infinicore.float16}]
27-
28-
METHOD 2: Rule-based combination generation
29-
Generate all valid combinations with business logic constraints
30-
Example: Output bf16 requires input bf16
31-
32-
METHOD 3: Multi-output support with complex structure
33-
Format: [{"inputs": [dtype1, dtype2], "outputs": [dtype3, dtype4], "params": {"scale": dtype5}}]
34-
Requires special handling in prepare_inputs()
35-
36-
METHOD 4: Per-tensor specification in TestCase
37-
Individual tensors can specify dtype in TensorSpec
38-
Overrides dtype_combinations for specific tensors
39-
40-
METHOD 5: Hybrid approach with fallback
41-
Combine explicit combinations with generated ones
42-
Support both simple and complex dtype requirements
4321
4422
5. torch_operator(*inputs, out=None, **kwargs) -> torch.Tensor
4523
- Implement PyTorch reference implementation
4624
4725
6. infinicore_operator(*inputs, out=None, **kwargs) -> infinicore.Tensor
4826
- Implement Infinicore operator implementation
4927
50-
Usage examples:
51-
- Single dtype: Return dtype list from get_tensor_dtypes(), None from get_dtype_combinations()
52-
- Mixed dtype: Return dtype combinations from get_dtype_combinations(), basic dtypes from get_tensor_dtypes()
28+
New Tensor Initialization Modes:
29+
- TensorInitializer.RANDOM (default): Random values using torch.rand
30+
- TensorInitializer.ZEROS: All zeros using torch.zeros
31+
- TensorInitializer.ONES: All ones using torch.ones
32+
- TensorInitializer.RANDINT: Random integers using torch.randint
33+
- TensorInitializer.MANUAL: Use a pre-existing tensor with shape/strides validation
34+
- TensorInitializer.BINARY: Use a pre-existing tensor with shape validation only
35+
36+
Usage examples in TestCase creation:
37+
- Basic: TensorSpec.from_tensor(shape)
38+
- With initialization: TensorSpec.from_tensor(shape, init_mode=TensorInitializer.ZEROS)
39+
- Strided with custom init: TensorSpec.from_strided_tensor(shape, strides, init_mode=TensorInitializer.ONES)
5340
"""
5441

5542
import torch
5643
import infinicore
5744
from .base import BaseOperatorTest
45+
from .tensor import TensorSpec, TensorInitializer
5846

5947

6048
class BinaryOperatorTest(BaseOperatorTest):

0 commit comments

Comments
 (0)