Skip to content

Commit bda405f

Browse files
committed
Address review: use make_tensor_ptr and store defaults by value
Replace manual TensorImpl construction with make_tensor_ptr from extension/tensor, removing the #ifdef USE_ATEN_LIB block and simplifying the destructor. Store defaults_ by value since it is always initialized.
1 parent 5f1c90c commit bda405f

4 files changed

Lines changed: 48 additions & 57 deletions

File tree

extension/training/optimizer/adamw.cpp

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88

99
#include <executorch/extension/training/optimizer/adamw.h>
1010

11+
#include <executorch/extension/tensor/tensor_ptr.h>
1112
#include <executorch/runtime/core/error.h>
1213

1314
#include <cmath>
1415
#include <cstring>
1516

1617
using executorch::aten::Tensor;
17-
using executorch::aten::TensorImpl;
18+
using executorch::extension::make_tensor_ptr;
19+
using executorch::extension::TensorPtr;
1820
using ::executorch::runtime::Error;
1921

2022
namespace executorch {
@@ -116,16 +118,15 @@ AdamWParamGroup::named_parameters() const {
116118
void AdamW::add_param_group(const AdamWParamGroup& param_group) {
117119
AdamWParamGroup param_group_(param_group.named_parameters());
118120
if (!param_group.has_options()) {
119-
param_group_.set_options(defaults_->clone());
121+
param_group_.set_options(defaults_.clone());
120122
} else {
121123
param_group_.set_options(param_group.options().clone());
122124
}
123125
param_groups_.emplace_back(std::move(param_group_));
124126
}
125127

126-
Error AdamW::step(
127-
const std::map<std::string_view, executorch::aten::Tensor>&
128-
named_gradients) {
128+
Error AdamW::step(const std::map<std::string_view, executorch::aten::Tensor>&
129+
named_gradients) {
129130
for (auto& group : param_groups_) {
130131
auto& options = static_cast<AdamWOptions&>(group.options());
131132
const double lr = options.lr();
@@ -162,29 +163,23 @@ Error AdamW::step(
162163
std::memset(m_buf_ptr, 0, g.nbytes());
163164
std::memset(v_buf_ptr, 0, g.nbytes());
164165

165-
Tensor m_buf(nullptr);
166-
Tensor v_buf(nullptr);
167-
#ifdef USE_ATEN_LIB
168-
std::vector<int64_t> sizes(g.sizes().begin(), g.sizes().end());
169-
m_buf = torch::from_blob(m_buf_ptr, sizes, g.scalar_type());
170-
v_buf = torch::from_blob(v_buf_ptr, sizes, g.scalar_type());
171-
#else
172-
TensorImpl* m_impl = new TensorImpl(
173-
g.scalar_type(),
174-
g.sizes().size(),
175-
const_cast<TensorImpl::SizesType*>(g.sizes().data()),
166+
std::vector<executorch::aten::SizesType> sizes(
167+
g.sizes().begin(), g.sizes().end());
168+
auto m_ptr = make_tensor_ptr(
169+
sizes,
176170
m_buf_ptr,
177-
const_cast<TensorImpl::DimOrderType*>(g.dim_order().data()));
178-
TensorImpl* v_impl = new TensorImpl(
179171
g.scalar_type(),
180-
g.sizes().size(),
181-
const_cast<TensorImpl::SizesType*>(g.sizes().data()),
172+
executorch::aten::TensorShapeDynamism::STATIC,
173+
[](void* p) { free(p); });
174+
auto v_ptr = make_tensor_ptr(
175+
sizes,
182176
v_buf_ptr,
183-
const_cast<TensorImpl::DimOrderType*>(g.dim_order().data()));
184-
m_buf = Tensor(m_impl);
185-
v_buf = Tensor(v_impl);
186-
#endif
187-
auto state = std::make_unique<AdamWParamState>(m_buf, v_buf);
177+
g.scalar_type(),
178+
executorch::aten::TensorShapeDynamism::STATIC,
179+
[](void* p) { free(p); });
180+
181+
auto state = std::make_unique<AdamWParamState>(
182+
std::move(m_ptr), std::move(v_ptr));
188183
state_ptr = state.get();
189184
state_[p.unsafeGetTensorImpl()] = std::move(state);
190185
} else {
@@ -214,17 +209,7 @@ Error AdamW::step(
214209
return Error::Ok;
215210
}
216211

217-
AdamW::~AdamW() {
218-
for (const auto& state_kv : state_) {
219-
auto& state = *state_kv.second;
220-
free(state.exp_avg().unsafeGetTensorImpl()->mutable_data());
221-
free(state.exp_avg_sq().unsafeGetTensorImpl()->mutable_data());
222-
#ifndef USE_ATEN_LIB
223-
delete state.exp_avg().unsafeGetTensorImpl();
224-
delete state.exp_avg_sq().unsafeGetTensorImpl();
225-
#endif
226-
}
227-
}
212+
AdamW::~AdamW() = default;
228213

229214
} // namespace optimizer
230215
} // namespace training

extension/training/optimizer/adamw.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
#pragma once
1717

18+
#include <executorch/extension/tensor/tensor_ptr.h>
1819
#include <executorch/runtime/core/error.h>
1920
#include <executorch/runtime/core/exec_aten/exec_aten.h>
2021
#include <cstdint>
@@ -41,16 +42,18 @@ class ET_EXPERIMENTAL AdamWParamState {
4142
* @param[in] exp_avg_sq The second moment (EMA of squared gradients) buffer.
4243
*/
4344
AdamWParamState(
44-
executorch::aten::Tensor& exp_avg,
45-
executorch::aten::Tensor& exp_avg_sq)
46-
: exp_avg_(exp_avg), exp_avg_sq_(exp_avg_sq), step_count_(0) {}
45+
executorch::extension::TensorPtr exp_avg,
46+
executorch::extension::TensorPtr exp_avg_sq)
47+
: exp_avg_(std::move(exp_avg)),
48+
exp_avg_sq_(std::move(exp_avg_sq)),
49+
step_count_(0) {}
4750

4851
executorch::aten::Tensor& exp_avg() {
49-
return exp_avg_;
52+
return *exp_avg_;
5053
}
5154

5255
executorch::aten::Tensor& exp_avg_sq() {
53-
return exp_avg_sq_;
56+
return *exp_avg_sq_;
5457
}
5558

5659
int64_t step_count() const {
@@ -62,8 +65,8 @@ class ET_EXPERIMENTAL AdamWParamState {
6265
}
6366

6467
private:
65-
executorch::aten::Tensor exp_avg_;
66-
executorch::aten::Tensor exp_avg_sq_;
68+
executorch::extension::TensorPtr exp_avg_;
69+
executorch::extension::TensorPtr exp_avg_sq_;
6770
int64_t step_count_;
6871
};
6972

@@ -179,7 +182,7 @@ class ET_EXPERIMENTAL AdamW {
179182
explicit AdamW(
180183
const std::vector<AdamWParamGroup>& param_groups,
181184
AdamWOptions defaults)
182-
: defaults_(std::make_unique<AdamWOptions>(defaults)) {
185+
: defaults_(defaults) {
183186
for (const auto& param_group : param_groups) {
184187
add_param_group(param_group);
185188
}
@@ -209,7 +212,7 @@ class ET_EXPERIMENTAL AdamW {
209212
private:
210213
std::vector<AdamWParamGroup> param_groups_;
211214
std::unordered_map<void*, std::unique_ptr<AdamWParamState>> state_;
212-
std::unique_ptr<AdamWOptions> defaults_;
215+
AdamWOptions defaults_;
213216
};
214217

215218
} // namespace optimizer

extension/training/optimizer/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def define_common_targets():
4848
"adamw.h",
4949
],
5050
exported_deps = [
51+
"//executorch/extension/tensor:tensor" + aten_suffix,
5152
"//executorch/runtime/core:core",
5253
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
5354
],

extension/training/optimizer/test/adamw_test.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/extension/tensor/tensor_ptr.h>
910
#include <executorch/extension/training/optimizer/adamw.h>
1011
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1112
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -33,10 +34,11 @@ class AdamWOptimizerTest : public ::testing::Test {
3334
};
3435

3536
TEST_F(AdamWOptimizerTest, AdamWParamStateTest) {
36-
TensorFactory<ScalarType::Float> tf;
37-
Tensor exp_avg = tf.make({2, 2}, {0, 0, 0, 0});
38-
Tensor exp_avg_sq = tf.make({2, 2}, {0, 0, 0, 0});
39-
AdamWParamState state(exp_avg, exp_avg_sq);
37+
auto exp_avg =
38+
executorch::extension::make_tensor_ptr({2, 2}, {0.f, 0.f, 0.f, 0.f});
39+
auto exp_avg_sq =
40+
executorch::extension::make_tensor_ptr({2, 2}, {0.f, 0.f, 0.f, 0.f});
41+
AdamWParamState state(std::move(exp_avg), std::move(exp_avg_sq));
4042

4143
EXPECT_EQ(state.step_count(), 0);
4244
state.increment_step_count();
@@ -78,8 +80,8 @@ TEST_F(AdamWOptimizerTest, AdamWOptimizerSimple) {
7880
optimizer.step(named_gradients);
7981
}
8082

81-
auto p1 = static_cast<const float*>(
82-
named_parameters.at("param1").const_data_ptr());
83+
auto p1 =
84+
static_cast<const float*>(named_parameters.at("param1").const_data_ptr());
8385
// With a constant gradient of -1 and no weight decay, the bias-corrected
8486
// m_hat / sqrt(v_hat) is ~= -1 at every step, so each step shifts p by
8587
// +lr. After 10 steps of lr=0.1, p should be near 2.0.
@@ -103,8 +105,8 @@ TEST_F(AdamWOptimizerTest, AdamWOptimizerDecoupledWeightDecay) {
103105
named_gradients.insert({"param1", tf.make({1, 1}, {0.0})});
104106
optimizer.step(named_gradients);
105107

106-
auto p1 = static_cast<const float*>(
107-
named_parameters.at("param1").const_data_ptr());
108+
auto p1 =
109+
static_cast<const float*>(named_parameters.at("param1").const_data_ptr());
108110
EXPECT_NEAR(p1[0], 0.95, 1e-5);
109111
}
110112

@@ -124,10 +126,10 @@ TEST_F(AdamWOptimizerTest, AdamWOptimizerMultipleParams) {
124126
optimizer.step(named_gradients);
125127
}
126128

127-
auto p1 = static_cast<const float*>(
128-
named_parameters.at("param1").const_data_ptr());
129-
auto p2 = static_cast<const float*>(
130-
named_parameters.at("param2").const_data_ptr());
129+
auto p1 =
130+
static_cast<const float*>(named_parameters.at("param1").const_data_ptr());
131+
auto p2 =
132+
static_cast<const float*>(named_parameters.at("param2").const_data_ptr());
131133
// Each param sees a constant gradient of +/- 1 for 5 steps -> p shifts by
132134
// roughly +/- 5 * lr = +/- 0.5. State is tracked independently per param.
133135
EXPECT_NEAR(p1[0], 1.5, 0.1);

0 commit comments

Comments
 (0)