77#include " aclnn/aclnn_base.h"
88#include " aclnn_add.h"
99#include " aclnn_rms_norm.h"
10- #include " ascend/add_rms_norm/registry.h"
1110#include " ascend/common.h"
1211#include " ascend/workspace_pool_.h"
12+ #include " base/add_rms_norm.h"
1313#include " operator.h"
1414
1515namespace infini ::ops {
@@ -74,12 +74,12 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
7474 aclSetInputTensorAddr (add_exec_, 1 , t_x2, const_cast <void *>(x2.data ()));
7575 aclSetOutputTensorAddr (add_exec_, 0 , t_x_out, x_out.data ());
7676 }
77- auto & add_arena = ascend::workspacePool ().ensure (stream, add_ws_);
77+ auto & add_arena = ascend::GetWorkspacePool ().Ensure (stream, add_ws_);
7878 aclnnAdd (add_arena.buf , add_ws_, add_exec_, stream);
7979
8080 // Obtain shared rstd buffer from pool.
8181 auto & rstd_arena =
82- ascend::workspacePool ().ensure (stream, rstd_size_, " temp" );
82+ ascend::GetWorkspacePool ().Ensure (stream, rstd_size_, " temp" );
8383
8484 // Lazily create rstd tensor descriptor on first call.
8585 if (!rstd_tensor_) {
@@ -102,7 +102,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
102102 aclSetOutputTensorAddr (norm_exec_, 0 , t_y_out, y_out.data ());
103103 aclSetOutputTensorAddr (norm_exec_, 1 , rstd_tensor_, rstd_arena.buf );
104104 }
105- auto & norm_arena = ascend::workspacePool ().ensure (stream, norm_ws_);
105+ auto & norm_arena = ascend::GetWorkspacePool ().Ensure (stream, norm_ws_);
106106 aclnnRmsNorm (norm_arena.buf , norm_ws_, norm_exec_, stream);
107107 }
108108
0 commit comments