diff --git a/rust/lance-core/src/datatypes/field.rs b/rust/lance-core/src/datatypes/field.rs index c41eb44913..60d0a0ec0c 100644 --- a/rust/lance-core/src/datatypes/field.rs +++ b/rust/lance-core/src/datatypes/field.rs @@ -1018,6 +1018,35 @@ impl Field { pub fn is_unenforced_primary_key(&self) -> bool { self.unenforced_primary_key_position.is_some() } + + /// Re-parse well-known metadata keys and update the corresponding structured fields. + /// + /// Call this after modifying `field.metadata` directly (e.g., via UpdateConfig) + /// to keep structured fields like `unenforced_primary_key_position` in sync. + pub fn sync_embedded_metadata(&mut self) -> Result<()> { + self.unenforced_primary_key_position = + parse_unenforced_primary_key_position(&self.metadata)?; + Ok(()) + } +} + +fn parse_unenforced_primary_key_position( + metadata: &HashMap, +) -> Result> { + if let Some(s) = metadata.get(LANCE_UNENFORCED_PRIMARY_KEY_POSITION) { + let parsed = s.parse::().map_err(|e| { + Error::invalid_input(format!( + "Invalid value '{}' for {}: {}", + s, LANCE_UNENFORCED_PRIMARY_KEY_POSITION, e + )) + })?; + Ok(Some(parsed)) + } else { + Ok(metadata + .get(LANCE_UNENFORCED_PRIMARY_KEY) + .filter(|s| matches!(s.to_lowercase().as_str(), "true" | "1" | "yes")) + .map(|_| 0)) + } } impl fmt::Display for Field { @@ -1098,16 +1127,6 @@ impl TryFrom<&ArrowField> for Field { } _ => vec![], }; - let unenforced_primary_key_position = metadata - .get(LANCE_UNENFORCED_PRIMARY_KEY_POSITION) - .and_then(|s| s.parse::().ok()) - .or_else(|| { - // Backward compatibility: use 0 for legacy boolean flag - metadata - .get(LANCE_UNENFORCED_PRIMARY_KEY) - .filter(|s| matches!(s.to_lowercase().as_str(), "true" | "1" | "yes")) - .map(|_| 0) - }); let is_blob_v2 = has_blob_v2_extension(field); if is_blob_v2 { @@ -1125,6 +1144,8 @@ impl TryFrom<&ArrowField> for Field { LogicalType::try_from(field.data_type())? }; + let unenforced_primary_key_position = parse_unenforced_primary_key_position(&metadata)?; + Ok(Self { id, parent_id: -1, @@ -1831,4 +1852,25 @@ mod tests { .unwrap(); assert_eq!(unloaded_projected, unloaded); } + + #[test] + fn test_try_from_arrow_field_invalid_pk_position_returns_error() { + let arrow_field = + ArrowField::new("id", DataType::Int32, false).with_metadata(HashMap::from([( + LANCE_UNENFORCED_PRIMARY_KEY_POSITION.to_string(), + "not_a_number".to_string(), + )])); + + let result = Field::try_from(&arrow_field); + assert!( + result.is_err(), + "Invalid pk position should fail in TryFrom" + ); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not_a_number"), + "Error should include the invalid value: {}", + err_msg + ); + } } diff --git a/rust/lance/src/dataset/metadata.rs b/rust/lance/src/dataset/metadata.rs index abf16081fd..33a713476b 100644 --- a/rust/lance/src/dataset/metadata.rs +++ b/rust/lance/src/dataset/metadata.rs @@ -557,4 +557,147 @@ mod tests { assert!(matches!(result, Err(Error::InvalidInput { .. }))); } + + /// Helper to create a simple dataset with a non-nullable `id` field suitable for PK tests. + async fn test_dataset_for_pk() -> Dataset { + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new("value", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + "memory://", + None, + ) + .await + .unwrap() + } + + #[tokio::test] + async fn test_update_field_metadata_sets_unenforced_primary_key() { + let mut dataset = test_dataset_for_pk().await; + + // Legacy boolean flag should map to position 0. + dataset + .update_field_metadata() + .update("id", [("lance-schema:unenforced-primary-key", "true")]) + .unwrap() + .await + .unwrap(); + + let field = dataset.schema().field("id").unwrap(); + assert!( + field.is_unenforced_primary_key(), + "Field should be recognized as unenforced primary key after metadata update" + ); + assert_eq!( + field.unenforced_primary_key_position, + Some(0), + "Legacy boolean flag should map to position 0" + ); + + // Explicit position should override the legacy flag. + dataset + .update_field_metadata() + .update( + "id", + [("lance-schema:unenforced-primary-key:position", "2")], + ) + .unwrap() + .await + .unwrap(); + + let field = dataset.schema().field("id").unwrap(); + assert!(field.is_unenforced_primary_key()); + assert_eq!( + field.unenforced_primary_key_position, + Some(2), + "Explicit position should take precedence over the legacy boolean flag" + ); + } + + #[tokio::test] + async fn test_update_field_metadata_primary_key_used_by_merge_insert() { + use crate::dataset::write::merge_insert::*; + + let mut dataset = test_dataset_for_pk().await; + + // Set PK via metadata update (the bug scenario) + dataset + .update_field_metadata() + .update("id", [("lance-schema:unenforced-primary-key", "true")]) + .unwrap() + .await + .unwrap(); + + let dataset = Arc::new(dataset); + + // Prepare new data that overlaps with existing + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new("value", DataType::Utf8, true), + ])); + let new_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![2, 4])), + Arc::new(arrow_array::StringArray::from(vec!["updated", "new"])), + ], + ) + .unwrap(); + + // MergeInsert with empty `on` keys — should default to the unenforced PK + let mut builder = MergeInsertBuilder::try_new(dataset.clone(), Vec::new()).unwrap(); + builder + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll); + let job = builder.try_build().unwrap(); + + let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone())); + let new_stream = lance_datafusion::utils::reader_to_stream(new_reader); + + let (updated_dataset, stats) = job.execute(new_stream).await.unwrap(); + + assert_eq!(stats.num_inserted_rows, 1, "id=4 should be inserted"); + assert_eq!(stats.num_updated_rows, 1, "id=2 should be updated"); + + let result = updated_dataset.scan().try_into_batch().await.unwrap(); + let ids = result + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let values = result + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let mut pairs: Vec<(i32, String)> = (0..ids.len()) + .map(|i| (ids.value(i), values.value(i).to_string())) + .collect(); + pairs.sort_by_key(|(id, _)| *id); + + assert_eq!( + pairs, + vec![ + (1, "a".to_string()), + (2, "updated".to_string()), + (3, "c".to_string()), + (4, "new".to_string()), + ] + ); + } } diff --git a/rust/lance/src/dataset/transaction.rs b/rust/lance/src/dataset/transaction.rs index bb43ba16d8..a076eca794 100644 --- a/rust/lance/src/dataset/transaction.rs +++ b/rust/lance/src/dataset/transaction.rs @@ -2358,6 +2358,7 @@ impl Transaction { for (field_id, field_metadata_update) in field_metadata_updates { if let Some(field) = manifest.schema.field_by_id_mut(*field_id) { apply_update_map(&mut field.metadata, field_metadata_update); + field.sync_embedded_metadata()?; } else { return Err(Error::invalid_input_source( format!("Field with id {} does not exist", field_id).into(),