2626#include < algorithm>
2727#include < boost/iterator/iterator_facade.hpp>
2828#include < memory>
29+ #include < optional>
2930#include < ostream>
3031#include < string>
3132#include < utility>
5051#include " core/data_type/data_type_map.h"
5152#include " core/data_type/data_type_nullable.h"
5253#include " core/data_type/data_type_number.h"
54+ #include " core/data_type/data_type_struct.h"
5355#include " core/data_type/primitive_type.h"
5456#include " core/types.h"
5557#include " exprs/function/function.h"
@@ -77,6 +79,31 @@ class FunctionArrayElement : public IFunction {
7779
7880 size_t get_number_of_arguments () const override { return 2 ; }
7981
82+ // A struct field access (element_at(struct, const) / the struct_element alias) resolves to a
83+ // different return type depending on which field the constant index selects, so we need the
84+ // index column here. Array/map return types depend only on the argument types and fall through
85+ // to the DataTypes-based overload below.
86+ DataTypePtr get_return_type_impl (const ColumnsWithTypeAndName& arguments) const override {
87+ DataTypePtr arg_0 = remove_nullable (arguments[0 ].type );
88+ if (arg_0->get_primitive_type () == TYPE_STRUCT) {
89+ const auto * struct_type = check_and_get_data_type<DataTypeStruct>(arg_0.get ());
90+ size_t index = 0 ;
91+ // Throw the concrete error (field not found / out of bound) instead of returning
92+ // nullptr, which the framework would report as an opaque "return type check failed".
93+ Status st = get_struct_element_index (*struct_type, arguments[1 ].column ,
94+ arguments[1 ].type , &index);
95+ if (!st.ok ()) {
96+ throw doris::Exception (st);
97+ }
98+ return make_nullable (struct_type->get_elements ()[index]);
99+ }
100+ DataTypes data_types (arguments.size ());
101+ for (size_t i = 0 ; i < arguments.size (); ++i) {
102+ data_types[i] = arguments[i].type ;
103+ }
104+ return get_return_type_impl (data_types);
105+ }
106+
80107 DataTypePtr get_return_type_impl (const DataTypes& arguments) const override {
81108 DataTypePtr arg_0 = remove_nullable (arguments[0 ]);
82109 DCHECK (arg_0->get_primitive_type () == TYPE_ARRAY || arg_0->get_primitive_type () == TYPE_MAP)
@@ -101,6 +128,10 @@ class FunctionArrayElement : public IFunction {
101128
102129 Status execute_impl (FunctionContext* context, Block& block, const ColumnNumbers& arguments,
103130 uint32_t result, size_t input_rows_count) const override {
131+ if (remove_nullable (block.get_by_position (arguments[0 ]).type )->get_primitive_type () ==
132+ TYPE_STRUCT) {
133+ return _execute_struct (block, arguments, result, input_rows_count);
134+ }
104135 auto dst_null_column = ColumnUInt8::create (input_rows_count, 0 );
105136 UInt8* dst_null_map = dst_null_column->get_data ().data ();
106137 const UInt8* src_null_map = nullptr ;
@@ -153,6 +184,93 @@ class FunctionArrayElement : public IFunction {
153184 }
154185
155186private:
187+ // =========================== struct element===========================//
188+ // Resolve the 0-based field offset selected by a constant int/string index. Mirrors the logic
189+ // of the former struct_element function, which element_at now subsumes.
190+ Status get_struct_element_index (const DataTypeStruct& struct_type,
191+ const ColumnPtr& index_column, const DataTypePtr& index_type,
192+ size_t * result) const {
193+ if (!index_column) {
194+ return Status::RuntimeError (" Function {}: second argument column is nullptr." ,
195+ get_name ());
196+ }
197+ size_t index = 0 ;
198+ if (is_int_or_bool (index_type->get_primitive_type ())) {
199+ int64_t offset = index_column->get_int (0 );
200+ size_t limit = struct_type.get_elements ().size () + 1 ;
201+ if (offset < 1 || offset >= static_cast <int64_t >(limit)) {
202+ return Status::RuntimeError (
203+ " Index out of bound for function {}: index {} should base from 1 and less "
204+ " than {}." ,
205+ get_name (), offset, limit);
206+ }
207+ index = offset - 1 ; // the index starts from 1
208+ } else if (is_string_type (index_type->get_primitive_type ())) {
209+ std::string field_name = index_column->get_data_at (0 ).to_string ();
210+ std::optional<size_t > pos = struct_type.try_get_position_by_name (field_name);
211+ if (!pos.has_value ()) {
212+ return Status::RuntimeError (
213+ " Element not found for function {}: name {} not found in {}." , get_name (),
214+ field_name, struct_type.get_name ());
215+ }
216+ index = pos.value ();
217+ } else {
218+ return Status::RuntimeError (
219+ " Argument not supported for function {}: second arg type {} should be int or "
220+ " string." ,
221+ get_name (), index_type->get_name ());
222+ }
223+ *result = index;
224+ return Status::OK ();
225+ }
226+
227+ Status _execute_struct (Block& block, const ColumnNumbers& arguments, uint32_t result,
228+ size_t input_rows_count) const {
229+ const auto & struct_arg = block.get_by_position (arguments[0 ]);
230+ ColumnPtr struct_col_ptr = struct_arg.column ->convert_to_full_column_if_const ();
231+ // element_at manages nulls itself (use_default_implementation_for_nulls() == false), so a
232+ // null struct row must be merged into the result null map manually.
233+ const ColumnUInt8* outer_null_map = nullptr ;
234+ if (struct_col_ptr->is_nullable ()) {
235+ const auto * nullable = assert_cast<const ColumnNullable*>(struct_col_ptr.get ());
236+ outer_null_map = &nullable->get_null_map_column ();
237+ struct_col_ptr = nullable->get_nested_column_ptr ();
238+ }
239+ const auto * struct_type =
240+ check_and_get_data_type<DataTypeStruct>(remove_nullable (struct_arg.type ).get ());
241+ const auto * struct_col = check_and_get_column<ColumnStruct>(struct_col_ptr.get ());
242+ if (!struct_col || !struct_type) {
243+ return Status::RuntimeError (" unsupported types for function {}({}, {})" , get_name (),
244+ struct_arg.type ->get_name (),
245+ block.get_by_position (arguments[1 ]).type ->get_name ());
246+ }
247+ const auto & index_arg = block.get_by_position (arguments[1 ]);
248+ size_t index = 0 ;
249+ RETURN_IF_ERROR (
250+ get_struct_element_index (*struct_type, index_arg.column , index_arg.type , &index));
251+
252+ ColumnPtr field_col = struct_col->get_column_ptr (index);
253+ auto res_null_column = ColumnUInt8::create (input_rows_count, 0 );
254+ auto & res_null_map = res_null_column->get_data ();
255+ ColumnPtr res_nested = field_col;
256+ if (field_col->is_nullable ()) {
257+ const auto * field_nullable = assert_cast<const ColumnNullable*>(field_col.get ());
258+ const auto & field_null_map = field_nullable->get_null_map_column ().get_data ();
259+ memcpy (res_null_map.data (), field_null_map.data (), input_rows_count);
260+ res_nested = field_nullable->get_nested_column_ptr ();
261+ }
262+ if (outer_null_map) {
263+ const auto & outer = outer_null_map->get_data ();
264+ for (size_t i = 0 ; i < input_rows_count; ++i) {
265+ res_null_map[i] |= outer[i];
266+ }
267+ }
268+ block.replace_by_position (
269+ result, ColumnNullable::create (res_nested->clone_resized (input_rows_count),
270+ std::move (res_null_column)));
271+ return Status::OK ();
272+ }
273+
156274 // =========================== map element===========================//
157275 ColumnPtr _get_mapped_idx (const ColumnArray& column,
158276 const ColumnWithTypeAndName& argument) const {
0 commit comments