diff --git a/.env.example b/.env.example index 8f0cb921..e6ce9328 100644 --- a/.env.example +++ b/.env.example @@ -290,13 +290,22 @@ DATABASE_URL=postgres://user:password@localhost:5432/boost_dashboard # ============================================================================== # Reddit (reddit_activity_tracker) # ============================================================================== -# Register a "script" app at https://www.reddit.com/prefs/apps +# Register a "script" app at https://www.reddit.com/prefs/apps (preferred): # REDDIT_CLIENT_ID=your_client_id # REDDIT_CLIENT_SECRET=your_client_secret # REDDIT_USER_AGENT=r_cpp_scraper/1.0 by u/yourusername # -# Optional: minimum seconds between API requests (default 1.0, ~60 req/min) +# Alternative auth (when client credentials are unavailable): +# REDDIT_BEARER_TOKEN=token_v2_cookie_value (~24h) +# REDDIT_SESSION_COOKIE=reddit_session_cookie (~180d; auto-mints bearer) +# REDDIT_CSRF_TOKEN=csrf_token (optional; required if session mint fails) +# +# Optional rate limiting and discovery (defaults shown): # REQUEST_INTERVAL=1.0 +# RATE_LIMIT_LOW_WATERMARK=2.0 +# +# First run when DB is empty: scrape this many days back (default 30) +# REDDIT_DEFAULT_LOOKBACK_DAYS=30 # ============================================================================== # YouTube (cppa_youtube_script_tracker) diff --git a/config/boost_collector_schedule.yaml b/config/boost_collector_schedule.yaml index e23d3618..e747c5bb 100644 --- a/config/boost_collector_schedule.yaml +++ b/config/boost_collector_schedule.yaml @@ -54,3 +54,9 @@ groups: tasks: - command: run_boost_mailing_list_tracker schedule: daily + + reddit: + default_time: "17:00" + tasks: + - command: run_reddit_activity_tracker + schedule: daily diff --git a/config/settings.py b/config/settings.py index f9f625dc..f8918d4c 100644 --- a/config/settings.py +++ b/config/settings.py @@ -530,8 +530,14 @@ def _slack_team_scope_from_env(): REDDIT_CLIENT_ID = (env("REDDIT_CLIENT_ID", default="") or "").strip() REDDIT_CLIENT_SECRET = (env("REDDIT_CLIENT_SECRET", default="") or "").strip() REDDIT_USER_AGENT = (env("REDDIT_USER_AGENT", default="") or "").strip() +REDDIT_BEARER_TOKEN = (env("REDDIT_BEARER_TOKEN", default="") or "").strip() +REDDIT_SESSION_COOKIE = (env("REDDIT_SESSION_COOKIE", default="") or "").strip() +REDDIT_CSRF_TOKEN = (env("REDDIT_CSRF_TOKEN", default="") or "").strip() or None # Minimum seconds between API requests (default 1.0, ~60 req/min). Env: REQUEST_INTERVAL. REDDIT_REQUEST_INTERVAL = env.float("REQUEST_INTERVAL", default=1.0) +# Pause when X-Ratelimit-Remaining drops below this value. Env: RATE_LIMIT_LOW_WATERMARK. +REDDIT_RATE_LIMIT_LOW_WATERMARK = env.float("RATE_LIMIT_LOW_WATERMARK", default=2.0) +REDDIT_DEFAULT_LOOKBACK_DAYS = env.int("REDDIT_DEFAULT_LOOKBACK_DAYS", default=30) # WG21 Paper Tracker Configuration WG21_GITHUB_DISPATCH_ENABLED = env.bool("WG21_GITHUB_DISPATCH_ENABLED", default=False) diff --git a/core/_version.py b/core/_version.py index 7f90c03b..5873b74b 100644 --- a/core/_version.py +++ b/core/_version.py @@ -1,2 +1,2 @@ # file generated by setuptools-scm; do not edit -version = "0.1.1.dev579+g8b4cba29b.d20260609" +version = "0.1.1.dev584+g9efa67002.d20260612" diff --git a/cppa_user_tracker/admin.py b/cppa_user_tracker/admin.py index fcbcadba..0ba4f12b 100644 --- a/cppa_user_tracker/admin.py +++ b/cppa_user_tracker/admin.py @@ -7,6 +7,7 @@ GitHubAccount, Identity, MailingListProfile, + RedditUser, SlackUser, TempProfileIdentityRelation, TmpIdentity, @@ -99,3 +100,17 @@ class WG21PaperAuthorProfileAdmin(ModelAdmin): list_display = ("id", "identity", "display_name", "updated_at") search_fields = ("display_name",) raw_id_fields = ("identity",) + + +@admin.register(RedditUser) +class RedditUserAdmin(ModelAdmin): + list_display = ( + "id", + "identity", + "reddit_user_id", + "username", + "display_name", + "updated_at", + ) + search_fields = ("reddit_user_id", "username", "display_name") + raw_id_fields = ("identity",) diff --git a/cppa_user_tracker/migrations/0009_reddituser_alter_baseprofile_type.py b/cppa_user_tracker/migrations/0009_reddituser_alter_baseprofile_type.py new file mode 100644 index 00000000..ab9fee4a --- /dev/null +++ b/cppa_user_tracker/migrations/0009_reddituser_alter_baseprofile_type.py @@ -0,0 +1,66 @@ +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("cppa_user_tracker", "0008_wg21paperauthorprofile_author_alias"), + ] + + operations = [ + migrations.CreateModel( + name="RedditUser", + fields=[ + ( + "baseprofile_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="cppa_user_tracker.baseprofile", + ), + ), + ( + "reddit_user_id", + models.CharField( + blank=True, + db_index=True, + max_length=64, + null=True, + unique=True, + ), + ), + ( + "username", + models.CharField(db_index=True, max_length=255, unique=True), + ), + ( + "display_name", + models.CharField(blank=True, db_index=True, max_length=255), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ], + bases=("cppa_user_tracker.baseprofile",), + ), + migrations.AlterField( + model_name="baseprofile", + name="type", + field=models.CharField( + choices=[ + ("github", "GitHub"), + ("slack", "Slack"), + ("mailing_list", "Mailing list"), + ("wg21", "WG21"), + ("discord", "Discord"), + ("youtube", "YouTube"), + ("reddit", "Reddit"), + ], + db_index=True, + max_length=20, + ), + ), + ] diff --git a/cppa_user_tracker/models.py b/cppa_user_tracker/models.py index ed3d42da..8cf284e9 100644 --- a/cppa_user_tracker/models.py +++ b/cppa_user_tracker/models.py @@ -17,6 +17,7 @@ class ProfileType(models.TextChoices): WG21 = "wg21", "WG21" # pyright: ignore[reportCallIssue] DISCORD = "discord", "Discord" # pyright: ignore[reportCallIssue] YOUTUBE = "youtube", "YouTube" # pyright: ignore[reportCallIssue] + REDDIT = "reddit", "Reddit" # pyright: ignore[reportCallIssue] class GitHubAccountType(models.TextChoices): @@ -195,6 +196,26 @@ def save(self, *args, **kwargs): updated_at = models.DateTimeField(auto_now=True) +class RedditUser(BaseProfile): + """Profile for Reddit; extends BaseProfile.""" + + def save(self, *args, **kwargs): + self.type = ProfileType.REDDIT + super().save(*args, **kwargs) + + reddit_user_id = models.CharField( + max_length=64, + unique=True, + db_index=True, + null=True, + blank=True, + ) + username = models.CharField(max_length=255, unique=True, db_index=True) + display_name = models.CharField(max_length=255, db_index=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class YoutubeSpeaker(BaseProfile): """YouTube speaker profile. diff --git a/cppa_user_tracker/services.py b/cppa_user_tracker/services.py index 2cbd0b8a..5e0dd7d9 100644 --- a/cppa_user_tracker/services.py +++ b/cppa_user_tracker/services.py @@ -29,10 +29,13 @@ MailingListProfile, SlackUser, DiscordProfile, + RedditUser, WG21PaperAuthorProfile, YoutubeSpeaker, ) +_REDDIT_DELETED_AUTHORS = frozenset({"", "[deleted]", "AutoModerator"}) + # --- Identity --- def create_identity( @@ -451,3 +454,98 @@ def get_or_create_youtube_speaker( speaker.display_name = display_name_val speaker.save(update_fields=["display_name", "updated_at"]) return speaker, created + + +class RedditClientProtocol(Protocol): + """Protocol for a Reddit API client used by get_or_create_reddit_user.""" + + def fetch_user_about(self, username: str) -> dict[str, Any] | None: ... + + +def _normalize_reddit_username(author: str | None) -> str | None: + username = (author or "").strip() + if not username or username in _REDDIT_DELETED_AUTHORS: + return None + return username + + +def _display_name_from_reddit_profile( + profile: dict[str, Any] | None, username: str +) -> str: + if not profile: + return username + subreddit = profile.get("subreddit") or {} + if isinstance(subreddit, dict): + title = (subreddit.get("title") or "").strip() + if title: + return title + return username + + +@transaction.atomic +def get_or_create_reddit_user( + username: str, + *, + reddit_user_id: str | None = None, + display_name: str | None = None, + client: RedditClientProtocol | None = None, +) -> RedditUser | None: + """Get or create a RedditUser; call /user/about only when the user is new.""" + normalized = _normalize_reddit_username(username) + if not normalized: + return None + + existing = RedditUser.objects.filter(username=normalized).first() + if existing is not None: + return existing + + profile_data: dict[str, Any] | None = None + if client is not None: + profile_data = client.fetch_user_about(normalized) + + resolved_reddit_user_id = (reddit_user_id or "").strip() or None + if profile_data: + profile_id = (profile_data.get("id") or "").strip() + if profile_id: + resolved_reddit_user_id = f"t2_{profile_id}" + elif profile_data.get("fullname"): + resolved_reddit_user_id = str(profile_data.get("fullname")).strip() + + resolved_display_name = (display_name or "").strip() + if not resolved_display_name: + resolved_display_name = _display_name_from_reddit_profile( + profile_data, normalized + ) + + user, created = RedditUser.objects.get_or_create( + username=normalized, + defaults={ + "reddit_user_id": resolved_reddit_user_id, + "display_name": resolved_display_name, + }, + ) + if not created: + if resolved_reddit_user_id: + user.reddit_user_id = resolved_reddit_user_id + user.display_name = resolved_display_name or user.display_name + user.save() + return user + + +def resolve_reddit_user_from_author_data( + data: dict[str, Any], + *, + client: RedditClientProtocol | None = None, +) -> RedditUser | None: + """Resolve RedditUser from submission/comment author fields.""" + author = data.get("author") + author_fullname = data.get("author_fullname") + username = _normalize_reddit_username(author) + if not username: + return None + reddit_user_id = (author_fullname or "").strip() or None + return get_or_create_reddit_user( + username, + reddit_user_id=reddit_user_id, + client=client, + ) diff --git a/cppa_user_tracker/tests/test_services.py b/cppa_user_tracker/tests/test_services.py index fde04cf1..0586d0a9 100644 --- a/cppa_user_tracker/tests/test_services.py +++ b/cppa_user_tracker/tests/test_services.py @@ -1,5 +1,7 @@ """Tests for cppa_user_tracker.services.""" +from unittest.mock import MagicMock + import pytest from cppa_user_tracker.models import ( @@ -8,6 +10,7 @@ GitHubAccount, GitHubAccountType, Identity, + RedditUser, SlackUser, TempProfileIdentityRelation, WG21PaperAuthorProfile, @@ -806,6 +809,66 @@ def test_get_or_create_discord_profile_updates_existing(): assert profile.is_bot is True +# --- get_or_create_reddit_user --- + + +@pytest.mark.django_db +def test_get_or_create_reddit_user_creates_and_updates(): + client = MagicMock() + client.fetch_user_about.return_value = { + "id": "abc123", + "name": "Taladar", + "subreddit": {"title": "Taladar"}, + } + user = services.get_or_create_reddit_user( + "Taladar", + reddit_user_id="t2_old", + client=client, + ) + assert user is not None + assert user.username == "Taladar" + assert user.reddit_user_id == "t2_abc123" + assert user.display_name == "Taladar" + + user2 = services.get_or_create_reddit_user("Taladar", client=client) + assert user2.pk == user.pk + client.fetch_user_about.assert_called_once() + + +@pytest.mark.django_db +def test_get_or_create_reddit_user_skips_about_for_existing_user(): + RedditUser.objects.create( + username="Taladar", + reddit_user_id="t2_abc123", + display_name="Taladar", + ) + client = MagicMock() + user = services.get_or_create_reddit_user("Taladar", client=client) + assert user is not None + assert user.username == "Taladar" + client.fetch_user_about.assert_not_called() + + +@pytest.mark.django_db +def test_get_or_create_reddit_user_deleted_author_returns_none(): + assert services.get_or_create_reddit_user("[deleted]") is None + + +@pytest.mark.django_db +def test_resolve_reddit_user_from_author_data(): + client = MagicMock() + client.fetch_user_about.return_value = { + "id": "abc123", + "subreddit": {"title": "Taladar"}, + } + user = services.resolve_reddit_user_from_author_data( + {"author": "Taladar", "author_fullname": "t2_abc123"}, + client=client, + ) + assert user is not None + assert user.username == "Taladar" + + # --- get_or_create_youtube_speaker --- diff --git a/docs/service_api/cppa_user_tracker.md b/docs/service_api/cppa_user_tracker.md index 5717dd6f..2d921687 100644 --- a/docs/service_api/cppa_user_tracker.md +++ b/docs/service_api/cppa_user_tracker.md @@ -24,12 +24,14 @@ | `get_or_create_identity` | display_name: str = '', description: str = '', defaults: dict[str, Any] \| None = None | tuple[Identity, bool] | Get or create an Identity by display_name. If exists, updates description from defaults. | | `get_or_create_mailing_list_profile` | display_name: str = '', email: str = '' | tuple[MailingListProfile, bool] | Get or create a MailingListProfile by display_name and email. Returns (profile, created). | | `get_or_create_owner_account` | client: GitHubClientProtocol, owner: str | GitHubAccount | Get or create a GitHubAccount for an owner (org or user). For use by any app. | +| `get_or_create_reddit_user` | username: str, *, reddit_user_id: str \| None = None, display_name: str \| None = None, client: RedditClientProtocol \| None = None | RedditUser \| None | Get or create a RedditUser; call /user/about only when the user is new. | | `get_or_create_slack_user` | user_data: SlackUserPayload \| dict[str, Any] | tuple[SlackUser, bool] | Get or create a SlackUser from Slack API user data. Returns (SlackUser, created). | | `get_or_create_unknown_github_account` | name: str \| None = None, email: str = '' | tuple[GitHubAccount, bool] | Get or create a GitHubAccount for commits with no API author/committer. | | `get_or_create_wg21_paper_author_profile` | display_name: str, email: str \| None = None | tuple[WG21PaperAuthorProfile, bool] | Get or create a WG21PaperAuthorProfile by display_name, with optional email disambiguation. | | `get_or_create_youtube_speaker` | external_id: str, display_name: str = '', identity: Identity \| None = None | tuple[YoutubeSpeaker, bool] | Get or create a YoutubeSpeaker by external_id. Returns (speaker, created). | | `remove_email` | email_obj: Email | None | Remove an email from a profile. | | `remove_temp_profile_identity_relation` | base_profile: BaseProfile, target_identity: TmpIdentity | None | Remove the staging relation between base_profile and target_identity. | +| `resolve_reddit_user_from_author_data` | data: dict[str, Any], *, client: RedditClientProtocol \| None = None | RedditUser \| None | Resolve RedditUser from submission/comment author fields. | | `update_email` | email_obj: Email, **kwargs: Any | Email | Update an Email instance. Allowed keys: email, is_primary, is_active. | diff --git a/docs/service_api/reddit_activity_tracker.md b/docs/service_api/reddit_activity_tracker.md index b8966a0d..d1815976 100644 --- a/docs/service_api/reddit_activity_tracker.md +++ b/docs/service_api/reddit_activity_tracker.md @@ -12,6 +12,13 @@ | Function | Parameters | Return type | Summary | | --- | --- | --- | --- | +| `get_latest_comment_created_utc` | | int | Return max created_utc across comments, or 0 when empty. | +| `get_latest_submission_created_utc` | | int | Return max created_utc across submissions, or 0 when empty. | +| `get_or_create_submission_stub` | submission_id: str, *, subreddit: str = 'cpp' | RedditSubmission | Ensure a submission row exists for FK when only a comment link_id is known. | +| `resolve_submission_for_comment` | comment_data: dict, submissions_by_id: dict[str, RedditSubmission] | RedditSubmission | Return the submission row for a period comment, creating a stub if needed. | +| `submission_id_from_link_id` | link_id: str | str \| None | — | +| `upsert_reddit_comment` | data: dict[str, Any], submission: RedditSubmission, *, session: RedditSession \| None = None | RedditComment | Update or create a comment keyed by reddit_comment_id. | +| `upsert_reddit_submission` | data: dict[str, Any], *, session: RedditSession \| None = None | RedditSubmission | Update or create a submission keyed by reddit_submission_id. | diff --git a/reddit_activity_tracker/admin.py b/reddit_activity_tracker/admin.py index 1f697318..e7d837c0 100644 --- a/reddit_activity_tracker/admin.py +++ b/reddit_activity_tracker/admin.py @@ -6,9 +6,9 @@ @admin.register(RedditSubmission) class RedditSubmissionAdmin(admin.ModelAdmin): list_display = ( - "reddit_id", + "reddit_submission_id", "subreddit", - "author", + "user", "title", "score", "num_comments", @@ -16,21 +16,22 @@ class RedditSubmissionAdmin(admin.ModelAdmin): "fetched_at", ) list_filter = ("subreddit",) - search_fields = ("reddit_id", "title", "author") - ordering = ("-created_utc", "reddit_id") + search_fields = ("reddit_submission_id", "title", "user__username") + raw_id_fields = ("user",) + ordering = ("-created_utc", "reddit_submission_id") @admin.register(RedditComment) class RedditCommentAdmin(admin.ModelAdmin): list_display = ( - "reddit_id", + "reddit_comment_id", "submission", - "author", + "user", "score", "created_utc", "fetched_at", ) list_filter = ("submission__subreddit",) - search_fields = ("reddit_id", "author", "body") - raw_id_fields = ("submission",) - ordering = ("created_utc", "reddit_id") + search_fields = ("reddit_comment_id", "user__username", "body") + raw_id_fields = ("submission", "user") + ordering = ("created_utc", "reddit_comment_id") diff --git a/reddit_activity_tracker/fetcher.py b/reddit_activity_tracker/fetcher.py index 471e1f9c..f7dcb2e7 100644 --- a/reddit_activity_tracker/fetcher.py +++ b/reddit_activity_tracker/fetcher.py @@ -2,21 +2,117 @@ Reddit OAuth API client for reddit_activity_tracker. Ported from reddit-scraper/scraper.py (RedditSession + build_session). +Supports client credentials, bearer token, or session-cookie auth. """ from __future__ import annotations +import base64 +import json import logging import random import time +from datetime import datetime, timezone import requests from django.conf import settings logger = logging.getLogger(__name__) +SUBREDDIT = "cpp" MAX_RETRIES = 5 RETRY_BASE_DELAY = 2.0 +RATE_LIMIT_LOW_WATERMARK = getattr(settings, "REDDIT_RATE_LIMIT_LOW_WATERMARK", 2.0) +_SHREDDIT_TOKEN_URL = "https://www.reddit.com/svc/shreddit/token" +_PLACEHOLDER_VALUES = frozenset({"your_client_id", "your_client_secret"}) + + +def _normalize_bearer(token: str) -> str: + token = token.strip() + if token.lower().startswith("bearer "): + return token[7:].strip() + return token + + +def _jwt_expiry(token: str) -> float | None: + try: + parts = _normalize_bearer(token).split(".") + if len(parts) != 3: + return None + payload = parts[1] + "=" * (-len(parts[1]) % 4) + data = json.loads(base64.urlsafe_b64decode(payload)) + exp = data.get("exp") + return float(exp) if exp is not None else None + except (ValueError, json.JSONDecodeError, TypeError): + return None + + +def _is_bearer_expired(token: str, leeway: float = 60) -> bool: + exp = _jwt_expiry(token) + if exp is None: + return False + return time.time() >= exp - leeway + + +def mint_bearer_from_session( + session_cookie: str, + user_agent: str, + csrf_token: str | None = None, +) -> str: + """Exchange a reddit_session cookie for a fresh token_v2 bearer JWT.""" + sess = requests.Session() + sess.headers.update( + { + "User-Agent": user_agent, + "Content-Type": "application/json", + "Origin": "https://www.reddit.com", + } + ) + sess.cookies.set("reddit_session", session_cookie.strip(), domain=".reddit.com") + + csrf = csrf_token.strip() if csrf_token else None + if csrf: + sess.cookies.set("csrf_token", csrf, domain=".reddit.com") + else: + sess.get("https://www.reddit.com/", timeout=30) + csrf = sess.cookies.get("csrf_token") + + if not csrf: + raise RuntimeError( + "Could not obtain csrf_token — set REDDIT_CSRF_TOKEN in .env " + "(DevTools → Cookies → csrf_token)" + ) + + resp = sess.post(_SHREDDIT_TOKEN_URL, json={"csrf_token": csrf}, timeout=30) + if resp.status_code != 200: + raise RuntimeError( + f"Failed to mint bearer token from session (HTTP {resp.status_code})" + ) + + data = resp.json() + token = data.get("token") + if not token: + raise RuntimeError("Reddit token endpoint returned no token") + + expires = data.get("expires") + if expires: + logger.info( + "Minted bearer token from session (expires %s)", + datetime.fromtimestamp(expires / 1000, tz=timezone.utc).strftime( + "%Y-%m-%d %H:%M UTC" + ), + ) + else: + logger.info("Minted bearer token from session") + return token + + +def _credentials_configured(value: str | None) -> str | None: + if not value or not value.strip(): + return None + if value.strip() in _PLACEHOLDER_VALUES: + return None + return value.strip() class RedditSession: @@ -25,15 +121,53 @@ class RedditSession: _TOKEN_URL = "https://www.reddit.com/api/v1/access_token" _API_BASE = "https://oauth.reddit.com" - def __init__(self, client_id: str, client_secret: str, user_agent: str) -> None: + def __init__( + self, + client_id: str | None, + client_secret: str | None, + user_agent: str, + *, + bearer_token: str | None = None, + session_cookie: str | None = None, + csrf_token: str | None = None, + ) -> None: self._client_id = client_id self._client_secret = client_secret + self._user_agent = user_agent + self._session_cookie = session_cookie + self._csrf_token = csrf_token + self._bearer_mode = bearer_token is not None self._session = requests.Session() self._session.headers.update({"User-Agent": user_agent}) self._token_expiry: float = 0.0 self._last_request_at: float = 0.0 + self._remaining: float | None = None + self._reset: float | None = None + if bearer_token: + self._apply_bearer(bearer_token) + + def _apply_bearer(self, token: str) -> None: + token = _normalize_bearer(token) + self._session.headers["Authorization"] = f"Bearer {token}" + exp = _jwt_expiry(token) + self._token_expiry = exp if exp is not None else float("inf") + + def _remint_bearer_from_session(self) -> None: + if not self._session_cookie: + raise RuntimeError( + "Bearer token expired and no REDDIT_SESSION_COOKIE available to re-mint" + ) + logger.info("Re-minting bearer token from REDDIT_SESSION_COOKIE...") + self._apply_bearer( + mint_bearer_from_session( + self._session_cookie, self._user_agent, self._csrf_token + ) + ) def _refresh_token(self) -> None: + if self._bearer_mode: + self._remint_bearer_from_session() + return logger.info("Obtaining OAuth token...") auth = requests.auth.HTTPBasicAuth(self._client_id, self._client_secret) resp = self._session.post( @@ -52,22 +186,64 @@ def _refresh_token(self) -> None: ) def _ensure_token(self) -> None: + if self._bearer_mode: + if time.time() >= self._token_expiry and self._session_cookie: + self._remint_bearer_from_session() + return if time.time() >= self._token_expiry: self._refresh_token() + def _update_rate_limit_state(self, resp: requests.Response) -> None: + remaining = resp.headers.get("X-Ratelimit-Remaining") + reset = resp.headers.get("X-Ratelimit-Reset") + if remaining is not None: + self._remaining = float(remaining) + if reset is not None: + self._reset = float(reset) + + def _backoff_seconds(self, resp: requests.Response | None, delay: float) -> float: + if resp is not None: + retry_after = resp.headers.get("Retry-After") + if retry_after is not None: + return float(retry_after) + random.uniform(0, 1) + reset = resp.headers.get("X-Ratelimit-Reset") + if reset is not None: + return float(reset) + random.uniform(0.5, 1.5) + return delay + random.uniform(0, 1) + def _throttle(self) -> None: - """Enforce a minimum gap between requests.""" + if ( + self._remaining is not None + and self._remaining < RATE_LIMIT_LOW_WATERMARK + and self._reset is not None + ): + wait = max(self._reset, 0) + random.uniform(0.5, 1.5) + logger.warning( + "Rate limit low (%.1f remaining, reset in %.1fs) — sleeping %.1fs", + self._remaining, + self._reset, + wait, + ) + time.sleep(wait) + self._remaining = None + self._reset = None + self._last_request_at = time.time() + return + elapsed = time.time() - self._last_request_at interval = settings.REDDIT_REQUEST_INTERVAL if elapsed < interval: time.sleep(interval - elapsed) self._last_request_at = time.time() - def get(self, path: str, params: dict | None = None) -> dict: - """ - GET from the Reddit OAuth API with rate-limit enforcement and - exponential backoff on 429 / transient errors. - """ + def _request( + self, + method: str, + path: str, + *, + params: dict | None = None, + data: dict | None = None, + ) -> dict: self._ensure_token() url = f"{self._API_BASE}{path}" delay = RETRY_BASE_DELAY @@ -75,11 +251,14 @@ def get(self, path: str, params: dict | None = None) -> dict: for attempt in range(1, MAX_RETRIES + 1): self._throttle() try: - resp = self._session.get(url, params=params, timeout=30) + if method == "GET": + resp = self._session.get(url, params=params, timeout=30) + else: + resp = self._session.post(url, data=data, timeout=30) except requests.exceptions.RequestException as exc: if attempt == MAX_RETRIES: raise - wait = delay + random.uniform(0, 1) + wait = self._backoff_seconds(None, delay) logger.warning( "Network error (attempt %d/%d): %s — retrying in %.1fs", attempt, @@ -91,13 +270,26 @@ def get(self, path: str, params: dict | None = None) -> dict: delay *= 2 continue + self._update_rate_limit_state(resp) + if resp.status_code == 401: + if self._bearer_mode: + if self._session_cookie and attempt < MAX_RETRIES: + logger.warning( + "Bearer token rejected — re-minting from session..." + ) + self._remint_bearer_from_session() + continue + raise RuntimeError( + "Bearer token rejected — update REDDIT_BEARER_TOKEN or " + "REDDIT_SESSION_COOKIE in .env" + ) logger.warning("Token expired mid-run, refreshing...") self._refresh_token() continue if resp.status_code == 429: - wait = delay + random.uniform(0, 1) + wait = self._backoff_seconds(resp, delay) logger.warning( "Rate limited (429) on attempt %d/%d — retrying in %.1fs", attempt, @@ -105,13 +297,15 @@ def get(self, path: str, params: dict | None = None) -> dict: wait, ) time.sleep(wait) + self._remaining = None + self._reset = None delay *= 2 continue if resp.status_code != 200: if attempt == MAX_RETRIES: resp.raise_for_status() - wait = delay + random.uniform(0, 1) + wait = self._backoff_seconds(resp, delay) logger.warning( "HTTP %d on attempt %d/%d — retrying in %.1fs", resp.status_code, @@ -123,21 +317,183 @@ def get(self, path: str, params: dict | None = None) -> dict: delay *= 2 continue + if ( + self._remaining is not None + and self._remaining < RATE_LIMIT_LOW_WATERMARK + ): + logger.info( + "Rate limit quota low after request (%.1f remaining, reset in %.1fs)", + self._remaining, + self._reset or 0, + ) + return resp.json() raise RuntimeError(f"All {MAX_RETRIES} retries exhausted for {url}") + def get(self, path: str, params: dict | None = None) -> dict: + """GET from the Reddit OAuth API with rate-limit enforcement.""" + return self._request("GET", path, params=params) + + def fetch_user_about(self, username: str) -> dict | None: + """Fetch /user/{username}/about; returns None for deleted/invalid users.""" + username = (username or "").strip() + if not username or username in {"[deleted]", "AutoModerator"}: + return None + try: + payload = self.get(f"/user/{username}/about", params={"raw_json": 1}) + except requests.exceptions.HTTPError: + return None + data = payload.get("data") + return data if isinstance(data, dict) else None + + def fetch_comments_in_range( + self, + start_ts: int, + end_ts: int, + *, + subreddit: str = SUBREDDIT, + ) -> list[dict]: + """Paginate /r/{subreddit}/comments and keep items created in range.""" + comments: list[dict] = [] + after: str | None = None + + logger.info( + "Reddit: searching r/%s recent comments for %d..%d", + subreddit, + start_ts, + end_ts, + ) + + while True: + params: dict = {"limit": 100, "raw_json": 1} + if after: + params["after"] = after + + data = self.get(f"/r/{subreddit}/comments", params=params) + listing = data.get("data", {}) + children = listing.get("children", []) + + if not children: + break + + page_timestamps: list[int] = [] + for child in children: + if child.get("kind") != "t1": + continue + comment = child.get("data", {}) + created = int(comment.get("created_utc", 0)) + page_timestamps.append(created) + if start_ts <= created <= end_ts: + comments.append(comment) + + if page_timestamps and min(page_timestamps) < start_ts: + break + + after = listing.get("after") + if not after: + break + + logger.info("Reddit: fetched %d comments in range", len(comments)) + return comments + + def fetch_submissions_in_range( + self, + start_ts: int, + end_ts: int, + *, + subreddit: str = SUBREDDIT, + ) -> list[dict]: + """Paginate /r/{subreddit}/new and keep submissions created in range.""" + posts: dict[str, dict] = {} + after: str | None = None + + logger.info( + "Reddit: searching r/%s recent submissions for %d..%d", + subreddit, + start_ts, + end_ts, + ) + + while True: + params: dict = {"limit": 100, "raw_json": 1} + if after: + params["after"] = after + + data = self.get(f"/r/{subreddit}/new", params=params) + listing = data.get("data", {}) + children = listing.get("children", []) + + if not children: + break + + page_timestamps: list[int] = [] + for child in children: + if child.get("kind") != "t3": + continue + post = child.get("data", {}) + created = int(post.get("created_utc", 0)) + page_timestamps.append(created) + if start_ts <= created <= end_ts: + posts[post["id"]] = post + + if page_timestamps and min(page_timestamps) < start_ts: + break + + after = listing.get("after") + if not after: + break + + discovered = sorted(posts.values(), key=lambda post: int(post["created_utc"])) + logger.info("Submission discovery found %d posts in range", len(discovered)) + return discovered + def build_session() -> RedditSession: """Build a RedditSession from Django settings (REDDIT_*).""" - client_id = settings.REDDIT_CLIENT_ID - client_secret = settings.REDDIT_CLIENT_SECRET user_agent = settings.REDDIT_USER_AGENT + if not user_agent: + raise EnvironmentError("Missing required setting: REDDIT_USER_AGENT") + + client_id = _credentials_configured(settings.REDDIT_CLIENT_ID) + client_secret = _credentials_configured(settings.REDDIT_CLIENT_SECRET) + if client_id and client_secret: + logger.info("Using official Reddit API (client credentials)") + return RedditSession(client_id, client_secret, user_agent) + + bearer_raw = settings.REDDIT_BEARER_TOKEN + session_cookie = settings.REDDIT_SESSION_COOKIE + csrf_token = settings.REDDIT_CSRF_TOKEN + + if bearer_raw and not _is_bearer_expired(bearer_raw): + logger.warning("Using REDDIT_BEARER_TOKEN") + return RedditSession( + None, + None, + user_agent, + bearer_token=bearer_raw, + session_cookie=session_cookie or None, + csrf_token=csrf_token, + ) + + if session_cookie: + bearer_token = mint_bearer_from_session(session_cookie, user_agent, csrf_token) + return RedditSession( + None, + None, + user_agent, + bearer_token=bearer_token, + session_cookie=session_cookie, + csrf_token=csrf_token, + ) - if not all([client_id, client_secret, user_agent]): + if bearer_raw: raise EnvironmentError( - "Missing one or more required settings: " - "REDDIT_CLIENT_ID, REDDIT_CLIENT_SECRET, REDDIT_USER_AGENT" + "REDDIT_BEARER_TOKEN is expired — paste a fresh token_v2 or set " + "REDDIT_SESSION_COOKIE to auto-mint" ) - return RedditSession(client_id, client_secret, user_agent) + raise EnvironmentError( + "No Reddit credentials configured. Set REDDIT_CLIENT_ID + " + "REDDIT_CLIENT_SECRET, or REDDIT_BEARER_TOKEN, or REDDIT_SESSION_COOKIE" + ) diff --git a/reddit_activity_tracker/management/commands/run_reddit_activity_tracker.py b/reddit_activity_tracker/management/commands/run_reddit_activity_tracker.py index e8c43de2..5a4b64f1 100644 --- a/reddit_activity_tracker/management/commands/run_reddit_activity_tracker.py +++ b/reddit_activity_tracker/management/commands/run_reddit_activity_tracker.py @@ -3,42 +3,177 @@ from __future__ import annotations import logging +import time +from datetime import datetime, timedelta, timezone from typing import Any +from django.conf import settings + from core.collectors import AbstractCollector, BaseCollectorCommand from core.protocols import TrackerResult from core.tracker_result import GenericTrackerResult - -from reddit_activity_tracker.fetcher import build_session +from reddit_activity_tracker.fetcher import RedditSession, build_session +from reddit_activity_tracker.models import RedditSubmission +from reddit_activity_tracker.services import ( + get_latest_comment_created_utc, + get_latest_submission_created_utc, + resolve_submission_for_comment, + upsert_reddit_comment, + upsert_reddit_submission, +) +from reddit_activity_tracker.workspace import ( + write_comment_json, + write_submission_json, + write_user_json, +) logger = logging.getLogger(__name__) +DEFAULT_LOOKBACK_DAYS = 30 + + +def _parse_since(value: str | None) -> int | None: + if not value or not value.strip(): + return None + raw = value.strip() + try: + if "T" in raw or " " in raw: + dt = datetime.fromisoformat(raw.replace("Z", "+00:00")) + else: + dt = datetime.strptime(raw, "%Y-%m-%d") + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return int(dt.astimezone(timezone.utc).timestamp()) + except ValueError as exc: + raise ValueError( + f"Invalid --since value {value!r}; use YYYY-MM-DD or ISO datetime" + ) from exc + + +def _default_lookback_start() -> int: + lookback_days = getattr( + settings, "REDDIT_DEFAULT_LOOKBACK_DAYS", DEFAULT_LOOKBACK_DAYS + ) + start = datetime.now(timezone.utc) - timedelta(days=lookback_days) + return int(start.timestamp()) + + +def _resolve_submission_start_ts(options: dict[str, Any]) -> int: + since_override = _parse_since(options.get("since")) + if since_override is not None: + return since_override + + latest = get_latest_submission_created_utc() + if latest > 0: + return latest + + return _default_lookback_start() + + +def _resolve_comment_start_ts(options: dict[str, Any]) -> int: + since_override = _parse_since(options.get("since")) + if since_override is not None: + return since_override + + latest = get_latest_comment_created_utc() + if latest > 0: + return latest + + return _default_lookback_start() + + +def _record_user( + user, + seen_users: set[int], + counts: dict[str, int], +) -> None: + if user is None or user.pk in seen_users: + return + write_user_json(user) + seen_users.add(user.pk) + counts["users"] += 1 + class RedditActivityTrackerCollector(AbstractCollector): - """Collector stub — full fetch/upsert pipeline ships in PR2.""" + """Scrape r/cpp submissions and comments for the incremental time window.""" - def __init__(self, *, stdout: Any, style: Any) -> None: - self.stdout = stdout - self.style = style + def __init__(self, *, options: dict[str, Any]) -> None: + self.options = options + self._session: RedditSession | None = None + self._counts = { + "submissions": 0, + "comments": 0, + "users": 0, + } @property def name(self) -> str: return "reddit_activity_tracker" def validate_config(self) -> None: - build_session() + self._session = build_session() def collect(self) -> TrackerResult: - logger.info("run_reddit_activity_tracker: stub — fetch/upsert in PR2") - self.stdout.write( - self.style.SUCCESS("reddit_activity_tracker completed (stub)") + if self._session is None: + self._session = build_session() + + submission_start_ts = _resolve_submission_start_ts(self.options) + comment_start_ts = _resolve_comment_start_ts(self.options) + end_ts = int(time.time()) + logger.info( + "reddit_activity_tracker: submission window %d..%d, comment window %d..%d (UTC)", + submission_start_ts, + end_ts, + comment_start_ts, + end_ts, ) - logger.info("run_reddit_activity_tracker: finished successfully") - return GenericTrackerResult.ok() + + posts = self._session.fetch_submissions_in_range(submission_start_ts, end_ts) + comments_data = self._session.fetch_comments_in_range(comment_start_ts, end_ts) + + submissions_by_id: dict[str, RedditSubmission] = {} + seen_users: set[int] = set() + + for post in posts: + submission = upsert_reddit_submission(post, session=self._session) + write_submission_json(submission) + submissions_by_id[post["id"]] = submission + self._counts["submissions"] += 1 + _record_user(submission.user, seen_users, self._counts) + + for comment_data in comments_data: + submission = resolve_submission_for_comment( + comment_data, + submissions_by_id, + ) + comment = upsert_reddit_comment( + comment_data, + submission, + session=self._session, + ) + write_comment_json(comment) + self._counts["comments"] += 1 + _record_user(comment.user, seen_users, self._counts) + + logger.info( + "reddit_activity_tracker: finished submissions=%d comments=%d users=%d", + self._counts["submissions"], + self._counts["comments"], + self._counts["users"], + ) + return GenericTrackerResult.ok(**self._counts) class Command(BaseCollectorCommand): - help = "Run the reddit_activity_tracker collector (stub)." + help = "Run reddit_activity_tracker: scrape r/cpp submissions and comments." + + def add_arguments(self, parser) -> None: + parser.add_argument( + "--since", + type=str, + default=None, + help="Override start date (YYYY-MM-DD or ISO datetime). Default: latest DB timestamp.", + ) - def get_collector(self, **_options: Any) -> AbstractCollector: - return RedditActivityTrackerCollector(stdout=self.stdout, style=self.style) + def get_collector(self, **options: Any) -> AbstractCollector: + return RedditActivityTrackerCollector(options=options) diff --git a/reddit_activity_tracker/migrations/0002_replace_author_with_user_fk.py b/reddit_activity_tracker/migrations/0002_replace_author_with_user_fk.py new file mode 100644 index 00000000..eeb07d28 --- /dev/null +++ b/reddit_activity_tracker/migrations/0002_replace_author_with_user_fk.py @@ -0,0 +1,79 @@ +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("cppa_user_tracker", "0009_reddituser_alter_baseprofile_type"), + ("reddit_activity_tracker", "0001_initial"), + ] + + operations = [ + migrations.RenameField( + model_name="redditsubmission", + old_name="reddit_id", + new_name="reddit_submission_id", + ), + migrations.RenameField( + model_name="redditcomment", + old_name="reddit_id", + new_name="reddit_comment_id", + ), + migrations.AlterModelOptions( + name="redditsubmission", + options={ + "ordering": ["-created_utc", "reddit_submission_id"], + "verbose_name": "Reddit submission", + "verbose_name_plural": "Reddit submissions", + }, + ), + migrations.AlterModelOptions( + name="redditcomment", + options={ + "ordering": ["created_utc", "reddit_comment_id"], + "verbose_name": "Reddit comment", + "verbose_name_plural": "Reddit comments", + }, + ), + migrations.RemoveField( + model_name="redditsubmission", + name="author", + ), + migrations.RemoveField( + model_name="redditsubmission", + name="author_id", + ), + migrations.RemoveField( + model_name="redditcomment", + name="author", + ), + migrations.RemoveField( + model_name="redditcomment", + name="author_id", + ), + migrations.AddField( + model_name="redditsubmission", + name="user", + field=models.ForeignKey( + blank=True, + db_column="reddit_user_id", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="submissions", + to="cppa_user_tracker.reddituser", + ), + ), + migrations.AddField( + model_name="redditcomment", + name="user", + field=models.ForeignKey( + blank=True, + db_column="reddit_user_id", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="comments", + to="cppa_user_tracker.reddituser", + ), + ), + ] diff --git a/reddit_activity_tracker/models.py b/reddit_activity_tracker/models.py index 498b8fdf..5b50a4da 100644 --- a/reddit_activity_tracker/models.py +++ b/reddit_activity_tracker/models.py @@ -4,12 +4,18 @@ class RedditSubmission(models.Model): - """Reddit post (submission) from a subreddit; keyed by reddit_id (t3_* fullname).""" + """Reddit post (submission) from a subreddit; keyed by reddit_submission_id (t3_*).""" - reddit_id = models.CharField(max_length=20, unique=True, db_index=True) + reddit_submission_id = models.CharField(max_length=20, unique=True, db_index=True) subreddit = models.CharField(max_length=128, db_index=True) - author = models.CharField(max_length=255, blank=True) - author_id = models.CharField(max_length=64, blank=True) + user = models.ForeignKey( + "cppa_user_tracker.RedditUser", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="submissions", + db_column="reddit_user_id", + ) title = models.CharField(max_length=1024) selftext = models.TextField(blank=True) selftext_html = models.TextField(blank=True) @@ -21,26 +27,32 @@ class RedditSubmission(models.Model): fetched_at = models.DateTimeField(auto_now=True) class Meta: - ordering = ["-created_utc", "reddit_id"] + ordering = ["-created_utc", "reddit_submission_id"] verbose_name = "Reddit submission" verbose_name_plural = "Reddit submissions" def __str__(self) -> str: - return f"{self.reddit_id}: {self.title[:60]}" + return f"{self.reddit_submission_id}: {self.title[:60]}" class RedditComment(models.Model): - """Reddit comment on a submission; keyed by reddit_id (t1_* fullname).""" + """Reddit comment on a submission; keyed by reddit_comment_id (t1_*).""" - reddit_id = models.CharField(max_length=20, unique=True, db_index=True) + reddit_comment_id = models.CharField(max_length=20, unique=True, db_index=True) submission = models.ForeignKey( RedditSubmission, on_delete=models.CASCADE, related_name="comments", ) parent_id = models.CharField(max_length=20, blank=True) - author = models.CharField(max_length=255, blank=True) - author_id = models.CharField(max_length=64, blank=True) + user = models.ForeignKey( + "cppa_user_tracker.RedditUser", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="comments", + db_column="reddit_user_id", + ) body = models.TextField(blank=True) url = models.URLField(max_length=1024) score = models.IntegerField(default=0) @@ -48,9 +60,9 @@ class RedditComment(models.Model): fetched_at = models.DateTimeField(auto_now=True) class Meta: - ordering = ["created_utc", "reddit_id"] + ordering = ["created_utc", "reddit_comment_id"] verbose_name = "Reddit comment" verbose_name_plural = "Reddit comments" def __str__(self) -> str: - return f"{self.reddit_id} on {self.submission.reddit_id}" + return f"{self.reddit_comment_id} on {self.submission.reddit_submission_id}" diff --git a/reddit_activity_tracker/services.py b/reddit_activity_tracker/services.py index 6f55bbc8..8dcc5a72 100644 --- a/reddit_activity_tracker/services.py +++ b/reddit_activity_tracker/services.py @@ -1,5 +1,179 @@ -"""All creates/updates/deletes for reddit_activity_tracker models must go through this module. +""" +Service layer for reddit_activity_tracker. -See CONTRIBUTING.md (service layer). Call these functions from collectors and -management commands instead of using Model.objects.* directly outside tests. +All creates/updates/deletes for this app's models must go through functions in this +module. See CONTRIBUTING.md for the project-wide rule. """ + +from __future__ import annotations + +from typing import Any + +from django.db import transaction +from django.db.models import Max + +from cppa_user_tracker.services import resolve_reddit_user_from_author_data + +from reddit_activity_tracker.fetcher import SUBREDDIT, RedditSession +from reddit_activity_tracker.models import RedditComment, RedditSubmission + + +def submission_id_from_link_id(link_id: str) -> str | None: + link_id = (link_id or "").strip() + if link_id.startswith("t3_"): + return link_id[3:] or None + return link_id or None + + +@transaction.atomic +def get_or_create_submission_stub( + submission_id: str, + *, + subreddit: str = "cpp", +) -> RedditSubmission: + """Ensure a submission row exists for FK when only a comment link_id is known.""" + post_id = (submission_id or "").strip().removeprefix("t3_") + if not post_id: + raise ValueError("Submission id is required") + + reddit_submission_id = f"t3_{post_id}" + permalink = f"/r/{subreddit}/comments/{post_id}/" + submission, _created = RedditSubmission.objects.get_or_create( + reddit_submission_id=reddit_submission_id, + defaults={ + "subreddit": subreddit, + "user": None, + "title": "", + "selftext": "", + "selftext_html": "", + "url": f"https://www.reddit.com{permalink}", + "permalink": permalink, + "score": 0, + "num_comments": 0, + "created_utc": 0, + }, + ) + return submission + + +def resolve_submission_for_comment( + comment_data: dict, + submissions_by_id: dict[str, RedditSubmission], +) -> RedditSubmission: + """Return the submission row for a period comment, creating a stub if needed.""" + post_id = submission_id_from_link_id(comment_data.get("link_id", "")) + if not post_id: + raise ValueError("Comment link_id is required") + + submission = submissions_by_id.get(post_id) + if submission is not None: + return submission + + reddit_submission_id = f"t3_{post_id}" + existing = RedditSubmission.objects.filter( + reddit_submission_id=reddit_submission_id + ).first() + if existing is not None: + return existing + + return get_or_create_submission_stub( + post_id, + subreddit=(comment_data.get("subreddit") or SUBREDDIT).strip(), + ) + + +@transaction.atomic +def upsert_reddit_submission( + data: dict[str, Any], + *, + session: RedditSession | None = None, +) -> RedditSubmission: + """Update or create a submission keyed by reddit_submission_id.""" + post_id = (data.get("id") or "").strip() + if not post_id: + raise ValueError("Submission id is required") + + reddit_submission_id = ( + data.get("reddit_submission_id") or data.get("name") or f"t3_{post_id}" + ) + if not str(reddit_submission_id).startswith("t3_"): + reddit_submission_id = f"t3_{post_id}" + + user = resolve_reddit_user_from_author_data(data, client=session) + defaults = { + "subreddit": (data.get("subreddit") or "cpp").strip(), + "user": user, + "title": (data.get("title") or "")[:1024], + "selftext": data.get("selftext") or "", + "selftext_html": data.get("selftext_html") or "", + "url": data.get("url") or f"https://www.reddit.com{data.get('permalink', '')}", + "permalink": data.get("permalink") or "", + "score": int(data.get("score") or 0), + "num_comments": int(data.get("num_comments") or 0), + "created_utc": int(data.get("created_utc") or 0), + } + + submission, _created = RedditSubmission.objects.update_or_create( + reddit_submission_id=reddit_submission_id, + defaults=defaults, + ) + return submission + + +@transaction.atomic +def upsert_reddit_comment( + data: dict[str, Any], + submission: RedditSubmission, + *, + session: RedditSession | None = None, +) -> RedditComment: + """Update or create a comment keyed by reddit_comment_id.""" + comment_id = (data.get("id") or "").strip() + if not comment_id: + raise ValueError("Comment id is required") + + reddit_comment_id = ( + data.get("reddit_comment_id") or data.get("name") or f"t1_{comment_id}" + ) + if not str(reddit_comment_id).startswith("t1_"): + reddit_comment_id = f"t1_{comment_id}" + + user = resolve_reddit_user_from_author_data(data, client=session) + permalink = (data.get("permalink") or "").strip() + if permalink: + url = ( + permalink + if permalink.startswith("http") + else f"https://www.reddit.com{permalink}" + ) + else: + submission_permalink = submission.permalink.rstrip("/") + url = f"https://www.reddit.com{submission_permalink}/{comment_id}/" + + defaults = { + "submission": submission, + "user": user, + "parent_id": data.get("parent_id") or "", + "body": data.get("body") or "", + "url": url, + "score": int(data.get("score") or 0), + "created_utc": int(data.get("created_utc") or 0), + } + + comment, _created = RedditComment.objects.update_or_create( + reddit_comment_id=reddit_comment_id, + defaults=defaults, + ) + return comment + + +def get_latest_submission_created_utc() -> int: + """Return max created_utc across submissions, or 0 when empty.""" + latest = RedditSubmission.objects.aggregate(latest=Max("created_utc"))["latest"] + return latest if latest is not None else 0 + + +def get_latest_comment_created_utc() -> int: + """Return max created_utc across comments, or 0 when empty.""" + latest = RedditComment.objects.aggregate(latest=Max("created_utc"))["latest"] + return latest if latest is not None else 0 diff --git a/reddit_activity_tracker/tests/conftest.py b/reddit_activity_tracker/tests/conftest.py new file mode 100644 index 00000000..25a4ff08 --- /dev/null +++ b/reddit_activity_tracker/tests/conftest.py @@ -0,0 +1,64 @@ +"""Shared fixtures for reddit_activity_tracker tests.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from reddit_activity_tracker.fetcher import RedditSession + + +@pytest.fixture +def reddit_session() -> RedditSession: + return RedditSession("cid", "secret", "test/1.0") + + +@pytest.fixture +def sample_submission_payload() -> dict: + return { + "id": "7i8fbd", + "name": "t3_7i8fbd", + "subreddit": "cpp", + "author": "Taladar", + "author_fullname": "t2_abc123", + "title": "Winsock Chat server and client", + "selftext": "", + "selftext_html": "", + "url": "https://github.com/VedantParanjape/Chat-Server-and-Client", + "permalink": "/r/cpp/comments/7i8fbd/winsock_chat_server_and_client/", + "score": 0, + "num_comments": 8, + "created_utc": 1512670704, + } + + +@pytest.fixture +def sample_comment_payload() -> dict: + return { + "id": "1h3p", + "name": "t1_1h3p", + "author": "Taladar", + "author_fullname": "t2_abc123", + "parent_id": "t3_7ijpx", + "body": "I call bullshit.", + "score": 1, + "created_utc": 1229786861, + } + + +@pytest.fixture +def mock_user_about() -> dict: + return { + "id": "abc123", + "name": "Taladar", + "subreddit": {"title": "Taladar"}, + } + + +@pytest.fixture +def mock_reddit_session(mock_user_about: dict) -> MagicMock: + session = MagicMock(spec=RedditSession) + session.fetch_user_about.return_value = mock_user_about + session.get.return_value = {"data": {"children": []}} + return session diff --git a/reddit_activity_tracker/tests/test_collector_integration.py b/reddit_activity_tracker/tests/test_collector_integration.py new file mode 100644 index 00000000..11e97809 --- /dev/null +++ b/reddit_activity_tracker/tests/test_collector_integration.py @@ -0,0 +1,90 @@ +"""Integration tests for RedditActivityTrackerCollector.""" + +from io import StringIO +from unittest.mock import MagicMock, patch + +import pytest +from django.core.management import call_command +from django.test import override_settings + +from reddit_activity_tracker.management.commands.run_reddit_activity_tracker import ( + RedditActivityTrackerCollector, +) +from reddit_activity_tracker.models import RedditComment, RedditSubmission + + +@pytest.mark.django_db +@override_settings( + REDDIT_CLIENT_ID="cid", + REDDIT_CLIENT_SECRET="secret", + REDDIT_USER_AGENT="test/1.0", + WORKSPACE_DIR="/tmp/reddit_collector_test", +) +@patch( + "reddit_activity_tracker.management.commands.run_reddit_activity_tracker.build_session" +) +def test_collector_end_to_end(mock_build_session, tmp_path, settings): + settings.WORKSPACE_DIR = str(tmp_path) + session = MagicMock() + session.fetch_user_about.return_value = { + "id": "abc123", + "name": "Taladar", + "subreddit": {"title": "Taladar"}, + } + session.fetch_submissions_in_range.return_value = [ + { + "id": "7i8fbd", + "subreddit": "cpp", + "author": "Taladar", + "author_fullname": "t2_abc123", + "title": "Winsock Chat server and client", + "selftext": "", + "selftext_html": "", + "url": "https://github.com/example/repo", + "permalink": "/r/cpp/comments/7i8fbd/winsock/", + "score": 0, + "num_comments": 1, + "created_utc": 1512670704, + } + ] + session.fetch_comments_in_range.return_value = [ + { + "id": "1h3p", + "author": "Taladar", + "author_fullname": "t2_abc123", + "parent_id": "t3_7i8fbd", + "link_id": "t3_7i8fbd", + "body": "Nice post", + "score": 1, + "created_utc": 1512670800, + } + ] + mock_build_session.return_value = session + + collector = RedditActivityTrackerCollector(options={"since": "2017-12-01"}) + result = collector.run() + + assert result.success is True + assert result.counts["submissions"] == 1 + assert result.counts["comments"] == 1 + assert RedditSubmission.objects.filter(reddit_submission_id="t3_7i8fbd").exists() + assert RedditComment.objects.filter(reddit_comment_id="t1_1h3p").exists() + assert ( + tmp_path / "reddit_activity_tracker" / "submissions" / "t3_7i8fbd.json" + ).exists() + assert (tmp_path / "reddit_activity_tracker" / "comments" / "t1_1h3p.json").exists() + + +@pytest.mark.django_db +@patch( + "reddit_activity_tracker.management.commands.run_reddit_activity_tracker.build_session" +) +def test_run_command_success(mock_build_session): + session = MagicMock() + session.fetch_submissions_in_range.return_value = [] + session.fetch_comments_in_range.return_value = [] + mock_build_session.return_value = session + out = StringIO() + call_command("run_reddit_activity_tracker", stdout=out, verbosity=0) + assert "Collector finished" in out.getvalue() or out.getvalue() == "" + mock_build_session.assert_called() diff --git a/reddit_activity_tracker/tests/test_fetcher.py b/reddit_activity_tracker/tests/test_fetcher.py deleted file mode 100644 index d3eb56bd..00000000 --- a/reddit_activity_tracker/tests/test_fetcher.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Tests for reddit_activity_tracker.fetcher (RedditSession).""" - -import itertools -from unittest.mock import MagicMock, patch - -import pytest -import requests -from django.test import override_settings -from requests.exceptions import ConnectionError - -from reddit_activity_tracker import fetcher - -_REDDIT_SETTINGS = { - "REDDIT_CLIENT_ID": "cid", - "REDDIT_CLIENT_SECRET": "secret", - "REDDIT_USER_AGENT": "test/1.0", -} -_NO_THROTTLE = {"REDDIT_REQUEST_INTERVAL": 0.0} - - -def _token_response(expires_in: int = 3600) -> MagicMock: - resp = MagicMock() - resp.status_code = 200 - resp.json.return_value = {"access_token": "test-token", "expires_in": expires_in} - resp.raise_for_status = MagicMock() - return resp - - -def _api_response(status: int = 200, json_data: dict | None = None) -> MagicMock: - resp = MagicMock() - resp.status_code = status - resp.json.return_value = json_data if json_data is not None else {"ok": True} - resp.raise_for_status = MagicMock( - side_effect=( - requests.exceptions.HTTPError(response=resp) if status >= 400 else None - ) - ) - return resp - - -@override_settings(**_REDDIT_SETTINGS) -def test_build_session_success(): - session = fetcher.build_session() - assert isinstance(session, fetcher.RedditSession) - - -@override_settings( - REDDIT_CLIENT_ID="", - REDDIT_CLIENT_SECRET="", - REDDIT_USER_AGENT="", -) -def test_build_session_missing_env(): - with pytest.raises(EnvironmentError, match="REDDIT_CLIENT_ID"): - fetcher.build_session() - - -@override_settings(**_REDDIT_SETTINGS, **_NO_THROTTLE) -@patch("reddit_activity_tracker.fetcher.time.time") -@patch("reddit_activity_tracker.fetcher.time.sleep") -def test_token_fetch_success(mock_sleep, mock_time): - mock_time.side_effect = lambda: next(itertools.count(1000, 1)) - session = fetcher.RedditSession("cid", "secret", "ua/1.0") - session._session.post = MagicMock(return_value=_token_response()) - session._session.get = MagicMock( - return_value=_api_response(json_data={"data": {"children": []}}) - ) - - result = session.get("/r/cpp/new", params={"limit": 5}) - - assert result == {"data": {"children": []}} - session._session.post.assert_called_once() - assert "Bearer test-token" in session._session.headers["Authorization"] - - -@override_settings(**_NO_THROTTLE) -@patch("reddit_activity_tracker.fetcher.time.time") -def test_token_fetch_failure(mock_time): - mock_time.side_effect = lambda: next(itertools.count(1000, 1)) - session = fetcher.RedditSession("cid", "secret", "ua/1.0") - fail_resp = MagicMock() - fail_resp.raise_for_status.side_effect = requests.exceptions.HTTPError( - response=fail_resp - ) - session._session.post = MagicMock(return_value=fail_resp) - - with pytest.raises(requests.exceptions.HTTPError): - session.get("/r/cpp/new") - - -@override_settings(**_NO_THROTTLE) -@patch("reddit_activity_tracker.fetcher.time.time") -@patch("reddit_activity_tracker.fetcher.time.sleep") -def test_429_retry_and_backoff(mock_sleep, mock_time): - mock_time.side_effect = lambda: next(itertools.count(1000, 1)) - session = fetcher.RedditSession("cid", "secret", "ua/1.0") - session._session.post = MagicMock(return_value=_token_response()) - session._session.get = MagicMock( - side_effect=[ - _api_response(status=429), - _api_response(status=429), - _api_response(json_data={"ok": True}), - ] - ) - - result = session.get("/r/cpp/new") - - assert result == {"ok": True} - assert session._session.get.call_count == 3 - assert mock_sleep.call_count >= 2 - - -@override_settings(**_NO_THROTTLE) -@patch("reddit_activity_tracker.fetcher.time.time") -@patch("reddit_activity_tracker.fetcher.time.sleep") -def test_401_mid_run_token_refresh(mock_sleep, mock_time): - mock_time.side_effect = lambda: next(itertools.count(1000, 1)) - session = fetcher.RedditSession("cid", "secret", "ua/1.0") - session._session.post = MagicMock(return_value=_token_response()) - session._session.get = MagicMock( - side_effect=[ - _api_response(status=401), - _api_response(json_data={"refreshed": True}), - ] - ) - - result = session.get("/r/cpp/new") - - assert result == {"refreshed": True} - assert session._session.post.call_count == 2 - - -@override_settings(**_NO_THROTTLE) -@patch("reddit_activity_tracker.fetcher.time.time") -@patch("reddit_activity_tracker.fetcher.time.sleep") -def test_network_error_retry(mock_sleep, mock_time): - mock_time.side_effect = lambda: next(itertools.count(1000, 1)) - session = fetcher.RedditSession("cid", "secret", "ua/1.0") - session._session.post = MagicMock(return_value=_token_response()) - session._session.get = MagicMock( - side_effect=[ - ConnectionError("down"), - _api_response(json_data={"ok": True}), - ] - ) - - result = session.get("/r/cpp/new") - - assert result == {"ok": True} - assert session._session.get.call_count == 2 - - -@override_settings(**_NO_THROTTLE) -@patch("reddit_activity_tracker.fetcher.time.time") -@patch("reddit_activity_tracker.fetcher.time.sleep") -def test_all_retries_exhausted(mock_sleep, mock_time): - mock_time.side_effect = lambda: next(itertools.count(1000, 1)) - session = fetcher.RedditSession("cid", "secret", "ua/1.0") - session._session.post = MagicMock(return_value=_token_response()) - session._session.get = MagicMock(side_effect=ConnectionError("down")) - - with pytest.raises(ConnectionError): - session.get("/r/cpp/new") - - assert session._session.get.call_count == fetcher.MAX_RETRIES diff --git a/reddit_activity_tracker/tests/test_models.py b/reddit_activity_tracker/tests/test_models.py index 2a2b184f..620b2a00 100644 --- a/reddit_activity_tracker/tests/test_models.py +++ b/reddit_activity_tracker/tests/test_models.py @@ -4,14 +4,15 @@ from django.db import IntegrityError from model_bakery import baker +from cppa_user_tracker.models import RedditUser from reddit_activity_tracker.models import RedditComment, RedditSubmission @pytest.mark.django_db -def test_reddit_submission_reddit_id_unique(): +def test_reddit_submission_reddit_submission_id_unique(): baker.make( RedditSubmission, - reddit_id="t3_abc123", + reddit_submission_id="t3_abc123", subreddit="cpp", title="First", url="https://example.com", @@ -21,7 +22,7 @@ def test_reddit_submission_reddit_id_unique(): with pytest.raises(IntegrityError): baker.make( RedditSubmission, - reddit_id="t3_abc123", + reddit_submission_id="t3_abc123", subreddit="cpp", title="Duplicate", url="https://example.com/2", @@ -34,7 +35,7 @@ def test_reddit_submission_reddit_id_unique(): def test_reddit_comment_cascade_delete(): submission = baker.make( RedditSubmission, - reddit_id="t3_sub001", + reddit_submission_id="t3_sub001", subreddit="cpp", title="Post", url="https://example.com", @@ -43,37 +44,38 @@ def test_reddit_comment_cascade_delete(): ) baker.make( RedditComment, - reddit_id="t1_cmt001", + reddit_comment_id="t1_cmt001", submission=submission, parent_id="t3_sub001", url="https://www.reddit.com/r/cpp/comments/sub001/cmt001/", created_utc=1700000100, ) - assert RedditComment.objects.filter(reddit_id="t1_cmt001").exists() + assert RedditComment.objects.filter(reddit_comment_id="t1_cmt001").exists() submission.delete() - assert not RedditComment.objects.filter(reddit_id="t1_cmt001").exists() + assert not RedditComment.objects.filter(reddit_comment_id="t1_cmt001").exists() @pytest.mark.django_db -def test_reddit_submission_str(): +def test_reddit_submission_user_fk(): + user = baker.make(RedditUser, username="Taladar", display_name="Taladar") submission = baker.make( RedditSubmission, - reddit_id="t3_str001", + reddit_submission_id="t3_str001", subreddit="cpp", + user=user, title="Hello World", url="https://example.com", permalink="/r/cpp/comments/str001/", created_utc=1700000000, ) - assert "t3_str001" in str(submission) - assert "Hello World" in str(submission) + assert submission.user.username == "Taladar" @pytest.mark.django_db def test_reddit_comment_str(): submission = baker.make( RedditSubmission, - reddit_id="t3_sub002", + reddit_submission_id="t3_sub002", subreddit="cpp", title="Post", url="https://example.com", @@ -82,7 +84,7 @@ def test_reddit_comment_str(): ) comment = baker.make( RedditComment, - reddit_id="t1_cmt002", + reddit_comment_id="t1_cmt002", submission=submission, parent_id="t3_sub002", url="https://www.reddit.com/r/cpp/comments/sub002/cmt002/", diff --git a/reddit_activity_tracker/tests/test_run_reddit_activity_tracker_command.py b/reddit_activity_tracker/tests/test_run_reddit_activity_tracker_command.py index 72993d68..5e50d3a5 100644 --- a/reddit_activity_tracker/tests/test_run_reddit_activity_tracker_command.py +++ b/reddit_activity_tracker/tests/test_run_reddit_activity_tracker_command.py @@ -12,8 +12,12 @@ ) @pytest.mark.django_db def test_run_reddit_activity_tracker_writes_success(mock_build_session): - mock_build_session.return_value = MagicMock() + session = MagicMock() + session.fetch_submissions_in_range.return_value = [] + session.fetch_comments_in_range.return_value = [] + mock_build_session.return_value = session out = StringIO() call_command("run_reddit_activity_tracker", stdout=out, verbosity=0) - assert "completed" in out.getvalue().lower() mock_build_session.assert_called_once() + session.fetch_submissions_in_range.assert_called_once() + session.fetch_comments_in_range.assert_called_once() diff --git a/reddit_activity_tracker/tests/test_services.py b/reddit_activity_tracker/tests/test_services.py new file mode 100644 index 00000000..e4418028 --- /dev/null +++ b/reddit_activity_tracker/tests/test_services.py @@ -0,0 +1,111 @@ +"""Tests for reddit_activity_tracker.services.""" + +import pytest +from model_bakery import baker + +from reddit_activity_tracker import services +from reddit_activity_tracker.models import RedditComment, RedditSubmission + + +@pytest.mark.django_db +def test_upsert_reddit_submission_creates_and_updates( + sample_submission_payload, mock_reddit_session +): + submission = services.upsert_reddit_submission( + sample_submission_payload, + session=mock_reddit_session, + ) + assert submission.reddit_submission_id == "t3_7i8fbd" + assert submission.user.username == "Taladar" + + sample_submission_payload["score"] = 5 + updated = services.upsert_reddit_submission( + sample_submission_payload, + session=mock_reddit_session, + ) + assert updated.pk == submission.pk + assert updated.score == 5 + + +@pytest.mark.django_db +def test_upsert_reddit_comment_creates_and_updates( + sample_submission_payload, + sample_comment_payload, + mock_reddit_session, +): + submission = services.upsert_reddit_submission( + sample_submission_payload, + session=mock_reddit_session, + ) + comment = services.upsert_reddit_comment( + sample_comment_payload, + submission, + session=mock_reddit_session, + ) + assert comment.reddit_comment_id == "t1_1h3p" + assert comment.submission_id == submission.pk + + sample_comment_payload["score"] = 9 + updated = services.upsert_reddit_comment( + sample_comment_payload, + submission, + session=mock_reddit_session, + ) + assert updated.pk == comment.pk + assert updated.score == 9 + + +@pytest.mark.django_db +def test_submission_id_from_link_id(): + assert services.submission_id_from_link_id("t3_7ijpx") == "7ijpx" + assert services.submission_id_from_link_id("") is None + + +@pytest.mark.django_db +def test_resolve_submission_for_comment_uses_stub(): + submission = services.resolve_submission_for_comment( + {"link_id": "t3_oldpost", "subreddit": "cpp"}, + {}, + ) + assert submission.reddit_submission_id == "t3_oldpost" + assert submission.title == "" + + +@pytest.mark.django_db +def test_get_or_create_submission_stub_creates_minimal_row(): + submission = services.get_or_create_submission_stub("7i8fbd") + assert submission.reddit_submission_id == "t3_7i8fbd" + assert submission.title == "" + assert submission.created_utc == 0 + + again = services.get_or_create_submission_stub("t3_7i8fbd") + assert again.pk == submission.pk + + +@pytest.mark.django_db +def test_get_latest_submission_and_comment_created_utc_empty_db(): + assert services.get_latest_submission_created_utc() == 0 + assert services.get_latest_comment_created_utc() == 0 + + +@pytest.mark.django_db +def test_get_latest_submission_and_comment_created_utc_independent(): + baker.make( + RedditSubmission, + reddit_submission_id="t3_a", + subreddit="cpp", + title="A", + url="https://example.com/a", + permalink="/r/cpp/comments/a/", + created_utc=100, + ) + baker.make( + RedditComment, + reddit_comment_id="t1_b", + submission=RedditSubmission.objects.get(reddit_submission_id="t3_a"), + parent_id="t3_a", + url="https://example.com/b", + created_utc=200, + ) + assert services.get_latest_submission_created_utc() == 100 + assert services.get_latest_comment_created_utc() == 200 diff --git a/reddit_activity_tracker/tests/test_workspace.py b/reddit_activity_tracker/tests/test_workspace.py new file mode 100644 index 00000000..900d3517 --- /dev/null +++ b/reddit_activity_tracker/tests/test_workspace.py @@ -0,0 +1,88 @@ +"""Tests for reddit_activity_tracker.workspace.""" + +import json + +import pytest +from django.test import override_settings +from model_bakery import baker + +from cppa_user_tracker.models import RedditUser +from reddit_activity_tracker.models import RedditComment, RedditSubmission +from reddit_activity_tracker.workspace import ( + get_comment_json_path, + get_submission_json_path, + get_user_json_path, + write_comment_json, + write_submission_json, + write_user_json, +) + + +@pytest.mark.django_db +@override_settings(WORKSPACE_DIR="/tmp/reddit_workspace_test") +def test_write_user_json_creates_file(tmp_path, settings): + settings.WORKSPACE_DIR = str(tmp_path) + user = baker.make( + RedditUser, + username="Taladar", + reddit_user_id="t2_abc123", + display_name="Taladar", + ) + path = write_user_json(user) + assert path == get_user_json_path("Taladar") + payload = json.loads(path.read_text(encoding="utf-8")) + assert payload["username"] == "Taladar" + assert payload["reddit_user_id"] == "t2_abc123" + + +@pytest.mark.django_db +@override_settings(WORKSPACE_DIR="/tmp/reddit_workspace_test") +def test_write_submission_json_overwrites(tmp_path, settings): + settings.WORKSPACE_DIR = str(tmp_path) + submission = baker.make( + RedditSubmission, + reddit_submission_id="t3_7i8fbd", + subreddit="cpp", + title="First", + url="https://example.com", + permalink="/r/cpp/comments/7i8fbd/", + created_utc=1512670704, + score=0, + ) + write_submission_json(submission) + submission.score = 10 + submission.save() + path = write_submission_json(submission) + payload = json.loads(path.read_text(encoding="utf-8")) + assert payload["reddit_submission_id"] == "t3_7i8fbd" + assert payload["score"] == 10 + assert path == get_submission_json_path("t3_7i8fbd") + + +@pytest.mark.django_db +@override_settings(WORKSPACE_DIR="/tmp/reddit_workspace_test") +def test_write_comment_json(tmp_path, settings): + settings.WORKSPACE_DIR = str(tmp_path) + submission = baker.make( + RedditSubmission, + reddit_submission_id="t3_sub", + subreddit="cpp", + title="Post", + url="https://example.com", + permalink="/r/cpp/comments/sub/", + created_utc=100, + ) + comment = baker.make( + RedditComment, + reddit_comment_id="t1_cmt", + submission=submission, + parent_id="t3_sub", + body="hello", + url="https://example.com/c", + created_utc=200, + ) + path = write_comment_json(comment) + payload = json.loads(path.read_text(encoding="utf-8")) + assert payload["reddit_comment_id"] == "t1_cmt" + assert payload["submission_id"] == "t3_sub" + assert path == get_comment_json_path("t1_cmt") diff --git a/reddit_activity_tracker/workspace.py b/reddit_activity_tracker/workspace.py new file mode 100644 index 00000000..c0341fbd --- /dev/null +++ b/reddit_activity_tracker/workspace.py @@ -0,0 +1,136 @@ +""" +Workspace paths and JSON writers for reddit_activity_tracker. + +Layout: + workspace/reddit_activity_tracker/users/{username}.json + workspace/reddit_activity_tracker/submissions/{reddit_submission_id}.json + workspace/reddit_activity_tracker/comments/{reddit_comment_id}.json +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from pathlib import Path + +from config.workspace import get_workspace_path +from core.operations.file_ops import sanitize_filename +from cppa_user_tracker.models import RedditUser +from reddit_activity_tracker.models import RedditComment, RedditSubmission + +_APP_SLUG = "reddit_activity_tracker" + + +def _iso(dt: datetime | None) -> str | None: + if dt is None: + return None + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") + + +def get_workspace_root() -> Path: + return get_workspace_path(_APP_SLUG) + + +def _users_dir() -> Path: + path = get_workspace_root() / "users" + path.mkdir(parents=True, exist_ok=True) + return path + + +def _submissions_dir() -> Path: + path = get_workspace_root() / "submissions" + path.mkdir(parents=True, exist_ok=True) + return path + + +def _comments_dir() -> Path: + path = get_workspace_root() / "comments" + path.mkdir(parents=True, exist_ok=True) + return path + + +def _slug(value: str) -> str: + cleaned = sanitize_filename((value or "").strip()).strip("_") + return cleaned or "unknown" + + +def get_user_json_path(username: str) -> Path: + return _users_dir() / f"{_slug(username)}.json" + + +def get_submission_json_path(reddit_submission_id: str) -> Path: + return _submissions_dir() / f"{_slug(reddit_submission_id)}.json" + + +def get_comment_json_path(reddit_comment_id: str) -> Path: + return _comments_dir() / f"{_slug(reddit_comment_id)}.json" + + +def _write_json(path: Path, payload: dict) -> Path: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + return path + + +def user_to_dict(user: RedditUser) -> dict: + return { + "reddit_user_id": user.reddit_user_id, + "username": user.username, + "display_name": user.display_name, + "created_at": _iso(user.created_at), + "updated_at": _iso(user.updated_at), + } + + +def submission_to_dict(submission: RedditSubmission) -> dict: + return { + "reddit_submission_id": submission.reddit_submission_id, + "subreddit": submission.subreddit, + "user": submission.user.username if submission.user_id else None, + "title": submission.title, + "selftext": submission.selftext, + "selftext_html": submission.selftext_html, + "url": submission.url, + "permalink": submission.permalink, + "score": submission.score, + "num_comments": submission.num_comments, + "created_utc": submission.created_utc, + "fetched_at": _iso(submission.fetched_at), + } + + +def comment_to_dict(comment: RedditComment) -> dict: + return { + "reddit_comment_id": comment.reddit_comment_id, + "submission_id": comment.submission.reddit_submission_id, + "user": comment.user.username if comment.user_id else None, + "parent_id": comment.parent_id, + "body": comment.body, + "url": comment.url, + "score": comment.score, + "created_utc": comment.created_utc, + "fetched_at": _iso(comment.fetched_at), + } + + +def write_user_json(user: RedditUser) -> Path: + return _write_json(get_user_json_path(user.username), user_to_dict(user)) + + +def write_submission_json(submission: RedditSubmission) -> Path: + return _write_json( + get_submission_json_path(submission.reddit_submission_id), + submission_to_dict(submission), + ) + + +def write_comment_json(comment: RedditComment) -> Path: + return _write_json( + get_comment_json_path(comment.reddit_comment_id), + comment_to_dict(comment), + )