@@ -347,89 +347,3 @@ def sigmoid_calculate_metrics(
347347 dtype = dtype .dtype .name ,
348348 )
349349
350-
351- # def get_output_named_shading(mesh, strategy: ShardingStrategy):
352- # match strategy:
353- # case ShardingStrategy.NO_SHARDING:
354- # return NamedSharding(mesh, P(None))
355- # case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M:
356- # return NamedSharding(mesh, P("device"))
357- # case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M:
358- # return NamedSharding(mesh, P("device"))
359- # case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N:
360- # assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
361- # case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N:
362- # assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
363-
364- # def get_out_sharding(strategy: ShardingStrategy):
365- # match strategy:
366- # case ShardingStrategy.NO_SHARDING:
367- # return P(None)
368- # case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M:
369- # return P("device")
370- # case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M:
371- # return P("device")
372- # case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N:
373- # assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
374- # case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N:
375- # assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
376-
377- # def add(m: int, dtype: jnp.dtype, num_runs: int = 1, trace_dir: str = None,
378- # ) -> Dict[str, Any]:
379- # """
380- # Z = X + Y
381- # """
382- # def f(x, y):
383- # with jax.named_scope(MARKER):
384- # return x + y
385-
386- # mesh = create_mesh(SHARDING_STRATEGY)
387- # x_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY)
388- # y_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY)
389- # out_sharding = get_out_sharding(SHARDING_STRATEGY)
390- # jit_sharded_f = jax.jit(
391- # shard_map(
392- # f,
393- # mesh,
394- # in_specs=(x_sharding.spec, y_sharding.spec),
395- # out_specs=out_sharding,
396- # check_rep=False,
397- # )
398- # )
399- # x_shape = (m)
400- # y_shape = (m)
401- # x_dtype = dtype
402- # y_dtype = dtype
403-
404- # key = jax.random.key(SEED)
405-
406- # def data_generator():
407- # """Creates new random data on host and puts it on device."""
408- # nonlocal key # Use and update the outer 'key'
409- # key, k1, k2 = jax.random.split(key, 3)
410-
411- # x_host = jax.random.normal(k1, x_shape).astype(x_dtype)
412- # y_host = jax.random.normal(k2, y_shape).astype(y_dtype)
413-
414- # x_device = jax.device_put(x_host, x_sharding)
415- # y_device = jax.device_put(y_host, y_sharding)
416-
417- # return (x_device, y_device)
418-
419- # time_ms_list = iteration_timeit(
420- # jit_sharded_f,
421- # data_generator,
422- # matrix_dim=f"{m}",
423- # tries=num_runs,
424- # task="add",
425- # trace_dir=trace_dir,
426- # )
427- # return {"time_ms_list": time_ms_list}
428-
429- # def add_calculate_metrics(
430- # m: int, dtype: jnp.dtype, time_ms_list: list[float]
431- # ) -> Dict[str, Any]:
432- # scale = 2 if dtype == jnp.bfloat16 else 1
433- # total_bytes = scale * 3 * m
434- # total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY)
435- # return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name)
0 commit comments