Skip to content

Commit 119088a

Browse files
committed
Start refactoring for creating common executionCOntext providing required data in various context
Related-To #138
1 parent 8e14fd5 commit 119088a

17 files changed

Lines changed: 321 additions & 304 deletions

File tree

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package sk.ainet.compile.nn
2+
3+
import sk.ainet.context.ExecutionContext
4+
import sk.ainet.lang.nn.dsl.NeuralNetworkDsl
5+
6+
7+
import sk.ainet.lang.nn.Module
8+
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
9+
import sk.ainet.lang.tensor.data.TensorDataFactory
10+
import sk.ainet.lang.tensor.ops.TensorOps
11+
import sk.ainet.lang.tensor.ops.VoidTensorOps
12+
import sk.ainet.lang.types.DType
13+
14+
15+
/**
16+
* Context for the DSL to define the data type and operations.
17+
*
18+
* This class holds the information about the data type and operations
19+
* that should be used in the DSL. It's used to make the DSL generic
20+
* and to avoid hardcoding the data type.
21+
*
22+
* @param T The default data type.
23+
*/
24+
public interface NeuralNetworkContext<T : DType, V> : ExecutionContext<V>
25+
26+
private class DefaultNetworkContext<T : DType, V> : NeuralNetworkContext<T, V> {
27+
override val ops: TensorOps<V>
28+
get() = VoidTensorOps()
29+
override val tensorDataFactory: TensorDataFactory
30+
get() = DenseTensorDataFactory()
31+
}
32+
33+
/**
34+
* Creates a context for the DSL with the given configuration.
35+
*
36+
* @param T The type of data processed by the modules.
37+
* @param init The configuration function.
38+
* @return The configured context.
39+
*/
40+
public fun <T : DType, V> context(init: NeuralNetworkContext<T, V>.(NeuralNetworkContext<T, V>) -> Module<T, V>): Module<T, V> {
41+
val instance = DefaultNetworkContext<T, V>()
42+
return instance.init(instance)
43+
}
44+
45+
/**
46+
* Extension function to create a network within a NetworkContext.
47+
* This bridges the context wrapper with the network DSL using the context's tensor factory.
48+
*/
49+
public inline fun <reified T : DType, V> NeuralNetworkContext<T, V>.network(
50+
content: NeuralNetworkDsl<T, V>.() -> Unit
51+
): Module<T, V> = sk.ainet.lang.nn.dsl.network(this.tensorDataFactory, content)
52+
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package sk.ainet.context
2+
3+
import sk.ainet.lang.tensor.data.TensorDataFactory
4+
import sk.ainet.lang.tensor.ops.TensorOps
5+
6+
7+
public interface ExecutionContext<V> {
8+
public val ops: TensorOps<V>
9+
public val tensorDataFactory: TensorDataFactory
10+
}
11+
12+
13+
/**
14+
* Memory usage information
15+
*/
16+
public data class MemoryInfo(
17+
/**
18+
* Total memory available on the device
19+
*/
20+
public val totalMemory: Long,
21+
22+
/**
23+
* Memory currently in use
24+
*/
25+
public val usedMemory: Long,
26+
27+
/**
28+
* Free memory available
29+
*/
30+
public val freeMemory: Long = totalMemory - usedMemory,
31+
32+
/**
33+
* Memory usage as percentage
34+
*/
35+
public val usagePercentage: Double = (usedMemory.toDouble() / totalMemory) * 100.0
36+
)
37+
38+
/**
39+
* Execution statistics
40+
*/
41+
public data class ExecutionStats(
42+
/**
43+
* Total number of operations executed
44+
*/
45+
public val operationsExecuted: Long = 0,
46+
47+
/**
48+
* Total execution time in milliseconds
49+
*/
50+
public val totalExecutionTime: Long = 0,
51+
52+
/**
53+
* Average execution time per operation
54+
*/
55+
public val averageExecutionTime: Double =
56+
if (operationsExecuted > 0) totalExecutionTime.toDouble() / operationsExecuted else 0.0,
57+
58+
/**
59+
* Number of tensors created
60+
*/
61+
public val tensorsCreated: Long = 0,
62+
63+
/**
64+
* Peak memory usage
65+
*/
66+
public val peakMemoryUsage: Long = 0
67+
)
68+

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionContext.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ import sk.ainet.lang.types.DType
66
/**
77
* Default implementation of ExecutionContext
88
*/
9+
/*
910
public class DefaultExecutionContext(
10-
initialMode: ExecutionMode = ExecutionMode.EAGER,
11-
initialDevice: Device = Device(DeviceType.CPU)
12-
) : ExecutionContext {
11+
) : GraphExecutionContext {
1312
1413
private var _executionMode: ExecutionMode = initialMode
1514
private var _device: Device = initialDevice
@@ -144,4 +143,6 @@ public class DefaultExecutionContext(
144143
// For now, return a simple counter
145144
return _executionStats.operationsExecuted
146145
}
147-
}
146+
}
147+
148+
*/
Lines changed: 13 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -1,230 +1,52 @@
11
package sk.ainet.lang.graph
22

3+
import sk.ainet.context.ExecutionContext
4+
import sk.ainet.context.ExecutionStats
5+
import sk.ainet.context.MemoryInfo
36
import sk.ainet.lang.tensor.Tensor
47
import sk.ainet.lang.types.DType
58

