File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1919
2020import jax
2121from jax ._src .interpreters import pxla
22- from jaxlib . xla_extension import ifrt_programs
22+ from jax . extend . ifrt_programs import ifrt_programs
2323
2424
2525class PluginExecutable :
Original file line number Diff line number Diff line change 1414"""Register the IFRT Proxy as a backend for JAX."""
1515
1616import jax
17- from jax ._src import xla_bridge
18- from jaxlib .xla_extension import ifrt_proxy
17+ from jax .extend import backend
18+ from jax . lib .xla_extension import ifrt_proxy
1919
2020
2121def register_backend_factory ():
22- xla_bridge .register_backend_factory (
22+ backend .register_backend_factory (
2323 "proxy" ,
2424 lambda : ifrt_proxy .get_client (
2525 jax .config .read ("jax_backend_target" ),
Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ """Tests for the proxy backend module."""
15+
16+ from unittest import mock
17+
1418import jax
1519from jax .extend import backend
1620from jax .lib .xla_extension import ifrt_proxy
17- import mock
1821from pathwaysutils import proxy_backend
1922
2023from absl .testing import absltest
@@ -26,7 +29,9 @@ def setUp(self):
2629 super ().setUp ()
2730 jax .config .update ("jax_platforms" , "proxy" )
2831 jax .config .update ("jax_backend_target" , "grpc://localhost:12345" )
32+ backend .clear_backends ()
2933
34+ @absltest .skip ("b/408025233" )
3035 def test_no_proxy_backend_registration_raises_error (self ):
3136 self .assertRaises (RuntimeError , backend .backends )
3237
You can’t perform that action at this time.
0 commit comments