Skip to content

Commit f02407f

Browse files
authored
Merge pull request #9 from shaojian-ant/main
repo-sync-2024-04-12T15:11:18+0800
2 parents b98a843 + 9e71147 commit f02407f

23 files changed

Lines changed: 429 additions & 206 deletions

.bazeliskrc

Lines changed: 0 additions & 2 deletions
This file was deleted.

.bazelrc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
common --experimental_repo_remote_exec
2+
common --experimental_cc_shared_library
23

34
build --incompatible_new_actions_api=false
45
build --copt=-fdiagnostics-color=always
@@ -29,3 +30,12 @@ build:linux --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-l%:libgcc.a
2930
build:linux --copt=-fopenmp
3031
build:linux --linkopt=-fopenmp
3132

33+
# Required by OpenXLA
34+
build --nocheck_visibility
35+
36+
# default off CUDA build
37+
build --@rules_cuda//cuda:enable=false
38+
39+
# Only on when asked
40+
build:gpu --@rules_cuda//cuda:archs=compute_80:compute_80
41+
build:gpu --@rules_cuda//cuda:enable=true

.bazelversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
5.4.1
1+
6.5.0

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ cmake-build-debug-remote
66
bazel-*
77

88
build
9-
.DS_Store
9+
.DS_Store
10+
11+
.idea

README.md

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44

55
## 构建
66

