Skip to content

Commit ff021b7

Browse files
committed
Code cleaned up to make it more readable
1 parent 8a1723c commit ff021b7

4 files changed

Lines changed: 3 additions & 33 deletions

File tree

recml/core/data/iterator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,18 @@ def _maybe_to_numpy(
6868
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
6969
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
7070
return x
71-
# FIX: Check for attribute existence to avoid crashes on non-Tensor objects
7271
if hasattr(x, "_numpy"):
7372
numpy = x._numpy() # pylint: disable=protected-access
7473
elif hasattr(x, "numpy"):
7574
numpy = x.numpy()
7675
else:
77-
return x # Return as-is if it can't be converted
76+
return x
7877

7978
if isinstance(numpy, np.ndarray):
80-
# `numpy` shares the same underlying buffer as the `x` Tensor.
8179
# Tensors are expected to be immutable, so we disable writes.
8280
numpy.setflags(write=False)
8381
return numpy
8482

85-
# FIX: Use jax.tree.map instead of tf.nest.map_structure
8683
return jax.tree.map(_maybe_to_numpy, batch)
8784

8885
@property
@@ -115,7 +112,6 @@ def _to_element_spec(
115112
)
116113
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
117114

118-
# element_spec = tf.nest.map_structure(_to_element_spec, batch)
119115
element_spec = jax.tree.map(_to_element_spec, batch)
120116
self._element_spec = element_spec
121117
return element_spec

recml/core/training/partitioning.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def _shard(x: np.ndarray) -> jax.Array:
110110
def partition_init(
111111
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
112112
) -> CreateStateFn:
113-
# FIXED: Use 'with self.mesh'
114113
with self.mesh:
115114
if abstract_batch is not None:
116115
mesh_context.set_global_mesh(self.mesh)
@@ -122,7 +121,6 @@ def partition_init(
122121
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
123122

124123
def _wrapped_init(batch: PyTree) -> State:
125-
# FIXED: Use 'with self.mesh'
126124
with self.mesh:
127125
state = init_fn(batch)
128126
state = _maybe_unbox_state(state)
@@ -136,7 +134,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
136134
jit_kws["out_shardings"] = (self.state_sharding, None)
137135
jit_kws["donate_argnums"] = (1,)
138136

139-
# FIXED: Use 'with self.mesh' and legacy bridge
140137
with self.mesh:
141138
mesh_context.set_global_mesh(self.mesh)
142139
step_fn = jax.jit(
@@ -146,7 +143,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
146143
)
147144

148145
def _wrapped_step(batch: PyTree, state: State) -> Any:
149-
# FIXED: Use 'with self.mesh'
150146
with self.mesh:
151147
return step_fn(batch, state)
152148

@@ -238,9 +234,7 @@ def partition_init(
238234
" model parallel partitioner."
239235
)
240236

241-
# FIXED: Use 'with self.mesh' directly
242237
with self.mesh:
243-
# FIXED: Legacy bridge
244238
mesh_context.set_global_mesh(self.mesh)
245239
abstract_state = jax.eval_shape(init_fn, abstract_batch)
246240
specs = nn.get_partition_spec(abstract_state)
@@ -254,7 +248,6 @@ def partition_init(
254248
compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding)
255249

256250
def _init(batch: PyTree) -> State:
257-
# FIXED: Use 'with self.mesh' directly
258251
with self.mesh:
259252
state = compiled_init_fn(batch)
260253
state = _maybe_unbox_state(state)
@@ -273,7 +266,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
273266
else:
274267
jit_kws["out_shardings"] = None
275268

276-
# FIXED: Use 'with self.mesh' directly and legacy bridge
269+
277270
with self.mesh:
278271
mesh_context.set_global_mesh(self.mesh)
279272
step_fn = jax.jit(
@@ -296,7 +289,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
296289
)
297290

298291
def _step(batch: PyTree, state: State) -> Any:
299-
# FIXED: Use 'with self.mesh' directly
300292
with self.mesh:
301293
return step_fn(batch, state)
302294

recml/layers/linen/sparsecore.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -371,17 +371,6 @@ class SparsecoreEmbed(nn.Module):
371371
sparsecore_config: SparsecoreConfig
372372
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None
373373

374-
# def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh:
375-
# if self.mesh is not None:
376-
# return self.mesh
377-
# abstract_mesh = jax.sharding.get_abstract_mesh()
378-
# if not abstract_mesh.shape_tuple:
379-
# raise ValueError(
380-
# 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make'
381-
# ' sure to set the mesh when calling the sparsecore module.'
382-
# )
383-
# return abstract_mesh
384-
385374
def get_mesh(self) -> jax.sharding.Mesh:
386375
# Try to get the mesh from our custom global context
387376
mesh = mesh_context.get_global_mesh()

training.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,4 @@ RUN pip install "protobuf>=6.31.1" --no-deps
6868
CMD ["python", "recml/examples/dlrm_experiment_test.py"]
6969
```
7070

71-
You can use this dockerfile to run the DLRM model experiment from this repo in your own environment.
72-
73-
74-
75-
76-
77-
78-
71+
You can use this dockerfile to run the DLRM model experiment from this repo in your own environment.

0 commit comments

Comments
 (0)