forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPlacement.h
More file actions
121 lines (96 loc) · 2.82 KB
/
Placement.h
File metadata and controls
121 lines (96 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#pragma once
/**
* The implementations in this file are coupled with
* torch/distributed/tensor/placement_types.py.
*/
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
namespace torch::distributed {
class Placement {
public:
Placement() = default;
virtual ~Placement() = default;
Placement(const Placement&) = default;
Placement& operator=(const Placement&) = default;
Placement(Placement&&) noexcept = default;
Placement& operator=(Placement&&) noexcept = default;
virtual bool is_shard(std::optional<std::int64_t> dim) const {
return false;
}
virtual bool is_replicate() const {
return false;
}
virtual bool is_partial(
std::optional<std::string_view> reduce_op = std::nullopt) const {
return false;
}
};
class Shard : public Placement {
public:
std::int64_t dim;
explicit Shard(std::int64_t dim_) : dim(dim_) {}
bool is_shard(std::optional<std::int64_t> dim_) const override {
return !dim_.has_value() || *dim_ == dim;
}
bool operator==(const Shard& rhs) const {
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class StridedShard : public Shard {
public:
std::int64_t split_factor;
explicit StridedShard(std::int64_t dim, std::int64_t split_factor_)
: Shard(dim), split_factor(split_factor_) {}
bool operator==(const StridedShard& rhs) const {
return dim == rhs.dim && split_factor == rhs.split_factor;
}
bool operator==(const Shard& rhs) const {
if (auto* rhs_strided = dynamic_cast<const StridedShard*>(&rhs)) {
return operator==(*rhs_strided);
}
// TODO: this is to avoid extra all-gather in dtensor op dispatch
// note that sharding prop would not produce _StridedShard and a
// placement inequality would introduce an all-gather for resharding
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class Replicate : public Placement {
public:
bool is_replicate() const override {
return true;
}
bool operator==(const Replicate& rhs) const {
return true;
}
bool operator!=(const Replicate& rhs) const {
return false;
}
};
class Partial : public Placement {
public:
std::string reduce_op;
Partial() : Partial("sum") {}
explicit Partial(std::optional<std::string> reduce_op_)
: reduce_op(
reduce_op_.has_value() ? std::move(*reduce_op_)
: std::string("sum")) {}
bool is_partial(
std::optional<std::string_view> op = std::nullopt) const override {
return !op.has_value() || *op == reduce_op;
}
bool operator==(const Partial& rhs) const {
return reduce_op == rhs.reduce_op;
}
bool operator!=(const Partial& rhs) const {
return !operator==(rhs);
}
};
} // namespace torch::distributed