7-
interconnection-impl引用了spu仓库代码,需要根据[spu构建前提](https://github.com/secretflow/spu/blob/main/CONTRIBUTING.md#build)在编译环境上安装好依赖库
7+
interconnection-impl 引用了 spu 仓库代码,需要根据[ spu 构建前提](https://github.com/secretflow/spu/blob/main/CONTRIBUTING.md#build)在编译环境上安装好依赖库
88

99
然后执行以下构建指令:
1010

1111
```shell
1212
bazel build ic_impl/ic_main
1313
```
1414

15-
## 运行
15+
## 运行 ECDH-PSI
1616

17-
### ECDH-PSI
17+
### 命令行传参
1818

1919
本地同时执行以下两条指令:
2020

@@ -30,16 +30,49 @@ bazel run ic_impl/ic_main -- -rank=1 -algo=ECDH_PSI -protocol_families=ECC \
3030
-parties=127.0.0.1:9530,127.0.0.1:9531
3131
```
3232

33-
### SS-LR
33+
### 环境变量传参
3434

35-
运行SS-LR之前,需要先启动Beaver服务。Beaver服务的代码位于SPU仓库中,需要将SPU代码克隆到本地,然后编译并启动Beaver服务:
35+
为满足北京金融科技产业联盟的调度层互联互通标准对算法组件接口的要求,interconnection-impl 支持 ECDH-PSI 算法从环境变量读取配置参数
36+
37+
当某个参数在环境变量和命令行选项都被指定时,优先选择读取环境变量参数
38+
39+
程序运行需要关闭握手过程:
40+
```shell
41+
bazel run ic_impl/ic_main -- -disable_handshake=1
42+
```
43+
44+
ECDH-PSI 算法配置的环境变量如下表所示。环境变量设置可参考 [ecdh-psi-env-alice.sh](./ic_impl/env/ecdh-psi-env-alice.sh)[ecdh-psi-env-bob.sh](./ic_impl/env/ecdh-psi-env-bob.sh)
45+
46+
| 环境变量 | 参考值 | 描述 |
47+
|:-------------------------------------------------------|:----------------------------------------:|:----------------------------------------------------:|
48+
| runtime.component.parameter.algo | ecdh_psi | algorithm |
49+
| runtime.component.parameter.protocol_families | ecc | comma-separated list of protocol families |
50+
| runtime.component.parameter.curve_type | curve25519 | elliptic curve type |
51+
| runtime.component.parameter.hash_type | sha_256 | hash type |
52+
| runtime.component.parameter.hash2curve_strategy | direct_hash_as_point_x | hash to curve strategy |
53+
| runtime.component.parameter.point_octet_format | uncompressed | point Octet-String format |
54+
| runtime.component.parameter.bit_length_after_truncated | -1 | optimization method: secondary ciphertext truncation |
55+
| system.storage.host.url | file://path/to/root | root path of input/output file |
56+
| runtime.component.input.train_data | {"namespace":"data","name":"psi_1.csv"} | relative path and name of input file |
57+
| runtime.component.parameter.field_names | id | field names |
58+
| runtime.component.parameter.result_to_rank | -1 | which rank gets the result |
59+
| runtime.component.output.train_data | {"namespace":"output","name":"result_a"} | relative path and name of output file |
60+
61+
## 运行 SS-LR
62+
63+
### 启动 Beaver 服务
64+
65+
运行 SS-LR 之前,需要先启动 Beaver 服务。Beaver 服务的代码位于 SPU 仓库中,需要将 SPU 代码克隆到本地,然后编译并启动 Beaver
66+
服务:
3667

3768
```shell
3869
git clone git@github.com:secretflow/spu.git
3970
cd spu && bazel run libspu/mpc/semi2k/beaver/ttp_server:beaver_server_main -- -port=9449
4071
```
4172

42-
启动Beaver服务后,本地同时执行以下两条指令:
73+
### 命令行传参
74+
75+
启动 Beaver 服务后,本地同时执行以下两条指令:
4376

4477
```shell
4578
bazel run ic_impl/ic_main -- -rank=0 -algo=SS_LR -protocol_families=SS \
@@ -55,13 +88,18 @@ bazel run ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \
5588
-parties=127.0.0.1:9530,127.0.0.1:9531
5689
```
5790

58-
## 环境变量传参
91+
### 环境变量传参
5992

60-
为满足北京金融科技产业联盟的调度层互联互通标准对算法组件接口的要求,interconnection-impl支持SS-LR算法从环境变量读取配置参数
93+
为满足北京金融科技产业联盟的调度层互联互通标准对算法组件接口的要求,interconnection-impl 支持 SS-LR 算法从环境变量读取配置参数
6194

6295
当某个参数在环境变量和命令行选项都被指定时,优先选择读取环境变量参数
6396

64-
### 环境变量定义
97+
程序运行需要关闭握手过程:
98+
```shell
99+
bazel run ic_impl/ic_main -- -disable_handshake=1
100+
```
101+
102+
SS-LR 算法配置的环境变量如下表所示。环境变量设置可参考 [ss-lr-env-alice.sh](./ic_impl/env/ss-lr-env-alice.sh)[ss-lr-env-bob.sh](./ic_impl/env/ss-lr-env-bob.sh)
65103

66104
| 环境变量 | 参考值 | 描述 |
67105
|:---------------------------------------------------|:-------------------------------------------------:|:---------------------------------------------------------:|
@@ -85,10 +123,11 @@ bazel run ic_impl/ic_main -- -rank=1 -algo=SS_LR -protocol_families=SS \
85123
| runtime.component.parameter.ttp_server_host | ip:port | remote ip:port or load-balance uri of beaver service |
86124
| runtime.component.parameter.ttp_session_id | interconnection-root | session id of beaver service |
87125
| runtime.component.parameter.ttp_adjust_rank | 0 | which rank do adjust rpc call to beaver service |
88-
| system.storage | file://path/to/root | root path of input / output file |
126+
| system.storage.host.url | file://path/to/root | root path of input/output file |
89127
| runtime.component.input.train_data | {"namespace":"data","name":"perfect_logit_a.csv"} | relative path and name of input file |
90128
| runtime.component.parameter.skip_rows=1 | 1 | number of skipped rows from dataset |
91-
| runtime.component.parameter.has_label | true | if true, label is the last column of dataset |
129+
| runtime.component.parameter.label_owner | host.0 | which party owns the label column |
130+
| runtime.component.parameter.feature_nums | {"host.0":10, "guest.0":10} | feature column nums of each party |
92131
| runtime.component.output.train_data | {"namespace":"output","name":"result_a"} | relative path and name of output file |
93132

94133
## FAQ

WORKSPACE

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ protocol_deps()
2020
load("//bazel:repositories.bzl", "ic_impl_deps")
2121
ic_impl_deps()
2222

23+
load("@psi//bazel:repositories.bzl", "psi_deps")
24+
25+
psi_deps()
26+
2327
# spu
2428
load("@spulib//bazel:repositories.bzl", "spu_deps")
2529
spu_deps()
@@ -28,6 +32,20 @@ spu_deps()
2832
load("@yacl//bazel:repositories.bzl", "yacl_deps")
2933
yacl_deps()
3034

35+
load("@rules_python//python:repositories.bzl", "py_repositories")
36+
37+
py_repositories()
38+
39+
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
40+
41+
boost_deps()
42+
43+
load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies")
44+
45+
rules_cuda_dependencies()
46+
47+
register_detected_cuda_toolchains()
48+
3149
# xla
3250
load("@xla//:workspace4.bzl", "xla_workspace4")
3351
xla_workspace4()

bazel/repositories.bzl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
1717

1818
SECRETFLOW_GIT = "https://github.com/secretflow"
1919

20-
SPU_COMMIT_ID = "b28e086c6b0fc2b01e7be80ea438afd9ef54f3e2"
20+
SPU_COMMIT_ID = "8bf3c97da503f1cffd1292c8e365ecbc30675400"
2121

22-
IC_COMMIT_ID = "a2ffff21a528456c383a113f6a57b98b3c9de6fe"
22+
PSI_COMMIT_ID = "58f687b9949bfda4cf31a09bb6d4a3bc1f375757"
23+
24+
IC_COMMIT_ID = "30e4220b7444d0bb077a9040f1b428632124e31a"
2325

2426
SPU_REPOSITORY = "SPU"
2527

@@ -30,7 +32,14 @@ def ic_impl_deps():
3032
git_repository,
3133
name = "spulib",
3234
commit = SPU_COMMIT_ID,
33-
remote = "{}/{}.git".format("https://github.com/shaojian-ant", SPU_REPOSITORY),
35+
remote = "{}/{}.git".format(SECRETFLOW_GIT, SPU_REPOSITORY),
36+
)
37+
38+
maybe(
39+
git_repository,
40+
name = "psi",
41+
commit = PSI_COMMIT_ID,
42+
remote = "{}/psi.git".format(SECRETFLOW_GIT),
3443
)
3544

3645
def protocol_deps():

ic_impl/algo/lr/BUILD.bazel

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ cc_library(
2525
"//ic_impl:handler",
2626
"@spulib//libspu/mpc:factory",
2727
"@com_google_absl//absl/functional:bind_front",
28+
"@spulib//libspu/kernel/hal:constants",
29+
"@spulib//libspu/kernel/hal:polymorphic",
30+
"@spulib//libspu/kernel/hal:type_cast",
31+
"@spulib//libspu/kernel/hal:shape_ops",
32+
"@spulib//libspu/kernel/hal:public_helper",
2833
]
2934
)
3035

@@ -37,7 +42,6 @@ cc_library(
3742
"//ic_impl:context",
3843
"//ic_impl/op/sigmoid",
3944
"//ic_impl/protocol_family/ss",
40-
"@spulib//libspu/kernel/hal",
4145
]
4246
)
4347

ic_impl/algo/lr/lr_context.cc

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ic_impl/algo/lr/lr_context.h"
1616

1717
#include "gflags/gflags.h"
18+
#include "nlohmann/json.hpp"
1819

1920
#include "ic_impl/op/sigmoid/sigmoid.h"
2021
#include "ic_impl/util.h"
@@ -30,14 +31,12 @@ DEFINE_double(l0_norm, 0.0, "l0 norm");
3031
DEFINE_double(l1_norm, 0.0, "l1 norm");
3132
DEFINE_double(l2_norm, 0.5, "l2 norm");
3233

34+
DECLARE_bool(disable_handshake);
35+
3336
namespace ic_impl::algo::lr {
3437

3538
namespace {
3639

37-
bool SuggestedHasLabel() {
38-
return util::GetParamEnv("has_label", FLAGS_has_label);
39-
}
40-
4140
int64_t SuggestedNumEpoch() {
4241
return util::GetParamEnv("num_epoch", FLAGS_num_epoch);
4342
}
@@ -71,13 +70,54 @@ LrHyperParam SuggestedLrHyperParam() {
7170
return lr_param;
7271
}
7372

73+
int32_t GetLabelRank(const std::shared_ptr<IcContext>& ic_ctx) {
74+
if (!FLAGS_disable_handshake) {
75+
return FLAGS_has_label ? ic_ctx->lctx->Rank() : -1;
76+
}
77+
78+
// get parameters from ENV
79+
char* label_owner = util::GetParamEnv("label_owner");
80+
YACL_ENFORCE(label_owner != nullptr, "label_owner not in ENV");
81+
82+
size_t word_size = ic_ctx->lctx->WorldSize();
83+
for (size_t i = 0; i < word_size; ++i) {
84+
if (ic_ctx->lctx->PartyIdByRank(i) == label_owner) {
85+
return i;
86+
}
87+
}
88+
89+
YACL_THROW("get label_rank failed");
90+
}
91+
92+
std::vector<int32_t> GetFeatureNums(const std::shared_ptr<IcContext>& ic_ctx) {
93+
std::vector<int32_t> feature_nums;
94+
if (!FLAGS_disable_handshake) {
95+
return feature_nums; // got through the handshake that follows
96+
}
97+
98+
// get parameters from ENV
99+
char* json_str = util::GetParamEnv("feature_nums");
100+
YACL_ENFORCE(json_str != nullptr, "feature_nums not in ENV");
101+
102+
size_t word_size = ic_ctx->lctx->WorldSize();
103+
feature_nums.resize(word_size);
104+
auto json_object = nlohmann::json::parse(json_str);
105+
for (size_t i = 0; i < word_size; ++i) {
106+
feature_nums.at(i) = json_object.at(ic_ctx->lctx->PartyIdByRank(i));
107+
}
108+
109+
return feature_nums;
110+
}
111+
74112
} // namespace
75113

76114
std::shared_ptr<LrContext> CreateLrContext(std::shared_ptr<IcContext> ic_ctx) {
77115
auto ctx = std::make_shared<LrContext>();
78116
ctx->lr_param = SuggestedLrHyperParam();
79117

80-
ctx->io_param.label_rank = SuggestedHasLabel() ? ic_ctx->lctx->Rank() : -1;
118+
ctx->io_param.label_rank = GetLabelRank(ic_ctx);
119+
120+
ctx->io_param.feature_nums = GetFeatureNums(ic_ctx);
81121

82122
ctx->sigmoid_mode = op::sigmoid::SuggestedSigmoidMode();
83123

ic_impl/algo/lr/lr_handler.cc

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,23 @@
2121
#include "gflags/gflags.h"
2222
#include "libspu/core/config.h"
2323
#include "libspu/core/encoding.h"
24-
#include "libspu/kernel/hal/hal.h"
24+
#include "libspu/kernel/hal/constants.h"
25+
#include "libspu/kernel/hal/polymorphic.h"
26+
#include "libspu/kernel/hal/public_helper.h"
27+
#include "libspu/kernel/hal/shape_ops.h"
28+
#include "libspu/kernel/hal/type_cast.h"
2529
#include "libspu/mpc/aby3/type.h"
2630
#include "libspu/mpc/factory.h"
2731
#include "libspu/mpc/semi2k/type.h"
32+
#include "xtensor/xarray.hpp"
2833
#include "xtensor/xcsv.hpp"
29-
#include "xtensor/xio.hpp"
34+
#include "xtensor/xview.hpp"
3035

3136
DEFINE_string(dataset, "data.csv", "dataset file, only csv is supported");
3237
DEFINE_int32(skip_rows, 1, "skip number of rows from dataset");
3338
DEFINE_string(lr_output, "/tmp/sslr_result", "full path name of output file");
3439
DECLARE_int32(rank);
40+
DECLARE_bool(disable_handshake);
3541

3642
namespace ic_impl::algo::lr {
3743

@@ -196,11 +202,17 @@ bool LrHandler::PrepareDataset() {
196202
int32_t feature_num =
197203
ctx_->HasLabel() ? dataset_->shape(1) - 1 : dataset_->shape(1);
198204
YACL_ENFORCE(sample_size > 0);
199-
YACL_ENFORCE(feature_num > 0);
200205

201206
ctx_->io_param.sample_size = sample_size;
202-
ctx_->io_param.feature_nums.resize(ctx_->ic_ctx->lctx->WorldSize());
203-
ctx_->io_param.feature_nums[ctx_->ic_ctx->lctx->Rank()] = feature_num;
207+
auto self_rank = ctx_->ic_ctx->lctx->Rank();
208+
209+
if (FLAGS_disable_handshake) {
210+
YACL_ENFORCE(ctx_->io_param.feature_nums.at(self_rank) == feature_num);
211+
} else {
212+
YACL_ENFORCE(feature_num > 0);
213+
ctx_->io_param.feature_nums.resize(ctx_->ic_ctx->lctx->WorldSize());
214+
ctx_->io_param.feature_nums.at(self_rank) = feature_num;
215+
}
204216

205217
return true;
206218
}
@@ -676,15 +688,17 @@ float Accuracy(const xt::xarray<float>& y_true,
676688

677689
void ProduceOutput(spu::SPUContext* sctx, const spu::Value& w) {
678690
// output result shares to the file
679-
spu::PtType out_pt_type;
680-
auto w_output = spu::decodeFromRing(
681-
w.data(), w.dtype(), sctx->config().fxp_fraction_bits(), &out_pt_type);
682-
YACL_ENFORCE(out_pt_type == spu::PT_F32);
691+
spu::PtType pt_type = getDecodeType(w.dtype());
692+
spu::NdArrayRef dst(makePtType(pt_type), w.shape());
693+
spu::PtBufferView pv(static_cast<void*>(dst.data()), pt_type, dst.shape(),
694+
dst.strides());
695+
spu::decodeFromRing(w.data(), w.dtype(), sctx->config().fxp_fraction_bits(),
696+
&pv);
683697

684698
std::string out_file_name = GetLrOutputFileName();
685699
std::ofstream of(out_file_name);
686700
YACL_ENFORCE(of, "open file={} failed", out_file_name);
687-
for (auto it = w_output.begin(); it != w_output.end(); ++it) {
701+
for (auto it = dst.begin(); it != dst.end(); ++it) {
688702
auto* item = reinterpret_cast<float*>(it.getRawPtr());
689703
of << *item << '\n';
690704
}
@@ -756,10 +770,17 @@ std::unique_ptr<spu::SPUContext> LrHandler::MakeSpuContext() {
756770

757771
spu::Value LrHandler::EncodingDataset(spu::PtBufferView dataset) {
758772
// encode to ring.
773+
auto array = convertToNdArray(dataset);
774+
SPU_ENFORCE(array.eltype().isa<spu::PtTy>(), "expect PtType, got={}",
775+
array.eltype());
776+
777+
const spu::PtType pt_type = array.eltype().as<spu::PtTy>()->pt_type();
778+
spu::PtBufferView pv(static_cast<const void*>(array.data()), pt_type,
779+
array.shape(), array.strides());
780+
759781
spu::DataType dtype;
760782
spu::NdArrayRef encoded =
761-
encodeToRing(convertToNdArray(dataset),
762-
static_cast<spu::FieldType>(ctx_->ss_param.field_type),
783+
encodeToRing(pv, static_cast<spu::FieldType>(ctx_->ss_param.field_type),
763784
ctx_->ss_param.fxp_bits, &dtype);
764785

765786
return spu::Value(encoded, dtype);

0 commit comments

Comments
 (0)