Skip to content

Commit 7dd6cf1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5e2a592 commit 7dd6cf1

4 files changed

Lines changed: 28 additions & 17 deletions

File tree

src/vdf_io/export_vdf/weaviate_export.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,18 @@ def make_parser(cls, subparsers):
2727

2828
parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
2929
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
30-
parser_weaviate.add_argument("--openai_api_key", type=str, help="Openai API key")
30+
parser_weaviate.add_argument(
31+
"--openai_api_key", type=str, help="Openai API key"
32+
)
3133
parser_weaviate.add_arguments(
32-
"--batch_size", type=int, help="batch size for fetching",
33-
default=1000
34+
"--batch_size", type=int, help="batch size for fetching", default=1000
3435
)
3536
parser_weaviate.add_argument(
36-
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
37-
help="Type of connection to Weaviate (local or cloud)"
37+
"--connection-type",
38+
type=str,
39+
choices=["local", "cloud"],
40+
default="cloud",
41+
help="Type of connection to Weaviate (local or cloud)",
3842
)
3943
parser_weaviate.add_argument(
4044
"--classes", type=str, help="Classes to export (comma-separated)"
@@ -52,7 +56,7 @@ def export_vdb(cls, args):
5256
args,
5357
"connection_type",
5458
"Enter 'local' or 'cloud' for connection types: ",
55-
choices=['local', 'cloud'],
59+
choices=["local", "cloud"],
5660
)
5761
set_arg_from_password(
5862
args,
@@ -83,7 +87,7 @@ def __init__(self, args):
8387
self.client = weaviate.connect_to_wcs(
8488
cluster_url=self.args["url"],
8589
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
86-
headers={'X-OpenAI-Api-key': self.args["openai_api_key"]}
90+
headers={"X-OpenAI-Api-key": self.args["openai_api_key"]}
8791
if self.args["openai_api_key"]
8892
else None,
8993
skip_init_checks=True,
@@ -128,8 +132,11 @@ def get_data(self):
128132
metadata = {}
129133
# Need a better way
130134
for obj in objects:
131-
metadata[obj.id] = {attr: getattr(obj, attr) for attr in dir(obj) if not attr.startswith("__")}
132-
135+
metadata[obj.id] = {
136+
attr: getattr(obj, attr)
137+
for attr in dir(obj)
138+
if not attr.startswith("__")
139+
}
133140

134141
# Save vectors and metadata to Parquet file
135142
num_vectors_exported += self.save_vectors_to_parquet(
@@ -143,7 +150,7 @@ def get_data(self):
143150
vectors_directory,
144151
total=total_vector_count,
145152
num_vectors_exported=num_vectors_exported,
146-
dim=300, # Not sure of the dimensions
153+
dim=300, # Not sure of the dimensions
147154
distance="Cosine",
148155
)
149156
]

src/vdf_io/import_vdf/astradb_import.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from astrapy.db import AstraDB
88
from cassandra.cluster import Cluster
99
from cassandra.auth import PlainTextAuthProvider
10-
from qdrant_client.http.models import Distance
1110

1211
from vdf_io.constants import INT_MAX
1312
from vdf_io.names import DBNames

src/vdf_io/import_vdf/lancedb_import.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def upsert_data(self):
117117
for col in df.columns:
118118
if col not in [field.name for field in table.schema]:
119119
col_type = df[col].dtype
120-
tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}")
120+
tqdm.write(
121+
f"Adding column {col} of type {col_type} to {new_index_name}"
122+
)
121123
table.add_columns(
122124
{
123125
col: get_default_value(col_type),

src/vdf_io/import_vdf/weaviate_import.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import weaviate
3-
import json
43
from tqdm import tqdm
54
from vdf_io.import_vdf.vdf_import_cls import ImportVDB
65
from vdf_io.names import DBNames
@@ -58,7 +57,7 @@ def __init__(self, args):
5857
self.client = weaviate.connect_to_wcs(
5958
cluster_url=self.args["url"],
6059
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
61-
headers={'X-OpenAI-Api-key': self.args["openai_api_key"]}
60+
headers={"X-OpenAI-Api-key": self.args["openai_api_key"]}
6261
if self.args["openai_api_key"]
6362
else None,
6463
skip_init_checks=True,
@@ -69,13 +68,17 @@ def upsert_data(self):
6968
total_imported_count = 0
7069

7170
# Iterate over the indexes and import the data
72-
for index_name, index_meta in tqdm(self.vdf_meta["indexes"].items(), desc="Importing indexes"):
71+
for index_name, index_meta in tqdm(
72+
self.vdf_meta["indexes"].items(), desc="Importing indexes"
73+
):
7374
tqdm.write(f"Importing data for index '{index_name}'")
7475
for namespace_meta in index_meta:
7576
self.set_dims(namespace_meta, index_name)
7677

7778
# Create or get the index
78-
index_name = self.create_new_name(index_name, self.client.collections.list_all().keys())
79+
index_name = self.create_new_name(
80+
index_name, self.client.collections.list_all().keys()
81+
)
7982
index = self.client.collections.get(index_name)
8083

8184
# Load data from the Parquet files
@@ -136,4 +139,4 @@ def upsert_data(self):
136139
# continue
137140

138141
# tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
139-
# self.args["imported_count"] = total_imported_count
142+
# self.args["imported_count"] = total_imported_count

0 commit comments

Comments
 (0)