|
21 | 21 | #include "gflags/gflags.h" |
22 | 22 | #include "libspu/core/config.h" |
23 | 23 | #include "libspu/core/encoding.h" |
24 | | -#include "libspu/kernel/hal/hal.h" |
| 24 | +#include "libspu/kernel/hal/constants.h" |
| 25 | +#include "libspu/kernel/hal/polymorphic.h" |
| 26 | +#include "libspu/kernel/hal/public_helper.h" |
| 27 | +#include "libspu/kernel/hal/shape_ops.h" |
| 28 | +#include "libspu/kernel/hal/type_cast.h" |
25 | 29 | #include "libspu/mpc/aby3/type.h" |
26 | 30 | #include "libspu/mpc/factory.h" |
27 | 31 | #include "libspu/mpc/semi2k/type.h" |
| 32 | +#include "xtensor/xarray.hpp" |
28 | 33 | #include "xtensor/xcsv.hpp" |
29 | | -#include "xtensor/xio.hpp" |
| 34 | +#include "xtensor/xview.hpp" |
30 | 35 |
|
31 | 36 | DEFINE_string(dataset, "data.csv", "dataset file, only csv is supported"); |
32 | 37 | DEFINE_int32(skip_rows, 1, "skip number of rows from dataset"); |
33 | 38 | DEFINE_string(lr_output, "/tmp/sslr_result", "full path name of output file"); |
34 | 39 | DECLARE_int32(rank); |
| 40 | +DECLARE_bool(disable_handshake); |
35 | 41 |
|
36 | 42 | namespace ic_impl::algo::lr { |
37 | 43 |
|
@@ -196,11 +202,17 @@ bool LrHandler::PrepareDataset() { |
196 | 202 | int32_t feature_num = |
197 | 203 | ctx_->HasLabel() ? dataset_->shape(1) - 1 : dataset_->shape(1); |
198 | 204 | YACL_ENFORCE(sample_size > 0); |
199 | | - YACL_ENFORCE(feature_num > 0); |
200 | 205 |
|
201 | 206 | ctx_->io_param.sample_size = sample_size; |
202 | | - ctx_->io_param.feature_nums.resize(ctx_->ic_ctx->lctx->WorldSize()); |
203 | | - ctx_->io_param.feature_nums[ctx_->ic_ctx->lctx->Rank()] = feature_num; |
| 207 | + auto self_rank = ctx_->ic_ctx->lctx->Rank(); |
| 208 | + |
| 209 | + if (FLAGS_disable_handshake) { |
| 210 | + YACL_ENFORCE(ctx_->io_param.feature_nums.at(self_rank) == feature_num); |
| 211 | + } else { |
| 212 | + YACL_ENFORCE(feature_num > 0); |
| 213 | + ctx_->io_param.feature_nums.resize(ctx_->ic_ctx->lctx->WorldSize()); |
| 214 | + ctx_->io_param.feature_nums.at(self_rank) = feature_num; |
| 215 | + } |
204 | 216 |
|
205 | 217 | return true; |
206 | 218 | } |
@@ -676,15 +688,17 @@ float Accuracy(const xt::xarray<float>& y_true, |
676 | 688 |
|
677 | 689 | void ProduceOutput(spu::SPUContext* sctx, const spu::Value& w) { |
678 | 690 | // output result shares to the file |
679 | | - spu::PtType out_pt_type; |
680 | | - auto w_output = spu::decodeFromRing( |
681 | | - w.data(), w.dtype(), sctx->config().fxp_fraction_bits(), &out_pt_type); |
682 | | - YACL_ENFORCE(out_pt_type == spu::PT_F32); |
| 691 | + spu::PtType pt_type = getDecodeType(w.dtype()); |
| 692 | + spu::NdArrayRef dst(makePtType(pt_type), w.shape()); |
| 693 | + spu::PtBufferView pv(static_cast<void*>(dst.data()), pt_type, dst.shape(), |
| 694 | + dst.strides()); |
| 695 | + spu::decodeFromRing(w.data(), w.dtype(), sctx->config().fxp_fraction_bits(), |
| 696 | + &pv); |
683 | 697 |
|
684 | 698 | std::string out_file_name = GetLrOutputFileName(); |
685 | 699 | std::ofstream of(out_file_name); |
686 | 700 | YACL_ENFORCE(of, "open file={} failed", out_file_name); |
687 | | - for (auto it = w_output.begin(); it != w_output.end(); ++it) { |
| 701 | + for (auto it = dst.begin(); it != dst.end(); ++it) { |
688 | 702 | auto* item = reinterpret_cast<float*>(it.getRawPtr()); |
689 | 703 | of << *item << '\n'; |
690 | 704 | } |
@@ -756,10 +770,17 @@ std::unique_ptr<spu::SPUContext> LrHandler::MakeSpuContext() { |
756 | 770 |
|
757 | 771 | spu::Value LrHandler::EncodingDataset(spu::PtBufferView dataset) { |
758 | 772 | // encode to ring. |
| 773 | + auto array = convertToNdArray(dataset); |
| 774 | + SPU_ENFORCE(array.eltype().isa<spu::PtTy>(), "expect PtType, got={}", |
| 775 | + array.eltype()); |
| 776 | + |
| 777 | + const spu::PtType pt_type = array.eltype().as<spu::PtTy>()->pt_type(); |
| 778 | + spu::PtBufferView pv(static_cast<const void*>(array.data()), pt_type, |
| 779 | + array.shape(), array.strides()); |
| 780 | + |
759 | 781 | spu::DataType dtype; |
760 | 782 | spu::NdArrayRef encoded = |
761 | | - encodeToRing(convertToNdArray(dataset), |
762 | | - static_cast<spu::FieldType>(ctx_->ss_param.field_type), |
| 783 | + encodeToRing(pv, static_cast<spu::FieldType>(ctx_->ss_param.field_type), |
763 | 784 | ctx_->ss_param.fxp_bits, &dtype); |
764 | 785 |
|
765 | 786 | return spu::Value(encoded, dtype); |
|
0 commit comments