Skip to content

Commit 16e211d

Browse files
committed
feat: #1276 add Asyncio SQLAlchemy support
1 parent 86169db commit 16e211d

File tree

5 files changed

+680
-1
lines changed

5 files changed

+680
-1
lines changed

requirements/testing.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ boto3<=2
1717
# For AWS tests
1818
moto>=4.0.13,<6
1919
mypy<=1.14.1
20+
# For AsyncSQLAlchemy tests
21+
greenlet<=4
22+
aiosqlite<=1

slack_sdk/oauth/installation_store/sqlalchemy/__init__.py

Lines changed: 280 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
)
1717
from sqlalchemy.engine import Engine
1818
from sqlalchemy.sql.sqltypes import Boolean
19-
19+
from sqlalchemy.ext.asyncio import AsyncEngine
2020
from slack_sdk.oauth.installation_store.installation_store import InstallationStore
2121
from slack_sdk.oauth.installation_store.models.bot import Bot
2222
from slack_sdk.oauth.installation_store.models.installation import Installation
23+
from slack_sdk.oauth.installation_store.async_installation_store import (
24+
AsyncInstallationStore,
25+
)
2326

2427

2528
class SQLAlchemyInstallationStore(InstallationStore):
@@ -362,3 +365,279 @@ def delete_installation(
362365
)
363366
)
364367
conn.execute(deletion)
368+
369+
370+
class AsyncSQLAlchemyInstallationStore(AsyncInstallationStore):
371+
default_bots_table_name: str = "slack_bots"
372+
default_installations_table_name: str = "slack_installations"
373+
374+
client_id: str
375+
engine: AsyncEngine
376+
metadata: MetaData
377+
installations: Table
378+
379+
def __init__(
380+
self,
381+
client_id: str,
382+
engine: AsyncEngine,
383+
bots_table_name: str = default_bots_table_name,
384+
installations_table_name: str = default_installations_table_name,
385+
logger: Logger = logging.getLogger(__name__),
386+
):
387+
self.metadata = sqlalchemy.MetaData()
388+
self.bots = self.build_bots_table(metadata=self.metadata, table_name=bots_table_name)
389+
self.installations = self.build_installations_table(metadata=self.metadata, table_name=installations_table_name)
390+
self.client_id = client_id
391+
self._logger = logger
392+
self.engine = engine
393+
394+
@classmethod
395+
def build_installations_table(cls, metadata: MetaData, table_name: str) -> Table:
396+
return SQLAlchemyInstallationStore.build_installations_table(metadata, table_name)
397+
398+
@classmethod
399+
def build_bots_table(cls, metadata: MetaData, table_name: str) -> Table:
400+
return SQLAlchemyInstallationStore.build_bots_table(metadata, table_name)
401+
402+
async def create_tables(self):
403+
async with self.engine.begin() as conn:
404+
await conn.run_sync(self.metadata.create_all)
405+
406+
@property
407+
def logger(self) -> Logger:
408+
return self._logger
409+
410+
async def async_save(self, installation: Installation):
411+
async with self.engine.begin() as conn:
412+
i = installation.to_dict()
413+
i["client_id"] = self.client_id
414+
415+
i_column = self.installations.c
416+
installations_rows = await conn.execute(
417+
sqlalchemy.select(i_column.id)
418+
.where(
419+
and_(
420+
i_column.client_id == self.client_id,
421+
i_column.enterprise_id == installation.enterprise_id,
422+
i_column.team_id == installation.team_id,
423+
i_column.installed_at == i.get("installed_at"),
424+
)
425+
)
426+
.limit(1)
427+
)
428+
installations_row_id: Optional[str] = None
429+
for row in installations_rows.mappings():
430+
installations_row_id = row["id"]
431+
if installations_row_id is None:
432+
await conn.execute(self.installations.insert(), i)
433+
else:
434+
update_statement = self.installations.update().where(i_column.id == installations_row_id).values(**i)
435+
await conn.execute(update_statement, i)
436+
437+
# bots
438+
await self.async_save_bot(installation.to_bot())
439+
440+
async def async_save_bot(self, bot: Bot):
441+
async with self.engine.begin() as conn:
442+
# bots
443+
b = bot.to_dict()
444+
b["client_id"] = self.client_id
445+
446+
b_column = self.bots.c
447+
bots_rows = await conn.execute(
448+
sqlalchemy.select(b_column.id)
449+
.where(
450+
and_(
451+
b_column.client_id == self.client_id,
452+
b_column.enterprise_id == bot.enterprise_id,
453+
b_column.team_id == bot.team_id,
454+
b_column.installed_at == b.get("installed_at"),
455+
)
456+
)
457+
.limit(1)
458+
)
459+
bots_row_id: Optional[str] = None
460+
for row in bots_rows.mappings():
461+
bots_row_id = row["id"]
462+
if bots_row_id is None:
463+
await conn.execute(self.bots.insert(), b)
464+
else:
465+
update_statement = self.bots.update().where(b_column.id == bots_row_id).values(**b)
466+
await conn.execute(update_statement, b)
467+
468+
async def async_find_bot(
469+
self,
470+
*,
471+
enterprise_id: Optional[str],
472+
team_id: Optional[str],
473+
is_enterprise_install: Optional[bool] = False,
474+
) -> Optional[Bot]:
475+
if is_enterprise_install or team_id is None:
476+
team_id = None
477+
478+
c = self.bots.c
479+
query = (
480+
self.bots.select()
481+
.where(
482+
and_(
483+
c.client_id == self.client_id,
484+
c.enterprise_id == enterprise_id,
485+
c.team_id == team_id,
486+
c.bot_token.is_not(None), # the latest one that has a bot token
487+
)
488+
)
489+
.order_by(desc(c.installed_at))
490+
.limit(1)
491+
)
492+
493+
async with self.engine.connect() as conn:
494+
result: object = await conn.execute(query)
495+
for row in result.mappings(): # type: ignore[attr-defined]
496+
return Bot(
497+
app_id=row["app_id"],
498+
enterprise_id=row["enterprise_id"],
499+
enterprise_name=row["enterprise_name"],
500+
team_id=row["team_id"],
501+
team_name=row["team_name"],
502+
bot_token=row["bot_token"],
503+
bot_id=row["bot_id"],
504+
bot_user_id=row["bot_user_id"],
505+
bot_scopes=row["bot_scopes"],
506+
bot_refresh_token=row["bot_refresh_token"],
507+
bot_token_expires_at=row["bot_token_expires_at"],
508+
is_enterprise_install=row["is_enterprise_install"],
509+
installed_at=row["installed_at"],
510+
)
511+
return None
512+
513+
async def async_find_installation(
514+
self,
515+
*,
516+
enterprise_id: Optional[str],
517+
team_id: Optional[str],
518+
user_id: Optional[str] = None,
519+
is_enterprise_install: Optional[bool] = False,
520+
) -> Optional[Installation]:
521+
if is_enterprise_install or team_id is None:
522+
team_id = None
523+
524+
c = self.installations.c
525+
where_clause = and_(
526+
c.client_id == self.client_id,
527+
c.enterprise_id == enterprise_id,
528+
c.team_id == team_id,
529+
)
530+
if user_id is not None:
531+
where_clause = and_(
532+
c.client_id == self.client_id,
533+
c.enterprise_id == enterprise_id,
534+
c.team_id == team_id,
535+
c.user_id == user_id,
536+
)
537+
538+
query = self.installations.select().where(where_clause).order_by(desc(c.installed_at)).limit(1)
539+
540+
installation: Optional[Installation] = None
541+
async with self.engine.connect() as conn:
542+
result: object = await conn.execute(query)
543+
for row in result.mappings(): # type: ignore[attr-defined]
544+
installation = Installation(
545+
app_id=row["app_id"],
546+
enterprise_id=row["enterprise_id"],
547+
enterprise_name=row["enterprise_name"],
548+
enterprise_url=row["enterprise_url"],
549+
team_id=row["team_id"],
550+
team_name=row["team_name"],
551+
bot_token=row["bot_token"],
552+
bot_id=row["bot_id"],
553+
bot_user_id=row["bot_user_id"],
554+
bot_scopes=row["bot_scopes"],
555+
bot_refresh_token=row["bot_refresh_token"],
556+
bot_token_expires_at=row["bot_token_expires_at"],
557+
user_id=row["user_id"],
558+
user_token=row["user_token"],
559+
user_scopes=row["user_scopes"],
560+
user_refresh_token=row["user_refresh_token"],
561+
user_token_expires_at=row["user_token_expires_at"],
562+
# Only the incoming webhook issued in the latest installation is set in this logic
563+
incoming_webhook_url=row["incoming_webhook_url"],
564+
incoming_webhook_channel=row["incoming_webhook_channel"],
565+
incoming_webhook_channel_id=row["incoming_webhook_channel_id"],
566+
incoming_webhook_configuration_url=row["incoming_webhook_configuration_url"],
567+
is_enterprise_install=row["is_enterprise_install"],
568+
token_type=row["token_type"],
569+
installed_at=row["installed_at"],
570+
)
571+
572+
has_user_installation = user_id is not None and installation is not None
573+
no_bot_token_installation = installation is not None and installation.bot_token is None
574+
should_find_bot_installation = has_user_installation or no_bot_token_installation
575+
if should_find_bot_installation:
576+
# Retrieve the latest bot token, just in case
577+
# See also: https://github.com/slackapi/bolt-python/issues/664
578+
latest_bot_installation = await self.async_find_bot(
579+
enterprise_id=enterprise_id,
580+
team_id=team_id,
581+
is_enterprise_install=is_enterprise_install,
582+
)
583+
if (
584+
latest_bot_installation is not None
585+
and installation is not None
586+
and installation.bot_token != latest_bot_installation.bot_token
587+
):
588+
installation.bot_id = latest_bot_installation.bot_id
589+
installation.bot_user_id = latest_bot_installation.bot_user_id
590+
installation.bot_token = latest_bot_installation.bot_token
591+
installation.bot_scopes = latest_bot_installation.bot_scopes
592+
installation.bot_refresh_token = latest_bot_installation.bot_refresh_token
593+
installation.bot_token_expires_at = latest_bot_installation.bot_token_expires_at
594+
595+
return installation
596+
597+
async def async_delete_bot(
598+
self,
599+
*,
600+
enterprise_id: Optional[str],
601+
team_id: Optional[str],
602+
) -> None:
603+
table = self.bots
604+
c = table.c
605+
async with self.engine.begin() as conn:
606+
deletion = table.delete().where(
607+
and_(
608+
c.client_id == self.client_id,
609+
c.enterprise_id == enterprise_id,
610+
c.team_id == team_id,
611+
)
612+
)
613+
await conn.execute(deletion)
614+
615+
async def async_delete_installation(
616+
self,
617+
*,
618+
enterprise_id: Optional[str],
619+
team_id: Optional[str],
620+
user_id: Optional[str] = None,
621+
) -> None:
622+
table = self.installations
623+
c = table.c
624+
async with self.engine.begin() as conn:
625+
if user_id is not None:
626+
deletion = table.delete().where(
627+
and_(
628+
c.client_id == self.client_id,
629+
c.enterprise_id == enterprise_id,
630+
c.team_id == team_id,
631+
c.user_id == user_id,
632+
)
633+
)
634+
await conn.execute(deletion)
635+
else:
636+
deletion = table.delete().where(
637+
and_(
638+
c.client_id == self.client_id,
639+
c.enterprise_id == enterprise_id,
640+
c.team_id == team_id,
641+
)
642+
)
643+
await conn.execute(deletion)

