Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gbd_core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def set_values(self, name, value, hashes, target_db=None):
raise GBDException("Feature '{}' does not exist".format(name))
if not len(hashes):
raise GBDException("No hashes given")
self.database.set_values(name, value, hashes, target_db)
self.database.set_values({name: value}, hashes, target_db)

def reset_values(self, feature, values=[], hashes=[], target_db=None):
"""Reset feature value for given hashes
Expand Down
12 changes: 9 additions & 3 deletions gbd_core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,15 @@ def create_feature(self, name, default_value=None, target_db=None, permissive=Fa
# this code disregards feature precedence by database position:
self.features[finfo.name].append(finfo)

def set_values(self, fname, value, hashes, target_db=None):
finfo = self.finfo(fname, target_db)
self.schemas[finfo.database].set_values(fname, value, hashes)
def set_values(self, mappings: typing.Dict[str, typing.Any], hashes: list[str], target_db=None):
db_mappings = {}
for fname, value in mappings.items():
finfo = self.finfo(fname, target_db)
if finfo.database not in db_mappings:
db_mappings[finfo.database] = {}
db_mappings[finfo.database][fname] = value
for database, mappings in db_mappings.items():
self.schemas[database].set_values(mappings, hashes)

def rename_feature(self, fname, new_fname, target_db=None):
Schema.valid_feature_or_raise(new_fname)
Expand Down
38 changes: 23 additions & 15 deletions gbd_core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,20 +392,28 @@ def create_feature(self, name, default_value=None, permissive=False):

return created

def set_values(self, feature, value, hashes):
if not self.has_feature(feature):
raise SchemaException("Feature '{}' does not exist".format(feature))
def set_values(self, mappings: typing.Dict[str, typing.Any], hashes: list[str]):
if not len(hashes):
raise SchemaException("No hashes given")
table = self.features[feature].table
column = self.features[feature].column
values = ", ".join(["('{}', '{}')".format(hash, value) for hash in hashes])
if self.features[feature].default is None:
self.execute("INSERT OR IGNORE INTO {tab} (hash, {col}) VALUES {vals}".format(tab=table, col=column, vals=values))
self.execute("UPDATE features SET {col}=hash WHERE hash in ('{h}')".format(col=table, h="', '".join(hashes)))
else:
self.execute(
"INSERT INTO {tab} (hash, {col}) VALUES {vals} ON CONFLICT (hash) DO UPDATE SET {col}='{val}' WHERE hash in ('{h}')".format(
tab=table, col=column, val=value, vals=values, h="', '".join(hashes)
)
)
hash_list = ", ".join(f"'{hash}'" for hash in hashes)
# Apply values to non-unique features, but collect unique feature values
unique_mappings = {}
for feature, value in mappings.items():
if not self.has_feature(feature):
raise SchemaException("Feature '{}' does not exist".format(feature))
table = self.features[feature].table
column = self.features[feature].column
if self.features[feature].default is None:
values = ", ".join(f"('{hash}', '{value}')" for hash in hashes)
self.execute(f"INSERT OR IGNORE INTO {table} (hash, {column}) VALUES {values}")
self.execute(f"UPDATE features SET {table}=hash WHERE hash in ({hash_list})")
continue
assert table == "features"
unique_mappings[column] = value
if not unique_mappings:
return
# Execute one query for the main table
columns = ", ".join(unique_mappings)
values = ", ".join(f"('{hash}', {", ".join(f"'{val}'" for val in unique_mappings.values())})" for hash in hashes)
updates = ", ".join(f"{col}='{val}'" for col, val in unique_mappings.items())
self.execute(f"INSERT INTO features (hash, {columns}) VALUES {values} ON CONFLICT (hash) DO UPDATE SET {updates} WHERE hash in ({hash_list})")
9 changes: 6 additions & 3 deletions gbd_init/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def create_features(self):
self.api.database.commit()

def save_features(self, result: list):
for attr in result:
name, hashv, value = attr[0], attr[1], attr[2]
self.api.database.set_values(name, value, [hashv], self.target_db)
hashv = result[0][1]
mappings = {}
for name, hashv2, value in result:
assert hashv == hashv2
mappings[name] = value
self.api.database.set_values(mappings, [hashv], self.target_db)
self.api.database.commit()

def run(self, instances: pl.DataFrame):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def tearDown(self) -> None:
return super().tearDown()

def test_databases_exist(self):
self.assertEquals(self.api.get_databases(), [ self.name1, self.name2 ])
self.assertEquals(self.api.get_database_path(self.name1), self.file1)
self.assertEquals(self.api.get_database_path(self.name2), self.file2)
self.assertEqual(self.api.get_databases(), [ self.name1, self.name2 ])
self.assertEqual(self.api.get_database_path(self.name1), self.file1)
self.assertEqual(self.api.get_database_path(self.name2), self.file2)

def test_create_feature(self):
self.api.create_feature("A", None, self.name1)
Expand Down Expand Up @@ -111,4 +111,4 @@ def test_reset_values(self):
self.api.database.commit()
api2 = GBD([self.file2])
df: pl.DataFrame = api2.query("A = value1", resolve=["A"])
self.assertCountEqual(df["A"].to_list(), [ "value1" for _ in range(50) ])
self.assertCountEqual(df["A"].to_list(), [ "value1" for _ in range(50) ])
6 changes: 3 additions & 3 deletions tests/test_db_nonunique_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def setUp(self) -> None:
self.name = Schema.dbname_from_path(self.file)
self.db = Database([self.file], verbose=False)
self.db.create_feature(self.feat, default_value=None)
self.db.set_values(self.feat, self.val1, ["a", "b", "c"])
self.db.set_values(self.feat, self.val2, ["a", "b", "c"])
self.db.set_values({self.feat: self.val1}, ["a", "b", "c"])
self.db.set_values({self.feat: self.val2}, ["a", "b", "c"])
return super().setUp()

def tearDown(self) -> None:
Expand Down Expand Up @@ -87,4 +87,4 @@ def test_feature_values_delete_value(self):
# Delete feature
def test_nonunique_feature_delete(self):
self.db.delete_feature(self.feat)
self.assertRaises(DatabaseException, self.db.find, self.feat)
self.assertRaises(DatabaseException, self.db.find, self.feat)
4 changes: 2 additions & 2 deletions tests/test_db_unique_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self) -> None:
self.name = Schema.dbname_from_path(self.file)
self.db = Database([self.file], verbose=False)
self.db.create_feature(self.feat, default_value=self.defv)
self.db.set_values(self.feat, self.val1, ["a", "b", "c"])
self.db.set_values({self.feat: self.val1}, ["a", "b", "c"])
return super().setUp()

def tearDown(self) -> None:
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_unique_feature_values_exist(self):

# Overwrite one value and check if it is set correctly and that the other values are still there
def test_unique_feature_values_overwrite(self):
self.db.set_values(self.feat, self.val2, ["a"])
self.db.set_values({self.feat: self.val2}, ["a"])
res = self.query(self.feat, self.val1)
self.assertEqual(len(res), 2)
self.assertSetEqual(set(res), set(["b", "c"]))
Expand Down
16 changes: 8 additions & 8 deletions tests/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def setUp(self) -> None:
self.db = Database([self.file1,self.file2], verbose=False)

self.db.create_feature(self.feat, default_value=None, target_db=self.dbname1)
self.db.set_values(self.feat, self.val1, self.hashes)
self.db.set_values({self.feat: self.val1}, self.hashes)

self.db.create_feature(self.feat, default_value=None, target_db=self.dbname2)
self.db.set_values(self.feat, self.val1, self.hashes[:1], target_db=self.dbname2)
self.db.set_values(self.feat, self.val2, self.hashes, target_db=self.dbname2)
self.db.set_values({self.feat: self.val1}, self.hashes[:1], target_db=self.dbname2)
self.db.set_values({self.feat: self.val2}, self.hashes, target_db=self.dbname2)

self.db.create_feature(self.feat2, default_value=None, target_db=self.dbname2)
self.db.set_values(self.feat2, self.val2, self.hashes)
self.db.set_values({self.feat2: self.val2}, self.hashes)

self.db.create_feature(self.feat3, default_value=0, target_db=self.dbname1)
self.db.set_values(self.feat3, 1, self.hashes[0])
self.db.set_values(self.feat3, 10, self.hashes[1])
self.db.set_values(self.feat3, 100, self.hashes[2])
self.db.set_values({self.feat3: 1}, self.hashes[0])
self.db.set_values({self.feat3: 10}, self.hashes[1])
self.db.set_values({self.feat3: 100}, self.hashes[2])

return super().setUp()

Expand Down Expand Up @@ -94,4 +94,4 @@ def test_multivalued_subselect(self):

def test_feature_accessible(self):
res = self.simple_query(self.feat2, self.val2)
self.assertEqual(len(res), 3)
self.assertEqual(len(res), 3)