diff --git a/collectoss/tasks/github/messages.py b/collectoss/tasks/github/messages.py index 342eeb2ca..a477409dd 100644 --- a/collectoss/tasks/github/messages.py +++ b/collectoss/tasks/github/messages.py @@ -1,12 +1,13 @@ import logging from datetime import timedelta, timezone +from typing import Generator from collectoss.tasks.init.celery_app import celery_app as celery from collectoss.tasks.init.celery_app import CoreRepoCollectionTask from collectoss.application.db.data_parse import * from collectoss.tasks.github.util.github_data_access import GithubDataAccess, UrlNotFoundException from collectoss.tasks.github.util.github_task_session import GithubTaskManifest -from collectoss.tasks.util.worker_util import remove_duplicate_dicts +from collectoss.tasks.util.worker_util import batched, remove_duplicate_dicts from collectoss.tasks.github.util.util import get_owner_repo from collectoss.application.db.models import PullRequest, Message, Issue, PullRequestMessageRef, IssueMessageRef, Contributor, Repo, CollectionStatus from collectoss.application.db import get_engine, get_session @@ -38,18 +39,14 @@ def collect_github_messages(repo_git: str, full_collection: bool) -> None: # subtract 2 days to ensure all data is collected core_data_last_collected = (get_core_data_last_collected(repo_id) - timedelta(days=2)).replace(tzinfo=timezone.utc) + message_data = fast_retrieve_all_pr_and_issue_messages(repo_git, logger, None, task_name, core_data_last_collected) - if is_repo_small(repo_id): - message_data = fast_retrieve_all_pr_and_issue_messages(repo_git, logger, manifest.key_auth, task_name, core_data_last_collected) - - if message_data: - process_messages(message_data, task_name, repo_id, logger, db_session) - - else: - logger.info(f"{owner}/{repo} has no messages") + if message_data: + for batch in batched(message_data, 1000): + process_messages(batch, task_name, repo_id, logger, db_session) else: - process_large_issue_and_pr_message_collection(repo_id, repo_git, logger, manifest.key_auth, task_name, db_session, core_data_last_collected) + logger.info(f"{owner}/{repo} has no messages") def is_repo_small(repo_id): @@ -60,7 +57,7 @@ def is_repo_small(repo_id): return result != None -def fast_retrieve_all_pr_and_issue_messages(repo_git: str, logger, key_auth, task_name, since) -> None: +def fast_retrieve_all_pr_and_issue_messages(repo_git: str, logger, key_auth, task_name, since) -> Generator: owner, repo = get_owner_repo(repo_git) @@ -73,70 +70,14 @@ def fast_retrieve_all_pr_and_issue_messages(repo_git: str, logger, key_auth, tas # define logger for task logger.info(f"Collecting github comments for {owner}/{repo}") - github_data_access = GithubDataAccess(key_auth, logger) + github_data_access = GithubDataAccess(None, logger) message_count = github_data_access.get_resource_count(url) logger.info(f"{task_name}: Collecting {message_count} github messages") - return list(github_data_access.paginate_resource(url)) - - -def process_large_issue_and_pr_message_collection(repo_id, repo_git: str, logger, key_auth, task_name, db_session, since) -> None: - - message_batch_size = get_batch_size("message") - - owner, repo = get_owner_repo(repo_git) - - # define logger for task - logger.info(f"Collecting github comments for {owner}/{repo}") - - engine = get_engine() - - with engine.connect() as connection: + return github_data_access.paginate_resource(url) - if since: - query = text(f""" - (select pr_comments_url from pull_requests WHERE repo_id={repo_id} AND pr_comments_url IS NOT NULL AND pr_updated_at > timestamptz(timestamp '{since}') order by pr_created_at desc) - UNION - (select comments_url as comment_url from issues WHERE repo_id={repo_id} AND comments_url IS NOT NULL AND updated_at > timestamptz(timestamp '{since}') order by created_at desc); - """) - else: - - query = text(f""" - (select pr_comments_url from pull_requests WHERE repo_id={repo_id} AND pr_comments_url IS NOT NULL order by pr_created_at desc) - UNION - (select comments_url as comment_url from issues WHERE repo_id={repo_id} AND comments_url IS NOT NULL order by created_at desc); - """) - - - result = connection.execute(query).fetchall() - comment_urls = [x[0] for x in result if x[0] is not None] - - github_data_access = GithubDataAccess(key_auth, logger) - - logger.info(f"{task_name}: Collecting github messages for {len(comment_urls)} prs/issues") - - all_data = [] - skipped_urls = 0 - - for comment_url in comment_urls: - try: - messages = list(github_data_access.paginate_resource(comment_url)) - all_data += messages - except UrlNotFoundException: - logger.info(f"{task_name}: PR or issue comment url of {comment_url} returned 404. Skipping.") - skipped_urls += 1 - - if len(all_data) >= message_batch_size: - process_messages(all_data, task_name, repo_id, logger, db_session) - all_data.clear() - - if len(all_data) > 0: - process_messages(all_data, task_name, repo_id, logger, db_session) - - logger.info(f"{task_name}: Finished. Skipped {skipped_urls} comment URLs due to 404.") - def process_messages(messages, task_name, repo_id, logger, db_session): diff --git a/collectoss/tasks/util/worker_util.py b/collectoss/tasks/util/worker_util.py index 7f315d5b0..f49d06dda 100644 --- a/collectoss/tasks/util/worker_util.py +++ b/collectoss/tasks/util/worker_util.py @@ -14,6 +14,15 @@ from typing_extensions import deprecated from collectoss.tasks.util.metadata_exception import MetadataException +from itertools import islice + + +def batched(iterable, n): + """Batch items from an iterable into lists of size n""" + it = iter(iterable) + while chunk := list(islice(it, n)): + yield chunk + def create_grouped_task_load(*args,processes=8,dataList=[],task=None): diff --git a/tests/test_tasks/test_task_utilities/test_util/test_worker_util.py b/tests/test_tasks/test_task_utilities/test_util/test_worker_util.py index 410c1ef70..1b874f36c 100644 --- a/tests/test_tasks/test_task_utilities/test_util/test_worker_util.py +++ b/tests/test_tasks/test_task_utilities/test_util/test_worker_util.py @@ -2,7 +2,7 @@ import pytest import sqlalchemy as s -from collectoss.tasks.util.worker_util import remove_duplicates_by_uniques +from collectoss.tasks.util.worker_util import remove_duplicates_by_uniques, batched logger = logging.getLogger(__name__) @@ -29,5 +29,57 @@ def test_remove_duplicates_by_uniques(): assert len(remove_duplicates_by_uniques([data_1, data_2], ["cntrb_id", "gh_user_id", "gh_login"])) == 2 +class TestBatched: + @pytest.mark.unit + def test_batched_evenly_divisible(self): + result = list(batched([1, 2, 3, 4, 5, 6], 3)) + assert result == [[1, 2, 3], [4, 5, 6]] + + @pytest.mark.unit + def test_batched_remainder(self): + result = list(batched([1, 2, 3, 4, 5], 3)) + assert result == [[1, 2, 3], [4, 5]] + + + @pytest.mark.unit + def test_batched_smaller_than_batch_size(self): + result = list(batched([1, 2], 100)) + assert result == [[1, 2]] + + + @pytest.mark.unit + def test_batched_empty(self): + result = list(batched([], 3)) + assert result == [] + + + @pytest.mark.unit + def test_batched_size_one(self): + result = list(batched([1, 2, 3], 1)) + assert result == [[1], [2], [3]] + + + @pytest.mark.unit + def test_batched_consumes_generator_lazily(self): + consumed = [] + + def tracking_gen(): + for i in range(6): + consumed.append(i) + yield i + + gen = batched(tracking_gen(), 3) + + next(gen) + assert consumed == [0, 1, 2], "should only have consumed first batch" + + next(gen) + assert consumed == [0, 1, 2, 3, 4, 5] + + + @pytest.mark.unit + def test_batched_returns_lists(self): + result = list(batched(range(4), 2)) + assert all(isinstance(chunk, list) for chunk in result) \ No newline at end of file