@@ -64,14 +64,14 @@ void add_embedding_q4gsw_node(
6464 const ValueRef weight,
6565 const ValueRef weight_scales,
6666 const int32_t group_size,
67- const int32_t embed_dim,
68- const int32_t num_indices,
69- const int32_t out_height,
7067 const int32_t is_linear_weight,
71- const ValueRef out) {
68+ const ValueRef out,
69+ const ValueRef embed_dim_ref) {
7270 VK_CHECK_COND (graph.packed_dim_of (out) == WHCN::kWidthDim );
7371 VK_CHECK_COND (graph.packed_dim_of (indices) == WHCN::kWidthDim );
74- VK_CHECK_COND (embed_dim % 32 == 0 , " embed_dim must be a multiple of 32" );
72+ VK_CHECK_COND (
73+ graph.get_int (embed_dim_ref) % 32 == 0 ,
74+ " embed_dim must be a multiple of 32" );
7575
7676 std::string kernel_name = " embedding_q4gsw" ;
7777 kernel_name.reserve (kShaderNameReserve );
@@ -91,21 +91,18 @@ void add_embedding_q4gsw_node(
9191
9292 std::vector<PushConstantDataInfo> push_constants = {
9393 PushConstantDataInfo (&group_size, sizeof (group_size)),
94- PushConstantDataInfo (&embed_dim, sizeof (embed_dim)),
95- PushConstantDataInfo (&num_indices, sizeof (num_indices)),
96- PushConstantDataInfo (&out_height, sizeof (out_height)),
9794 PushConstantDataInfo (&is_linear_weight, sizeof (is_linear_weight)),
9895 };
9996
100- ValueRef embed_dim_ref = graph.add_scalar < int64_t >(embed_dim) ;
97+ vkapi::ParamsBindList param_ubos = { graph.sizes_ubo (out)} ;
10198
10299 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
103100 graph,
104101 VK_KERNEL_FROM_STR (kernel_name),
105102 pick_embedding_q4gsw_global_wg_size,
106103 default_pick_local_wg_size,
107104 {{out, vkapi::kWrite }, {{indices, weight, weight_scales}, vkapi::kRead }},
108- {} ,
105+ param_ubos ,
109106 push_constants,
110107 {},
111108 {embed_dim_ref},
@@ -125,14 +122,8 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
125122 graph.extract_scalar <bool >(is_linear_weight_ref) ? 1 : 0 ;
126123
127124 const std::vector<int64_t > weight_sizes = graph.sizes_of (weight_data);
128- int32_t embed_dim = static_cast <int32_t >(weight_sizes.back () * 2 );
129-
130- const std::vector<int64_t > indices_sizes = graph.sizes_of (indices);
131- int32_t num_indices = 1 ;
132- for (auto s : indices_sizes) {
133- num_indices *= static_cast <int32_t >(s);
134- }
135- int32_t out_height = static_cast <int32_t >(indices_sizes.back ());
125+ int64_t embed_dim = weight_sizes.back () * 2 ;
126+ ValueRef embed_dim_ref = graph.add_scalar <int64_t >(embed_dim);
136127
137128 ValueRef weight;
138129 if (is_linear_weight) {
@@ -152,11 +143,9 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
152143 weight,
153144 weight_scales,
154145 group_size,
155- embed_dim,
156- num_indices,
157- out_height,
158146 is_linear_weight,
159- out);
147+ out,
148+ embed_dim_ref);
160149}
161150
162151REGISTER_OPERATORS {
0 commit comments