diff --git a/src/snowflake/snowpark/_internal/proto/ast.proto b/src/snowflake/snowpark/_internal/proto/ast.proto index 0f8b56b268..9ebf65eae3 100644 --- a/src/snowflake/snowpark/_internal/proto/ast.proto +++ b/src/snowflake/snowpark/_internal/proto/ast.proto @@ -48,6 +48,14 @@ message Tuple_String_String { string _2 = 2; } +// dataframe-ai.ir:5 +message AiSplitTextRecursiveFormat { + oneof variant { + bool ai_split_text_recursive_format_markdown = 1; + bool ai_split_text_recursive_format_none = 2; + } +} + // fn.ir:25 message Callable { int64 id = 1; @@ -739,6 +747,139 @@ message DataframeAgg { SrcPosition src = 3; } +// dataframe-ai.ir:25 +message DataframeAiAgg { + Expr df = 1; + Expr input_column = 2; + google.protobuf.StringValue output_column = 3; + SrcPosition src = 4; + string task_description = 5; +} + +// dataframe-ai.ir:32 +message DataframeAiClassify { + Expr categories = 1; + Expr df = 2; + Expr input_column = 3; + repeated Tuple_String_Expr kwargs = 4; + google.protobuf.StringValue output_column = 5; + SrcPosition src = 6; +} + +// dataframe-ai.ir:10 +message DataframeAiComplete { + Expr df = 1; + Expr input_columns = 2; + string model = 3; + repeated Tuple_String_Expr model_parameters = 4; + google.protobuf.StringValue output_column = 5; + string prompt = 6; + SrcPosition src = 7; +} + +// dataframe-ai.ir:89 +message DataframeAiCountTokens { + Expr df = 1; + string model = 2; + google.protobuf.StringValue output_column = 3; + Expr prompt = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:55 +message DataframeAiEmbed { + Expr df = 1; + Expr input_column = 2; + string model = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:82 +message DataframeAiExtract { + Expr df = 1; + Expr input_column = 2; + google.protobuf.StringValue output_column = 3; + Expr response_format = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:19 +message DataframeAiFilter { + Expr df = 1; + Expr input_columns = 2; + string predicate = 3; + SrcPosition src = 4; +} + +// dataframe-ai.ir:75 +message DataframeAiParseDocument { + Expr df = 1; + Expr input_column = 2; + repeated Tuple_String_Expr kwargs = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:48 +message DataframeAiSentiment { + Expr categories = 1; + Expr df = 2; + Expr input_column = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:40 +message DataframeAiSimilarity { + Expr df = 1; + Expr input1 = 2; + Expr input2 = 3; + repeated Tuple_String_Expr kwargs = 4; + google.protobuf.StringValue output_column = 5; + SrcPosition src = 6; +} + +// dataframe-ai.ir:96 +message DataframeAiSplitTextMarkdownHeader { + Expr chunk_size = 1; + Expr df = 2; + Expr headers_to_split_on = 3; + google.protobuf.StringValue output_column = 4; + Expr overlap = 5; + SrcPosition src = 6; + Expr text_to_split = 7; +} + +// dataframe-ai.ir:105 +message DataframeAiSplitTextRecursiveCharacter { + Expr chunk_size = 1; + Expr df = 2; + AiSplitTextRecursiveFormat format = 3; + google.protobuf.StringValue output_column = 4; + Expr overlap = 5; + Expr separators = 6; + SrcPosition src = 7; + Expr text_to_split = 8; +} + +// dataframe-ai.ir:62 +message DataframeAiSummarizeAgg { + Expr df = 1; + Expr input_column = 2; + google.protobuf.StringValue output_column = 3; + SrcPosition src = 4; +} + +// dataframe-ai.ir:68 +message DataframeAiTranscribe { + Expr df = 1; + Expr input_column = 2; + repeated Tuple_String_Expr kwargs = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + // dataframe.ir:180 message DataframeAlias { Expr df = 1; @@ -1386,144 +1527,159 @@ message Expr { ColumnWithinGroup column_within_group = 44; CreateDataframe create_dataframe = 45; DataframeAgg dataframe_agg = 46; - DataframeAlias dataframe_alias = 47; - DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 48; - DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 49; - DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 50; - DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 51; - DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 52; - DataframeCacheResult dataframe_cache_result = 53; - DataframeCol dataframe_col = 54; - DataframeColIlike dataframe_col_ilike = 55; - DataframeCollect dataframe_collect = 56; - DataframeCopyIntoTable dataframe_copy_into_table = 57; - DataframeCount dataframe_count = 58; - DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 59; - DataframeCreateOrReplaceView dataframe_create_or_replace_view = 60; - DataframeCrossJoin dataframe_cross_join = 61; - DataframeCube dataframe_cube = 62; - DataframeDescribe dataframe_describe = 63; - DataframeDistinct dataframe_distinct = 64; - DataframeDrop dataframe_drop = 65; - DataframeDropDuplicates dataframe_drop_duplicates = 66; - DataframeExcept dataframe_except = 67; - DataframeFilter dataframe_filter = 68; - DataframeFirst dataframe_first = 69; - DataframeFlatten dataframe_flatten = 70; - DataframeGroupBy dataframe_group_by = 71; - DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 72; - DataframeIntersect dataframe_intersect = 73; - DataframeJoin dataframe_join = 74; - DataframeJoinTableFunction dataframe_join_table_function = 75; - DataframeLimit dataframe_limit = 76; - DataframeNaDrop_Python dataframe_na_drop__python = 77; - DataframeNaDrop_Scala dataframe_na_drop__scala = 78; - DataframeNaFill dataframe_na_fill = 79; - DataframeNaReplace dataframe_na_replace = 80; - DataframeNaturalJoin dataframe_natural_join = 81; - DataframePivot dataframe_pivot = 82; - DataframeRandomSplit dataframe_random_split = 83; - DataframeReader dataframe_reader = 84; - DataframeRef dataframe_ref = 85; - DataframeRename dataframe_rename = 86; - DataframeRollup dataframe_rollup = 87; - DataframeSample dataframe_sample = 88; - DataframeSelect dataframe_select = 89; - DataframeShow dataframe_show = 90; - DataframeSort dataframe_sort = 91; - DataframeStatApproxQuantile dataframe_stat_approx_quantile = 92; - DataframeStatCorr dataframe_stat_corr = 93; - DataframeStatCov dataframe_stat_cov = 94; - DataframeStatCrossTab dataframe_stat_cross_tab = 95; - DataframeStatSampleBy dataframe_stat_sample_by = 96; - DataframeToDf dataframe_to_df = 97; - DataframeToLocalIterator dataframe_to_local_iterator = 98; - DataframeToPandas dataframe_to_pandas = 99; - DataframeToPandasBatches dataframe_to_pandas_batches = 100; - DataframeUnion dataframe_union = 101; - DataframeUnpivot dataframe_unpivot = 102; - DataframeWithColumn dataframe_with_column = 103; - DataframeWithColumnRenamed dataframe_with_column_renamed = 104; - DataframeWithColumns dataframe_with_columns = 105; - DataframeWriter dataframe_writer = 106; - DatatypeVal datatype_val = 107; - Directory directory = 108; - Div div = 109; - Eq eq = 110; - Flatten flatten = 111; - Float64Val float64_val = 112; - FnRef fn_ref = 113; - Generator generator = 114; - Geq geq = 115; - GroupingSets grouping_sets = 116; - Gt gt = 117; - IndirectTableFnIdRef indirect_table_fn_id_ref = 118; - IndirectTableFnNameRef indirect_table_fn_name_ref = 119; - Int64Val int64_val = 120; - Leq leq = 121; - ListVal list_val = 122; - Lt lt = 123; - MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 124; - MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 125; - MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 126; - Mod mod = 127; - Mul mul = 128; - Neg neg = 129; - Neq neq = 130; - Not not = 131; - NullVal null_val = 132; - ObjectGetItem object_get_item = 133; - Or or = 134; - Pow pow = 135; - PythonDateVal python_date_val = 136; - PythonTimeVal python_time_val = 137; - PythonTimestampVal python_timestamp_val = 138; - Range range = 139; - ReadAvro read_avro = 140; - ReadCsv read_csv = 141; - ReadDirectory read_directory = 142; - ReadJson read_json = 143; - ReadLoad read_load = 144; - ReadOrc read_orc = 145; - ReadParquet read_parquet = 146; - ReadTable read_table = 147; - ReadXml read_xml = 148; - RedactedConst redacted_const = 149; - RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 150; - RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 151; - RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 152; - RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 153; - RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 154; - Row row = 155; - SeqMapVal seq_map_val = 156; - SessionTableFunction session_table_function = 157; - Sql sql = 158; - SqlExpr sql_expr = 159; - StoredProcedure stored_procedure = 160; - StringVal string_val = 161; - Sub sub = 162; - Table table = 163; - TableDelete table_delete = 164; - TableDropTable table_drop_table = 165; - TableFnCallAlias table_fn_call_alias = 166; - TableFnCallOver table_fn_call_over = 167; - TableMerge table_merge = 168; - TableSample table_sample = 169; - TableUpdate table_update = 170; - ToSnowparkPandas to_snowpark_pandas = 171; - TruncatedExpr truncated_expr = 172; - TupleVal tuple_val = 173; - Udaf udaf = 174; - Udf udf = 175; - Udtf udtf = 176; - WriteCopyIntoLocation write_copy_into_location = 177; - WriteCsv write_csv = 178; - WriteInsertInto write_insert_into = 179; - WriteJson write_json = 180; - WritePandas write_pandas = 181; - WriteParquet write_parquet = 182; - WriteSave write_save = 183; - WriteTable write_table = 184; + DataframeAiAgg dataframe_ai_agg = 47; + DataframeAiClassify dataframe_ai_classify = 48; + DataframeAiComplete dataframe_ai_complete = 49; + DataframeAiCountTokens dataframe_ai_count_tokens = 50; + DataframeAiEmbed dataframe_ai_embed = 51; + DataframeAiExtract dataframe_ai_extract = 52; + DataframeAiFilter dataframe_ai_filter = 53; + DataframeAiParseDocument dataframe_ai_parse_document = 54; + DataframeAiSentiment dataframe_ai_sentiment = 55; + DataframeAiSimilarity dataframe_ai_similarity = 56; + DataframeAiSplitTextMarkdownHeader dataframe_ai_split_text_markdown_header = 57; + DataframeAiSplitTextRecursiveCharacter dataframe_ai_split_text_recursive_character = 58; + DataframeAiSummarizeAgg dataframe_ai_summarize_agg = 59; + DataframeAiTranscribe dataframe_ai_transcribe = 60; + DataframeAlias dataframe_alias = 61; + DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 62; + DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 63; + DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 64; + DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 65; + DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 66; + DataframeCacheResult dataframe_cache_result = 67; + DataframeCol dataframe_col = 68; + DataframeColIlike dataframe_col_ilike = 69; + DataframeCollect dataframe_collect = 70; + DataframeCopyIntoTable dataframe_copy_into_table = 71; + DataframeCount dataframe_count = 72; + DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 73; + DataframeCreateOrReplaceView dataframe_create_or_replace_view = 74; + DataframeCrossJoin dataframe_cross_join = 75; + DataframeCube dataframe_cube = 76; + DataframeDescribe dataframe_describe = 77; + DataframeDistinct dataframe_distinct = 78; + DataframeDrop dataframe_drop = 79; + DataframeDropDuplicates dataframe_drop_duplicates = 80; + DataframeExcept dataframe_except = 81; + DataframeFilter dataframe_filter = 82; + DataframeFirst dataframe_first = 83; + DataframeFlatten dataframe_flatten = 84; + DataframeGroupBy dataframe_group_by = 85; + DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 86; + DataframeIntersect dataframe_intersect = 87; + DataframeJoin dataframe_join = 88; + DataframeJoinTableFunction dataframe_join_table_function = 89; + DataframeLimit dataframe_limit = 90; + DataframeNaDrop_Python dataframe_na_drop__python = 91; + DataframeNaDrop_Scala dataframe_na_drop__scala = 92; + DataframeNaFill dataframe_na_fill = 93; + DataframeNaReplace dataframe_na_replace = 94; + DataframeNaturalJoin dataframe_natural_join = 95; + DataframePivot dataframe_pivot = 96; + DataframeRandomSplit dataframe_random_split = 97; + DataframeReader dataframe_reader = 98; + DataframeRef dataframe_ref = 99; + DataframeRename dataframe_rename = 100; + DataframeRollup dataframe_rollup = 101; + DataframeSample dataframe_sample = 102; + DataframeSelect dataframe_select = 103; + DataframeShow dataframe_show = 104; + DataframeSort dataframe_sort = 105; + DataframeStatApproxQuantile dataframe_stat_approx_quantile = 106; + DataframeStatCorr dataframe_stat_corr = 107; + DataframeStatCov dataframe_stat_cov = 108; + DataframeStatCrossTab dataframe_stat_cross_tab = 109; + DataframeStatSampleBy dataframe_stat_sample_by = 110; + DataframeToDf dataframe_to_df = 111; + DataframeToLocalIterator dataframe_to_local_iterator = 112; + DataframeToPandas dataframe_to_pandas = 113; + DataframeToPandasBatches dataframe_to_pandas_batches = 114; + DataframeUnion dataframe_union = 115; + DataframeUnpivot dataframe_unpivot = 116; + DataframeWithColumn dataframe_with_column = 117; + DataframeWithColumnRenamed dataframe_with_column_renamed = 118; + DataframeWithColumns dataframe_with_columns = 119; + DataframeWriter dataframe_writer = 120; + DatatypeVal datatype_val = 121; + Directory directory = 122; + Div div = 123; + Eq eq = 124; + Flatten flatten = 125; + Float64Val float64_val = 126; + FnRef fn_ref = 127; + Generator generator = 128; + Geq geq = 129; + GroupingSets grouping_sets = 130; + Gt gt = 131; + IndirectTableFnIdRef indirect_table_fn_id_ref = 132; + IndirectTableFnNameRef indirect_table_fn_name_ref = 133; + Int64Val int64_val = 134; + Leq leq = 135; + ListVal list_val = 136; + Lt lt = 137; + MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 138; + MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 139; + MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 140; + Mod mod = 141; + Mul mul = 142; + Neg neg = 143; + Neq neq = 144; + Not not = 145; + NullVal null_val = 146; + ObjectGetItem object_get_item = 147; + Or or = 148; + Pow pow = 149; + PythonDateVal python_date_val = 150; + PythonTimeVal python_time_val = 151; + PythonTimestampVal python_timestamp_val = 152; + Range range = 153; + ReadAvro read_avro = 154; + ReadCsv read_csv = 155; + ReadDirectory read_directory = 156; + ReadJson read_json = 157; + ReadLoad read_load = 158; + ReadOrc read_orc = 159; + ReadParquet read_parquet = 160; + ReadTable read_table = 161; + ReadXml read_xml = 162; + RedactedConst redacted_const = 163; + RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 164; + RelationalGroupedDataframeAiAgg relational_grouped_dataframe_ai_agg = 165; + RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 166; + RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 167; + RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 168; + RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 169; + Row row = 170; + SeqMapVal seq_map_val = 171; + SessionTableFunction session_table_function = 172; + Sql sql = 173; + SqlExpr sql_expr = 174; + StoredProcedure stored_procedure = 175; + StringVal string_val = 176; + Sub sub = 177; + Table table = 178; + TableDelete table_delete = 179; + TableDropTable table_drop_table = 180; + TableFnCallAlias table_fn_call_alias = 181; + TableFnCallOver table_fn_call_over = 182; + TableMerge table_merge = 183; + TableSample table_sample = 184; + TableUpdate table_update = 185; + ToSnowparkPandas to_snowpark_pandas = 186; + TruncatedExpr truncated_expr = 187; + TupleVal tuple_val = 188; + Udaf udaf = 189; + Udf udf = 190; + Udtf udtf = 191; + WriteCopyIntoLocation write_copy_into_location = 192; + WriteCsv write_csv = 193; + WriteInsertInto write_insert_into = 194; + WriteJson write_json = 195; + WritePandas write_pandas = 196; + WriteParquet write_parquet = 197; + WriteSave write_save = 198; + WriteTable write_table = 199; } } @@ -1679,150 +1835,165 @@ message HasSrcPosition { ColumnWithinGroup column_within_group = 47; CreateDataframe create_dataframe = 48; DataframeAgg dataframe_agg = 49; - DataframeAlias dataframe_alias = 50; - DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 51; - DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 52; - DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 53; - DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 54; - DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 55; - DataframeCacheResult dataframe_cache_result = 56; - DataframeCol dataframe_col = 57; - DataframeColIlike dataframe_col_ilike = 58; - DataframeCollect dataframe_collect = 59; - DataframeCopyIntoTable dataframe_copy_into_table = 60; - DataframeCount dataframe_count = 61; - DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 62; - DataframeCreateOrReplaceView dataframe_create_or_replace_view = 63; - DataframeCrossJoin dataframe_cross_join = 64; - DataframeCube dataframe_cube = 65; - DataframeDescribe dataframe_describe = 66; - DataframeDistinct dataframe_distinct = 67; - DataframeDrop dataframe_drop = 68; - DataframeDropDuplicates dataframe_drop_duplicates = 69; - DataframeExcept dataframe_except = 70; - DataframeFilter dataframe_filter = 71; - DataframeFirst dataframe_first = 72; - DataframeFlatten dataframe_flatten = 73; - DataframeGroupBy dataframe_group_by = 74; - DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 75; - DataframeIntersect dataframe_intersect = 76; - DataframeJoin dataframe_join = 77; - DataframeJoinTableFunction dataframe_join_table_function = 78; - DataframeLimit dataframe_limit = 79; - DataframeNaDrop_Python dataframe_na_drop__python = 80; - DataframeNaDrop_Scala dataframe_na_drop__scala = 81; - DataframeNaFill dataframe_na_fill = 82; - DataframeNaReplace dataframe_na_replace = 83; - DataframeNaturalJoin dataframe_natural_join = 84; - DataframePivot dataframe_pivot = 85; - DataframeRandomSplit dataframe_random_split = 86; - DataframeReader dataframe_reader = 87; - DataframeRef dataframe_ref = 88; - DataframeRename dataframe_rename = 89; - DataframeRollup dataframe_rollup = 90; - DataframeSample dataframe_sample = 91; - DataframeSelect dataframe_select = 92; - DataframeShow dataframe_show = 93; - DataframeSort dataframe_sort = 94; - DataframeStatApproxQuantile dataframe_stat_approx_quantile = 95; - DataframeStatCorr dataframe_stat_corr = 96; - DataframeStatCov dataframe_stat_cov = 97; - DataframeStatCrossTab dataframe_stat_cross_tab = 98; - DataframeStatSampleBy dataframe_stat_sample_by = 99; - DataframeToDf dataframe_to_df = 100; - DataframeToLocalIterator dataframe_to_local_iterator = 101; - DataframeToPandas dataframe_to_pandas = 102; - DataframeToPandasBatches dataframe_to_pandas_batches = 103; - DataframeUnion dataframe_union = 104; - DataframeUnpivot dataframe_unpivot = 105; - DataframeWithColumn dataframe_with_column = 106; - DataframeWithColumnRenamed dataframe_with_column_renamed = 107; - DataframeWithColumns dataframe_with_columns = 108; - DataframeWriter dataframe_writer = 109; - DatatypeVal datatype_val = 110; - Directory directory = 111; - Div div = 112; - Eq eq = 113; - Flatten flatten = 114; - Float64Val float64_val = 115; - FnRef fn_ref = 116; - Generator generator = 117; - Geq geq = 118; - GroupingSets grouping_sets = 119; - Gt gt = 120; - IndirectTableFnIdRef indirect_table_fn_id_ref = 121; - IndirectTableFnNameRef indirect_table_fn_name_ref = 122; - Int64Val int64_val = 123; - Leq leq = 124; - ListVal list_val = 125; - Lt lt = 126; - MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 127; - MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 128; - MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 129; - Mod mod = 130; - Mul mul = 131; - NameRef name_ref = 132; - Neg neg = 133; - Neq neq = 134; - Not not = 135; - NullVal null_val = 136; - ObjectGetItem object_get_item = 137; - Or or = 138; - Pow pow = 139; - PythonDateVal python_date_val = 140; - PythonTimeVal python_time_val = 141; - PythonTimestampVal python_timestamp_val = 142; - Range range = 143; - ReadAvro read_avro = 144; - ReadCsv read_csv = 145; - ReadDirectory read_directory = 146; - ReadJson read_json = 147; - ReadLoad read_load = 148; - ReadOrc read_orc = 149; - ReadParquet read_parquet = 150; - ReadTable read_table = 151; - ReadXml read_xml = 152; - RedactedConst redacted_const = 153; - RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 154; - RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 155; - RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 156; - RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 157; - RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 158; - Row row = 159; - SeqMapVal seq_map_val = 160; - SessionTableFunction session_table_function = 161; - Sql sql = 162; - SqlExpr sql_expr = 163; - StoredProcedure stored_procedure = 164; - StringVal string_val = 165; - Sub sub = 166; - Table table = 167; - TableDelete table_delete = 168; - TableDropTable table_drop_table = 169; - TableFnCallAlias table_fn_call_alias = 170; - TableFnCallOver table_fn_call_over = 171; - TableMerge table_merge = 172; - TableSample table_sample = 173; - TableUpdate table_update = 174; - ToSnowparkPandas to_snowpark_pandas = 175; - TruncatedExpr truncated_expr = 176; - TupleVal tuple_val = 177; - Udaf udaf = 178; - Udf udf = 179; - Udtf udtf = 180; - WindowSpecEmpty window_spec_empty = 181; - WindowSpecOrderBy window_spec_order_by = 182; - WindowSpecPartitionBy window_spec_partition_by = 183; - WindowSpecRangeBetween window_spec_range_between = 184; - WindowSpecRowsBetween window_spec_rows_between = 185; - WriteCopyIntoLocation write_copy_into_location = 186; - WriteCsv write_csv = 187; - WriteInsertInto write_insert_into = 188; - WriteJson write_json = 189; - WritePandas write_pandas = 190; - WriteParquet write_parquet = 191; - WriteSave write_save = 192; - WriteTable write_table = 193; + DataframeAiAgg dataframe_ai_agg = 50; + DataframeAiClassify dataframe_ai_classify = 51; + DataframeAiComplete dataframe_ai_complete = 52; + DataframeAiCountTokens dataframe_ai_count_tokens = 53; + DataframeAiEmbed dataframe_ai_embed = 54; + DataframeAiExtract dataframe_ai_extract = 55; + DataframeAiFilter dataframe_ai_filter = 56; + DataframeAiParseDocument dataframe_ai_parse_document = 57; + DataframeAiSentiment dataframe_ai_sentiment = 58; + DataframeAiSimilarity dataframe_ai_similarity = 59; + DataframeAiSplitTextMarkdownHeader dataframe_ai_split_text_markdown_header = 60; + DataframeAiSplitTextRecursiveCharacter dataframe_ai_split_text_recursive_character = 61; + DataframeAiSummarizeAgg dataframe_ai_summarize_agg = 62; + DataframeAiTranscribe dataframe_ai_transcribe = 63; + DataframeAlias dataframe_alias = 64; + DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 65; + DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 66; + DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 67; + DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 68; + DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 69; + DataframeCacheResult dataframe_cache_result = 70; + DataframeCol dataframe_col = 71; + DataframeColIlike dataframe_col_ilike = 72; + DataframeCollect dataframe_collect = 73; + DataframeCopyIntoTable dataframe_copy_into_table = 74; + DataframeCount dataframe_count = 75; + DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 76; + DataframeCreateOrReplaceView dataframe_create_or_replace_view = 77; + DataframeCrossJoin dataframe_cross_join = 78; + DataframeCube dataframe_cube = 79; + DataframeDescribe dataframe_describe = 80; + DataframeDistinct dataframe_distinct = 81; + DataframeDrop dataframe_drop = 82; + DataframeDropDuplicates dataframe_drop_duplicates = 83; + DataframeExcept dataframe_except = 84; + DataframeFilter dataframe_filter = 85; + DataframeFirst dataframe_first = 86; + DataframeFlatten dataframe_flatten = 87; + DataframeGroupBy dataframe_group_by = 88; + DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 89; + DataframeIntersect dataframe_intersect = 90; + DataframeJoin dataframe_join = 91; + DataframeJoinTableFunction dataframe_join_table_function = 92; + DataframeLimit dataframe_limit = 93; + DataframeNaDrop_Python dataframe_na_drop__python = 94; + DataframeNaDrop_Scala dataframe_na_drop__scala = 95; + DataframeNaFill dataframe_na_fill = 96; + DataframeNaReplace dataframe_na_replace = 97; + DataframeNaturalJoin dataframe_natural_join = 98; + DataframePivot dataframe_pivot = 99; + DataframeRandomSplit dataframe_random_split = 100; + DataframeReader dataframe_reader = 101; + DataframeRef dataframe_ref = 102; + DataframeRename dataframe_rename = 103; + DataframeRollup dataframe_rollup = 104; + DataframeSample dataframe_sample = 105; + DataframeSelect dataframe_select = 106; + DataframeShow dataframe_show = 107; + DataframeSort dataframe_sort = 108; + DataframeStatApproxQuantile dataframe_stat_approx_quantile = 109; + DataframeStatCorr dataframe_stat_corr = 110; + DataframeStatCov dataframe_stat_cov = 111; + DataframeStatCrossTab dataframe_stat_cross_tab = 112; + DataframeStatSampleBy dataframe_stat_sample_by = 113; + DataframeToDf dataframe_to_df = 114; + DataframeToLocalIterator dataframe_to_local_iterator = 115; + DataframeToPandas dataframe_to_pandas = 116; + DataframeToPandasBatches dataframe_to_pandas_batches = 117; + DataframeUnion dataframe_union = 118; + DataframeUnpivot dataframe_unpivot = 119; + DataframeWithColumn dataframe_with_column = 120; + DataframeWithColumnRenamed dataframe_with_column_renamed = 121; + DataframeWithColumns dataframe_with_columns = 122; + DataframeWriter dataframe_writer = 123; + DatatypeVal datatype_val = 124; + Directory directory = 125; + Div div = 126; + Eq eq = 127; + Flatten flatten = 128; + Float64Val float64_val = 129; + FnRef fn_ref = 130; + Generator generator = 131; + Geq geq = 132; + GroupingSets grouping_sets = 133; + Gt gt = 134; + IndirectTableFnIdRef indirect_table_fn_id_ref = 135; + IndirectTableFnNameRef indirect_table_fn_name_ref = 136; + Int64Val int64_val = 137; + Leq leq = 138; + ListVal list_val = 139; + Lt lt = 140; + MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 141; + MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 142; + MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 143; + Mod mod = 144; + Mul mul = 145; + NameRef name_ref = 146; + Neg neg = 147; + Neq neq = 148; + Not not = 149; + NullVal null_val = 150; + ObjectGetItem object_get_item = 151; + Or or = 152; + Pow pow = 153; + PythonDateVal python_date_val = 154; + PythonTimeVal python_time_val = 155; + PythonTimestampVal python_timestamp_val = 156; + Range range = 157; + ReadAvro read_avro = 158; + ReadCsv read_csv = 159; + ReadDirectory read_directory = 160; + ReadJson read_json = 161; + ReadLoad read_load = 162; + ReadOrc read_orc = 163; + ReadParquet read_parquet = 164; + ReadTable read_table = 165; + ReadXml read_xml = 166; + RedactedConst redacted_const = 167; + RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 168; + RelationalGroupedDataframeAiAgg relational_grouped_dataframe_ai_agg = 169; + RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 170; + RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 171; + RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 172; + RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 173; + Row row = 174; + SeqMapVal seq_map_val = 175; + SessionTableFunction session_table_function = 176; + Sql sql = 177; + SqlExpr sql_expr = 178; + StoredProcedure stored_procedure = 179; + StringVal string_val = 180; + Sub sub = 181; + Table table = 182; + TableDelete table_delete = 183; + TableDropTable table_drop_table = 184; + TableFnCallAlias table_fn_call_alias = 185; + TableFnCallOver table_fn_call_over = 186; + TableMerge table_merge = 187; + TableSample table_sample = 188; + TableUpdate table_update = 189; + ToSnowparkPandas to_snowpark_pandas = 190; + TruncatedExpr truncated_expr = 191; + TupleVal tuple_val = 192; + Udaf udaf = 193; + Udf udf = 194; + Udtf udtf = 195; + WindowSpecEmpty window_spec_empty = 196; + WindowSpecOrderBy window_spec_order_by = 197; + WindowSpecPartitionBy window_spec_partition_by = 198; + WindowSpecRangeBetween window_spec_range_between = 199; + WindowSpecRowsBetween window_spec_rows_between = 200; + WriteCopyIntoLocation write_copy_into_location = 201; + WriteCsv write_csv = 202; + WriteInsertInto write_insert_into = 203; + WriteJson write_json = 204; + WritePandas write_pandas = 205; + WriteParquet write_parquet = 206; + WriteSave write_save = 207; + WriteTable write_table = 208; } } @@ -2083,7 +2254,15 @@ message RelationalGroupedDataframeAgg { SrcPosition src = 3; } -// dataframe-grouped.ir:45 +// dataframe-grouped.ir:46 +message RelationalGroupedDataframeAiAgg { + Expr expr = 1; + Expr grouped_df = 2; + SrcPosition src = 3; + string task_description = 4; +} + +// dataframe-grouped.ir:52 message RelationalGroupedDataframeApplyInPandas { Callable func = 1; Expr grouped_df = 2; @@ -2100,7 +2279,7 @@ message RelationalGroupedDataframeBuiltin { SrcPosition src = 4; } -// dataframe-grouped.ir:52 +// dataframe-grouped.ir:59 message RelationalGroupedDataframePivot { Expr default_on_null = 1; Expr grouped_df = 2; diff --git a/src/snowflake/snowpark/dataframe_ai_functions.py b/src/snowflake/snowpark/dataframe_ai_functions.py index a58daa9110..872cf9225c 100644 --- a/src/snowflake/snowpark/dataframe_ai_functions.py +++ b/src/snowflake/snowpark/dataframe_ai_functions.py @@ -7,6 +7,13 @@ from snowflake.snowpark._internal.utils import ( create_prompt_column_from_template, experimental, + publicapi, +) +from snowflake.snowpark._internal.ast.utils import ( + build_expr_from_python_val, + build_expr_from_snowpark_column_or_col_name, + build_expr_from_snowpark_column_or_python_val, + with_src_position, ) from snowflake.snowpark._internal.type_utils import ColumnOrName from snowflake.snowpark.column import Column, _to_col_if_str, _to_col_if_lit @@ -37,13 +44,14 @@ def __init__(self, dataframe: "snowflake.snowpark.DataFrame") -> None: self._dataframe = dataframe @experimental(version="1.37.0") + @publicapi def complete( self, prompt: str, input_columns: Union[List[Column], Dict[str, Column]], + model: str, *, output_column: Optional[str] = None, - model: Optional[str] = None, model_parameters: Optional[Dict[str, Any]] = None, _emit_ast: bool = True, ) -> "snowflake.snowpark.DataFrame": @@ -54,10 +62,10 @@ def complete( or ``{0}``, ``{1}`` when passing a list. input_columns: A list of Columns (positional placeholders ``{0}``, ``{1}``, ...) or a dict mapping placeholder names to Columns. + model: A string specifying the model to be used. Different input types have different supported models. + See details in `AI_COMPLETE `_. output_column: The name of the output column to be appended. - If not provided, a column named ``AI_COMPLETE_OUTPUT`` is appended. - model: Model name to pass to the underlying function. - It must be specified. + If not provided, a column named ``AI_COMPLETE_OUTPUT`` is appended model_parameters: Optional dict containing model hyperparameters: - temperature: Value from 0 to 1 controlling randomness (default: 0) @@ -122,19 +130,36 @@ def complete( True """ - if not model: - raise ValueError("model must be specified for ai.complete") - # Build the prompt Column if isinstance(input_columns, (dict, list)): prompt_obj = create_prompt_column_from_template( - prompt, input_columns, _emit_ast=False + prompt, input_columns, _emit_ast=_emit_ast ) else: raise TypeError( "input_columns must be a list of Columns or a dict mapping placeholder names to Columns" ) + output_column_name = output_column or "AI_COMPLETE_OUTPUT" + + # AST at top + stmt = None + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_complete, stmt) + self._dataframe._set_ast_ref(ast.df) + ast.model = model + ast.prompt = prompt + # populate input_columns with the prompt column expression + build_expr_from_snowpark_column_or_col_name(ast.input_columns, prompt_obj) + if model_parameters: + for k, v in model_parameters.items(): + entry = ast.model_parameters.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name + # Call the ai_complete function with all explicit parameters result_col = ai_complete( model=model, @@ -144,7 +169,6 @@ def complete( ) # Add the output column to the DataFrame - output_column_name = output_column or "AI_COMPLETE_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -153,9 +177,12 @@ def complete( df, "DataFrame.ai.complete", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def filter( self, predicate: str, @@ -213,7 +240,27 @@ def filter( 1 """ - # Build the predicate Column + # AST at top + stmt = None + predicate_ast_col = None + if _emit_ast: + if isinstance(input_columns, (dict, list)): + predicate_ast_col = create_prompt_column_from_template( + predicate, input_columns, _emit_ast=True + ) + else: + raise TypeError( + "input_columns must be a list of Columns or a dict mapping placeholder names to Columns" + ) + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_filter, stmt) + self._dataframe._set_ast_ref(ast.df) + ast.predicate = predicate + build_expr_from_snowpark_column_or_col_name( + ast.input_columns, predicate_ast_col + ) + + # Build the predicate Column for execution if isinstance(input_columns, (dict, list)): predicate_col = create_prompt_column_from_template( predicate, input_columns, _emit_ast=False @@ -234,9 +281,12 @@ def filter( filtered_df, "DataFrame.ai.filter", ) + if _emit_ast: + filtered_df._ast_id = stmt.uid return filtered_df @experimental(version="1.37.0") + @publicapi def agg( self, task_description: str, @@ -312,9 +362,21 @@ def agg( various publishers presenting events from different points of view. Please create a concise and elaborative summary of source texts without missing any crucial information.". """ + output_column_name = output_column or "AI_AGG_OUTPUT" - # Call the ai_agg function + # AST at top + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.agg") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_agg, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + ast.task_description = task_description + + ast.output_column.value = output_column_name + + # Call the ai_agg function result_col = ai_agg( input_col, task_description=task_description, @@ -322,7 +384,6 @@ def agg( ) # Create a new DataFrame with the aggregated result - output_column_name = output_column or "AI_AGG_OUTPUT" df = self._dataframe.select( result_col.alias(output_column_name), _emit_ast=False ) @@ -331,9 +392,12 @@ def agg( df, "DataFrame.ai.agg", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def classify( self, input_column: ColumnOrName, @@ -454,8 +518,23 @@ def classify( True """ - # Convert string input column to Column object + output_column_name = output_column or "AI_CLASSIFY_OUTPUT" + + # Convert string input column to Column object and AST at top + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.classify") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_classify, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + build_expr_from_snowpark_column_or_python_val(ast.categories, categories) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name # Call the ai_classify function result_col = ai_classify( @@ -466,7 +545,6 @@ def classify( ) # Add the output column to the DataFrame - output_column_name = output_column or "AI_CLASSIFY_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -475,9 +553,12 @@ def classify( df, "DataFrame.ai.classify", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def similarity( self, input1: ColumnOrName, @@ -591,10 +672,24 @@ def similarity( - 0 indicates no similarity - -1 indicates opposite or very dissimilar content """ + output_column_name = output_column or "AI_SIMILARITY_OUTPUT" - # Convert string inputs to Column objects + # Convert string inputs to Column objects and AST at top + stmt = None input1_col = _to_col_if_str(input1, "DataFrame.ai.similarity") input2_col = _to_col_if_str(input2, "DataFrame.ai.similarity") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_similarity, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input1, input1_col) + build_expr_from_snowpark_column_or_col_name(ast.input2, input2_col) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name # Call the ai_similarity function result_col = ai_similarity( @@ -605,7 +700,6 @@ def similarity( ) # Add the output column to the DataFrame - output_column_name = output_column or "AI_SIMILARITY_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -614,9 +708,12 @@ def similarity( df, "DataFrame.ai.similarity", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def sentiment( self, input_column: ColumnOrName, @@ -699,9 +796,20 @@ def sentiment( AI_SENTIMENT can analyze sentiment in English, French, German, Hindi, Italian, Spanish, and Portuguese. You can specify categories in the language of the text or in English. """ + output_column_name = output_column or "AI_SENTIMENT_OUTPUT" - # Convert string input column to Column object + # Convert string input column and AST at top + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.sentiment") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_sentiment, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + if categories is not None: + build_expr_from_python_val(ast.categories, categories) + + ast.output_column.value = output_column_name # Call the ai_sentiment function result_col = ai_sentiment( @@ -711,7 +819,6 @@ def sentiment( ) # Add the output column to the DataFrame - output_column_name = output_column or "AI_SENTIMENT_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -720,9 +827,12 @@ def sentiment( df, "DataFrame.ai.sentiment", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def embed( self, input_column: ColumnOrName, @@ -815,9 +925,19 @@ def embed( - Different models produce embeddings of different dimensions - For best results, use the same model for all items you want to compare """ + output_column_name = output_column or "AI_EMBED_OUTPUT" - # Convert string input column to Column object + # Convert string input column to Column object & AST at top + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.embed") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_embed, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + ast.model = model + + ast.output_column.value = output_column_name # Call the ai_embed function result_col = ai_embed( @@ -827,7 +947,6 @@ def embed( ) # Add the output column to the DataFrame - output_column_name = output_column or "AI_EMBED_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -836,9 +955,12 @@ def embed( df, "DataFrame.ai.embed", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def summarize_agg( self, input_column: ColumnOrName, @@ -907,9 +1029,18 @@ def summarize_agg( - Unlike the ``agg`` method which requires a task description, ``summarize_agg`` automatically generates a comprehensive summary """ + output_column_name = output_column or "AI_SUMMARIZE_AGG_OUTPUT" - # Convert string input column to Column object + # Convert string input column to Column object & AST at top + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.summarize_agg") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_summarize_agg, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + + ast.output_column.value = output_column_name # Call the ai_summarize_agg function result_col = ai_summarize_agg( @@ -918,7 +1049,6 @@ def summarize_agg( ) # Create a new DataFrame with the summarized result - output_column_name = output_column or "AI_SUMMARIZE_AGG_OUTPUT" df = self._dataframe.select( result_col.alias(output_column_name), _emit_ast=False ) @@ -927,9 +1057,12 @@ def summarize_agg( df, "DataFrame.ai.summarize_agg", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def transcribe( self, input_column: ColumnOrName, @@ -999,8 +1132,21 @@ def transcribe( >>> 'start' in result["segments"][0] and 'end' in result["segments"][0] True """ + output_column_name = output_column or "AI_TRANSCRIBE_OUTPUT" + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.transcribe") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_transcribe, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name result_col = ai_transcribe( input_col, @@ -1008,7 +1154,6 @@ def transcribe( **kwargs, ) - output_column_name = output_column or "AI_TRANSCRIBE_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -1017,9 +1162,12 @@ def transcribe( df, "DataFrame.ai.transcribe", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def parse_document( self, input_column: ColumnOrName, @@ -1078,8 +1226,20 @@ def parse_document( >>> len(result["pages"]) == 3 and result["pages"][0]["index"] == 0 True """ + output_column_name = output_column or "AI_PARSE_DOCUMENT_OUTPUT" + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.parse_document") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_parse_document, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + ast.output_column.value = output_column_name result_col = ai_parse_document( input_col, @@ -1087,7 +1247,6 @@ def parse_document( **kwargs, ) - output_column_name = output_column or "AI_PARSE_DOCUMENT_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -1096,9 +1255,12 @@ def parse_document( df, "DataFrame.ai.parse_document", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def extract( self, input_column: ColumnOrName, @@ -1238,7 +1400,19 @@ def extract( """ + output_column_name = output_column or "AI_EXTRACT_OUTPUT" + + stmt = None input_col = _to_col_if_str(input_column, "DataFrame.ai.extract") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_extract, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + if response_format is not None: + build_expr_from_python_val(ast.response_format, response_format) + + ast.output_column.value = output_column_name result_col = ai_extract( input=input_col, @@ -1255,9 +1429,12 @@ def extract( df, "DataFrame.ai.extract", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def count_tokens( self, model: str, @@ -1321,15 +1498,26 @@ def count_tokens( automatically added when using other Cortex AI functions. The actual token usage may be higher when using those functions. """ - # Convert string input to Column object + + output_column_name = output_column or "COUNT_TOKENS_OUTPUT" + + # AST at top + stmt = None prompt_col = _to_col_if_str(prompt, "DataFrame.ai.count_tokens") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_count_tokens, stmt) + self._dataframe._set_ast_ref(ast.df) + ast.model = model + build_expr_from_snowpark_column_or_col_name(ast.prompt, prompt_col) + + ast.output_column.value = output_column_name # Call SNOWFLAKE.CORTEX.COUNT_TOKENS function count_tokens_func = function("SNOWFLAKE.CORTEX.COUNT_TOKENS", _emit_ast=False) result_col = count_tokens_func(model, prompt_col) # Add the output column to the DataFrame - output_column_name = output_column or "COUNT_TOKENS_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -1338,9 +1526,12 @@ def count_tokens( df, "DataFrame.ai.count_tokens", ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def split_text_markdown_header( self, text_to_split: ColumnOrName, @@ -1425,12 +1616,30 @@ def split_text_markdown_header( - Overlap helps maintain context across chunk boundaries """ method_name = "DataFrame.ai.split_text_markdown_header" + output_column_name = output_column or "SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT" - # Convert inputs to Column objects + # Convert inputs to Column objects and AST at top + stmt = None text_col = _to_col_if_str(text_to_split, method_name) headers_col = _to_col_if_lit(headers_to_split_on, method_name) chunk_size_col = _to_col_if_lit(chunk_size, method_name) overlap_col = _to_col_if_lit(overlap, method_name) + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position( + stmt.expr.dataframe_ai_split_text_markdown_header, stmt + ) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name( + ast.text_to_split, text_to_split + ) + build_expr_from_snowpark_column_or_python_val( + ast.headers_to_split_on, headers_to_split_on + ) + build_expr_from_snowpark_column_or_python_val(ast.chunk_size, chunk_size) + build_expr_from_snowpark_column_or_python_val(ast.overlap, overlap) + + ast.output_column.value = output_column_name # Call SNOWFLAKE.CORTEX.SPLIT_TEXT_MARKDOWN_HEADER function split_func = function( @@ -1439,7 +1648,6 @@ def split_text_markdown_header( result_col = split_func(text_col, headers_col, chunk_size_col, overlap_col) # Add the output column to the DataFrame - output_column_name = output_column or "SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -1448,9 +1656,12 @@ def split_text_markdown_header( df, method_name, ) + if _emit_ast: + df._ast_id = stmt.uid return df @experimental(version="1.37.0") + @publicapi def split_text_recursive_character( self, text_to_split: ColumnOrName, @@ -1598,12 +1809,32 @@ def split_text_recursive_character( functions for code) """ method_name = "DataFrame.ai.split_text_recursive_character" + output_column_name = output_column or "SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT" - # Convert input to Column object + # Convert input to Column object & AST at top + stmt = None text_col = _to_col_if_str(text_to_split, method_name) chunk_size_col = _to_col_if_lit(chunk_size, method_name) separators_col = _to_col_if_lit(separators, method_name) overlap_col = _to_col_if_lit(overlap, method_name) + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position( + stmt.expr.dataframe_ai_split_text_recursive_character, stmt + ) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name( + ast.text_to_split, text_to_split + ) + if format == "markdown": + ast.format.ai_split_text_recursive_format_markdown = True + else: + ast.format.ai_split_text_recursive_format_none = True + build_expr_from_snowpark_column_or_python_val(ast.chunk_size, chunk_size) + build_expr_from_snowpark_column_or_python_val(ast.overlap, overlap) + build_expr_from_snowpark_column_or_python_val(ast.separators, separators) + + ast.output_column.value = output_column_name # Call the function split_func = function( @@ -1614,7 +1845,6 @@ def split_text_recursive_character( ) # Add the output column to the DataFrame - output_column_name = output_column or "SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT" df = self._dataframe.with_column( output_column_name, result_col, _emit_ast=False ) @@ -1623,4 +1853,6 @@ def split_text_recursive_character( df, method_name, ) + if _emit_ast: + df._ast_id = stmt.uid return df diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index 9f9411eecc..d58d4be6ce 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -910,13 +910,12 @@ def ai_agg( if _emit_ast: stmt = self._dataframe._session._ast_batch.bind() - ast = with_src_position( - stmt.expr.relational_grouped_dataframe_builtin, stmt - ) + ast = with_src_position(stmt.expr.relational_grouped_dataframe_ai_agg, stmt) + # Reference the grouped dataframe self._set_ast_ref(ast.grouped_df) - ast.agg_name = "ai_agg" - build_expr_from_python_val(ast.cols.args.add(), expr) - build_expr_from_python_val(ast.cols.args.add(), task_description) + # Set arguments + build_expr_from_python_val(ast.expr, expr) + ast.task_description = task_description df._ast_id = stmt.uid return df diff --git a/tests/ast/data/DataFrame.ai.test b/tests/ast/data/DataFrame.ai.test new file mode 100644 index 0000000000..162ed8e4c4 --- /dev/null +++ b/tests/ast/data/DataFrame.ai.test @@ -0,0 +1,2151 @@ +## TEST CASE + +# Basic table to drive some AI ops that don't require files +df = session.table(tables.table1) + +# DataFrame.ai.complete with named placeholders +from snowflake.snowpark.functions import col +df1 = df.select(col("STR").as_("question")) +df2 = df1.ai.complete( + prompt="Answer briefly: {q}", + input_columns={"q": col("question")}, + model="snowflake-arctic", +) + +# DataFrame.ai.filter with named placeholders +df4 = session.create_dataframe([ + ["Switzerland", "Europe"], + ["Korea", "Asia"], + ["Brazil", "South America"], +], schema=["country", "continent"]) +df5 = df4.ai.filter( + "Is {country} located in {continent} and specifically in Europe?", + input_columns={"country": col("country"), "continent": col("continent")}, +) + +# DataFrame.ai.agg on text column +df6 = session.create_dataframe([ + ["Excellent product, highly recommend!"], + ["Great quality and fast shipping"], + ["Average product, nothing special"], + ["Poor quality, very disappointed"], +], schema=["review"]) +df7 = df6.ai.agg( + task_description="Summarize these product reviews for a blog post", + input_column="review", +) + +# DataFrame.ai.classify with list categories +df8 = session.create_dataframe([ + ["I love hiking in the mountains"], + ["My favorite dish is pasta"], + ["Just finished reading a great book"], +], schema=["text"]) +df9 = df8.ai.classify( + input_column="text", + categories=["hiking", "cooking", "reading"], +) + +# DataFrame.ai.similarity between two text columns +df10 = session.create_dataframe([ + ["I love programming", "I enjoy coding"], + ["The weather is nice", "It's raining heavily"], +], schema=["text1", "text2"]) +df11 = df10.ai.similarity( + input1="text1", + input2="text2", +) + +# DataFrame.ai.sentiment overall sentiment +df12 = session.create_dataframe([ + ["The movie had amazing visual effects but the plot was terrible."], + ["Everything about this experience was perfect!"], +], schema=["review"]) +df13 = df12.ai.sentiment(input_column="review") + +# DataFrame.ai.embed text embeddings +df14 = session.create_dataframe([ + ["Machine learning is fascinating"], + ["Snowflake provides cloud data platform"], +], schema=["text"]) +df15 = df14.ai.embed(input_column="text", model="snowflake-arctic-embed-l-v2.0") + +# DataFrame.ai.summarize_agg aggregation summary +df16 = session.create_dataframe([ + ["Meeting started with project updates"], + ["Discussed timeline and deliverables"], + ["Identified key risks"], +], schema=["notes"]) +df17 = df16.ai.summarize_agg(input_column="notes") + +# DataFrame.ai.extract with dict response_format +df20 = session.create_dataframe([["John Smith lives in San Francisco and works for Snowflake"]], schema=["text"]) +df21 = df20.ai.extract( + input_column="text", + response_format={"name": "What is the first name of the employee?", "city": "What is the address of the employee?"}, +) + +# DataFrame.ai.count_tokens simple +df22 = session.create_dataframe([["What is a large language model?"], ["Explain quantum computing in simple terms."]], schema=["text"]) +df23 = df22.ai.count_tokens(model="llama3.1-70b", prompt="text") + +# DataFrame.ai.split_text_markdown_header +df24 = session.create_dataframe([["# Intro\nThis is the intro.\n## Background\nSome background info."]], schema=["document"]) +df25 = df24.ai.split_text_markdown_header( + text_to_split="document", + headers_to_split_on={"#": "section", "##": "subsection"}, + chunk_size=20, + overlap=5, +) + +# DataFrame.ai.split_text_recursive_character +df26 = session.create_dataframe([["This is a long document. It has multiple sentences.\n\nAnd multiple paragraphs."]], schema=["text"]) +df27 = df26.ai.split_text_recursive_character( + text_to_split="text", + format="none", + chunk_size=30, + overlap=5, +) + +## EXPECTED UNPARSER OUTPUT + +df = session.table("table1") + +df1 = df.select(col("STR").as_("question")) + +df2 = df1.ai.complete(prompt="Answer briefly: {q}", model="snowflake-arctic", input_columns=prompt("Answer briefly: {0}", col("question")), output_column="AI_COMPLETE_OUTPUT") + +df4 = session.create_dataframe([["Switzerland", "Europe"], ["Korea", "Asia"], ["Brazil", "South America"]], schema=["country", "continent"]) + +df5 = df4.ai.filter("Is {country} located in {continent} and specifically in Europe?", input_columns=prompt("Is {0} located in {1} and specifically in Europe?", col("country"), col("continent"))) + +df6 = session.create_dataframe([["Excellent product, highly recommend!"], ["Great quality and fast shipping"], ["Average product, nothing special"], ["Poor quality, very disappointed"]], schema=["review"]) + +df7 = df6.ai.agg("Summarize these product reviews for a blog post", input_column="review", output_column="AI_AGG_OUTPUT") + +df8 = session.create_dataframe([["I love hiking in the mountains"], ["My favorite dish is pasta"], ["Just finished reading a great book"]], schema=["text"]) + +df9 = df8.ai.classify(input_column="text", categories=["hiking", "cooking", "reading"], output_column="AI_CLASSIFY_OUTPUT") + +df10 = session.create_dataframe([["I love programming", "I enjoy coding"], ["The weather is nice", "It's raining heavily"]], schema=["text1", "text2"]) + +df11 = df10.ai.similarity(input1="text1", input2="text2", output_column="AI_SIMILARITY_OUTPUT") + +df12 = session.create_dataframe([["The movie had amazing visual effects but the plot was terrible."], ["Everything about this experience was perfect!"]], schema=["review"]) + +df13 = df12.ai.sentiment(input_column="review", output_column="AI_SENTIMENT_OUTPUT") + +df14 = session.create_dataframe([["Machine learning is fascinating"], ["Snowflake provides cloud data platform"]], schema=["text"]) + +df15 = df14.ai.embed(input_column="text", model="snowflake-arctic-embed-l-v2.0", output_column="AI_EMBED_OUTPUT") + +df16 = session.create_dataframe([["Meeting started with project updates"], ["Discussed timeline and deliverables"], ["Identified key risks"]], schema=["notes"]) + +df17 = df16.ai.summarize_agg(input_column="notes", output_column="AI_SUMMARIZE_AGG_OUTPUT") + +df20 = session.create_dataframe([["John Smith lives in San Francisco and works for Snowflake"]], schema=["text"]) + +df21 = df20.ai.extract(input_column="text", response_format={"name": "What is the first name of the employee?", "city": "What is the address of the employee?"}, output_column="AI_EXTRACT_OUTPUT") + +df22 = session.create_dataframe([["What is a large language model?"], ["Explain quantum computing in simple terms."]], schema=["text"]) + +df23 = df22.ai.count_tokens(model="llama3.1-70b", prompt="text", output_column="COUNT_TOKENS_OUTPUT") + +df24 = session.create_dataframe([["# Intro\nThis is the intro.\n## Background\nSome background info."]], schema=["document"]) + +df25 = df24.ai.split_text_markdown_header(text_to_split="document", headers_to_split_on={"#": "section", "##": "subsection"}, chunk_size=20, overlap=5, output_column="SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT") + +df26 = session.create_dataframe([["This is a long document. It has multiple sentences.\n\nAnd multiple paragraphs."]], schema=["text"]) + +df27 = df26.ai.split_text_recursive_character(text_to_split="text", format="none", chunk_size=30, overlap=5, separators=("\n\n", "\n", " ", ""), output_column="SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT") + +## EXPECTED ENCODED AST + +interned_value_table { + string_values { + key: -1 + } + string_values { + key: 2 + value: "SRC_POSITION_TEST_MODE" + } +} +body { + bind { + expr { + table { + name { + name { + name_flat { + name: "table1" + } + } + } + src { + end_column: 41 + end_line: 26 + file: 2 + start_column: 13 + start_line: 26 + } + variant { + session_table: true + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df" + } + uid: 1 + } +} +body { + bind { + expr { + dataframe_select { + cols { + args { + column_alias { + col { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 24 + start_line: 30 + } + v: "STR" + } + } + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 24 + start_line: 30 + } + } + } + fn { + column_alias_fn_as: true + } + name: "question" + src { + end_column: 50 + end_line: 30 + file: 2 + start_column: 24 + start_line: 30 + } + } + } + variadic: true + } + df { + dataframe_ref { + id: 1 + } + } + src { + end_column: 51 + end_line: 30 + file: 2 + start_column: 14 + start_line: 30 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df1" + } + uid: 2 + } +} +body { + bind { + expr { + dataframe_ai_complete { + df { + dataframe_ref { + id: 2 + } + } + input_columns { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "prompt" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 9 + end_line: 35 + file: 2 + start_column: 14 + start_line: 31 + } + v: "Answer briefly: {0}" + } + } + pos_args { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 47 + end_line: 33 + file: 2 + start_column: 32 + start_line: 33 + } + v: "question" + } + } + src { + end_column: 47 + end_line: 33 + file: 2 + start_column: 32 + start_line: 33 + } + } + } + src { + end_column: 9 + end_line: 35 + file: 2 + start_column: 14 + start_line: 31 + } + } + } + model: "snowflake-arctic" + output_column { + value: "AI_COMPLETE_OUTPUT" + } + prompt: "Answer briefly: {q}" + src { + end_column: 9 + end_line: 35 + file: 2 + start_column: 14 + start_line: 31 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df2" + } + uid: 3 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Switzerland" + } + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Europe" + } + } + } + } + vs { + list_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Korea" + } + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Asia" + } + } + } + } + vs { + list_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Brazil" + } + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "South America" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "country" + vs: "continent" + } + } + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df4" + } + uid: 4 + } +} +body { + bind { + expr { + dataframe_ai_filter { + df { + dataframe_ref { + id: 4 + } + } + input_columns { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "prompt" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 9 + end_line: 46 + file: 2 + start_column: 14 + start_line: 43 + } + v: "Is {0} located in {1} and specifically in Europe?" + } + } + pos_args { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 52 + end_line: 45 + file: 2 + start_column: 38 + start_line: 45 + } + v: "country" + } + } + src { + end_column: 52 + end_line: 45 + file: 2 + start_column: 38 + start_line: 45 + } + } + } + pos_args { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 83 + end_line: 45 + file: 2 + start_column: 67 + start_line: 45 + } + v: "continent" + } + } + src { + end_column: 83 + end_line: 45 + file: 2 + start_column: 67 + start_line: 45 + } + } + } + src { + end_column: 9 + end_line: 46 + file: 2 + start_column: 14 + start_line: 43 + } + } + } + predicate: "Is {country} located in {continent} and specifically in Europe?" + src { + end_column: 9 + end_line: 46 + file: 2 + start_column: 14 + start_line: 43 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df5" + } + uid: 5 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Excellent product, highly recommend!" + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Great quality and fast shipping" + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Average product, nothing special" + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Poor quality, very disappointed" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "review" + } + } + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df6" + } + uid: 6 + } +} +body { + bind { + expr { + dataframe_ai_agg { + df { + dataframe_ref { + id: 6 + } + } + input_column { + string_val { + src { + end_column: 9 + end_line: 58 + file: 2 + start_column: 14 + start_line: 55 + } + v: "review" + } + } + output_column { + value: "AI_AGG_OUTPUT" + } + src { + end_column: 9 + end_line: 58 + file: 2 + start_column: 14 + start_line: 55 + } + task_description: "Summarize these product reviews for a blog post" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df7" + } + uid: 7 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + vs { + string_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + v: "I love hiking in the mountains" + } + } + } + } + vs { + list_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + vs { + string_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + v: "My favorite dish is pasta" + } + } + } + } + vs { + list_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + vs { + string_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + v: "Just finished reading a great book" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df8" + } + uid: 8 + } +} +body { + bind { + expr { + dataframe_ai_classify { + categories { + list_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + vs { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "hiking" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "cooking" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "reading" + } + } + } + } + df { + dataframe_ref { + id: 8 + } + } + input_column { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "text" + } + } + output_column { + value: "AI_CLASSIFY_OUTPUT" + } + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df9" + } + uid: 9 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "I love programming" + } + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "I enjoy coding" + } + } + } + } + vs { + list_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "The weather is nice" + } + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "It\'s raining heavily" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text1" + vs: "text2" + } + } + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df10" + } + uid: 10 + } +} +body { + bind { + expr { + dataframe_ai_similarity { + df { + dataframe_ref { + id: 10 + } + } + input1 { + string_val { + src { + end_column: 9 + end_line: 79 + file: 2 + start_column: 15 + start_line: 76 + } + v: "text1" + } + } + input2 { + string_val { + src { + end_column: 9 + end_line: 79 + file: 2 + start_column: 15 + start_line: 76 + } + v: "text2" + } + } + output_column { + value: "AI_SIMILARITY_OUTPUT" + } + src { + end_column: 9 + end_line: 79 + file: 2 + start_column: 15 + start_line: 76 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df11" + } + uid: 11 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + vs { + string_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + v: "The movie had amazing visual effects but the plot was terrible." + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + vs { + string_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + v: "Everything about this experience was perfect!" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "review" + } + } + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df12" + } + uid: 12 + } +} +body { + bind { + expr { + dataframe_ai_sentiment { + df { + dataframe_ref { + id: 12 + } + } + input_column { + string_val { + src { + end_column: 55 + end_line: 86 + file: 2 + start_column: 15 + start_line: 86 + } + v: "review" + } + } + output_column { + value: "AI_SENTIMENT_OUTPUT" + } + src { + end_column: 55 + end_line: 86 + file: 2 + start_column: 15 + start_line: 86 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df13" + } + uid: 13 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + vs { + string_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + v: "Machine learning is fascinating" + } + } + } + } + vs { + list_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + vs { + string_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + v: "Snowflake provides cloud data platform" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df14" + } + uid: 14 + } +} +body { + bind { + expr { + dataframe_ai_embed { + df { + dataframe_ref { + id: 14 + } + } + input_column { + string_val { + src { + end_column: 88 + end_line: 93 + file: 2 + start_column: 15 + start_line: 93 + } + v: "text" + } + } + model: "snowflake-arctic-embed-l-v2.0" + output_column { + value: "AI_EMBED_OUTPUT" + } + src { + end_column: 88 + end_line: 93 + file: 2 + start_column: 15 + start_line: 93 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df15" + } + uid: 15 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + vs { + string_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + v: "Meeting started with project updates" + } + } + } + } + vs { + list_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + vs { + string_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + v: "Discussed timeline and deliverables" + } + } + } + } + vs { + list_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + vs { + string_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + v: "Identified key risks" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "notes" + } + } + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df16" + } + uid: 16 + } +} +body { + bind { + expr { + dataframe_ai_summarize_agg { + df { + dataframe_ref { + id: 16 + } + } + input_column { + string_val { + src { + end_column: 58 + end_line: 101 + file: 2 + start_column: 15 + start_line: 101 + } + v: "notes" + } + } + output_column { + value: "AI_SUMMARIZE_AGG_OUTPUT" + } + src { + end_column: 58 + end_line: 101 + file: 2 + start_column: 15 + start_line: 101 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df17" + } + uid: 17 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 121 + end_line: 104 + file: 2 + start_column: 15 + start_line: 104 + } + vs { + string_val { + src { + end_column: 121 + end_line: 104 + file: 2 + start_column: 15 + start_line: 104 + } + v: "John Smith lives in San Francisco and works for Snowflake" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 121 + end_line: 104 + file: 2 + start_column: 15 + start_line: 104 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df20" + } + uid: 18 + } +} +body { + bind { + expr { + dataframe_ai_extract { + df { + dataframe_ref { + id: 18 + } + } + input_column { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "text" + } + } + output_column { + value: "AI_EXTRACT_OUTPUT" + } + response_format { + seq_map_val { + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "name" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "What is the first name of the employee?" + } + } + } + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "city" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "What is the address of the employee?" + } + } + } + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + } + } + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df21" + } + uid: 19 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + vs { + string_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + v: "What is a large language model?" + } + } + } + } + vs { + list_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + vs { + string_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + v: "Explain quantum computing in simple terms." + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df22" + } + uid: 20 + } +} +body { + bind { + expr { + dataframe_ai_count_tokens { + df { + dataframe_ref { + id: 20 + } + } + model: "llama3.1-70b" + output_column { + value: "COUNT_TOKENS_OUTPUT" + } + prompt { + string_val { + src { + end_column: 72 + end_line: 112 + file: 2 + start_column: 15 + start_line: 112 + } + v: "text" + } + } + src { + end_column: 72 + end_line: 112 + file: 2 + start_column: 15 + start_line: 112 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df23" + } + uid: 21 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 133 + end_line: 115 + file: 2 + start_column: 15 + start_line: 115 + } + vs { + string_val { + src { + end_column: 133 + end_line: 115 + file: 2 + start_column: 15 + start_line: 115 + } + v: "# Intro\nThis is the intro.\n## Background\nSome background info." + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "document" + } + } + src { + end_column: 133 + end_line: 115 + file: 2 + start_column: 15 + start_line: 115 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df24" + } + uid: 22 + } +} +body { + bind { + expr { + dataframe_ai_split_text_markdown_header { + chunk_size { + int64_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: 20 + } + } + df { + dataframe_ref { + id: 22 + } + } + headers_to_split_on { + seq_map_val { + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "#" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "section" + } + } + } + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "##" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "subsection" + } + } + } + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + } + } + output_column { + value: "SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT" + } + overlap { + int64_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: 5 + } + } + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + text_to_split { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "document" + } + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df25" + } + uid: 23 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 143 + end_line: 124 + file: 2 + start_column: 15 + start_line: 124 + } + vs { + string_val { + src { + end_column: 143 + end_line: 124 + file: 2 + start_column: 15 + start_line: 124 + } + v: "This is a long document. It has multiple sentences.\n\nAnd multiple paragraphs." + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 143 + end_line: 124 + file: 2 + start_column: 15 + start_line: 124 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df26" + } + uid: 24 + } +} +body { + bind { + expr { + dataframe_ai_split_text_recursive_character { + chunk_size { + int64_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: 30 + } + } + df { + dataframe_ref { + id: 24 + } + } + format { + ai_split_text_recursive_format_none: true + } + output_column { + value: "SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT" + } + overlap { + int64_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: 5 + } + } + separators { + tuple_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: "\n\n" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: "\n" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: " " + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + } + } + } + } + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + text_to_split { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: "text" + } + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df27" + } + uid: 25 + } +} +client_ast_version: 1 +client_language { + python_language { + version { + label: "final" + major: 3 + minor: 9 + patch: 1 + } + } +} +client_version { + major: 1 + minor: 37 +} +id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" diff --git a/tests/ast/data/RelationalGroupedDataFrame.ai_agg.test b/tests/ast/data/RelationalGroupedDataFrame.ai_agg.test new file mode 100644 index 0000000000..702bbd473a --- /dev/null +++ b/tests/ast/data/RelationalGroupedDataFrame.ai_agg.test @@ -0,0 +1,292 @@ +## TEST CASE + +from snowflake.snowpark.functions import col + +df = session.table(tables.table1) + +# group_by with a single column, expr as string +rgdf1 = df.group_by("str") +df1 = rgdf1.ai_agg("str", "Summarize strings per group") + +# group_by with a single column, expr as Column +df2 = rgdf1.ai_agg(col("str"), "Summarize strings per group using Column") + +# group_by with no columns (global aggregation) +rgdf2 = df.group_by() +df3 = rgdf2.ai_agg("str", "Summarize strings across all rows") + +## EXPECTED UNPARSER OUTPUT + +df = session.table("table1") + +rgdf1 = df.group_by("str") + +df1 = rgdf1.ai_agg("str", task_description="Summarize strings per group") + +df2 = rgdf1.ai_agg(col("str"), task_description="Summarize strings per group using Column") + +rgdf2 = df.group_by() + +df3 = rgdf2.ai_agg("str", task_description="Summarize strings across all rows") + +## EXPECTED ENCODED AST + +interned_value_table { + string_values { + key: -1 + } + string_values { + key: 2 + value: "SRC_POSITION_TEST_MODE" + } +} +body { + bind { + expr { + table { + name { + name { + name_flat { + name: "table1" + } + } + } + src { + end_column: 41 + end_line: 27 + file: 2 + start_column: 13 + start_line: 27 + } + variant { + session_table: true + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df" + } + uid: 1 + } +} +body { + bind { + expr { + dataframe_group_by { + cols { + args { + string_val { + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 16 + start_line: 30 + } + v: "str" + } + } + variadic: true + } + df { + dataframe_ref { + id: 1 + } + } + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 16 + start_line: 30 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "rgdf1" + } + uid: 2 + } +} +body { + bind { + expr { + relational_grouped_dataframe_ai_agg { + expr { + string_val { + src { + end_column: 64 + end_line: 31 + file: 2 + start_column: 14 + start_line: 31 + } + v: "str" + } + } + grouped_df { + relational_grouped_dataframe_ref { + id: 2 + } + } + src { + end_column: 64 + end_line: 31 + file: 2 + start_column: 14 + start_line: 31 + } + task_description: "Summarize strings per group" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df1" + } + uid: 3 + } +} +body { + bind { + expr { + relational_grouped_dataframe_ai_agg { + expr { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 37 + end_line: 34 + file: 2 + start_column: 27 + start_line: 34 + } + v: "str" + } + } + src { + end_column: 37 + end_line: 34 + file: 2 + start_column: 27 + start_line: 34 + } + } + } + grouped_df { + relational_grouped_dataframe_ref { + id: 2 + } + } + src { + end_column: 82 + end_line: 34 + file: 2 + start_column: 14 + start_line: 34 + } + task_description: "Summarize strings per group using Column" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df2" + } + uid: 4 + } +} +body { + bind { + expr { + dataframe_group_by { + cols { + variadic: true + } + df { + dataframe_ref { + id: 1 + } + } + src { + end_column: 29 + end_line: 37 + file: 2 + start_column: 16 + start_line: 37 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "rgdf2" + } + uid: 5 + } +} +body { + bind { + expr { + relational_grouped_dataframe_ai_agg { + expr { + string_val { + src { + end_column: 70 + end_line: 38 + file: 2 + start_column: 14 + start_line: 38 + } + v: "str" + } + } + grouped_df { + relational_grouped_dataframe_ref { + id: 5 + } + } + src { + end_column: 70 + end_line: 38 + file: 2 + start_column: 14 + start_line: 38 + } + task_description: "Summarize strings across all rows" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df3" + } + uid: 6 + } +} +client_ast_version: 1 +client_language { + python_language { + version { + label: "final" + major: 3 + minor: 9 + patch: 1 + } + } +} +client_version { + major: 1 + minor: 37 +} +id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" diff --git a/tests/integ/test_dataframe_ai.py b/tests/integ/test_dataframe_ai.py index a4c69760e0..95f60ef8ee 100644 --- a/tests/integ/test_dataframe_ai.py +++ b/tests/integ/test_dataframe_ai.py @@ -46,7 +46,7 @@ def test_dataframe_ai_complete_with_named_placeholders(session): assert result_df.columns == ["REVIEW", "RATING", "CATEGORY", "SENTIMENT_ANALYSIS"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 3 for row in results: @@ -78,7 +78,7 @@ def test_dataframe_ai_complete_with_positional_placeholders(session): assert result_df.columns == ["TOPIC", "CATEGORY", "DEFINITION"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 3 for row in results: @@ -108,7 +108,7 @@ def test_dataframe_ai_complete_default_output_column(session): # Check that default column name is used assert "AI_COMPLETE_OUTPUT" in result_df.schema.names - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 2 for row in results: assert row["AI_COMPLETE_OUTPUT"] is not None @@ -119,7 +119,9 @@ def test_dataframe_ai_complete_error_handling(session): # Test missing model parameter df = session.create_dataframe([["test"]], schema=["text"]) - with pytest.raises(ValueError, match="model must be specified"): + with pytest.raises( + TypeError, match="missing 1 required positional argument: 'model'" + ): df.ai.complete( prompt="Test {text}", input_columns={"text": col("text")} @@ -157,7 +159,7 @@ def test_dataframe_ai_filter_simple_text(session): ) # Check that we get some results (should be the positive ones) - results = positive_df.collect(_emit_ast=False) + results = positive_df.collect() assert len(results) >= 1 # At least some positive reviews should be found assert len(results) <= 5 # Not more than the original count @@ -191,7 +193,7 @@ def test_dataframe_ai_filter_with_named_placeholders(session): ) # Check results - results = european_df.collect(_emit_ast=False) + results = european_df.collect() assert len(results) >= 1 # Should find at least one European country # Verify the results are European countries @@ -221,7 +223,7 @@ def test_dataframe_ai_filter_with_positional_placeholders(session): ) # Check results - results = programming_df.collect(_emit_ast=False) + results = programming_df.collect() assert len(results) >= 1 # Should find at least one programming language # Verify the results contain programming languages @@ -265,7 +267,7 @@ def test_dataframe_ai_agg_basic(session): ) # Verify results - results = summary_df.collect(_emit_ast=False) + results = summary_df.collect() assert len(results) == 1 # Should be a single aggregated row assert results[0]["SUMMARY"] is not None assert len(results[0]["SUMMARY"]) > 10 # Should have meaningful content @@ -276,7 +278,7 @@ def test_dataframe_ai_agg_basic(session): ) # Verify results - results = summary_df.collect(_emit_ast=False) + results = summary_df.collect() assert len(results) == 1 # Should be a single aggregated row assert results[0]["AI_AGG_OUTPUT"] is not None assert len(results[0]["AI_AGG_OUTPUT"]) > 10 # Should have meaningful content @@ -298,7 +300,7 @@ def test_dataframe_ai_agg_error_handling(session): df.ai.agg( task_description="Summarize the text", input_column=col("invalid"), # Invalid column name - ).collect(_emit_ast=False) + ).collect() def test_grouped_dataframe_ai_agg(session): @@ -329,7 +331,7 @@ def test_grouped_dataframe_ai_agg(session): task_description="Create an overall summary of all customer reviews", ) - count = global_summary_df.count(_emit_ast=False) + count = global_summary_df.count() assert count == 1 # Single row for global aggregation # Test 2: Group by single column with string expr @@ -343,7 +345,7 @@ def test_grouped_dataframe_ai_agg(session): .sort(col("CATEGORY").asc()) # Chain sort operation ) - category_results = category_summary_df.collect(_emit_ast=False) + category_results = category_summary_df.collect() assert len(category_results) == 3 # 3 categories after filtering out toys categories = [row["CATEGORY"] for row in category_results] assert categories == ["books", "clothing", "electronics"] # Alphabetically sorted @@ -365,7 +367,7 @@ def test_grouped_dataframe_ai_agg(session): .select("QUALITY_LEVEL") # Chain select to keep only grouping column ) - quality_results = quality_summary_df.collect(_emit_ast=False) + quality_results = quality_summary_df.collect() assert len(quality_results) == 2 # Two quality levels: high and low assert all( len(row.as_dict()) == 1 for row in quality_results @@ -385,7 +387,7 @@ def test_grouped_dataframe_ai_agg(session): .limit(2) # Limit to top 2 results ) - multi_results = multi_group_df.collect(_emit_ast=False) + multi_results = multi_group_df.collect() assert len(multi_results) == 2 # Limited to 2 results # Verify the results are high quality and sorted correctly @@ -411,7 +413,7 @@ def test_grouped_dataframe_ai_agg(session): .sort("PRODUCT_TYPE") ) - complex_results = complex_chain_df.collect(_emit_ast=False) + complex_results = complex_chain_df.collect() assert len(complex_results) == 2 # Only electronics and books assert "PRODUCT_TYPE" in complex_results[0].as_dict() assert "CATEGORY" not in complex_results[0].as_dict() @@ -444,7 +446,7 @@ def test_dataframe_ai_classify_basic(session): assert result_df.columns == ["TEXT", "CATEGORY"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 4 # Verify some expected classifications @@ -487,7 +489,7 @@ def test_dataframe_ai_classify_multi_label(session): assert result_df.columns == ["TEXT", "CATEGORIES", "TOPICS"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 4 # First text mentions both travel and cooking @@ -539,7 +541,7 @@ def test_dataframe_ai_classify_with_examples(session): assert result_df.columns == ["FEEDBACK", "SENTIMENT"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 4 # Check sentiment classifications @@ -591,7 +593,7 @@ def test_dataframe_ai_similarity_basic(session): assert result_df.columns == ["TEXT1", "TEXT2", "SIMILARITY_SCORE"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 4 # Verify similarity scores are in range @@ -630,7 +632,7 @@ def test_dataframe_ai_similarity_default_output_column(session): # Check that default column name is used assert "AI_SIMILARITY_OUTPUT" in result_df.columns - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 2 for row in results: assert row["AI_SIMILARITY_OUTPUT"] is not None @@ -664,7 +666,7 @@ def test_dataframe_ai_similarity_with_custom_model(session): ] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 3 # With multilingual model, translations should have high similarity @@ -695,7 +697,7 @@ def test_dataframe_ai_similarity_error_handling(session): df.ai.similarity( input1="text", input2="nonexistent_column", - ).collect(_emit_ast=False) + ).collect() def test_dataframe_ai_similarity_with_nulls(session): @@ -717,7 +719,7 @@ def test_dataframe_ai_similarity_with_nulls(session): output_column="similarity", ) - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 4 # Rows with NULLs should return NULL similarity scores @@ -752,7 +754,7 @@ def test_dataframe_ai_sentiment_basic(session): assert result_df.columns == ["REVIEW", "SENTIMENT"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 4 # Parse sentiment results @@ -802,7 +804,7 @@ def test_dataframe_ai_sentiment_with_categories(session): assert result_df.columns == ["REVIEW", "DETAILED_SENTIMENT"] # Collect and verify results - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 3 # Check first review sentiments @@ -970,9 +972,9 @@ def test_dataframe_ai_summarize_agg_basic(session): # Check schema and results assert summary_df.columns == ["REVIEWS_SUMMARY"] - assert summary_df.count(_emit_ast=False) == 1 + assert summary_df.count() == 1 - results = summary_df.collect(_emit_ast=False) + results = summary_df.collect() assert len(results) == 1 assert results[0]["REVIEWS_SUMMARY"] is not None assert len(results[0]["REVIEWS_SUMMARY"]) > 10 # Should have meaningful content @@ -994,9 +996,9 @@ def test_dataframe_ai_summarize_agg_default_output_column(session): # Check that default column name is used assert "AI_SUMMARIZE_AGG_OUTPUT" in summary_df.columns - assert summary_df.count(_emit_ast=False) == 1 + assert summary_df.count() == 1 - results = summary_df.collect(_emit_ast=False) + results = summary_df.collect() assert results[0]["AI_SUMMARIZE_AGG_OUTPUT"] is not None @@ -1030,7 +1032,7 @@ def test_dataframe_ai_transcribe_basic(session, resources_path): assert result_df.columns == ["AUDIO_PATH", "TRANSCRIPT"] - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 1 data = json.loads(results[0]["TRANSCRIPT"]) if results[0]["TRANSCRIPT"] else {} assert isinstance(data, dict) @@ -1057,7 +1059,7 @@ def test_dataframe_ai_transcribe_default_output_column(session, resources_path): ) assert "AI_TRANSCRIBE_OUTPUT" in result_df.columns - results = result_df.collect(_emit_ast=False) + results = result_df.collect() data = ( json.loads(results[0]["AI_TRANSCRIBE_OUTPUT"]) if results[0]["AI_TRANSCRIBE_OUTPUT"] @@ -1087,7 +1089,7 @@ def test_dataframe_ai_parse_document_basic(session, resources_path): assert result_df.columns == ["FILE_PATH", "PARSED"] - results = result_df.collect(_emit_ast=False) + results = result_df.collect() data = json.loads(results[0]["PARSED"]) if results[0]["PARSED"] else {} assert isinstance(data, dict) assert "content" in data and isinstance(data["content"], str) @@ -1112,7 +1114,7 @@ def test_dataframe_ai_parse_document_default_output_column(session, resources_pa ) assert "AI_PARSE_DOCUMENT_OUTPUT" in result_df.columns - results = result_df.collect(_emit_ast=False) + results = result_df.collect() data = ( json.loads(results[0]["AI_PARSE_DOCUMENT_OUTPUT"]) if results[0]["AI_PARSE_DOCUMENT_OUTPUT"] @@ -1268,7 +1270,7 @@ def test_dataframe_ai_count_tokens_default_output_column(session): prompt="text", ) - results = result_df.collect(_emit_ast=False) + results = result_df.collect() assert len(results) == 1 assert results[0]["COUNT_TOKENS_OUTPUT"] == 5 @@ -1318,7 +1320,7 @@ def test_dataframe_ai_split_text_markdown_header_basic(session): ) # Verify results - results = result_df.select("chunks").collect(_emit_ast=False) + results = result_df.select("chunks").collect() chunks = json.loads(results[0][0]) assert chunks == [ { @@ -1374,9 +1376,7 @@ def test_dataframe_ai_split_text_markdown_header_default_output(session): chunk_size=20, ) - results = result_df.select("SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT").collect( - _emit_ast=False - ) + results = result_df.select("SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT").collect() chunks = json.loads(results[0][0]) assert chunks == [{"chunk": "Content", "headers": {"h1": "Header"}}] @@ -1419,7 +1419,7 @@ def test_dataframe_ai_split_text_recursive_character_basic(session): ) # Verify results - results = result_df.select("chunks").collect(_emit_ast=False) + results = result_df.select("chunks").collect() chunks = json.loads(results[0][0]) assert chunks == [ "This is a long document with multiple sentences.", @@ -1471,7 +1471,7 @@ def hello(): ) # Verify results - results = result_df.select("md_chunks").collect(_emit_ast=False) + results = result_df.select("md_chunks").collect() chunks = json.loads(results[0][0]) assert chunks == [ "# Main Title", @@ -1519,7 +1519,7 @@ def function_three(): ) # Verify results - results = result_df.select("code_chunks").collect(_emit_ast=False) + results = result_df.select("code_chunks").collect() chunks = json.loads(results[0][0]) assert chunks == [ "def function_one():", @@ -1546,9 +1546,7 @@ def test_dataframe_ai_split_text_recursive_character_default_output(session): ) # Verify results - results = result_df.select("SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT").collect( - _emit_ast=False - ) + results = result_df.select("SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT").collect() chunks = json.loads(results[0][0]) assert chunks == ["Short text", "to split"]