@@ -2086,31 +2086,21 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) {
20862086 debug (3 ) << " \n " ;
20872087
20882088 if (arg_ids.size () == 1 ) {
2089-
20902089 // 1 argument, just do a simple assignment via a cast
20912090 SpvId result_id = cast_type (op->type , op->vectors [0 ].type (), arg_ids[0 ]);
20922091 builder.update_id (result_id);
20932092
20942093 } else if (arg_ids.size () == 2 ) {
2095-
2096- // 2 arguments, use a composite insert to update even and odd indices
2097- uint32_t even_idx = 0 ;
2098- uint32_t odd_idx = 1 ;
2099- SpvFactory::Indices even_indices;
2100- SpvFactory::Indices odd_indices;
2101- for (int i = 0 ; i < op_lanes; ++i) {
2102- even_indices.push_back (even_idx);
2103- odd_indices.push_back (odd_idx);
2104- even_idx += 2 ;
2105- odd_idx += 2 ;
2094+ // 2 arguments, use vector-shuffle with logical indices indexing into (vec1[0], vec1[1], ..., vec2[0], vec2[1], ...)
2095+ SpvFactory::Indices logical_indices;
2096+ for (int i = 0 ; i < arg_lanes; ++i) {
2097+ logical_indices.push_back (uint32_t (i));
2098+ logical_indices.push_back (uint32_t (i + arg_lanes));
21062099 }
21072100
21082101 SpvId type_id = builder.declare_type (op->type );
2109- SpvId value_id = builder.declare_null_constant (op->type );
2110- SpvId partial_id = builder.reserve_id (SpvResultId);
21112102 SpvId result_id = builder.reserve_id (SpvResultId);
2112- builder.append (SpvFactory::composite_insert (type_id, partial_id, arg_ids[0 ], value_id, even_indices));
2113- builder.append (SpvFactory::composite_insert (type_id, result_id, arg_ids[1 ], partial_id, odd_indices));
2103+ builder.append (SpvFactory::vector_shuffle (type_id, result_id, arg_ids[0 ], arg_ids[1 ], logical_indices));
21142104 builder.update_id (result_id);
21152105
21162106 } else {
@@ -2140,7 +2130,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) {
21402130 } else if (op->is_extract_element ()) {
21412131 int idx = op->indices [0 ];
21422132 internal_assert (idx >= 0 );
2143- internal_assert (idx <= op->vectors [0 ].type ().lanes ());
2133+ internal_assert (idx < op->vectors [0 ].type ().lanes ());
21442134 if (op->vectors [0 ].type ().is_vector ()) {
21452135 SpvFactory::Indices indices = {(uint32_t )idx};
21462136 SpvId type_id = builder.declare_type (op->type );
0 commit comments