@@ -43,7 +43,47 @@ layout will be empty.")code")
4343 .AddOptionalArg(" new_axis_names" , R"code( Names of the new dimensions in the data layout.
4444
4545The length of `new_axis_names` must match the length of `axes`.
46- If argument isn't be provided, the layout will be cleared.)code" , TensorLayout(" " ));
46+ If argument isn't be provided, the layout will be cleared.)code" , TensorLayout(" " ))
47+ .OutputNDim(0 , [](const OpSpec &spec)->std::optional<int > {
48+ auto &desc = spec.InputDesc (0 );
49+ if (!desc.ndim )
50+ return std::nullopt ;
51+ return *desc.ndim + spec.GetRepeatedArgument <int >(" axes" ).size ();
52+ })
53+ .OutputLayout(0 , [](const OpSpec &spec)->std::optional<TensorLayout> {
54+ auto &desc = spec.InputDesc (0 );
55+ if (!desc.layout )
56+ return std::nullopt ;
57+
58+ auto axes = spec.GetRepeatedArgument <int >(" axes" );
59+ if (axes.empty ())
60+ return desc.layout ;
61+
62+ auto names = spec.GetArgument <TensorLayout>(" axis_names" );
63+ int num_new_axes = ssize (axes);
64+ if (num_new_axes != names.ndim ())
65+ return " " ;
66+
67+ SmallVector<std::pair<int , char >, 6 > ind_with_layout;
68+ for (size_t i = 0 ; i < axes.size (); i++) {
69+ ind_with_layout.push_back ({ i, names[i] });
70+ }
71+ std::sort (ind_with_layout.begin (), ind_with_layout.end ());
72+
73+ TensorLayout out_layout = " " ;
74+ int out_ndim = desc.layout ->ndim () + names.ndim ();
75+ int src_axis = 0 ;
76+ int new_axis = 0 ;
77+ for (int j = 0 ; j < out_ndim; j++) {
78+ if (new_axis < num_new_axes && axes[new_axis] == j) { // inserting new axis
79+ out_layout += names[new_axis++];
80+ } else {
81+ assert (src_axis < desc.layout ->ndim ());
82+ out_layout += (*desc.layout )[src_axis++];
83+ }
84+ }
85+ return out_layout;
86+ });
4787
4888template <typename Backend>
4989ExpandDims<Backend>::ExpandDims(const OpSpec &spec)
0 commit comments