@@ -2890,12 +2890,17 @@ struct test_cpy : public test_case {
28902890 const std::array<int64_t , 4 > ne_dst;
28912891 const std::array<int64_t , 4 > permute_src;
28922892 const std::array<int64_t , 4 > permute_dst;
2893+ const std::array<int64_t , 4 > dst_alloc; // if set, dst is a view into a larger buffer (strided)
28932894 bool _src_use_permute;
28942895 bool _dst_use_permute;
28952896 bool _src_transpose;
28962897 bool _use_dst_shape;
2898+ bool _use_dst_alloc;
28972899
28982900 std::string vars () override {
2901+ if (_use_dst_alloc) {
2902+ return VARS_TO_STR8 (type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose, dst_alloc);
2903+ }
28992904 if (_use_dst_shape) {
29002905 return VARS_TO_STR7 (type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose);
29012906 }
@@ -2943,12 +2948,15 @@ struct test_cpy : public test_case {
29432948 std::array<int64_t , 4 > ne_dst = {-1 , -1 , -1 , -1 },
29442949 std::array<int64_t , 4 > permute_src = {0 , 0 , 0 , 0 },
29452950 std::array<int64_t , 4 > permute_dst = {0 , 0 , 0 , 0 },
2946- bool transpose_src = false )
2951+ bool transpose_src = false ,
2952+ std::array<int64_t , 4 > dst_alloc = {0 , 0 , 0 , 0 })
29472953 : type_src(type_src), type_dst(type_dst), ne_src(ne_src), ne_dst(ne_dst), permute_src(permute_src), permute_dst(permute_dst),
2954+ dst_alloc (dst_alloc),
29482955 _src_use_permute(permute_src[0 ] + permute_src[1 ] + permute_src[2 ] + permute_src[3 ] > 0 ),
29492956 _dst_use_permute(permute_dst[0 ] + permute_dst[1 ] + permute_dst[2 ] + permute_dst[3 ] > 0 ),
29502957 _src_transpose(transpose_src),
2951- _use_dst_shape(ne_dst[0 ] >= 0 && ne_dst[1 ] >= 0 && ne_dst[2 ] >= 0 && ne_dst[3 ] >= 0 ){}
2958+ _use_dst_shape(ne_dst[0 ] >= 0 && ne_dst[1 ] >= 0 && ne_dst[2 ] >= 0 && ne_dst[3 ] >= 0 ),
2959+ _use_dst_alloc(dst_alloc[0 ] > 0 ){}
29522960
29532961 ggml_tensor * build_graph (ggml_context * ctx) override {
29542962 ggml_tensor * src = ggml_new_tensor (ctx, type_src, 4 , ne_src.data ());
@@ -2966,12 +2974,23 @@ struct test_cpy : public test_case {
29662974 }
29672975
29682976 std::array<int64_t , 4 > dst_ne = _use_dst_shape ? ne_dst : std::array<int64_t , 4 >{src->ne [0 ], src->ne [1 ], src->ne [2 ], src->ne [3 ]};
2969- ggml_tensor * dst = ggml_new_tensor (ctx, type_dst, 4 , dst_ne.data ());
2970- ggml_set_name (dst, " dst" );
2977+ ggml_tensor * dst;
29712978
2972- if (_dst_use_permute) {
2973- dst = ggml_permute (ctx, dst, permute_dst[0 ], permute_dst[1 ], permute_dst[2 ], permute_dst[3 ]);
2974- ggml_set_name (dst, " dst_permuted" );
2979+ if (_use_dst_alloc) {
2980+ // view a sub-block of a larger buffer -> strided dst
2981+ ggml_tensor * dst_buf = ggml_new_tensor (ctx, type_dst, 4 , dst_alloc.data ());
2982+ ggml_set_name (dst_buf, " dst_buf" );
2983+ dst = ggml_view_4d (ctx, dst_buf, dst_ne[0 ], dst_ne[1 ], dst_ne[2 ], dst_ne[3 ],
2984+ dst_buf->nb [1 ], dst_buf->nb [2 ], dst_buf->nb [3 ], 0 );
2985+ ggml_set_name (dst, " dst_view" );
2986+ } else {
2987+ dst = ggml_new_tensor (ctx, type_dst, 4 , dst_ne.data ());
2988+ ggml_set_name (dst, " dst" );
2989+
2990+ if (_dst_use_permute) {
2991+ dst = ggml_permute (ctx, dst, permute_dst[0 ], permute_dst[1 ], permute_dst[2 ], permute_dst[3 ]);
2992+ ggml_set_name (dst, " dst_permuted" );
2993+ }
29752994 }
29762995
29772996 ggml_tensor * out = ggml_cpy (ctx, src, dst);
@@ -8181,6 +8200,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
81818200 test_cases.emplace_back (new test_cpy (GGML_TYPE_F32 , GGML_TYPE_F32 , {256 , 1 , 4 , 1 }, {-1 ,-1 ,-1 ,-1 }, {1 , 2 , 0 , 3 }, {0 , 0 , 0 , 0 }));
81828201 test_cases.emplace_back (new test_cpy (GGML_TYPE_F32 , GGML_TYPE_F32 , {2 , 2097121 , 1 , 1 }, {-1 ,-1 ,-1 ,-1 }, {1 , 0 , 2 , 3 }));
81838202 test_cases.emplace_back (new test_cpy (GGML_TYPE_F32 , GGML_TYPE_F32 , {2 , 2 , 524281 , 1 }, {-1 ,-1 ,-1 ,-1 }, {1 , 0 , 2 , 3 }));
8203+ test_cases.emplace_back (new test_cpy (GGML_TYPE_F32 , GGML_TYPE_F32 , {128 , 2 , 3 , 1 }, {128 , 2 , 3 , 1 }, {0 , 0 , 0 , 0 }, {0 , 0 , 0 , 0 }, false , {128 , 4 , 3 , 1 })); // strided dst
8204+ test_cases.emplace_back (new test_cpy (GGML_TYPE_F16 , GGML_TYPE_F16 , {128 , 2 , 3 , 1 }, {128 , 2 , 3 , 1 }, {0 , 0 , 0 , 0 }, {0 , 0 , 0 , 0 }, false , {128 , 4 , 3 , 1 })); // strided dst
81848205
81858206 // CPY - different src/dst shapes (reshaping via CPY)
81868207 // Use permutations of {3, 5, 7, 32}. Total elements: 3*5*7*32 = 3360.
0 commit comments