Commit f08effb
Wan training: Change checkpoint restore_type to jax.Array
- Address review comments: change the restore_type from np.ndarray to jax.Array. This is necessary because JAX sharding has no effect on np.ndarray, using jax.Array ensures that the specified sharding is respected during checkpoint restoration.
Co-authored-by: martinarroyo <martinarroyo@google.com>1 parent 19b875b commit f08effb
2 files changed
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
66 | | - | |
| 66 | + | |
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
66 | | - | |
| 66 | + | |
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
| |||
0 commit comments