|
| 1 | +#include "infinicore/ops/layer_norm.hpp" |
| 2 | +#include "../../utils.hpp" |
| 3 | + |
| 4 | +namespace infinicore::op { |
| 5 | +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(LayerNorm); |
| 6 | + |
| 7 | +LayerNorm::LayerNorm(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { |
| 8 | + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, standardization, std_deviation, x, weight); |
| 9 | + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, standardization, std_deviation, x, weight, bias, epsilon); |
| 10 | +} |
| 11 | + |
| 12 | +void LayerNorm::execute(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { |
| 13 | + INFINICORE_GRAPH_OP_RECORD_OR_RUN(LayerNorm, y, standardization, std_deviation, x, weight, bias, epsilon); |
| 14 | +} |
| 15 | + |
| 16 | +Tensor layer_norm(const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { |
| 17 | + auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); |
| 18 | + auto reduced_shape = x->shape(); |
| 19 | + reduced_shape.pop_back(); |
| 20 | + auto standardization = Tensor::empty(x->shape(), x->dtype(), x->device()); |
| 21 | + auto std_deviation = Tensor::empty(reduced_shape, x->dtype(), x->device()); |
| 22 | + layer_norm_(y, standardization, std_deviation, x, weight, bias, epsilon); |
| 23 | + return y; |
| 24 | +} |
| 25 | + |
| 26 | +void layer_norm_(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { |
| 27 | + LayerNorm::execute(y, standardization, std_deviation, x, weight, bias, epsilon); |
| 28 | +} |
| 29 | + |
| 30 | +void layer_norm_(Tensor y, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { |
| 31 | + auto reduced_shape = x->shape(); |
| 32 | + reduced_shape.pop_back(); |
| 33 | + auto standardization = Tensor::empty(x->shape(), x->dtype(), x->device()); |
| 34 | + auto std_deviation = Tensor::empty(reduced_shape, x->dtype(), x->device()); |
| 35 | + LayerNorm::execute(y, standardization, std_deviation, x, weight, bias, epsilon); |
| 36 | +} |
| 37 | + |
| 38 | +void layer_norm_for_pybind(Tensor y, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { |
| 39 | + layer_norm_(y, x, weight, bias, epsilon); |
| 40 | +} |
| 41 | + |
| 42 | +} // namespace infinicore::op |
0 commit comments