Skip to content

Commit 6806f60

Browse files
committed
feat: enhance PyDataFrame and PySessionContext to maintain session state across operations
1 parent 42e6d88 commit 6806f60

2 files changed

Lines changed: 52 additions & 41 deletions

File tree

src/context.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ impl PySessionContext {
434434
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
435435
let result = self.ctx.sql(query);
436436
let df = wait_for_future(py, result)??;
437-
Ok(PyDataFrame::new(df))
437+
Ok(PyDataFrame::new(df, self.ctx.state().into()))
438438
}
439439

440440
#[pyo3(signature = (query, options=None))]
@@ -451,7 +451,7 @@ impl PySessionContext {
451451
};
452452
let result = self.ctx.sql_with_options(query, options);
453453
let df = wait_for_future(py, result)??;
454-
Ok(PyDataFrame::new(df))
454+
Ok(PyDataFrame::new(df, self.ctx.state().into()))
455455
}
456456

457457
#[pyo3(signature = (partitions, name=None, schema=None))]
@@ -486,13 +486,16 @@ impl PySessionContext {
486486

487487
let table = wait_for_future(py, self._table(&table_name))??;
488488

489-
let df = PyDataFrame::new(table);
489+
let df = PyDataFrame::new(table, self.ctx.state().into());
490490
Ok(df)
491491
}
492492

493493
/// Create a DataFrame from an existing logical plan
494494
pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
495-
PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
495+
PyDataFrame::new(
496+
DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()),
497+
self.ctx.state().into(),
498+
)
496499
}
497500

498501
/// Construct datafusion dataframe from Python list
@@ -913,7 +916,7 @@ impl PySessionContext {
913916
let res = wait_for_future(py, self.ctx.table(name))
914917
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
915918
match res {
916-
Ok(df) => Ok(PyDataFrame::new(df)),
919+
Ok(df) => Ok(PyDataFrame::new(df, self.ctx.state().into())),
917920
Err(e) => {
918921
if let datafusion::error::DataFusionError::Plan(msg) = &e {
919922
if msg.contains("No table named") {
@@ -930,7 +933,10 @@ impl PySessionContext {
930933
}
931934

932935
pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
933-
Ok(PyDataFrame::new(self.ctx.read_empty()?))
936+
Ok(PyDataFrame::new(
937+
self.ctx.read_empty()?,
938+
self.ctx.state().into(),
939+
))
934940
}
935941

936942
pub fn session_id(&self) -> String {
@@ -970,7 +976,7 @@ impl PySessionContext {
970976
let result = self.ctx.read_json(path, options);
971977
wait_for_future(py, result)??
972978
};
973-
Ok(PyDataFrame::new(df))
979+
Ok(PyDataFrame::new(df, self.ctx.state().into()))
974980
}
975981

976982
#[allow(clippy::too_many_arguments)]
@@ -1020,12 +1026,12 @@ impl PySessionContext {
10201026
let paths = path.extract::<Vec<String>>()?;
10211027
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
10221028
let result = self.ctx.read_csv(paths, options);
1023-
let df = PyDataFrame::new(wait_for_future(py, result)??);
1029+
let df = PyDataFrame::new(wait_for_future(py, result)??, self.ctx.state().into());
10241030
Ok(df)
10251031
} else {
10261032
let path = path.extract::<String>()?;
10271033
let result = self.ctx.read_csv(path, options);
1028-
let df = PyDataFrame::new(wait_for_future(py, result)??);
1034+
let df = PyDataFrame::new(wait_for_future(py, result)??, self.ctx.state().into());
10291035
Ok(df)
10301036
}
10311037
}
@@ -1068,7 +1074,7 @@ impl PySessionContext {
10681074
.collect();
10691075

10701076
let result = self.ctx.read_parquet(path, options);
1071-
let df = PyDataFrame::new(wait_for_future(py, result)??);
1077+
let df = PyDataFrame::new(wait_for_future(py, result)??, self.ctx.state().into());
10721078
Ok(df)
10731079
}
10741080

@@ -1097,12 +1103,12 @@ impl PySessionContext {
10971103
let read_future = self.ctx.read_avro(path, options);
10981104
wait_for_future(py, read_future)??
10991105
};
1100-
Ok(PyDataFrame::new(df))
1106+
Ok(PyDataFrame::new(df, self.ctx.state().into()))
11011107
}
11021108

11031109
pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
11041110
let df = self.ctx.read_table(table.table())?;
1105-
Ok(PyDataFrame::new(df))
1111+
Ok(PyDataFrame::new(df, self.ctx.state().into()))
11061112
}
11071113

11081114
fn __repr__(&self) -> PyResult<String> {
@@ -1133,7 +1139,7 @@ impl PySessionContext {
11331139
let ctx: TaskContext = TaskContext::from(&state);
11341140
let plan = plan.plan.clone();
11351141
let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?;
1136-
Ok(PyRecordBatchStream::new(stream, state))
1142+
Ok(PyRecordBatchStream::new(stream, state.into()))
11371143
}
11381144
}
11391145

src/dataframe.rs

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,20 @@ impl PyParquetColumnOptions {
292292
pub struct PyDataFrame {
293293
df: Arc<DataFrame>,
294294

295+
// Hold the session state so streams/readers can keep the
296+
// underlying SessionContext alive while Python iterates.
297+
session_state: Arc<SessionState>,
298+
295299
// In IPython environment cache batches between __repr__ and _repr_html_ calls.
296300
batches: Option<(Vec<RecordBatch>, bool)>,
297301
}
298302

299303
impl PyDataFrame {
300304
/// creates a new PyDataFrame
301-
pub fn new(df: DataFrame) -> Self {
305+
pub fn new(df: DataFrame, session_state: Arc<SessionState>) -> Self {
302306
Self {
303307
df: Arc::new(df),
308+
session_state,
304309
batches: None,
305310
}
306311
}
@@ -481,7 +486,7 @@ impl PyDataFrame {
481486
fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
482487
let df = self.df.as_ref().clone();
483488
let stat_df = spawn_future(py, async move { df.describe().await })?;
484-
Ok(Self::new(stat_df))
489+
Ok(Self::new(stat_df, self.session_state.clone()))
485490
}
486491

487492
/// Returns the schema from the logical plan
@@ -511,31 +516,31 @@ impl PyDataFrame {
511516
fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
512517
let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
513518
let df = self.df.as_ref().clone().select_columns(&args)?;
514-
Ok(Self::new(df))
519+
Ok(Self::new(df, self.session_state.clone()))
515520
}
516521

517522
#[pyo3(signature = (*args))]
518523
fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
519524
let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
520525
let df = self.df.as_ref().clone().select(expr)?;
521-
Ok(Self::new(df))
526+
Ok(Self::new(df, self.session_state.clone()))
522527
}
523528

524529
#[pyo3(signature = (*args))]
525530
fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
526531
let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
527532
let df = self.df.as_ref().clone().drop_columns(&cols)?;
528-
Ok(Self::new(df))
533+
Ok(Self::new(df, self.session_state.clone()))
529534
}
530535

531536
fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
532537
let df = self.df.as_ref().clone().filter(predicate.into())?;
533-
Ok(Self::new(df))
538+
Ok(Self::new(df, self.session_state.clone()))
534539
}
535540

536541
fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
537542
let df = self.df.as_ref().clone().with_column(name, expr.into())?;
538-
Ok(Self::new(df))
543+
Ok(Self::new(df, self.session_state.clone()))
539544
}
540545

