Skip to content

Commit 5da3608

Browse files
committed
refactored hf_api and added tests
1 parent 98191d5 commit 5da3608

4 files changed

Lines changed: 46 additions & 18 deletions

File tree

src/rowgen/hf_api.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from huggingface_hub import InferenceClient
2+
3+
from .utils import API_KEY
4+
5+
6+
class HFapi:
7+
def __init__(self, provider: str = "novita", api_key: str = API_KEY):
8+
self.client = InferenceClient(provider=provider, api_key=api_key)
9+
10+
def send_message_to_api(self, message: str) -> str:
11+
"""
12+
Sends a message to the AI and returns the response
13+
:param message: The prompt to be sent.
14+
:return: the response of AI model.
15+
"""
16+
completion = self.client.chat.completions.create(
17+
model="deepseek-ai/DeepSeek-V3",
18+
messages=[
19+
{
20+
"role": "user",
21+
"content": message,
22+
}
23+
],
24+
)
25+
return completion.choices[0].message["content"]

src/rowgen/insert_todb.py

Whitespace-only changes.

src/rowgen/main.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from dotenv import load_dotenv
77
from utils import API_KEY
88

9-
print()
10-
119

1210
class JsonDB:
1311
def __init__(self):
@@ -23,19 +21,6 @@ def __getitem__(self, item):
2321

2422
print(JsonDB()["columns"])
2523

26-
client = InferenceClient(provider="novita", api_key=API_KEY)
27-
completion = client.chat.completions.create(
28-
model="deepseek-ai/DeepSeek-V3",
29-
messages=[
30-
{
31-
"role": "user",
32-
"content": f"generate 10 rows for the following schema: {JsonDB().get_column_names()}. Use real looking data (no john doe or @example.com. also make emails beliavble such as using numbers underscores. sometimes even use unrelated to emails that are unrelate to names.). do not say anything. just generate data in a json format.",
33-
}
34-
],
35-
)
36-
37-
print(completion.choices[0])
38-
3924

4025
class JsonParse:
4126
def __init__(self, query: str):
@@ -64,6 +49,3 @@ def save_to_json(self, file_path: str = "output.json"):
6449

6550
with open(file_path, "w") as f:
6651
json.dump(data, f)
67-
68-
69-
JsonParse(completion.choices[0].message["content"]).save_to_json()

tests/test_ai_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from rowgen.utils import API_KEY
88
from huggingface_hub import InferenceClient, HfApi
9+
from rowgen.hf_api import HFapi
910

1011

1112
DEEP_SEEK_MODEL = "deepseek-ai/DeepSeek-V3"
@@ -23,6 +24,12 @@ def hf_api_client():
2324
return HfApi()
2425

2526

27+
@pytest.fixture
28+
def hf_api():
29+
client = HFapi()
30+
return client
31+
32+
2633
def test_hf_hub_ping(hf_api_client):
2734
try:
2835
info = hf_api_client.model_info(DEEP_SEEK_MODEL)
@@ -48,3 +55,17 @@ def test_hf_inference_connection(inference_client):
4855

4956
except Exception as d:
5057
pytest.fail(f"Failed inference call: {d}")
58+
59+
60+
def test_hf_api_send_message_to_api(hf_api):
61+
prompt_message = "Hi! This is a test. say banana if you can hear me."
62+
try:
63+
response_message = hf_api.send_message_to_api(prompt_message)
64+
assert len(response_message) > 0
65+
assert "banana" in response_message.casefold()
66+
67+
except BadRequestError as e:
68+
print(e)
69+
70+
except Exception as d:
71+
pytest.fail(f"Failed inference call: {d}")

0 commit comments

Comments
 (0)