Skip to content

Commit d8aa2bb

Browse files
committed
Merge remote-tracking branch 'upstream/main' into delta-kernel-phase-1
2 parents db4b6eb + 9a7e616 commit d8aa2bb

22 files changed

Lines changed: 737 additions & 105 deletions

File tree

docs/source/user-guide/latest/compatibility.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,19 @@ Cast operations in Comet fall into three levels of support:
136136
Spark.
137137
- **N/A**: Spark does not support this cast.
138138

139+
### String to Decimal
140+
141+
Comet's native `CAST(string AS DECIMAL)` implementation matches Apache Spark's behavior,
142+
including:
143+
144+
- Leading and trailing ASCII whitespace is trimmed before parsing.
145+
- Null bytes (`\u0000`) at the start or end of a string are trimmed, matching Spark's
146+
`UTF8String` behavior. Null bytes embedded in the middle of a string produce `NULL`.
147+
- Fullwidth Unicode digits (U+FF10–U+FF19, e.g. `123.45`) are treated as their ASCII
148+
equivalents, so `CAST('123.45' AS DECIMAL(10,2))` returns `123.45`.
149+
- Scientific notation (e.g. `1.23E+5`) is supported.
150+
- Special values (`inf`, `infinity`, `nan`) produce `NULL`.
151+
139152
### String to Timestamp
140153

141154
Comet's native `CAST(string AS TIMESTAMP)` implementation supports all timestamp formats accepted

docs/source/user-guide/latest/tuning.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ The valid pool types are:
6161
- `fair_unified` (default when `spark.memory.offHeap.enabled=true` is set)
6262
- `greedy_unified`
6363

64-
The `fair_unified` pool types prevents operators from using more than an even fraction of the available memory
64+
Both pool types are shared across all native execution contexts within the same Spark task. When
65+
Comet executes a shuffle, it runs two native execution contexts concurrently (e.g. one for
66+
pre-shuffle operators and one for the shuffle writer). The shared pool ensures that the combined
67+
memory usage stays within the per-task limit.
68+
69+
The `fair_unified` pool prevents operators from using more than an even fraction of the available memory
6570
(i.e. `pool_size / num_reservations`). This pool works best when you know beforehand
6671
the query has multiple operators that will likely all need to spill. Sometimes it will cause spills even
6772
when there is sufficient memory in order to leave enough memory for other operators.

native/core/src/execution/expressions/temporal.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ use datafusion::logical_expr::ScalarUDF;
2525
use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr};
2626
use datafusion_comet_proto::spark_expression::Expr;
2727
use datafusion_comet_spark_expr::{
28-
SparkHour, SparkMinute, SparkSecond, SparkUnixTimestamp, TimestampTruncExpr,
28+
SparkHour, SparkHoursTransform, SparkMinute, SparkSecond, SparkUnixTimestamp,
29+
TimestampTruncExpr,
2930
};
3031

3132
use crate::execution::{
@@ -160,3 +161,29 @@ impl ExpressionBuilder for TruncTimestampBuilder {
160161
Ok(Arc::new(TimestampTruncExpr::new(child, format, timezone)))
161162
}
162163
}
164+
165+
pub struct HoursTransformBuilder;
166+
167+
impl ExpressionBuilder for HoursTransformBuilder {
168+
fn build(
169+
&self,
170+
spark_expr: &Expr,
171+
input_schema: SchemaRef,
172+
planner: &PhysicalPlanner,
173+
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
174+
let expr = extract_expr!(spark_expr, HoursTransform);
175+
let child = planner.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
176+
let args = vec![child];
177+
let comet_hours_transform = Arc::new(ScalarUDF::new_from_impl(SparkHoursTransform::new()));
178+
let field_ref = Arc::new(Field::new("hours_transform", DataType::Int32, true));
179+
let expr: ScalarFunctionExpr = ScalarFunctionExpr::new(
180+
"hours_transform",
181+
comet_hours_transform,
182+
args,
183+
field_ref,
184+
Arc::new(ConfigOptions::default()),
185+
);
186+
187+
Ok(Arc::new(expr))
188+
}
189+
}

native/core/src/execution/jni_api.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use crate::{
2626
},
2727
jvm_bridge::JVMClasses,
2828
};
29+
use std::collections::HashSet;
30+
2931
use arrow::array::{Array, RecordBatch, UInt32Array};
3032
use arrow::compute::{take, TakeOptions};
3133
use arrow::datatypes::DataType as ArrowDataType;
@@ -141,15 +143,33 @@ fn unregister_and_total(thread_id: u64, context_id: i64) -> usize {
141143
map.remove(&thread_id);
142144
return 0;
143145
}
144-
return pools.values().map(|p| p.reserved()).sum::<usize>();
146+
let mut seen = HashSet::new();
147+
return pools
148+
.values()
149+
.filter_map(|p| {
150+
let ptr = Arc::as_ptr(p) as *const ();
151+
seen.insert(ptr).then(|| p.reserved())
152+
})
153+
.sum::<usize>();
145154
}
146155
0
147156
}
148157

