Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions configs/test/batch/batch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ mapping:
name: east4-network2
weight: 1
project: 'test-clusterfuzz'
queue_check_regions:
- us-central1
- us-east4
subconfigs:
central1-network1:
region: 'us-central1'
Expand Down
25 changes: 25 additions & 0 deletions src/clusterfuzz/_internal/base/memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import functools
import json
import threading
import time

from clusterfuzz._internal.base import persistent_cache
from clusterfuzz._internal.metrics import logs
Expand Down Expand Up @@ -89,6 +90,30 @@ def get_key(self, func, args, kwargs):
return _default_key(func, args, kwargs)


class InMemory(FifoInMemory):
"""In-memory caching engine with TTL."""

def __init__(self, ttl_in_seconds, capacity=1000):
super().__init__(capacity)
self.ttl_in_seconds = ttl_in_seconds

def put(self, key, value):
"""Put (key, value) into cache."""
super().put(key, (value, time.time() + self.ttl_in_seconds))

def get(self, key):
"""Get the value from cache."""
entry = super().get(key)
if entry is None:
return None

value, expiry = entry
if expiry < time.time():
return None

return value


class FifoOnDisk:
"""On-disk caching engine."""

Expand Down
133 changes: 112 additions & 21 deletions src/clusterfuzz/_internal/batch/service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0
Comment thread
javanlacerda marked this conversation as resolved.
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
Expand All @@ -18,14 +18,19 @@
and provides a simple interface for scheduling ClusterFuzz tasks.
"""
import collections
import json
import random
import threading
from typing import Dict
from typing import List
from typing import Tuple
import urllib.request
import uuid

import google.auth.transport.requests
from google.cloud import batch_v1 as batch

from clusterfuzz._internal.base import memoize
from clusterfuzz._internal.base import retry
from clusterfuzz._internal.base import tasks
from clusterfuzz._internal.base import utils
Expand Down Expand Up @@ -65,6 +70,13 @@
# See https://cloud.google.com/batch/quotas#job_limits
MAX_CONCURRENT_VMS_PER_JOB = 1000

MAX_QUEUE_SIZE = 100


class AllRegionsOverloadedError(Exception):
"""Raised when all batch regions are overloaded."""


_local = threading.local()

DEFAULT_RETRY_COUNT = 0
Expand Down Expand Up @@ -184,14 +196,58 @@ def count_queued_or_scheduled_tasks(project: str,
return (queued, scheduled)


@memoize.wrap(memoize.InMemory(60))
def get_region_load(project: str, region: str) -> int:
"""Gets the current load (queued and scheduled jobs) for a region."""
creds, _ = credentials.get_default()
if not creds.valid:
creds.refresh(google.auth.transport.requests.Request())

headers = {
'Authorization': f'Bearer {creds.token}',
'Content-Type': 'application/json'
}

try:
url = (f'https://batch.googleapis.com/v1alpha/projects/{project}/locations/'
f'{region}/jobs:countByState?states=QUEUED')
req = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(req) as response:
if response.status != 200:
logs.error(
f'Batch countByState failed: {response.status} {response.read()}')
return 0

data = json.loads(response.read())
logs.info(f'Batch countByState response for {region}: {data}')
# The API returns a list of state counts.
# Example: { "jobCounts": { "QUEUED": "10" } }
total = 0

# Log data for debugging first few times if needed, or just rely on structure.
# We'll assume the structure is standard for Google APIs.
job_counts = data.get('jobCounts', {})
for state, count in job_counts.items():
count = int(count)
if state == 'QUEUED':
total += count
else:
logs.error(f'Unknown state: {state}')

return total
except Exception as e:
logs.error(f'Failed to get region load for {region}: {e}')
return 0


def _get_batch_config():
"""Returns the batch config. This function was made to make mocking easier."""
return local_config.BatchConfig()


def is_remote_task(command: str, job_name: str) -> bool:
"""Returns whether a task is configured to run remotely on GCP Batch.

