Skip to content

Commit a46c64e

Browse files
committed
Fix use after move
1 parent 9ab3913 commit a46c64e

1 file changed

Lines changed: 60 additions & 41 deletions

File tree

mlx/ops.cpp

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,11 +1658,9 @@ std::vector<array> broadcast_arrays(
16581658
if (in.shape() == out_shape) {
16591659
outputs.push_back(in);
16601660
} else {
1661-
outputs.push_back(array(
1662-
std::move(out_shape),
1663-
in.dtype(),
1664-
std::make_shared<Broadcast>(to_stream(s), out_shape),
1665-
{in}));
1661+
auto prim = std::make_shared<Broadcast>(to_stream(s), out_shape);
1662+
outputs.push_back(
1663+
array(std::move(out_shape), in.dtype(), std::move(prim), {in}));
16661664
}
16671665
}
16681666
return outputs;
@@ -1766,17 +1764,20 @@ std::pair<array, array> broadcast_arrays(
17661764
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
17671765
auto dtype = promote_types(a.dtype(), b.dtype());
17681766
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
1769-
auto& shape = inputs[0].shape();
1767+
auto shape = inputs[0].shape();
17701768
return array(
1771-
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
1769+
std::move(shape),
1770+
bool_,
1771+
std::make_shared<Equal>(to_stream(s)),
1772+
std::move(inputs));
17721773
}
17731774

17741775
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
17751776
auto dtype = promote_types(a.dtype(), b.dtype());
17761777
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
1777-
auto& shape = inputs[0].shape();
1778+
auto shape = inputs[0].shape();
17781779
return array(
1779-
shape,
1780+
std::move(shape),
17801781
bool_,
17811782
std::make_shared<NotEqual>(to_stream(s)),
17821783
std::move(inputs));
@@ -1785,9 +1786,12 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
17851786
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
17861787
auto dtype = promote_types(a.dtype(), b.dtype());
17871788
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
1788-
auto& shape = inputs[0].shape();
1789+
auto shape = inputs[0].shape();
17891790
return array(
1790-
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
1791+
std::move(shape),
1792+
bool_,
1793+
std::make_shared<Greater>(to_stream(s)),
1794+
std::move(inputs));
17911795
}
17921796