149158
fn total_reserved_for_thread(thread_id: u64) -> usize {
150159
let map = get_thread_memory_pools().lock();
151160
map.get(&thread_id)
152-
.map(|pools| pools.values().map(|p| p.reserved()).sum::<usize>())
161+
.map(|pools| {
162+
// Deduplicate pools that share the same underlying allocation
163+
// (e.g. task-shared pools registered by multiple execution contexts)
164+
let mut seen = HashSet::new();
165+
pools
166+
.values()
167+
.filter_map(|p| {
168+
let ptr = Arc::as_ptr(p) as *const ();
169+
seen.insert(ptr).then(|| p.reserved())
170+
})
171+
.sum::<usize>()
172+
})
153173
.unwrap_or(0)
154174
}
155175

native/core/src/execution/memory_pools/config.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ impl MemoryPoolType {
3434
pub(crate) fn is_task_shared(&self) -> bool {
3535
matches!(
3636
self,
37-
MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared
37+
MemoryPoolType::GreedyTaskShared
38+
| MemoryPoolType::FairSpillTaskShared
39+
| MemoryPoolType::FairUnified
40+
| MemoryPoolType::GreedyUnified
3841
)
3942
}
4043
}

native/core/src/execution/memory_pools/mod.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,36 @@ pub(crate) fn create_memory_pool(
4242
const NUM_TRACKED_CONSUMERS: usize = 10;
4343
match memory_pool_config.pool_type {
4444
MemoryPoolType::GreedyUnified => {
45-
// Set Comet memory pool for native
46-
let memory_pool =
47-
CometUnifiedMemoryPool::new(comet_task_memory_manager, task_attempt_id);
48-
Arc::new(TrackConsumersPool::new(
49-
memory_pool,
50-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
51-
))
45+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
46+
let per_task_memory_pool =
47+
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
48+
let pool: Arc<dyn MemoryPool> = Arc::new(TrackConsumersPool::new(
49+
CometUnifiedMemoryPool::new(
50+
Arc::clone(&comet_task_memory_manager),
51+
task_attempt_id,
52+
),
53+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
54+
));
55+
PerTaskMemoryPool::new(pool)
56+
});
57+
per_task_memory_pool.num_plans += 1;
58+
Arc::clone(&per_task_memory_pool.memory_pool)
5259
}
5360
MemoryPoolType::FairUnified => {
54-
// Set Comet fair memory pool for native
55-
let memory_pool =
56-
CometFairMemoryPool::new(comet_task_memory_manager, memory_pool_config.pool_size);
57-
Arc::new(TrackConsumersPool::new(
58-
memory_pool,
59-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
60-
))
61+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
62+
let per_task_memory_pool =
63+
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
64+
let pool: Arc<dyn MemoryPool> = Arc::new(TrackConsumersPool::new(
65+
CometFairMemoryPool::new(
66+
Arc::clone(&comet_task_memory_manager),
67+
memory_pool_config.pool_size,
68+
),
69+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
70+
));
71+
PerTaskMemoryPool::new(pool)
72+
});
73+
per_task_memory_pool.num_plans += 1;
74+
Arc::clone(&per_task_memory_pool.memory_pool)
6175
}
6276
MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
6377
GreedyMemoryPool::new(memory_pool_config.pool_size),

native/core/src/execution/planner/expression_registry.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ pub enum ExpressionType {
110110
Second,
111111
TruncTimestamp,
112112
UnixTimestamp,
113+
HoursTransform,
113114
}
114115

115116
/// Registry for expression builders
@@ -310,6 +311,10 @@ impl ExpressionRegistry {
310311
ExpressionType::TruncTimestamp,
311312
Box::new(TruncTimestampBuilder),
312313
);
314+
self.builders.insert(
315+
ExpressionType::HoursTransform,
316+
Box::new(HoursTransformBuilder),
317+
);
313318
}
314319

