Skip to content
Open
31 changes: 7 additions & 24 deletions app/ChirpHeliumJoinRpc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import psycopg2
import psycopg2.extras
import redis.asyncio as redis
import grpc
from google.protobuf.json_format import MessageToJson, MessageToDict
Expand All @@ -23,34 +21,19 @@ class ChirpstackJoins:
def __init__(
self,
route_id: str,
postgres_host: str,
postgres_user: str,
postgres_pass: str,
postgres_name: str,
postgres_port: str,
postgres_ssl_mode: str,
pool,
chirpstack_host: str,
chirpstack_token: str,
):
self.route_id = route_id
self.pg_host = postgres_host
self.pg_user = postgres_user
self.pg_pass = postgres_pass
self.pg_name = postgres_name
self.pg_port = postgres_port
self.pg_ssl_mode = postgres_ssl_mode
conn_str = f"postgresql://{self.pg_user}:{self.pg_pass}@{self.pg_host}:{self.pg_port}/{self.pg_name}"
if self.pg_ssl_mode[0] != "require":
self.postgres = conn_str
else:
self.postgres = "%s?sslmode=%s" % (conn_str, self.pg_ssl_mode)
self.pool = pool
self.cs_grpc = chirpstack_host
self.auth_token = [("authorization", f"Bearer {chirpstack_token}")]

def db_transaction(self, query: str):
with psycopg2.connect(self.postgres) as con:
with con.cursor() as cur:
cur.execute(query)
async def db_transaction(self, query: str):
async with self.pool.acquire() as con:
async with con.transaction():
await con.execute(query)

###########################################################################
# follow internal redis stream gRPC for actionable changes
Expand Down Expand Up @@ -160,4 +143,4 @@ async def add_session_key(self, dev_eui):
devices["fCntUp"],
devices["nFCntDown"],
)
self.db_transaction(query)
await self.db_transaction(query)
127 changes: 25 additions & 102 deletions app/ChirpHeliumKeysRpc.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,79 @@
from functools import wraps
# import asyncpg
import psycopg2
import psycopg2.extras
import grpc
from google.protobuf.json_format import MessageToDict
from chirpstack_api import api
import logging

from ChirpHeliumCrypto import get_route_skfs, update_device_skfs
from protos.helium import iot_config
from DatabasePool import Database


# def my_logger(orig_func):
# logging.basicConfig(
# filename='chirpstack-hpr.log',
# filemode='a',
# format='%(asctime)s %(levelname)s:%(name)s:%(message)s',
# level=logging.INFO,
# datefmt='%Y-%m-%d %H:%M:%S',
# )
# logging.getLogger("asyncio").setLevel(logging.INFO)
#
# @wraps(orig_func)
# def wrapper(*args, **kwargs):
# logging.info(
# f'Passed args: {args}, kwargs: {kwargs}')
# return orig_func(*args, **kwargs)
# return wrapper


class ChirpDeviceKeys:
def __init__(
self,
route_id: str,
postgres_host: str,
postgres_user: str,
postgres_pass: str,
postgres_name: str,
postgres_port: str,
postgres_ssl_mode: str,
pool,
chirpstack_host: str,
chirpstack_token: str,
):
self.route_id = route_id
self.pg_host = postgres_host
self.pg_user = postgres_user
self.pg_pass = postgres_pass
self.pg_name = postgres_name
self.pg_port = postgres_port
self.pg_ssl_mode = postgres_ssl_mode
conn_str = f"postgresql://{self.pg_user}:{self.pg_pass}@{self.pg_host}:{self.pg_port}/{self.pg_name}"
if self.pg_ssl_mode[0] != "require":
self.postgres = conn_str
else:
self.postgres = "%s?sslmode=%s" % (conn_str, self.pg_ssl_mode)
self.pool = pool
self.cs_gprc = chirpstack_host
self.auth_token = [("authorization", f"Bearer {chirpstack_token}")]

def db_fetch(self, query: str):
with psycopg2.connect(self.postgres) as con:
with con.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(query)
return cur.fetchall()

async def async_db_fetch(self, query: str):
db = Database()
await db.connect()
assert db.pool

async with db.pool.acquire() as conn:
async def db_fetch(self, query: str):
async with self.pool.acquire() as conn:
async with conn.transaction():
cur = await conn.fetch(query)
await db.close()
return cur

def db_transaction(self, query: str):
with psycopg2.connect(self.postgres) as con:
with con.cursor() as cur:
cur.execute(query)
return await conn.fetch(query)

async def async_db_transaction(self, query: str):
db = Database()
await db.connect()
assert db.pool
async with db.pool.acquire() as conn:
async def db_transaction(self, query: str):
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(query)