541546
fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
@@ -545,7 +550,7 @@ impl PyDataFrame {
545550
let name = format!("{}", expr.schema_name());
546551
df = df.with_column(name.as_str(), expr)?
547552
}
548-
Ok(Self::new(df))
553+
Ok(Self::new(df, self.session_state.clone()))
549554
}
550555

551556
/// Rename one column by applying a new projection. This is a no-op if the column to be
@@ -556,27 +561,27 @@ impl PyDataFrame {
556561
.as_ref()
557562
.clone()
558563
.with_column_renamed(old_name, new_name)?;
559-
Ok(Self::new(df))
564+
Ok(Self::new(df, self.session_state.clone()))
560565
}
561566

562567
fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
563568
let group_by = group_by.into_iter().map(|e| e.into()).collect();
564569
let aggs = aggs.into_iter().map(|e| e.into()).collect();
565570
let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
566-
Ok(Self::new(df))
571+
Ok(Self::new(df, self.session_state.clone()))
567572
}
568573

569574
#[pyo3(signature = (*exprs))]
570575
fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
571576
let exprs = to_sort_expressions(exprs);
572577
let df = self.df.as_ref().clone().sort(exprs)?;
573-
Ok(Self::new(df))
578+
Ok(Self::new(df, self.session_state.clone()))
574579
}
575580

576581
#[pyo3(signature = (count, offset=0))]
577582
fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
578583
let df = self.df.as_ref().clone().limit(offset, Some(count))?;
579-
Ok(Self::new(df))
584+
Ok(Self::new(df, self.session_state.clone()))
580585
}
581586

582587
/// Executes the plan, returning a list of `RecordBatch`es.
@@ -593,7 +598,7 @@ impl PyDataFrame {
593598
/// Cache DataFrame.
594599
fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
595600
let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
596-
Ok(Self::new(df))
601+
Ok(Self::new(df, self.session_state.clone()))
597602
}
598603

