|
24 | 24 | MIN_IMAGE_WIDTH = 1500 |
25 | 25 |
|
26 | 26 |
|
| 27 | +CODE_BLOCK_RE = re.compile(r"```(\w+)\n(.*?)```", re.DOTALL) |
| 28 | + |
| 29 | + |
| 30 | +def strip_code_block(source: str) -> str: |
| 31 | + source = source.strip() |
| 32 | + if match := CODE_BLOCK_RE.fullmatch(source): |
| 33 | + return match.group(2).strip() |
| 34 | + if match := re.fullmatch(r"`(.*?)`", source, re.DOTALL): |
| 35 | + return match.group(1).strip() |
| 36 | + return source |
| 37 | + |
| 38 | + |
27 | 39 | class CompileError(Exception): |
28 | 40 | pass |
29 | 41 |
|
@@ -225,13 +237,23 @@ async def get_default_renderer(self, message: discord.Message): |
225 | 237 |
|
226 | 238 | @commands.Cog.listener() |
227 | 239 | async def on_message(self, message): |
228 | | - if message.author.bot or re.search(r"\$.+\$", message.content) is None: |
| 240 | + if message.author.bot: |
229 | 241 | return |
230 | 242 | ctx = await self.bot.get_context(message) |
231 | 243 | if ctx.command is not None: |
232 | 244 | return |
233 | | - renderer = await self.get_default_renderer(message) |
234 | | - await self.process_math(ctx, renderer, message.clean_content) |
| 245 | + |
| 246 | + if match := CODE_BLOCK_RE.search(message.content): |
| 247 | + lang = match.group(1).lower() |
| 248 | + if lang in self.renderer_by_key: |
| 249 | + renderer = self.renderer_by_key[lang] |
| 250 | + source = match.group(2).strip() |
| 251 | + await self.process_math(ctx, renderer, source) |
| 252 | + return |
| 253 | + |
| 254 | + if re.search(r"\$.+\$", message.content) is not None: |
| 255 | + renderer = await self.get_default_renderer(message) |
| 256 | + await self.process_math(ctx, renderer, message.clean_content) |
235 | 257 |
|
236 | 258 | @commands.command(aliases=("latex",)) |
237 | 259 | async def tex(self, ctx, file: Optional[discord.Attachment], *, source: str | None = None): |
@@ -295,6 +317,7 @@ async def process_math_command( |
295 | 317 | raise commands.MissingRequiredArgument(ctx.command.clean_params["source"]) |
296 | 318 |
|
297 | 319 | async def process_math(self, ctx: Context, renderer: MathRenderer, source: str): |
| 320 | + source = strip_code_block(source) |
298 | 321 | async with ctx.typing(): |
299 | 322 | view = MathView(ctx, source, renderer, self.renderers) |
300 | 323 | await view.send(ctx.channel) |
|
0 commit comments