diff --git a/configs/test/batch/batch.yaml b/configs/test/batch/batch.yaml index 12788e149d4..a77ee1e01ef 100644 --- a/configs/test/batch/batch.yaml +++ b/configs/test/batch/batch.yaml @@ -76,6 +76,9 @@ mapping: name: east4-network2 weight: 1 project: 'test-clusterfuzz' +queue_check_regions: + - us-central1 + - us-east4 subconfigs: central1-network1: region: 'us-central1' diff --git a/src/clusterfuzz/_internal/base/memoize.py b/src/clusterfuzz/_internal/base/memoize.py index ae3063b0ce4..8b214f9f91a 100644 --- a/src/clusterfuzz/_internal/base/memoize.py +++ b/src/clusterfuzz/_internal/base/memoize.py @@ -17,6 +17,7 @@ import functools import json import threading +import time from clusterfuzz._internal.base import persistent_cache from clusterfuzz._internal.metrics import logs @@ -89,6 +90,30 @@ def get_key(self, func, args, kwargs): return _default_key(func, args, kwargs) +class InMemory(FifoInMemory): + """In-memory caching engine with TTL.""" + + def __init__(self, ttl_in_seconds, capacity=1000): + super().__init__(capacity) + self.ttl_in_seconds = ttl_in_seconds + + def put(self, key, value): + """Put (key, value) into cache.""" + super().put(key, (value, time.time() + self.ttl_in_seconds)) + + def get(self, key): + """Get the value from cache.""" + entry = super().get(key) + if entry is None: + return None + + value, expiry = entry + if expiry < time.time(): + return None + + return value + + class FifoOnDisk: """On-disk caching engine.""" diff --git a/src/clusterfuzz/_internal/batch/service.py b/src/clusterfuzz/_internal/batch/service.py index 6d7a11ab9bc..f8b396d1c4a 100644 --- a/src/clusterfuzz/_internal/batch/service.py +++ b/src/clusterfuzz/_internal/batch/service.py @@ -1,8 +1,8 @@ # Copyright 2025 Google LLC # -# Licensed under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -18,14 +18,19 @@ and provides a simple interface for scheduling ClusterFuzz tasks. """ import collections +import json +import random import threading from typing import Dict from typing import List from typing import Tuple +import urllib.request import uuid +import google.auth.transport.requests from google.cloud import batch_v1 as batch +from clusterfuzz._internal.base import memoize from clusterfuzz._internal.base import retry from clusterfuzz._internal.base import tasks from clusterfuzz._internal.base import utils @@ -65,6 +70,13 @@ # See https://cloud.google.com/batch/quotas#job_limits MAX_CONCURRENT_VMS_PER_JOB = 1000 +MAX_QUEUE_SIZE = 100 + + +class AllRegionsOverloadedError(Exception): + """Raised when all batch regions are overloaded.""" + + _local = threading.local() DEFAULT_RETRY_COUNT = 0 @@ -184,6 +196,50 @@ def count_queued_or_scheduled_tasks(project: str, return (queued, scheduled) +@memoize.wrap(memoize.InMemory(60)) +def get_region_load(project: str, region: str) -> int: + """Gets the current load (queued and scheduled jobs) for a region.""" + creds, _ = credentials.get_default() + if not creds.valid: + creds.refresh(google.auth.transport.requests.Request()) + + headers = { + 'Authorization': f'Bearer {creds.token}', + 'Content-Type': 'application/json' + } + + try: + url = (f'https://batch.googleapis.com/v1alpha/projects/{project}/locations/' + f'{region}/jobs:countByState?states=QUEUED') + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req) as response: + if response.status != 200: + logs.error( + f'Batch countByState failed: {response.status} {response.read()}') + return 0 + + data = json.loads(response.read()) + logs.info(f'Batch countByState response for {region}: {data}') + # The API returns a list of state counts. + # Example: { "jobCounts": { "QUEUED": "10" } } + total = 0 + + # Log data for debugging first few times if needed, or just rely on structure. + # We'll assume the structure is standard for Google APIs. + job_counts = data.get('jobCounts', {}) + for state, count in job_counts.items(): + count = int(count) + if state == 'QUEUED': + total += count + else: + logs.error(f'Unknown state: {state}') + + return total + except Exception as e: + logs.error(f'Failed to get region load for {region}: {e}') + return 0 + + def _get_batch_config(): """Returns the batch config. This function was made to make mocking easier.""" return local_config.BatchConfig() @@ -191,7 +247,7 @@ def _get_batch_config(): def is_remote_task(command: str, job_name: str) -> bool: """Returns whether a task is configured to run remotely on GCP Batch. - + This is determined by checking if a valid batch workload specification can be found for the given command and job type. """ @@ -242,15 +298,46 @@ def _get_config_names(batch_tasks: List[remote_task_types.RemoteTask]): def _get_subconfig(batch_config, instance_spec): - # TODO(metzman): Make this pick one at random or based on conditions. all_subconfigs = batch_config.get('subconfigs', {}) instance_subconfigs = instance_spec['subconfigs'] - weighted_subconfigs = [ - WeightedSubconfig(subconfig['name'], subconfig['weight']) - for subconfig in instance_subconfigs - ] - weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs) - return all_subconfigs[weighted_subconfig.name] + + queue_check_regions = batch_config.get('queue_check_regions') + if not queue_check_regions: + logs.info( + 'Skipping batch load check because queue_check_regions is not configured.' + ) + weighted_subconfigs = [ + WeightedSubconfig(subconfig['name'], subconfig['weight']) + for subconfig in instance_subconfigs + ] + weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs) + return all_subconfigs[weighted_subconfig.name] + + # Check load for configured regions. + healthy_subconfigs = [] + project = batch_config.get('project') + + for subconfig in instance_subconfigs: + name = subconfig['name'] + conf = all_subconfigs[name] + region = conf['region'] + + if region in queue_check_regions: + load = get_region_load(project, region) + logs.info(f'Region {region} has {load} queued jobs.') + if load >= MAX_QUEUE_SIZE: + logs.info(f'Region {region} overloaded (load={load}). Skipping.') + continue + + healthy_subconfigs.append(name) + + if not healthy_subconfigs: + logs.error('All candidate regions are overloaded.') + raise AllRegionsOverloadedError('All candidate regions are overloaded.') + + # Randomly pick one from healthy regions to avoid thundering herd. + chosen_name = random.choice(healthy_subconfigs) + return all_subconfigs[chosen_name] def _get_specs_from_config( @@ -277,7 +364,6 @@ def _get_specs_from_config( versioned_images_map = instance_spec.get('versioned_docker_images') if (base_os_version and versioned_images_map and base_os_version in versioned_images_map): - # New path: Use the versioned image if specified and available. docker_image_uri = versioned_images_map[base_os_version] else: # Fallback/legacy path: Use the original docker_image key. @@ -324,7 +410,7 @@ def _get_specs_from_config( class GcpBatchService(remote_task_types.RemoteTaskInterface): """A high-level service for creating and managing remote tasks. - + This service provides a simple interface for scheduling ClusterFuzz tasks on GCP Batch. It handles the details of creating batch jobs and tasks, and provides a way to check if a task is configured to run remotely. @@ -376,32 +462,37 @@ def create_utask_main_job(self, module: str, job_type: str, remote_task_types.RemoteTask(command, job_type, input_download_url) ] result = self.create_utask_main_jobs(batch_tasks) - if result is None: - return result + if not result: + return None return result[0] def create_utask_main_jobs(self, remote_tasks: List[remote_task_types.RemoteTask]): """Creates a batch job for a list of uworker main tasks. - + This method groups the tasks by their workload specification and creates a separate batch job for each group. This allows tasks with similar requirements to be processed together, which can improve efficiency. """ job_specs = collections.defaultdict(list) - specs = _get_specs_from_config(remote_tasks) + try: + specs = _get_specs_from_config(remote_tasks) + + # Return the remote tasks as uncreated task + # if all regions are overloaded + except AllRegionsOverloadedError: + return remote_tasks + for remote_task in remote_tasks: logs.info(f'Scheduling {remote_task.command}, {remote_task.job_type}.') spec = specs[(remote_task.command, remote_task.job_type)] job_specs[spec].append(remote_task.input_download_url) logs.info('Creating batch jobs.') - jobs = [] - logs.info('Batching utask_mains.') for spec, input_urls in job_specs.items(): for input_urls_portion in utils.batched(input_urls, MAX_CONCURRENT_VMS_PER_JOB - 1): - jobs.append(self.create_job(spec, input_urls_portion).name) + self.create_job(spec, input_urls_portion).name - return jobs + return [] diff --git a/src/clusterfuzz/_internal/tests/core/batch/batch_service_test.py b/src/clusterfuzz/_internal/tests/core/batch/batch_service_test.py index e75969f27bc..118621e51d5 100644 --- a/src/clusterfuzz/_internal/tests/core/batch/batch_service_test.py +++ b/src/clusterfuzz/_internal/tests/core/batch/batch_service_test.py @@ -217,6 +217,22 @@ def test_create_uworker_main_batch_jobs(self): mock.call(expected_create_request_2), ]) + def test_create_uworker_main_batch_jobs_all_regions_overloaded(self): + """Tests that create_utask_main_jobs returns tasks when all regions are overloaded.""" + tasks = [ + remote_task_types.RemoteTask('command1', 'job1', 'url1'), + remote_task_types.RemoteTask('command2', 'job2', 'url2'), + ] + with mock.patch('clusterfuzz._internal.batch.service._get_specs_from_config' + ) as mock_get_specs_from_config: + mock_get_specs_from_config.side_effect = batch_service.AllRegionsOverloadedError( + 'All regions overloaded') + + result = self.batch_service.create_utask_main_jobs(tasks) + + self.assertEqual(result, tasks) + self.mock_batch_client_instance.create_job.assert_not_called() + def test_create_uworker_main_batch_job(self): """Tests that create_utask_main_job works as expected.""" # Create mock data. @@ -254,7 +270,7 @@ def test_create_uworker_main_batch_job(self): UUIDS[0], spec1, ['url1']) self.mock_batch_client_instance.create_job.assert_called_with( expected_create_request) - self.assertEqual(result, 'job') + self.assertEqual(result, None) @test_utils.with_cloud_emulators('datastore') @@ -278,10 +294,141 @@ def test_is_remote_task(self): self.assertFalse(batch_service.is_remote_task('progression', 'job')) -if __name__ == '__main__': - unittest.main() +@test_utils.with_cloud_emulators('datastore') +class GetRegionLoadTest(unittest.TestCase): + """Tests for get_region_load.""" + + def setUp(self): + helpers.patch(self, [ + 'urllib.request.urlopen', + ]) + + def test_get_region_load_success(self): + """Tests get_region_load with a successful API response.""" + mock_response = mock.Mock() + mock_response.status = 200 + mock_response.read.return_value = b'{"jobCounts": {"QUEUED": "15"}}' + self.mock.urlopen.return_value.__enter__.return_value = mock_response + + load = batch_service.get_region_load('project_success', 'us-central1') + self.assertEqual(load, 15) -# pylint: disable=protected-access + def test_get_region_load_empty(self): + """Tests get_region_load with an empty response.""" + mock_response = mock.Mock() + mock_response.status = 200 + mock_response.read.return_value = b'{}' + self.mock.urlopen.return_value.__enter__.return_value = mock_response + + load = batch_service.get_region_load('project_empty', 'us-central1') + self.assertEqual(load, 0) + + def test_get_region_load_error(self): + """Tests get_region_load with an API error.""" + self.mock.urlopen.side_effect = Exception('error') + + load = batch_service.get_region_load('project_error', 'us-central1') + self.assertEqual(load, 0) + + +@test_utils.with_cloud_emulators('datastore') +class GetSubconfigLoadBalancingTest(unittest.TestCase): + """Tests for load balancing in _get_subconfig.""" + + def setUp(self): + helpers.patch(self, [ + 'clusterfuzz._internal.batch.service.get_region_load', + 'clusterfuzz._internal.batch.service.random.choice', + 'clusterfuzz._internal.base.utils.random_weighted_choice', + ]) + self.batch_config = { + 'project': 'test-project', + 'queue_check_regions': ['us-central1', 'us-east4'], + 'subconfigs': { + 'central1': { + 'region': 'us-central1', + 'network': 'n1' + }, + 'east4': { + 'region': 'us-east4', + 'network': 'n2' + }, + 'west1': { + 'region': 'us-west1', + 'network': 'n3' + }, + } + } + self.instance_spec = { + 'subconfigs': [ + { + 'name': 'central1', + 'weight': 1 + }, + { + 'name': 'east4', + 'weight': 1 + }, + ] + } + + def test_all_regions_healthy(self): + """Tests that a region is picked when all are healthy.""" + self.mock.get_region_load.return_value = 2 # Total load 2 < 100 + self.mock.choice.side_effect = lambda x: x[0] + + subconfig = batch_service._get_subconfig(self.batch_config, + self.instance_spec) + self.assertEqual(subconfig['region'], 'us-central1') + + def test_one_region_overloaded(self): + """Tests that overloaded regions are skipped.""" + # us-central1 (load 100) is overloaded, us-east4 (load 2) is healthy. + self.mock.get_region_load.side_effect = [ + 100, # us-central1 + 2, # us-east4 + ] + + # random.choice should only see ['east4'] + def mock_choice(items): + self.assertEqual(items, ['east4']) + return items[0] + + self.mock.choice.side_effect = mock_choice + + subconfig = batch_service._get_subconfig(self.batch_config, + self.instance_spec) + self.assertEqual(subconfig['region'], 'us-east4') + + def test_all_regions_overloaded(self): + """Tests that AllRegionsOverloadedError is raised when no healthy regions exist.""" + self.mock.get_region_load.return_value = 100 # Load 100 is threshold for "overloaded" + + with self.assertRaises(batch_service.AllRegionsOverloadedError): + batch_service._get_subconfig(self.batch_config, self.instance_spec) + + def test_skip_load_check_if_not_in_config(self): + """Tests that load check is skipped for regions not in queue_check_regions.""" + instance_spec = {'subconfigs': [{'name': 'central1', 'weight': 1},]} + self.batch_config['queue_check_regions'] = [ + ] # Empty list, so central1 is not checked + self.mock.random_weighted_choice.return_value = mock.Mock(name='central1') + self.mock.random_weighted_choice.return_value.name = 'central1' + + subconfig = batch_service._get_subconfig(self.batch_config, instance_spec) + self.assertEqual(subconfig['region'], 'us-central1') + self.assertFalse(self.mock.get_region_load.called) + + def test_skip_load_check_if_disabled(self): + """Tests that load check is skipped if queue_check_regions is missing.""" + del self.batch_config['queue_check_regions'] + self.mock.random_weighted_choice.return_value = mock.Mock(name='central1') + self.mock.random_weighted_choice.return_value.name = 'central1' + + subconfig = batch_service._get_subconfig(self.batch_config, + self.instance_spec) + self.assertEqual(subconfig['region'], 'us-central1') + self.assertFalse(self.mock.get_region_load.called) @test_utils.with_cloud_emulators('datastore') @@ -294,11 +441,15 @@ def setUp(self): self.job.put() helpers.patch(self, [ 'clusterfuzz._internal.base.utils.random_weighted_choice', + 'clusterfuzz._internal.batch.service.random.choice', + 'clusterfuzz._internal.batch.service.get_region_load', ]) self.mock.random_weighted_choice.return_value = batch_service.WeightedSubconfig( name='east4-network2', weight=1, ) + self.mock.choice.return_value = 'east4-network2' + self.mock.get_region_load.return_value = 0 def test_nonpreemptible(self): """Tests that _get_specs_from_config works for non-preemptibles as