599604
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
@@ -618,7 +623,7 @@ impl PyDataFrame {
618623
/// Filter out duplicate rows
619624
fn distinct(&self) -> PyDataFusionResult<Self> {
620625
let df = self.df.as_ref().clone().distinct()?;
621-
Ok(Self::new(df))
626+
Ok(Self::new(df, self.session_state.clone()))
622627
}
623628

624629
fn join(
@@ -652,7 +657,7 @@ impl PyDataFrame {
652657
&right_keys,
653658
None,
654659
)?;
655-
Ok(Self::new(df))
660+
Ok(Self::new(df, self.session_state.clone()))
656661
}
657662

658663
fn join_on(
@@ -681,7 +686,7 @@ impl PyDataFrame {
681686
.as_ref()
682687
.clone()
683688
.join_on(right.df.as_ref().clone(), join_type, exprs)?;
684-
Ok(Self::new(df))
689+
Ok(Self::new(df, self.session_state.clone()))
685690
}
686691

687692
/// Print the query plan
@@ -714,7 +719,7 @@ impl PyDataFrame {
714719
.as_ref()
715720
.clone()
716721
.repartition(Partitioning::RoundRobinBatch(num))?;
717-
Ok(Self::new(new_df))
722+
Ok(Self::new(new_df, self.session_state.clone()))
718723
}
719724

720725
/// Repartition a `DataFrame` based on a logical partitioning scheme.
@@ -726,7 +731,7 @@ impl PyDataFrame {
726731
.as_ref()
727732
.clone()
728733
.repartition(Partitioning::Hash(expr, num))?;
729-
Ok(Self::new(new_df))
734+
Ok(Self::new(new_df, self.session_state.clone()))
730735
}
731736

732737
/// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
@@ -742,7 +747,7 @@ impl PyDataFrame {
742747
self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
743748
};
744749

745-
Ok(Self::new(new_df))
750+
Ok(Self::new(new_df, self.session_state.clone()))
746751
}
747752

748753
/// Calculate the distinct union of two `DataFrame`s. The
@@ -753,7 +758,7 @@ impl PyDataFrame {
753758
.as_ref()
754759
.clone()
755760
.union_distinct(py_df.df.as_ref().clone())?;
756-
Ok(Self::new(new_df))
761+
Ok(Self::new(new_df, self.session_state.clone()))
757762
}
758763

759764
#[pyo3(signature = (column, preserve_nulls=true))]
@@ -766,7 +771,7 @@ impl PyDataFrame {
766771
.as_ref()
767772
.clone()
768773
.unnest_columns_with_options(&[column], unnest_options)?;
769-
Ok(Self::new(df))
774+
Ok(Self::new(df, self.session_state.clone()))
770775
}
771776

772777
#[pyo3(signature = (columns, preserve_nulls=true))]
@@ -784,7 +789,7 @@ impl PyDataFrame {
784789
.as_ref()
785790
.clone()
786791
.unnest_columns_with_options(&cols, unnest_options)?;
787-
Ok(Self::new(df))
792+
Ok(Self::new(df, self.session_state.clone()))
788793
}
789794

790795
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
@@ -794,13 +799,13 @@ impl PyDataFrame {
794799
.as_ref()
795800
.clone()
796801
.intersect(py_df.df.as_ref().clone())?;
797-
Ok(Self::new(new_df))
802+
Ok(Self::new(new_df, self.session_state.clone()))
798803
}
799804

800805
/// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
801806
fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
802807
let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
803-
Ok(Self::new(new_df))
808+
Ok(Self::new(new_df, self.session_state.clone()))
804809
}
805810

806811
/// Write a `DataFrame` to a CSV file.
@@ -957,7 +962,7 @@ impl PyDataFrame {
957962
requested_schema: Option<Bound<'py, PyCapsule>>,
958963
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
959964
let df = self.df.as_ref().clone();
960-
let state = df.session_state().clone();
965+
let state = self.session_state.clone();
961966
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
962967
let streams = streams
963968
.into_iter()
@@ -997,14 +1002,14 @@ impl PyDataFrame {
9971002

9981003
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
9991004
let df = self.df.as_ref().clone();
1000-
let state = df.session_state().clone();
1005+
let state = self.session_state.clone();
10011006
let stream = spawn_future(py, async move { df.execute_stream().await })?;
10021007
Ok(PyRecordBatchStream::new(stream, state))
10031008
}
10041009

10051010
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
10061011
let df = self.df.as_ref().clone();
1007-
let state = df.session_state().clone();
1012+
let state = self.session_state.clone();
10081013
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
10091014
Ok(streams
10101015
.into_iter()
@@ -1073,7 +1078,7 @@ impl PyDataFrame {
10731078
};
10741079

10751080
let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
1076-
Ok(Self::new(df))
1081+
Ok(Self::new(df, self.session_state.clone()))
10771082
}
10781083
}
10791084

0 commit comments

Comments
 (0)