1+ package sk.ainet.lang.tensor
2+
3+ import sk.ainet.lang.tensor.data.TensorData
4+ import sk.ainet.lang.tensor.ops.TensorOps
5+ import sk.ainet.lang.types.DType
6+
7+ /* *
8+ * A materialization strategy that immediately copies all data from a tensor view
9+ * into a new, standalone tensor with contiguous memory layout.
10+ *
11+ * This strategy provides immediate materialization by iterating through all
12+ * elements in the view and copying them to a new tensor with the view's shape.
13+ * The resulting tensor is completely independent of the parent tensor and can
14+ * be used even after the parent tensor is garbage collected.
15+ *
16+ * ## Characteristics
17+ *
18+ * - **Immediate Execution**: Materialization happens synchronously when called
19+ * - **Memory Independent**: Result tensor has no dependencies on parent tensor
20+ * - **Contiguous Layout**: Output data is stored in standard row-major order
21+ * - **Type Preservation**: Maintains the same data type and value type as the view
22+ *
23+ * ## Trade-offs
24+ *
25+ * **Benefits:**
26+ * - Predictable memory usage and performance
27+ * - No ongoing computational overhead for element access
28+ * - Enables garbage collection of parent tensors
29+ * - Compatible with all downstream operations
30+ *
31+ * **Costs:**
32+ * - Immediate memory allocation for full tensor size
33+ * - Computational cost of copying all elements
34+ * - Temporary memory pressure during materialization
35+ *
36+ * ## Usage Scenarios
37+ *
38+ * This strategy is optimal when:
39+ * - The materialized tensor will be accessed frequently
40+ * - Memory usage is predictable and acceptable
41+ * - The parent tensor can be released after materialization
42+ * - Compatibility with external libraries is required
43+ *
44+ * @param T the data type constraint extending DType
45+ * @param V the actual value type that will be stored and accessed
46+ */
47+ public class CopyMaterializationStrategy <T : DType , V > : MaterializationStrategy <T , V > {
48+
49+ override val name: String = " CopyMaterialization"
50+
51+ override fun materialize (view : TensorView <T , V >): Tensor <T , V > {
52+ val viewShape = view.viewShape
53+ val viewVolume = viewShape.volume
54+
55+ // Create a new data array to hold the materialized elements
56+ val materializedData = createDataArray(view, viewVolume)
57+
58+ // Copy all elements from the view to the new array
59+ copyViewElements(view, materializedData, viewShape)
60+
61+ // Create and return the materialized tensor
62+ return createMaterializedTensor(view, materializedData, viewShape)
63+ }
64+
65+ override fun canMaterialize (view : TensorView <T , V >): Boolean {
66+ // CopyMaterializationStrategy can handle any view as long as:
67+ // 1. The view has a valid shape
68+ // 2. Memory is available for allocation
69+ return try {
70+ view.viewShape.volume >= 0
71+ } catch (e: Exception ) {
72+ false
73+ }
74+ }
75+
76+ override fun estimateMemoryOverhead (view : TensorView <T , V >): Long {
77+ // Estimate memory required for a copy of the view data
78+ val viewVolume = view.viewShape.volume
79+ val bytesPerElement = estimateBytesPerElement(view.dtype)
80+ return viewVolume.toLong() * bytesPerElement
81+ }
82+
83+ /* *
84+ * Creates a data array suitable for storing the materialized view elements.
85+ *
86+ * This method needs to create an appropriate array type based on the
87+ * tensor's value type. Since we don't have direct access to the tensor
88+ * factory here, we'll need to work with the existing data structure.
89+ */
90+ @Suppress(" UNCHECKED_CAST" )
91+ private fun createDataArray (view : TensorView <T , V >, volume : Int ): Array <V ?> {
92+ return arrayOfNulls<Any >(volume) as Array <V ?>
93+ }
94+
95+ /* *
96+ * Copies all elements from the tensor view to the materialized data array.
97+ *
98+ * This method iterates through the view's coordinate space and copies
99+ * each element to the corresponding position in the output array using
100+ * row-major order.
101+ */
102+ private fun copyViewElements (view : TensorView <T , V >, data : Array <V ?>, shape : Shape ) {
103+ val dimensions = shape.dimensions
104+ val indices = IntArray (dimensions.size)
105+
106+ fun copyRecursive (dimension : Int , flatIndex : Int ): Int {
107+ var currentIndex = flatIndex
108+
109+ if (dimension == dimensions.size) {
110+ // Base case: copy the element at this coordinate
111+ val element = view.data.get(* indices)
112+ data[currentIndex] = element
113+ return currentIndex + 1
114+ }
115+
116+ // Recursive case: iterate through this dimension
117+ for (i in 0 until dimensions[dimension]) {
118+ indices[dimension] = i
119+ currentIndex = copyRecursive(dimension + 1 , currentIndex)
120+ }
121+
122+ return currentIndex
123+ }
124+
125+ copyRecursive(0 , 0 )
126+ }
127+
128+ /* *
129+ * Creates a materialized tensor from the copied data.
130+ *
131+ * This method constructs a new Tensor instance using the copied data
132+ * and the view's shape and data type information.
133+ */
134+ private fun createMaterializedTensor (
135+ view : TensorView <T , V >,
136+ data : Array <V ?>,
137+ shape : Shape
138+ ): Tensor <T , V > {
139+ // Create a simple tensor implementation that wraps our materialized data
140+ return MaterializedTensor (
141+ data = MaterializedTensorData <T , V >(shape, data),
142+ ops = view.ops,
143+ dtype = view.dtype
144+ )
145+ }
146+
147+ /* *
148+ * Estimates the number of bytes per element for the given data type.
149+ */
150+ private fun estimateBytesPerElement (dtype : DType ): Int {
151+ return when (dtype.name) {
152+ " FP32" -> 4
153+ " FP16" -> 2
154+ " Int32" -> 4
155+ " Int8" -> 1
156+ " Int4" -> 1 // Packed, but estimate 1 byte for simplicity
157+ " Ternary" -> 1 // Packed, but estimate 1 byte for simplicity
158+ else -> 4 // Default to 4 bytes
159+ }
160+ }
161+
162+ /* *
163+ * Simple tensor data implementation for materialized tensors.
164+ */
165+ private class MaterializedTensorData <T : DType , V >(
166+ override val shape : Shape ,
167+ private val data : Array <V ?>
168+ ) : TensorData<T, V> {
169+
170+ override fun get (vararg indices : Int ): V {
171+ val flatIndex = calculateFlatIndex(indices)
172+ return data[flatIndex] ? : throw IllegalStateException (" Null data at index $flatIndex " )
173+ }
174+
175+ override fun set (vararg indices : Int , value : V ) {
176+ val flatIndex = calculateFlatIndex(indices)
177+ data[flatIndex] = value
178+ }
179+
180+ private fun calculateFlatIndex (indices : IntArray ): Int {
181+ require(indices.size == shape.dimensions.size) {
182+ " Expected ${shape.dimensions.size} indices, got ${indices.size} "
183+ }
184+
185+ var flatIndex = 0
186+ var stride = 1
187+
188+ // Calculate flat index using row-major order
189+ for (i in shape.dimensions.size - 1 downTo 0 ) {
190+ require(indices[i] >= 0 && indices[i] < shape.dimensions[i]) {
191+ " Index ${indices[i]} out of bounds for dimension $i with size ${shape.dimensions[i]} "
192+ }
193+ flatIndex + = indices[i] * stride
194+ stride * = shape.dimensions[i]
195+ }
196+
197+ return flatIndex
198+ }
199+ }
200+
201+ /* *
202+ * Simple tensor implementation for materialized tensors.
203+ */
204+ private class MaterializedTensor <T : DType , V >(
205+ override val data : TensorData <T , V >,
206+ override val ops : TensorOps <V >,
207+ override val dtype : T
208+ ) : Tensor<T, V>
209+ }
0 commit comments