6-
/**
7-
* Execution modes supported by the framework
8-
*/
9-
public enum class ExecutionMode {
10-
/**
11-
* Eager execution - operations are executed immediately
12-
*/
13-
EAGER,
14-
15-
/**
16-
* Graph execution - operations are recorded and executed later
17-
*/
18-
GRAPH
19-
}
20-
21-
/**
22-
* Device types for tensor operations
23-
*/
24-
public enum class DeviceType {
25-
CPU,
26-
GPU,
27-
TPU,
28-
CUSTOM
29-
}
30-
31-
/**
32-
* Device specification
33-
*/
34-
public data class Device(
35-
public val type: DeviceType,
36-
public val id: Int = 0,
37-
public val name: String = "${type.name.lowercase()}_$id"
38-
)
399

4010
/**
4111
* Context for managing execution state, including mode switching,
4212
* device management, and memory management.
4313
*/
44-
public interface ExecutionContext {
45-
46-
/**
47-
* Current execution mode
48-
*/
49-
public val executionMode: ExecutionMode
50-
51-
/**
52-
* Current device for tensor operations
53-
*/
54-
public val device: Device
55-
14+
public interface GraphExecutionContext<V> : ExecutionContext<V> {
15+
16+
5617
/**
5718
* Current execution tape (null if not recording)
5819
*/
5920
public val currentTape: ExecutionTape?
60-
21+
6122
/**
6223
* Tape stack for nested execution contexts
6324
*/
6425
public val tapeStack: TapeStack
65-
26+
6627
/**
6728
* Whether operations should be recorded
6829
*/
6930
public val isRecording: Boolean get() = currentTape?.isRecording == true
70-
71-
/**
72-
* Switch to eager execution mode
73-
*/
74-
public fun switchToEager()
75-
76-
/**
77-
* Switch to graph execution mode
78-
*/
79-
public fun switchToGraph()
80-
81-
/**
82-
* Set the device for tensor operations
83-
*/
84-
public fun setDevice(device: Device)
85-
86-
/**
87-
* Start recording operations with a new tape
88-
*/
89-
public fun startRecording(tape: ExecutionTape = createTape())
90-
91-
/**
92-
* Stop recording operations
93-
*/
94-
public fun stopRecording(): ExecutionTape?
95-
96-
/**
97-
* Execute an operation in the current context
98-
*/
99-
public fun <T : DType, V> executeOperation(
100-
operation: Operation,
101-
inputs: List<Tensor<T, V>>
102-
): List<Tensor<T, V>>
103-
104-
/**
105-
* Create a new execution tape
106-
*/
107-
public fun createTape(): ExecutionTape
108-
109-
/**
110-
* Create a new gradient tape
111-
*/
112-
public fun createGradientTape(): GradientTape
113-
114-
/**
115-
* Execute with a specific execution mode temporarily
116-
*/
117-
public fun <R> withExecutionMode(mode: ExecutionMode, block: () -> R): R
118-
119-
/**
120-
* Execute with a specific device temporarily
121-
*/
122-
public fun <R> withDevice(device: Device, block: () -> R): R
123-
124-
/**
125-
* Execute with recording enabled temporarily
126-
*/
127-
public fun <R> withRecording(tape: ExecutionTape = createTape(), block: () -> R): Pair<R, ExecutionTape>
128-
31+
32+
12933
/**
13034
* Get memory usage information
13135
*/
13236
public fun getMemoryInfo(): MemoryInfo
133-
37+
13438
/**
13539
* Force garbage collection
13640
*/
13741
public fun collectGarbage()
138-
42+
13943
/**
14044
* Get execution statistics
14145
*/
14246
public fun getExecutionStats(): ExecutionStats
143-
47+
14448
/**
14549
* Reset execution statistics
14650
*/
14751
public fun resetExecutionStats()
14852
}
149-
150-
/**
151-
* Memory usage information
152-
*/
153-
public data class MemoryInfo(
154-
/**
155-
* Total memory available on the device
156-
*/
157-
public val totalMemory: Long,
158-
159-
/**
160-
* Memory currently in use
161-
*/
162-
public val usedMemory: Long,
163-
164-
/**
165-
* Free memory available
166-
*/
167-
public val freeMemory: Long = totalMemory - usedMemory,
168-
169-
/**
170-
* Memory usage as percentage
171-
*/
172-
public val usagePercentage: Double = (usedMemory.toDouble() / totalMemory) * 100.0
173-
)
174-
175-
/**
176-
* Execution statistics
177-
*/
178-
public data class ExecutionStats(
179-
/**
180-
* Total number of operations executed
181-
*/
182-
public val operationsExecuted: Long = 0,
183-
184-
/**
185-
* Total execution time in milliseconds
186-
*/
187-
public val totalExecutionTime: Long = 0,
188-
189-
/**
190-
* Average execution time per operation
191-
*/
192-
public val averageExecutionTime: Double =
193-
if (operationsExecuted > 0) totalExecutionTime.toDouble() / operationsExecuted else 0.0,
194-
195-
/**
196-
* Number of tensors created
197-
*/
198-
public val tensorsCreated: Long = 0,
199-
200-
/**
201-
* Peak memory usage
202-
*/
203-
public val peakMemoryUsage: Long = 0
204-
)
205-
206-
/**
207-
* Global execution context
208-
*/
209-
public object GlobalExecutionContext {
210-
private var _current: ExecutionContext? = null
211-
212-
/**
213-
* Get the current global execution context
214-
*/
215-
public fun current(): ExecutionContext {
216-
return _current ?: throw IllegalStateException("No execution context set")
217-
}
218-
219-
/**
220-
* Set the global execution context
221-
*/
222-
public fun setCurrent(context: ExecutionContext) {
223-
_current = context
224-
}
225-
226-
/**
227-
* Check if a global execution context is set
228-
*/
229-
public fun isSet(): Boolean = _current != null
230-
}

0 commit comments

Comments
 (0)