slack_sdk/oauth/state_store/sqlalchemy/__init__.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from uuid import uuid4
66

77
from ..state_store import OAuthStateStore
8+
from ..async_state_store import AsyncOAuthStateStore
89
import sqlalchemy
910
from sqlalchemy import Table, Column, Integer, String, DateTime, and_, MetaData
1011
from sqlalchemy.engine import Engine
12+
from sqlalchemy.ext.asyncio import AsyncEngine
1113

1214

1315
class SQLAlchemyOAuthStateStore(OAuthStateStore):
@@ -76,3 +78,72 @@ def consume(self, state: str) -> bool:
7678
message = f"Failed to find any persistent data for state: {state} - {e}"
7779
self.logger.warning(message)
7880
return False
81+
82+
83+
class AsyncSQLAlchemyOAuthStateStore(AsyncOAuthStateStore):
84+
default_table_name: str = "slack_oauth_states"
85+
86+
expiration_seconds: int
87+
engine: AsyncEngine
88+
metadata: MetaData
89+
oauth_states: Table
90+
91+
@classmethod
92+
def build_oauth_states_table(cls, metadata: MetaData, table_name: str) -> Table:
93+
return sqlalchemy.Table(
94+
table_name,
95+
metadata,
96+
metadata,
97+
Column("id", Integer, primary_key=True, autoincrement=True),
98+
Column("state", String(200), nullable=False),
99+
Column("expire_at", DateTime, nullable=False),
100+
)
101+
102+
def __init__(
103+
self,
104+
expiration_seconds: int,
105+
engine: Engine,
106+
logger: Logger = logging.getLogger(__name__),
107+
table_name: str = default_table_name,
108+
):
109+
self.expiration_seconds = expiration_seconds
110+
self._logger = logger
111+
self.engine = engine
112+
self.metadata = MetaData()
113+
self.oauth_states = self.build_oauth_states_table(self.metadata, table_name)
114+
115+
async def create_tables(self):
116+
async with self.engine.begin() as conn:
117+
await conn.run_sync(self.metadata.create_all)
118+
119+
@property
120+
def logger(self) -> Logger:
121+
if self._logger is None:
122+
self._logger = logging.getLogger(__name__)
123+
return self._logger
124+
125+
async def async_issue(self, *args, **kwargs) -> str:
126+
state: str = str(uuid4())
127+
now = datetime.utcfromtimestamp(time.time() + self.expiration_seconds)
128+
async with self.engine.begin() as conn:
129+
await conn.execute(
130+
self.oauth_states.insert(),
131+
{"state": state, "expire_at": now},
132+
)
133+
return state
134+
135+
async def async_consume(self, state: str) -> bool:
136+
try:
137+
async with self.engine.begin() as conn:
138+
c = self.oauth_states.c
139+
query = self.oauth_states.select().where(and_(c.state == state, c.expire_at > datetime.utcnow()))
140+
result = await conn.execute(query)
141+
for row in result.mappings():
142+
self.logger.debug(f"consume's query result: {row}")
143+
await conn.execute(self.oauth_states.delete().where(c.id == row["id"]))
144+
return True
145+
return False
146+
except Exception as e:
147+
message = f"Failed to find any persistent data for state: {state} - {e}"
148+
self.logger.warning(message)
149+
return False

0 commit comments

Comments
 (0)