|
10 | 10 | # |
11 | 11 | # ===--------------------------------------------------------------------------------------===# |
12 | 12 |
|
13 | | -from typing import Any, Callable, Dict, List, Optional |
| 13 | +from typing import Dict, Optional |
14 | 14 | from abc import ABC, abstractmethod |
15 | 15 |
|
16 | 16 | import numpy as np |
@@ -264,18 +264,45 @@ def reset(self, exploration_rate: Optional[float] = None) -> None: |
264 | 264 |
|
265 | 265 |
|
266 | 266 | class CosineScheduler(ExplorationRateScheduler): |
267 | | - """ """ |
| 267 | + """Cosine annealing scheduler that oscillates exploration rate periodically. |
| 268 | + |
| 269 | + This scheduler uses a cosine wave to smoothly vary the exploration rate |
| 270 | + between min_rate and max_rate over a fixed cycle length, allowing for |
| 271 | + periodic transitions between exploration and exploitation. |
| 272 | + |
| 273 | + Attributes: |
| 274 | + cycle_length: Number of epochs for one complete cosine cycle. |
| 275 | + """ |
268 | 276 |
|
269 | 277 | def __init__( |
270 | 278 | self, exploration_rate: float, max_rate: float, min_rate: float, cycle_length: int |
271 | 279 | ): |
| 280 | + """Initialize the cosine annealing scheduler. |
| 281 | + |
| 282 | + Args: |
| 283 | + exploration_rate: Initial exploration rate. |
| 284 | + max_rate: Maximum exploration rate bound. |
| 285 | + min_rate: Minimum exploration rate bound. |
| 286 | + cycle_length: Number of epochs per complete cosine cycle. |
| 287 | + |
| 288 | + Raises: |
| 289 | + ValueError: If cycle_length is not positive. |
| 290 | + """ |
272 | 291 | super().__init__(exploration_rate, max_rate, min_rate) |
273 | 292 | if cycle_length <= 0: |
274 | 293 | raise ValueError(f"cycle_length ({cycle_length}) must be positive") |
275 | 294 | self.cycle_length = cycle_length |
276 | 295 |
|
277 | 296 | def __call__(self, epoch: int, **kwargs) -> float: |
278 | | - """ """ |
| 297 | + """Compute exploration rate using cosine annealing. |
| 298 | + |
| 299 | + Args: |
| 300 | + epoch: Current epoch number. |
| 301 | + **kwargs: Additional arguments (ignored). |
| 302 | + |
| 303 | + Returns: |
| 304 | + Updated exploration rate following cosine wave. |
| 305 | + """ |
279 | 306 | cycle_progress: float = (epoch % self.cycle_length) / self.cycle_length |
280 | 307 | cosine_factor: float = 0.5 * (1 + np.cos(np.pi * cycle_progress)) |
281 | 308 | rate: float = self.min_rate + (self.max_rate - self.min_rate) * cosine_factor |
|
0 commit comments