315320
/// Extract expression type from Spark protobuf expression
@@ -382,6 +387,7 @@ impl ExpressionRegistry {
382387
Some(ExprStruct::Second(_)) => Ok(ExpressionType::Second),
383388
Some(ExprStruct::TruncTimestamp(_)) => Ok(ExpressionType::TruncTimestamp),
384389
Some(ExprStruct::UnixTimestamp(_)) => Ok(ExpressionType::UnixTimestamp),
390+
Some(ExprStruct::HoursTransform(_)) => Ok(ExpressionType::HoursTransform),
385391

386392
Some(other) => Err(ExecutionError::GeneralError(format!(
387393
"Unsupported expression type: {:?}",

native/proto/src/proto/expr.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ message Expr {
8888
UnixTimestamp unix_timestamp = 65;
8989
FromJson from_json = 66;
9090
ToCsv to_csv = 67;
91+
HoursTransform hours_transform = 68;
9192
}
9293

9394
// Optional QueryContext for error reporting (contains SQL text and position)
@@ -356,6 +357,10 @@ message Hour {
356357
string timezone = 2;
357358
}
358359

360+
message HoursTransform {
361+
Expr child = 1;
362+
}
363+
359364
message Minute {
360365
Expr child = 1;
361366
string timezone = 2;

native/spark-expr/src/conversion_funcs/string.rs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,40 @@ fn cast_string_to_decimal256_impl(
438438
))
439439
}
440440

441+
/// Normalize fullwidth Unicode digits (U+FF10–U+FF19) to their ASCII equivalents.
442+
///
443+
/// Spark's UTF8String parser treats fullwidth digits as numerically equivalent to
444+
/// ASCII digits, e.g. "123.45" parses as 123.45. Each fullwidth digit encodes
445+
/// to exactly three UTF-8 bytes: [0xEF, 0xBC, 0x90+n] for digit n. The ASCII
446+
/// equivalent is 0x30+n, so the conversion is: third_byte - 0x60.
447+
///
448+
/// All other bytes (ASCII or other multi-byte sequences) are passed through
449+
/// unchanged, so the output is valid UTF-8 whenever the input is.
450+
fn normalize_fullwidth_digits(s: &str) -> String {
451+
let bytes = s.as_bytes();
452+
let mut out = Vec::with_capacity(s.len());
453+
let mut i = 0;
454+
while i < bytes.len() {
455+
if i + 2 < bytes.len()
456+
&& bytes[i] == 0xEF
457+
&& bytes[i + 1] == 0xBC
458+
&& bytes[i + 2] >= 0x90
459+
&& bytes[i + 2] <= 0x99
460+
{
461+
// e.g. 0x91 - 0x60 = 0x31 = b'1'
462+
out.push(bytes[i + 2] - 0x60);
463+
i += 3;
464+
} else {
465+
out.push(bytes[i]);
466+
i += 1;
467+
}
468+
}
469+
// SAFETY: we only replace valid 3-byte UTF-8 sequences [EF BC 9X] with a
470+
// single ASCII byte; all other bytes are copied unchanged, preserving the
471+
// UTF-8 invariant of the input.
472+
unsafe { String::from_utf8_unchecked(out) }
473+
}
474+
441475
/// Parse a decimal string into mantissa and scale
442476
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
443477
/// Parse a string to decimal following Spark's behavior
@@ -446,16 +480,30 @@ fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> SparkRe
446480
let mut start = 0;
447481
let mut end = string_bytes.len();
448482

449-
// trim whitespaces
450-
while start < end && string_bytes[start].is_ascii_whitespace() {
483+
// Trim ASCII whitespace and null bytes from both ends. Spark's UTF8String
484+
// trims null bytes the same way it trims whitespace: "123\u0000" and
485+
// "\u0000123" both parse as 123. Null bytes in the middle are not trimmed
486+
// and will fail the digit validation in parse_decimal_str, producing NULL.
487+
while start < end && (string_bytes[start].is_ascii_whitespace() || string_bytes[start] == 0) {
451488
start += 1;
452489
}
453-
while end > start && string_bytes[end - 1].is_ascii_whitespace() {
490+
while end > start && (string_bytes[end - 1].is_ascii_whitespace() || string_bytes[end - 1] == 0)
491+
{
454492
end -= 1;
455493
}
456494

457495
let trimmed = &input_str[start..end];
458496

497+
// Normalize fullwidth digits to ASCII. Fast path skips the allocation for
498+
// pure-ASCII strings, which is the common case.
499+
let normalized;
500+
let trimmed = if trimmed.bytes().any(|b| b > 0x7F) {
501+
normalized = normalize_fullwidth_digits(trimmed);
502+
normalized.as_str()
503+
} else {
504+
trimmed
505+
};
506+
459507
if trimmed.is_empty() {
460508
return Ok(None);
461509
}

0 commit comments

Comments
 (0)