def fetch_all_devices(self) -> list[str]:
with psycopg2.connect(self.postgres) as con:
with con.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""
SELECT encode(dev_eui, 'hex') AS dev_eui
FROM device
WHERE is_disabled=false;
"""
)
return [dev['dev_eui'] for dev in cur.fetchall()]
async def fetch_all_devices(self) -> list[str]:
query = """
SELECT encode(dev_eui, 'hex') AS dev_eui
FROM device
WHERE is_disabled=false;
"""
return [dev['dev_eui'] for dev in await self.db_fetch(query)]

def chunker(self, seq, size):
return (seq[pos:pos + size] for pos in range(0, len(seq), size))

###########################################################################
# Chirpstack gRPC API calls
###########################################################################
# @my_logger
def get_device(self, dev_eui: str) -> dict[str]:
with grpc.insecure_channel(self.cs_gprc) as channel:
async def get_device(self, dev_eui: str) -> dict[str]:
async with grpc.aio.insecure_channel(self.cs_gprc) as channel:
client = api.DeviceServiceStub(channel)
req = api.GetDeviceRequest()
req.dev_eui = dev_eui
resp = client.Get(req, metadata=self.auth_token)
resp = await client.Get(req, metadata=self.auth_token)
data = MessageToDict(resp)["device"]
return data

# @my_logger
def get_device_activation(self, dev_eui: str) -> dict[str]:
with grpc.insecure_channel(self.cs_gprc) as channel:
async def get_device_activation(self, dev_eui: str) -> dict[str]:
async with grpc.aio.insecure_channel(self.cs_gprc) as channel:
client = api.DeviceServiceStub(channel)
req = api.GetDeviceActivationRequest()
req.dev_eui = dev_eui
resp = client.GetActivation(req, metadata=self.auth_token)
resp = await client.GetActivation(req, metadata=self.auth_token)
data = MessageToDict(resp)
if bool(data):
return data["deviceActivation"]
return data

# @my_logger
def get_merged_keys(self, dev_eui: str) -> dict[str]:
async def get_merged_keys(self, dev_eui: str) -> str:
devices = {
"devAddr": "",
"appSKey": "",
"nwkSEncKey": "",
"name": "",
}

devices.update(self.get_device(dev_eui))
devices.update(self.get_device_activation(dev_eui))
devices.update(await self.get_device(dev_eui))
devices.update(await self.get_device_activation(dev_eui))

max_copies = 0
if devices.get("variables") and "max_copies" in devices.get("variables"):
Expand Down Expand Up @@ -172,11 +107,9 @@ def get_merged_keys(self, dev_eui: str) -> dict[str]:
devices["fCntUp"],
devices["nFCntDown"],
)
self.db_transaction(query)
# await self.async_db_transaction(query)
await self.db_transaction(query)
return f"Updated: {dev_eui}"

# @my_logger
async def helium_skfs_update(self):
"""
TODO:
Expand All @@ -188,15 +121,10 @@ async def helium_skfs_update(self):
WHERE is_disabled=false
AND dev_addr != '';
"""
# all_helium_devices = self.db_fetch(helium_devices)
all_helium_devices = await self.async_db_fetch(helium_devices)
# logging.info(f"All Helium Devices: {all_helium_devices}")
all_helium_devices = await self.db_fetch(helium_devices)

skfs_list = await get_route_skfs()

# logging.info(f"All Helium Devices: {all_helium_devices}")
# logging.info(f"SKFS List: {skfs_list}")

# Convert the lists to sets for efficient set operations
# compare dev_addr & session_key for match else remove
all_helium_sessions_set = {
Expand All @@ -222,9 +150,6 @@ async def helium_skfs_update(self):
for d in skfs_list
}

# logging.info(f"All Helium Devices Set: {all_helium_devices_set}")
# logging.info(f"SKFS List Set: {skfs_list_set}")

# Devices to add to skfs_list
devices_to_add = all_helium_devices_set - skfs_list_set
logging.info(f"Devices_to_add: {devices_to_add}")
Expand All @@ -241,7 +166,6 @@ async def helium_skfs_update(self):
action=iot_config.ActionV1(1)
) for dev_addr, nws_key in devices_to_remove
]
# logging.info(f'RM-SKFS: {rm_skfs}')

if devices_to_add:
add_skfs = [
Expand All @@ -252,7 +176,6 @@ async def helium_skfs_update(self):
max_copies=max_copies
) for dev_addr, nws_key, max_copies in devices_to_add
]
# logging.info(f'ADD-SKFS: {add_skfs}')

skfs_action = rm_skfs + add_skfs

Expand Down
Loading