-
Notifications
You must be signed in to change notification settings - Fork 73
Expand file tree
/
Copy pathcircular_buffer.py
More file actions
114 lines (92 loc) · 4.2 KB
/
Copy pathcircular_buffer.py
File metadata and controls
114 lines (92 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Circular buffer written in JAX.
See circular_buffer_test.py for usage.
"""
import collections
import functools
from typing import Generic, Tuple, TypeVar
import jax
from jax import tree_util
import jax.numpy as jnp
CircularBufferState = collections.namedtuple("CircularBufferState",
["values", "idx"])
T = TypeVar("T")
class CircularBuffer(Generic[T]):
"""Stateless class to manage circular buffer."""
def __init__(self, abstract_value: T, size: int):
"""Initializer.
Args:
abstract_value: a pytree of jax.ShapedArray with the shape of each element
in the circular buffer.
size: length of circular buffer.
"""
self.abstract_value: T = abstract_value
self.size = size
@functools.partial(jax.jit, static_argnums=0)
def init(self, default=0.0) -> CircularBufferState:
"""Construct the initial state of the circular buffer with default value."""
def build_one(x):
expanded = jnp.expand_dims(default * jnp.ones(x.shape, dtype=x.dtype), 0)
tiled = jnp.tile(expanded, [self.size] + [1] * len(x.shape))
return jnp.asarray(tiled, dtype=x.dtype)
empty_buffer = tree_util.tree_map(build_one, self.abstract_value)
return CircularBufferState(
idx=jnp.asarray(0, jnp.int64),
values=(empty_buffer,
jnp.ones([self.size], dtype=jnp.int64) * -self.size))
@functools.partial(jax.jit, static_argnums=(0,))
def add(self, state: CircularBufferState, value: T) -> CircularBufferState: # pytype: disable=invalid-annotation
"""Construct the initial state of the circular buffer with default value."""
idx = state.idx % self.size
def do_update(src, to_set):
if src.shape:
return src.at[idx].set(to_set)
else:
return src.at[idx, :].set(to_set)
new_jax_array = tree_util.tree_map(do_update, state.values,
(value, state.idx))
return CircularBufferState(idx=state.idx + 1, values=new_jax_array)
def _reorder(self, vals, idx):
offset = idx % self.size
return jnp.roll(vals, -offset, axis=0)
@functools.partial(jax.jit, static_argnums=(0,))
def stack_with_idx(self, state: CircularBufferState) -> Tuple[T, jnp.ndarray]: # pytype: disable=invalid-annotation
"""Return raw values with integer array containing index.
Args:
state: State of circular buffer
Returns:
values: The values contained in the circular buffer with a leading
dimension of size `self.size`.
idx: The integer representing when each element was added.
"""
candidate = jnp.clip((state.values[1] - state.idx + self.size), -1,
self.size)
return state.values[0], jnp.where(state.values[1] == -1, -1, candidate)
@functools.partial(jax.jit, static_argnums=(0,))
def stack_reorder(self, state: CircularBufferState) -> Tuple[T, jnp.ndarray]: # pytype: disable=invalid-annotation
"""Reorder the values, and return with a mask."""
candidate = jnp.clip((state.values[1] - state.idx + self.size), -1,
self.size)
mask = self._reorder(jnp.where(candidate == -1, 0, 1), state.idx)
return tree_util.tree_map(lambda x: self._reorder(x, state.idx),
state.values[0]), mask
@functools.partial(jax.jit, static_argnums=(0,))
def gather_from_present(
self, state: CircularBufferState, idxs: jnp.ndarray) -> T: # pytype: disable=invalid-annotation
"""Get the values from for each idx in the past."""
offset = (idxs % self.size)
idx = (state.idx + offset) % self.size
return tree_util.tree_map(lambda x: x[idx], state.values[0])