This is determined by checking if a valid batch workload specification can
be found for the given command and job type.
"""
Expand Down Expand Up @@ -242,15 +298,46 @@ def _get_config_names(batch_tasks: List[remote_task_types.RemoteTask]):


def _get_subconfig(batch_config, instance_spec):
# TODO(metzman): Make this pick one at random or based on conditions.
all_subconfigs = batch_config.get('subconfigs', {})
instance_subconfigs = instance_spec['subconfigs']
weighted_subconfigs = [
WeightedSubconfig(subconfig['name'], subconfig['weight'])
for subconfig in instance_subconfigs
]
weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs)
return all_subconfigs[weighted_subconfig.name]

queue_check_regions = batch_config.get('queue_check_regions')
if not queue_check_regions:
logs.info(
'Skipping batch load check because queue_check_regions is not configured.'
)
weighted_subconfigs = [
WeightedSubconfig(subconfig['name'], subconfig['weight'])
for subconfig in instance_subconfigs
]
weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs)
return all_subconfigs[weighted_subconfig.name]

# Check load for configured regions.
healthy_subconfigs = []
project = batch_config.get('project')

for subconfig in instance_subconfigs:
name = subconfig['name']
conf = all_subconfigs[name]
region = conf['region']

if region in queue_check_regions:
load = get_region_load(project, region)
logs.info(f'Region {region} has {load} queued jobs.')
if load >= MAX_QUEUE_SIZE:
logs.info(f'Region {region} overloaded (load={load}). Skipping.')
continue

healthy_subconfigs.append(name)

if not healthy_subconfigs:
logs.error('All candidate regions are overloaded.')
raise AllRegionsOverloadedError('All candidate regions are overloaded.')

# Randomly pick one from healthy regions to avoid thundering herd.
chosen_name = random.choice(healthy_subconfigs)
return all_subconfigs[chosen_name]


def _get_specs_from_config(
Expand All @@ -277,7 +364,6 @@ def _get_specs_from_config(
versioned_images_map = instance_spec.get('versioned_docker_images')
if (base_os_version and versioned_images_map and
base_os_version in versioned_images_map):
# New path: Use the versioned image if specified and available.
docker_image_uri = versioned_images_map[base_os_version]
else:
# Fallback/legacy path: Use the original docker_image key.
Expand Down Expand Up @@ -324,7 +410,7 @@ def _get_specs_from_config(

class GcpBatchService(remote_task_types.RemoteTaskInterface):
"""A high-level service for creating and managing remote tasks.

This service provides a simple interface for scheduling ClusterFuzz tasks on
GCP Batch. It handles the details of creating batch jobs and tasks, and
provides a way to check if a task is configured to run remotely.
Expand Down Expand Up @@ -376,32 +462,37 @@ def create_utask_main_job(self, module: str, job_type: str,
remote_task_types.RemoteTask(command, job_type, input_download_url)
]
result = self.create_utask_main_jobs(batch_tasks)
if result is None:
return result
if not result:
return None
return result[0]

def create_utask_main_jobs(self,
remote_tasks: List[remote_task_types.RemoteTask]):
"""Creates a batch job for a list of uworker main tasks.

This method groups the tasks by their workload specification and creates a
separate batch job for each group. This allows tasks with similar
requirements to be processed together, which can improve efficiency.
"""
job_specs = collections.defaultdict(list)
specs = _get_specs_from_config(remote_tasks)
try:
specs = _get_specs_from_config(remote_tasks)

# Return the remote tasks as uncreated task
# if all regions are overloaded
except AllRegionsOverloadedError:
return remote_tasks

for remote_task in remote_tasks:
logs.info(f'Scheduling {remote_task.command}, {remote_task.job_type}.')
spec = specs[(remote_task.command, remote_task.job_type)]
job_specs[spec].append(remote_task.input_download_url)

logs.info('Creating batch jobs.')
jobs = []

logs.info('Batching utask_mains.')
for spec, input_urls in job_specs.items():
for input_urls_portion in utils.batched(input_urls,
MAX_CONCURRENT_VMS_PER_JOB - 1):
jobs.append(self.create_job(spec, input_urls_portion).name)
self.create_job(spec, input_urls_portion).name

return jobs
return []
Loading