diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f67c53015..8c357fe043 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,34 @@ # Release History +## 1.40.0 (YYYY-MM-DD) + +### Snowpark Python API Updates + +#### New Features + ## 1.39.0 (YYYY-MM-DD) ### Snowpark Python API Updates #### New Features +- Added support for unstructured data engineering in Snowpark, powered by Snowflake AISQL and Cortex functions: + - `DataFrame.ai.complete`: Generate per-row LLM completions from prompts built over columns and files. + - `DataFrame.ai.filter`: Keep rows where an AI classifier returns TRUE for the given predicate. + - `DataFrame.ai.agg`: Reduce a text column into one result using a natural-language task description. + - `RelationalGroupedDataFrame.ai_agg`: Perform the same natural-language aggregation per group. + - `DataFrame.ai.classify`: Assign single or multiple labels from given categories to text or images. + - `DataFrame.ai.similarity`: Compute cosine-based similarity scores between two columns via embeddings. + - `DataFrame.ai.sentiment`: Extract overall and aspect-level sentiment from text into JSON. + - `DataFrame.ai.embed`: Generate VECTOR embeddings for text or images using configurable models. + - `DataFrame.ai.summarize_agg`: Aggregate and produce a single comprehensive summary over many rows. + - `DataFrame.ai.transcribe`: Transcribe audio files to text with optional timestamps and speaker labels. + - `DataFrame.ai.parse_document`: OCR/layout-parse documents or images into structured JSON. + - `DataFrame.ai.extract`: Pull structured fields from text or files using a response schema. + - `DataFrame.ai.count_tokens`: Estimate token usage for a given model and input text per row. + - `DataFrame.ai.split_text_markdown_header`: Split Markdown into hierarchical header-aware chunks. + - `DataFrame.ai.split_text_recursive_character`: Split text into size-bounded chunks using recursive separators. + - `DataFrameReader.file`: Create a DataFrame containing all files from a stage as FILE data type for downstream unstructured data processing. - Added a new datatype `YearMonthIntervalType` that allows users to create intervals for datetime operations. - Added a new function `interval_year_month_from_parts` that allows users to easily create `YearMonthIntervalType` without using SQL. - Added a new datatype `DayTimeIntervalType` that allows users to create intervals for datetime operations. diff --git a/docs/source/snowpark/dataframe.rst b/docs/source/snowpark/dataframe.rst index c458a0dc2a..43d9858aa4 100644 --- a/docs/source/snowpark/dataframe.rst +++ b/docs/source/snowpark/dataframe.rst @@ -13,6 +13,7 @@ DataFrame DataFrameNaFunctions DataFrameStatFunctions DataFrameAnalyticsFunctions + DataFrameAIFunctions .. rubric:: Methods @@ -120,6 +121,20 @@ DataFrame DataFrameAnalyticsFunctions.compute_lag DataFrameAnalyticsFunctions.compute_lead DataFrameAnalyticsFunctions.time_series_agg + DataFrameAIFunctions.agg + DataFrameAIFunctions.classify + DataFrameAIFunctions.complete + DataFrameAIFunctions.count_tokens + DataFrameAIFunctions.embed + DataFrameAIFunctions.extract + DataFrameAIFunctions.filter + DataFrameAIFunctions.parse_document + DataFrameAIFunctions.sentiment + DataFrameAIFunctions.similarity + DataFrameAIFunctions.split_text_markdown_header + DataFrameAIFunctions.split_text_recursive_character + DataFrameAIFunctions.summarize_agg + DataFrameAIFunctions.transcribe dataframe.map dataframe.map_in_pandas @@ -133,6 +148,7 @@ DataFrame .. autosummary:: :toctree: api/ + DataFrame.ai DataFrame.columns DataFrame.na DataFrame.queries diff --git a/docs/source/snowpark/grouping.rst b/docs/source/snowpark/grouping.rst index 6e1168018e..68a4c86a30 100644 --- a/docs/source/snowpark/grouping.rst +++ b/docs/source/snowpark/grouping.rst @@ -18,6 +18,7 @@ Grouping :toctree: api/ RelationalGroupedDataFrame.agg + RelationalGroupedDataFrame.ai_agg RelationalGroupedDataFrame.apply_in_pandas RelationalGroupedDataFrame.applyInPandas RelationalGroupedDataFrame.avg diff --git a/src/snowflake/snowpark/__init__.py b/src/snowflake/snowpark/__init__.py index dd9fa994d3..cbd0df6227 100644 --- a/src/snowflake/snowpark/__init__.py +++ b/src/snowflake/snowpark/__init__.py @@ -22,6 +22,7 @@ "DataFrameStatFunctions", "DataFrameAnalyticsFunctions", "DataFrameNaFunctions", + "DataFrameAIFunctions", "DataFrameWriter", "DataFrameReader", "GroupingSets", @@ -54,6 +55,7 @@ from snowflake.snowpark.column import CaseExpr, Column from snowflake.snowpark.stored_procedure_profiler import StoredProcedureProfiler from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.dataframe_ai_functions import DataFrameAIFunctions from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_reader import DataFrameReader diff --git a/src/snowflake/snowpark/_internal/proto/ast.proto b/src/snowflake/snowpark/_internal/proto/ast.proto index 0f8b56b268..9ebf65eae3 100644 --- a/src/snowflake/snowpark/_internal/proto/ast.proto +++ b/src/snowflake/snowpark/_internal/proto/ast.proto @@ -48,6 +48,14 @@ message Tuple_String_String { string _2 = 2; } +// dataframe-ai.ir:5 +message AiSplitTextRecursiveFormat { + oneof variant { + bool ai_split_text_recursive_format_markdown = 1; + bool ai_split_text_recursive_format_none = 2; + } +} + // fn.ir:25 message Callable { int64 id = 1; @@ -739,6 +747,139 @@ message DataframeAgg { SrcPosition src = 3; } +// dataframe-ai.ir:25 +message DataframeAiAgg { + Expr df = 1; + Expr input_column = 2; + google.protobuf.StringValue output_column = 3; + SrcPosition src = 4; + string task_description = 5; +} + +// dataframe-ai.ir:32 +message DataframeAiClassify { + Expr categories = 1; + Expr df = 2; + Expr input_column = 3; + repeated Tuple_String_Expr kwargs = 4; + google.protobuf.StringValue output_column = 5; + SrcPosition src = 6; +} + +// dataframe-ai.ir:10 +message DataframeAiComplete { + Expr df = 1; + Expr input_columns = 2; + string model = 3; + repeated Tuple_String_Expr model_parameters = 4; + google.protobuf.StringValue output_column = 5; + string prompt = 6; + SrcPosition src = 7; +} + +// dataframe-ai.ir:89 +message DataframeAiCountTokens { + Expr df = 1; + string model = 2; + google.protobuf.StringValue output_column = 3; + Expr prompt = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:55 +message DataframeAiEmbed { + Expr df = 1; + Expr input_column = 2; + string model = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:82 +message DataframeAiExtract { + Expr df = 1; + Expr input_column = 2; + google.protobuf.StringValue output_column = 3; + Expr response_format = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:19 +message DataframeAiFilter { + Expr df = 1; + Expr input_columns = 2; + string predicate = 3; + SrcPosition src = 4; +} + +// dataframe-ai.ir:75 +message DataframeAiParseDocument { + Expr df = 1; + Expr input_column = 2; + repeated Tuple_String_Expr kwargs = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:48 +message DataframeAiSentiment { + Expr categories = 1; + Expr df = 2; + Expr input_column = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + +// dataframe-ai.ir:40 +message DataframeAiSimilarity { + Expr df = 1; + Expr input1 = 2; + Expr input2 = 3; + repeated Tuple_String_Expr kwargs = 4; + google.protobuf.StringValue output_column = 5; + SrcPosition src = 6; +} + +// dataframe-ai.ir:96 +message DataframeAiSplitTextMarkdownHeader { + Expr chunk_size = 1; + Expr df = 2; + Expr headers_to_split_on = 3; + google.protobuf.StringValue output_column = 4; + Expr overlap = 5; + SrcPosition src = 6; + Expr text_to_split = 7; +} + +// dataframe-ai.ir:105 +message DataframeAiSplitTextRecursiveCharacter { + Expr chunk_size = 1; + Expr df = 2; + AiSplitTextRecursiveFormat format = 3; + google.protobuf.StringValue output_column = 4; + Expr overlap = 5; + Expr separators = 6; + SrcPosition src = 7; + Expr text_to_split = 8; +} + +// dataframe-ai.ir:62 +message DataframeAiSummarizeAgg { + Expr df = 1; + Expr input_column = 2; + google.protobuf.StringValue output_column = 3; + SrcPosition src = 4; +} + +// dataframe-ai.ir:68 +message DataframeAiTranscribe { + Expr df = 1; + Expr input_column = 2; + repeated Tuple_String_Expr kwargs = 3; + google.protobuf.StringValue output_column = 4; + SrcPosition src = 5; +} + // dataframe.ir:180 message DataframeAlias { Expr df = 1; @@ -1386,144 +1527,159 @@ message Expr { ColumnWithinGroup column_within_group = 44; CreateDataframe create_dataframe = 45; DataframeAgg dataframe_agg = 46; - DataframeAlias dataframe_alias = 47; - DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 48; - DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 49; - DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 50; - DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 51; - DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 52; - DataframeCacheResult dataframe_cache_result = 53; - DataframeCol dataframe_col = 54; - DataframeColIlike dataframe_col_ilike = 55; - DataframeCollect dataframe_collect = 56; - DataframeCopyIntoTable dataframe_copy_into_table = 57; - DataframeCount dataframe_count = 58; - DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 59; - DataframeCreateOrReplaceView dataframe_create_or_replace_view = 60; - DataframeCrossJoin dataframe_cross_join = 61; - DataframeCube dataframe_cube = 62; - DataframeDescribe dataframe_describe = 63; - DataframeDistinct dataframe_distinct = 64; - DataframeDrop dataframe_drop = 65; - DataframeDropDuplicates dataframe_drop_duplicates = 66; - DataframeExcept dataframe_except = 67; - DataframeFilter dataframe_filter = 68; - DataframeFirst dataframe_first = 69; - DataframeFlatten dataframe_flatten = 70; - DataframeGroupBy dataframe_group_by = 71; - DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 72; - DataframeIntersect dataframe_intersect = 73; - DataframeJoin dataframe_join = 74; - DataframeJoinTableFunction dataframe_join_table_function = 75; - DataframeLimit dataframe_limit = 76; - DataframeNaDrop_Python dataframe_na_drop__python = 77; - DataframeNaDrop_Scala dataframe_na_drop__scala = 78; - DataframeNaFill dataframe_na_fill = 79; - DataframeNaReplace dataframe_na_replace = 80; - DataframeNaturalJoin dataframe_natural_join = 81; - DataframePivot dataframe_pivot = 82; - DataframeRandomSplit dataframe_random_split = 83; - DataframeReader dataframe_reader = 84; - DataframeRef dataframe_ref = 85; - DataframeRename dataframe_rename = 86; - DataframeRollup dataframe_rollup = 87; - DataframeSample dataframe_sample = 88; - DataframeSelect dataframe_select = 89; - DataframeShow dataframe_show = 90; - DataframeSort dataframe_sort = 91; - DataframeStatApproxQuantile dataframe_stat_approx_quantile = 92; - DataframeStatCorr dataframe_stat_corr = 93; - DataframeStatCov dataframe_stat_cov = 94; - DataframeStatCrossTab dataframe_stat_cross_tab = 95; - DataframeStatSampleBy dataframe_stat_sample_by = 96; - DataframeToDf dataframe_to_df = 97; - DataframeToLocalIterator dataframe_to_local_iterator = 98; - DataframeToPandas dataframe_to_pandas = 99; - DataframeToPandasBatches dataframe_to_pandas_batches = 100; - DataframeUnion dataframe_union = 101; - DataframeUnpivot dataframe_unpivot = 102; - DataframeWithColumn dataframe_with_column = 103; - DataframeWithColumnRenamed dataframe_with_column_renamed = 104; - DataframeWithColumns dataframe_with_columns = 105; - DataframeWriter dataframe_writer = 106; - DatatypeVal datatype_val = 107; - Directory directory = 108; - Div div = 109; - Eq eq = 110; - Flatten flatten = 111; - Float64Val float64_val = 112; - FnRef fn_ref = 113; - Generator generator = 114; - Geq geq = 115; - GroupingSets grouping_sets = 116; - Gt gt = 117; - IndirectTableFnIdRef indirect_table_fn_id_ref = 118; - IndirectTableFnNameRef indirect_table_fn_name_ref = 119; - Int64Val int64_val = 120; - Leq leq = 121; - ListVal list_val = 122; - Lt lt = 123; - MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 124; - MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 125; - MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 126; - Mod mod = 127; - Mul mul = 128; - Neg neg = 129; - Neq neq = 130; - Not not = 131; - NullVal null_val = 132; - ObjectGetItem object_get_item = 133; - Or or = 134; - Pow pow = 135; - PythonDateVal python_date_val = 136; - PythonTimeVal python_time_val = 137; - PythonTimestampVal python_timestamp_val = 138; - Range range = 139; - ReadAvro read_avro = 140; - ReadCsv read_csv = 141; - ReadDirectory read_directory = 142; - ReadJson read_json = 143; - ReadLoad read_load = 144; - ReadOrc read_orc = 145; - ReadParquet read_parquet = 146; - ReadTable read_table = 147; - ReadXml read_xml = 148; - RedactedConst redacted_const = 149; - RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 150; - RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 151; - RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 152; - RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 153; - RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 154; - Row row = 155; - SeqMapVal seq_map_val = 156; - SessionTableFunction session_table_function = 157; - Sql sql = 158; - SqlExpr sql_expr = 159; - StoredProcedure stored_procedure = 160; - StringVal string_val = 161; - Sub sub = 162; - Table table = 163; - TableDelete table_delete = 164; - TableDropTable table_drop_table = 165; - TableFnCallAlias table_fn_call_alias = 166; - TableFnCallOver table_fn_call_over = 167; - TableMerge table_merge = 168; - TableSample table_sample = 169; - TableUpdate table_update = 170; - ToSnowparkPandas to_snowpark_pandas = 171; - TruncatedExpr truncated_expr = 172; - TupleVal tuple_val = 173; - Udaf udaf = 174; - Udf udf = 175; - Udtf udtf = 176; - WriteCopyIntoLocation write_copy_into_location = 177; - WriteCsv write_csv = 178; - WriteInsertInto write_insert_into = 179; - WriteJson write_json = 180; - WritePandas write_pandas = 181; - WriteParquet write_parquet = 182; - WriteSave write_save = 183; - WriteTable write_table = 184; + DataframeAiAgg dataframe_ai_agg = 47; + DataframeAiClassify dataframe_ai_classify = 48; + DataframeAiComplete dataframe_ai_complete = 49; + DataframeAiCountTokens dataframe_ai_count_tokens = 50; + DataframeAiEmbed dataframe_ai_embed = 51; + DataframeAiExtract dataframe_ai_extract = 52; + DataframeAiFilter dataframe_ai_filter = 53; + DataframeAiParseDocument dataframe_ai_parse_document = 54; + DataframeAiSentiment dataframe_ai_sentiment = 55; + DataframeAiSimilarity dataframe_ai_similarity = 56; + DataframeAiSplitTextMarkdownHeader dataframe_ai_split_text_markdown_header = 57; + DataframeAiSplitTextRecursiveCharacter dataframe_ai_split_text_recursive_character = 58; + DataframeAiSummarizeAgg dataframe_ai_summarize_agg = 59; + DataframeAiTranscribe dataframe_ai_transcribe = 60; + DataframeAlias dataframe_alias = 61; + DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 62; + DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 63; + DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 64; + DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 65; + DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 66; + DataframeCacheResult dataframe_cache_result = 67; + DataframeCol dataframe_col = 68; + DataframeColIlike dataframe_col_ilike = 69; + DataframeCollect dataframe_collect = 70; + DataframeCopyIntoTable dataframe_copy_into_table = 71; + DataframeCount dataframe_count = 72; + DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 73; + DataframeCreateOrReplaceView dataframe_create_or_replace_view = 74; + DataframeCrossJoin dataframe_cross_join = 75; + DataframeCube dataframe_cube = 76; + DataframeDescribe dataframe_describe = 77; + DataframeDistinct dataframe_distinct = 78; + DataframeDrop dataframe_drop = 79; + DataframeDropDuplicates dataframe_drop_duplicates = 80; + DataframeExcept dataframe_except = 81; + DataframeFilter dataframe_filter = 82; + DataframeFirst dataframe_first = 83; + DataframeFlatten dataframe_flatten = 84; + DataframeGroupBy dataframe_group_by = 85; + DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 86; + DataframeIntersect dataframe_intersect = 87; + DataframeJoin dataframe_join = 88; + DataframeJoinTableFunction dataframe_join_table_function = 89; + DataframeLimit dataframe_limit = 90; + DataframeNaDrop_Python dataframe_na_drop__python = 91; + DataframeNaDrop_Scala dataframe_na_drop__scala = 92; + DataframeNaFill dataframe_na_fill = 93; + DataframeNaReplace dataframe_na_replace = 94; + DataframeNaturalJoin dataframe_natural_join = 95; + DataframePivot dataframe_pivot = 96; + DataframeRandomSplit dataframe_random_split = 97; + DataframeReader dataframe_reader = 98; + DataframeRef dataframe_ref = 99; + DataframeRename dataframe_rename = 100; + DataframeRollup dataframe_rollup = 101; + DataframeSample dataframe_sample = 102; + DataframeSelect dataframe_select = 103; + DataframeShow dataframe_show = 104; + DataframeSort dataframe_sort = 105; + DataframeStatApproxQuantile dataframe_stat_approx_quantile = 106; + DataframeStatCorr dataframe_stat_corr = 107; + DataframeStatCov dataframe_stat_cov = 108; + DataframeStatCrossTab dataframe_stat_cross_tab = 109; + DataframeStatSampleBy dataframe_stat_sample_by = 110; + DataframeToDf dataframe_to_df = 111; + DataframeToLocalIterator dataframe_to_local_iterator = 112; + DataframeToPandas dataframe_to_pandas = 113; + DataframeToPandasBatches dataframe_to_pandas_batches = 114; + DataframeUnion dataframe_union = 115; + DataframeUnpivot dataframe_unpivot = 116; + DataframeWithColumn dataframe_with_column = 117; + DataframeWithColumnRenamed dataframe_with_column_renamed = 118; + DataframeWithColumns dataframe_with_columns = 119; + DataframeWriter dataframe_writer = 120; + DatatypeVal datatype_val = 121; + Directory directory = 122; + Div div = 123; + Eq eq = 124; + Flatten flatten = 125; + Float64Val float64_val = 126; + FnRef fn_ref = 127; + Generator generator = 128; + Geq geq = 129; + GroupingSets grouping_sets = 130; + Gt gt = 131; + IndirectTableFnIdRef indirect_table_fn_id_ref = 132; + IndirectTableFnNameRef indirect_table_fn_name_ref = 133; + Int64Val int64_val = 134; + Leq leq = 135; + ListVal list_val = 136; + Lt lt = 137; + MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 138; + MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 139; + MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 140; + Mod mod = 141; + Mul mul = 142; + Neg neg = 143; + Neq neq = 144; + Not not = 145; + NullVal null_val = 146; + ObjectGetItem object_get_item = 147; + Or or = 148; + Pow pow = 149; + PythonDateVal python_date_val = 150; + PythonTimeVal python_time_val = 151; + PythonTimestampVal python_timestamp_val = 152; + Range range = 153; + ReadAvro read_avro = 154; + ReadCsv read_csv = 155; + ReadDirectory read_directory = 156; + ReadJson read_json = 157; + ReadLoad read_load = 158; + ReadOrc read_orc = 159; + ReadParquet read_parquet = 160; + ReadTable read_table = 161; + ReadXml read_xml = 162; + RedactedConst redacted_const = 163; + RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 164; + RelationalGroupedDataframeAiAgg relational_grouped_dataframe_ai_agg = 165; + RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 166; + RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 167; + RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 168; + RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 169; + Row row = 170; + SeqMapVal seq_map_val = 171; + SessionTableFunction session_table_function = 172; + Sql sql = 173; + SqlExpr sql_expr = 174; + StoredProcedure stored_procedure = 175; + StringVal string_val = 176; + Sub sub = 177; + Table table = 178; + TableDelete table_delete = 179; + TableDropTable table_drop_table = 180; + TableFnCallAlias table_fn_call_alias = 181; + TableFnCallOver table_fn_call_over = 182; + TableMerge table_merge = 183; + TableSample table_sample = 184; + TableUpdate table_update = 185; + ToSnowparkPandas to_snowpark_pandas = 186; + TruncatedExpr truncated_expr = 187; + TupleVal tuple_val = 188; + Udaf udaf = 189; + Udf udf = 190; + Udtf udtf = 191; + WriteCopyIntoLocation write_copy_into_location = 192; + WriteCsv write_csv = 193; + WriteInsertInto write_insert_into = 194; + WriteJson write_json = 195; + WritePandas write_pandas = 196; + WriteParquet write_parquet = 197; + WriteSave write_save = 198; + WriteTable write_table = 199; } } @@ -1679,150 +1835,165 @@ message HasSrcPosition { ColumnWithinGroup column_within_group = 47; CreateDataframe create_dataframe = 48; DataframeAgg dataframe_agg = 49; - DataframeAlias dataframe_alias = 50; - DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 51; - DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 52; - DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 53; - DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 54; - DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 55; - DataframeCacheResult dataframe_cache_result = 56; - DataframeCol dataframe_col = 57; - DataframeColIlike dataframe_col_ilike = 58; - DataframeCollect dataframe_collect = 59; - DataframeCopyIntoTable dataframe_copy_into_table = 60; - DataframeCount dataframe_count = 61; - DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 62; - DataframeCreateOrReplaceView dataframe_create_or_replace_view = 63; - DataframeCrossJoin dataframe_cross_join = 64; - DataframeCube dataframe_cube = 65; - DataframeDescribe dataframe_describe = 66; - DataframeDistinct dataframe_distinct = 67; - DataframeDrop dataframe_drop = 68; - DataframeDropDuplicates dataframe_drop_duplicates = 69; - DataframeExcept dataframe_except = 70; - DataframeFilter dataframe_filter = 71; - DataframeFirst dataframe_first = 72; - DataframeFlatten dataframe_flatten = 73; - DataframeGroupBy dataframe_group_by = 74; - DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 75; - DataframeIntersect dataframe_intersect = 76; - DataframeJoin dataframe_join = 77; - DataframeJoinTableFunction dataframe_join_table_function = 78; - DataframeLimit dataframe_limit = 79; - DataframeNaDrop_Python dataframe_na_drop__python = 80; - DataframeNaDrop_Scala dataframe_na_drop__scala = 81; - DataframeNaFill dataframe_na_fill = 82; - DataframeNaReplace dataframe_na_replace = 83; - DataframeNaturalJoin dataframe_natural_join = 84; - DataframePivot dataframe_pivot = 85; - DataframeRandomSplit dataframe_random_split = 86; - DataframeReader dataframe_reader = 87; - DataframeRef dataframe_ref = 88; - DataframeRename dataframe_rename = 89; - DataframeRollup dataframe_rollup = 90; - DataframeSample dataframe_sample = 91; - DataframeSelect dataframe_select = 92; - DataframeShow dataframe_show = 93; - DataframeSort dataframe_sort = 94; - DataframeStatApproxQuantile dataframe_stat_approx_quantile = 95; - DataframeStatCorr dataframe_stat_corr = 96; - DataframeStatCov dataframe_stat_cov = 97; - DataframeStatCrossTab dataframe_stat_cross_tab = 98; - DataframeStatSampleBy dataframe_stat_sample_by = 99; - DataframeToDf dataframe_to_df = 100; - DataframeToLocalIterator dataframe_to_local_iterator = 101; - DataframeToPandas dataframe_to_pandas = 102; - DataframeToPandasBatches dataframe_to_pandas_batches = 103; - DataframeUnion dataframe_union = 104; - DataframeUnpivot dataframe_unpivot = 105; - DataframeWithColumn dataframe_with_column = 106; - DataframeWithColumnRenamed dataframe_with_column_renamed = 107; - DataframeWithColumns dataframe_with_columns = 108; - DataframeWriter dataframe_writer = 109; - DatatypeVal datatype_val = 110; - Directory directory = 111; - Div div = 112; - Eq eq = 113; - Flatten flatten = 114; - Float64Val float64_val = 115; - FnRef fn_ref = 116; - Generator generator = 117; - Geq geq = 118; - GroupingSets grouping_sets = 119; - Gt gt = 120; - IndirectTableFnIdRef indirect_table_fn_id_ref = 121; - IndirectTableFnNameRef indirect_table_fn_name_ref = 122; - Int64Val int64_val = 123; - Leq leq = 124; - ListVal list_val = 125; - Lt lt = 126; - MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 127; - MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 128; - MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 129; - Mod mod = 130; - Mul mul = 131; - NameRef name_ref = 132; - Neg neg = 133; - Neq neq = 134; - Not not = 135; - NullVal null_val = 136; - ObjectGetItem object_get_item = 137; - Or or = 138; - Pow pow = 139; - PythonDateVal python_date_val = 140; - PythonTimeVal python_time_val = 141; - PythonTimestampVal python_timestamp_val = 142; - Range range = 143; - ReadAvro read_avro = 144; - ReadCsv read_csv = 145; - ReadDirectory read_directory = 146; - ReadJson read_json = 147; - ReadLoad read_load = 148; - ReadOrc read_orc = 149; - ReadParquet read_parquet = 150; - ReadTable read_table = 151; - ReadXml read_xml = 152; - RedactedConst redacted_const = 153; - RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 154; - RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 155; - RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 156; - RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 157; - RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 158; - Row row = 159; - SeqMapVal seq_map_val = 160; - SessionTableFunction session_table_function = 161; - Sql sql = 162; - SqlExpr sql_expr = 163; - StoredProcedure stored_procedure = 164; - StringVal string_val = 165; - Sub sub = 166; - Table table = 167; - TableDelete table_delete = 168; - TableDropTable table_drop_table = 169; - TableFnCallAlias table_fn_call_alias = 170; - TableFnCallOver table_fn_call_over = 171; - TableMerge table_merge = 172; - TableSample table_sample = 173; - TableUpdate table_update = 174; - ToSnowparkPandas to_snowpark_pandas = 175; - TruncatedExpr truncated_expr = 176; - TupleVal tuple_val = 177; - Udaf udaf = 178; - Udf udf = 179; - Udtf udtf = 180; - WindowSpecEmpty window_spec_empty = 181; - WindowSpecOrderBy window_spec_order_by = 182; - WindowSpecPartitionBy window_spec_partition_by = 183; - WindowSpecRangeBetween window_spec_range_between = 184; - WindowSpecRowsBetween window_spec_rows_between = 185; - WriteCopyIntoLocation write_copy_into_location = 186; - WriteCsv write_csv = 187; - WriteInsertInto write_insert_into = 188; - WriteJson write_json = 189; - WritePandas write_pandas = 190; - WriteParquet write_parquet = 191; - WriteSave write_save = 192; - WriteTable write_table = 193; + DataframeAiAgg dataframe_ai_agg = 50; + DataframeAiClassify dataframe_ai_classify = 51; + DataframeAiComplete dataframe_ai_complete = 52; + DataframeAiCountTokens dataframe_ai_count_tokens = 53; + DataframeAiEmbed dataframe_ai_embed = 54; + DataframeAiExtract dataframe_ai_extract = 55; + DataframeAiFilter dataframe_ai_filter = 56; + DataframeAiParseDocument dataframe_ai_parse_document = 57; + DataframeAiSentiment dataframe_ai_sentiment = 58; + DataframeAiSimilarity dataframe_ai_similarity = 59; + DataframeAiSplitTextMarkdownHeader dataframe_ai_split_text_markdown_header = 60; + DataframeAiSplitTextRecursiveCharacter dataframe_ai_split_text_recursive_character = 61; + DataframeAiSummarizeAgg dataframe_ai_summarize_agg = 62; + DataframeAiTranscribe dataframe_ai_transcribe = 63; + DataframeAlias dataframe_alias = 64; + DataframeAnalyticsComputeLag dataframe_analytics_compute_lag = 65; + DataframeAnalyticsComputeLead dataframe_analytics_compute_lead = 66; + DataframeAnalyticsCumulativeAgg dataframe_analytics_cumulative_agg = 67; + DataframeAnalyticsMovingAgg dataframe_analytics_moving_agg = 68; + DataframeAnalyticsTimeSeriesAgg dataframe_analytics_time_series_agg = 69; + DataframeCacheResult dataframe_cache_result = 70; + DataframeCol dataframe_col = 71; + DataframeColIlike dataframe_col_ilike = 72; + DataframeCollect dataframe_collect = 73; + DataframeCopyIntoTable dataframe_copy_into_table = 74; + DataframeCount dataframe_count = 75; + DataframeCreateOrReplaceDynamicTable dataframe_create_or_replace_dynamic_table = 76; + DataframeCreateOrReplaceView dataframe_create_or_replace_view = 77; + DataframeCrossJoin dataframe_cross_join = 78; + DataframeCube dataframe_cube = 79; + DataframeDescribe dataframe_describe = 80; + DataframeDistinct dataframe_distinct = 81; + DataframeDrop dataframe_drop = 82; + DataframeDropDuplicates dataframe_drop_duplicates = 83; + DataframeExcept dataframe_except = 84; + DataframeFilter dataframe_filter = 85; + DataframeFirst dataframe_first = 86; + DataframeFlatten dataframe_flatten = 87; + DataframeGroupBy dataframe_group_by = 88; + DataframeGroupByGroupingSets dataframe_group_by_grouping_sets = 89; + DataframeIntersect dataframe_intersect = 90; + DataframeJoin dataframe_join = 91; + DataframeJoinTableFunction dataframe_join_table_function = 92; + DataframeLimit dataframe_limit = 93; + DataframeNaDrop_Python dataframe_na_drop__python = 94; + DataframeNaDrop_Scala dataframe_na_drop__scala = 95; + DataframeNaFill dataframe_na_fill = 96; + DataframeNaReplace dataframe_na_replace = 97; + DataframeNaturalJoin dataframe_natural_join = 98; + DataframePivot dataframe_pivot = 99; + DataframeRandomSplit dataframe_random_split = 100; + DataframeReader dataframe_reader = 101; + DataframeRef dataframe_ref = 102; + DataframeRename dataframe_rename = 103; + DataframeRollup dataframe_rollup = 104; + DataframeSample dataframe_sample = 105; + DataframeSelect dataframe_select = 106; + DataframeShow dataframe_show = 107; + DataframeSort dataframe_sort = 108; + DataframeStatApproxQuantile dataframe_stat_approx_quantile = 109; + DataframeStatCorr dataframe_stat_corr = 110; + DataframeStatCov dataframe_stat_cov = 111; + DataframeStatCrossTab dataframe_stat_cross_tab = 112; + DataframeStatSampleBy dataframe_stat_sample_by = 113; + DataframeToDf dataframe_to_df = 114; + DataframeToLocalIterator dataframe_to_local_iterator = 115; + DataframeToPandas dataframe_to_pandas = 116; + DataframeToPandasBatches dataframe_to_pandas_batches = 117; + DataframeUnion dataframe_union = 118; + DataframeUnpivot dataframe_unpivot = 119; + DataframeWithColumn dataframe_with_column = 120; + DataframeWithColumnRenamed dataframe_with_column_renamed = 121; + DataframeWithColumns dataframe_with_columns = 122; + DataframeWriter dataframe_writer = 123; + DatatypeVal datatype_val = 124; + Directory directory = 125; + Div div = 126; + Eq eq = 127; + Flatten flatten = 128; + Float64Val float64_val = 129; + FnRef fn_ref = 130; + Generator generator = 131; + Geq geq = 132; + GroupingSets grouping_sets = 133; + Gt gt = 134; + IndirectTableFnIdRef indirect_table_fn_id_ref = 135; + IndirectTableFnNameRef indirect_table_fn_name_ref = 136; + Int64Val int64_val = 137; + Leq leq = 138; + ListVal list_val = 139; + Lt lt = 140; + MergeDeleteWhenMatchedClause merge_delete_when_matched_clause = 141; + MergeInsertWhenNotMatchedClause merge_insert_when_not_matched_clause = 142; + MergeUpdateWhenMatchedClause merge_update_when_matched_clause = 143; + Mod mod = 144; + Mul mul = 145; + NameRef name_ref = 146; + Neg neg = 147; + Neq neq = 148; + Not not = 149; + NullVal null_val = 150; + ObjectGetItem object_get_item = 151; + Or or = 152; + Pow pow = 153; + PythonDateVal python_date_val = 154; + PythonTimeVal python_time_val = 155; + PythonTimestampVal python_timestamp_val = 156; + Range range = 157; + ReadAvro read_avro = 158; + ReadCsv read_csv = 159; + ReadDirectory read_directory = 160; + ReadJson read_json = 161; + ReadLoad read_load = 162; + ReadOrc read_orc = 163; + ReadParquet read_parquet = 164; + ReadTable read_table = 165; + ReadXml read_xml = 166; + RedactedConst redacted_const = 167; + RelationalGroupedDataframeAgg relational_grouped_dataframe_agg = 168; + RelationalGroupedDataframeAiAgg relational_grouped_dataframe_ai_agg = 169; + RelationalGroupedDataframeApplyInPandas relational_grouped_dataframe_apply_in_pandas = 170; + RelationalGroupedDataframeBuiltin relational_grouped_dataframe_builtin = 171; + RelationalGroupedDataframePivot relational_grouped_dataframe_pivot = 172; + RelationalGroupedDataframeRef relational_grouped_dataframe_ref = 173; + Row row = 174; + SeqMapVal seq_map_val = 175; + SessionTableFunction session_table_function = 176; + Sql sql = 177; + SqlExpr sql_expr = 178; + StoredProcedure stored_procedure = 179; + StringVal string_val = 180; + Sub sub = 181; + Table table = 182; + TableDelete table_delete = 183; + TableDropTable table_drop_table = 184; + TableFnCallAlias table_fn_call_alias = 185; + TableFnCallOver table_fn_call_over = 186; + TableMerge table_merge = 187; + TableSample table_sample = 188; + TableUpdate table_update = 189; + ToSnowparkPandas to_snowpark_pandas = 190; + TruncatedExpr truncated_expr = 191; + TupleVal tuple_val = 192; + Udaf udaf = 193; + Udf udf = 194; + Udtf udtf = 195; + WindowSpecEmpty window_spec_empty = 196; + WindowSpecOrderBy window_spec_order_by = 197; + WindowSpecPartitionBy window_spec_partition_by = 198; + WindowSpecRangeBetween window_spec_range_between = 199; + WindowSpecRowsBetween window_spec_rows_between = 200; + WriteCopyIntoLocation write_copy_into_location = 201; + WriteCsv write_csv = 202; + WriteInsertInto write_insert_into = 203; + WriteJson write_json = 204; + WritePandas write_pandas = 205; + WriteParquet write_parquet = 206; + WriteSave write_save = 207; + WriteTable write_table = 208; } } @@ -2083,7 +2254,15 @@ message RelationalGroupedDataframeAgg { SrcPosition src = 3; } -// dataframe-grouped.ir:45 +// dataframe-grouped.ir:46 +message RelationalGroupedDataframeAiAgg { + Expr expr = 1; + Expr grouped_df = 2; + SrcPosition src = 3; + string task_description = 4; +} + +// dataframe-grouped.ir:52 message RelationalGroupedDataframeApplyInPandas { Callable func = 1; Expr grouped_df = 2; @@ -2100,7 +2279,7 @@ message RelationalGroupedDataframeBuiltin { SrcPosition src = 4; } -// dataframe-grouped.ir:52 +// dataframe-grouped.ir:59 message RelationalGroupedDataframePivot { Expr default_on_null = 1; Expr grouped_df = 2; diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 713887a71f..12b2b13794 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -2146,6 +2146,97 @@ def get_line_numbers( return new_intervals +def create_prompt_column_from_template( + template: str, + placeholder_to_column: Union[ + Dict[str, "snowflake.snowpark.Column"], List["snowflake.snowpark.Column"] + ], + _emit_ast: bool = True, +) -> "snowflake.snowpark.Column": + """ + Creates a prompt Column object from a template string with placeholders. + + Args: + template: A string containing placeholders (either named like {name} or positional like {0}) + placeholder_to_column: Either: + - A dict mapping placeholder names to Column objects + - A list of Column objects for positional placeholders + _emit_ast: Whether to emit AST + + Returns: + A prompt Column object + """ + from snowflake.snowpark.functions import prompt + + if isinstance(placeholder_to_column, dict): + # Handle named placeholders + # First validate that all provided columns are used in the template + named_placeholders = set(re.findall(r"\{([^{}]+)\}", template)) + provided_keys = set(placeholder_to_column.keys()) + unused_columns = provided_keys - named_placeholders + if unused_columns: + raise ValueError( + f"The following column placeholders were provided but not used in the template: {unused_columns}" + ) + + # Track which placeholders we've seen and their order + seen_placeholders = [] + placeholder_positions = {} + + def replace_placeholder(match: "re.Match[str]") -> str: + placeholder_name = match.group(1) + if placeholder_name not in placeholder_to_column: + raise ValueError( + f"Placeholder '{{{placeholder_name}}}' in template not found in provided columns. " + f"Available placeholders: {list(placeholder_to_column.keys())}" + ) + + # Assign position if not seen before + if placeholder_name not in placeholder_positions: + placeholder_positions[placeholder_name] = len(seen_placeholders) + seen_placeholders.append(placeholder_name) + + return "{" + str(placeholder_positions[placeholder_name]) + "}" + + # Replace all named placeholders with positional ones + rewritten_template = re.sub(r"\{([^{}]+)\}", replace_placeholder, template) + + # Build ordered list of columns based on the order they appeared in the template + ordered_columns = [placeholder_to_column[name] for name in seen_placeholders] + + # Create and return the prompt Column + return prompt(rewritten_template, *ordered_columns, _emit_ast=_emit_ast) + + elif isinstance(placeholder_to_column, list): + # Handle positional placeholders + # Find all positional placeholders + positional_placeholders = re.findall(r"\{(\d+)\}", template) + num_placeholders = len( + set(positional_placeholders) + ) # Count unique placeholders + num_columns = len(placeholder_to_column) + + if num_placeholders == 0: + # No placeholders found - auto-append them at the end + placeholders_to_add = " ".join([f"{{{i}}}" for i in range(num_columns)]) + template = f"{template} {placeholders_to_add}" + elif num_placeholders != num_columns: + # Validate that number of placeholders matches number of columns + raise ValueError( + f"Number of positional placeholders ({num_placeholders}) does not match " + f"number of columns provided ({num_columns}). " + f"Found placeholders: {{{', '.join(sorted(set(positional_placeholders)))}}}" + ) + + # Create and return the prompt Column with the list of columns + return prompt(template, *placeholder_to_column, _emit_ast=_emit_ast) + + else: + raise TypeError( + "placeholder_to_column must be a list of Columns or a dict mapping placeholder names to Columns" + ) + + def get_plan_from_line_numbers( plan_node: Union["SnowflakePlan", "Selectable"], line_number: int, diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 80f078f912..ceb278e0e4 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -172,6 +172,7 @@ ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str +from snowflake.snowpark.dataframe_ai_functions import DataFrameAIFunctions from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions @@ -653,6 +654,8 @@ def __init__( self.fillna = self._na.fill self.replace = self._na.replace + self._ai = DataFrameAIFunctions(self) + self._alias: Optional[str] = None if context._debug_eager_schema_validation: @@ -5794,6 +5797,14 @@ def na(self) -> DataFrameNaFunctions: """ return self._na + @property + def ai(self) -> DataFrameAIFunctions: + """ + Returns a :class:`DataFrameAIFunctions` object that provides AI-powered functions + for the DataFrame. + """ + return self._ai + @property def session(self) -> "snowflake.snowpark.Session": """ diff --git a/src/snowflake/snowpark/dataframe_ai_functions.py b/src/snowflake/snowpark/dataframe_ai_functions.py new file mode 100644 index 0000000000..e06f8fcad2 --- /dev/null +++ b/src/snowflake/snowpark/dataframe_ai_functions.py @@ -0,0 +1,1858 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + +from snowflake.snowpark._internal.utils import ( + create_prompt_column_from_template, + experimental, + publicapi, +) +from snowflake.snowpark._internal.ast.utils import ( + build_expr_from_python_val, + build_expr_from_snowpark_column_or_col_name, + build_expr_from_snowpark_column_or_python_val, + with_src_position, +) +from snowflake.snowpark._internal.type_utils import ColumnOrName +from snowflake.snowpark.column import Column, _to_col_if_str, _to_col_if_lit +from snowflake.snowpark.functions import ( + ai_complete, + ai_filter, + ai_agg, + ai_classify, + ai_extract, + ai_similarity, + ai_sentiment, + ai_embed, + ai_summarize_agg, + ai_transcribe, + ai_parse_document, + function, +) +from snowflake.snowpark._internal.telemetry import add_api_call + +if TYPE_CHECKING: + import snowflake.snowpark + + +class DataFrameAIFunctions: + """Provides AI-powered functions for a :class:`DataFrame`.""" + + def __init__(self, dataframe: "snowflake.snowpark.DataFrame") -> None: + self._dataframe = dataframe + + @experimental(version="1.39.0") + @publicapi + def complete( + self, + prompt: str, + input_columns: Union[List[Column], Dict[str, Column]], + model: str, + *, + output_column: Optional[str] = None, + model_parameters: Optional[Dict[str, Any]] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Generate a response (completion) on each row using the specified language model. + + Args: + prompt: The prompt template string. Use placeholders like ``{name}`` when passing a dict of columns, + or ``{0}``, ``{1}`` when passing a list. + input_columns: A list of Columns (positional placeholders ``{0}``, ``{1}``, ...) + or a dict mapping placeholder names to Columns. + model: A string specifying the model to be used. Different input types have different supported models. + See details in `AI_COMPLETE `_. + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_COMPLETE_OUTPUT`` is appended + model_parameters: Optional dict containing model hyperparameters: + + - temperature: Value from 0 to 1 controlling randomness (default: 0) + - top_p: Value from 0 to 1 controlling diversity (default: 0) + - max_tokens: Maximum number of output tokens (default: 4096, max: 8192) + - guardrails: Enable Cortex Guard filtering (default: False) + + Returns: + A new DataFrame with appended output columns at the end. + + Examples:: + + >>> # Single column output with named placeholder + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe( + ... [["What is machine learning?"], ["Explain quantum computing"]], + ... schema=["question"] + ... ) + >>> result_df = df.ai.complete( + ... prompt="Answer this question briefly: {q}", + ... input_columns={"q": col("question")}, + ... output_column="answer", + ... model="snowflake-arctic" + ... ) + >>> result_df.columns + ['QUESTION', 'ANSWER'] + >>> result_df.count() + 2 + + >>> # Processing images with file input + >>> from snowflake.snowpark.functions import to_file + >>> # Upload images to a stage first + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/kitchen.png", "@mystage", auto_compress=False) + >>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False) + >>> # Create DataFrame with image paths and questions + >>> df = session.create_dataframe( + ... [ + ... ["@mystage/kitchen.png", "What appliances are visible in this image?"], + ... ["@mystage/dog.jpg", "What animal is in this image?"] + ... ], + ... schema=["image_path", "question"] + ... ) + >>> # Use ai.complete with image files + >>> result_df = df.ai.complete( + ... prompt="Image: {0}, Question: {1}", + ... input_columns=[ + ... to_file(col("image_path")), + ... col("question") + ... ], + ... output_column="answer", + ... model="claude-4-sonnet" + ... ) + >>> result_df.columns + ['IMAGE_PATH', 'QUESTION', 'ANSWER'] + >>> result_df.count() + 2 + >>> results = result_df.collect() + >>> 'microwave' in results[0]["ANSWER"].lower() + True + >>> 'dog' in results[1]["ANSWER"].lower() + True + """ + + # Build the prompt Column + if isinstance(input_columns, (dict, list)): + prompt_obj = create_prompt_column_from_template( + prompt, input_columns, _emit_ast=_emit_ast + ) + else: + raise TypeError( + "input_columns must be a list of Columns or a dict mapping placeholder names to Columns" + ) + + output_column_name = output_column or "AI_COMPLETE_OUTPUT" + + # AST at top + stmt = None + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_complete, stmt) + self._dataframe._set_ast_ref(ast.df) + ast.model = model + ast.prompt = prompt + # populate input_columns with the prompt column expression + build_expr_from_snowpark_column_or_col_name(ast.input_columns, prompt_obj) + if model_parameters: + for k, v in model_parameters.items(): + entry = ast.model_parameters.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name + + # Call the ai_complete function with all explicit parameters + result_col = ai_complete( + model=model, + prompt=prompt_obj, + model_parameters=model_parameters, + _emit_ast=False, + ) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.complete", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def filter( + self, + predicate: str, + input_columns: Union[List[Column], Dict[str, Column]], + *, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Filter rows using AI-powered boolean classification. + + This method applies AI-based filtering to each row, classifying them as True or False + based on the provided predicate. Supports both text-based filtering and image filtering. + + Args: + predicate: The classification instruction string. Use placeholders like ``{name}`` when passing + a dict of columns, or ``{0}``, ``{1}`` when passing a list. For file-based filtering, + this should contain instructions to classify the file as TRUE or FALSE. + input_columns: Optional list of Columns (positional placeholders ``{0}``, ``{1}``, ...) + or a dict mapping placeholder names to Columns. Used when predicate contains placeholders. + + Examples:: + + >>> # Simple text filtering without placeholders + >>> df = session.create_dataframe( + ... [["This is great!"], ["This is terrible!"], ["This is okay."]], + ... schema=["review"] + ... ) + >>> positive_df = df.ai.filter("Is this review positive?", input_columns=[df["review"]]) + >>> positive_df.count() # Should be 1 (only "This is great!") + 1 + + >>> # Text filtering with named placeholders + >>> df = session.create_dataframe( + ... [["Switzerland", "Europe"], ["Korea", "Asia"], ["Brazil", "South America"]], + ... schema=["country", "continent"] + ... ) + >>> european_df = df.ai.filter( + ... "Is {country} located in {continent} and specifically in Europe?", + ... input_columns={"country": df["country"], "continent": df["continent"]} + ... ) + >>> european_df.collect()[0]["COUNTRY"] + 'Switzerland' + + >>> # Image filtering with positional placeholders + >>> from snowflake.snowpark.functions import to_file + >>> # Upload images to a stage first + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False) + >>> _ = session.file.put("tests/resources/cat.jpeg", "@mystage", auto_compress=False) + >>> df = session.read.file("@mystage") + >>> dog_images_df = df.ai.filter( + ... "Does this image contain a dog?", + ... input_columns=[df["FILE"]] + ... ) + >>> dog_images_df.count() # Should be 1 (only dog image) + 1 + """ + + # AST at top + stmt = None + predicate_ast_col = None + if _emit_ast: + if isinstance(input_columns, (dict, list)): + predicate_ast_col = create_prompt_column_from_template( + predicate, input_columns, _emit_ast=True + ) + else: + raise TypeError( + "input_columns must be a list of Columns or a dict mapping placeholder names to Columns" + ) + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_filter, stmt) + self._dataframe._set_ast_ref(ast.df) + ast.predicate = predicate + build_expr_from_snowpark_column_or_col_name( + ast.input_columns, predicate_ast_col + ) + + # Build the predicate Column for execution + if isinstance(input_columns, (dict, list)): + predicate_col = create_prompt_column_from_template( + predicate, input_columns, _emit_ast=False + ) + else: + raise TypeError( + "input_columns must be a list of Columns or a dict mapping placeholder names to Columns" + ) + + # Filter the DataFrame to only include rows where the result is True + filter_result = ai_filter( + predicate=predicate_col, + _emit_ast=False, + ) + filtered_df = self._dataframe.filter(filter_result, _emit_ast=False) + + add_api_call( + filtered_df, + "DataFrame.ai.filter", + ) + if _emit_ast: + filtered_df._ast_id = stmt.uid + return filtered_df + + @experimental(version="1.39.0") + @publicapi + def agg( + self, + task_description: str, + input_column: ColumnOrName, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Aggregate a column of text data using a natural language task description. + + This method reduces a column of text by performing a natural language aggregation + as described in the task description. For instance, it can summarize large datasets or + extract specific insights. + + Args: + task_description: A plain English string that describes the aggregation task, such as + "Summarize the product reviews for a blog post targeting consumers" or + "Identify the most positive review and translate it into French and Polish, one word only". + input_column: The column (Column object or column name as string) containing the text data + on which the aggregation operation is to be performed. + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_AGG_OUTPUT`` is appended. + + Examples:: + + >>> # Aggregate product reviews + >>> df = session.create_dataframe([ + ... ["Excellent product, highly recommend!"], + ... ["Great quality and fast shipping"], + ... ["Average product, nothing special"], + ... ["Poor quality, very disappointed"], + ... ], schema=["review"]) + >>> summary_df = df.ai.agg( + ... task_description="Summarize these product reviews for a blog post targeting consumers", + ... input_column="review", + ... output_column="summary" + ... ) + >>> summary_df.columns + ['SUMMARY'] + >>> summary_df.count() + 1 + + >>> # Aggregate with Column object + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe([ + ... ["Customer service was excellent"], + ... ["Product arrived damaged"], + ... ["Great value for money"], + ... ["Would buy again"], + ... ], schema=["feedback"]) + >>> insights_df = df.ai.agg( + ... task_description="Extract the main positive and negative points from customer feedback", + ... input_column=col("feedback"), + ... output_column="insights" + ... ) + >>> insights_df.count() + 1 + + Note: + For optimal performance, follow these guidelines: + + - Use plain English text for the task description. + + - Describe the text provided in the task description. For example, instead of a task + description like "summarize", use "Summarize the phone call transcripts". + + - Describe the intended use case. For example, instead of "find the best review", + use "Find the most positive and well-written restaurant review to highlight on + the restaurant website". + + - Consider breaking the task description into multiple steps. For example, instead of + "Summarize the new articles", use "You will be provided with news articles from + various publishers presenting events from different points of view. Please create + a concise and elaborative summary of source texts without missing any crucial information.". + """ + output_column_name = output_column or "AI_AGG_OUTPUT" + + # AST at top + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.agg") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_agg, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + ast.task_description = task_description + + ast.output_column.value = output_column_name + + # Call the ai_agg function + result_col = ai_agg( + input_col, + task_description=task_description, + _emit_ast=False, + ) + + # Create a new DataFrame with the aggregated result + df = self._dataframe.select( + result_col.alias(output_column_name), _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.agg", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def classify( + self, + input_column: ColumnOrName, + categories: Union[List[str], Column], + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + **kwargs, + ) -> "snowflake.snowpark.DataFrame": + """Classify text or images into specified categories using AI. + + This method applies AI-based classification to each row, assigning one or more categories + from the provided list based on the input content. + + Args: + input_column: The column (Column object or column name as string) containing the text + or image data to classify. + categories: List of category strings or a Column containing an array of categories. + Must contain at least 2 and no more than 100 categories. + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_CLASSIFY_OUTPUT`` is appended. + **kwargs: Configuration settings specified as key/value pairs. Supported keys: + + - task_description: A explanation of the classification task that is 50 words or fewer. + This can help the model understand the context of the classification task and improve accuracy. + + - output_mode: Set to ``multi`` for multi-label classification. Defaults to ``single`` for single-label classification. + + - examples: A list of example objects for few-shot learning. Each example must include: + + - input: Example text to classify. + - labels: List of correct categories for the input. + - explanation: Explanation of why the input maps to those categories. + + Returns: + A new DataFrame with an appended output column containing classification results. + The output is a JSON object with a ``labels`` field containing the assigned categories. + + Examples:: + + >>> # Simple text classification with list of categories + >>> from snowflake.snowpark.functions import col + >>> import json + >>> df = session.create_dataframe( + ... [ + ... ["I love hiking in the mountains"], + ... ["My favorite dish is pasta carbonara"], + ... ["Just finished reading a great book"], + ... ], + ... schema=["text"] + ... ) + >>> result_df = df.ai.classify( + ... input_column="text", + ... categories=["hiking", "cooking", "reading"], + ... output_column="category" + ... ) + >>> result_df.columns + ['TEXT', 'CATEGORY'] + >>> results = result_df.collect() + >>> json.loads(results[0]["CATEGORY"])["labels"][0] + 'hiking' + + >>> # Image classification with Column containing categories + >>> from snowflake.snowpark.functions import to_file + >>> # Upload images to a stage first + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False) + >>> _ = session.file.put("tests/resources/cat.jpeg", "@mystage", auto_compress=False) + >>> _ = session.file.put("tests/resources/kitchen.png", "@mystage", auto_compress=False) + >>> # Create DataFrame with image paths and possible categories for each image + >>> df = session.create_dataframe( + ... [ + ... ["@mystage/dog.jpg", ["cat", "dog", "bird", "fish"]], + ... ["@mystage/cat.jpeg", ["cat", "dog", "rabbit", "hamster"]], + ... ["@mystage/kitchen.png", ["kitchen", "bedroom", "bathroom", "living room"]], + ... ], + ... schema=["image_path", "categories"] + ... ) + >>> # Classify images using their respective category options + >>> result_df = df.ai.classify( + ... input_column=to_file(col("image_path")), + ... categories=col("categories"), + ... output_column="classification" + ... ) + >>> result_df.columns + ['IMAGE_PATH', 'CATEGORIES', 'CLASSIFICATION'] + >>> results = result_df.collect() + >>> # Verify the dog image is classified as 'dog' + >>> dog_result = [r for r in results if 'dog.jpg' in r["IMAGE_PATH"]][0] + >>> json.loads(dog_result["CLASSIFICATION"])["labels"][0] + 'dog' + + >>> # Multi-label classification with advanced configuration + >>> df = session.create_dataframe( + ... [ + ... ["I enjoy traveling and trying local cuisines"], + ... ["Reading books while on a flight"], + ... ["Cooking recipes from different countries"], + ... ], + ... schema=["text"] + ... ) + >>> result_df = df.ai.classify( + ... input_column="text", + ... categories=["travel", "cooking", "reading", "sports"], + ... output_column="topics", + ... task_description="Identify all topics mentioned in the text", + ... output_mode="multi", + ... examples=[{ + ... "input": "I love reading cookbooks during my travels", + ... "labels": ["travel", "cooking", "reading"], + ... "explanation": "The text mentions traveling, cookbooks (cooking), and reading" + ... }] + ... ) + >>> result_df.columns + ['TEXT', 'TOPICS'] + >>> results = result_df.collect() + >>> len(json.loads(results[0]["TOPICS"])["labels"]) >= 1 # Multi-label can have multiple labels + True + """ + + output_column_name = output_column or "AI_CLASSIFY_OUTPUT" + + # Convert string input column to Column object and AST at top + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.classify") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_classify, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + build_expr_from_snowpark_column_or_python_val(ast.categories, categories) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name + + # Call the ai_classify function + result_col = ai_classify( + input_col, + categories, + _emit_ast=False, + **kwargs, + ) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.classify", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def similarity( + self, + input1: ColumnOrName, + input2: ColumnOrName, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + **kwargs, + ) -> "snowflake.snowpark.DataFrame": + """Compute similarity scores between two columns using AI-powered embeddings. + + This method computes a similarity score based on the vector cosine similarity + of the inputs' embedding vectors. Supports both text and image similarity. + + Args: + input1: The first column (Column object or column name as string) for comparison. + Can contain text strings or images (FILE data type). + input2: The second column (Column object or column name as string) for comparison. + Must be the same type as input1 (both text or both images). + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_SIMILARITY_OUTPUT`` is appended. + **kwargs: Configuration settings specified as key/value pairs. Supported keys: + + - model: The embedding model used for embeddings. + For text input, defaults to 'snowflake-arctic-embed-l-v2'. + For image input, defaults to 'voyage-multimodal-3'. + Supported models include: + + - Text: 'snowflake-arctic-embed-l-v2', 'nv-embed-qa-4', + 'multilingual-e5-large', 'voyage-multilingual-2', + 'snowflake-arctic-embed-m-v1.5', 'snowflake-arctic-embed-m', + 'e5-base-v2' + - Images: 'voyage-multimodal-3' + + Returns: + A new DataFrame with an appended output column containing similarity scores. + The scores range from -1 to 1, where higher values indicate greater similarity. + + Examples:: + + >>> # Text similarity between two columns + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe( + ... [ + ... ["I love programming", "I enjoy coding"], + ... ["The weather is nice", "It's raining heavily"], + ... ["Python is great", "Python is awesome"], + ... ], + ... schema=["text1", "text2"] + ... ) + >>> result_df = df.ai.similarity( + ... input1="text1", + ... input2="text2", + ... output_column="similarity_score" + ... ) + >>> result_df.columns + ['TEXT1', 'TEXT2', 'SIMILARITY_SCORE'] + >>> results = result_df.collect() + >>> results[0]["SIMILARITY_SCORE"] > 0.5 # Similar texts + True + + >>> # Multilingual text similarity with custom model + >>> df = session.create_dataframe( + ... [ + ... ["I love programming", "我喜欢编程"], # Same meaning in English and Chinese + ... ["Good morning", "Buenas noches"], # Different meanings + ... ], + ... schema=["english", "other_language"] + ... ) + >>> result_df = df.ai.similarity( + ... input1=col("english"), + ... input2=col("other_language"), + ... output_column="cross_lingual_similarity", + ... model="multilingual-e5-large" + ... ) + >>> result_df.columns + ['ENGLISH', 'OTHER_LANGUAGE', 'CROSS_LINGUAL_SIMILARITY'] + + >>> # Image similarity + >>> from snowflake.snowpark.functions import to_file + >>> # Upload images to a stage first + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False) + >>> _ = session.file.put("tests/resources/cat.jpeg", "@mystage", auto_compress=False) + >>> _ = session.file.put("tests/resources/kitchen.png", "@mystage", auto_compress=False) + >>> # Create DataFrame with image pairs + >>> df = session.create_dataframe( + ... [ + ... ["@mystage/dog.jpg", "@mystage/cat.jpeg"], # Animal comparison + ... ["@mystage/dog.jpg", "@mystage/kitchen.png"], # Animal vs non-animal + ... ], + ... schema=["image1", "image2"] + ... ) + >>> result_df = df.ai.similarity( + ... input1=to_file(col("image1")), + ... input2=to_file(col("image2")), + ... output_column="visual_similarity" + ... ) + >>> result_df.columns + ['IMAGE1', 'IMAGE2', 'VISUAL_SIMILARITY'] + >>> results = result_df.collect() + >>> # Dog and cat (both animals) should be more similar than dog and kitchen + >>> results[0]["VISUAL_SIMILARITY"] > results[1]["VISUAL_SIMILARITY"] + True + + Note: + - Both inputs must be of the same type (both text or both images) + - AI_SIMILARITY does not support computing similarity between text and image inputs + - Similarity scores range from -1 to 1, where: + - 1 indicates identical or very similar content + - 0 indicates no similarity + - -1 indicates opposite or very dissimilar content + """ + output_column_name = output_column or "AI_SIMILARITY_OUTPUT" + + # Convert string inputs to Column objects and AST at top + stmt = None + input1_col = _to_col_if_str(input1, "DataFrame.ai.similarity") + input2_col = _to_col_if_str(input2, "DataFrame.ai.similarity") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_similarity, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input1, input1_col) + build_expr_from_snowpark_column_or_col_name(ast.input2, input2_col) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name + + # Call the ai_similarity function + result_col = ai_similarity( + input1_col, + input2_col, + _emit_ast=False, + **kwargs, + ) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.similarity", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def sentiment( + self, + input_column: ColumnOrName, + categories: Optional[List[str]] = None, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Extract sentiment analysis from text content. + + This method analyzes the sentiment of text in each row, providing overall sentiment + and optionally sentiment for specific categories or aspects mentioned in the text. + + Args: + input_column: The column (Column object or column name as string) containing the text + to analyze for sentiment. + categories: Optional list of up to 10 categories (also called entities or aspects) for which + sentiment should be extracted. Each category may be a maximum of 30 characters long. + For example, if extracting sentiment from restaurant reviews, you might specify + ``['cost', 'quality', 'service', 'wait time']`` as categories. If not provided, + only overall sentiment is returned. + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_SENTIMENT_OUTPUT`` is appended. + + Returns: + A new DataFrame with an appended output column containing sentiment results. + The output is a JSON object with a ``categories`` field containing an array of records. + Each record includes: + + - ``name``: The category name (``overall`` for overall sentiment) + - ``sentiment``: One of ``unknown``, ``positive``, ``negative``, ``neutral``, or ``mixed`` + + Examples:: + + >>> # Overall sentiment analysis + >>> df = session.create_dataframe([ + ... ["The movie had amazing visual effects but the plot was terrible."], + ... ["The food was delicious but the service was slow."], + ... ["Everything about this experience was perfect!"], + ... ], schema=["review"]) + >>> result_df = df.ai.sentiment( + ... input_column="review", + ... output_column="sentiment" + ... ) + >>> result_df.columns + ['REVIEW', 'SENTIMENT'] + >>> results = result_df.collect() + >>> import json + >>> overall_sentiment = json.loads(results[2]["SENTIMENT"])["categories"][0] + >>> overall_sentiment["name"] + 'overall' + >>> overall_sentiment["sentiment"] + 'positive' + + >>> # Sentiment analysis with specific categories + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe([ + ... ["The hotel room was spacious and clean, but the wifi was terrible and the breakfast was mediocre."], + ... ["Great location and friendly staff, though the parking was expensive."], + ... ], schema=["review"]) + >>> result_df = df.ai.sentiment( + ... input_column=col("review"), + ... categories=["room", "wifi", "breakfast", "location", "staff", "parking"], + ... output_column="detailed_sentiment" + ... ) + >>> result_df.columns + ['REVIEW', 'DETAILED_SENTIMENT'] + >>> results = result_df.collect() + >>> sentiments = json.loads(results[0]["DETAILED_SENTIMENT"])["categories"] + >>> # Check that we have sentiments for overall plus the specified categories + >>> len(sentiments) > 1 + True + >>> category_names = [s["name"] for s in sentiments] + >>> "overall" in category_names + True + >>> "room" in category_names + True + + Note: + AI_SENTIMENT can analyze sentiment in English, French, German, Hindi, Italian, Spanish, + and Portuguese. You can specify categories in the language of the text or in English. + """ + output_column_name = output_column or "AI_SENTIMENT_OUTPUT" + + # Convert string input column and AST at top + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.sentiment") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_sentiment, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + if categories is not None: + build_expr_from_python_val(ast.categories, categories) + + ast.output_column.value = output_column_name + + # Call the ai_sentiment function + result_col = ai_sentiment( + input_col, + categories=categories, + _emit_ast=False, + ) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.sentiment", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def embed( + self, + input_column: ColumnOrName, + model: str, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Generate embedding vectors from text or images. + + This method creates dense vector representations (embeddings) of text or images, + which can be used for similarity search, clustering, or as features for machine learning. + + Args: + input_column: The column (Column object or column name as string) containing the text + or images (FILE data type) to embed. + model: The embedding model to use. Supported models: + + For text embeddings: + - ``snowflake-arctic-embed-l-v2.0``: Arctic large model (default for text) + - ``snowflake-arctic-embed-l-v2.0-8k``: Arctic large model with 8K context + - ``nv-embed-qa-4``: NVIDIA embedding model for Q&A + - ``multilingual-e5-large``: Multilingual embedding model + - ``voyage-multilingual-2``: Voyage multilingual model + + For image embeddings: + - ``voyage-multimodal-3``: Voyage multimodal model (only for images) + + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_EMBED_OUTPUT`` is appended. + + Returns: + A new DataFrame with an appended output column containing VECTOR embeddings. + + Examples:: + + >>> # Text embeddings with default model + >>> df = session.create_dataframe([ + ... ["Machine learning is fascinating"], + ... ["Snowflake provides cloud data platform"], + ... ["Python is a versatile programming language"], + ... ], schema=["text"]) + >>> result_df = df.ai.embed( + ... input_column="text", + ... model="snowflake-arctic-embed-l-v2.0", + ... output_column="text_vector" + ... ) + >>> results = result_df.collect() + >>> # Verify we got embeddings + >>> all(len(row["TEXT_VECTOR"]) > 0 for row in results) + True + + >>> # Multilingual text embeddings + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe([ + ... ["Hello world"], + ... ["Bonjour le monde"], + ... ["Hola mundo"], + ... ["你好世界"], + ... ], schema=["greeting"]) + >>> result_df = df.ai.embed( + ... input_column=col("greeting"), + ... model="multilingual-e5-large", + ... output_column="multilingual_vector" + ... ) + >>> results = result_df.collect() + >>> # All greetings should have embeddings + >>> all(len(row["MULTILINGUAL_VECTOR"]) > 0 for row in results) + True + + >>> # Image embeddings + >>> from snowflake.snowpark.functions import to_file + >>> # Upload images to a stage first + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/dog.jpg", "@mystage", auto_compress=False) + >>> _ = session.file.put("tests/resources/cat.jpeg", "@mystage", auto_compress=False) + >>> df = session.read.file("@mystage") + >>> result_df = df.ai.embed( + ... input_column="file", + ... model="voyage-multimodal-3", + ... output_column="image_vector" + ... ) + >>> results = result_df.collect() + >>> # Both images should have embeddings + >>> all(len(row["IMAGE_VECTOR"]) > 0 for row in results) + True + + Note: + - Embeddings can be used with vector similarity functions to find similar items + - Different models produce embeddings of different dimensions + - For best results, use the same model for all items you want to compare + """ + output_column_name = output_column or "AI_EMBED_OUTPUT" + + # Convert string input column to Column object & AST at top + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.embed") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_embed, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + ast.model = model + + ast.output_column.value = output_column_name + + # Call the ai_embed function + result_col = ai_embed( + model=model, + input=input_col, + _emit_ast=False, + ) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.embed", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def summarize_agg( + self, + input_column: ColumnOrName, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Summarize a column of text data using AI. + + This method aggregates and summarizes text data from multiple rows into a single + comprehensive summary. It's particularly useful for creating summaries from + collections of reviews, feedback, transcripts, or other text content. + + Args: + input_column: The column (Column object or column name as string) containing the text + data to summarize. + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_SUMMARIZE_AGG_OUTPUT`` is appended. + + Returns: + A new DataFrame with a single row containing the summarized text. + + Examples:: + + >>> # Summarize product reviews + >>> df = session.create_dataframe([ + ... ["The product quality is excellent and shipping was fast."], + ... ["Great value for money, highly recommend!"], + ... ["Customer service was very helpful and responsive."], + ... ["The packaging could be better, but the product itself is good."], + ... ["Easy to use and works as advertised."], + ... ], schema=["review"]) + >>> summary_df = df.ai.summarize_agg( + ... input_column="review", + ... output_column="reviews_summary" + ... ) + >>> summary_df.columns + ['REVIEWS_SUMMARY'] + >>> summary_df.count() + 1 + >>> results = summary_df.collect() + >>> len(results[0]["REVIEWS_SUMMARY"]) > 10 + True + + >>> # Summarize with Column object + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe([ + ... ["Meeting started with project updates"], + ... ["Discussed timeline and deliverables"], + ... ["Identified key risks and mitigation strategies"], + ... ["Assigned action items to team members"], + ... ], schema=["meeting_notes"]) + >>> summary_df = df.ai.summarize_agg( + ... input_column=col("meeting_notes"), + ... output_column="meeting_summary" + ... ) + >>> summary_df.columns + ['MEETING_SUMMARY'] + >>> summary_df.count() + 1 + + Note: + - This is an aggregation function that combines multiple rows into a single summary + - For best results, provide clear and coherent text in the input column + - The summary will capture the main themes and important points from all input rows + - Unlike the ``agg`` method which requires a task description, ``summarize_agg`` + automatically generates a comprehensive summary + """ + output_column_name = output_column or "AI_SUMMARIZE_AGG_OUTPUT" + + # Convert string input column to Column object & AST at top + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.summarize_agg") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_summarize_agg, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + + ast.output_column.value = output_column_name + + # Call the ai_summarize_agg function + result_col = ai_summarize_agg( + input_col, + _emit_ast=False, + ) + + # Create a new DataFrame with the summarized result + df = self._dataframe.select( + result_col.alias(output_column_name), _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.summarize_agg", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def transcribe( + self, + input_column: ColumnOrName, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + **kwargs, + ) -> "snowflake.snowpark.DataFrame": + """Transcribe text from an audio file with optional timestamps and speaker labels. + + Args: + input_column: The column (Column object or column name as string) containing FILE references + to audio files. Use ``to_file`` to convert staged paths to FILE type. + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_TRANSCRIBE_OUTPUT`` is appended. + **kwargs: Additional options forwarded to the underlying function, e.g. ``timestamp_granularity``. + + Examples:: + + >>> import json + >>> # Basic transcription without timestamps + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/audio.ogg", "@mystage", auto_compress=False) + >>> from snowflake.snowpark.functions import col, to_file + >>> df = session.create_dataframe([["@mystage/audio.ogg"]], schema=["audio_path"]) # staged file path + >>> result_df = df.ai.transcribe( + ... input_column=to_file(col("audio_path")), + ... output_column="transcript", + ... ) + >>> result_df.columns + ['AUDIO_PATH', 'TRANSCRIPT'] + >>> result = json.loads(result_df.collect()[0]["TRANSCRIPT"]) + >>> result['audio_duration'] > 120 + True + >>> "glad to see things are going well" in result['text'].lower() + True + + >>> # Transcription with word-level timestamps + >>> result_df = df.ai.transcribe( + ... input_column=to_file(col("audio_path")), + ... output_column="transcript", + ... timestamp_granularity='word', + ... ) + >>> result = json.loads(result_df.collect()[0]["TRANSCRIPT"]) + >>> len(result["segments"]) > 0 + True + >>> result["segments"][0]["text"].lower() + 'glad' + >>> 'start' in result["segments"][0] and 'end' in result["segments"][0] + True + + >>> # Transcription with speaker diarization (requires a multi-speaker audio file) + >>> _ = session.file.put("tests/resources/conversation.ogg", "@mystage", auto_compress=False) + >>> df = session.create_dataframe([["@mystage/conversation.ogg"]], schema=["audio_path"]) + >>> result_df = df.ai.transcribe( + ... input_column=to_file(col("audio_path")), + ... output_column="transcript", + ... timestamp_granularity='speaker', + ... ) + >>> result = json.loads(result_df.collect()[0]["TRANSCRIPT"]) + >>> result["audio_duration"] > 100 and len(result["segments"]) > 0 + True + >>> result["segments"][0]["speaker_label"] + 'SPEAKER_00' + >>> 'jenny' in result["segments"][0]["text"].lower() + True + >>> 'start' in result["segments"][0] and 'end' in result["segments"][0] + True + """ + output_column_name = output_column or "AI_TRANSCRIBE_OUTPUT" + + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.transcribe") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_transcribe, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + + ast.output_column.value = output_column_name + + result_col = ai_transcribe( + input_col, + _emit_ast=False, + **kwargs, + ) + + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.transcribe", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def parse_document( + self, + input_column: ColumnOrName, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + **kwargs, + ) -> "snowflake.snowpark.DataFrame": + """Extract content from a document (OCR or layout parsing) as JSON text. + + Args: + input_column: The column (Column object or column name as string) containing FILE references + to documents or images on a stage. Use ``to_file`` to convert staged paths to FILE type. + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_PARSE_DOCUMENT_OUTPUT`` is appended. + **kwargs: Additional options forwarded to the underlying function, such as ``mode`` and ``page_split``. + + Examples:: + + >>> import json + >>> # Parse a PDF document with default OCR mode + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/doc.pdf", "@mystage", auto_compress=False) + >>> from snowflake.snowpark.functions import col, to_file + >>> df = session.create_dataframe([["@mystage/doc.pdf"]], schema=["file_path"]) # staged file path + >>> result_df = df.ai.parse_document( + ... input_column=to_file(col("file_path")), + ... output_column="parsed", + ... ) + >>> result_df.columns + ['FILE_PATH', 'PARSED'] + >>> result = json.loads(result_df.collect()[0]["PARSED"]) + >>> "Sample PDF" in result["content"] and result["metadata"]["pageCount"] == 3 + True + + >>> # Parse with LAYOUT mode to extract tables and structure + >>> _ = session.file.put("tests/resources/invoice.pdf", "@mystage", auto_compress=False) + >>> df = session.create_dataframe([["@mystage/invoice.pdf"]], schema=["file_path"]) + >>> result_df = df.ai.parse_document( + ... input_column=to_file(col("file_path")), + ... output_column="parsed", + ... mode='LAYOUT', + ... ) + >>> result = json.loads(result_df.collect()[0]["PARSED"]) + >>> "| Customer Name |" in result["content"] and "| Country |" in result["content"] + True + + >>> # Parse with page splitting for long documents (PDF only) + >>> df = session.create_dataframe([["@mystage/doc.pdf"]], schema=["file_path"]) + >>> result_df = df.ai.parse_document( + ... input_column=to_file(col("file_path")), + ... output_column="parsed", + ... page_split=True, + ... ) + >>> result = json.loads(result_df.collect()[0]["PARSED"]) + >>> len(result["pages"]) == 3 and result["pages"][0]["index"] == 0 + True + """ + output_column_name = output_column or "AI_PARSE_DOCUMENT_OUTPUT" + + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.parse_document") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_parse_document, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + for k, v in kwargs.items(): + entry = ast.kwargs.add() + entry._1 = k + build_expr_from_python_val(entry._2, v) + ast.output_column.value = output_column_name + + result_col = ai_parse_document( + input_col, + _emit_ast=False, + **kwargs, + ) + + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.parse_document", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def extract( + self, + input_column: ColumnOrName, + *, + response_format: Optional[Union[Dict[str, str], List]] = None, + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Extract structured information from text or files using a response schema. + + Args: + input_column: The column (Column object or column name as string) containing the text + or FILE data to extract information from. Use ``to_file`` for staged file paths. + response_format: The schema describing information to extract. Supports: + + - Simple object schema (dict) mapping feature names to extraction prompts: + ``{'name': 'What is the last name of the employee?', 'address': 'What is the address of the employee?'}`` + - Array of strings containing the information to be extracted: + ``['What is the last name of the employee?', 'What is the address of the employee?']`` + - Array of arrays containing two strings (feature name and extraction prompt): + ``[['name', 'What is the last name of the employee?'], ['address', 'What is the address of the employee?']]`` + - Array of strings with colon-separated feature names and extraction prompts: + ``['name: What is the last name of the employee?', 'address: What is the address of the employee?']`` + + output_column: The name of the output column to be appended. + If not provided, a column named ``AI_EXTRACT_OUTPUT`` is appended. + + Returns: + A new DataFrame with an appended JSON object containing the extracted fields + under ``response``. + + Examples:: + + >>> # Extract from text string + >>> from snowflake.snowpark.functions import col + >>> df = session.create_dataframe([ + ... ["John Smith lives in San Francisco and works for Snowflake"], + ... ], schema=["text"]) + >>> result_df = df.ai.extract( + ... input_column="text", + ... response_format={'name': 'What is the first name of the employee?', 'city': 'What is the address of the employee?'}, + ... output_column="extracted", + ... ) + >>> result_df.select("EXTRACTED").show() + -------------------------------- + |"EXTRACTED" | + -------------------------------- + |{ | + | "response": { | + | "city": "San Francisco", | + | "name": "John" | + | } | + |} | + -------------------------------- + + + >>> # Extract using array format + >>> df = session.create_dataframe( + ... [ + ... ["Alice Johnson works in Seattle"], + ... ["Bob Williams works in Portland"], + ... ], + ... schema=["text"] + ... ) + >>> result_df = df.ai.extract( + ... input_column=col("text"), + ... response_format=[["name", "What is the first name?"], ["city", "What city do they work in?"]], + ... output_column="info", + ... ) + >>> result_df.show() + ------------------------------------------------------------ + |"TEXT" |"INFO" | + ------------------------------------------------------------ + |Alice Johnson works in Seattle |{ | + | | "response": { | + | | "city": "Seattle", | + | | "name": "Alice" | + | | } | + | |} | + |Bob Williams works in Portland |{ | + | | "response": { | + | | "city": "Portland", | + | | "name": "Bob" | + | | } | + | |} | + ------------------------------------------------------------ + + + >>> # Extract lists using List: prefix + >>> df = session.create_dataframe( + ... [["Python, Java, and JavaScript are popular programming languages"]], + ... schema=["text"] + ... ) + >>> result_df = df.ai.extract( + ... input_column="text", + ... response_format=[["languages", "List: What programming languages are mentioned?"]], + ... output_column="extracted", + ... ) + >>> result_df.select("EXTRACTED").show() + ---------------------- + |"EXTRACTED" | + ---------------------- + |{ | + | "response": { | + | "languages": [ | + | "Python", | + | "Java", | + | "JavaScript" | + | ] | + | } | + |} | + ---------------------- + + + >>> # Extract from file + >>> from snowflake.snowpark.functions import to_file + >>> _ = session.sql("CREATE OR REPLACE TEMP STAGE mystage ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')").collect() + >>> _ = session.file.put("tests/resources/invoice.pdf", "@mystage", auto_compress=False) + >>> df = session.create_dataframe([["@mystage/invoice.pdf"]], schema=["file_path"]) + >>> result_df = df.ai.extract( + ... input_column=to_file(col("file_path")), + ... response_format=[["date", "What is the invoice date?"], ["amount", "What is the amount?"]], + ... output_column="info", + ... ) + >>> result_df.select("INFO").show() + -------------------------------- + |"INFO" | + -------------------------------- + |{ | + | "response": { | + | "amount": "USD $950.00", | + | "date": "Nov 26, 2016" | + | } | + |} | + -------------------------------- + + + """ + + output_column_name = output_column or "AI_EXTRACT_OUTPUT" + + stmt = None + input_col = _to_col_if_str(input_column, "DataFrame.ai.extract") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_extract, stmt) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name(ast.input_column, input_col) + if response_format is not None: + build_expr_from_python_val(ast.response_format, response_format) + + ast.output_column.value = output_column_name + + result_col = ai_extract( + input=input_col, + response_format=response_format, + _emit_ast=False, + ) + + output_column_name = output_column or "AI_EXTRACT_OUTPUT" + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.extract", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def count_tokens( + self, + model: str, + prompt: ColumnOrName, + *, + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Count the number of tokens in text for a specified language model. + + This method returns the number of tokens that would be consumed by the specified + model when processing the input text. This is useful for estimating costs and + ensuring inputs fit within model token limits. + + Args: + model: The model to base the token count on. Required. Supported models include: + + - ``deepseek-r1``, ``e5-base-v2``, ``e5-large-v2`` + - ``gemma-7b``, ``jamba-1.5-large``, ``jamba-1.5-mini``, ``jamba-instruct`` + - ``llama2-70b-chat``, ``llama3-70b``, ``llama3-8b`` + - ``llama3.1-405b``, ``llama3.1-70b``, ``llama3.1-8b`` + - ``llama3.2-1b``, ``llama3.2-3b``, ``llama3.3-70b`` + - ``llama4-maverick``, ``llama4-scout`` + - ``mistral-7b``, ``mistral-large``, ``mistral-large2``, ``mixtral-8x7b`` + - ``nv-embed-qa-4``, ``reka-core``, ``reka-flash`` + - ``snowflake-arctic-embed-l-v2.0``, ``snowflake-arctic-embed-m-v1.5`` + - ``snowflake-arctic-embed-m``, ``snowflake-arctic`` + - ``snowflake-llama-3.1-405b``, ``snowflake-llama-3.3-70b`` + - ``voyage-multilingual-2`` + prompt: The column (Column object or column name as string) containing the text + to count tokens for. + output_column: The name of the output column to be appended. + If not provided, a column named ``COUNT_TOKENS_OUTPUT`` is appended. + + Returns: + A new DataFrame with an appended output column containing the token count as an integer. + + Examples:: + + >>> # Count tokens for a simple text + >>> df = session.create_dataframe([ + ... ["What is a large language model?"], + ... ["Explain quantum computing in simple terms."], + ... ], schema=["text"]) + >>> result_df = df.ai.count_tokens( + ... model="llama3.1-70b", + ... prompt="text", + ... output_column="token_count" + ... ) + >>> result_df.show() + -------------------------------------------------------------- + |"TEXT" |"TOKEN_COUNT" | + -------------------------------------------------------------- + |What is a large language model? |8 | + |Explain quantum computing in simple terms. |9 | + -------------------------------------------------------------- + + + Note: + The token count does not account for any managed system prompt that may be + automatically added when using other Cortex AI functions. The actual token + usage may be higher when using those functions. + """ + + output_column_name = output_column or "COUNT_TOKENS_OUTPUT" + + # AST at top + stmt = None + prompt_col = _to_col_if_str(prompt, "DataFrame.ai.count_tokens") + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.dataframe_ai_count_tokens, stmt) + self._dataframe._set_ast_ref(ast.df) + ast.model = model + build_expr_from_snowpark_column_or_col_name(ast.prompt, prompt_col) + + ast.output_column.value = output_column_name + + # Call SNOWFLAKE.CORTEX.COUNT_TOKENS function + count_tokens_func = function("SNOWFLAKE.CORTEX.COUNT_TOKENS", _emit_ast=False) + result_col = count_tokens_func(model, prompt_col) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + "DataFrame.ai.count_tokens", + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def split_text_markdown_header( + self, + text_to_split: ColumnOrName, + headers_to_split_on: Union[Dict[str, str], Column], + chunk_size: Union[int, Column], + *, + overlap: Union[int, Column] = 0, + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Split Markdown-formatted text into structured chunks based on header levels. + + This method segments text using specified Markdown headers and recursively splits + each segment to produce chunks of the desired size. It preserves document structure + by tracking which headers each chunk falls under. + + Args: + text_to_split: The column (Column object or column name as string) containing + the Markdown-formatted text to split. + headers_to_split_on: A dictionary mapping Markdown header syntax to metadata field names, + or a Column containing such a mapping. For example: + ``{"#": "header_1", "##": "header_2"}`` will split on # and ## headers. + chunk_size: The maximum number of characters in each chunk. Must be greater than zero. + Can be an integer or a Column containing integer values. + overlap: Optional number of characters to overlap between consecutive chunks. + Defaults to 0 if not provided. Can be an integer or a Column. + output_column: The name of the output column to be appended. + If not provided, a column named ``SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT`` is appended. + + Returns: + A new DataFrame with an appended output column containing an array of objects. + Each object has: + + - ``chunk``: A string containing the extracted text + - ``headers``: A dictionary containing the Markdown header values under which the chunk is nested + + Examples:: + + >>> # Split a simple Markdown document + >>> df = session.create_dataframe([ + ... ["# Introduction\\nThis is the intro.\\n## Background\\nSome background info."], + ... ], schema=["document"]) + >>> result_df = df.ai.split_text_markdown_header( + ... text_to_split="document", + ... headers_to_split_on={"#": "section", "##": "subsection"}, + ... chunk_size=20, + ... overlap=5, + ... output_column="chunks" + ... ) + >>> result_df.show() + -------------------------------------------------------------- + |"DOCUMENT" |"CHUNKS" | + -------------------------------------------------------------- + |# Introduction |[ | + |This is the intro. | { | + |## Background | "chunk": "This is the intro.", | + |Some background info. | "headers": { | + | | "section": "Introduction" | + | | } | + | | }, | + | | { | + | | "chunk": "Some background", | + | | "headers": { | + | | "section": "Introduction", | + | | "subsection": "Background" | + | | } | + | | }, | + | | { | + | | "chunk": "info.", | + | | "headers": { | + | | "section": "Introduction", | + | | "subsection": "Background" | + | | } | + | | } | + | |] | + -------------------------------------------------------------- + + + Note: + - The function preserves document hierarchy by including parent headers for each chunk + - Chunks are created using recursive character splitting after initial header segmentation + - Overlap helps maintain context across chunk boundaries + """ + method_name = "DataFrame.ai.split_text_markdown_header" + output_column_name = output_column or "SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT" + + # Convert inputs to Column objects and AST at top + stmt = None + text_col = _to_col_if_str(text_to_split, method_name) + headers_col = _to_col_if_lit(headers_to_split_on, method_name) + chunk_size_col = _to_col_if_lit(chunk_size, method_name) + overlap_col = _to_col_if_lit(overlap, method_name) + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position( + stmt.expr.dataframe_ai_split_text_markdown_header, stmt + ) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name( + ast.text_to_split, text_to_split + ) + build_expr_from_snowpark_column_or_python_val( + ast.headers_to_split_on, headers_to_split_on + ) + build_expr_from_snowpark_column_or_python_val(ast.chunk_size, chunk_size) + build_expr_from_snowpark_column_or_python_val(ast.overlap, overlap) + + ast.output_column.value = output_column_name + + # Call SNOWFLAKE.CORTEX.SPLIT_TEXT_MARKDOWN_HEADER function + split_func = function( + "SNOWFLAKE.CORTEX.SPLIT_TEXT_MARKDOWN_HEADER", _emit_ast=False + ) + result_col = split_func(text_col, headers_col, chunk_size_col, overlap_col) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + method_name, + ) + if _emit_ast: + df._ast_id = stmt.uid + return df + + @experimental(version="1.39.0") + @publicapi + def split_text_recursive_character( + self, + text_to_split: ColumnOrName, + format: Literal["none", "markdown"], + chunk_size: Union[int, Column], + *, + overlap: Union[int, Column] = 0, + separators: Union[List[str], Column] = ("\n\n", "\n", " ", ""), + output_column: Optional[str] = None, + _emit_ast: bool = True, + ) -> "snowflake.snowpark.DataFrame": + """Split text into chunks using recursive character-based splitting. + + This method splits text by recursively trying a list of separators in order, + creating chunks that fit within the specified size limit. It's useful for + breaking down large documents for embedding, RAG, or search indexing. + + Args: + text_to_split: The column (Column object or column name as string) containing + the text to split. + format: The format of your input text, which determines the default separators in the splitting algorithm. Must be one of the following: + + - ``none``: No format-specific separators. Only the separators in the separators field are used for splitting. + - ``markdown``: Separates on headers, code blocks, and tables, in addition to any separators in the separators field. + + chunk_size: The maximum number of characters in each chunk. Must be greater than zero. + Can be an integer or a Column containing integer values. + overlap: Optional number of characters to overlap between consecutive chunks. + Defaults to 0 if not provided. Can be an integer or a Column. + separators: A list of separator strings to use for splitting, or a Column + containing an array of separators. The function tries separators in order + until it finds one that produces appropriately sized chunks. + Defaults to ``["\\n\\n", "\\n", " ", ""]``. + output_column: The name of the output column to be appended. + If not provided, a column named ``SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT`` is appended. + + Returns: + A new DataFrame with an appended output column containing an array of text chunks. + + Examples:: + + >>> # Basic text splitting without format + >>> df = session.create_dataframe([ + ... ["This is a long document. It has multiple sentences.\\n\\nAnd multiple paragraphs."], + ... ], schema=["text"]) + >>> result_df = df.ai.split_text_recursive_character( + ... text_to_split="text", + ... format="none", + ... chunk_size=30, + ... overlap=5, + ... output_column="chunks" + ... ) + >>> result_df.show() + ----------------------------------------------------------------------------------------- + |"TEXT" |"CHUNKS" | + ----------------------------------------------------------------------------------------- + |This is a long document. It has multiple senten... |[ | + | | "This is a long document. It", | + |And multiple paragraphs. | "It has multiple sentences.", | + | | "And multiple paragraphs." | + | |] | + ----------------------------------------------------------------------------------------- + + + >>> # Split markdown formatted text + >>> from snowflake.snowpark.functions import col + >>> markdown_text = "# Title\\n\\n## Subtitle\\n\\nMore content." + >>> df = session.create_dataframe([ + ... [markdown_text], + ... ], schema=["text"]) + >>> result_df = df.ai.split_text_recursive_character( + ... text_to_split=col("text"), + ... format="markdown", + ... chunk_size=25, + ... overlap=3, + ... output_column="md_chunks" + ... ) + >>> result_df.show() + ------------------------------------- + |"TEXT" |"MD_CHUNKS" | + ------------------------------------- + |# Title |[ | + | | "# Title", | + |## Subtitle | "## Subtitle", | + | | "More content." | + |More content. |] | + ------------------------------------- + + + >>> # Custom separators with code + >>> df = session.create_dataframe([ + ... ["def hello():\\n print('Hello')\\n\\ndef world():\\n print('World')"], + ... ], schema=["code"]) + >>> result_df = df.ai.split_text_recursive_character( + ... text_to_split="code", + ... format="none", + ... chunk_size=30, + ... separators=["\\n\\n", "\\n", " ", " ", ""], + ... output_column="code_chunks" + ... ) + >>> result_df.show() + -------------------------------------------- + |"CODE" |"CODE_CHUNKS" | + -------------------------------------------- + |def hello(): |[ | + | print('Hello') | "def hello():", | + | | "print('Hello')", | + |def world(): | "def world():", | + | print('World') | "print('World')" | + | |] | + -------------------------------------------- + + + >>> # Custom separators + >>> df = session.create_dataframe([ + ... ["First sentence. Second sentence. Third sentence.", "none", 15, 3], + ... ], schema=["text", "fmt", "max_size", "overlap_size"]) + >>> result_df = df.ai.split_text_recursive_character( + ... text_to_split=col("text"), + ... format=col("fmt"), + ... chunk_size=col("max_size"), + ... overlap=col("overlap_size"), + ... separators=[". ", " ", ""], + ... output_column="split_text" + ... ) + >>> result_df.select("text", "split_text").show() + -------------------------------------------------------------------------- + |"TEXT" |"SPLIT_TEXT" | + -------------------------------------------------------------------------- + |First sentence. Second sentence. Third sentence. |[ | + | | "First sentence", | + | | ". Second", | + | | "sentence", | + | | ". Third", | + | | "sentence." | + | |] | + -------------------------------------------------------------------------- + + + Note: + - The function tries separators in the order provided + - If no separator produces small enough chunks, it splits by individual characters + - Overlap helps maintain context between chunks, useful for embedding and retrieval + - Choose separators appropriate for your content type (e.g., paragraphs for prose, + functions for code) + """ + method_name = "DataFrame.ai.split_text_recursive_character" + output_column_name = output_column or "SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT" + + # Convert input to Column object & AST at top + stmt = None + text_col = _to_col_if_str(text_to_split, method_name) + chunk_size_col = _to_col_if_lit(chunk_size, method_name) + separators_col = _to_col_if_lit(separators, method_name) + overlap_col = _to_col_if_lit(overlap, method_name) + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position( + stmt.expr.dataframe_ai_split_text_recursive_character, stmt + ) + self._dataframe._set_ast_ref(ast.df) + build_expr_from_snowpark_column_or_col_name( + ast.text_to_split, text_to_split + ) + if format == "markdown": + ast.format.ai_split_text_recursive_format_markdown = True + else: + ast.format.ai_split_text_recursive_format_none = True + build_expr_from_snowpark_column_or_python_val(ast.chunk_size, chunk_size) + build_expr_from_snowpark_column_or_python_val(ast.overlap, overlap) + build_expr_from_snowpark_column_or_python_val(ast.separators, separators) + + ast.output_column.value = output_column_name + + # Call the function + split_func = function( + "SNOWFLAKE.CORTEX.SPLIT_TEXT_RECURSIVE_CHARACTER", _emit_ast=False + ) + result_col = split_func( + text_col, format, chunk_size_col, overlap_col, separators_col + ) + + # Add the output column to the DataFrame + df = self._dataframe.with_column( + output_column_name, result_col, _emit_ast=False + ) + + add_api_call( + df, + method_name, + ) + if _emit_ast: + df._ast_id = stmt.uid + return df diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index 2982e65b2b..2f29726e4e 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -41,6 +41,7 @@ from snowflake.snowpark._internal.type_utils import ColumnOrName, LiteralType from snowflake.snowpark._internal.utils import ( check_agg_exprs, + experimental, is_valid_tuple_for_agg, parse_positional_args_to_list, parse_positional_args_to_list_variadic, @@ -831,6 +832,96 @@ def _function( return df + @relational_group_df_api_usage + @experimental(version="1.39.0") + @publicapi + def ai_agg( + self, + expr: ColumnOrName, + task_description: str, + _emit_ast: bool = True, + **kwargs, + ) -> DataFrame: + """Aggregate a column of text data using a natural language task description. + + This method reduces a column of text by performing a natural language aggregation + as described in the task description for each group. For instance, it can summarize + large datasets or extract specific insights per group. + + Args: + expr: The column (Column object or column name as string) containing the text data + on which the aggregation operation is to be performed. + task_description: A plain English string that describes the aggregation task, such as + "Summarize the product reviews for a blog post targeting consumers" or + "Identify the most positive review and translate it into French and Polish, one word only". + + Returns: + A DataFrame with one row per group containing the aggregated result. + + Example:: + + >>> df = session.create_dataframe([ + ... ["electronics", "Excellent product, highly recommend!"], + ... ["electronics", "Great quality and fast shipping"], + ... ["clothing", "Perfect fit and great material"], + ... ["clothing", "Poor quality, very disappointed"], + ... ], schema=["category", "review"]) + >>> summary_df = df.group_by("category").ai_agg( + ... expr="review", + ... task_description="Summarize these product reviews for a blog post targeting consumers" + ... ) + >>> summary_df.count() + 2 + + Note: + For optimal performance, follow these guidelines: + + - Use plain English text for the task description. + + - Describe the text provided in the task description. For example, instead of a task + description like "summarize", use "Summarize the phone call transcripts". + + - Describe the intended use case. For example, instead of "find the best review", + use "Find the most positive and well-written restaurant review to highlight on + the restaurant website". + + - Consider breaking the task description into multiple steps. + """ + exclude_grouping_columns = kwargs.get("exclude_grouping_columns", False) + + # Convert expr to Column expression + expr_col = ( + Column(expr)._expression if isinstance(expr, str) else expr._expression + ) + + # Create the ai_agg expression + agg_expr = functions.ai_agg( + Column(expr_col, _emit_ast=False), + functions.lit(task_description, _emit_ast=False), + _emit_ast=False, + )._expression + + df = self._to_df( + [agg_expr], + exclude_grouping_columns=exclude_grouping_columns, + _emit_ast=False, + ) + # if no grouping exprs, there is already a LIMIT 1 in the query + # see aggregate_statement in analyzer_utils.py + df._ops_after_agg = set() if self._grouping_exprs else {"limit"} + + if _emit_ast: + stmt = self._dataframe._session._ast_batch.bind() + ast = with_src_position(stmt.expr.relational_grouped_dataframe_ai_agg, stmt) + # Reference the grouped dataframe + self._set_ast_ref(ast.grouped_df) + # Set arguments + build_expr_from_python_val(ast.expr, expr) + ast.task_description = task_description + df._ast_id = stmt.uid + + return df + @publicapi def _non_empty_argument_function( self, func_name: str, *cols: ColumnOrName, _emit_ast: bool = True, **kwargs diff --git a/tests/ast/data/DataFrame.ai.test b/tests/ast/data/DataFrame.ai.test new file mode 100644 index 0000000000..416b5b5937 --- /dev/null +++ b/tests/ast/data/DataFrame.ai.test @@ -0,0 +1,2151 @@ +## TEST CASE + +# Basic table to drive some AI ops that don't require files +df = session.table(tables.table1) + +# DataFrame.ai.complete with named placeholders +from snowflake.snowpark.functions import col +df1 = df.select(col("STR").as_("question")) +df2 = df1.ai.complete( + prompt="Answer briefly: {q}", + input_columns={"q": col("question")}, + model="snowflake-arctic", +) + +# DataFrame.ai.filter with named placeholders +df4 = session.create_dataframe([ + ["Switzerland", "Europe"], + ["Korea", "Asia"], + ["Brazil", "South America"], +], schema=["country", "continent"]) +df5 = df4.ai.filter( + "Is {country} located in {continent} and specifically in Europe?", + input_columns={"country": col("country"), "continent": col("continent")}, +) + +# DataFrame.ai.agg on text column +df6 = session.create_dataframe([ + ["Excellent product, highly recommend!"], + ["Great quality and fast shipping"], + ["Average product, nothing special"], + ["Poor quality, very disappointed"], +], schema=["review"]) +df7 = df6.ai.agg( + task_description="Summarize these product reviews for a blog post", + input_column="review", +) + +# DataFrame.ai.classify with list categories +df8 = session.create_dataframe([ + ["I love hiking in the mountains"], + ["My favorite dish is pasta"], + ["Just finished reading a great book"], +], schema=["text"]) +df9 = df8.ai.classify( + input_column="text", + categories=["hiking", "cooking", "reading"], +) + +# DataFrame.ai.similarity between two text columns +df10 = session.create_dataframe([ + ["I love programming", "I enjoy coding"], + ["The weather is nice", "It's raining heavily"], +], schema=["text1", "text2"]) +df11 = df10.ai.similarity( + input1="text1", + input2="text2", +) + +# DataFrame.ai.sentiment overall sentiment +df12 = session.create_dataframe([ + ["The movie had amazing visual effects but the plot was terrible."], + ["Everything about this experience was perfect!"], +], schema=["review"]) +df13 = df12.ai.sentiment(input_column="review") + +# DataFrame.ai.embed text embeddings +df14 = session.create_dataframe([ + ["Machine learning is fascinating"], + ["Snowflake provides cloud data platform"], +], schema=["text"]) +df15 = df14.ai.embed(input_column="text", model="snowflake-arctic-embed-l-v2.0") + +# DataFrame.ai.summarize_agg aggregation summary +df16 = session.create_dataframe([ + ["Meeting started with project updates"], + ["Discussed timeline and deliverables"], + ["Identified key risks"], +], schema=["notes"]) +df17 = df16.ai.summarize_agg(input_column="notes") + +# DataFrame.ai.extract with dict response_format +df20 = session.create_dataframe([["John Smith lives in San Francisco and works for Snowflake"]], schema=["text"]) +df21 = df20.ai.extract( + input_column="text", + response_format={"name": "What is the first name of the employee?", "city": "What is the address of the employee?"}, +) + +# DataFrame.ai.count_tokens simple +df22 = session.create_dataframe([["What is a large language model?"], ["Explain quantum computing in simple terms."]], schema=["text"]) +df23 = df22.ai.count_tokens(model="llama3.1-70b", prompt="text") + +# DataFrame.ai.split_text_markdown_header +df24 = session.create_dataframe([["# Intro\nThis is the intro.\n## Background\nSome background info."]], schema=["document"]) +df25 = df24.ai.split_text_markdown_header( + text_to_split="document", + headers_to_split_on={"#": "section", "##": "subsection"}, + chunk_size=20, + overlap=5, +) + +# DataFrame.ai.split_text_recursive_character +df26 = session.create_dataframe([["This is a long document. It has multiple sentences.\n\nAnd multiple paragraphs."]], schema=["text"]) +df27 = df26.ai.split_text_recursive_character( + text_to_split="text", + format="none", + chunk_size=30, + overlap=5, +) + +## EXPECTED UNPARSER OUTPUT + +df = session.table("table1") + +df1 = df.select(col("STR").as_("question")) + +df2 = df1.ai.complete(prompt="Answer briefly: {q}", model="snowflake-arctic", input_columns=prompt("Answer briefly: {0}", col("question")), output_column="AI_COMPLETE_OUTPUT") + +df4 = session.create_dataframe([["Switzerland", "Europe"], ["Korea", "Asia"], ["Brazil", "South America"]], schema=["country", "continent"]) + +df5 = df4.ai.filter("Is {country} located in {continent} and specifically in Europe?", input_columns=prompt("Is {0} located in {1} and specifically in Europe?", col("country"), col("continent"))) + +df6 = session.create_dataframe([["Excellent product, highly recommend!"], ["Great quality and fast shipping"], ["Average product, nothing special"], ["Poor quality, very disappointed"]], schema=["review"]) + +df7 = df6.ai.agg("Summarize these product reviews for a blog post", input_column="review", output_column="AI_AGG_OUTPUT") + +df8 = session.create_dataframe([["I love hiking in the mountains"], ["My favorite dish is pasta"], ["Just finished reading a great book"]], schema=["text"]) + +df9 = df8.ai.classify(input_column="text", categories=["hiking", "cooking", "reading"], output_column="AI_CLASSIFY_OUTPUT") + +df10 = session.create_dataframe([["I love programming", "I enjoy coding"], ["The weather is nice", "It's raining heavily"]], schema=["text1", "text2"]) + +df11 = df10.ai.similarity(input1="text1", input2="text2", output_column="AI_SIMILARITY_OUTPUT") + +df12 = session.create_dataframe([["The movie had amazing visual effects but the plot was terrible."], ["Everything about this experience was perfect!"]], schema=["review"]) + +df13 = df12.ai.sentiment(input_column="review", output_column="AI_SENTIMENT_OUTPUT") + +df14 = session.create_dataframe([["Machine learning is fascinating"], ["Snowflake provides cloud data platform"]], schema=["text"]) + +df15 = df14.ai.embed(input_column="text", model="snowflake-arctic-embed-l-v2.0", output_column="AI_EMBED_OUTPUT") + +df16 = session.create_dataframe([["Meeting started with project updates"], ["Discussed timeline and deliverables"], ["Identified key risks"]], schema=["notes"]) + +df17 = df16.ai.summarize_agg(input_column="notes", output_column="AI_SUMMARIZE_AGG_OUTPUT") + +df20 = session.create_dataframe([["John Smith lives in San Francisco and works for Snowflake"]], schema=["text"]) + +df21 = df20.ai.extract(input_column="text", response_format={"name": "What is the first name of the employee?", "city": "What is the address of the employee?"}, output_column="AI_EXTRACT_OUTPUT") + +df22 = session.create_dataframe([["What is a large language model?"], ["Explain quantum computing in simple terms."]], schema=["text"]) + +df23 = df22.ai.count_tokens(model="llama3.1-70b", prompt="text", output_column="COUNT_TOKENS_OUTPUT") + +df24 = session.create_dataframe([["# Intro\nThis is the intro.\n## Background\nSome background info."]], schema=["document"]) + +df25 = df24.ai.split_text_markdown_header(text_to_split="document", headers_to_split_on={"#": "section", "##": "subsection"}, chunk_size=20, overlap=5, output_column="SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT") + +df26 = session.create_dataframe([["This is a long document. It has multiple sentences.\n\nAnd multiple paragraphs."]], schema=["text"]) + +df27 = df26.ai.split_text_recursive_character(text_to_split="text", format="none", chunk_size=30, overlap=5, separators=("\n\n", "\n", " ", ""), output_column="SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT") + +## EXPECTED ENCODED AST + +interned_value_table { + string_values { + key: -1 + } + string_values { + key: 2 + value: "SRC_POSITION_TEST_MODE" + } +} +body { + bind { + expr { + table { + name { + name { + name_flat { + name: "table1" + } + } + } + src { + end_column: 41 + end_line: 26 + file: 2 + start_column: 13 + start_line: 26 + } + variant { + session_table: true + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df" + } + uid: 1 + } +} +body { + bind { + expr { + dataframe_select { + cols { + args { + column_alias { + col { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 24 + start_line: 30 + } + v: "STR" + } + } + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 24 + start_line: 30 + } + } + } + fn { + column_alias_fn_as: true + } + name: "question" + src { + end_column: 50 + end_line: 30 + file: 2 + start_column: 24 + start_line: 30 + } + } + } + variadic: true + } + df { + dataframe_ref { + id: 1 + } + } + src { + end_column: 51 + end_line: 30 + file: 2 + start_column: 14 + start_line: 30 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df1" + } + uid: 2 + } +} +body { + bind { + expr { + dataframe_ai_complete { + df { + dataframe_ref { + id: 2 + } + } + input_columns { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "prompt" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 9 + end_line: 35 + file: 2 + start_column: 14 + start_line: 31 + } + v: "Answer briefly: {0}" + } + } + pos_args { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 47 + end_line: 33 + file: 2 + start_column: 32 + start_line: 33 + } + v: "question" + } + } + src { + end_column: 47 + end_line: 33 + file: 2 + start_column: 32 + start_line: 33 + } + } + } + src { + end_column: 9 + end_line: 35 + file: 2 + start_column: 14 + start_line: 31 + } + } + } + model: "snowflake-arctic" + output_column { + value: "AI_COMPLETE_OUTPUT" + } + prompt: "Answer briefly: {q}" + src { + end_column: 9 + end_line: 35 + file: 2 + start_column: 14 + start_line: 31 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df2" + } + uid: 3 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Switzerland" + } + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Europe" + } + } + } + } + vs { + list_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Korea" + } + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Asia" + } + } + } + } + vs { + list_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "Brazil" + } + } + vs { + string_val { + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + v: "South America" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "country" + vs: "continent" + } + } + src { + end_column: 43 + end_line: 42 + file: 2 + start_column: 14 + start_line: 38 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df4" + } + uid: 4 + } +} +body { + bind { + expr { + dataframe_ai_filter { + df { + dataframe_ref { + id: 4 + } + } + input_columns { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "prompt" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 9 + end_line: 46 + file: 2 + start_column: 14 + start_line: 43 + } + v: "Is {0} located in {1} and specifically in Europe?" + } + } + pos_args { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 52 + end_line: 45 + file: 2 + start_column: 38 + start_line: 45 + } + v: "country" + } + } + src { + end_column: 52 + end_line: 45 + file: 2 + start_column: 38 + start_line: 45 + } + } + } + pos_args { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 83 + end_line: 45 + file: 2 + start_column: 67 + start_line: 45 + } + v: "continent" + } + } + src { + end_column: 83 + end_line: 45 + file: 2 + start_column: 67 + start_line: 45 + } + } + } + src { + end_column: 9 + end_line: 46 + file: 2 + start_column: 14 + start_line: 43 + } + } + } + predicate: "Is {country} located in {continent} and specifically in Europe?" + src { + end_column: 9 + end_line: 46 + file: 2 + start_column: 14 + start_line: 43 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df5" + } + uid: 5 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Excellent product, highly recommend!" + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Great quality and fast shipping" + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Average product, nothing special" + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + vs { + string_val { + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + v: "Poor quality, very disappointed" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "review" + } + } + src { + end_column: 29 + end_line: 54 + file: 2 + start_column: 14 + start_line: 49 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df6" + } + uid: 6 + } +} +body { + bind { + expr { + dataframe_ai_agg { + df { + dataframe_ref { + id: 6 + } + } + input_column { + string_val { + src { + end_column: 9 + end_line: 58 + file: 2 + start_column: 14 + start_line: 55 + } + v: "review" + } + } + output_column { + value: "AI_AGG_OUTPUT" + } + src { + end_column: 9 + end_line: 58 + file: 2 + start_column: 14 + start_line: 55 + } + task_description: "Summarize these product reviews for a blog post" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df7" + } + uid: 7 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + vs { + string_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + v: "I love hiking in the mountains" + } + } + } + } + vs { + list_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + vs { + string_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + v: "My favorite dish is pasta" + } + } + } + } + vs { + list_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + vs { + string_val { + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + v: "Just finished reading a great book" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 27 + end_line: 65 + file: 2 + start_column: 14 + start_line: 61 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df8" + } + uid: 8 + } +} +body { + bind { + expr { + dataframe_ai_classify { + categories { + list_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + vs { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "hiking" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "cooking" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "reading" + } + } + } + } + df { + dataframe_ref { + id: 8 + } + } + input_column { + string_val { + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + v: "text" + } + } + output_column { + value: "AI_CLASSIFY_OUTPUT" + } + src { + end_column: 9 + end_line: 69 + file: 2 + start_column: 14 + start_line: 66 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df9" + } + uid: 9 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "I love programming" + } + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "I enjoy coding" + } + } + } + } + vs { + list_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "The weather is nice" + } + } + vs { + string_val { + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + v: "It\'s raining heavily" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text1" + vs: "text2" + } + } + src { + end_column: 37 + end_line: 75 + file: 2 + start_column: 15 + start_line: 72 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df10" + } + uid: 10 + } +} +body { + bind { + expr { + dataframe_ai_similarity { + df { + dataframe_ref { + id: 10 + } + } + input1 { + string_val { + src { + end_column: 9 + end_line: 79 + file: 2 + start_column: 15 + start_line: 76 + } + v: "text1" + } + } + input2 { + string_val { + src { + end_column: 9 + end_line: 79 + file: 2 + start_column: 15 + start_line: 76 + } + v: "text2" + } + } + output_column { + value: "AI_SIMILARITY_OUTPUT" + } + src { + end_column: 9 + end_line: 79 + file: 2 + start_column: 15 + start_line: 76 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df11" + } + uid: 11 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + vs { + string_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + v: "The movie had amazing visual effects but the plot was terrible." + } + } + } + } + vs { + list_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + vs { + string_val { + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + v: "Everything about this experience was perfect!" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "review" + } + } + src { + end_column: 29 + end_line: 85 + file: 2 + start_column: 15 + start_line: 82 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df12" + } + uid: 12 + } +} +body { + bind { + expr { + dataframe_ai_sentiment { + df { + dataframe_ref { + id: 12 + } + } + input_column { + string_val { + src { + end_column: 55 + end_line: 86 + file: 2 + start_column: 15 + start_line: 86 + } + v: "review" + } + } + output_column { + value: "AI_SENTIMENT_OUTPUT" + } + src { + end_column: 55 + end_line: 86 + file: 2 + start_column: 15 + start_line: 86 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df13" + } + uid: 13 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + vs { + string_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + v: "Machine learning is fascinating" + } + } + } + } + vs { + list_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + vs { + string_val { + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + v: "Snowflake provides cloud data platform" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 27 + end_line: 92 + file: 2 + start_column: 15 + start_line: 89 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df14" + } + uid: 14 + } +} +body { + bind { + expr { + dataframe_ai_embed { + df { + dataframe_ref { + id: 14 + } + } + input_column { + string_val { + src { + end_column: 88 + end_line: 93 + file: 2 + start_column: 15 + start_line: 93 + } + v: "text" + } + } + model: "snowflake-arctic-embed-l-v2.0" + output_column { + value: "AI_EMBED_OUTPUT" + } + src { + end_column: 88 + end_line: 93 + file: 2 + start_column: 15 + start_line: 93 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df15" + } + uid: 15 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + vs { + string_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + v: "Meeting started with project updates" + } + } + } + } + vs { + list_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + vs { + string_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + v: "Discussed timeline and deliverables" + } + } + } + } + vs { + list_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + vs { + string_val { + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + v: "Identified key risks" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "notes" + } + } + src { + end_column: 28 + end_line: 100 + file: 2 + start_column: 15 + start_line: 96 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df16" + } + uid: 16 + } +} +body { + bind { + expr { + dataframe_ai_summarize_agg { + df { + dataframe_ref { + id: 16 + } + } + input_column { + string_val { + src { + end_column: 58 + end_line: 101 + file: 2 + start_column: 15 + start_line: 101 + } + v: "notes" + } + } + output_column { + value: "AI_SUMMARIZE_AGG_OUTPUT" + } + src { + end_column: 58 + end_line: 101 + file: 2 + start_column: 15 + start_line: 101 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df17" + } + uid: 17 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 121 + end_line: 104 + file: 2 + start_column: 15 + start_line: 104 + } + vs { + string_val { + src { + end_column: 121 + end_line: 104 + file: 2 + start_column: 15 + start_line: 104 + } + v: "John Smith lives in San Francisco and works for Snowflake" + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 121 + end_line: 104 + file: 2 + start_column: 15 + start_line: 104 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df20" + } + uid: 18 + } +} +body { + bind { + expr { + dataframe_ai_extract { + df { + dataframe_ref { + id: 18 + } + } + input_column { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "text" + } + } + output_column { + value: "AI_EXTRACT_OUTPUT" + } + response_format { + seq_map_val { + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "name" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "What is the first name of the employee?" + } + } + } + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "city" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + v: "What is the address of the employee?" + } + } + } + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + } + } + src { + end_column: 9 + end_line: 108 + file: 2 + start_column: 15 + start_line: 105 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df21" + } + uid: 19 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + vs { + string_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + v: "What is a large language model?" + } + } + } + } + vs { + list_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + vs { + string_val { + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + v: "Explain quantum computing in simple terms." + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 143 + end_line: 111 + file: 2 + start_column: 15 + start_line: 111 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df22" + } + uid: 20 + } +} +body { + bind { + expr { + dataframe_ai_count_tokens { + df { + dataframe_ref { + id: 20 + } + } + model: "llama3.1-70b" + output_column { + value: "COUNT_TOKENS_OUTPUT" + } + prompt { + string_val { + src { + end_column: 72 + end_line: 112 + file: 2 + start_column: 15 + start_line: 112 + } + v: "text" + } + } + src { + end_column: 72 + end_line: 112 + file: 2 + start_column: 15 + start_line: 112 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df23" + } + uid: 21 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 133 + end_line: 115 + file: 2 + start_column: 15 + start_line: 115 + } + vs { + string_val { + src { + end_column: 133 + end_line: 115 + file: 2 + start_column: 15 + start_line: 115 + } + v: "# Intro\nThis is the intro.\n## Background\nSome background info." + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "document" + } + } + src { + end_column: 133 + end_line: 115 + file: 2 + start_column: 15 + start_line: 115 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df24" + } + uid: 22 + } +} +body { + bind { + expr { + dataframe_ai_split_text_markdown_header { + chunk_size { + int64_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: 20 + } + } + df { + dataframe_ref { + id: 22 + } + } + headers_to_split_on { + seq_map_val { + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "#" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "section" + } + } + } + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "##" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "subsection" + } + } + } + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + } + } + output_column { + value: "SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT" + } + overlap { + int64_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: 5 + } + } + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + text_to_split { + string_val { + src { + end_column: 9 + end_line: 121 + file: 2 + start_column: 15 + start_line: 116 + } + v: "document" + } + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df25" + } + uid: 23 + } +} +body { + bind { + expr { + create_dataframe { + data { + dataframe_data__list { + vs { + list_val { + src { + end_column: 143 + end_line: 124 + file: 2 + start_column: 15 + start_line: 124 + } + vs { + string_val { + src { + end_column: 143 + end_line: 124 + file: 2 + start_column: 15 + start_line: 124 + } + v: "This is a long document. It has multiple sentences.\n\nAnd multiple paragraphs." + } + } + } + } + } + } + schema { + dataframe_schema__list { + vs: "text" + } + } + src { + end_column: 143 + end_line: 124 + file: 2 + start_column: 15 + start_line: 124 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df26" + } + uid: 24 + } +} +body { + bind { + expr { + dataframe_ai_split_text_recursive_character { + chunk_size { + int64_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: 30 + } + } + df { + dataframe_ref { + id: 24 + } + } + format { + ai_split_text_recursive_format_none: true + } + output_column { + value: "SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT" + } + overlap { + int64_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: 5 + } + } + separators { + tuple_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: "\n\n" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: "\n" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: " " + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + } + } + } + } + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + text_to_split { + string_val { + src { + end_column: 9 + end_line: 130 + file: 2 + start_column: 15 + start_line: 125 + } + v: "text" + } + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df27" + } + uid: 25 + } +} +client_ast_version: 1 +client_language { + python_language { + version { + label: "final" + major: 3 + minor: 9 + patch: 1 + } + } +} +client_version { + major: 1 + minor: 38 +} +id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" diff --git a/tests/ast/data/RelationalGroupedDataFrame.ai_agg.test b/tests/ast/data/RelationalGroupedDataFrame.ai_agg.test new file mode 100644 index 0000000000..d195e7e054 --- /dev/null +++ b/tests/ast/data/RelationalGroupedDataFrame.ai_agg.test @@ -0,0 +1,292 @@ +## TEST CASE + +from snowflake.snowpark.functions import col + +df = session.table(tables.table1) + +# group_by with a single column, expr as string +rgdf1 = df.group_by("str") +df1 = rgdf1.ai_agg("str", "Summarize strings per group") + +# group_by with a single column, expr as Column +df2 = rgdf1.ai_agg(col("str"), "Summarize strings per group using Column") + +# group_by with no columns (global aggregation) +rgdf2 = df.group_by() +df3 = rgdf2.ai_agg("str", "Summarize strings across all rows") + +## EXPECTED UNPARSER OUTPUT + +df = session.table("table1") + +rgdf1 = df.group_by("str") + +df1 = rgdf1.ai_agg("str", task_description="Summarize strings per group") + +df2 = rgdf1.ai_agg(col("str"), task_description="Summarize strings per group using Column") + +rgdf2 = df.group_by() + +df3 = rgdf2.ai_agg("str", task_description="Summarize strings across all rows") + +## EXPECTED ENCODED AST + +interned_value_table { + string_values { + key: -1 + } + string_values { + key: 2 + value: "SRC_POSITION_TEST_MODE" + } +} +body { + bind { + expr { + table { + name { + name { + name_flat { + name: "table1" + } + } + } + src { + end_column: 41 + end_line: 27 + file: 2 + start_column: 13 + start_line: 27 + } + variant { + session_table: true + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df" + } + uid: 1 + } +} +body { + bind { + expr { + dataframe_group_by { + cols { + args { + string_val { + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 16 + start_line: 30 + } + v: "str" + } + } + variadic: true + } + df { + dataframe_ref { + id: 1 + } + } + src { + end_column: 34 + end_line: 30 + file: 2 + start_column: 16 + start_line: 30 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "rgdf1" + } + uid: 2 + } +} +body { + bind { + expr { + relational_grouped_dataframe_ai_agg { + expr { + string_val { + src { + end_column: 64 + end_line: 31 + file: 2 + start_column: 14 + start_line: 31 + } + v: "str" + } + } + grouped_df { + relational_grouped_dataframe_ref { + id: 2 + } + } + src { + end_column: 64 + end_line: 31 + file: 2 + start_column: 14 + start_line: 31 + } + task_description: "Summarize strings per group" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df1" + } + uid: 3 + } +} +body { + bind { + expr { + relational_grouped_dataframe_ai_agg { + expr { + apply_expr { + fn { + builtin_fn { + name { + name { + name_flat { + name: "col" + } + } + } + } + } + pos_args { + string_val { + src { + end_column: 37 + end_line: 34 + file: 2 + start_column: 27 + start_line: 34 + } + v: "str" + } + } + src { + end_column: 37 + end_line: 34 + file: 2 + start_column: 27 + start_line: 34 + } + } + } + grouped_df { + relational_grouped_dataframe_ref { + id: 2 + } + } + src { + end_column: 82 + end_line: 34 + file: 2 + start_column: 14 + start_line: 34 + } + task_description: "Summarize strings per group using Column" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df2" + } + uid: 4 + } +} +body { + bind { + expr { + dataframe_group_by { + cols { + variadic: true + } + df { + dataframe_ref { + id: 1 + } + } + src { + end_column: 29 + end_line: 37 + file: 2 + start_column: 16 + start_line: 37 + } + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "rgdf2" + } + uid: 5 + } +} +body { + bind { + expr { + relational_grouped_dataframe_ai_agg { + expr { + string_val { + src { + end_column: 70 + end_line: 38 + file: 2 + start_column: 14 + start_line: 38 + } + v: "str" + } + } + grouped_df { + relational_grouped_dataframe_ref { + id: 5 + } + } + src { + end_column: 70 + end_line: 38 + file: 2 + start_column: 14 + start_line: 38 + } + task_description: "Summarize strings across all rows" + } + } + first_request_id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" + symbol { + value: "df3" + } + uid: 6 + } +} +client_ast_version: 1 +client_language { + python_language { + version { + label: "final" + major: 3 + minor: 9 + patch: 1 + } + } +} +client_version { + major: 1 + minor: 38 +} +id: "\003U\"\366q\366P\346\260\261?\234\303\254\316\353" diff --git a/tests/integ/test_dataframe_ai.py b/tests/integ/test_dataframe_ai.py new file mode 100644 index 0000000000..95f60ef8ee --- /dev/null +++ b/tests/integ/test_dataframe_ai.py @@ -0,0 +1,1564 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import json +import pytest +from snowflake.snowpark.functions import col, lit, to_file +from snowflake.snowpark.row import Row +from tests.utils import TestFiles, Utils +from snowflake.snowpark.exceptions import SnowparkSQLException + + +pytestmark = [ + pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="AI functions are not yet supported in local testing mode.", + ), +] + + +def test_dataframe_ai_complete_with_named_placeholders(session): + """Test DataFrame.ai.complete with named placeholders.""" + # Create a DataFrame with review data + df = session.create_dataframe( + [ + ["Great product, highly recommend!", 5, "electronics"], + ["Poor quality, very disappointed", 1, "clothing"], + ["Average product, nothing special", 3, "books"], + ], + schema=["review", "rating", "category"], + ) + + # Use DataFrame.ai.complete with named placeholders + result_df = df.ai.complete( + prompt="Analyze this {category} product review: '{review}' with rating {rating}/5. What is the sentiment?", + input_columns={ + "review": col("review"), + "rating": col("rating"), + "category": col("category"), + }, + output_column="sentiment_analysis", + model="snowflake-arctic", + ) + + # Check schema + assert result_df.columns == ["REVIEW", "RATING", "CATEGORY", "SENTIMENT_ANALYSIS"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 3 + + for row in results: + assert row["SENTIMENT_ANALYSIS"] is not None + assert len(row["SENTIMENT_ANALYSIS"]) > 0 + + +def test_dataframe_ai_complete_with_positional_placeholders(session): + """Test DataFrame.ai.complete with positional placeholders.""" + # Create a DataFrame + df = session.create_dataframe( + [ + ["Python", "programming"], + ["SQL", "database"], + ["Machine Learning", "AI"], + ], + schema=["topic", "category"], + ) + + # Use DataFrame.ai.complete with positional placeholders + result_df = df.ai.complete( + prompt="Define {0} in the context of {1} in one sentence.", + input_columns=[col("topic"), col("category")], + output_column="definition", + model="claude-4-sonnet", + ) + + # Check schema + assert result_df.columns == ["TOPIC", "CATEGORY", "DEFINITION"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 3 + + for row in results: + assert row["DEFINITION"] is not None + assert len(row["DEFINITION"]) > 10 + + +def test_dataframe_ai_complete_default_output_column(session): + """Test DataFrame.ai.complete with default output column name.""" + df = session.create_dataframe( + [["What is 2+2?"], ["What is the capital of France?"]], schema=["question"] + ) + + # Don't specify output_column, should use default + result_df = df.ai.complete( + prompt="Answer the question", + input_columns=[col("question")], + model="snowflake-arctic", + model_parameters={ + "temperature": 0.8, + "top_p": 0.95, + "max_tokens": 100, + "guardrails": False, + }, + ) + + # Check that default column name is used + assert "AI_COMPLETE_OUTPUT" in result_df.schema.names + + results = result_df.collect() + assert len(results) == 2 + for row in results: + assert row["AI_COMPLETE_OUTPUT"] is not None + + +def test_dataframe_ai_complete_error_handling(session): + """Test error handling in DataFrame.ai.complete.""" + + # Test missing model parameter + df = session.create_dataframe([["test"]], schema=["text"]) + with pytest.raises( + TypeError, match="missing 1 required positional argument: 'model'" + ): + df.ai.complete( + prompt="Test {text}", + input_columns={"text": col("text")} + # model parameter missing + ) + + # Test invalid input_columns type + with pytest.raises( + TypeError, match="input_columns must be a list of Columns or a dict" + ): + df.ai.complete( + prompt="Test", + input_columns="invalid", # Should be list or dict + model="snowflake-arctic", + ) + + +def test_dataframe_ai_filter_simple_text(session): + """Test DataFrame.ai.filter with simple text predicate.""" + # Create a DataFrame with sentiment data + df = session.create_dataframe( + [ + ["This is amazing!"], + ["This is terrible!"], + ["This is okay."], + ["I love this product!"], + ["I hate this service."], + ], + schema=["review"], + ) + + # Filter for positive reviews + positive_df = df.ai.filter( + "Is this review positive?", input_columns=[col("review")] + ) + + # Check that we get some results (should be the positive ones) + results = positive_df.collect() + assert len(results) >= 1 # At least some positive reviews should be found + assert len(results) <= 5 # Not more than the original count + + # Verify the filtered results contain positive sentiment words + positive_reviews = [row["REVIEW"] for row in results] + positive_indicators = ["amazing", "love"] + assert any( + any(indicator in review.lower() for indicator in positive_indicators) + for review in positive_reviews + ) + + +def test_dataframe_ai_filter_with_named_placeholders(session): + """Test DataFrame.ai.filter with named placeholders.""" + # Create a DataFrame with country and continent data + df = session.create_dataframe( + [ + ["Switzerland", "Europe"], + ["Korea", "Asia"], + ["Brazil", "South America"], + ["Germany", "Europe"], + ["Japan", "Asia"], + ], + schema=["country", "continent"], + ) + + # Filter for European countries + european_df = df.ai.filter( + "Is {country} located in {continent} and specifically in Europe?", + input_columns={"country": col("country"), "continent": col("continent")}, + ) + + # Check results + results = european_df.collect() + assert len(results) >= 1 # Should find at least one European country + + # Verify the results are European countries + countries = [row["COUNTRY"] for row in results] + european_countries = ["Switzerland", "Germany"] + assert any(country in european_countries for country in countries) + + +def test_dataframe_ai_filter_with_positional_placeholders(session): + """Test DataFrame.ai.filter with positional placeholders.""" + # Create a DataFrame with programming languages and their types + df = session.create_dataframe( + [ + ["Python", "programming"], + ["SQL", "database"], + ["HTML", "markup"], + ["JavaScript", "programming"], + ["CSS", "styling"], + ], + schema=["language", "type"], + ) + + # Filter for programming languages + programming_df = df.ai.filter( + "Is {0} primarily used for {1} and is it a programming language?", + input_columns=[col("language"), col("type")], + ) + + # Check results + results = programming_df.collect() + assert len(results) >= 1 # Should find at least one programming language + + # Verify the results contain programming languages + languages = [row["LANGUAGE"] for row in results] + programming_languages = ["Python", "JavaScript"] + assert any(lang in programming_languages for lang in languages) + + +def test_dataframe_ai_filter_error_handling(session): + """Test error handling in DataFrame.ai.filter.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid input_columns type + with pytest.raises( + TypeError, match="input_columns must be a list of Columns or a dict" + ): + df.ai.filter( + "Test predicate", + input_columns="invalid", # Should be list or dict + ) + + +def test_dataframe_ai_agg_basic(session): + """Test DataFrame.ai.agg with basic usage.""" + # Create a DataFrame with review data + df = session.create_dataframe( + [ + ["Excellent product, highly recommend!"], + ["Great quality and fast shipping"], + ["Average product, nothing special"], + ["Poor quality, very disappointed"], + ["Amazing value for money"], + ], + schema=["review"], + ) + + summary_df = df.ai.agg( + task_description="Summarize these product reviews for a blog post targeting consumers", + input_column=col("review"), + output_column="summary", + ) + + # Verify results + results = summary_df.collect() + assert len(results) == 1 # Should be a single aggregated row + assert results[0]["SUMMARY"] is not None + assert len(results[0]["SUMMARY"]) > 10 # Should have meaningful content + + summary_df = df.ai.agg( + task_description="Summarize these product reviews for a blog post targeting consumers", + input_column=col("review"), + ) + + # Verify results + results = summary_df.collect() + assert len(results) == 1 # Should be a single aggregated row + assert results[0]["AI_AGG_OUTPUT"] is not None + assert len(results[0]["AI_AGG_OUTPUT"]) > 10 # Should have meaningful content + + +def test_dataframe_ai_agg_error_handling(session): + """Test error handling in DataFrame.ai.agg.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid input_column type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.agg( + task_description="Summarize the text", + input_column=123, # Invalid type + ) + + # Test invalid column name + with pytest.raises(SnowparkSQLException, match="invalid identifier 'INVALID'"): + df.ai.agg( + task_description="Summarize the text", + input_column=col("invalid"), # Invalid column name + ).collect() + + +def test_grouped_dataframe_ai_agg(session): + """Comprehensive test for GroupedDataFrame.ai_agg with various grouping scenarios and chained operations.""" + + # Create a single DataFrame with product reviews that can be grouped in multiple ways + df = session.create_dataframe( + [ + ["electronics", "high", 4.5, "Excellent product, highly recommend!"], + ["electronics", "high", 4.8, "Outstanding quality and performance"], + ["electronics", "low", 2.5, "Not worth the price, disappointed"], + ["electronics", "low", 2.2, "Broke after a week, poor quality"], + ["clothing", "high", 4.2, "Perfect fit and great material"], + ["clothing", "high", 4.6, "Beautiful design, love it!"], + ["clothing", "low", 2.8, "Poor quality fabric, not as described"], + ["clothing", "low", 2.1, "Color faded quickly, sizing issues"], + ["books", "high", 4.9, "Fantastic read, couldn't put it down!"], + ["books", "high", 4.7, "Well written and engaging"], + ["toys", "low", 1.5, "Poor quality, broke quickly"], + ["toys", "low", 1.8, "Not safe for children, avoid"], + ], + schema=["category", "quality_level", "rating", "review"], + ) + + # Test 1: Group by empty list (aggregate entire DataFrame) + global_summary_df = df.group_by().ai_agg( + "review", + task_description="Create an overall summary of all customer reviews", + ) + + count = global_summary_df.count() + assert count == 1 # Single row for global aggregation + + # Test 2: Group by single column with string expr + category_summary_df = ( + df.group_by("category") + .ai_agg( + "review", + task_description="Summarize product reviews for a blog post", + ) + .filter(col("CATEGORY") != "toys") # Chain filter operation + .sort(col("CATEGORY").asc()) # Chain sort operation + ) + + category_results = category_summary_df.collect() + assert len(category_results) == 3 # 3 categories after filtering out toys + categories = [row["CATEGORY"] for row in category_results] + assert categories == ["books", "clothing", "electronics"] # Alphabetically sorted + + # Verify each category has a summary + for row in category_results: + summary_cols = [c for c in row.as_dict().keys() if c != "CATEGORY"] + assert len(summary_cols) == 1 + assert row[summary_cols[0]] is not None + assert len(row[summary_cols[0]]) > 10 + + # Test 3: Group by single column with Column object and select operation + quality_summary_df = ( + df.group_by("quality_level") + .ai_agg( + col("review"), # Using Column object + task_description="Extract key insights from customer feedback", + ) + .select("QUALITY_LEVEL") # Chain select to keep only grouping column + ) + + quality_results = quality_summary_df.collect() + assert len(quality_results) == 2 # Two quality levels: high and low + assert all( + len(row.as_dict()) == 1 for row in quality_results + ) # Only QUALITY_LEVEL column + quality_levels = {row["QUALITY_LEVEL"] for row in quality_results} + assert quality_levels == {"high", "low"} + + # Test 4: Group by multiple columns with filtering and limit + multi_group_df = ( + df.group_by(["category", "quality_level"]) + .ai_agg( + col("review"), + task_description="Summarize reviews highlighting common themes", + ) + .filter(col("QUALITY_LEVEL") == "high") # Only high quality items + .sort(col("CATEGORY").desc()) # Sort by category descending + .limit(2) # Limit to top 2 results + ) + + multi_results = multi_group_df.collect() + assert len(multi_results) == 2 # Limited to 2 results + + # Verify the results are high quality and sorted correctly + for row in multi_results: + assert row["QUALITY_LEVEL"] == "high" + + # First result should be from toys or electronics (descending order) + first_category = multi_results[0]["CATEGORY"] + assert first_category in ["toys", "electronics", "clothing", "books"] + + # Test 5: Complex chaining with rename and additional filter + complex_chain_df = ( + df.filter(col("rating") > 2.0) # Pre-filter low ratings + .group_by("category") + .ai_agg( + "review", + task_description="Brief summary of positive reviews", + ) + .filter( + col("CATEGORY").isin(["electronics", "books"]) + ) # Keep only specific categories + .with_column_renamed("CATEGORY", "PRODUCT_TYPE") # Rename column + .sort("PRODUCT_TYPE") + ) + + complex_results = complex_chain_df.collect() + assert len(complex_results) == 2 # Only electronics and books + assert "PRODUCT_TYPE" in complex_results[0].as_dict() + assert "CATEGORY" not in complex_results[0].as_dict() + + product_types = [row["PRODUCT_TYPE"] for row in complex_results] + assert product_types == ["books", "electronics"] # Sorted alphabetically + + +def test_dataframe_ai_classify_basic(session): + """Test DataFrame.ai.classify with basic usage.""" + # Create a DataFrame with text data + df = session.create_dataframe( + [ + ["I love hiking in the mountains"], + ["My favorite dish is pasta carbonara"], + ["Just finished reading a great book"], + ["Learning to cook Italian cuisine"], + ], + schema=["text"], + ) + + # Use DataFrame.ai.classify with list of categories + result_df = df.ai.classify( + input_column="text", + categories=["hiking", "cooking", "reading"], + output_column="category", + ) + + # Check schema + assert result_df.columns == ["TEXT", "CATEGORY"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 4 + + # Verify some expected classifications + text_to_category = { + row["TEXT"]: json.loads(row["CATEGORY"])["labels"][0] for row in results + } + + # These should be fairly obvious classifications + assert text_to_category["My favorite dish is pasta carbonara"] == "cooking" + assert text_to_category["Just finished reading a great book"] == "reading" + assert text_to_category["I love hiking in the mountains"] == "hiking" + + +def test_dataframe_ai_classify_multi_label(session): + """Test DataFrame.ai.classify with multi-label classification.""" + # Create a DataFrame with text that may belong to multiple categories + df = session.create_dataframe( + [ + ["I enjoy traveling and trying local cuisines"], + ["Reading books while on a flight to Paris"], + ["Cooking recipes from different countries I've visited"], + ["Training for a marathon while exploring new cities"], + ], + schema=["text"], + ) + df = df.with_column( + "categories", lit(["travel", "cooking", "reading", "sports", "education"]) + ) + + # Use multi-label classification with task description + result_df = df.ai.classify( + input_column=col("text"), + categories=col("categories"), + output_column="topics", + task_description="Identify all topics mentioned in the text", + output_mode="multi", + ) + + # Check schema + assert result_df.columns == ["TEXT", "CATEGORIES", "TOPICS"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 4 + + # First text mentions both travel and cooking + first_labels = json.loads(results[0]["TOPICS"])["labels"] + assert "travel" in first_labels or "cooking" in first_labels + # At least one row should have multiple labels + assert any(len(json.loads(row["TOPICS"])["labels"]) > 1 for row in results) + + +def test_dataframe_ai_classify_with_examples(session): + """Test DataFrame.ai.classify with few-shot examples.""" + # Create a DataFrame with ambiguous text + df = session.create_dataframe( + [ + ["The service was outstanding"], + ["The product broke after one day"], + ["Average experience, nothing special"], + ["Exceeded all my expectations"], + ], + schema=["feedback"], + ) + + # Use classification with examples for better accuracy + result_df = df.ai.classify( + input_column="feedback", + categories=["positive", "negative", "neutral"], + output_column="sentiment", + task_description="Classify customer feedback sentiment", + examples=[ + { + "input": "This is the best product ever", + "labels": ["positive"], + "explanation": "The feedback expresses strong satisfaction", + }, + { + "input": "Terrible quality, want my money back", + "labels": ["negative"], + "explanation": "The feedback expresses dissatisfaction and complaint", + }, + { + "input": "It is okay, not great but not bad", + "labels": ["neutral"], + "explanation": "The feedback shows neither strong positive nor negative sentiment", + }, + ], + ) + + # Check schema + assert result_df.columns == ["FEEDBACK", "SENTIMENT"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 4 + + # Check sentiment classifications + feedback_to_sentiment = { + row["FEEDBACK"]: json.loads(row["SENTIMENT"])["labels"][0] for row in results + } + + assert feedback_to_sentiment["The service was outstanding"] == "positive" + assert feedback_to_sentiment["The product broke after one day"] == "negative" + assert feedback_to_sentiment["Average experience, nothing special"] == "neutral" + assert feedback_to_sentiment["Exceeded all my expectations"] == "positive" + + +def test_dataframe_ai_classify_default_output_column(session): + """Test DataFrame.ai.classify with default output column name.""" + df = session.create_dataframe([["apple"], ["carrot"], ["chicken"]], schema=["item"]) + + # Don't specify output_column, should use default + result_df = df.ai.classify( + input_column="item", + categories=["fruit", "vegetable", "meat", "dairy"], + ) + + # Check that default column name is used + assert "AI_CLASSIFY_OUTPUT" in result_df.columns + + +def test_dataframe_ai_similarity_basic(session): + """Test DataFrame.ai.similarity with basic text similarity.""" + # Create a DataFrame with text pairs + df = session.create_dataframe( + [ + ["I love programming", "I enjoy coding"], + ["The weather is nice today", "It's a beautiful day"], + ["Python is great", "Python is awesome"], + ["Cats are cute", "Dogs are loyal"], + ], + schema=["text1", "text2"], + ) + + # Use DataFrame.ai.similarity with string column names + result_df = df.ai.similarity( + input1="text1", + input2="text2", + output_column="similarity_score", + ) + + # Check schema + assert result_df.columns == ["TEXT1", "TEXT2", "SIMILARITY_SCORE"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 4 + + # Verify similarity scores are in range + for row in results: + score = row["SIMILARITY_SCORE"] + assert score is not None + assert -1 <= score <= 1 + + # Verify relative similarities make sense + # "I love programming" and "I enjoy coding" should be more similar + # than "Cats are cute" and "Dogs are loyal" + programming_similarity = results[0]["SIMILARITY_SCORE"] + python_similarity = results[2]["SIMILARITY_SCORE"] + pets_similarity = results[3]["SIMILARITY_SCORE"] + + # Python statements should be very similar + assert python_similarity > 0.7 + # Programming statements should be somewhat similar + assert programming_similarity > 0.4 + # Pets statements are less similar (different animals) + assert pets_similarity < programming_similarity + + +def test_dataframe_ai_similarity_default_output_column(session): + """Test DataFrame.ai.similarity with default output column name.""" + df = session.create_dataframe( + [["apple", "orange"], ["car", "vehicle"]], schema=["word1", "word2"] + ) + + # Don't specify output_column, should use default + result_df = df.ai.similarity( + input1="word1", + input2="word2", + ) + + # Check that default column name is used + assert "AI_SIMILARITY_OUTPUT" in result_df.columns + + results = result_df.collect() + assert len(results) == 2 + for row in results: + assert row["AI_SIMILARITY_OUTPUT"] is not None + + +def test_dataframe_ai_similarity_with_custom_model(session): + """Test DataFrame.ai.similarity with custom embedding model.""" + # Create a DataFrame with multilingual text + df = session.create_dataframe( + [ + ["Hello world", "Hola mundo"], # English and Spanish + ["Good morning", "Guten Morgen"], # English and German + ["Thank you", "Merci"], # English and French + ], + schema=["english", "other_language"], + ) + + # Use multilingual model for better cross-lingual similarity + result_df = df.ai.similarity( + input1=col("english"), + input2=col("other_language"), + output_column="cross_lingual_similarity", + model="multilingual-e5-large", + ) + + # Check schema + assert result_df.columns == [ + "ENGLISH", + "OTHER_LANGUAGE", + "CROSS_LINGUAL_SIMILARITY", + ] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 3 + + # With multilingual model, translations should have high similarity + for row in results: + # Translations should have moderate to high similarity with multilingual model + assert row["CROSS_LINGUAL_SIMILARITY"] > 0.3 + + +def test_dataframe_ai_similarity_error_handling(session): + """Test error handling in DataFrame.ai.similarity.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid input type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.similarity( + input1=123, # Invalid type + input2="text", + ) + + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.similarity( + input1="text", + input2=[col("text")], # Invalid type (list) + ) + + # Test invalid column name + with pytest.raises(SnowparkSQLException, match="invalid identifier"): + df.ai.similarity( + input1="text", + input2="nonexistent_column", + ).collect() + + +def test_dataframe_ai_similarity_with_nulls(session): + """Test DataFrame.ai.similarity behavior with NULL values.""" + # Create a DataFrame with some NULL values + df = session.create_dataframe( + [ + ["Hello", "World"], + [None, "Test"], + ["Test", None], + ["Snowflake", "Database"], + ], + schema=["col1", "col2"], + ) + + result_df = df.ai.similarity( + input1="col1", + input2="col2", + output_column="similarity", + ) + + results = result_df.collect() + assert len(results) == 4 + + # Rows with NULLs should return NULL similarity scores + assert results[1]["SIMILARITY"] is None # First column is NULL + assert results[2]["SIMILARITY"] is None # Second column is NULL + + # Rows without NULLs should have valid scores + assert results[0]["SIMILARITY"] is not None + assert results[3]["SIMILARITY"] is not None + + +def test_dataframe_ai_sentiment_basic(session): + """Test DataFrame.ai.sentiment with basic usage.""" + # Create a DataFrame with review data + df = session.create_dataframe( + [ + ["The movie had amazing visual effects but the plot was terrible."], + ["The food was delicious but the service was slow."], + ["Everything about this experience was perfect!"], + ["This product is okay, nothing special."], + ], + schema=["review"], + ) + + # Test overall sentiment analysis + result_df = df.ai.sentiment( + input_column="review", + output_column="sentiment", + ) + + # Check schema + assert result_df.columns == ["REVIEW", "SENTIMENT"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 4 + + # Parse sentiment results + for row in results: + sentiment_data = json.loads(row["SENTIMENT"]) + assert "categories" in sentiment_data + assert len(sentiment_data["categories"]) >= 1 + + # Find overall sentiment + overall = [c for c in sentiment_data["categories"] if c["name"] == "overall"][0] + assert overall["sentiment"] in [ + "positive", + "negative", + "neutral", + "mixed", + "unknown", + ] + + # Third review should be positive + third_sentiment = json.loads(results[2]["SENTIMENT"])["categories"][0] + assert third_sentiment["name"] == "overall" + assert third_sentiment["sentiment"] == "positive" + + +def test_dataframe_ai_sentiment_with_categories(session): + """Test DataFrame.ai.sentiment with specific categories.""" + # Create a DataFrame with hotel reviews + df = session.create_dataframe( + [ + [ + "The hotel room was spacious and clean, but the wifi was terrible and the breakfast was mediocre." + ], + ["Great location and friendly staff, though the parking was expensive."], + ["Room was dirty, staff was rude, but the location was perfect."], + ], + schema=["review"], + ) + + # Test sentiment with specific categories + result_df = df.ai.sentiment( + input_column=col("review"), + categories=["room", "wifi", "breakfast", "location", "staff", "parking"], + output_column="detailed_sentiment", + ) + + # Check schema + assert result_df.columns == ["REVIEW", "DETAILED_SENTIMENT"] + + # Collect and verify results + results = result_df.collect() + assert len(results) == 3 + + # Check first review sentiments + first_sentiment = json.loads(results[0]["DETAILED_SENTIMENT"]) + categories = first_sentiment["categories"] + + # Should have overall plus the specified categories + assert len(categories) > 1 + + category_names = [c["name"] for c in categories] + assert "overall" in category_names + assert "room" in category_names + assert "wifi" in category_names + + # Room should be positive (spacious and clean) + room_sentiment = [c for c in categories if c["name"] == "room"][0] + assert room_sentiment["sentiment"] in ["positive", "neutral"] + + # Wifi should be negative (terrible) + wifi_sentiment = [c for c in categories if c["name"] == "wifi"][0] + assert wifi_sentiment["sentiment"] == "negative" + + +def test_dataframe_ai_sentiment_default_output_column(session): + """Test DataFrame.ai.sentiment with default output column name.""" + df = session.create_dataframe( + [["Great product!"], ["Terrible experience"]], schema=["feedback"] + ) + + # Don't specify output_column, should use default + result_df = df.ai.sentiment(input_column="feedback") + + # Check that default column name is used + assert "AI_SENTIMENT_OUTPUT" in result_df.columns + + results = result_df.collect(_emit_ast=False) + assert len(results) == 2 + for row in results: + assert row["AI_SENTIMENT_OUTPUT"] is not None + + +def test_dataframe_ai_sentiment_error_handling(session): + """Test error handling in DataFrame.ai.sentiment.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid input_column type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.sentiment( + input_column=123, # Invalid type + ) + + +def test_dataframe_ai_embed_text(session): + """Test DataFrame.ai.embed with text embeddings.""" + # Create a DataFrame with text data + df = session.create_dataframe( + [ + ["Machine learning is fascinating"], + ["Snowflake provides cloud data platform"], + ["Python is a versatile programming language"], + ], + schema=["text"], + ) + + # Generate text embeddings + result_df = df.ai.embed( + input_column="text", + model="snowflake-arctic-embed-l-v2.0", + output_column="text_vector", + ) + + # Check schema + assert result_df.columns == ["TEXT", "TEXT_VECTOR"] + + # Collect and verify results + results = result_df.collect(_emit_ast=False) + assert len(results) == 3 + + # Verify embeddings are generated + for row in results: + assert row["TEXT_VECTOR"] is not None + assert len(row["TEXT_VECTOR"]) > 0 + + +def test_dataframe_ai_embed_multilingual(session): + """Test DataFrame.ai.embed with multilingual text.""" + # Create a DataFrame with multilingual greetings + df = session.create_dataframe( + [ + ["Hello world"], + ["Bonjour le monde"], + ["Hola mundo"], + ["你好世界"], + ], + schema=["greeting"], + ) + + # Generate embeddings with multilingual model + result_df = df.ai.embed( + input_column=col("greeting"), + model="multilingual-e5-large", + output_column="multilingual_vector", + ) + + # Check schema + assert result_df.columns == ["GREETING", "MULTILINGUAL_VECTOR"] + + # Collect and verify results + results = result_df.collect(_emit_ast=False) + assert len(results) == 4 + + # All greetings should have embeddings + for row in results: + assert row["MULTILINGUAL_VECTOR"] is not None + assert len(row["MULTILINGUAL_VECTOR"]) > 0 + + +def test_dataframe_ai_embed_default_output_column(session): + """Test DataFrame.ai.embed with default output column name.""" + df = session.create_dataframe([["Sample text"]], schema=["content"]) + + # Don't specify output_column, should use default + result_df = df.ai.embed( + input_column="content", model="snowflake-arctic-embed-l-v2.0" + ) + + # Check that default column name is used + assert "AI_EMBED_OUTPUT" in result_df.columns + + results = result_df.collect(_emit_ast=False) + assert len(results) == 1 + assert results[0]["AI_EMBED_OUTPUT"] is not None + + +def test_dataframe_ai_embed_error_handling(session): + """Test error handling in DataFrame.ai.embed.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid input_column type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.embed( + input_column=123, model="snowflake-arctic-embed-l-v2.0" # Invalid type + ) + + +def test_dataframe_ai_summarize_agg_basic(session): + """Test DataFrame.ai.summarize_agg with basic usage.""" + # Create a DataFrame with review data + df = session.create_dataframe( + [ + ["The product quality is excellent and shipping was fast."], + ["Great value for money, highly recommend!"], + ["Customer service was very helpful and responsive."], + ["The packaging could be better, but the product itself is good."], + ["Easy to use and works as advertised."], + ], + schema=["review"], + ) + + # Summarize the reviews + summary_df = df.ai.summarize_agg( + input_column="review", + output_column="reviews_summary", + ) + + # Check schema and results + assert summary_df.columns == ["REVIEWS_SUMMARY"] + assert summary_df.count() == 1 + + results = summary_df.collect() + assert len(results) == 1 + assert results[0]["REVIEWS_SUMMARY"] is not None + assert len(results[0]["REVIEWS_SUMMARY"]) > 10 # Should have meaningful content + + +def test_dataframe_ai_summarize_agg_default_output_column(session): + """Test DataFrame.ai.summarize_agg with default output column name.""" + df = session.create_dataframe( + [ + ["First feedback item"], + ["Second feedback item"], + ["Third feedback item"], + ], + schema=["feedback"], + ) + + # Don't specify output_column, should use default + summary_df = df.ai.summarize_agg(input_column="feedback") + + # Check that default column name is used + assert "AI_SUMMARIZE_AGG_OUTPUT" in summary_df.columns + assert summary_df.count() == 1 + + results = summary_df.collect() + assert results[0]["AI_SUMMARIZE_AGG_OUTPUT"] is not None + + +def test_dataframe_ai_summarize_agg_error_handling(session): + """Test error handling in DataFrame.ai.summarize_agg.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid input_column type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.summarize_agg( + input_column=123, # Invalid type + ) + + +def test_dataframe_ai_transcribe_basic(session, resources_path): + """Test DataFrame.ai.transcribe with basic usage.""" + stage_name = Utils.random_stage_name() + _ = session.sql( + f"CREATE OR REPLACE TEMP STAGE {stage_name} ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')" + ).collect() + test_files = TestFiles(resources_path) + audio_local = test_files.test_audio_ogg + _ = session.file.put(audio_local, f"@{stage_name}", auto_compress=False) + + df = session.create_dataframe([[f"@{stage_name}/audio.ogg"]], schema=["audio_path"]) + + result_df = df.ai.transcribe( + input_column=to_file(col("audio_path")), + output_column="transcript", + ) + + assert result_df.columns == ["AUDIO_PATH", "TRANSCRIPT"] + + results = result_df.collect() + assert len(results) == 1 + data = json.loads(results[0]["TRANSCRIPT"]) if results[0]["TRANSCRIPT"] else {} + assert isinstance(data, dict) + assert data.get("audio_duration", 0) > 100 + assert (data.get("text") or "").lower().find( + "glad to see things are going well" + ) >= 0 + + +def test_dataframe_ai_transcribe_default_output_column(session, resources_path): + """Test DataFrame.ai.transcribe with default output and word timestamps.""" + stage_name = Utils.random_stage_name() + _ = session.sql( + f"CREATE OR REPLACE TEMP STAGE {stage_name} ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')" + ).collect() + audio_local = TestFiles(resources_path).test_audio_ogg + _ = session.file.put(audio_local, f"@{stage_name}", auto_compress=False) + + df = session.create_dataframe([[f"@{stage_name}/audio.ogg"]], schema=["audio_path"]) + + result_df = df.ai.transcribe( + input_column=to_file(col("audio_path")), + timestamp_granularity="word", + ) + + assert "AI_TRANSCRIBE_OUTPUT" in result_df.columns + results = result_df.collect() + data = ( + json.loads(results[0]["AI_TRANSCRIBE_OUTPUT"]) + if results[0]["AI_TRANSCRIBE_OUTPUT"] + else {} + ) + assert data.get("audio_duration", 0) > 0 + assert isinstance(data.get("segments", []), list) and len(data["segments"]) > 0 + assert all("start" in s and "end" in s for s in data["segments"]) + + +def test_dataframe_ai_parse_document_basic(session, resources_path): + """Test DataFrame.ai.parse_document OCR on a PDF document.""" + stage_name = Utils.random_stage_name() + _ = session.sql( + f"CREATE OR REPLACE TEMP STAGE {stage_name} ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')" + ).collect() + file_local = TestFiles(resources_path).test_doc_pdf + _ = session.file.put(file_local, f"@{stage_name}", auto_compress=False) + + df = session.create_dataframe([[f"@{stage_name}/doc.pdf"]], schema=["file_path"]) + + result_df = df.ai.parse_document( + input_column=to_file(col("file_path")), + output_column="parsed", + mode="OCR", + ) + + assert result_df.columns == ["FILE_PATH", "PARSED"] + + results = result_df.collect() + data = json.loads(results[0]["PARSED"]) if results[0]["PARSED"] else {} + assert isinstance(data, dict) + assert "content" in data and isinstance(data["content"], str) + assert isinstance(data.get("metadata", {}), dict) + assert data["metadata"].get("pageCount", 0) >= 3 + + +def test_dataframe_ai_parse_document_default_output_column(session, resources_path): + """Test DataFrame.ai.parse_document default output with page splitting.""" + stage_name = Utils.random_stage_name() + _ = session.sql( + f"CREATE OR REPLACE TEMP STAGE {stage_name} ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')" + ).collect() + file_local = TestFiles(resources_path).test_doc_pdf + _ = session.file.put(file_local, f"@{stage_name}", auto_compress=False) + + df = session.create_dataframe([[f"@{stage_name}/doc.pdf"]], schema=["file_path"]) + + result_df = df.ai.parse_document( + input_column=to_file(col("file_path")), + page_split=True, + ) + + assert "AI_PARSE_DOCUMENT_OUTPUT" in result_df.columns + results = result_df.collect() + data = ( + json.loads(results[0]["AI_PARSE_DOCUMENT_OUTPUT"]) + if results[0]["AI_PARSE_DOCUMENT_OUTPUT"] + else {} + ) + assert isinstance(data, dict) + assert isinstance(data.get("pages", []), list) and len(data["pages"]) >= 1 + first_page = data["pages"][0] + assert ( + isinstance(first_page, dict) + and "index" in first_page + and "content" in first_page + ) + + +def test_dataframe_ai_extract_text_basic(session): + """Test DataFrame.ai.extract with text input and dict response format.""" + df = session.create_dataframe( + [ + ["John Smith lives in San Francisco"], + ["Alice Johnson works in Seattle"], + ], + schema=["text"], + ) + + result_df = df.ai.extract( + input_column="text", + response_format={ + "name": "What is the first name?", + "city": "What city is mentioned?", + }, + output_column="extracted", + ) + + assert result_df.columns == ["TEXT", "EXTRACTED"] + + results = result_df.collect(_emit_ast=False) + assert len(results) == 2 + text_to_response = {} + for row in results: + assert row["EXTRACTED"] is not None + data = json.loads(row["EXTRACTED"]) if row["EXTRACTED"] else {} + assert isinstance(data, dict) and "response" in data + text_to_response[row["TEXT"]] = data["response"] + + assert text_to_response["John Smith lives in San Francisco"]["name"] == "John" + assert ( + text_to_response["John Smith lives in San Francisco"]["city"] == "San Francisco" + ) + assert text_to_response["Alice Johnson works in Seattle"]["name"] == "Alice" + assert text_to_response["Alice Johnson works in Seattle"]["city"] == "Seattle" + + +def test_dataframe_ai_extract_default_output_column(session): + """Test DataFrame.ai.extract uses default output column name.""" + df = session.create_dataframe([["Bob lives in Denver"]], schema=["text"]) + + result_df = df.ai.extract( + input_column="text", + response_format=[ + ["person", "What is the first name?"], + ["location", "What city is mentioned?"], + ], + ) + + assert "AI_EXTRACT_OUTPUT" in result_df.columns + results = result_df.collect(_emit_ast=False) + data = ( + json.loads(results[0]["AI_EXTRACT_OUTPUT"]) + if results[0]["AI_EXTRACT_OUTPUT"] + else {} + ) + assert isinstance(data, dict) and isinstance(data.get("response", {}), dict) + assert data["response"]["person"] == "Bob" + assert data["response"]["location"] == "Denver" + + +def test_dataframe_ai_extract_file(session, resources_path): + """Test DataFrame.ai.extract on a staged PDF file.""" + stage_name = Utils.random_stage_name() + _ = session.sql( + f"CREATE OR REPLACE TEMP STAGE {stage_name} ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')" + ).collect() + file_local = TestFiles(resources_path).test_invoice_pdf + _ = session.file.put(file_local, f"@{stage_name}", auto_compress=False) + + df = session.create_dataframe( + [[f"@{stage_name}/invoice.pdf"]], schema=["file_path"] + ) + + result_df = df.ai.extract( + input_column=to_file(col("file_path")), + response_format=[ + ["date", "What is the invoice date?"], + ["amount", "What is the amount?"], + ], + output_column="info", + ) + + assert result_df.columns == ["FILE_PATH", "INFO"] + results = result_df.collect(_emit_ast=False) + data = json.loads(results[0]["INFO"]) if results[0]["INFO"] else {} + assert isinstance(data, dict) and isinstance(data.get("response", {}), dict) + assert data["response"]["date"] == "Nov 26, 2016" + assert data["response"]["amount"] == "USD $950.00" + + +def test_dataframe_ai_extract_error_handling(session): + """Test error handling in DataFrame.ai.extract.""" + df = session.create_dataframe([["text"]], schema=["col"]) + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.extract( + input_column=123, # Invalid type + response_format={"a": "What is a?"}, + ) + + +def test_dataframe_ai_count_tokens_basic(session): + """Test DataFrame.ai.count_tokens with basic usage.""" + # Create a DataFrame with text data + df = session.create_dataframe( + [ + ["What is a large language model?"], + ["Explain quantum computing in simple terms."], + [ + "This is a much longer text with many more words to demonstrate token counting accurately." + ], + ], + schema=["text"], + ) + + # Count tokens using llama3.1-70b model + result_df = df.ai.count_tokens( + model="llama3.1-70b", + prompt="text", + output_column="token_count", + ) + + # Verify results + Utils.check_answer( + result_df.select("token_count"), + [Row(TOKEN_COUNT=8), Row(TOKEN_COUNT=9), Row(TOKEN_COUNT=17)], + ) + + +def test_dataframe_ai_count_tokens_default_output_column(session): + """Test DataFrame.ai.count_tokens with default output column name.""" + df = session.create_dataframe([["Sample text for counting"]], schema=["text"]) + + # Don't specify output_column, should use default + result_df = df.ai.count_tokens( + model="llama3.1-8b", + prompt="text", + ) + + results = result_df.collect() + assert len(results) == 1 + assert results[0]["COUNT_TOKENS_OUTPUT"] == 5 + + +def test_dataframe_ai_count_tokens_error_handling(session): + """Test error handling in DataFrame.ai.count_tokens.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid prompt type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.count_tokens( + model="llama3.1-70b", + prompt=123, # Invalid type + ) + + +def test_dataframe_ai_split_text_markdown_header_basic(session): + """Test DataFrame.ai.split_text_markdown_header with basic usage.""" + # Create a DataFrame with Markdown text + markdown_text = """# Introduction +This is the introduction section with some general information about the topic. + +## Background +Here we provide background context about the topic. This section contains important details. + +### Details +Some detailed information goes here with specific examples. + +## Methods +Description of methods used in the analysis. + +# Conclusion +Final thoughts and summary of the findings.""" + + df = session.create_dataframe( + [[markdown_text]], + schema=["document"], + ) + + # Split the markdown document + result_df = df.ai.split_text_markdown_header( + text_to_split="document", + headers_to_split_on={"#": "h1", "##": "h2", "###": "h3"}, + chunk_size=50, + overlap=10, + output_column="chunks", + ) + + # Verify results + results = result_df.select("chunks").collect() + chunks = json.loads(results[0][0]) + assert chunks == [ + { + "chunk": "This is the introduction section with some general", + "headers": {"h1": "Introduction"}, + }, + { + "chunk": "general information about the topic.", + "headers": {"h1": "Introduction"}, + }, + { + "chunk": "Here we provide background context about the", + "headers": {"h1": "Introduction", "h2": "Background"}, + }, + { + "chunk": "about the topic. This section contains important", + "headers": {"h1": "Introduction", "h2": "Background"}, + }, + { + "chunk": "important details.", + "headers": {"h1": "Introduction", "h2": "Background"}, + }, + { + "chunk": "Some detailed information goes here with specific", + "headers": {"h1": "Introduction", "h2": "Background", "h3": "Details"}, + }, + { + "chunk": "specific examples.", + "headers": {"h1": "Introduction", "h2": "Background", "h3": "Details"}, + }, + { + "chunk": "Description of methods used in the analysis.", + "headers": {"h1": "Introduction", "h2": "Methods"}, + }, + { + "chunk": "Final thoughts and summary of the findings.", + "headers": {"h1": "Conclusion"}, + }, + ] + + +def test_dataframe_ai_split_text_markdown_header_default_output(session): + """Test DataFrame.ai.split_text_markdown_header with default output column.""" + df = session.create_dataframe( + [["# Header\nContent"]], + schema=["doc"], + ) + + # Don't specify output_column, should use default + result_df = df.ai.split_text_markdown_header( + text_to_split="doc", + headers_to_split_on={"#": "h1"}, + chunk_size=20, + ) + + results = result_df.select("SPLIT_TEXT_MARKDOWN_HEADER_OUTPUT").collect() + chunks = json.loads(results[0][0]) + assert chunks == [{"chunk": "Content", "headers": {"h1": "Header"}}] + + +def test_dataframe_ai_split_text_markdown_header_error_handling(session): + """Test error handling in DataFrame.ai.split_text_markdown_header.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid text_to_split type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.split_text_markdown_header( + text_to_split=123, # Invalid type + headers_to_split_on={"#": "h1"}, + chunk_size=20, + ) + + +def test_dataframe_ai_split_text_recursive_character_basic(session): + """Test DataFrame.ai.split_text_recursive_character with basic usage.""" + long_text = """This is a long document with multiple sentences. It contains various information +that needs to be split into smaller chunks for processing. + +Another paragraph here with more content. The recursive splitter will try different +separators to create appropriately sized chunks. + +Final paragraph with concluding remarks.""" + + df = session.create_dataframe( + [[long_text]], + schema=["text"], + ) + + # Split the text recursively + result_df = df.ai.split_text_recursive_character( + text_to_split="text", + format="none", + chunk_size=50, + overlap=10, + output_column="chunks", + ) + + # Verify results + results = result_df.select("chunks").collect() + chunks = json.loads(results[0][0]) + assert chunks == [ + "This is a long document with multiple sentences.", + "It contains various information", + "that needs to be split into smaller chunks for", + "for processing.", + "Another paragraph here with more content. The", + "The recursive splitter will try different", + "separators to create appropriately sized chunks.", + "Final paragraph with concluding remarks.", + ] + + +def test_dataframe_ai_split_text_recursive_character_markdown_format(session): + """Test DataFrame.ai.split_text_recursive_character with markdown format.""" + markdown_text = """# Main Title + +This is the introduction paragraph with some text. + +## Section 1 + +Content for section 1 goes here. + +```python +def hello(): + print("Hello, World!") +``` + +## Section 2 + +More content in section 2. + +| Column 1 | Column 2 | +|----------|----------| +| Data 1 | Data 2 |""" + + df = session.create_dataframe( + [[markdown_text]], + schema=["content"], + ) + + # Split with markdown format + result_df = df.ai.split_text_recursive_character( + text_to_split=col("content"), + format="markdown", + chunk_size=40, + overlap=5, + output_column="md_chunks", + ) + + # Verify results + results = result_df.select("md_chunks").collect() + chunks = json.loads(results[0][0]) + assert chunks == [ + "# Main Title", + "This is the introduction paragraph with", + "with some text.", + "## Section 1", + "Content for section 1 goes here.", + "```python\ndef hello():", + 'print("Hello, World!")', + "```", + "## Section 2", + "More content in section 2.", + "| Column 1 | Column 2 |", + "|----------|----------|", + "| Data 1 | Data 2 |", + ] + + +def test_dataframe_ai_split_text_recursive_character_custom_separators(session): + """Test DataFrame.ai.split_text_recursive_character with custom separators.""" + code_text = """def function_one(): + # First function + return 1 + +def function_two(): + # Second function + return 2 + +def function_three(): + # Third function + return 3""" + + df = session.create_dataframe( + [[code_text]], + schema=["code"], + ) + + # Split with custom separators for code + result_df = df.ai.split_text_recursive_character( + text_to_split="code", + format="none", + chunk_size=35, + separators=["\n\n", "\n", " ", " ", ""], + output_column="code_chunks", + ) + + # Verify results + results = result_df.select("code_chunks").collect() + chunks = json.loads(results[0][0]) + assert chunks == [ + "def function_one():", + "# First function\n return 1", + "def function_two():", + "# Second function\n return 2", + "def function_three():", + "# Third function\n return 3", + ] + + +def test_dataframe_ai_split_text_recursive_character_default_output(session): + """Test DataFrame.ai.split_text_recursive_character with default output column.""" + df = session.create_dataframe( + [["Short text to split"]], + schema=["text"], + ) + + # Don't specify output_column, should use default + result_df = df.ai.split_text_recursive_character( + text_to_split="text", + format="none", + chunk_size=10, + ) + + # Verify results + results = result_df.select("SPLIT_TEXT_RECURSIVE_CHARACTER_OUTPUT").collect() + chunks = json.loads(results[0][0]) + assert chunks == ["Short text", "to split"] + + +def test_dataframe_ai_split_text_recursive_character_error_handling(session): + """Test error handling in DataFrame.ai.split_text_recursive_character.""" + df = session.create_dataframe([["test"]], schema=["text"]) + + # Test invalid text_to_split type + with pytest.raises(TypeError, match="expected Column or str"): + df.ai.split_text_recursive_character( + text_to_split=123, # Invalid type + format="none", + chunk_size=10, + ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 6336fba114..a1e00a0dd3 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -3,6 +3,7 @@ # from collections import defaultdict from unittest import mock +from snowflake.snowpark.functions import col import pytest import uuid @@ -10,6 +11,7 @@ import time from datetime import datetime from snowflake.snowpark._internal.utils import ( + create_prompt_column_from_template, ExprAliasUpdateDict, str_contains_alphabet, get_sorted_key_for_version, @@ -559,6 +561,130 @@ def test_get_plan_from_line_numbers_no_matching_child(): get_plan_from_line_numbers(orphan_plan, 4) +def test_create_prompt_column_from_template(): + # Mock the prompt function to capture its calls + with mock.patch("snowflake.snowpark.functions.prompt") as mock_prompt: + mock_prompt.return_value = mock.Mock(name="prompt_column") + + # Test: Dict with named placeholders - all placeholders used + template = "Analyze {review} with rating {rating}" + columns_dict = {"review": col("review_col"), "rating": col("rating_col")} + _ = create_prompt_column_from_template(template, columns_dict) + + # Check that prompt was called with positional template + mock_prompt.assert_called() + call_args = mock_prompt.call_args + assert "{0}" in call_args[0][0] and "{1}" in call_args[0][0] + assert len(call_args[0]) == 3 # template + 2 columns + + # Test: Dict with unused columns - should raise error + template = "Analyze {review} and {rating}" + columns_dict = { + "review": col("review_col"), + "rating": col("rating_col"), + "unused": col("unused_col"), + } + with pytest.raises(ValueError, match="not used in the template.*unused"): + create_prompt_column_from_template(template, columns_dict) + + # Test: Dict with missing placeholder - should raise error + template = "Analyze {review} and {sentiment}" + columns_dict = {"review": col("review_col")} + with pytest.raises(ValueError, match="Placeholder.*sentiment.*not found"): + create_prompt_column_from_template(template, columns_dict) + + # Test: List with matching positional placeholders + mock_prompt.reset_mock() + template = "Analyze {0} with rating {1}" + columns_list = [col("review_col"), col("rating_col")] + _ = create_prompt_column_from_template(template, columns_list) + + # Check that prompt was called with the original template + mock_prompt.assert_called_once() + call_args = mock_prompt.call_args + assert call_args[0][0] == template + assert len(call_args[0]) == 3 # template + 2 columns + + # Test: List with no placeholders - should auto-append + mock_prompt.reset_mock() + template = "Analyze this review and rating:" + columns_list = [col("review_col"), col("rating_col")] + _ = create_prompt_column_from_template(template, columns_list) + + # Check that placeholders were appended + mock_prompt.assert_called_once() + call_args = mock_prompt.call_args + expected_template = "Analyze this review and rating: {0} {1}" + assert call_args[0][0] == expected_template + + # Test: List with mismatched placeholder count - should raise error + template = "Analyze {0} with rating {1}" + columns_list = [col("review_col"), col("rating_col"), col("extra_col")] + with pytest.raises( + ValueError, match="Number of positional placeholders.*2.*does not match.*3" + ): + create_prompt_column_from_template(template, columns_list) + + # Test: Invalid input type - should raise error + template = "Some template" + invalid_input = "not_a_dict_or_list" + with pytest.raises(TypeError, match="must be a list of Columns or a dict"): + create_prompt_column_from_template(template, invalid_input) + + # Test: List with single column and no placeholder + mock_prompt.reset_mock() + template = "Process this:" + columns_list = [col("data_col")] + _ = create_prompt_column_from_template(template, columns_list) + + mock_prompt.assert_called_once() + call_args = mock_prompt.call_args + expected_template = "Process this: {0}" + assert call_args[0][0] == expected_template + + # Test: Dict with repeated placeholders + mock_prompt.reset_mock() + template = "Compare {value} with {value} and rate {rating}" + columns_dict = {"value": col("value_col"), "rating": col("rating_col")} + _ = create_prompt_column_from_template(template, columns_dict) + + # Should work fine, using same column twice + mock_prompt.assert_called() + call_args = mock_prompt.call_args + # Should have converted to "{0} with {0} and rate {1}" + assert call_args[0][0].count("{0}") == 2 + assert "{1}" in call_args[0][0] + + # Test: Empty list - should just append placeholders + mock_prompt.reset_mock() + template = "No placeholders here" + columns_list = [] + _ = create_prompt_column_from_template(template, columns_list) + + mock_prompt.assert_called_once() + call_args = mock_prompt.call_args + # With no columns, template should remain unchanged (no placeholders to add) + assert call_args[0][0] == "No placeholders here " + + # Test: List with duplicate placeholders + template = "Compare {0} with {0} and {1}" + columns_list = [col("first_col"), col("second_col")] + _ = create_prompt_column_from_template(template, columns_list) + + # Should work - {0} appears twice but we still have 2 unique placeholders + mock_prompt.assert_called() + + # Test: List with out-of-order placeholders + mock_prompt.reset_mock() + template = "First: {2}, Second: {0}, Third: {1}" + columns_list = [col("col_0"), col("col_1"), col("col_2")] + _ = create_prompt_column_from_template(template, columns_list) + + mock_prompt.assert_called_once() + call_args = mock_prompt.call_args + assert call_args[0][0] == template # Should keep original template + + def test_time_travel_config(): """Test TimeTravelConfig NamedTuple creation.""" config = TimeTravelConfig(time_travel_mode="at", statement="query_123") diff --git a/tests/utils.py b/tests/utils.py index ede0148a11..40d2c70303 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1707,6 +1707,22 @@ def test_dog_image(self): def test_books_xsd(self): return os.path.join(self.resources_path, "books.xsd") + @property + def test_audio_ogg(self): + return os.path.join(self.resources_path, "audio.ogg") + + @property + def test_conversation_ogg(self): + return os.path.join(self.resources_path, "conversation.ogg") + + @property + def test_doc_pdf(self): + return os.path.join(self.resources_path, "doc.pdf") + + @property + def test_invoice_pdf(self): + return os.path.join(self.resources_path, "invoice.pdf") + class TypeMap(NamedTuple): col_name: str