|
17 | 17 | on a specific v6e-256 hardware setup using the XPK runner. |
18 | 18 | """ |
19 | 19 |
|
20 | | -import maxtext_trillium_model_configs as model_configs |
| 20 | +import os |
21 | 21 |
|
22 | | -from maxtext_xpk_runner import BenchmarkRunner |
23 | | -from maxtext_xpk_runner import HWConfig |
24 | | -from maxtext_xpk_runner import SWconfig |
25 | | -from maxtext_xpk_runner import xpk_benchmark_runner |
26 | | -from maxtext_xpk_runner import XpkConfig |
| 22 | +from benchmarks import maxtext_trillium_model_configs as model_configs |
| 23 | +from benchmarks.maxtext_xpk_runner import WorkloadConfig |
| 24 | +from benchmarks.maxtext_xpk_runner import xpk_benchmark_runner |
| 25 | +from benchmarks.maxtext_xpk_runner import XpkClusterConfig |
27 | 26 |
|
28 | 27 |
|
29 | 28 | DATE = "20241009" |
|
35 | 34 | DEVICE_TYPE = "v6e-256" |
36 | 35 | NUM_SLICES = 1 |
37 | 36 | BASE_OUTPUT_DIR = "gs://maxtext-experiments-tpem/" |
38 | | - |
39 | | -v6e_env_configs = SWconfig(base_docker_image=BASE_DOCKER_IMAGE, libtpu_version=DATE) |
40 | | -v6e_256_configs = HWConfig(num_slices=NUM_SLICES, device_type=DEVICE_TYPE) |
41 | | - |
42 | | -llama2_70b_4096 = BenchmarkRunner( |
43 | | - model_name=model_configs.llama2_70b_4096, |
44 | | - software_config=v6e_env_configs, |
45 | | - hardware_config=v6e_256_configs, |
46 | | -) |
47 | | - |
48 | | -llama2_7b_4096 = BenchmarkRunner( |
49 | | - model_name=model_configs.llama2_7b_4096, |
50 | | - software_config=v6e_env_configs, |
51 | | - hardware_config=v6e_256_configs, |
52 | | -) |
| 37 | +XPK_PATH = os.path.join("~", "xpk") |
| 38 | +BENCHMARK_STEPS = 20 |
53 | 39 |
|
54 | 40 |
|
55 | 41 | def main() -> None: |
56 | | - cluster_config = XpkConfig( |
| 42 | + cluster_config = XpkClusterConfig( |
57 | 43 | cluster_name=CLUSTER_NAME, |
58 | 44 | project=PROJECT, |
59 | 45 | zone=ZONE, |
60 | | - num_slices=NUM_SLICES, |
61 | 46 | device_type=DEVICE_TYPE, |
62 | | - base_output_directory=BASE_OUTPUT_DIR, |
63 | 47 | ) |
64 | 48 |
|
65 | | - xpk_benchmark_runner(cluster_config, [llama2_7b_4096, llama2_70b_4096]) |
| 49 | + workload_configs = [] |
| 50 | + for model in [model_configs.llama2_7b_4096, model_configs.llama2_70b_4096]: |
| 51 | + workload_configs.append( |
| 52 | + WorkloadConfig( |
| 53 | + model=model, |
| 54 | + num_slices=NUM_SLICES, |
| 55 | + device_type=DEVICE_TYPE, |
| 56 | + base_output_directory=BASE_OUTPUT_DIR, |
| 57 | + base_docker_image=BASE_DOCKER_IMAGE, |
| 58 | + libtpu_type=None, |
| 59 | + libtpu_nightly_version=DATE, |
| 60 | + pathways_config=None, |
| 61 | + xpk_path=XPK_PATH, |
| 62 | + num_steps=BENCHMARK_STEPS, |
| 63 | + priority="medium", |
| 64 | + ) |
| 65 | + ) |
| 66 | + |
| 67 | + xpk_benchmark_runner(cluster_config, workload_configs) |
66 | 68 |
|
67 | 69 |
|
68 | 70 | if __name__ == "__main__": |
|
0 commit comments