Skip to content

Commit 4d2a4ac

Browse files
fix: 修复工具调用枚举
1 parent ecfb327 commit 4d2a4ac

8 files changed

Lines changed: 768 additions & 122 deletions

File tree

=0.7.3

Whitespace-only changes.

amrita_plugin_omikuji/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@
99
require("amrita.plugins.chat")
1010
require("amrita.plugins.menu")
1111

12-
from amrita.plugins.chat.API import ToolsManager
13-
1412
from . import commands, llm_tool, sql_models
1513
from .cache import OmikujiCache
16-
from .config import get_cache_dir, get_config
17-
from .llm_tool import TOOL_DATA
14+
from .config import get_cache_dir
1815

1916
__plugin_meta__ = PluginMetadata(
2017
name="御神签",
@@ -33,10 +30,6 @@ async def init():
3330
version = metadata.version("amrita_plugin_omikuji")
3431
if "dev" in version:
3532
logger.warning("当前版本为开发版本,可能存在不稳定情况!")
36-
logger.info(f"Loading OMIKUJI V{version}......")
37-
conf = get_config()
38-
if conf.enable_omikuji:
39-
ToolsManager().register_tool(TOOL_DATA)
4033
logger.info("正在初始化缓存数据......")
4134
os.makedirs(get_cache_dir(), exist_ok=True)
4235
for cache in get_cache_dir().glob("*.json"):

amrita_plugin_omikuji/config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,13 @@ def check(self) -> Config:
4242
return self
4343

4444

45+
PLUGIN_CONFIG = get_plugin_config(Config)
46+
CACHE_DIR = get_plugin_cache_dir()
47+
48+
4549
def get_config() -> Config:
46-
return get_plugin_config(Config)
50+
return PLUGIN_CONFIG
4751

4852

4953
def get_cache_dir() -> Path:
50-
return get_plugin_cache_dir()
54+
return CACHE_DIR

amrita_plugin_omikuji/llm_tool.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
from amrita.plugins.chat.API import (
44
ToolContext,
5-
ToolData,
5+
on_tools,
66
)
77
from nonebot import get_bot, logger
88
from nonebot.adapters.onebot.v11 import MessageEvent
99

10-
from amrita_plugin_omikuji.cache import cache_omikuji, get_cached_omikuji
11-
10+
from .cache import cache_omikuji, get_cached_omikuji
1211
from .config import get_config
13-
from .models import FUNC_META, OmikujiData
12+
from .models import FUNC_DEFINTION, OmikujiData
1413
from .utils import generate_omikuji
1514

1615
LEVEL = ["大吉", "吉", "中吉", "小吉", "末吉", "凶", "大凶"]
@@ -34,7 +33,8 @@ def format_omikuji(data: OmikujiData, user_name: str | None = ""):
3433
return msg
3534

3635

37-
async def omikuji(ctx: ToolContext):
36+
@on_tools(FUNC_DEFINTION, custom_run=True, strict=True)
37+
async def omikuji(ctx: ToolContext) -> str:
3838
logger.info("获取御神签")
3939
nb_event: MessageEvent = typing.cast(MessageEvent, ctx.event.get_nonebot_event())
4040
is_group = hasattr(nb_event, "group_id")
@@ -51,7 +51,4 @@ async def omikuji(ctx: ToolContext):
5151
return data.model_dump_json()
5252
msg = format_omikuji(data)
5353
await bot.send(nb_event, msg)
54-
ctx.matcher.cancel_nonebot_process()
55-
56-
57-
TOOL_DATA = ToolData(data=FUNC_META, func=omikuji, custom_run=True)
54+
return "Generated a omikuji for user."

amrita_plugin_omikuji/models.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,20 @@ class OmikujiData(BaseModel):
124124
),
125125
)
126126

127-
FUNC_META = ToolFunctionSchema(
128-
strict=True,
129-
function=FunctionDefinitionSchema(
130-
name="omikuji",
131-
description="抽取一个御神签",
132-
parameters=FunctionParametersSchema(
133-
type="object",
134-
properties={
135-
"theme": FunctionPropertySchema(
136-
type="string",
137-
description="御神签主题(如果包含不良内容则随机选择)",
138-
)
139-
},
140-
required=["theme"],
141-
),
127+
FUNC_DEFINTION = FunctionDefinitionSchema(
128+
name="omikuji",
129+
description="抽取一个御神签",
130+
parameters=FunctionParametersSchema(
131+
type="object",
132+
properties={
133+
"theme": FunctionPropertySchema(
134+
type="string",
135+
description="御神签主题(如果包含不良内容则随机选择)",
136+
enum=OMIKUJI_THEMES,
137+
)
138+
},
139+
required=["theme"],
142140
),
143141
)
142+
143+
FUNC_META = ToolFunctionSchema(strict=True, function=FUNC_DEFINTION)

amrita_plugin_omikuji/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def _hit_cache_omikuji(
5959
async def generate_omikuji(
6060
theme: THEME_TYPE,
6161
is_group: bool = False,
62-
level: str = "",
62+
level: str | None = None,
6363
) -> OmikujiData:
6464
config = get_config()
6565
level = level or random_level()

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "amrita-plugin-omikuji"
3-
version = "0.1.1.1"
3+
version = "0.1.2"
44
description = "AmritaBot的御神签插件"
55
readme = "README.md"
66
requires-python = ">=3.10, <4.0"
@@ -10,7 +10,7 @@ dependencies = [
1010
"nonebot-plugin-localstore>=0.7.4",
1111
"aiofiles>=24.1.0",
1212
"nonebot-plugin-orm>=0.8.2",
13-
"amrita[full]>=0.4.4",
13+
"amrita[full]==0.7.3",
1414
"zipp>=3.23.0",
1515
]
1616
license = "GPL-3.0-or-later"

0 commit comments

Comments
 (0)