Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/s3_log_extraction/ip_utils/_ip_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def load_ip_cache(
*,
cache_type: typing.Literal["ip_to_region", "ip_not_in_services", "region_codes_to_coordinates"],
cache_type: typing.Literal["ip_to_region", "region_codes_to_coordinates"],
cache_directory: str | pathlib.Path | None = None,
) -> dict[str, str]:
"""Load the IP cache from the cache directory."""
Expand Down
81 changes: 34 additions & 47 deletions src/s3_log_extraction/ip_utils/_update_ip_to_region_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import pathlib
import random
import typing
import warnings

import tqdm
import yaml
Expand All @@ -17,7 +19,7 @@ def update_ip_to_region_codes(
batch_size: int = 1_000,
batch_limit: int | None = None,
cache_directory: str | pathlib.Path | None = None,
) -> str | None:
) -> None:
"""
Update the ``ip_to_region.yaml`` file in the cache directory.

Expand Down Expand Up @@ -58,8 +60,8 @@ def update_ip_to_region_codes(
all_ips.update(_read_ips_from_file(file_path=full_ips_file))

ip_to_region = load_ip_cache(cache_type="ip_to_region", cache_directory=cache_directory)
ip_not_in_services = load_ip_cache(cache_type="ip_not_in_services", cache_directory=cache_directory)
ips_to_update = list(all_ips - set(ip_to_region.keys()))
ip_to_determined_region = {ip: region for ip, region in ip_to_region.items() if region != "undetermined"}
ips_to_update = list(all_ips - set(ip_to_determined_region.keys()))

# If a batch limit is set, shuffle the IPs to ensure repeated runs update different IPs
if batch_limit is not None:
Expand Down Expand Up @@ -88,61 +90,41 @@ def update_ip_to_region_codes(
position=1,
leave=False,
):
region_code = _get_region_code_from_ip_address(
ip_address=ip_address,
ipinfo_handler=ipinfo_handler,
ip_not_in_services=ip_not_in_services,
)

if region_code is None:
continue

# API limit reached; do not cache and wait for it to reset
if region_code == "unknown":
continue
region_code = _get_region_code_from_ip_address(ip_address=ip_address, ipinfo_handler=ipinfo_handler)
ip_to_region[ip_address] = region_code

with ip_to_region_file_path.open(mode="w") as file_stream:
yaml.dump(data=ip_to_region, stream=file_stream)

ip_not_in_services_file_path = ip_cache_directory / "ip_not_in_services.yaml"
with ip_not_in_services_file_path.open(mode="w") as file_stream:
yaml.dump(data=ip_not_in_services, stream=file_stream)


def _get_region_code_from_ip_address(
ip_address: str, ipinfo_handler: "ipinfo.Handler", ip_not_in_services: dict[str, bool]
) -> str | None:
ip_address: str,
ipinfo_handler: "ipinfo.Handler",
) -> str | typing.Literal["undetermined", "bogon"]:
import ipinfo

# Determine if IP address belongs to GitHub, AWS, Google, or known VPNs
# Determine if the IP address belongs to GitHub, AWS, Google, or known VPNs
# Azure not yet easily doable; keep an eye on
# https://learn.microsoft.com/en-us/answers/questions/1410071/up-to-date-azure-public-api-to-get-azure-ip-ranges
# maybe it will change in the future
if ip_address not in ip_not_in_services:
for service_name in _KNOWN_SERVICES:
cidr_addresses_and_subregions = _get_cidr_address_ranges_and_subregions(service_name=service_name)

matched_cidr_address_and_subregion = next(
(
(cidr_address, subregion)
for cidr_address, subregion in cidr_addresses_and_subregions
if _ip_in_cidr(ip_address=ip_address, cidr_address=cidr_address)
),
None,
)
if matched_cidr_address_and_subregion is not None:
region_service_string = service_name

subregion = matched_cidr_address_and_subregion[1]
if subregion is not None:
region_service_string += f"/{subregion}"

ip_not_in_services[ip_address] = False
return region_service_string

# TODO: make `ip_not_in_services` a `set`
ip_not_in_services[ip_address] = True
for service_name in _KNOWN_SERVICES:
cidr_addresses_and_subregions = _get_cidr_address_ranges_and_subregions(service_name=service_name)

matched_cidr_address_and_subregion = next(
(
(cidr_address, subregion)
for cidr_address, subregion in cidr_addresses_and_subregions
if _ip_in_cidr(ip_address=ip_address, cidr_address=cidr_address)
),
None,
)
if matched_cidr_address_and_subregion is not None:
region_service_string = service_name

subregion = matched_cidr_address_and_subregion[1]
if subregion is not None:
region_service_string += f"/{subregion}"
return region_service_string

# TODO: add batching support to ipinfo requests
# Lines cannot be covered without testing on a real IP
Expand All @@ -165,4 +147,9 @@ def _get_region_code_from_ip_address(

return region_string
except ipinfo.exceptions.RequestQuotaExceededError: # pragma: no cover
return "unknown"
warnings.warn(
msg="IPInfo API request quota exceeded. Returning 'undetermined' value.",
category=RuntimeWarning,
stacklevel=2,
)
return "undetermined"
Loading