Skip to content

Commit 9667b2d

Browse files
committed
Fix KeyError in insert_all when pk column is missing from a single record
When a pk= column was named that isn't present in the data (and so isn't created on the table), inserting exactly one record raised a KeyError while inserting any other number of records did not. The single-record branch that populates last_pk read the column straight out of the inserted row. Check that the named pk column(s) actually exist on the row first and fall back to the rowid otherwise, matching the behaviour for multiple records. Closes #732
1 parent 8f0c06e commit 9667b2d

2 files changed

Lines changed: 29 additions & 5 deletions

File tree

sqlite_utils/db.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3608,12 +3608,20 @@ def insert_all(
36083608
if (hash_id or pk) and self.last_rowid:
36093609
# Set self.last_pk to the pk(s) for that rowid
36103610
row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0]
3611-
if hash_id:
3612-
self.last_pk = row[hash_id]
3613-
elif isinstance(pk, str):
3614-
self.last_pk = row[pk]
3611+
pk_cols = (
3612+
[hash_id]
3613+
if hash_id
3614+
else ([pk] if isinstance(pk, str) else list(pk))
3615+
)
3616+
if all(col in row for col in pk_cols):
3617+
if hash_id or isinstance(pk, str):
3618+
self.last_pk = row[pk_cols[0]]
3619+
else:
3620+
self.last_pk = tuple(row[col] for col in pk_cols)
36153621
else:
3616-
self.last_pk = tuple(row[p] for p in pk)
3622+
# Named pk column(s) are not present in the table - fall
3623+
# back to the rowid, matching the multi-row behaviour
3624+
self.last_pk = self.last_rowid
36173625
else:
36183626
self.last_pk = self.last_rowid
36193627
else:

tests/test_create.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,22 @@ def test_insert_all_with_extra_columns_in_later_chunks(fresh_db):
670670
]
671671

672672

673+
@pytest.mark.parametrize("num_rows", (0, 1, 2, 3, 10))
674+
def test_insert_all_pk_not_in_records(fresh_db, num_rows):
675+
# https://github.com/simonw/sqlite-utils/issues/732
676+
# Naming a pk= column that is absent from the records should behave the
677+
# same regardless of how many rows are inserted - previously a single row
678+
# raised a KeyError while other row counts did not.
679+
fresh_db.conn.execute("CREATE TABLE t (a TEXT, b INT, PRIMARY KEY (a, b))")
680+
rows = [{"a": "x{}".format(i), "b": i} for i in range(num_rows)]
681+
table = fresh_db.table("t")
682+
table.insert_all(rows, pk="not_a_column", alter=True)
683+
assert table.count == num_rows
684+
if num_rows == 1:
685+
# Falls back to the rowid since the named pk column does not exist
686+
assert table.last_pk == table.last_rowid
687+
688+
673689
def test_bulk_insert_more_than_999_values(fresh_db):
674690
"Inserting 100 items with 11 columns should work"
675691
fresh_db["big"].insert_all(

0 commit comments

Comments
 (0)