17931797
array greater_equal(
@@ -1796,9 +1800,9 @@ array greater_equal(
17961800
StreamOrDevice s /* = {} */) {
17971801
auto dtype = promote_types(a.dtype(), b.dtype());
17981802
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
1799-
auto& shape = inputs[0].shape();
1803+
auto shape = inputs[0].shape();
18001804
return array(
1801-
shape,
1805+
std::move(shape),
18021806
bool_,
18031807
std::make_shared<GreaterEqual>(to_stream(s)),
18041808
std::move(inputs));
@@ -1807,17 +1811,20 @@ array greater_equal(
18071811
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
18081812
auto dtype = promote_types(a.dtype(), b.dtype());
18091813
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
1810-
auto& shape = inputs[0].shape();
1814+
auto shape = inputs[0].shape();
18111815
return array(
1812-
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
1816+
std::move(shape),
1817+
bool_,
1818+
std::make_shared<Less>(to_stream(s)),
1819+
std::move(inputs));
18131820
}
18141821

18151822
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
18161823
auto dtype = promote_types(a.dtype(), b.dtype());
18171824
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
1818-
auto& shape = inputs[0].shape();
1825+
auto shape = inputs[0].shape();
18191826
return array(
1820-
shape,
1827+
std::move(shape),
18211828
bool_,
18221829
std::make_shared<LessEqual>(to_stream(s)),
18231830
std::move(inputs));
@@ -2811,9 +2818,9 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
28112818
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
28122819
// Broadcast arrays to a common shape
28132820
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
2814-
auto& shape = inputs[0].shape();
2821+
auto shape = inputs[0].shape();
28152822
return array(
2816-
shape,
2823+
std::move(shape),
28172824
bool_,
28182825
std::make_shared<LogicalAnd>(to_stream(s)),
28192826
std::move(inputs));
@@ -2825,9 +2832,9 @@ array operator&&(const array& a, const array& b) {
28252832
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
28262833
// Broadcast arrays to a common shape
28272834
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
2828-
auto& shape = inputs[0].shape();
2835+
auto shape = inputs[0].shape();
28292836
return array(
2830-
shape,
2837+
std::move(shape),
28312838
bool_,
28322839
std::make_shared<LogicalOr>(to_stream(s)),
28332840
std::move(inputs));
@@ -2845,9 +2852,12 @@ array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
28452852
auto out_type = promote_types(a.dtype(), b.dtype());
28462853
auto inputs =
28472854
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
2848-
auto& shape = inputs[0].shape();
2855+
auto shape = inputs[0].shape();
28492856
return array(
2850-
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
2857+
std::move(shape),
2858+
out_type,
2859+
std::make_shared<Add>(to_stream(s)),
2860+
std::move(inputs));
28512861
}
28522862

28532863
array operator+(const array& a, const array& b) {
@@ -2858,9 +2868,9 @@ array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
28582868
auto out_type = promote_types(a.dtype(), b.dtype());
28592869
auto inputs =
28602870
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
2861-
auto& shape = inputs[0].shape();
2871+
auto shape = inputs[0].shape();
28622872
return array(
2863-
shape,
2873+
std::move(shape),
28642874
out_type,
28652875
std::make_shared<Subtract>(to_stream(s)),
28662876
std::move(inputs));
@@ -2874,9 +2884,9 @@ array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
28742884
auto out_type = promote_types(a.dtype(), b.dtype());
28752885
auto inputs =
28762886
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
2877-
auto& shape = inputs[0].shape();
2887+
auto shape = inputs[0].shape();
28782888
return array(
2879-
shape,
2889+
std::move(shape),
28802890
out_type,
28812891
std::make_shared<Multiply>(to_stream(s)),
28822892
std::move(inputs));
@@ -2890,9 +2900,12 @@ array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
28902900
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
28912901
auto inputs = broadcast_arrays(
28922902
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
2893-
auto& shape = inputs[0].shape();
2903+
auto shape = inputs[0].shape();
28942904
return array(
2895-
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
2905+
std::move(shape),
2906+
dtype,
2907+
std::make_shared<Divide>(to_stream(s)),
2908+
std::move(inputs));
28962909
}
28972910
array operator/(const array& a, const array& b) {
28982911
return divide(a, b);
@@ -2914,18 +2927,21 @@ array floor_divide(
29142927
}
29152928

29162929
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
2917-
auto& shape = inputs[0].shape();
2930+
auto shape = inputs[0].shape();
29182931
return array(
2919-
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
2932+
std::move(shape),
2933+
dtype,
2934+
std::make_shared<Divide>(to_stream(s)),
2935+
std::move(inputs));
29202936
}
29212937

29222938
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
29232939
auto dtype = promote_types(a.dtype(), b.dtype());
29242940
auto inputs = broadcast_arrays(
29252941
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
2926-
auto& shape = inputs[0].shape();
2942+
auto shape = inputs[0].shape();
29272943
return array(
2928-
shape,
2944+
std::move(shape),
29292945
dtype,
29302946
std::make_shared<Remainder>(to_stream(s)),
29312947
std::move(inputs));
@@ -2953,9 +2969,9 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
29532969
auto out_type = promote_types(a.dtype(), b.dtype());
29542970
auto inputs =
29552971
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
2956-
auto& shape = inputs[0].shape();
2972+
auto shape = inputs[0].shape();
29572973
return array(
2958-
shape,
2974+
std::move(shape),
29592975
out_type,
29602976
std::make_shared<Maximum>(to_stream(s)),
29612977
std::move(inputs));
@@ -2965,9 +2981,9 @@ array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
29652981
auto out_type = promote_types(a.dtype(), b.dtype());
29662982
auto inputs =
29672983
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
2968-
auto& shape = inputs[0].shape();
2984+
auto shape = inputs[0].shape();
29692985
return array(
2970-
shape,
2986+
std::move(shape),
29712987
out_type,
29722988
std::make_shared<Minimum>(to_stream(s)),
29732989
std::move(inputs));
@@ -3048,9 +3064,12 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
30483064
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
30493065
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
30503066
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
3051-
auto& shape = inputs[0].shape();
3067+
auto shape = inputs[0].shape();
30523068
return array(
3053-
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
3069+
std::move(shape),
3070+
dtype,
3071+
std::make_shared<ArcTan2>(to_stream(s)),
3072+
std::move(inputs));
30543073
}
30553074

30563075
array sinh(const array& a, StreamOrDevice s /* = {} */) {
@@ -3144,9 +3163,9 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
31443163
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
31453164
auto inputs =
31463165
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
3147-
auto& shape = inputs[0].shape();
3166+
auto shape = inputs[0].shape();
31483167
return array(
3149-
shape,
3168+
std::move(shape),
31503169
out_type,
31513170
std::make_shared<LogAddExp>(to_stream(s)),
31523171
std::move(inputs));

0 commit comments

Comments
 (0)