diff --git a/gigl/orchestration/kubeflow/components/config_populator/component.yaml b/gigl/orchestration/kubeflow/components/config_populator/component.yaml index 2b9bd941b..deb792861 100644 --- a/gigl/orchestration/kubeflow/components/config_populator/component.yaml +++ b/gigl/orchestration/kubeflow/components/config_populator/component.yaml @@ -4,6 +4,8 @@ inputs: - {name: job_name, type: String, description: 'Unique name to identify the job'} - {name: template_uri, type: String, description: 'GBML Template uri'} - {name: resource_config_uri, type: String, description: 'Runtine argument for resource and env specifications of each component'} +- {name: cpu_docker_uri, type: String, description: "Uri to dockerized source code compiled for cpu at runtime"} +- {name: cuda_docker_uri, type: String, description: "Uri to dockerized source code compiled for gpu at runtime"} outputs: - {name: frozen_gbml_config_uri, type: String, description: 'Output frozen gbml config uri, populated'} @@ -15,5 +17,7 @@ implementation: --job_name, {inputValue: job_name}, --template_uri, {inputValue: template_uri}, --resource_config_uri, {inputValue: resource_config_uri}, + --cpu_docker_uri, {inputValue: cpu_docker_uri}, + --cuda_docker_uri, {inputValue: cuda_docker_uri}, --output_file_path_frozen_gbml_config_uri, {outputPath: frozen_gbml_config_uri} ] diff --git a/gigl/orchestration/kubeflow/components/config_validator/component.yaml b/gigl/orchestration/kubeflow/components/config_validator/component.yaml index 2dd9743a4..7c4ec057e 100644 --- a/gigl/orchestration/kubeflow/components/config_validator/component.yaml +++ b/gigl/orchestration/kubeflow/components/config_validator/component.yaml @@ -6,6 +6,8 @@ inputs: - {name: start_at, type: String, description: 'Start component'} - {name: resource_config_uri, type: String, description: 'Runtine argument for resource and env specifications of each component'} - {name: stop_after, type: String, description: 'Stop component'} +- {name: cpu_docker_uri, type: String, description: "Uri to dockerized source code compiled for cpu at runtime"} +- {name: cuda_docker_uri, type: String, description: "Uri to dockerized source code compiled for gpu at runtime"} outputs: implementation: @@ -17,5 +19,7 @@ implementation: --task_config_uri, {inputValue: task_config_uri}, --start_at, {inputValue: start_at}, --resource_config_uri, {inputValue: resource_config_uri}, - --stop_after, {inputValue: stop_after} + --stop_after, {inputValue: stop_after}, + --cpu_docker_uri, {inputValue: cpu_docker_uri}, + --cuda_docker_uri, {inputValue: cuda_docker_uri} ] diff --git a/gigl/orchestration/kubeflow/components/data_preprocessor/component.yaml b/gigl/orchestration/kubeflow/components/data_preprocessor/component.yaml index d0039b69f..67b593c07 100644 --- a/gigl/orchestration/kubeflow/components/data_preprocessor/component.yaml +++ b/gigl/orchestration/kubeflow/components/data_preprocessor/component.yaml @@ -5,6 +5,8 @@ inputs: - {name: task_config_uri, type: String, description: 'Frozen GBML config uri'} - {name: resource_config_uri, type: String, description: 'Runtine argument for resource and env specifications of each component'} - {name: custom_worker_image_uri, type: String, description: "Docker image to use for the worker harness in dataflow "} +- {name: cpu_docker_uri, type: String, description: "Uri to dockerized source code compiled for cpu execution at runtime"} +- {name: cuda_docker_uri, type: String, description: "Uri to dockerized source code compiled for gpu execution at runtime"} outputs: implementation: @@ -16,4 +18,6 @@ implementation: --task_config_uri, {inputValue: task_config_uri}, --resource_config_uri, {inputValue: resource_config_uri}, --custom_worker_image_uri, {inputValue: custom_worker_image_uri}, + --cpu_docker_uri, {inputValue: cpu_docker_uri}, + --cuda_docker_uri, {inputValue: cuda_docker_uri}, ] diff --git a/gigl/orchestration/kubeflow/components/post_processor/component.yaml b/gigl/orchestration/kubeflow/components/post_processor/component.yaml index ef6516a28..0203b0c38 100644 --- a/gigl/orchestration/kubeflow/components/post_processor/component.yaml +++ b/gigl/orchestration/kubeflow/components/post_processor/component.yaml @@ -4,6 +4,8 @@ inputs: - {name: job_name, type: String, description: 'Unique name to identify the job'} - {name: task_config_uri, type: String, description: 'Frozen gbml config uri'} - {name: resource_config_uri, type: String, description: 'Runtine argument for resource and env specifications of each component'} +- {name: cpu_docker_uri, type: String, description: "Uri to dockerized source code compiled for cpu execution at runtime"} +- {name: cuda_docker_uri, type: String, description: "Uri to dockerized source code compiled for gpu execution at runtime"} outputs: implementation: @@ -14,4 +16,6 @@ implementation: --job_name, {inputValue: job_name}, --task_config_uri, {inputValue: task_config_uri}, --resource_config_uri, {inputValue: resource_config_uri}, + --cpu_docker_uri, {inputValue: cpu_docker_uri}, + --cuda_docker_uri, {inputValue: cuda_docker_uri}, ] diff --git a/gigl/orchestration/kubeflow/kfp_pipeline.py b/gigl/orchestration/kubeflow/kfp_pipeline.py index 524107ce2..b0ad70e30 100644 --- a/gigl/orchestration/kubeflow/kfp_pipeline.py +++ b/gigl/orchestration/kubeflow/kfp_pipeline.py @@ -66,6 +66,8 @@ def _generate_component_task( job_name=job_name, template_uri=task_config_uri, resource_config_uri=resource_config_uri, + cpu_docker_uri=common_pipeline_component_configs.cpu_container_image, + cuda_docker_uri=common_pipeline_component_configs.cuda_container_image, **common_pipeline_component_configs.additional_job_args.get(component, {}), ) @@ -76,6 +78,8 @@ def _generate_component_task( start_at=start_at, resource_config_uri=resource_config_uri, stop_after=stop_after, + cpu_docker_uri=common_pipeline_component_configs.cpu_container_image, + cuda_docker_uri=common_pipeline_component_configs.cuda_container_image, **common_pipeline_component_configs.additional_job_args.get(component, {}), ) elif component == GiGLComponents.SubgraphSampler: @@ -101,6 +105,8 @@ def _generate_component_task( task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, custom_worker_image_uri=common_pipeline_component_configs.dataflow_container_image, + cpu_docker_uri=common_pipeline_component_configs.cpu_container_image, + cuda_docker_uri=common_pipeline_component_configs.cuda_container_image, **common_pipeline_component_configs.additional_job_args.get(component, {}), ) elif component == GiGLComponents.Inferencer: @@ -113,6 +119,15 @@ def _generate_component_task( cuda_docker_uri=common_pipeline_component_configs.cuda_container_image, **common_pipeline_component_configs.additional_job_args.get(component, {}), ) + elif component == GiGLComponents.PostProcessor: + component_task = _speced_component_op_dict[component]( + job_name=job_name, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + cpu_docker_uri=common_pipeline_component_configs.cpu_container_image, + cuda_docker_uri=common_pipeline_component_configs.cuda_container_image, + **common_pipeline_component_configs.additional_job_args.get(component, {}), + ) else: component_task = _speced_component_op_dict[component]( job_name=job_name, diff --git a/gigl/orchestration/local/runner.py b/gigl/orchestration/local/runner.py index 06bf4c2b2..2964d652c 100644 --- a/gigl/orchestration/local/runner.py +++ b/gigl/orchestration/local/runner.py @@ -8,7 +8,7 @@ from gigl.common.utils.proto_utils import ProtoUtils from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types import AppliedTaskIdentifier -from gigl.src.common.utils.metrics_service_provider import initialize_metrics +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.config_populator.config_populator import ConfigPopulator from gigl.src.data_preprocessor.data_preprocessor import DataPreprocessor from gigl.src.inference.inferencer import Inferencer @@ -78,11 +78,6 @@ def run( f"dataflow_docker_uri: {pipeline_config.dataflow_docker_uri}" ) - initialize_metrics( - task_config_uri=pipeline_config.task_config_uri, - service_name=pipeline_config.applied_task_identifier, - ) - if start_at == GiGLComponents.ConfigPopulator.value: frozen_config_uri = Runner.run_config_populator(pipeline_config) pipeline_config.task_config_uri = frozen_config_uri @@ -107,8 +102,27 @@ def run( if started: method(pipeline_config) + @staticmethod + def _initialize_component_runtime( + pipeline_config: PipelineConfig, + component: GiGLComponents, + ) -> None: + initialize_gigl_runtime( + applied_task_identifier=pipeline_config.applied_task_identifier, + task_config_uri=pipeline_config.task_config_uri, + resource_config_uri=pipeline_config.resource_config_uri, + service_name=pipeline_config.applied_task_identifier, + component=component, + cpu_docker_uri=pipeline_config.custom_cpu_docker_uri, + cuda_docker_uri=pipeline_config.custom_cuda_docker_uri, + ) + @staticmethod def config_check(start_at: str, pipeline_config: PipelineConfig): + Runner._initialize_component_runtime( + pipeline_config=pipeline_config, + component=GiGLComponents.ConfigValidator, + ) proto_utils = ProtoUtils() gbml_config_pb: gbml_config_pb2.GbmlConfig = proto_utils.read_proto_from_yaml( uri=pipeline_config.task_config_uri, proto_cls=gbml_config_pb2.GbmlConfig @@ -123,17 +137,27 @@ def config_check(start_at: str, pipeline_config: PipelineConfig): @staticmethod def run_config_populator(pipeline_config: PipelineConfig) -> Uri: logger.info("Running Config Populator...") + Runner._initialize_component_runtime( + pipeline_config=pipeline_config, + component=GiGLComponents.ConfigPopulator, + ) config_populator = ConfigPopulator() return config_populator.run( applied_task_identifier=pipeline_config.applied_task_identifier, task_config_uri=pipeline_config.task_config_uri, resource_config_uri=pipeline_config.resource_config_uri, + cpu_docker_uri=pipeline_config.custom_cpu_docker_uri, + cuda_docker_uri=pipeline_config.custom_cuda_docker_uri, ) @staticmethod def run_data_preprocessor(pipeline_config: PipelineConfig) -> None: logger.info("Running Data Preprocessor...") + Runner._initialize_component_runtime( + pipeline_config=pipeline_config, + component=GiGLComponents.DataPreprocessor, + ) data_preprocessor = DataPreprocessor() data_preprocessor.run( applied_task_identifier=pipeline_config.applied_task_identifier, @@ -145,6 +169,10 @@ def run_data_preprocessor(pipeline_config: PipelineConfig) -> None: @staticmethod def run_subgraph_sampler(pipeline_config: PipelineConfig) -> None: logger.info("Running Subgraph Sampler...") + Runner._initialize_component_runtime( + pipeline_config=pipeline_config, + component=GiGLComponents.SubgraphSampler, + ) subgraph_sampler = SubgraphSampler() subgraph_sampler.run( applied_task_identifier=pipeline_config.applied_task_identifier, @@ -155,6 +183,10 @@ def run_subgraph_sampler(pipeline_config: PipelineConfig) -> None: @staticmethod def run_split_generator(pipeline_config: PipelineConfig) -> None: logger.info("Running Split Generator...") + Runner._initialize_component_runtime( + pipeline_config=pipeline_config, + component=GiGLComponents.SplitGenerator, + ) split_generator = SplitGenerator() split_generator.run( applied_task_identifier=pipeline_config.applied_task_identifier, @@ -165,6 +197,10 @@ def run_split_generator(pipeline_config: PipelineConfig) -> None: @staticmethod def run_trainer(pipeline_config: PipelineConfig) -> None: logger.info("Running Trainer...") + Runner._initialize_component_runtime( + pipeline_config=pipeline_config, + component=GiGLComponents.Trainer, + ) trainer = Trainer() trainer.run( applied_task_identifier=pipeline_config.applied_task_identifier, @@ -177,6 +213,10 @@ def run_trainer(pipeline_config: PipelineConfig) -> None: @staticmethod def run_inferencer(pipeline_config: PipelineConfig) -> None: logger.info("Running Inferencer...") + Runner._initialize_component_runtime( + pipeline_config=pipeline_config, + component=GiGLComponents.Inferencer, + ) inferencer = Inferencer() inferencer.run( applied_task_identifier=pipeline_config.applied_task_identifier, diff --git a/gigl/src/common/utils/gigl_runtime.py b/gigl/src/common/utils/gigl_runtime.py index 2c784b71b..655f3b9da 100644 --- a/gigl/src/common/utils/gigl_runtime.py +++ b/gigl/src/common/utils/gigl_runtime.py @@ -4,6 +4,10 @@ from typing import Optional from gigl.common import Uri +from gigl.env.constants import ( + GIGL_CPU_DOCKER_URI_ENV_KEY, + GIGL_CUDA_DOCKER_URI_ENV_KEY, +) from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.utils.gigl_env import get_gigl_runtime_env_vars from gigl.src.common.utils.metrics_service_provider import initialize_metrics @@ -20,6 +24,10 @@ def initialize_gigl_runtime( ) -> None: """Initialize GiGL runtime environment and metrics for a component. + For ``SubgraphSampler`` and ``SplitGenerator`` only metrics are initialized; + runtime env vars are not set, since these legacy (Scala/Spark) components do + not consume the GiGL Python runtime. + Args: applied_task_identifier: Unique identifier for the GiGL job. task_config_uri: URI to the task config YAML file. @@ -29,14 +37,32 @@ def initialize_gigl_runtime( cpu_docker_uri: CPU source image URI. Defaults to the release CPU image. cuda_docker_uri: CUDA source image URI. Defaults to the release CUDA image. """ + if component in {GiGLComponents.SubgraphSampler, GiGLComponents.SplitGenerator}: + initialize_metrics(task_config_uri=task_config_uri, service_name=service_name) + return + + # TODO(kmonte): Also expose the dataflow docker URI (used as custom_worker_image_uri by + # DataPreprocessor/Inferencer) as a GIGL_DATAFLOW_DOCKER_URI env var for parity with the + # CPU/CUDA docker URIs. Requires a new key in gigl/env/constants.py and threading it + # through get_gigl_runtime_env_vars. + resolved_cpu_docker_uri = ( + os.environ.get(GIGL_CPU_DOCKER_URI_ENV_KEY) + if cpu_docker_uri is None + else cpu_docker_uri + ) + resolved_cuda_docker_uri = ( + os.environ.get(GIGL_CUDA_DOCKER_URI_ENV_KEY) + if cuda_docker_uri is None + else cuda_docker_uri + ) os.environ.update( get_gigl_runtime_env_vars( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, component=component, - cpu_docker_uri=cpu_docker_uri, - cuda_docker_uri=cuda_docker_uri, + cpu_docker_uri=resolved_cpu_docker_uri, + cuda_docker_uri=resolved_cuda_docker_uri, ) ) initialize_metrics(task_config_uri=task_config_uri, service_name=service_name) diff --git a/gigl/src/config_populator/config_populator.py b/gigl/src/config_populator/config_populator.py index d53a6fa11..becbe669d 100644 --- a/gigl/src/config_populator/config_populator.py +++ b/gigl/src/config_populator/config_populator.py @@ -12,6 +12,7 @@ from gigl.common.metrics.decorators import flushes_metrics, profileit from gigl.common.utils.proto_utils import ProtoUtils from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.constants.metrics import TIMER_CONFIG_POPULATOR_S from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.dataset_split import DatasetSplit @@ -19,9 +20,9 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.pb_wrappers.task_metadata import TaskMetadataPbWrapper from gigl.src.common.types.task_metadata import TaskMetadataType +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.common.utils.metrics_service_provider import ( get_metrics_service_instance, - initialize_metrics, ) from snapchat.research.gbml import ( dataset_metadata_pb2, @@ -616,6 +617,8 @@ def run( applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, + cpu_docker_uri: Optional[str] = None, + cuda_docker_uri: Optional[str] = None, ) -> GcsUri: """ Runs the ConfigPopulator; given an input GbmlConfig file, produces a frozen one. @@ -624,12 +627,20 @@ def run( applied_task_identifier (AppliedTaskIdentifier): The job name. task_config_uri (Uri): Template GbmlConfig URI. resource_config_uri: GiGL resource config Uri + cpu_docker_uri (Optional[str]): CPU source image URI. Defaults to the release CPU image. + cuda_docker_uri (Optional[str]): CUDA source image URI. Defaults to the release CUDA image. Returns: GcsUri: The URI of the frozen GbmlConfig. """ - initialize_metrics( - task_config_uri=task_config_uri, service_name=applied_task_identifier + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=applied_task_identifier, + component=GiGLComponents.ConfigPopulator, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, ) resource_config = get_resource_config(resource_config_uri=resource_config_uri) @@ -673,6 +684,18 @@ def run( type=str, help="Runtime argument for resource and env specifications of each component", ) + parser.add_argument( + "--cpu_docker_uri", + type=str, + default=None, + help="Uri to dockerized source code compiled for cpu at runtime", + ) + parser.add_argument( + "--cuda_docker_uri", + type=str, + default=None, + help="Uri to dockerized source code compiled for gpu at runtime", + ) args = parser.parse_args() @@ -688,6 +711,8 @@ def run( applied_task_identifier=ati, task_config_uri=template_uri, resource_config_uri=resource_config_uri, + cpu_docker_uri=args.cpu_docker_uri, + cuda_docker_uri=args.cuda_docker_uri, ) # Write fozen_gbml_config_uri to file where it can be read by subsequent components diff --git a/gigl/src/data_preprocessor/data_preprocessor.py b/gigl/src/data_preprocessor/data_preprocessor.py index 98c0bc153..543954753 100644 --- a/gigl/src/data_preprocessor/data_preprocessor.py +++ b/gigl/src/data_preprocessor/data_preprocessor.py @@ -36,9 +36,9 @@ ) from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.file_loader import FileLoader +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.common.utils.metrics_service_provider import ( get_metrics_service_instance, - initialize_metrics, ) from gigl.src.data_preprocessor.lib.data_preprocessor_config import ( DataPreprocessorConfig, @@ -998,14 +998,36 @@ def run( help="Docker image to use for the worker harness in dataflow", required=False, ) + parser.add_argument( + "--cpu_docker_uri", + type=str, + help="User Specified or KFP compiled Docker Image for CPU execution", + required=False, + ) + parser.add_argument( + "--cuda_docker_uri", + type=str, + help="User Specified or KFP compiled Docker Image for GPU execution", + required=False, + ) args = parser.parse_args() ati = AppliedTaskIdentifier(args.job_name) task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) custom_worker_image_uri = args.custom_worker_image_uri + cpu_docker_uri = args.cpu_docker_uri + cuda_docker_uri = args.cuda_docker_uri - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + initialize_gigl_runtime( + applied_task_identifier=ati, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.DataPreprocessor, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) data_preprocessor = DataPreprocessor() data_preprocessor.run( diff --git a/gigl/src/inference/inferencer.py b/gigl/src/inference/inferencer.py index 732d4405d..35e5d5410 100644 --- a/gigl/src/inference/inferencer.py +++ b/gigl/src/inference/inferencer.py @@ -4,12 +4,13 @@ from gigl.common import Uri, UriFactory from gigl.common.logger import Logger from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( GiglResourceConfigWrapper, ) -from gigl.src.common.utils.metrics_service_provider import initialize_metrics +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.inference.lib.assets import InferenceAssets from gigl.src.inference.v1.gnn_inferencer import InferencerV1 from gigl.src.inference.v2.glt_inferencer import GLTInferencer @@ -111,9 +112,17 @@ def run( cpu_docker_uri = args.cpu_docker_uri cuda_docker_uri = args.cuda_docker_uri - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) - applied_task_identifier = AppliedTaskIdentifier(args.job_name) + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.Inferencer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) + inferencer = Inferencer() inferencer.run( applied_task_identifier=applied_task_identifier, diff --git a/gigl/src/inference/v1/gnn_inferencer.py b/gigl/src/inference/v1/gnn_inferencer.py index 8e0b6901c..7f17a0439 100644 --- a/gigl/src/inference/v1/gnn_inferencer.py +++ b/gigl/src/inference/v1/gnn_inferencer.py @@ -19,15 +19,16 @@ from gigl.common.metrics.decorators import flushes_metrics, profileit from gigl.common.utils import os_utils from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.constants.metrics import TIMER_INFERENCER_S from gigl.src.common.graph_builder.graph_builder_factory import GraphBuilderFactory from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.graph_data import NodeType from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.bq import BqUtils +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.common.utils.metrics_service_provider import ( get_metrics_service_instance, - initialize_metrics, ) from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets @@ -412,11 +413,25 @@ def __init__(self, bq_gcp_project: str): task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) custom_worker_image_uri = args.custom_worker_image_uri - - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + cpu_docker_uri = args.cpu_docker_uri + cuda_docker_uri = args.cuda_docker_uri applied_task_identifier = AppliedTaskIdentifier(args.job_name) - inferencer = InferencerV1(bq_gcp_project=get_resource_config().project) + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.Inferencer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) + + inferencer = InferencerV1( + bq_gcp_project=get_resource_config( + resource_config_uri=resource_config_uri + ).project + ) inferencer.run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, diff --git a/gigl/src/inference/v2/glt_inferencer.py b/gigl/src/inference/v2/glt_inferencer.py index d1e5e4f5d..c51c4bbc7 100644 --- a/gigl/src/inference/v2/glt_inferencer.py +++ b/gigl/src/inference/v2/glt_inferencer.py @@ -10,7 +10,7 @@ from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( GiglResourceConfigWrapper, ) -from gigl.src.common.utils.metrics_service_provider import initialize_metrics +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.common.vertex_ai_launcher import ( launch_graph_store_enabled_job, launch_single_pool_job, @@ -170,7 +170,15 @@ def run( resource_config_uri = UriFactory.create_uri(args.resource_config_uri) cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.Inferencer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) glt_inferencer = GLTInferencer() glt_inferencer.run( diff --git a/gigl/src/post_process/post_processor.py b/gigl/src/post_process/post_processor.py index d4ce790fb..6642dac89 100644 --- a/gigl/src/post_process/post_processor.py +++ b/gigl/src/post_process/post_processor.py @@ -10,6 +10,7 @@ from gigl.common.utils import os_utils from gigl.common.utils.gcs import GcsUtils from gigl.src.common.constants import gcs as gcs_constants +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.constants.metrics import TIMER_POST_PROCESSOR_S from gigl.src.common.translators.model_eval_metrics_translator import ( EvalMetricsCollectionTranslator, @@ -18,9 +19,9 @@ from gigl.src.common.types.model_eval_metrics import EvalMetricsCollection from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.file_loader import FileLoader +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.common.utils.metrics_service_provider import ( get_metrics_service_instance, - initialize_metrics, ) from gigl.src.post_process.lib.base_post_processor import BasePostProcessor from gigl.src.post_process.utils.unenumeration import unenumerate_all_inferred_bq_assets @@ -184,12 +185,34 @@ def run( type=str, help="Runtime argument for resource and env specifications of each component", ) + parser.add_argument( + "--cpu_docker_uri", + type=str, + help="User Specified or KFP compiled Docker Image for CPU execution", + required=False, + ) + parser.add_argument( + "--cuda_docker_uri", + type=str, + help="User Specified or KFP compiled Docker Image for GPU execution", + required=False, + ) args = parser.parse_args() task_config_uri = UriFactory.create_uri(args.task_config_uri) resource_config_uri = UriFactory.create_uri(args.resource_config_uri) applied_task_identifier = AppliedTaskIdentifier(args.job_name) - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + cpu_docker_uri = args.cpu_docker_uri + cuda_docker_uri = args.cuda_docker_uri + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.PostProcessor, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) post_processor = PostProcessor() post_processor.run( diff --git a/gigl/src/split_generator/split_generator.py b/gigl/src/split_generator/split_generator.py index db3969cea..98fbc2f5c 100644 --- a/gigl/src/split_generator/split_generator.py +++ b/gigl/src/split_generator/split_generator.py @@ -241,6 +241,9 @@ def run( if not args.job_name or not args.task_config_uri or not args.resource_config_uri: raise RuntimeError("Missing command-line arguments") + # SubgraphSampler/SplitGenerator are legacy Scala/Spark components that do not + # consume the GiGL Python runtime env vars, so we only initialize metrics here + # (rather than initialize_gigl_runtime). See the skip branch in initialize_gigl_runtime. initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) split_generator = SplitGenerator() diff --git a/gigl/src/subgraph_sampler/subgraph_sampler.py b/gigl/src/subgraph_sampler/subgraph_sampler.py index 34a5e7d38..8b7c140d1 100644 --- a/gigl/src/subgraph_sampler/subgraph_sampler.py +++ b/gigl/src/subgraph_sampler/subgraph_sampler.py @@ -374,6 +374,9 @@ def run( applied_task_identifier = AppliedTaskIdentifier(args.job_name) custom_worker_image_uri = args.custom_worker_image_uri + # SubgraphSampler/SplitGenerator are legacy Scala/Spark components that do not + # consume the GiGL Python runtime env vars, so we only initialize metrics here + # (rather than initialize_gigl_runtime). See the skip branch in initialize_gigl_runtime. initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) sgs = SubgraphSampler() diff --git a/gigl/src/training/trainer.py b/gigl/src/training/trainer.py index b1ee5dbe0..197aab0f3 100644 --- a/gigl/src/training/trainer.py +++ b/gigl/src/training/trainer.py @@ -4,10 +4,11 @@ import gigl.src.common.constants.gcs as gcs_constants from gigl.common import Uri, UriFactory from gigl.common.logger import Logger +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.file_loader import FileLoader -from gigl.src.common.utils.metrics_service_provider import initialize_metrics +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime # TODO: (svij) Rename Trainer to TrainerV1 from gigl.src.training.v1.trainer import Trainer as TrainerV1 @@ -124,7 +125,15 @@ def run( resource_config_uri = UriFactory.create_uri(args.resource_config_uri) cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.Trainer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) trainer = Trainer() trainer.run( diff --git a/gigl/src/training/v1/lib/training_process.py b/gigl/src/training/v1/lib/training_process.py index 9d8e8f21b..c79bd6983 100644 --- a/gigl/src/training/v1/lib/training_process.py +++ b/gigl/src/training/v1/lib/training_process.py @@ -25,6 +25,7 @@ is_distributed_available_and_initialized, should_distribute, ) +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.constants.metrics import ( TIMER_TRAINER_CLEANUP_ENV_S, TIMER_TRAINER_EXPORT_INFERENCE_ASSETS_S, @@ -42,9 +43,9 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.task_metadata import TaskMetadataType from gigl.src.common.utils.file_loader import FileLoader +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.common.utils.metrics_service_provider import ( get_metrics_service_instance, - initialize_metrics, ) from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.common.utils.time import current_formatted_datetime @@ -408,6 +409,18 @@ def __cleanup_training_env(self): type=str, help="Runtime argument for resource and env specifications of each component", ) + parser.add_argument( + "--cpu_docker_uri", + type=str, + help="User Specified or KFP compiled Docker Image for CPU training", + required=False, + ) + parser.add_argument( + "--cuda_docker_uri", + type=str, + help="User Specified or KFP compiled Docker Image for GPU training", + required=False, + ) args = parser.parse_args() if not args.job_name or not args.task_config_uri or not args.resource_config_uri: @@ -420,12 +433,22 @@ def __cleanup_training_env(self): logger.info(f"Starting training with device: {device}") task_config_uri = UriFactory.create_uri(args.task_config_uri) + resource_config_uri = UriFactory.create_uri(args.resource_config_uri) + cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri logger.info(f"Will use the following config for training: {task_config_uri}") logger.info( f"World Size: {torch_training.get_world_size()}, Rank: {torch_training.get_rank()}, Should Distribute: {torch_training.should_distribute()}" ) - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + initialize_gigl_runtime( + applied_task_identifier=args.job_name, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.Trainer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) training_process = GnnTrainingProcess() training_process.run(task_config_uri=task_config_uri, device=device) diff --git a/gigl/src/training/v1/trainer.py b/gigl/src/training/v1/trainer.py index c1509ea54..94e216a1f 100644 --- a/gigl/src/training/v1/trainer.py +++ b/gigl/src/training/v1/trainer.py @@ -14,7 +14,8 @@ from gigl.env.pipelines_config import get_resource_config from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types import AppliedTaskIdentifier -from gigl.src.common.utils.metrics_service_provider import initialize_metrics +from gigl.src.common.utils.gigl_env import get_gigl_runtime_env_vars +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.training.v1.lib.training_process import GnnTrainingProcess from snapchat.research.gbml.gigl_resource_config_pb2 import ( LocalResourceConfig, @@ -48,11 +49,24 @@ def run( container_uri = cpu_docker_uri if is_cpu_training else cuda_docker_uri environment_variables: list[env_var.EnvVar] = [ env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"), + *[ + env_var.EnvVar(name=name, value=value) + for name, value in get_gigl_runtime_env_vars( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + component=GiGLComponents.Trainer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ).items() + ], ] job_args = [ f"--job_name={applied_task_identifier}", f"--task_config_uri={task_config_uri}", f"--resource_config_uri={resource_config_uri}", + f"--cpu_docker_uri={cpu_docker_uri}", + f"--cuda_docker_uri={cuda_docker_uri}", ] + ([] if is_cpu_training else ["--use_cuda"]) job_config = VertexAiJobConfig( @@ -151,7 +165,15 @@ def _determine_if_cpu_training(self, trainer_config) -> bool: resource_config_uri = UriFactory.create_uri(args.resource_config_uri) cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.Trainer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) trainer = Trainer() trainer.run( diff --git a/gigl/src/training/v2/glt_trainer.py b/gigl/src/training/v2/glt_trainer.py index ff2acc5f9..db2481978 100644 --- a/gigl/src/training/v2/glt_trainer.py +++ b/gigl/src/training/v2/glt_trainer.py @@ -10,7 +10,7 @@ from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( GiglResourceConfigWrapper, ) -from gigl.src.common.utils.metrics_service_provider import initialize_metrics +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.common.vertex_ai_launcher import ( launch_graph_store_enabled_job, launch_single_pool_job, @@ -166,7 +166,15 @@ def run( resource_config_uri = UriFactory.create_uri(args.resource_config_uri) cpu_docker_uri, cuda_docker_uri = args.cpu_docker_uri, args.cuda_docker_uri - initialize_metrics(task_config_uri=task_config_uri, service_name=args.job_name) + initialize_gigl_runtime( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.Trainer, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) glt_trainer = GLTTrainer() glt_trainer.run( diff --git a/gigl/src/validation_check/config_validator.py b/gigl/src/validation_check/config_validator.py index ec0ca4caf..1e1f77340 100644 --- a/gigl/src/validation_check/config_validator.py +++ b/gigl/src/validation_check/config_validator.py @@ -9,6 +9,7 @@ from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( GiglResourceConfigWrapper, ) +from gigl.src.common.utils.gigl_runtime import initialize_gigl_runtime from gigl.src.validation_check.libs.frozen_config_path_checks import ( assert_preprocessed_metadata_exists, assert_split_generator_output_exists, @@ -383,12 +384,37 @@ def kfp_validation_checks( type=str, help="Runtime argument for resource and env specifications of each component", ) + parser.add_argument( + "--cpu_docker_uri", + type=str, + default=None, + help="Uri to dockerized source code compiled for cpu at runtime", + ) + parser.add_argument( + "--cuda_docker_uri", + type=str, + default=None, + help="Uri to dockerized source code compiled for gpu at runtime", + ) args = parser.parse_args() + task_config_uri = UriFactory.create_uri(args.task_config_uri) + resource_config_uri = UriFactory.create_uri(args.resource_config_uri) + + initialize_gigl_runtime( + applied_task_identifier=args.job_name, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + service_name=args.job_name, + component=GiGLComponents.ConfigValidator, + cpu_docker_uri=args.cpu_docker_uri, + cuda_docker_uri=args.cuda_docker_uri, + ) + kfp_validation_checks( job_name=args.job_name, - task_config_uri=UriFactory.create_uri(args.task_config_uri), + task_config_uri=task_config_uri, start_at=args.start_at, - resource_config_uri=UriFactory.create_uri(args.resource_config_uri), + resource_config_uri=resource_config_uri, stop_after=args.stop_after, ) diff --git a/tests/unit/orchestration/kubeflow/component_spec_test.py b/tests/unit/orchestration/kubeflow/component_spec_test.py new file mode 100644 index 000000000..6e705b7d4 --- /dev/null +++ b/tests/unit/orchestration/kubeflow/component_spec_test.py @@ -0,0 +1,38 @@ +from pathlib import Path + +from absl.testing import absltest + +from tests.test_assets.test_case import TestCase + +_REPO_ROOT = Path(__file__).resolve().parents[4] +_COMPONENTS_ROOT = _REPO_ROOT / "gigl" / "orchestration" / "kubeflow" / "components" + + +def _read_component_spec(component_name: str) -> str: + return (_COMPONENTS_ROOT / component_name / "component.yaml").read_text() + + +class ComponentSpecTest(TestCase): + def test_data_preprocessor_accepts_source_image_uri_flags(self) -> None: + component_spec = _read_component_spec("data_preprocessor") + + self.assertIn("name: cpu_docker_uri", component_spec) + self.assertIn("name: cuda_docker_uri", component_spec) + self.assertIn("--cpu_docker_uri, {inputValue: cpu_docker_uri}", component_spec) + self.assertIn( + "--cuda_docker_uri, {inputValue: cuda_docker_uri}", component_spec + ) + + def test_post_processor_accepts_source_image_uri_flags(self) -> None: + component_spec = _read_component_spec("post_processor") + + self.assertIn("name: cpu_docker_uri", component_spec) + self.assertIn("name: cuda_docker_uri", component_spec) + self.assertIn("--cpu_docker_uri, {inputValue: cpu_docker_uri}", component_spec) + self.assertIn( + "--cuda_docker_uri, {inputValue: cuda_docker_uri}", component_spec + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/src/common/utils/gigl_runtime_test.py b/tests/unit/src/common/utils/gigl_runtime_test.py index 20d5ee064..374fe5a5f 100644 --- a/tests/unit/src/common/utils/gigl_runtime_test.py +++ b/tests/unit/src/common/utils/gigl_runtime_test.py @@ -128,3 +128,67 @@ def test_initialize_gigl_runtime_sets_env_and_initializes_metrics(self) -> None: "data-preprocessor-service", ) self.assertIsInstance(get_metrics_service_instance(), NopMetricsPublisher) + + def test_initialize_gigl_runtime_skips_env_for_scala_components(self) -> None: + task_config_uri = self._write_task_config(gbml_config_pb2.GbmlConfig()) + + for component in ( + GiGLComponents.SubgraphSampler, + GiGLComponents.SplitGenerator, + ): + with self.subTest(component=component): + with patch.dict(os.environ, {}, clear=True): + initialize_gigl_runtime( + applied_task_identifier="job-42", + task_config_uri=task_config_uri, + resource_config_uri=Uri("gs://bucket/resource.yaml"), + service_name="scala-service", + component=component, + ) + + # Runtime env vars are not set for legacy Scala/Spark components. + self.assertNotIn(GIGL_COMPONENT_ENV_KEY, os.environ) + self.assertNotIn(GIGL_TASK_CONFIG_URI_ENV_KEY, os.environ) + self.assertNotIn(GIGL_RESOURCE_CONFIG_URI_ENV_KEY, os.environ) + self.assertNotIn(GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY, os.environ) + self.assertNotIn(GIGL_CPU_DOCKER_URI_ENV_KEY, os.environ) + self.assertNotIn(GIGL_CUDA_DOCKER_URI_ENV_KEY, os.environ) + + # Metrics are still initialized. + self.assertEqual( + os.environ[JOB_NAME_GROUPING_ENV_KEY], + "scala-service", + ) + self.assertIsInstance( + get_metrics_service_instance(), NopMetricsPublisher + ) + + def test_initialize_gigl_runtime_preserves_existing_image_env_when_args_omitted( + self, + ) -> None: + task_config_uri = self._write_task_config(gbml_config_pb2.GbmlConfig()) + + with patch.dict( + os.environ, + { + GIGL_CPU_DOCKER_URI_ENV_KEY: "gcr.io/env/cpu:tag", + GIGL_CUDA_DOCKER_URI_ENV_KEY: "gcr.io/env/cuda:tag", + }, + clear=False, + ): + initialize_gigl_runtime( + applied_task_identifier="job-42", + task_config_uri=task_config_uri, + resource_config_uri=Uri("gs://bucket/resource.yaml"), + service_name="trainer-service", + component=GiGLComponents.Trainer, + ) + + self.assertEqual( + os.environ[GIGL_CPU_DOCKER_URI_ENV_KEY], + "gcr.io/env/cpu:tag", + ) + self.assertEqual( + os.environ[GIGL_CUDA_DOCKER_URI_ENV_KEY], + "gcr.io/env/cuda:tag", + ) diff --git a/tests/unit/src/config_populator/config_populator_functionality_test.py b/tests/unit/src/config_populator/config_populator_functionality_test.py index 440b4cc95..a4c177f95 100644 --- a/tests/unit/src/config_populator/config_populator_functionality_test.py +++ b/tests/unit/src/config_populator/config_populator_functionality_test.py @@ -1,10 +1,16 @@ +from unittest.mock import patch + from parameterized import param, parameterized +from gigl.common import GcsUri, LocalUri from gigl.common.logger import Logger +from gigl.common.metrics.base_metrics import NopMetricsPublisher +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.pb_wrappers.dataset_metadata import DatasetMetadataPbWrapper from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.task_metadata import TaskMetadataType +from gigl.src.common.utils import metrics_service_provider from gigl.src.common.utils.time import current_formatted_datetime from gigl.src.config_populator.config_populator import ConfigPopulator from snapchat.research.gbml import ( @@ -223,6 +229,37 @@ def test_glt_config_population_is_accurate( "", ) + def test_run_forwards_docker_uris_to_initialize_gigl_runtime(self) -> None: + config_populator = ConfigPopulator() + frozen_uri = GcsUri("gs://bucket/frozen.yaml") + + with ( + patch( + "gigl.src.config_populator.config_populator.initialize_gigl_runtime" + ) as mock_initialize_runtime, + patch("gigl.src.config_populator.config_populator.get_resource_config"), + patch.object( + ConfigPopulator, "_ConfigPopulator__run", return_value=frozen_uri + ), + patch.object( + metrics_service_provider, "_metrics_instance", NopMetricsPublisher() + ), + ): + result = config_populator.run( + applied_task_identifier=self.applied_task_identifier, + task_config_uri=LocalUri("/tmp/task.yaml"), + resource_config_uri=LocalUri("/tmp/resource.yaml"), + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + ) + + self.assertEqual(result, frozen_uri) + mock_initialize_runtime.assert_called_once() + _, kwargs = mock_initialize_runtime.call_args + self.assertEqual(kwargs["component"], GiGLComponents.ConfigPopulator) + self.assertEqual(kwargs["cpu_docker_uri"], "gcr.io/p/cpu:tag") + self.assertEqual(kwargs["cuda_docker_uri"], "gcr.io/p/cuda:tag") + def setUp(self) -> None: self.applied_task_identifier = AppliedTaskIdentifier( f"test_config_populator_functionality_{current_formatted_datetime()}" diff --git a/tests/unit/src/training/v1_trainer_test.py b/tests/unit/src/training/v1_trainer_test.py new file mode 100644 index 000000000..a9dc06638 --- /dev/null +++ b/tests/unit/src/training/v1_trainer_test.py @@ -0,0 +1,92 @@ +from unittest.mock import patch + +from absl.testing import absltest + +from gigl.common import Uri +from gigl.env.constants import ( + GIGL_COMPONENT_ENV_KEY, + GIGL_CPU_DOCKER_URI_ENV_KEY, + GIGL_CUDA_DOCKER_URI_ENV_KEY, +) +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( + GiglResourceConfigWrapper, +) +from gigl.src.training.v1.trainer import Trainer +from snapchat.research.gbml import gigl_resource_config_pb2 +from tests.test_assets.test_case import TestCase + + +def _build_resource_config_with_vertex_ai_trainer() -> ( + gigl_resource_config_pb2.GiglResourceConfig +): + return gigl_resource_config_pb2.GiglResourceConfig( + shared_resource_config=gigl_resource_config_pb2.SharedResourceConfig( + resource_labels={ + "env": "test", + "cost_resource_group_tag": "unittest_COMPONENT", + "cost_resource_group": "gigl_test", + }, + common_compute_config=( + gigl_resource_config_pb2.SharedResourceConfig.CommonComputeConfig( + project="test-project", + region="us-central1", + temp_assets_bucket="gs://test-temp-bucket", + temp_regional_assets_bucket="gs://test-temp-regional-bucket", + perm_assets_bucket="gs://test-perm-bucket", + temp_assets_bq_dataset_name="test_temp_dataset", + embedding_bq_dataset_name="test_embeddings_dataset", + gcp_service_account_email="test-sa@test-project.iam.gserviceaccount.com", + dataflow_runner="DataflowRunner", + ) + ), + ), + trainer_resource_config=gigl_resource_config_pb2.TrainerResourceConfig( + vertex_ai_trainer_config=gigl_resource_config_pb2.VertexAiResourceConfig( + machine_type="n1-standard-8", + num_replicas=1, + ), + ), + ) + + +class TrainerV1Test(TestCase): + @patch("gigl.src.training.v1.trainer.VertexAIService") + @patch("gigl.src.training.v1.trainer.get_resource_config") + def test_vertex_ai_training_process_receives_runtime_image_context( + self, + mock_get_resource_config, + mock_vertex_ai_service, + ) -> None: + mock_get_resource_config.return_value = GiglResourceConfigWrapper( + resource_config=_build_resource_config_with_vertex_ai_trainer() + ) + + cpu_docker_uri = "gcr.io/project/cpu:tag" + cuda_docker_uri = "gcr.io/project/cuda:tag" + Trainer().run( + applied_task_identifier=AppliedTaskIdentifier("job_1"), + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) + + mock_vertex_ai_service.return_value.launch_job.assert_called_once() + job_config = mock_vertex_ai_service.return_value.launch_job.call_args.kwargs[ + "job_config" + ] + self.assertIn(f"--cpu_docker_uri={cpu_docker_uri}", job_config.args) + self.assertIn(f"--cuda_docker_uri={cuda_docker_uri}", job_config.args) + + env_vars = { + env_var.name: env_var.value for env_var in job_config.environment_variables + } + self.assertEqual(env_vars[GIGL_CPU_DOCKER_URI_ENV_KEY], cpu_docker_uri) + self.assertEqual(env_vars[GIGL_CUDA_DOCKER_URI_ENV_KEY], cuda_docker_uri) + self.assertEqual(env_vars[GIGL_COMPONENT_ENV_KEY], GiGLComponents.Trainer.name) + + +if __name__ == "__main__": + absltest.main()