Skip to content

Commit 286284d

Browse files
feat(hpc-ai-training): add AI training use-case to codebase (GoogleCloudPlatform#385)
1 parent dfd7974 commit 286284d

19 files changed

Lines changed: 642 additions & 41 deletions

File tree

3-fleetscope/modules/private_install_manifest/main.tf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ module "kubectl" {
7373
membership_name = var.cluster_name
7474
membership_location = var.cluster_region
7575
kubectl_create_command = "kubectl apply --server-side -f ${path.module}/manifest-${random_uuid.uid.result}-${var.cluster_name}.yaml"
76-
kubectl_destroy_command = "kubectl delete -f ${path.module}/manifest-${random_uuid.uid.result}-${var.cluster_name}.yaml || exit 0"
76+
kubectl_destroy_command = "timeout 300s kubectl delete -f ${path.module}/manifest-${random_uuid.uid.result}-${var.cluster_name}.yaml || exit 0"
7777

7878
module_depends_on = [
7979
local_file.downloaded_file.filename,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
FROM tensorflow/tensorflow:latest-gpu@sha256:1f16fbd9be8bb84891de12533e332bbd500511caeb5cf4db501dbe39d422f9c7
16+
WORKDIR /data/tensorflow-mnist-example
17+
COPY requirements.txt .
18+
RUN pip install --no-cache-dir -r requirements.txt
19+
COPY . .
20+
CMD ["/bin/bash", "-c", "--", "python tensorflow_mnist_train_distributed.py"]
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/**
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
resource "google_service_account" "builder" {
18+
project = var.infra_project
19+
account_id = "ai-builder"
20+
}
21+
22+
resource "google_storage_bucket" "build_logs" {
23+
name = "cb-ai-builder-logs-${var.infra_project}"
24+
project = var.infra_project
25+
uniform_bucket_level_access = true
26+
force_destroy = var.bucket_force_destroy
27+
location = var.region
28+
}
29+
30+
# IAM Roles required to build the terraform image on Google Cloud Build
31+
resource "google_storage_bucket_iam_member" "builder_admin" {
32+
member = google_service_account.builder.member
33+
bucket = google_storage_bucket.build_logs.name
34+
role = "roles/storage.admin"
35+
}
36+
37+
resource "google_project_iam_member" "builder_object_user" {
38+
member = google_service_account.builder.member
39+
project = var.infra_project
40+
role = "roles/storage.objectUser"
41+
}
42+
43+
resource "google_artifact_registry_repository_iam_member" "builder" {
44+
project = google_artifact_registry_repository.private_images.project
45+
location = google_artifact_registry_repository.private_images.location
46+
repository = google_artifact_registry_repository.private_images.name
47+
role = "roles/artifactregistry.repoAdmin"
48+
member = google_service_account.builder.member
49+
}
50+
51+
resource "google_artifact_registry_repository_iam_member" "allow_cluster_sa_download" {
52+
for_each = var.cluster_service_accounts
53+
project = google_artifact_registry_repository.private_images.project
54+
location = google_artifact_registry_repository.private_images.location
55+
repository = google_artifact_registry_repository.private_images.name
56+
role = "roles/artifactregistry.reader"
57+
member = "serviceAccount:${each.value}"
58+
}
59+
60+
resource "time_sleep" "wait_iam_propagation" {
61+
create_duration = "60s"
62+
63+
depends_on = [
64+
google_artifact_registry_repository_iam_member.builder,
65+
google_storage_bucket_iam_member.builder_admin,
66+
google_project_iam_member.builder_object_user,
67+
]
68+
}
69+
70+
71+
resource "time_sleep" "wait_api" {
72+
create_duration = "20s"
73+
74+
depends_on = [
75+
google_project_service.enable_apis
76+
]
77+
}
78+
79+
resource "google_artifact_registry_repository" "private_images" {
80+
location = var.region
81+
project = var.infra_project
82+
repository_id = "private-images"
83+
description = "Docker repository for private images"
84+
format = "DOCKER"
85+
86+
depends_on = [
87+
time_sleep.wait_api
88+
]
89+
}
90+
91+
module "build_ai_run_image_image" {
92+
source = "terraform-google-modules/gcloud/google"
93+
version = "~> 3.5"
94+
upgrade = false
95+
96+
create_cmd_triggers = {
97+
"tag_version" = local.docker_tag_version_terraform
98+
}
99+
100+
create_cmd_entrypoint = "bash"
101+
102+
create_cmd_body = <<EOF
103+
gcloud builds submit ${path.module} \
104+
--tag ${var.region}-docker.pkg.dev/${var.infra_project}/${google_artifact_registry_repository.private_images.name}/ai-train:${local.docker_tag_version_terraform} \
105+
--project=${var.infra_project} \
106+
--service-account=${google_service_account.builder.id} \
107+
--gcs-log-dir=${google_storage_bucket.build_logs.url} || (
108+
sleep 45 && gcloud builds submit ${path.module} \
109+
--tag ${var.region}-docker.pkg.dev/${var.infra_project}/${google_artifact_registry_repository.private_images.name}/ai-train:${local.docker_tag_version_terraform} \
110+
--project=${var.infra_project} \
111+
--service-account=${google_service_account.builder.id} \
112+
--gcs-log-dir=${google_storage_bucket.build_logs.url}
113+
)
114+
EOF
115+
116+
module_depends_on = [time_sleep.wait_iam_propagation]
117+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/**
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
locals {
18+
docker_tag_version_terraform = "v1"
19+
image_url = "${google_artifact_registry_repository.private_images.location}-docker.pkg.dev/${google_artifact_registry_repository.private_images.project}/${google_artifact_registry_repository.private_images.repository_id}/ai-train:${local.docker_tag_version_terraform}"
20+
namespace = "${var.team}-${var.env}"
21+
}
22+
23+
resource "google_project_iam_member" "team_roles" {
24+
for_each = toset([
25+
"roles/storage.objectUser",
26+
"roles/pubsub.publisher",
27+
"roles/pubsub.viewer"
28+
])
29+
30+
project = var.infra_project
31+
role = each.value
32+
member = "principalSet://iam.googleapis.com/projects/${var.cluster_project_number}/locations/global/workloadIdentityPools/${var.cluster_project}.svc.id.goog/namespace/${local.namespace}"
33+
}
34+
35+
resource "google_project_service" "enable_apis" {
36+
for_each = toset([
37+
"storage.googleapis.com",
38+
"cloudresourcemanager.googleapis.com",
39+
"logging.googleapis.com",
40+
"batch.googleapis.com",
41+
"cloudbuild.googleapis.com",
42+
"artifactregistry.googleapis.com",
43+
])
44+
project = var.infra_project
45+
service = each.key
46+
disable_on_destroy = false
47+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/**
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
output "image_url" {
18+
description = "AI Image URL"
19+
value = local.image_url
20+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow-datasets
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# flake8: noqa
15+
16+
import os
17+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
18+
import tensorflow_datasets as tfds
19+
import tensorflow as tf
20+
import keras
21+
import glob
22+
23+
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
24+
25+
mnist_train, mnist_test = datasets['train'], datasets['test']
26+
27+
print('******************')
28+
print('MNIST TRAINING JOB')
29+
print('******************')
30+
31+
strategy = tf.distribute.MirroredStrategy()
32+
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
33+
num_train_examples = info.splits['train'].num_examples
34+
num_test_examples = info.splits['test'].num_examples
35+
36+
BUFFER_SIZE = 10000
37+
38+
BATCH_SIZE_PER_REPLICA = 64
39+
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
40+
41+
def scale(image, label):
42+
image = tf.cast(image, tf.float32)
43+
image /= 255
44+
45+
return image, label
46+
47+
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
48+
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
49+
50+
with strategy.scope():
51+
model = keras.Sequential([
52+
keras.Input(shape=(28, 28, 1)),
53+
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
54+
keras.layers.MaxPooling2D(),
55+
keras.layers.Flatten(),
56+
keras.layers.Dense(64, activation='relu'),
57+
keras.layers.Dense(10)
58+
])
59+
60+
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
61+
optimizer=keras.optimizers.Adam(),
62+
metrics=['accuracy'])
63+
64+
# Define the checkpoint directory to store the checkpoints.
65+
checkpoint_dir = './training_checkpoints'
66+
# Define the name of the checkpoint files.
67+
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")
68+
69+
def decay(epoch):
70+
if epoch < 3:
71+
return 1e-3
72+
elif epoch >= 3 and epoch < 7:
73+
return 1e-4
74+
else:
75+
return 1e-5
76+
77+
# Define a callback for printing the learning rate at the end of each epoch.
78+
class PrintLR(keras.callbacks.Callback):
79+
def on_epoch_end(self, epoch, logs=None):
80+
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
81+
model.optimizer.learning_rate.numpy()))
82+
83+
callbacks = [
84+
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
85+
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
86+
save_weights_only=True),
87+
tf.keras.callbacks.LearningRateScheduler(decay),
88+
PrintLR()
89+
]
90+
91+
EPOCHS = 12
92+
93+
model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
94+
95+
# Function to find the latest .h5 file
96+
def find_latest_h5_checkpoint(checkpoint_dir):
97+
list_of_files = glob.glob(f'{checkpoint_dir}/*.h5')
98+
if list_of_files:
99+
latest_file = max(list_of_files, key=os.path.getctime)
100+
return latest_file
101+
else:
102+
return None
103+
104+
model.load_weights(find_latest_h5_checkpoint(checkpoint_dir))
105+
106+
eval_loss, eval_acc = model.evaluate(eval_dataset)
107+
108+
print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))
109+
110+
path = '/data/mnist_saved_model'
111+
os.makedirs(path, exist_ok=True)
112+
113+
model_file = '/data/mnist_saved_model/mnist.keras'
114+
model.save(model_file)
115+
116+
print('Training finished. Model saved')
117+
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
/**
3+
* Copyright 2025 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
variable "infra_project" {
19+
description = "The infrastructure project where resources will be managed."
20+
type = string
21+
}
22+
23+
variable "cluster_project" {
24+
description = "The project that hosts the Kubernetes cluster."
25+
type = string
26+
}
27+
28+
variable "region" {
29+
description = "The region where the cloud resources will be deployed."
30+
type = string
31+
}
32+
33+
variable "bucket_force_destroy" {
34+
description = "When deleting a bucket, this boolean option will delete all contained objects. If false, Terraform will fail to delete buckets which contain objects."
35+
type = bool
36+
default = false
37+
}
38+
39+
variable "cluster_project_number" {
40+
description = "The numerical identifier for the cluster project."
41+
type = string
42+
}
43+
44+
variable "env" {
45+
description = "The environment in which resources are deployed (e.g., development, nonproduction, production)."
46+
type = string
47+
}
48+
49+
variable "cluster_service_accounts" {
50+
description = "A map of service accounts emails associated with the Kubernetes cluster, these will be granted access to created Docker images."
51+
type = map(any)
52+
}
53+
54+
variable "team" {
55+
description = "Environment Team, must be the same as the fleet scope team"
56+
type = string
57+
}

0 commit comments

Comments
 (0)