@@ -174,6 +174,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
174174 memory_limit : jlong ,
175175 memory_limit_per_task : jlong ,
176176 task_attempt_id : jlong ,
177+ task_cpus : jlong ,
177178 key_unwrapper_obj : JObject ,
178179) -> jlong {
179180 try_unwrap_or_throw ( & e, |mut env| {
@@ -241,6 +242,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
241242 memory_pool,
242243 local_dirs_vec,
243244 max_temp_directory_size,
245+ task_cpus as usize ,
244246 ) ?;
245247
246248 let plan_creation_time = start. elapsed ( ) ;
@@ -294,6 +296,7 @@ fn prepare_datafusion_session_context(
294296 memory_pool : Arc < dyn MemoryPool > ,
295297 local_dirs : Vec < String > ,
296298 max_temp_directory_size : u64 ,
299+ task_cpus : usize ,
297300) -> CometResult < SessionContext > {
298301 let paths = local_dirs. into_iter ( ) . map ( PathBuf :: from) . collect ( ) ;
299302 let disk_manager = DiskManagerBuilder :: default ( )
@@ -306,6 +309,10 @@ fn prepare_datafusion_session_context(
306309 // can be configured in Comet Spark JVM using Spark --conf parameters
307310 // e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true
308311 let session_config = SessionConfig :: new ( )
312+ . with_target_partitions ( task_cpus)
313+ // This DataFusion context is within the scope of an executing Spark Task. We want to set
314+ // its internal parallelism to the number of CPUs allocated to Spark Tasks. This can be
315+ // modified by changing spark.task.cpus in the Spark config.
309316 . with_batch_size ( batch_size)
310317 // DataFusion partial aggregates can emit duplicate rows so we disable the
311318 // skip partial aggregation feature because this is not compatible with Spark's
0 commit comments