Skip to content

Commit e7ac714

Browse files
feat: strip code block backticks and auto-render tagged code blocks (#4)
Co-authored-by: Oliver Ni <oliver.ni@gmail.com>
1 parent 6a12456 commit e7ac714

1 file changed

Lines changed: 26 additions & 3 deletions

File tree

bmt_discord_bot/cogs/math.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@
2424
MIN_IMAGE_WIDTH = 1500
2525

2626

27+
CODE_BLOCK_RE = re.compile(r"```(\w+)\n(.*?)```", re.DOTALL)
28+
29+
30+
def strip_code_block(source: str) -> str:
31+
source = source.strip()
32+
if match := CODE_BLOCK_RE.fullmatch(source):
33+
return match.group(2).strip()
34+
if match := re.fullmatch(r"`(.*?)`", source, re.DOTALL):
35+
return match.group(1).strip()
36+
return source
37+
38+
2739
class CompileError(Exception):
2840
pass
2941

@@ -225,13 +237,23 @@ async def get_default_renderer(self, message: discord.Message):
225237

226238
@commands.Cog.listener()
227239
async def on_message(self, message):
228-
if message.author.bot or re.search(r"\$.+\$", message.content) is None:
240+
if message.author.bot:
229241
return
230242
ctx = await self.bot.get_context(message)
231243
if ctx.command is not None:
232244
return
233-
renderer = await self.get_default_renderer(message)
234-
await self.process_math(ctx, renderer, message.clean_content)
245+
246+
if match := CODE_BLOCK_RE.search(message.content):
247+
lang = match.group(1).lower()
248+
if lang in self.renderer_by_key:
249+
renderer = self.renderer_by_key[lang]
250+
source = match.group(2).strip()
251+
await self.process_math(ctx, renderer, source)
252+
return
253+
254+
if re.search(r"\$.+\$", message.content) is not None:
255+
renderer = await self.get_default_renderer(message)
256+
await self.process_math(ctx, renderer, message.clean_content)
235257

236258
@commands.command(aliases=("latex",))
237259
async def tex(self, ctx, file: Optional[discord.Attachment], *, source: str | None = None):
@@ -295,6 +317,7 @@ async def process_math_command(
295317
raise commands.MissingRequiredArgument(ctx.command.clean_params["source"])
296318

297319
async def process_math(self, ctx: Context, renderer: MathRenderer, source: str):
320+
source = strip_code_block(source)
298321
async with ctx.typing():
299322
view = MathView(ctx, source, renderer, self.renderers)
300323
await view.send(ctx.channel)

0 commit comments

Comments
 (0)