diff --git a/.github/actions/tr_post_test_run/action.yml b/.github/actions/tr_post_test_run/action.yml index d227ca45de..3a0ca6c42a 100644 --- a/.github/actions/tr_post_test_run/action.yml +++ b/.github/actions/tr_post_test_run/action.yml @@ -24,7 +24,7 @@ runs: id: tar_files if: ${{ always() }} run: | - tar -cvf result.tar --exclude="cert" --exclude="data" --exclude="__pycache__" --exclude="tensor.db" --exclude="workspace.tar" $HOME/results + tar -cvf result.tar --exclude="cert" --exclude="data" --exclude="__pycache__" --exclude="tensor.db" --exclude="workspace.tar" --exclude="minio_data" $HOME/results # Model name might contain forward slashes, convert them to underscore. tmp=${{ env.MODEL_NAME }} echo "MODEL_NAME_MODIFIED=${tmp//\//_}" >> $GITHUB_ENV diff --git a/.github/workflows/pq_pipeline.yml b/.github/workflows/pq_pipeline.yml index 41d5ac99c9..17efcd361a 100644 --- a/.github/workflows/pq_pipeline.yml +++ b/.github/workflows/pq_pipeline.yml @@ -80,7 +80,7 @@ jobs: (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch') name: TaskRunner E2E - needs: set_commit_id_for_all_jobs + needs: task_runner_connectivity_e2e uses: ./.github/workflows/task_runner_basic_e2e.yml with: commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} @@ -90,7 +90,7 @@ jobs: (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch') name: TaskRunner Resiliency E2E - needs: task_runner_e2e + needs: task_runner_connectivity_e2e uses: ./.github/workflows/task_runner_resiliency_e2e.yml with: commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} @@ -131,7 +131,7 @@ jobs: (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch') name: TaskRunner Dockerized E2E - needs: task_runner_straggler_e2e + needs: task_runner_resiliency_e2e uses: ./.github/workflows/task_runner_dockerized_ws_e2e.yml with: commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} @@ -158,6 +158,26 @@ jobs: with: commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} + task_runner_fed_analytics_e2e: + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') + name: TaskRunner Federated Analytics E2E + needs: task_runner_connectivity_e2e + uses: ./.github/workflows/task_runner_fed_analytics_e2e.yml + with: + commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} + + tr_verifiable_dataset_e2e: + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') + name: TaskRunner Verifiable Dataset E2E + needs: task_runner_connectivity_e2e + uses: ./.github/workflows/tr_verifiable_dataset_e2e.yml + with: + commit_id: ${{ needs.set_commit_id_for_all_jobs.outputs.commit_id }} + run_trivy: if: | (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || @@ -198,7 +218,9 @@ jobs: wf_mnist_local_runtime, wf_watermark_e2e, wf_secagg_e2e, + task_runner_connectivity_e2e, task_runner_e2e, + task_runner_fed_analytics_e2e, task_runner_resiliency_e2e, task_runner_fedeval_e2e, task_runner_secure_agg_e2e, @@ -206,6 +228,7 @@ jobs: task_runner_dockerized_e2e, task_runner_secret_ssl_e2e, task_runner_flower_app_pytorch, + tr_verifiable_dataset_e2e, run_trivy, run_bandit ] diff --git a/.github/workflows/task_runner_connectivity_e2e.yml b/.github/workflows/task_runner_connectivity_e2e.yml index c8e404475d..d13e02a075 100644 --- a/.github/workflows/task_runner_connectivity_e2e.yml +++ b/.github/workflows/task_runner_connectivity_e2e.yml @@ -52,37 +52,36 @@ jobs: with: test_type: "TLS_Connectivity_gRPC" -# Uncomment once Rest API PR is merged - # test_rest_connectivity: - # name: Task Runner Rest connectivity (no-op, 3.11, rest) - # if: | - # (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || - # (github.event_name == 'workflow_dispatch') || - # (github.event.pull_request.draft == false) - # runs-on: ubuntu-22.04 - # timeout-minutes: 30 - # env: - # MODEL_NAME: 'no-op' - # PYTHON_VERSION: '3.11' - # steps: - # - name: Checkout OpenFL repository - # id: checkout_openfl - # uses: actions/checkout@v4 - # with: - # ref: ${{ env.COMMIT_ID }} + test_rest_connectivity: + name: Task Runner Rest connectivity (no-op, 3.11) + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') || + (github.event.pull_request.draft == false) + runs-on: ubuntu-22.04 + timeout-minutes: 30 + env: + MODEL_NAME: 'no-op' + PYTHON_VERSION: '3.11' + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} - # - name: Pre test run - # uses: ./.github/actions/tr_pre_test_run - # if: ${{ always() }} + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} - # - name: Run Task Runner rest connectivity test - # id: run_tests - # run: | - # python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py -k test_federation_connectivity --model_name ${{ env.MODEL_NAME }} --tr_rest_api - # echo "Task runner end to end test run completed" + - name: Run Task Runner rest connectivity test + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py -k test_federation_connectivity --model_name ${{ env.MODEL_NAME }} --tr_rest_protocol + echo "Task runner end to end test run completed" - # - name: Post test run - # uses: ./.github/actions/tr_post_test_run - # if: ${{ always() }} - # with: - # test_type: "TLS_Connectivity_REST" \ No newline at end of file + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "TLS_Connectivity_REST" \ No newline at end of file diff --git a/.github/workflows/task_runner_dockerized_ws_e2e.yml b/.github/workflows/task_runner_dockerized_ws_e2e.yml index 278b39773b..a6904ccae0 100644 --- a/.github/workflows/task_runner_dockerized_ws_e2e.yml +++ b/.github/workflows/task_runner_dockerized_ws_e2e.yml @@ -180,7 +180,7 @@ jobs: test_type: "DWS_Without_Client_Auth" test_memory_logs: - name: With Memory Logs + name: With Memory Logs REST needs: input_selection if: needs.input_selection.outputs.selected_jobs == 'all' || needs.input_selection.outputs.selected_jobs == 'test_memory_logs' runs-on: ubuntu-22.04 @@ -212,11 +212,11 @@ jobs: python -m pytest -s tests/end_to_end/test_suites/memory_logs_tests.py \ -k test_log_memory_usage_dockerized_ws --model_name ${{ env.MODEL_NAME }} \ --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} \ - --log_memory_usage + --log_memory_usage --tr_rest_protocol echo "Task runner memory logs test run completed" - name: Post test run uses: ./.github/actions/tr_post_test_run if: ${{ always() }} with: - test_type: "DWS_With_Memory_Logs" + test_type: "DWS_With_Memory_Logs_REST" diff --git a/.github/workflows/task_runner_fed_analytics_e2e.yml b/.github/workflows/task_runner_fed_analytics_e2e.yml new file mode 100644 index 0000000000..57d21ba733 --- /dev/null +++ b/.github/workflows/task_runner_fed_analytics_e2e.yml @@ -0,0 +1,108 @@ +--- +# Task Runner Federated Analytics E2E tests for bare metal approach + +name: Task_Runner_Fed_Analytics_E2E # Please do not modify the name as it is used in the composite action + +on: + workflow_call: + inputs: + commit_id: + required: false + type: string + workflow_dispatch: + inputs: + num_collaborators: + description: "Number of collaborators" + required: false + default: "2" + type: string + python_version: + description: "Python version" + required: false + default: "3.10" + type: choice + options: + - "3.10" + - "3.11" + - "3.12" + +permissions: + contents: read + +# Environment variables common for all the jobs +# DO NOT use double quotes for the values of the environment variables +env: + NUM_COLLABORATORS: ${{ inputs.num_collaborators || 2 }} + COMMIT_ID: ${{ inputs.commit_id || github.sha }} # use commit_id from the calling workflow + +jobs: + test_fed_analytics_histogram: + name: With REST (federated_analytics/histogram, 3.11) + runs-on: ubuntu-22.04 + timeout-minutes: 30 + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') || + (github.event.pull_request.draft == false) + env: + MODEL_NAME: 'federated_analytics/histogram' + PYTHON_VERSION: ${{ inputs.python_version || '3.11' }} + + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Run Federated Analytics Histogram + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_fed_analytics_tests.py --tr_rest_protocol \ + -m task_runner_fed_analytics --model_name ${{ env.MODEL_NAME }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Federated analytics histogram test run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "Sepal_Histogram_Analytics" + + test_fed_analytics_smokers_health: + name: With gRPC (federated_analytics/smokers_health, 3.12) + runs-on: ubuntu-22.04 + timeout-minutes: 30 + if: | + (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + (github.event_name == 'workflow_dispatch') || + (github.event.pull_request.draft == false) + env: + MODEL_NAME: 'federated_analytics/smokers_health' + PYTHON_VERSION: ${{ inputs.python_version || '3.12' }} + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Run Federated Analytics Smokers Health + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_fed_analytics_tests.py \ + -m task_runner_fed_analytics --model_name ${{ env.MODEL_NAME }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Federated analytics smokers health test run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "Smokers_Health_Analytics" \ No newline at end of file diff --git a/.github/workflows/task_runner_fedeval_e2e.yml b/.github/workflows/task_runner_fedeval_e2e.yml index 07aa643edf..83cb29258e 100644 --- a/.github/workflows/task_runner_fedeval_e2e.yml +++ b/.github/workflows/task_runner_fedeval_e2e.yml @@ -80,8 +80,8 @@ jobs: with: test_type: "FedEval_With_TLS" - test_without_tls: - name: Without TLS + test_without_tls_rest: + name: Without TLS Using Rest Protocol if: | # Skip for PR pipeline ((github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch')) && (github.workflow != 'OpenFL PR Pipeline') @@ -115,17 +115,17 @@ jobs: run: | python -m pytest -s tests/end_to_end/test_suites/tr_with_fedeval_tests.py \ -m task_runner_basic --model_name ${{ env.MODEL_NAME }} \ - --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --disable_tls + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --disable_tls --tr_rest_protocol echo "Task runner end to end test run completed" - name: Post test run uses: ./.github/actions/tr_post_test_run if: ${{ always() }} with: - test_type: "FedEval_Without_TLS" + test_type: "FedEval_Without_TLS_REST" - test_without_client_auth: - name: Without Client Auth + test_without_client_auth_rest: + name: Without ClientAuth Using Rest Protocol if: | # Skip for PR pipeline ((github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch')) && (github.workflow != 'OpenFL PR Pipeline') @@ -159,11 +159,11 @@ jobs: run: | python -m pytest -s tests/end_to_end/test_suites/tr_with_fedeval_tests.py \ -m task_runner_basic --model_name ${{ env.MODEL_NAME }} \ - --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --disable_client_auth + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --disable_client_auth --tr_rest_protocol echo "Task runner end to end test run completed" - name: Post test run uses: ./.github/actions/tr_post_test_run if: ${{ always() }} with: - test_type: "FedEval_Without_Client_Auth" + test_type: "FedEval_Without_Client_Auth_REST" diff --git a/.github/workflows/task_runner_resiliency_e2e.yml b/.github/workflows/task_runner_resiliency_e2e.yml index c001e0b4e7..10656094d3 100644 --- a/.github/workflows/task_runner_resiliency_e2e.yml +++ b/.github/workflows/task_runner_resiliency_e2e.yml @@ -53,50 +53,17 @@ env: COMMIT_ID: ${{ inputs.commit_id || github.sha }} # use commit_id from the calling workflow jobs: - input_selection: + resiliency_in_native_gRPC: + name: Resiliency in gRPC (torch/mnist, 3.10) + runs-on: ubuntu-22.04 + timeout-minutes: 30 if: | (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch') || (github.event.pull_request.draft == false) - name: Input value selection - runs-on: ubuntu-22.04 - outputs: - # Output all the variables related to models and python versions to be used in the matrix strategy - # for different jobs, however their usage depends on the selected job. - selected_models_for_tls: ${{ steps.input_selection.outputs.models_for_tls }} - selected_python_for_tls: ${{ steps.input_selection.outputs.python_for_tls }} - steps: - - name: Job to select input values - id: input_selection - run: | - if [ "${{ env.MODEL_NAME }}" == "all" ]; then - echo "models_for_tls=[\"torch/mnist\", \"keras/mnist\"]" >> "$GITHUB_OUTPUT" - else - echo "models_for_tls=[\"${{env.MODEL_NAME}}\"]" >> "$GITHUB_OUTPUT" - fi - if [ "${{ env.PYTHON_VERSION }}" == "all" ]; then - echo "python_for_tls=[\"3.10\", \"3.11\"]" >> "$GITHUB_OUTPUT" - else - echo "python_for_tls=[\"${{env.PYTHON_VERSION}}\"]" >> "$GITHUB_OUTPUT" - fi - - resiliency_in_native: - name: With TLS - needs: input_selection - runs-on: ubuntu-22.04 - timeout-minutes: 30 - strategy: - matrix: - model_name: ${{ fromJson(needs.input_selection.outputs.selected_models_for_tls) }} - python_version: ${{ fromJson(needs.input_selection.outputs.selected_python_for_tls) }} - exclude: # Keras does not support Python 3.12 - - model_name: "keras/mnist" - python_version: "3.12" - fail-fast: false # do not immediately fail if one of the combinations fail - env: - MODEL_NAME: ${{ matrix.model_name }} - PYTHON_VERSION: ${{ matrix.python_version }} + MODEL_NAME: ${{ inputs.model_name || 'torch/mnist' }} + PYTHON_VERSION: ${{ inputs.python_version || '3.10' }} steps: - name: Checkout OpenFL repository @@ -121,6 +88,42 @@ jobs: uses: ./.github/actions/tr_post_test_run if: ${{ always() }} with: - test_type: "Resiliency_Native" + test_type: "Resiliency_Native_gRPC" + + # Uncomment below once Issue is resolved - https://github.com/securefederatedai/openfl/issues/1646 + # resiliency_in_native_rest: + # name: Resiliency in REST (keras/tensorflow/mnist, 3.11) + # runs-on: ubuntu-22.04 + # timeout-minutes: 30 + # if: | + # (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || + # (github.event_name == 'workflow_dispatch') || + # (github.event.pull_request.draft == false) + # env: + # MODEL_NAME: ${{ inputs.model_name || 'keras/tensorflow/mnist' }} + # PYTHON_VERSION: ${{ inputs.python_version || '3.11' }} + + # steps: + # - name: Checkout OpenFL repository + # id: checkout_openfl + # uses: actions/checkout@v4 + # with: + # ref: ${{ env.COMMIT_ID }} + + # - name: Pre test run + # uses: ./.github/actions/tr_pre_test_run + # if: ${{ always() }} + + # - name: Run Task Runner E2E tests with TLS + # id: run_tests + # run: | + # python -m pytest -s tests/end_to_end/test_suites/tr_resiliency_tests.py \ + # -m task_runner_basic --model_name ${{ env.MODEL_NAME }} \ + # --num_collaborators ${{ env.NUM_COLLABORATORS }} --num_rounds ${{ env.NUM_ROUNDS }} --tr_rest_protocol + # echo "Task runner end to end test run completed" - # TODO - Add dockerized approach as well once we have GitHub runners with higher configurations. + # - name: Post test run + # uses: ./.github/actions/tr_post_test_run + # if: ${{ always() }} + # with: + # test_type: "Resiliency_Native_REST" \ No newline at end of file diff --git a/.github/workflows/task_runner_straggler_e2e.yml b/.github/workflows/task_runner_straggler_e2e.yml index 90ab99728a..c93786533c 100644 --- a/.github/workflows/task_runner_straggler_e2e.yml +++ b/.github/workflows/task_runner_straggler_e2e.yml @@ -51,8 +51,9 @@ jobs: with: test_type: "With_TLS_Percentage" + # Move it to Rest once issue is resolved https://github.com/securefederatedai/openfl/issues/1646 test_straggler_cutoff: - name: Cutoff Policy (torch/mnist_straggler_check, 3.10) + name: Cutoff Policy (torch/mnist_straggler_check, 3.11) if: | (github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') || (github.event_name == 'workflow_dispatch') @@ -60,7 +61,7 @@ jobs: timeout-minutes: 30 env: MODEL_NAME: 'torch/mnist_straggler_check' - PYTHON_VERSION: '3.10' + PYTHON_VERSION: '3.11' steps: - name: Checkout OpenFL repository id: checkout_openfl diff --git a/.github/workflows/tr_verifiable_dataset_e2e.yml b/.github/workflows/tr_verifiable_dataset_e2e.yml new file mode 100644 index 0000000000..eb3f61ac08 --- /dev/null +++ b/.github/workflows/tr_verifiable_dataset_e2e.yml @@ -0,0 +1,173 @@ +--- +# Task Runner Verifiable Dataset E2E + +name: TR_Verifiable_Dataset_E2E # Please do not modify the name as it is used in the composite action + +on: + workflow_call: + inputs: + commit_id: + required: false + type: string + workflow_dispatch: + inputs: + num_rounds: + description: "Number of rounds to train" + required: false + default: "2" + type: string + num_collaborators: + description: "Number of collaborators" + required: false + default: "2" + type: string + +permissions: + contents: read + +# Environment variables common for all the jobs +# DO NOT use double quotes for the values of the environment variables +env: + NUM_ROUNDS: ${{ inputs.num_rounds || 2 }} + NUM_COLLABORATORS: ${{ inputs.num_collaborators || 2 }} + COMMIT_ID: ${{ inputs.commit_id || github.sha }} # use commit_id from the calling workflow + +jobs: + test_with_s3: + name: With S3 (torch/histology_s3, 3.11) + runs-on: ubuntu-22.04 + timeout-minutes: 120 + if: ${{ github.workflow != 'OpenFL Product Quality Pipeline' }} + env: + MODEL_NAME: "torch/histology_s3" + PYTHON_VERSION: "3.11" + + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Install MinIO + id: install_minio + run: | + wget https://dl.min.io/server/minio/release/linux-amd64/minio + chmod +x minio + sudo mv minio /usr/local/bin/ + + - name: Install MinIO Client + id: install_minio_client + run: | + wget https://dl.min.io/client/mc/release/linux-amd64/mc + chmod +x mc + sudo mv mc /usr/local/bin/ + + - name: Run Task Runner E2E tests with S3 + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py \ + -m task_runner_with_s3 --model_name ${{ env.MODEL_NAME }} \ + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Task Runner E2E tests with S3 run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "With_S3" + + test_with_azure_blob: + name: With Azure Blob (torch/histology_s3, 3.11) + runs-on: ubuntu-22.04 + timeout-minutes: 120 + if: ${{ github.workflow != 'OpenFL Product Quality Pipeline' }} + env: + MODEL_NAME: "torch/histology_s3" + PYTHON_VERSION: "3.11" + + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Install Azurite + id: install_azurite + run: | + docker pull mcr.microsoft.com/azure-storage/azurite + + - name: Run Task Runner E2E tests with Azure Blob + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py \ + -m task_runner_with_azure_blob --model_name ${{ env.MODEL_NAME }} \ + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Task Runner E2E tests with Azure Blob run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "With_Azure_Blob" + + test_with_all_ds: + name: With All Data Sources (torch/histology_s3, 3.11) + runs-on: ubuntu-22.04 + timeout-minutes: 120 + env: + MODEL_NAME: "torch/histology_s3" + PYTHON_VERSION: "3.11" + + steps: + - name: Checkout OpenFL repository + id: checkout_openfl + uses: actions/checkout@v4 + with: + ref: ${{ env.COMMIT_ID }} + + - name: Pre test run + uses: ./.github/actions/tr_pre_test_run + if: ${{ always() }} + + - name: Install MinIO + id: install_minio + run: | + wget https://dl.min.io/server/minio/release/linux-amd64/minio + chmod +x minio + sudo mv minio /usr/local/bin/ + + - name: Install MinIO Client + id: install_minio_client + run: | + wget https://dl.min.io/client/mc/release/linux-amd64/mc + chmod +x mc + sudo mv mc /usr/local/bin/ + + - name: Install Azurite + id: install_azurite + run: | + docker pull mcr.microsoft.com/azure-storage/azurite + + - name: Run Task Runner E2E tests with all data sources + id: run_tests + run: | + python -m pytest -s tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py \ + -m task_runner_with_all_ds --model_name ${{ env.MODEL_NAME }} \ + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} + echo "Task Runner E2E tests with Azure Blob run completed" + + - name: Post test run + uses: ./.github/actions/tr_post_test_run + if: ${{ always() }} + with: + test_type: "With_All_Datasources" diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 14e9482901..9551ce24c4 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -54,6 +54,24 @@ jobs: run: | python -m tests.github.test_hello_federation --template keras/mnist --fed_workspace aggregator --col1 col1 --col2 col2 --rounds-to-train 3 --save-model output_model + torch_mnist_rest: # from taskrunner.yml - torch/mnist + if: github.event.pull_request.draft == false + runs-on: windows-latest + timeout-minutes: 15 + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install . + - name: Test TaskRunner API + run: | + python -m tests.github.test_hello_federation --template torch/mnist --fed_workspace aggregator --col1 col1 --col2 col2 --rounds-to-train 3 --transport-protocol rest --save-model output_model + torch_mnist_eden_compression: # from taskrunner_eden_pipeline.yml - torch/mnist_eden_compression if: github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'eden_compression') runs-on: windows-latest diff --git a/docs/about/features_index/workflowinterface.rst b/docs/about/features_index/workflowinterface.rst index 7f51b94ba2..7b0b5eb419 100644 --- a/docs/about/features_index/workflowinterface.rst +++ b/docs/about/features_index/workflowinterface.rst @@ -236,6 +236,9 @@ Some important points to remember while creating callback function and private a - If no Callback Function or private attributes is specified then the Participant shall not have any *private attributes* - In above example multiple collaborators have the same callback function or private attributes. Depending on the Federated Learning requirements, user can specify unique callback function or private attributes for each Participant - *Private attributes* needs to be set after instantiating the participant. + - **Known Limitations**: When using a `callable` to initialize *private attributes* that are **not serializable**, users should be aware of following limitations: + * `checkpoint` should not be enabled with `LocalRuntime`. Users should ensure that default (disabled) setting of checkpoint is used or it is explicitly disabled :code:`flow = FederatedFlow( ..., checkpoint = false)` + * filtering of attributes (via `include` or `exclude`) cannot be used during the transition from aggregator step to collaborator steps. This limitation applies to **all attributes** if any non-serializable private attribute is present in aggregator. The flow logic must be updated to avoid filtering in steps that transition control from aggregator to collaborators Now let's see how the runtime for a flow is assigned, and the flow gets run: @@ -558,6 +561,7 @@ In a distributed environment consisting of Director, Envoys and User Node (where **IMPORTANT**: While this information is useful for debugging, depending on your workflow it may require significant disk space. For this reason, checkpoint is disabled by default. + Future Plans ============== Following functionalities are planned for inclusion in future releases of the Workflow Interface: diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/__init__.py b/openfl-workspace/federated_analytics/smokers_health/.workspace similarity index 100% rename from openfl-workspace/flower-app-pytorch/src/grpc/__init__.py rename to openfl-workspace/federated_analytics/smokers_health/.workspace diff --git a/openfl-workspace/federated_analytics/smokers_health/README.md b/openfl-workspace/federated_analytics/smokers_health/README.md new file mode 100644 index 0000000000..2f8fd892ee --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/README.md @@ -0,0 +1,94 @@ +# Federated Analytics: Smokers Health Example + +This workspace demonstrates how to use OpenFL for privacy-preserving analytics on the Smokers Health dataset. The setup enables distributed computation of health statistics (such as heart rate, cholesterol, and blood pressure) across multiple collaborators, without sharing raw data. + +## Instantiating a Workspace from Smokers Health Template +To instantiate a workspace from the `federated_analytics/smokers_health` template, use the `fx workspace create` command. This will set up a new workspace with the required configuration and code. + +1. **Install dependencies:** +```bash +pip install virtualenv +mkdir ~/openfl-smokers-health +virtualenv ~/openfl-smokers-health/venv +source ~/openfl-smokers-health/venv/bin/activate +pip install openfl +``` + +2. **Create the Workspace Folder:** +```bash +cd ~/openfl-smokers-health +fx workspace create --template federated_analytics/smokers_health --prefix fl_workspace +cd ~/openfl-smokers-health/fl_workspace +``` + +## Directory Structure +The workspace has the following structure: +``` +smokers_health +├── requirements.txt +├── .workspace +├── plan +│ ├── plan.yaml +│ ├── cols.yaml +│ ├── data.yaml +│ └── defaults/ +├── src +│ ├── __init__.py +│ ├── dataloader.py +│ ├── taskrunner.py +│ └── aggregate_health.py +├── data/ +└── save/ +``` + +### Directory Breakdown +- **requirements.txt**: Lists all Python dependencies for the workspace. +- **plan/**: Contains configuration files for the federation: + - `plan.yaml`: Main plan declaration. + - `cols.yaml`: List of authorized collaborators. + - `data.yaml`: Data path for each collaborator. + - `defaults/`: Default configuration values. +- **src/**: Python modules for federated analytics: + - `dataloader.py`: Loads and shards the Smokers Health dataset, supports SQL queries. + - `taskrunner.py`: Groups data and computes mean health metrics by age, sex, and smoking status. + - `aggregatehealth.py`: Aggregates results from all collaborators. +- **data/**: Place to store the downloaded and unzipped dataset. +- **save/**: Stores aggregated results and analytics outputs. + +## Data Preparation +The data loader will automatically download the Smokers Health dataset from Kaggle or a specified source. Make sure you have the required access or download the dataset manually if needed. + +## Defining the Data Loader +The data loader supports SQL-like queries and can load data from CSV or other sources as configured. It shards the dataset among collaborators and provides query functionality for analytics tasks. + +## Defining the Task Runner +The task runner groups the data by `age`, `sex`, and `current_smoker`, and computes the mean of `heart_rate`, `chol`, and `blood pressure (systolic/diastolic)`. The results are returned as numpy arrays for aggregation. + +## Running the Federation +1. **Initialize the plan:** +```bash +fx plan initialize +``` +2. **Set up the aggregator and collaborators:** +```bash +fx workspace certify +fx aggregator generate-cert-request +fx aggregator certify --silent + +fx collaborator create -n collaborator1 -d 1 +fx collaborator generate-cert-request -n collaborator1 +fx collaborator certify -n collaborator1 --silent + +fx collaborator create -n collaborator2 -d 2 +fx collaborator generate-cert-request -n collaborator2 +fx collaborator certify -n collaborator2 --silent +``` +3. **Start the federation:** +```bash +fx aggregator start & +fx collaborator start -n collaborator1 & +fx collaborator start -n collaborator2 & +``` + +## License +This project is licensed under the Apache License 2.0. See the LICENSE file for details. \ No newline at end of file diff --git a/openfl-workspace/federated_analytics/smokers_health/plan/cols.yaml b/openfl-workspace/federated_analytics/smokers_health/plan/cols.yaml new file mode 100644 index 0000000000..da14426f1e --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/plan/cols.yaml @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +collaborators: diff --git a/openfl-workspace/federated_analytics/smokers_health/plan/data.yaml b/openfl-workspace/federated_analytics/smokers_health/plan/data.yaml new file mode 100644 index 0000000000..37b1cc2fc8 --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/plan/data.yaml @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +# collaborator_name,data_directory_path diff --git a/openfl-workspace/federated_analytics/smokers_health/plan/plan.yaml b/openfl-workspace/federated_analytics/smokers_health/plan/plan.yaml new file mode 100644 index 0000000000..81056eeede --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/plan/plan.yaml @@ -0,0 +1,45 @@ +aggregator: + defaults: plan/defaults/aggregator.yaml + template: openfl.component.Aggregator + settings: + last_state_path: save/result.json + rounds_to_train: 1 + +collaborator: + defaults: plan/defaults/collaborator.yaml + template: openfl.component.Collaborator + settings: + use_delta_updates: false + opt_treatment: RESET + +data_loader: + defaults: plan/defaults/data_loader.yaml + template: src.dataloader.SmokersHealthDataLoader + settings: + collaborator_count: 2 + data_group_name: smokers_health + batch_size: 150 + +task_runner: + defaults: plan/defaults/task_runner.yaml + template: src.taskrunner.SmokersHealthAnalytics + +network: + defaults: plan/defaults/network.yaml + +assigner: + template: openfl.component.RandomGroupedAssigner + settings: + task_groups: + - name: analytics + percentage: 1.0 + tasks: + - analytics + +tasks: + analytics: + function: analytics + aggregation_type: + template: src.aggregate_health.AggregateHealthMetrics + kwargs: + columns: ['age', 'sex', 'current_smoker', 'heart_rate', 'blood_pressure', 'cigs_per_day', 'chol'] \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/__init__.py b/openfl-workspace/federated_analytics/smokers_health/requirements.txt similarity index 100% rename from openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/__init__.py rename to openfl-workspace/federated_analytics/smokers_health/requirements.txt diff --git a/openfl-workspace/federated_analytics/smokers_health/src/aggregate_health.py b/openfl-workspace/federated_analytics/smokers_health/src/aggregate_health.py new file mode 100644 index 0000000000..9700e820ac --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/src/aggregate_health.py @@ -0,0 +1,35 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +from openfl.interface.aggregation_functions.core import AggregationFunction + + +class AggregateHealthMetrics(AggregationFunction): + """Aggregation logic for Smokers Health analytics.""" + + def call(self, local_tensors, *_) -> dict: + """ + Aggregates local tensors which contains mean of local health metrics such as + heart_rate_mean, cholesterol, systolic_blood_pressure, and + diastolic_blood_pressure which are grouped by age, sex and if they smoke or not. + Each tensor represents local metrics for these health parameters. + + Args: + local_tensors (list): A list of objects, each containing a `tensor` attribute + that represents local means for the health metrics. + *_: Additional arguments (unused). + Returns: + dict: A dictionary containing the aggregated means for each health metric. + Raises: + ValueError: If the input list `local_tensors` is empty, indicating + that there are no metrics to aggregate. + """ + + if not local_tensors: + raise ValueError("No local metrics to aggregate.") + + agg_histogram = np.zeros_like(local_tensors[0].tensor) + for local_tensor in local_tensors: + agg_histogram += local_tensor.tensor / len(local_tensors) + return agg_histogram diff --git a/openfl-workspace/federated_analytics/smokers_health/src/dataloader.py b/openfl-workspace/federated_analytics/smokers_health/src/dataloader.py new file mode 100644 index 0000000000..dcf431c240 --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/src/dataloader.py @@ -0,0 +1,96 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from openfl.federated.data.loader import DataLoader +import pandas as pd +import os +import subprocess + + +class SmokersHealthDataLoader(DataLoader): + """Data Loader for Smokers Health Dataset.""" + + def __init__(self, batch_size, data_path, **kwargs): + super().__init__(**kwargs) + + # If data_path is None, this is being used for model initialization only + if data_path is None: + return + + # Load actual data if a data path is provided + try: + int(data_path) + except ValueError: + raise ValueError( + f"Expected '{data_path}' to be representable as `int`, " + "as it refers to the data shard number used by the collaborator." + ) + + # Download and prepare data + self._download_raw_data() + self.data_shard = self.load_data_shard( + shard_num=int(data_path), **kwargs + ) + + def _download_raw_data(self): + """ + Downloads and extracts the raw data for the smokers' health dataset. + This method performs the following steps: + 1. Downloads the dataset from the specified Kaggle URL using the `curl` command. + 2. Saves the downloaded file as a ZIP archive in the `./data` directory. + 3. Extracts the contents of the ZIP archive into the `data` directory. + """ + + download_path = os.path.expanduser('./data/smokers_health.zip') + subprocess.run( + [ + 'curl', '-L', '-o', download_path, + 'https://www.kaggle.com/api/v1/datasets/download/jaceprater/smokers-health-data' + ], + check=True + ) + + # Unzip the downloaded file into the data directory + subprocess.run(['unzip', '-o', download_path, '-d', 'data'], check=True) + + def load_data_shard(self, shard_num, **kwargs): + """ + Loads data from a CSV file. + This method reads the data from a CSV file located at './data/smoking_health_data_final.csv' + and returns it as a pandas DataFrame. + Returns: + pd.DataFrame: The data loaded from the CSV file. + """ + file_path = os.path.join('data', 'smoking_health_data_final.csv') + df = pd.read_csv(file_path) + + # Split data into shards + shard_size = len(df) // shard_num + start_idx = shard_size * (shard_num - 1) + end_idx = start_idx + shard_size + + return df.iloc[start_idx:end_idx] + + def query(self, columns, **kwargs): + """ + Query the data shard for the specified columns. + Args: + columns (list): A list of column names to query from the data shard. + **kwargs: Additional keyword arguments (currently not used). + Returns: + DataFrame: A DataFrame containing the data for the specified columns. + Raises: + ValueError: If the columns parameter is not a list. + """ + if not isinstance(columns, list): + raise ValueError("Columns parameter must be a list") + return self.data_shard[columns] + + def get_feature_shape(self): + """ + This function is not required and is kept for compatibility. + + Returns: + None + """ + pass diff --git a/openfl-workspace/federated_analytics/smokers_health/src/runner_fa.py b/openfl-workspace/federated_analytics/smokers_health/src/runner_fa.py new file mode 100644 index 0000000000..4db58582a3 --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/src/runner_fa.py @@ -0,0 +1,115 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +""" +Base classes for Federated Analytics. + +This file can serve as a template for creating your own Federated Analytics experiments. +""" + +from openfl.federated.task.runner import TaskRunner +from openfl.utilities import TensorKey +from openfl.utilities.split import split_tensor_dict_for_holdouts + +import logging +import numpy as np + +logger = logging.getLogger(__name__) + + +class FederatedAnalyticsTaskRunner(TaskRunner): + """The base class for Federated Analytics Task Runner.""" + + def __init__(self, **kwargs): + """Initializes the FederatedAnalyticsTaskRunner instance. + + Args: + **kwargs: Additional parameters to pass to the function + """ + super().__init__(**kwargs) + + # Dummy model initialization. Dummy models and weights are used here as placeholders + # to ensure compatibility with the core OpenFL framework, which currently assumes + # the presence of a model for federated learning tasks. + # + # This approach is necessary to support Federated Analytics use cases, which do not + # involve traditional model training, until OpenFL is refactored to accommodate + # broader use cases beyond learning. + # + # For more details, refer to the discussion at: + # https://github.com/securefederatedai/openfl/discussions/1385#discussioncomment-13009961. + self.model = None + + self.model_tensor_names = [] + self.required_tensorkeys_for_function = {} + + def analytics(self, col_name, round_num, **kwargs): + """ + Return analytics result as tensors. + + Args: + col_name (str): collaborator name. + round_num (int): The current round number. + **kwargs: Additional parameters for analysis. + + Returns: + dict: A dictionary of analysis results. + """ + results = self.analytics_task(**kwargs) + tags = ("analytics",) + origin = col_name + output_metric_dict = { + # TensorKey(metric_name, origin, round_num, False, tags): metric_value + TensorKey(metric_name, origin, round_num, False, tags): np.array(metric_value) if not isinstance(metric_value, np.ndarray) else metric_value + for metric_name, metric_value in results.items() + } + return output_metric_dict, output_metric_dict + + def analytics_task(self, **kwargs): + """ + Perform analytics on the provided data. + This method should be implemented by subclasses to perform specific analysis tasks. + Args: + **kwargs: Arbitrary keyword arguments that can be used for analysis. + Raises: + NotImplementedError: If the method is not implemented by a subclass. + """ + raise NotImplementedError + + def get_tensor_dict(self, with_opt_vars, suffix=""): + """ + Get the model weights as a tensor dictionary. + + Args: + with_opt_vars (bool): If we should include the optimizer's status. + suffix (str): Universally. + + Returns: + model_weights (dict): The tensor dictionary. + """ + return {'dummy_tensor': np.float32(1)} + + def get_required_tensorkeys_for_function(self, func_name, **kwargs): + """Get the required tensors for specified function that could be called + as part of a task. + + By default, this is just all of the layers and optimizer of the dummy model. + + Args: + func_name (str): The function name. + **kwargs: Any function arguments. + + Returns: + list: List of TensorKey objects. + """ + return [] + + def initialize_tensorkeys_for_functions(self, with_opt_vars=False): + """ + This function is not required and is kept for compatibility. + + Returns: + None + """ + pass diff --git a/openfl-workspace/federated_analytics/smokers_health/src/taskrunner.py b/openfl-workspace/federated_analytics/smokers_health/src/taskrunner.py new file mode 100644 index 0000000000..8f42691c2b --- /dev/null +++ b/openfl-workspace/federated_analytics/smokers_health/src/taskrunner.py @@ -0,0 +1,84 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from src.runner_fa import FederatedAnalyticsTaskRunner +import pandas as pd +import numpy as np + + +class SmokersHealthAnalytics(FederatedAnalyticsTaskRunner): + """ + Taskrunner class for performing federated analytics on the Smokers Health dataset. + Methods + ------- + analytics(columns, **kwargs) + Groups data by specified columns and calculates averages for selected metrics. + """ + + def analytics_task(self, columns, **kwargs): + """ + Perform analytics on the specified columns and compute aggregated metrics. + Args: + columns (list): List of column names to group data by. + **kwargs: Additional keyword arguments for customization. + Returns: + dict: A dictionary where keys are formatted strings representing group identifiers, + and values are numpy arrays containing aggregated metrics. + """ + # query data + data = self.data_loader.query(columns) + + grouped = data.groupby(['age', 'sex', 'current_smoker']) + + # Convert mean values to numpy arrays if they are not already + result = grouped.agg({ + 'heart_rate': 'mean', + 'chol': 'mean', + 'blood_pressure': lambda x: self.process_blood_pressure(x).iloc[0] + }) + + # Convert the result into the desired format + formatted_result = {} + + keys = ', heart_rate_mean, chol_mean, systolic_blood_pressure_mean, diastolic_blood_pressure_mean' + for index, row in result.iterrows(): + age, sex, current_smoker = index + heart_rate_mean = row['heart_rate'] + chol_mean = row['chol'] + systolic_mean = row['blood_pressure'][0] + diastolic_mean = row['blood_pressure'][1] + combined_key = f"age_{age}_sex_{sex}_current_smoker_{current_smoker} {keys}" + formatted_result[combined_key] = np.array([ + heart_rate_mean, chol_mean, systolic_mean, diastolic_mean + ]) + return formatted_result + + # Process blood pressure data + def process_blood_pressure(self, bp_series): + """ + Processes a series of blood pressure readings and calculates the mean + systolic and diastolic values. + Args: + bp_series (pd.Series): A pandas Series containing blood pressure + readings in the format "systolic/diastolic" (e.g., "120/80"). + Returns: + pd.DataFrame: A DataFrame with two columns: + - 'systolic_mean': The mean of valid systolic values, or None if no valid values exist. + - 'diastolic_mean': The mean of valid diastolic values, or None if no valid values exist. + Notes: + - Invalid or non-numeric blood pressure readings are ignored. + - If all readings are invalid, the resulting means will be None. + """ + + systolic, diastolic = zip(*bp_series.str.split('/').map( + lambda x: ( + float(x[0]) if x[0].replace('.', '', 1).isdigit() else None, + float(x[1]) if x[1].replace('.', '', 1).isdigit() else None + ) + )) + systolic = [s for s in systolic if s is not None] + diastolic = [d for d in diastolic if d is not None] + return pd.DataFrame({ + 'systolic_mean': [sum(systolic) / len(systolic) if systolic else None], + 'diastolic_mean': [sum(diastolic) / len(diastolic) if diastolic else None] + }) diff --git a/openfl-workspace/flower-app-pytorch/README.md b/openfl-workspace/flower-app-pytorch/README.md index c42ee65a61..a18d6059c8 100644 --- a/openfl-workspace/flower-app-pytorch/README.md +++ b/openfl-workspace/flower-app-pytorch/README.md @@ -28,7 +28,7 @@ Then create a certificate authority (CA) fx workspace certify ``` -This will create a workspace in your current working directory called `./my_workspace` as well as install the Flower app defined in `./app-pytorch.` This will be where the experiment takes place. The CA will be used to sign the certificates of the collaborators. +This will create a workspace in your current working directory called `./my_workspace` as well as install the Flower app defined in `./src/app-pytorch.` This will be where the experiment takes place. The CA will be used to sign the certificates of the collaborators. ### Setup Data We will be using CIFAR10 dataset. You can install an automatically partition it into 2 using the `./src/setup_data.py` script provided. @@ -63,44 +63,52 @@ data/ Notice under `./plan`, you will find the familiar OpenFL YAML files to configure the experiment. `cols.yaml` and `data.yaml` will be populated by the collaborators that will run the Flower client app and the respective data shard or directory they will perform their training and testing on. `plan.yaml` configures the experiment itself. The Open-Flower integration makes a few key changes to the `plan.yaml`: -1. Introduction of a new top-level key (`connector`) to configure a newly introduced component called `ConnectorFlower`. This component is run by the aggregator and is responsible for initializing the Flower `SuperLink` and connecting to the OpenFL server. The `SuperLink` parameters can be configured using `connector.settings.superlink_params`. If nothing is supplied, it will simply run `flower-superlink --insecure` with the command's default settings as dictated by Flower. It also includes the option to run the flwr run command via `connector.settings.flwr_run_params`. If `flwr_run_params` are not provided, the user will be expected to run `flwr run ` from the aggregator machine to initiate the experiment. Additionally, the `ConnectorFlower` has an additional setting `connector.settings.automatic_shutdown` which is default set to `True`. When set to `True`, the task runner will shut the SuperNode at the completion of an experiment, otherwise, it will run continuously. +1. Introduction of a new top-level key (`connector`) to configure a newly introduced component called `ConnectorFlower`. This component is run by the aggregator and is responsible for initializing the Flower `SuperLink` and connecting to the OpenFL server. Under `settings`, you will find the parameters for both the `flower-superlink` and `flower run` commands. All parameters are configuration by the user. By default, the `flower-superlink` will be run in `insecure` mode. The default `fleet_api_port` and `exec_api_port` will be automatically assigned, while the `exec_api_port` should be set to match the address configured in `./src/app-pytorch/pyproject.toml`. This is not set dynamically. In addition, since OpenFL handles cross network communication, `superlink_host` is set to a local host by default. For the `flwr run` command, the user should ensure that the `federation_name` and `flwr_app_name` is consistent with what is defined in `./src/` (if different than `app-pytorch`) and `./src//pyproject.toml`. The Flower directory `flwr_dir` is set to save the FAB in `save/.flwr`. Should a user configure this, the save directory must be located inside the workspace. Additionally, the `ConnectorFlower` has a setting `automatic_shutdown` which is default set to `True`. When set to `True`, the task runner will shut the SuperNode at the completion of an experiment, otherwise, it will run continuously. ```yaml connector: defaults: plan/defaults/connector.yaml template: openfl.component.ConnectorFlower settings: - automatic_shutdown: True - superlink_params: - insecure: True - serverappio-api-address: 127.0.0.1:9091 - fleet-api-address: 127.0.0.1:9092 - exec-api-address: 127.0.0.1:9093 - flwr_run_params: - flwr_app_name: "app-pytorch" - federation_name: "local-poc" + automatic_shutdown: true + insecure: true + exec_api_port: 9093 + fleet_api_port: 57085 + serverappio_api_port: 58873 + federation_name: local-poc + flwr_app_name: app-pytorch + flwr_dir: save/.flwr + superlink_host: 127.0.0.1 ``` -2. `FlowerTaskRunner` which will execute the `start_client_adapter` task. This task starts the Flower SuperNode and makes a connection to the OpenFL client. +2. `FlowerTaskRunner` which will execute the `start_client_adapter` task. This task starts the Flower SuperNode and makes a connection to the OpenFL client. In addition, you will notice there are settings for the `flwr_app_name`, `flwr_dir`, and `sgx_enabled`. `flwr_app_name` and `flwr_dir` are for prebuilding and installing the Flower app and should follow the convention as the `Connector` settings. `sgx_enabled` enables secure execution of the Flower `ClientApp` within an Intel® SGX enclave. When set to `True`, the task runner will launch the client app in an isolated process suitable for enclave execution and handle additional setup required for SGX compatibility (see [Running in Intel® SGX Enclave](#running-in-intel®-sgx-enclave) for details). ```yaml task_runner: defaults: plan/defaults/task_runner.yaml - template: openfl.federated.task.runner_flower.FlowerTaskRunner + template: openfl.federated.task.FlowerTaskRunner + settings : + flwr_app_name : app-pytorch + flwr_dir : save/.flwr + sgx_enabled: False ``` 3. `FlowerDataLoader` with similar high-level functionality to other dataloaders. -4. `Task` - we introduce a `tasks_connector.yaml` that will allow the collaborator to connect to Flower framework via the interop server. It also handles the task runner's `start_client_adapter` method, which actually starts the Flower component and interop server. By setting `local_server_port` to 0, the port is dynamically allocated. This is mainly for local experiments to avoid overlapping the ports. +4. `Task` - we introduce a `tasks_connector.yaml` that will allow the collaborator to connect to Flower framework via the interop server. It also handles the task runner's `start_client_adapter` method, which actually starts the Flower component and interop server. In the `settings`, the `interop_server` points to the `FlowerInteropServer` module that will establish connect between the OpenFL client and the Flower `SuperNode`. Like the `SuperLink` the host is set to the local host because OpenFL handles cross network communication. The `interop_server_port` and `clientappio_api_port` are automatically allocated by OpenFL. Setting `local_simulation` to `True` will further offest the ports based on the collaborator names in order to avoid overlapping ports. This is not an issue when collaborators are remote. ```yaml tasks: prepare_for_interop: function: start_client_adapter - kwargs: - interop_server_port: 0 + kwargs: {} settings: - interop_server: src.grpc.connector.flower.interop_server + clientappio_api_port: 59731 + interop_server: openfl.transport.grpc.interop.FlowerInteropServer + interop_server_host: 127.0.0.1 + interop_server_port: 51807 + local_simulation: true + ``` 5.`Collaborator` has an additional setting `interop_mode` which will invoke a callback to prepare the interop server that'll eventually be started by the Task Runner diff --git a/openfl-workspace/flower-app-pytorch/plan/plan.yaml b/openfl-workspace/flower-app-pytorch/plan/plan.yaml index f8d9984e66..a0bf1debf2 100644 --- a/openfl-workspace/flower-app-pytorch/plan/plan.yaml +++ b/openfl-workspace/flower-app-pytorch/plan/plan.yaml @@ -10,18 +10,13 @@ aggregator : write_logs : false connector : - defaults : plan/defaults/connector.yaml - template : src.connector_flower.ConnectorFlower + defaults : plan/defaults/connector_flower.yaml + template : openfl.component.ConnectorFlower settings : - automatic_shutdown : True - superlink_params : - insecure : True - serverappio-api-address : 127.0.0.1:9091 - fleet-api-address : 127.0.0.1:9092 - exec-api-address : 127.0.0.1:9093 - flwr_run_params : - flwr_app_name : "app-pytorch" - federation_name : "local-poc" + exec_api_port : 9093 + flwr_app_name : app-pytorch + federation_name : local-poc + flwr_dir : save/.flwr collaborator : defaults : plan/defaults/collaborator.yaml @@ -31,15 +26,14 @@ collaborator : data_loader : defaults : plan/defaults/data_loader.yaml - template : src.loader.FlowerDataLoader - settings : - collaborator_count : 2 + template : openfl.federated.data.FlowerDataLoader task_runner : defaults : plan/defaults/task_runner.yaml - template : src.runner.FlowerTaskRunner + template : openfl.federated.task.FlowerTaskRunner settings : - flwr_app_name: app-pytorch + flwr_app_name : app-pytorch + flwr_dir : save/.flwr sgx_enabled: False network : @@ -56,9 +50,7 @@ assigner : - prepare_for_interop tasks : - defaults : plan/defaults/tasks_connector.yaml - settings : - interop_server: src.grpc.connector.flower.interop_server + defaults : plan/defaults/tasks_flower.yaml compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml \ No newline at end of file diff --git a/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml b/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml index 32a51b068d..64061c7542 100644 --- a/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml +++ b/openfl-workspace/flower-app-pytorch/src/app-pytorch/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.0" description = "" license = "Apache-2.0" dependencies = [ - "flwr-nightly", + "flwr-nightly==1.19.0.dev20250513", "flwr-datasets[vision]>=0.5.0", "torch==2.5.1", "torchvision==0.20.1", diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/__init__.py b/openfl-workspace/flower-app-pytorch/src/grpc/connector/__init__.py deleted file mode 100644 index 035174e6f2..0000000000 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from src.grpc.connector.utils import get_interop_server diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/utils.py b/openfl-workspace/flower-app-pytorch/src/grpc/connector/utils.py deleted file mode 100644 index 0202346ea9..0000000000 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -import importlib - -def get_interop_server(framework: str = 'Flower') -> object: - if framework == 'Flower': - try: - module = importlib.import_module('src.grpc.connector.flower.interop_server') - return module.FlowerInteropServer - except ImportError: - print("Flower is not installed.") - return None diff --git a/openfl-workspace/flower-app-pytorch/src/util.py b/openfl-workspace/flower-app-pytorch/src/util.py deleted file mode 100644 index 750eff8f2a..0000000000 --- a/openfl-workspace/flower-app-pytorch/src/util.py +++ /dev/null @@ -1,13 +0,0 @@ -import re - -def is_safe_path(path): - """ - Validate the path to ensure it contains only allowed characters. - - Args: - path (str): The path to validate. - - Returns: - bool: True if the path is safe, False otherwise. - """ - return re.match(r'^[\w\-/\.]+$', path) is not None diff --git a/openfl-workspace/workspace/plan/defaults/connector.yaml b/openfl-workspace/workspace/plan/defaults/connector.yaml deleted file mode 100644 index 2b6645d22b..0000000000 --- a/openfl-workspace/workspace/plan/defaults/connector.yaml +++ /dev/null @@ -1 +0,0 @@ -template : openfl.component.Connector \ No newline at end of file diff --git a/openfl-workspace/workspace/plan/defaults/connector_flower.yaml b/openfl-workspace/workspace/plan/defaults/connector_flower.yaml new file mode 100644 index 0000000000..66a7e5e914 --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/connector_flower.yaml @@ -0,0 +1,8 @@ +template : openfl.component.ConnectorFlower +settings : + automatic_shutdown : True + insecure : True + superlink_host : 127.0.0.1 + serverappio_api_port : auto + fleet_api_port : auto + exec_api_port : auto \ No newline at end of file diff --git a/openfl-workspace/workspace/plan/defaults/network.yaml b/openfl-workspace/workspace/plan/defaults/network.yaml index 654667240e..82372e822c 100644 --- a/openfl-workspace/workspace/plan/defaults/network.yaml +++ b/openfl-workspace/workspace/plan/defaults/network.yaml @@ -7,4 +7,5 @@ settings: client_reconnect_interval : 5 require_client_auth : True cert_folder : cert - enable_atomic_connections : False \ No newline at end of file + enable_atomic_connections : False + transport_protocol : grpc diff --git a/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml b/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml deleted file mode 100644 index 71b4db5bda..0000000000 --- a/openfl-workspace/workspace/plan/defaults/tasks_connector.yaml +++ /dev/null @@ -1,4 +0,0 @@ -prepare_for_interop: - function : start_client_adapter - kwargs : - interop_server_port : 0 # interop server port, 0 to dynamically allocate diff --git a/openfl-workspace/workspace/plan/defaults/tasks_flower.yaml b/openfl-workspace/workspace/plan/defaults/tasks_flower.yaml new file mode 100644 index 0000000000..6f0da7841e --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/tasks_flower.yaml @@ -0,0 +1,11 @@ +prepare_for_interop: + function : start_client_adapter + kwargs : + {} + +settings: + interop_server : openfl.transport.grpc.interop.FlowerInteropServer + interop_server_host : 127.0.0.1 + interop_server_port : auto + clientappio_api_port : auto + local_simulation : True diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 5b0a22c487..d9e91981ae 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -3,6 +3,8 @@ """OpenFL Component Module.""" +from importlib import util + from openfl.component.aggregator.aggregator import Aggregator from openfl.component.aggregator.straggler_handling import ( CutoffTimePolicy, @@ -14,3 +16,6 @@ from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner from openfl.component.collaborator.collaborator import Collaborator + +if util.find_spec("flwr") is not None: + from openfl.component.connector.connector_flower import ConnectorFlower diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index bd3b22df34..2269e1fdc9 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -7,6 +7,7 @@ import importlib import logging from enum import Enum +from os.path import splitext from time import sleep from typing import List, Optional @@ -14,7 +15,7 @@ from openfl.databases import TensorDB from openfl.pipelines import NoCompressionPipeline, TensorCodec from openfl.protocols import utils -from openfl.transport.grpc.aggregator_client import AggregatorGRPCClient +from openfl.transport.grpc.aggregator_client import AggregatorClientInterface from openfl.utilities import TensorKey logger = logging.getLogger(__name__) @@ -64,7 +65,7 @@ def __init__( collaborator_name, aggregator_uuid, federation_uuid, - client: AggregatorGRPCClient, + client: AggregatorClientInterface, task_runner, task_config, opt_treatment="RESET", @@ -395,8 +396,10 @@ def prepare_interop_server(self): """ # Initialize the interop server - framework = self.task_config["settings"]["interop_server"] - module = importlib.import_module(framework) + interop_server_template = self.task_config["settings"]["interop_server"] + interop_server_class = splitext(interop_server_template)[1].strip(".") + interop_server_module_path = splitext(interop_server_template)[0] + interop_server_module = importlib.import_module(interop_server_module_path) def receive_message_from_interop(message): """Receive message from interop server.""" @@ -404,5 +407,11 @@ def receive_message_from_interop(message): response = self.client.send_message_to_server(message, self.collaborator_name) return response - interop_server = module.FlowerInteropServer(receive_message_from_interop) + interop_server = getattr(interop_server_module, interop_server_class)( + receive_message_from_interop + ) + # Pass all keys in self.task_config['settings'] through to prepare_for_interop kwargs + self.task_config["prepare_for_interop"]["kwargs"].update( + self.task_config.get("settings", {}) + ) self.task_config["prepare_for_interop"]["kwargs"]["interop_server"] = interop_server diff --git a/openfl/component/connector/__init__.py b/openfl/component/connector/__init__.py new file mode 100644 index 0000000000..482e80e33c --- /dev/null +++ b/openfl/component/connector/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""OpenFL Connector Module.""" + +from importlib import util + +if util.find_spec("flwr") is not None: + from openfl.component.connector.connector_flower import ConnectorFlower diff --git a/openfl-workspace/flower-app-pytorch/src/connector_flower.py b/openfl/component/connector/connector_flower.py similarity index 53% rename from openfl-workspace/flower-app-pytorch/src/connector_flower.py rename to openfl/component/connector/connector_flower.py index afe7057201..24a8d35cb2 100644 --- a/openfl-workspace/flower-app-pytorch/src/connector_flower.py +++ b/openfl/component/connector/connector_flower.py @@ -1,22 +1,19 @@ -from logging import getLogger -logger = getLogger(__name__) +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 -import psutil +import os +import signal import subprocess import sys -import signal +from logging import getLogger -from src.grpc.connector.flower.interop_client import FlowerInteropClient -from src.util import is_safe_path +import psutil -import os +from openfl.transport.grpc.interop import FlowerInteropClient +from openfl.utilities.path_check import is_directory_traversal -flwr_home = os.path.join(os.getcwd(), "save/.flwr") -if not is_safe_path(flwr_home): - raise ValueError("Invalid path for FLWR_HOME") +logger = getLogger(__name__) -os.environ["FLWR_HOME"] = flwr_home -os.makedirs(os.environ["FLWR_HOME"], exist_ok=True) class ConnectorFlower: """ @@ -24,30 +21,66 @@ class ConnectorFlower: This class is responsible for constructing and managing the execution of Flower server commands. """ - def __init__(self, - superlink_params: dict, - flwr_run_params: dict = None, - automatic_shutdown: bool = True, - **kwargs): + def __init__( + self, + superlink_host, + fleet_api_port, + exec_api_port, + serverappio_api_port, + insecure=True, + flwr_app_name=None, + federation_name=None, + automatic_shutdown=True, + flwr_dir=None, + **kwargs, + ): """ Initialize the ConnectorFlower instance by setting up the necessary server commands. Args: - superlink_params (dict): Configuration settings for the Flower server. - flwr_run_params (dict, optional): Parameters for running the Flower application. - automatic_shutdown (bool, optional): Flag to enable automatic shutdown of the server. Defaults to True. + superlink_host (str): Host address for the Flower SuperLink. + fleet_api_port (int): Port for the fleet API. + exec_api_port (int): Port for the exec API. + serverappio_api_port (int): Port for the serverappio API. + insecure (bool): Whether to use insecure connections. Defaults to True. + flwr_app_name (str, optional): Name of the Flower application to run. Defaults to None. + federation_name (str, optional): Name of the federation. Defaults to None. + automatic_shutdown (bool, optional): Whether to enable automatic shutdown. + Defaults to True. + flwr_dir (str, optional): Directory for Flower app within the OpenFL workspace. + Plan.yaml configuration defaults to `save/.flwr` **kwargs: Additional keyword arguments. """ super().__init__() self._process = None + self.flwr_dir = flwr_dir + if is_directory_traversal(self.flwr_dir): + logger.error("Flower app directory path is out of the OpenFL workspace scope.") + sys.exit(1) + else: + os.makedirs(self.flwr_dir, exist_ok=True) + os.environ["FLWR_HOME"] = self.flwr_dir + self.automatic_shutdown = automatic_shutdown self.signal_shutdown_sent = False - self.superlink_params = superlink_params + self.superlink_params = { + "insecure": insecure, + "exec_api_port": exec_api_port, + "fleet_api_port": fleet_api_port, + "serverappio_api_port": serverappio_api_port, + } + self.superlink_host = superlink_host self.flwr_superlink_command = self._build_flwr_superlink_command() - self.flwr_run_params = flwr_run_params + if flwr_app_name is None or federation_name is None: + self.flwr_run_params = None + else: + self.flwr_run_params = { + "flwr_app_name": flwr_app_name, + "federation_name": federation_name, + } self.flwr_run_command = self._build_flwr_run_command() if self.flwr_run_params else None self.interop_client = None @@ -55,12 +88,14 @@ def __init__(self, def get_interop_client(self): """ - Create and return a LocalGRPCClient instance using the superlink parameters. + Create and return a FlowerInteropClient instance using the superlink parameters. Returns: - LocalGRPCClient: An instance configured with the connector address and server rounds. + FlowerInteropClient: An instance configured with the connector address + and server rounds. """ - connector_address = self.superlink_params.get("fleet-api-address", "0.0.0.0:9092") + connector_port = self.superlink_params.get("fleet_api_port") + connector_address = f"{self.superlink_host}:{connector_port}" self.interop_client = FlowerInteropClient(connector_address, self.automatic_shutdown) return self.interop_client @@ -74,20 +109,20 @@ def _build_flwr_superlink_command(self) -> list[str]: command = ["flower-superlink", "--fleet-api-type", "grpc-adapter"] - if "insecure" in self.superlink_params and self.superlink_params["insecure"]: + if self.superlink_params.get("insecure"): command += ["--insecure"] - if "serverappio-api-address" in self.superlink_params: - command += ["--serverappio-api-address", str(self.superlink_params["serverappio-api-address"])] - # flwr default: 0.0.0.0:9091 + serverappio_api_port = self.superlink_params.get("serverappio_api_port") + serverappio_api_address = f"{self.superlink_host}:{serverappio_api_port}" + command += ["--serverappio-api-address", serverappio_api_address] - if "fleet-api-address" in self.superlink_params: - command += ["--fleet-api-address", str(self.superlink_params["fleet-api-address"])] - # flwr default: 0.0.0.0:9092 + fleet_api_port = self.superlink_params.get("fleet_api_port") + fleet_api_address = f"{self.superlink_host}:{fleet_api_port}" + command += ["--fleet-api-address", fleet_api_address] - if "exec-api-address" in self.superlink_params: - command += ["--exec-api-address", str(self.superlink_params["exec-api-address"])] - # flwr default: 0.0.0.0:9093 + exec_api_port = self.superlink_params.get("exec_api_port") + exec_api_address = f"{self.superlink_host}:{exec_api_port}" + command += ["--exec-api-address", exec_api_address] if self.automatic_shutdown: command += ["--isolation", "process"] @@ -105,11 +140,12 @@ def _build_flwr_serverapp_command(self) -> list[str]: """ command = ["flwr-serverapp", "--run-once"] - if "insecure" in self.superlink_params and self.superlink_params["insecure"]: + if self.superlink_params["insecure"]: command += ["--insecure"] - if "serverappio-api-address" in self.superlink_params: - command += ["--serverappio-api-address", str(self.superlink_params["serverappio-api-address"])] + serverappio_api_port = self.superlink_params["serverappio_api_port"] + serverappio_api_address = f"{self.superlink_host}:{serverappio_api_port}" + command += ["--serverappio-api-address", serverappio_api_address] return command @@ -120,7 +156,7 @@ def is_flwr_serverapp_running(self): Returns: bool: True if the ServerApp is running, False otherwise. """ - if not hasattr(self, 'flwr_serverapp_subprocess'): + if not hasattr(self, "flwr_serverapp_subprocess"): logger.debug("[OpenFL Connector] ServerApp was never started.") return False @@ -130,13 +166,19 @@ def is_flwr_serverapp_running(self): if not self.signal_shutdown_sent: self.signal_shutdown_sent = True - logger.info("[OpenFL Connector] Experiment has ended. Sending signal to shut down Flower components.") + logger.info( + "[OpenFL Connector] Experiment has ended. Sending signal " + "to shut down Flower components." + ) return False def _stop_flwr_serverapp(self): """Terminate the `flwr_serverapp` subprocess if it is still active.""" - if hasattr(self, 'flwr_serverapp_subprocess') and self.flwr_serverapp_subprocess.poll() is None: + if ( + hasattr(self, "flwr_serverapp_subprocess") + and self.flwr_serverapp_subprocess.poll() is None + ): logger.debug("[OpenFL Connector] ServerApp still running. Stopping...") self.flwr_serverapp_subprocess.terminate() try: @@ -162,20 +204,35 @@ def _build_flwr_run_command(self) -> list[str]: return command def start(self): - """Launch the `flower-superlink` and `flwr run` subprocesses using the constructed commands.""" + """ + Launch the `flower-superlink` and `flwr run` subprocesses + using the constructed commands. + """ if self._process is None: - logger.info(f"[OpenFL Connector] Starting server process: {' '.join(self.flwr_superlink_command)}") + logger.info( + f"[OpenFL Connector] Starting server process: " + f"{' '.join(self.flwr_superlink_command)}" + ) self._process = subprocess.Popen(self.flwr_superlink_command) logger.info(f"[OpenFL Connector] Server process started with PID: {self._process.pid}") else: logger.info("[OpenFL Connector] Server process is already running.") - if hasattr(self, 'flwr_run_command') and self.flwr_run_command: - logger.info(f"[OpenFL Connector] Starting `flwr run` subprocess: {' '.join(self.flwr_run_command)}") + if hasattr(self, "flwr_run_command") and self.flwr_run_command: + logger.info( + f"[OpenFL Connector] Starting `flwr run` " + f"subprocess: {' '.join(self.flwr_run_command)}" + ) subprocess.run(self.flwr_run_command) - if hasattr(self, 'flwr_serverapp_command') and self.flwr_serverapp_command: - self.interop_client.set_is_flwr_serverapp_running_callback(self.is_flwr_serverapp_running) + if hasattr(self, "flwr_serverapp_command") and self.flwr_serverapp_command: + logger.info( + f"[OpenFL Connector] Starting server app subprocess: " + f"{' '.join(self.flwr_serverapp_command)}" + ) + self.interop_client.set_is_flwr_serverapp_running_callback( + self.is_flwr_serverapp_running + ) self.flwr_serverapp_subprocess = subprocess.Popen(self.flwr_serverapp_command) def stop(self): @@ -183,11 +240,18 @@ def stop(self): self._stop_flwr_serverapp() if self._process: try: - logger.info(f"[OpenFL Connector] Stopping server process with PID: {self._process.pid}...") + logger.info( + f"[OpenFL Connector] Stopping server process with PID: {self._process.pid}..." + ) main_process = psutil.Process(self._process.pid) sub_processes = main_process.children(recursive=True) for sub_process in sub_processes: - logger.info(f"[OpenFL Connector] Stopping server subprocess with PID: {sub_process.pid}...") + logger.info( + ( + f"[OpenFL Connector] Stopping server subprocess " + f"with PID: {sub_process.pid}..." + ) + ) sub_process.terminate() _, still_alive = psutil.wait_procs(sub_processes, timeout=1) for p in still_alive: diff --git a/openfl/experimental/workflow/interface/fl_spec.py b/openfl/experimental/workflow/interface/fl_spec.py index f0713a8497..ac25c5a692 100644 --- a/openfl/experimental/workflow/interface/fl_spec.py +++ b/openfl/experimental/workflow/interface/fl_spec.py @@ -170,13 +170,14 @@ def _setup_initial_state(self) -> None: """ self._metaflow_interface = MetaflowInterface(self.__class__, self.runtime.backend) self._run_id = self._metaflow_interface.create_run() - # Initialize aggregator private attributes - self.runtime.initialize_aggregator() - self._foreach_methods = [] FLSpec._reset_clones() FLSpec._create_clones(self, self.runtime.collaborators) - # Initialize collaborator private attributes + + # Initialize participant private attributes + self.runtime.initialize_aggregator() self.runtime.initialize_collaborators() + self._foreach_methods = [] + if self._checkpoint: print(f"Created flow {self.__class__.__name__}") diff --git a/openfl/federated/__init__.py b/openfl/federated/__init__.py index a8b443c059..e17857d330 100644 --- a/openfl/federated/__init__.py +++ b/openfl/federated/__init__.py @@ -26,6 +26,9 @@ if util.find_spec("xgboost") is not None: from openfl.federated.data import XGBoostDataLoader from openfl.federated.task import XGBoostTaskRunner +if util.find_spec("flwr") is not None: + from openfl.federated.data import FlowerDataLoader + from openfl.federated.task import FlowerTaskRunner __all__ = [ "Plan", diff --git a/openfl/federated/data/__init__.py b/openfl/federated/data/__init__.py index 53e56a7f7d..29667f7b23 100644 --- a/openfl/federated/data/__init__.py +++ b/openfl/federated/data/__init__.py @@ -16,3 +16,6 @@ if util.find_spec("xgboost") is not None: from openfl.federated.data.loader_xgb import XGBoostDataLoader # NOQA + +if util.find_spec("flwr") is not None: + from openfl.federated.data.loader_flower import FlowerDataLoader # NOQA diff --git a/openfl-workspace/flower-app-pytorch/src/loader.py b/openfl/federated/data/loader_flower.py similarity index 96% rename from openfl-workspace/flower-app-pytorch/src/loader.py rename to openfl/federated/data/loader_flower.py index 0b63f60af0..1a4305b198 100644 --- a/openfl-workspace/flower-app-pytorch/src/loader.py +++ b/openfl/federated/data/loader_flower.py @@ -1,11 +1,12 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """FlowerDataLoader module.""" -from openfl.federated.data.loader import DataLoader import os +from openfl.federated.data.loader import DataLoader + class FlowerDataLoader(DataLoader): """Flower Dataloader @@ -25,7 +26,7 @@ def __init__(self, data_path, **kwargs): Raises: FileNotFoundError: If the specified data path does not exist. - """ + """ super().__init__(**kwargs) if not os.path.exists(data_path): raise FileNotFoundError(f"The specified data path does not exist: {data_path}") diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 9e0e0a96d3..ec76e34f5e 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -15,8 +15,13 @@ from openfl.interface.aggregation_functions import AggregationFunction, WeightedAverage from openfl.interface.cli_helper import WORKSPACE -from openfl.transport import AggregatorGRPCClient, AggregatorGRPCServer -from openfl.utilities.utils import getfqdn_env +from openfl.transport import ( + AggregatorGRPCClient, + AggregatorGRPCServer, + AggregatorRESTClient, + AggregatorRESTServer, +) +from openfl.utilities.utils import generate_port, getfqdn_env SETTINGS = "settings" TEMPLATE = "template" @@ -312,9 +317,18 @@ def resolve(self): self.config["network"][SETTINGS]["agg_addr"] = getfqdn_env() if self.config["network"][SETTINGS]["agg_port"] == AUTO: - self.config["network"][SETTINGS]["agg_port"] = ( - int(self.hash[:8], 16) % (60999 - 49152) + 49152 - ) + self.config["network"][SETTINGS]["agg_port"] = generate_port(self.hash) + + if "connector" in self.config: + # automatically generate ports for Flower interoperability components + # if they are set to AUTO + for key, value in self.config["connector"][SETTINGS].items(): + if value == AUTO: + self.config["connector"][SETTINGS][key] = generate_port(self.hash) + + for key, value in self.config["tasks"][SETTINGS].items(): + if value == AUTO: + self.config["tasks"][SETTINGS][key] = generate_port(self.hash) def get_assigner(self): """Get the plan task assigner.""" @@ -537,8 +551,6 @@ def get_collaborator( else: defaults[SETTINGS]["client"] = self.get_client( collaborator_name, - self.aggregator_uuid, - self.federation_uuid, root_certificate, private_key, certificate, @@ -552,13 +564,11 @@ def get_collaborator( def get_client( self, collaborator_name, - aggregator_uuid, - federation_uuid, root_certificate=None, private_key=None, certificate=None, ): - """Get gRPC client for the specified collaborator. + """Get gRPC or REST client for the specified collaborator. Args: collaborator_name (str): Name of the collaborator. @@ -572,8 +582,38 @@ def get_client( Defaults to None. Returns: - AggregatorGRPCClient: gRPC client for the specified collaborator. + AggregatorGRPCClient or AggregatorRESTClient: gRPC or REST client for the collaborator. """ + client_args = self.get_client_args( + collaborator_name, + root_certificate, + private_key, + certificate, + ) + network_cfg = self.config["network"][SETTINGS] + protocol = network_cfg.get("transport_protocol", "grpc").lower() + + if self.client_ is None: + self.client_ = self._get_client(protocol, **client_args) + + return self.client_ + + def _get_client(self, protocol, **kwargs): + if protocol == "rest": + client = AggregatorRESTClient(**kwargs) + elif protocol == "grpc": + client = AggregatorGRPCClient(**kwargs) + else: + raise ValueError(f"Unsupported transport_protocol '{protocol}'") + return client + + def get_client_args( + self, + collaborator_name, + root_certificate=None, + private_key=None, + certificate=None, + ): common_name = collaborator_name if not root_certificate or not private_key or not certificate: root_certificate = "cert/cert_chain.crt" @@ -588,14 +628,10 @@ def get_client( client_args["certificate"] = certificate client_args["private_key"] = private_key - client_args["aggregator_uuid"] = aggregator_uuid - client_args["federation_uuid"] = federation_uuid + client_args["aggregator_uuid"] = self.aggregator_uuid + client_args["federation_uuid"] = self.federation_uuid client_args["collaborator_name"] = collaborator_name - - if self.client_ is None: - self.client_ = AggregatorGRPCClient(**client_args) - - return self.client_ + return client_args def get_server( self, @@ -604,7 +640,7 @@ def get_server( certificate=None, **kwargs, ): - """Get gRPC server of the aggregator instance. + """Get gRPC or REST server of the aggregator instance. Args: root_certificate (str, optional): Root certificate for the server. @@ -616,8 +652,29 @@ def get_server( **kwargs: Additional keyword arguments. Returns: - AggregatorGRPCServer: gRPC server of the aggregator instance. + Aggregator Server: returns either gRPC or REST server of the aggregator instance. """ + server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs) + + server_args["aggregator"] = self.get_aggregator() + network_cfg = self.config["network"][SETTINGS] + protocol = network_cfg.get("transport_protocol", "grpc").lower() + + if self.server_ is None: + self.server_ = self._get_server(protocol, **server_args) + + return self.server_ + + def _get_server(self, protocol, **kwargs): + if protocol == "rest": + server = AggregatorRESTServer(**kwargs) + elif protocol == "grpc": + server = AggregatorGRPCServer(**kwargs) + else: + raise ValueError(f"Unsupported transport_protocol '{protocol}'") + return server + + def get_server_args(self, root_certificate, private_key, certificate, kwargs): common_name = self.config["network"][SETTINGS]["agg_addr"].lower() if not root_certificate or not private_key or not certificate: @@ -633,13 +690,7 @@ def get_server( server_args["root_certificate"] = root_certificate server_args["certificate"] = certificate server_args["private_key"] = private_key - - server_args["aggregator"] = self.get_aggregator() - - if self.server_ is None: - self.server_ = AggregatorGRPCServer(**server_args) - - return self.server_ + return server_args def save_model_to_state_file(self, tensor_dict, round_number, output_path): """Save model weights to a protobuf state file. diff --git a/openfl/federated/task/__init__.py b/openfl/federated/task/__init__.py index 7d1d7dfaeb..1763b3c54d 100644 --- a/openfl/federated/task/__init__.py +++ b/openfl/federated/task/__init__.py @@ -14,3 +14,5 @@ from openfl.federated.task.runner_pt import PyTorchTaskRunner # NOQA if util.find_spec("xgboost") is not None: from openfl.federated.task.runner_xgb import XGBoostTaskRunner # NOQA +if util.find_spec("flwr") is not None: + from openfl.federated.task.runner_flower import FlowerTaskRunner # NOQA diff --git a/openfl-workspace/flower-app-pytorch/src/runner.py b/openfl/federated/task/runner_flower.py similarity index 59% rename from openfl-workspace/flower-app-pytorch/src/runner.py rename to openfl/federated/task/runner_flower.py index 9fdbd8d619..dc4f4a598c 100644 --- a/openfl-workspace/flower-app-pytorch/src/runner.py +++ b/openfl/federated/task/runner_flower.py @@ -1,19 +1,23 @@ -from openfl.federated.task.runner import TaskRunner +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import logging +import os +import socket import subprocess -from logging import getLogger +import sys import time -import os -import numpy as np from pathlib import Path -import socket -from src.util import is_safe_path -flwr_home = os.path.join(os.getcwd(), "save/.flwr") -if not is_safe_path(flwr_home): - raise ValueError("Invalid path for FLWR_HOME") +import numpy as np + +from openfl.federated.task.runner import TaskRunner +from openfl.utilities.path_check import is_directory_traversal +from openfl.utilities.utils import generate_port + +logger = logging.getLogger(__name__) -os.environ["FLWR_HOME"] = flwr_home -os.makedirs(os.environ["FLWR_HOME"], exist_ok=True) class FlowerTaskRunner(TaskRunner): """ @@ -24,6 +28,7 @@ class FlowerTaskRunner(TaskRunner): in a subprocess. It provides options for both manual and automatic shutdown based on subprocess activity. """ + def __init__(self, **kwargs): """ Initialize the FlowerTaskRunner. @@ -33,29 +38,28 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) + self.flwr_dir = kwargs.get("flwr_dir") + if is_directory_traversal(self.flwr_dir): + logger.error("Flower app directory path is out of the OpenFL workspace scope.") + sys.exit(1) + else: + os.makedirs(self.flwr_dir, exist_ok=True) + os.environ["FLWR_HOME"] = self.flwr_dir + if self.data_loader is None: - flwr_app_name = kwargs.get('flwr_app_name') + flwr_app_name = kwargs.get("flwr_app_name") install_flower_FAB(flwr_app_name) return - self.sgx_enabled = kwargs.get('sgx_enabled') + self.sgx_enabled = kwargs.get("sgx_enabled") self.model = None - self.logger = getLogger(__name__) self.data_path = self.data_loader.get_node_configs() - self.client_port = kwargs.get('client_port') - if self.client_port is None: - self.client_port = get_dynamic_port() - self.shutdown_requested = False # Flag to signal shutdown - def start_client_adapter(self, - col_name=None, - round_num=None, - input_tensor_dict=None, - **kwargs): + def start_client_adapter(self, col_name=None, round_num=None, input_tensor_dict=None, **kwargs): """ Start the FlowerInteropServer and the Flower SuperNode. @@ -66,27 +70,43 @@ def start_client_adapter(self, **kwargs: Additional parameters for configuration. includes: interop_server (object): The FlowerInteropServer instance. + interop_server_host (str): The address of the interop server. + clientappio_api_port (int): The port for the clientappio API. + local_simulation (bool): Flag for local simulation to dynamically adjust ports. interop_server_port (int): The port for the interop server. """ def message_callback(): self.shutdown_requested = True - interop_server = kwargs.get('interop_server') - interop_server_port = kwargs.get('interop_server_port') - interop_server.set_end_experiment_callback(message_callback) - interop_server.start_server(interop_server_port) + interop_server = kwargs.get("interop_server") + interop_server_host = kwargs.get("interop_server_host") + interop_server_port = kwargs.get("interop_server_port") + clientappio_api_port = kwargs.get("clientappio_api_port") + + if kwargs.get("local_simulation"): + # Dynamically adjust ports for local simulation + logger.info(f"Adjusting ports for local simulation: {col_name}") + + interop_server_port = get_dynamic_port(interop_server_port, col_name) + clientappio_api_port = get_dynamic_port(clientappio_api_port, col_name) - # interop server sets port dynamically - interop_server_port = interop_server.get_port() + logger.info(f"Adjusted interop_server_port: {interop_server_port}") + logger.info(f"Adjusted clientappio_api_port: {clientappio_api_port}") + + interop_server.set_end_experiment_callback(message_callback) + interop_server.start_server(interop_server_host, interop_server_port) command = [ "flower-supernode", "--insecure", "--grpc-adapter", - "--superlink", f"127.0.0.1:{interop_server_port}", - "--clientappio-api-address", f"127.0.0.1:{self.client_port}", - "--node-config", f"data-path='{self.data_path}'" + "--superlink", + f"{interop_server_host}:{interop_server_port}", + "--clientappio-api-address", + f"{interop_server_host}:{clientappio_api_port}", + "--node-config", + f"data-path='{self.data_path}'", ] if self.sgx_enabled: @@ -94,34 +114,35 @@ def message_callback(): flwr_clientapp_command = [ "flwr-clientapp", "--insecure", - "--clientappio-api-address", f"127.0.0.1:{self.client_port}", + "--clientappio-api-address", + f"{interop_server_host}:{clientappio_api_port}", ] - self.logger.info("Starting Flower SuperNode process...") + logger.info("Starting Flower SuperNode process...") supernode_process = subprocess.Popen(command, shell=False) interop_server.handle_signals(supernode_process) if self.sgx_enabled: # Check if port is open before starting the client app - while not is_port_open('127.0.0.1', interop_server_port): + while not is_port_open(interop_server_host, interop_server_port): time.sleep(0.5) - time.sleep(1) # Add a small delay after confirming the port is open + time.sleep(1) # Add a small delay after confirming the port is open - self.logger.info("Starting Flower ClientApp process...") + logger.info("Starting Flower ClientApp process...") flwr_clientapp_process = subprocess.Popen(flwr_clientapp_command, shell=False) interop_server.handle_signals(flwr_clientapp_process) - self.logger.info("Press CTRL+C to stop the server and SuperNode process.") + logger.info("Press CTRL+C to stop the server and SuperNode process.") while not interop_server.termination_event.is_set(): if self.shutdown_requested: if self.sgx_enabled: - self.logger.info("Terminating Flower ClientApp process...") + logger.info("Terminating Flower ClientApp process...") interop_server.terminate_supernode_process(flwr_clientapp_process) flwr_clientapp_process.wait() - self.logger.info("Shutting down the server and SuperNode process...") + logger.info("Shutting down the server and SuperNode process...") interop_server.terminate_supernode_process(supernode_process) interop_server.stop_server() time.sleep(0.1) @@ -133,8 +154,6 @@ def message_callback(): return global_output_tensor_dict, local_output_tensor_dict - - def set_tensor_dict(self, tensor_dict, with_opt_vars=False): """ Set the tensor dictionary for the task runner. @@ -169,7 +188,7 @@ def save_native(self, filepath, **kwargs): if isinstance(filepath, Path): filepath = str(filepath) - assert filepath.endswith('.npz'), "Currently, only '.npz' file type is supported." + assert filepath.endswith(".npz"), "Currently, only '.npz' file type is supported." # Save the tensor dictionary to a .npz file np.savez(filepath, **self.tensor_dict) @@ -182,54 +201,43 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): """Get tensor keys for functions. Return empty dict.""" return {} + def install_flower_FAB(flwr_app_name): """ - Build and install the patch for the Flower application. + Build and install Flower application. Args: - flwr_app_name (str): The name of the Flower application to patch. + flwr_app_name (str): The name of the Flower application. """ - flwr_dir = os.environ["FLWR_HOME"] - - # Change the current working directory to the Flower directory - os.chdir(flwr_dir) - # Run the build command - build_command = [ - "flwr", - "build", - "--app", - os.path.join("..", "..", "src", flwr_app_name) - ] + build_command = ["flwr", "build", "--app", os.path.join("src", flwr_app_name)] subprocess.check_call(build_command) # List .fab files after running the build command - fab_files = list(Path(flwr_dir).glob("*.fab")) + fab_files = list(Path.cwd().glob("*.fab")) # Determine the newest .fab file newest_fab_file = max(fab_files, key=os.path.getmtime) # Run the install command using the newest .fab file - subprocess.check_call([ - "flwr", - "install", - str(newest_fab_file) - ]) + install_command = ["flwr", "install", str(newest_fab_file)] + subprocess.check_call(install_command) + os.remove(newest_fab_file) + -def get_dynamic_port(): +def get_dynamic_port(base_port, collaborator_name): """ - Get a dynamically assigned port number. + Get a dynamically assigned port number based on collaborator name and base port. + This is only necessary for local simulation in order to avoid port conflicts. Returns: - int: An available port number assigned by the operating system. + int: The dynamically assigned port number. """ - # Create a socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - # Bind to port 0 to let the OS assign an available port - s.bind(('127.0.0.1', 0)) - # Get the assigned port number - port = s.getsockname()[1] - return port + combined_string = f"{base_port}--{collaborator_name}" + hash_object = hashlib.md5(combined_string.encode()) + hash_value = hash_object.hexdigest() + return generate_port(hash_value) + def is_port_open(host, port): """Check if a port is open on the given host.""" diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 043216d9f2..16dc48e9bf 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -92,8 +92,8 @@ def start_(plan, authorized_cols, task_group): logger.info(f"Setting aggregator to assign: {task_group} task_group") logger.info("🧿 Starting the Aggregator Service.") - - parsed_plan.get_server().serve() + server = parsed_plan.get_server() + server.serve() @aggregator.command(name="generate-cert-request") diff --git a/openfl/protocols/aggregator_client_interface.py b/openfl/protocols/aggregator_client_interface.py new file mode 100644 index 0000000000..e1b8d02d3c --- /dev/null +++ b/openfl/protocols/aggregator_client_interface.py @@ -0,0 +1,69 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AggregatorClientInterface module.""" + +from abc import ABC, abstractmethod +from typing import Any, List, Tuple + + +class AggregatorClientInterface(ABC): + @abstractmethod + def ping(self): + """ + Ping the aggregator to check connectivity. + """ + pass + + @abstractmethod + def get_tasks(self) -> Tuple[List[Any], int, int, bool]: + """ + Retrieves tasks for the given collaborator client. + Returns a tuple: (tasks, round_number, sleep_time, time_to_quit) + """ + pass + + @abstractmethod + def get_aggregated_tensors( + self, + tensor_keys, + require_lossless: bool = True, + ) -> Any: + """ + Retrieves the aggregated tensor. + """ + pass + + @abstractmethod + def send_local_task_results( + self, + round_number: int, + task_name: str, + data_size: int, + named_tensors: List[Any], + ) -> Any: + """ + Sends local task results. + Parameters: + collaborator_name: Name of the collaborator. + round_number: The current round. + task_name: Name of the task. + data_size: Size of the data. + named_tensors: A list of tensors (or named tensor objects). + Returns a SendLocalTaskResultsResponse. + """ + pass + + @abstractmethod + def send_message_to_server(self, openfl_message: Any, collaborator_name: str) -> Any: + """ + Forwards a converted message from the local client to the OpenFL server and returns the + response. + Args: + openfl_message: The converted message to be sent to the OpenFL server (InteropMessage + proto). + collaborator_name: The name of the collaborator. + Returns: + The response from the OpenFL server (InteropMessage proto). + """ + pass diff --git a/openfl/transport/__init__.py b/openfl/transport/__init__.py index 72bc7864c8..b757223351 100644 --- a/openfl/transport/__init__.py +++ b/openfl/transport/__init__.py @@ -1,5 +1,6 @@ -# Copyright 2020-2024 Intel Corporation +# Copyright 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from openfl.transport.grpc import AggregatorGRPCClient, AggregatorGRPCServer +from openfl.transport.rest import AggregatorRESTClient, AggregatorRESTServer diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index cb4a1b18f5..8698044eeb 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -11,6 +11,7 @@ import grpc from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, base_pb2, utils +from openfl.protocols.aggregator_client_interface import AggregatorClientInterface from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel logger = logging.getLogger(__name__) @@ -165,9 +166,11 @@ def wrapper(self, *args, **kwargs): return wrapper -class AggregatorGRPCClient: +class AggregatorGRPCClient(AggregatorClientInterface): """Collaborator-side gRPC client that talks to the aggregator. + This class implements a gRPC client for communicating with an aggregator. + Attributes: agg_addr (str): Aggregator address. agg_port (int): Aggregator port. diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index ebe89f0ca6..ccf5ce410d 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -224,12 +224,6 @@ def GetAggregatedTensors(self, request, context): aggregator_pb2.GetAggregatedTensorsResponse: The response to the request, containing the aggregated tensors as list of `NamedTensor`s. """ - if self.interop_mode: - context.abort( - grpc.StatusCode.UNIMPLEMENTED, - "This method is not available in framework interoperability mode.", - ) - self.validate_collaborator(request, context) self.check_request(request) diff --git a/openfl/transport/grpc/interop/__init__.py b/openfl/transport/grpc/interop/__init__.py new file mode 100644 index 0000000000..d481116591 --- /dev/null +++ b/openfl/transport/grpc/interop/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from importlib import util + +if util.find_spec("flwr") is not None: + from openfl.transport.grpc.interop.flower.interop_client import FlowerInteropClient + from openfl.transport.grpc.interop.flower.interop_server import FlowerInteropServer diff --git a/openfl/transport/grpc/interop/flower/__init__.py b/openfl/transport/grpc/interop/flower/__init__.py new file mode 100644 index 0000000000..d481116591 --- /dev/null +++ b/openfl/transport/grpc/interop/flower/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from importlib import util + +if util.find_spec("flwr") is not None: + from openfl.transport.grpc.interop.flower.interop_client import FlowerInteropClient + from openfl.transport.grpc.interop.flower.interop_server import FlowerInteropServer diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_client.py b/openfl/transport/grpc/interop/flower/interop_client.py similarity index 82% rename from openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_client.py rename to openfl/transport/grpc/interop/flower/interop_client.py index 7159baf9e7..1f21ddbe38 100644 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_client.py +++ b/openfl/transport/grpc/interop/flower/interop_client.py @@ -1,7 +1,14 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import grpc from flwr.proto import grpcadapter_pb2_grpc -from src.grpc.connector.flower.message_conversion import flower_to_openfl_message, openfl_to_flower_message -from logging import getLogger + +from openfl.transport.grpc.interop.flower.message_conversion import ( + flower_to_openfl_message, + openfl_to_flower_message, +) + class FlowerInteropClient: """ @@ -9,6 +16,7 @@ class FlowerInteropClient: and the OpenFL Server. It converts messages between OpenFL and Flower formats and handles the send-receive communication with the Flower SuperNode using gRPC. """ + def __init__(self, superlink_address, automatic_shutdown=False): """ Initialize. @@ -23,8 +31,6 @@ def __init__(self, superlink_address, automatic_shutdown=False): self.end_experiment = False self.is_flwr_serverapp_running_callback = None - self.logger = getLogger(__name__) - def set_is_flwr_serverapp_running_callback(self, is_flwr_serverapp_running_callback): self.is_flwr_serverapp_running_callback = is_flwr_serverapp_running_callback @@ -47,8 +53,8 @@ def send_receive(self, openfl_message, header): # then the experiment has completed self.end_experiment = not self.is_flwr_serverapp_running_callback() - openfl_response = flower_to_openfl_message(flower_response, - header=header, - end_experiment=self.end_experiment) + openfl_response = flower_to_openfl_message( + flower_response, header=header, end_experiment=self.end_experiment + ) return openfl_response diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_server.py b/openfl/transport/grpc/interop/flower/interop_server.py similarity index 84% rename from openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_server.py rename to openfl/transport/grpc/interop/flower/interop_server.py index 16f104b576..732c5c8341 100644 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/interop_server.py +++ b/openfl/transport/grpc/interop/flower/interop_server.py @@ -1,14 +1,22 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import logging -import threading import queue -import grpc +import signal +import threading +import time from concurrent.futures import ThreadPoolExecutor -from flwr.proto import grpcadapter_pb2_grpc -from src.grpc.connector.flower.message_conversion import flower_to_openfl_message, openfl_to_flower_message from multiprocessing import cpu_count -import signal + +import grpc import psutil -import time +from flwr.proto import grpcadapter_pb2_grpc + +from openfl.transport.grpc.interop.flower.message_conversion import ( + flower_to_openfl_message, + openfl_to_flower_message, +) logger = logging.getLogger(__name__) @@ -26,7 +34,8 @@ def __init__(self, send_message_to_client): Initialize. Args: - send_message_to_client (Callable): A callable function to send messages to the OpenFL client. + send_message_to_client (Callable): A callable function to send messages + to the OpenFL client. """ self.send_message_to_client = send_message_to_client self.end_experiment_callback = None @@ -41,18 +50,14 @@ def __init__(self, send_message_to_client): def set_end_experiment_callback(self, callback): self.end_experiment_callback = callback - def start_server(self, local_server_port): + def start_server(self, interop_server_host, interop_server_port): """Starts the gRPC server.""" self.server = grpc.server(ThreadPoolExecutor(max_workers=cpu_count())) grpcadapter_pb2_grpc.add_GrpcAdapterServicer_to_server(self, self.server) - self.port = self.server.add_insecure_port(f'[::]:{local_server_port}') + self.port = self.server.add_insecure_port(f"{interop_server_host}:{interop_server_port}") self.server.start() logger.info(f"OpenFL local gRPC server started, listening on port {self.port}.") - def get_port(self): - # Return the port that was assigned - return self.port - def stop_server(self): """Stops the gRPC server.""" if self.server: @@ -62,7 +67,10 @@ def stop_server(self): self.termination_event.set() def SendReceive(self, request, context): - """ Handles incoming gRPC requests by putting them into the request queue and waiting for the response. + """ + Handles incoming gRPC requests by putting them into the request + queue and waiting for the response. + Args: request: The incoming gRPC request. context: The gRPC context. @@ -87,8 +95,8 @@ def process_queue(self): openfl_response = self.send_message_to_client(openfl_request) # Check to end experiment - if hasattr(openfl_response, 'metadata'): - if openfl_response.metadata['end_experiment'] == 'True': + if hasattr(openfl_response, "metadata"): + if openfl_response.metadata["end_experiment"] == "True": self.end_experiment_callback() # Send response to Flower client @@ -98,6 +106,7 @@ def process_queue(self): def handle_signals(self, supernode_process): """Sets up signal handlers for graceful shutdown.""" + def signal_handler(_sig, _frame): self.terminate_supernode_process(supernode_process) self.stop_server() @@ -132,7 +141,10 @@ def terminate_process(self, process, timeout=5): process.terminate() process.wait(timeout=timeout) except psutil.TimeoutExpired: - logger.debug(f"Timeout expired while waiting for process {process.pid} to terminate. Killing the process.") + logger.debug( + f"Timeout expired while waiting for process {process.pid} " + "to terminate. Killing the process." + ) process.kill() except psutil.NoSuchProcess: logger.debug(f"Process {process.pid} does not exist. Skipping.") diff --git a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/message_conversion.py b/openfl/transport/grpc/interop/flower/message_conversion.py similarity index 59% rename from openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/message_conversion.py rename to openfl/transport/grpc/interop/flower/message_conversion.py index d900b83cf0..d46998f526 100644 --- a/openfl-workspace/flower-app-pytorch/src/grpc/connector/flower/message_conversion.py +++ b/openfl/transport/grpc/interop/flower/message_conversion.py @@ -1,9 +1,18 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import logging + from flwr.proto import grpcadapter_pb2 +from google.protobuf.message import DecodeError + from openfl.protocols import aggregator_pb2 -def flower_to_openfl_message(flower_message, - header=None, - end_experiment=False): +logger = logging.getLogger(__name__) + + +def flower_to_openfl_message(flower_message, header=None, end_experiment=False): """ Convert a Flower MessageContainer to an OpenFL InteropMessage. @@ -25,6 +34,14 @@ def flower_to_openfl_message(flower_message, # If the input is already an OpenFL message, return it as-is return flower_message else: + # Check if the Flower message can be deserialized, log a warning if not + try: + deserialized_message = deserialize_flower_message(flower_message) + if deserialized_message is None: + logger.warning("Failed to introspect Flower message.") + except Exception as e: + logger.warning(f"Exception during Flower message introspection: {e}") + # Create the OpenFL message openfl_message = aggregator_pb2.InteropMessage() # Set the MessageHeader fields based on the provided sender and receiver @@ -40,6 +57,7 @@ def flower_to_openfl_message(flower_message, openfl_message.metadata.update({"end_experiment": str(end_experiment)}) return openfl_message + def openfl_to_flower_message(openfl_message): """ Convert an OpenFL InteropMessage to a Flower MessageContainer. @@ -63,3 +81,43 @@ def openfl_to_flower_message(openfl_message): flower_message = grpcadapter_pb2.MessageContainer() flower_message.ParseFromString(openfl_message.message.npbytes) return flower_message + + +def deserialize_flower_message(flower_message): + """ + Deserialize the grpc_message_content of a Flower message using the module and class name + specified in the metadata. + + Args: + flower_message: The Flower message containing the metadata and binary content. + + Returns: + The deserialized message object, or None if deserialization fails. + """ + # Access metadata directly + metadata = flower_message.metadata + module_name = metadata.get("grpc-message-module") + qualname = metadata.get("grpc-message-qualname") + + # Import the module + try: + module = importlib.import_module(module_name) + except ImportError as e: + print(f"Failed to import module: {module_name}. Error: {e}") + return None + + # Get the message class + try: + message_class = getattr(module, qualname) + except AttributeError as e: + print(f"Failed to get message class '{qualname}' from module '{module_name}'. Error: {e}") + return None + + # Deserialize the content + try: + message = message_class.FromString(flower_message.grpc_message_content) + except DecodeError as e: + print(f"Failed to deserialize message content. Error: {e}") + return None + + return message diff --git a/openfl/transport/rest/__init__.py b/openfl/transport/rest/__init__.py new file mode 100644 index 0000000000..633f6dae84 --- /dev/null +++ b/openfl/transport/rest/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from openfl.transport.rest.aggregator_client import AggregatorRESTClient +from openfl.transport.rest.aggregator_server import AggregatorRESTServer diff --git a/openfl/transport/rest/aggregator_client.py b/openfl/transport/rest/aggregator_client.py new file mode 100644 index 0000000000..8f07b928f1 --- /dev/null +++ b/openfl/transport/rest/aggregator_client.py @@ -0,0 +1,667 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AggregatorRESTClient module.""" + +# Standard library imports +import logging +import ssl +import struct +import time +from typing import Any, List, Tuple + +# Third-party libraries +import requests +from google.protobuf import json_format +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +# Internal modules +from openfl.protocols import aggregator_pb2, base_pb2 +from openfl.protocols.aggregator_client_interface import AggregatorClientInterface + +logger = logging.getLogger(__name__) + + +class SecurityError(Exception): + """Security-related error.""" + + pass + + +class AggregatorRESTClient(AggregatorClientInterface): + def __init__( + self, + agg_addr, + agg_port, + aggregator_uuid: str, + federation_uuid: str, + collaborator_name: str, + use_tls=True, + require_client_auth=True, + root_certificate=None, + certificate=None, + private_key=None, + single_col_cert_common_name=None, + refetch_server_cert_callback=None, + **kwargs, + ): + """ + Initialize the AggregatorRESTClient with proper security settings. + + Args: + agg_addr: Aggregator address + agg_port: Aggregator port + aggregator_uuid: UUID of the aggregator + federation_uuid: UUID of the federation + collaborator_name: Name of the collaborator + use_tls: Whether to use TLS + require_client_auth: Whether to require client authentication + root_certificate: Path to root certificate + certificate: Path to client certificate + private_key: Path to client private key + single_col_cert_common_name: Common name for single collaborator certificate + refetch_server_cert_callback: Callback to refetch server certificate + """ + self.use_tls = use_tls + self.require_client_auth = require_client_auth + self.root_certificate = root_certificate + self.certificate = certificate + self.private_key = private_key + self.aggregator_uuid = aggregator_uuid + self.federation_uuid = federation_uuid + self.collaborator_name = collaborator_name + self.single_col_cert_common_name = single_col_cert_common_name + self.refetch_server_cert_callback = refetch_server_cert_callback + + # Determine scheme and TLS verification + scheme = "https" if self.use_tls else "http" + + # Configure certificate verification + self.cert_verification = self._configure_cert_verification( + self.use_tls, self.root_certificate + ) + + # Configure client certificates if required + if self.use_tls and self.require_client_auth: + if not self.certificate or not self.private_key: + raise ValueError( + "Both certificate and private key are required for mTLS " + "(client authentication). " + "Please provide both certificate and private key paths." + ) + self.cert = (self.certificate, self.private_key) + else: + self.cert = None + + # Configure session with proper settings + self.session = requests.Session() + + # Set default headers + self.session.headers.update( + { + "Connection": "keep-alive", + "Keep-Alive": "timeout=300", + "Accept": "application/json", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + } + ) + + # Configure timeouts with longer duration for large payloads + self.timeout = (30, 300) # (connect timeout, read timeout) in seconds + + # Configure retries with backoff + retry_strategy = Retry( + total=3, + backoff_factor=1, + status_forcelist=[408, 429, 500, 502, 503, 504], + allowed_methods=["GET", "POST"], + raise_on_status=True, + ) + + # Configure the adapter with the retry strategy + adapter = HTTPAdapter( + max_retries=retry_strategy, pool_connections=10, pool_maxsize=10, pool_block=False + ) + + # Mount the adapter for both HTTP and HTTPS + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + + # Build the base URL + self.base_url = f"{scheme}://{agg_addr}:{agg_port}/experimental/v1" + + # Log warning about experimental API + logger.warning( + "Initializing Aggregator REST Client (EXPERIMENTAL API - Not for production use)" + ) + + # Verify certificates if TLS is enabled + if self.use_tls: + try: + self._verify_certificates() + except Exception as e: + logger.error(f"Certificate verification failed: {e}") + raise + + @classmethod + def _configure_cert_verification( + cls, use_tls: bool, root_certificate: str = None + ) -> bool | str: + """ + Configure certificate verification settings for requests. + + Args: + use_tls: Whether TLS is enabled + root_certificate: Optional path to root certificate file + + Returns: + Union[bool, str]: Either True for system CA bundle, False for no verification, + or path to root certificate file + """ + if not use_tls: + return False + + if root_certificate: + return root_certificate + + return True # Use system's default CA bundle + + def _verify_certificates(self): + """Verify SSL certificates and configuration.""" + import socket + import ssl + + # Try to establish a test connection + try: + hostname = self.base_url.split("://")[1].split(":")[0] + port = int(self.base_url.split(":")[2].split("/")[0]) + + # Create SSL context with specific options + context = ssl.create_default_context() + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + # Set secure cipher suites + context.set_ciphers("ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384") + + # Disable older TLS versions + context.options |= ( + ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1 + | ssl.OP_NO_TLSv1_2 + | ssl.OP_NO_COMPRESSION + | ssl.OP_NO_TICKET + ) + + if self.root_certificate: + context.load_verify_locations(cafile=self.root_certificate) + + if self.certificate and self.private_key: + context.load_cert_chain(certfile=self.certificate, keyfile=self.private_key) + + # Use context managers for proper resource cleanup + with socket.create_connection((hostname, port)) as sock: + with context.wrap_socket(sock, server_hostname=hostname) as _: + pass # Connection successful if we get here + + except ssl.SSLError as e: + if "CERTIFICATE_UNKNOWN" in str(e): + logger.error( + "Certificate unknown error - this usually means the " + "server's certificate is not trusted" + ) + logger.error("Please verify that:") + logger.error( + "1. The root certificate contains all necessary intermediate certificates" + ) + logger.error("2. The server's certificate is properly signed by a trusted CA") + logger.error("3. The hostname matches the certificate's subject") + raise + except Exception as e: + logger.error(f"Connection verification failed: {e}") + raise + + def _build_header(self) -> dict: + """Build and return a header dictionary with security headers.""" + headers = { + "Receiver": self.aggregator_uuid, + "Federation-UUID": self.federation_uuid, + "Single-Col-Cert-CN": self.single_col_cert_common_name or "", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + "Sender": self.collaborator_name, + } + if self.use_tls: + headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" + return headers + + def _make_request( + self, method, url, data=None, params=None, headers=None, stream=False, timeout=None + ): + """Make a request with proper security settings.""" + start_time = time.time() + try: + self._validate_url_scheme(url) + request_headers = self._prepare_headers(headers) + response = self._execute_request( + method, url, request_headers, data, params, stream, timeout + ) + self._validate_response(response) + logger.debug(f"Request completed in {time.time() - start_time:.2f} seconds") + return response + + except requests.exceptions.Timeout: + logger.error(f"Request timed out after {time.time() - start_time:.2f} seconds") + raise + except requests.exceptions.ConnectionError as e: + self._handle_connection_error(e) + raise + except requests.exceptions.RequestException as e: + self._handle_request_error(e) + raise + + def _validate_url_scheme(self, url): + """Validate URL scheme matches TLS setting.""" + if self.use_tls and not url.startswith("https://"): + raise ValueError("TLS required but URL is not HTTPS") + elif not url.startswith("http://") and not url.startswith("https://"): + raise ValueError("URL must use either HTTP or HTTPS scheme") + + def _prepare_headers(self, headers): + """Prepare request headers with security settings.""" + request_headers = self._build_header() + if headers: + request_headers.update(headers) + return request_headers + + def _execute_request(self, method, url, headers, data, params, stream, timeout): + """Execute the HTTP request with retry logic.""" + max_retries = 3 + for attempt in range(max_retries): + try: + session = requests.Session() + if self.use_tls: + # Extract hostname from URL for verification + hostname = url.split("://")[1].split(":")[0].split("/")[0] + + # Create a custom SSL context for this request + context = ssl.create_default_context( + cafile=self.root_certificate if self.root_certificate else None + ) + context.verify_mode = ssl.CERT_REQUIRED + + # Configure session with SSL context and hostname verification + session.verify = self.cert_verification + session.cert = self.cert + + # Configure adapter with proper SSL settings + adapter = HTTPAdapter( + pool_connections=1, + pool_maxsize=1, + max_retries=Retry( + total=3, + backoff_factor=1, + status_forcelist=[408, 429, 500, 502, 503, 504], + allowed_methods=["GET", "POST"], + ), + ) + session.mount("https://", adapter) + + # Build the complete headers with security information + base_headers = self._build_header() + if headers: + # Merge user-provided headers with base headers + base_headers.update(headers) + headers = base_headers + headers["Host"] = hostname + + # Add certificate info to request kwargs + request_kwargs = { + "method": method, + "url": url, + "headers": headers, + "data": data, + "params": params, + "stream": stream, + "verify": self.cert_verification, + "timeout": timeout or self.timeout, + } + + # Add client certificate if mTLS is enabled + if self.require_client_auth: + if not self.certificate or not self.private_key: + raise ValueError( + "Both certificate and private key are required for mTLS " + "(client authentication). " + "Please provide both certificate and private key paths." + ) + # Use proper cert format + request_kwargs["cert"] = (self.certificate, self.private_key) + + response = session.request(**request_kwargs) + else: + # For non-TLS requests, still use the security headers + base_headers = self._build_header() + if headers: + base_headers.update(headers) + + response = session.request( + method=method, + url=url, + headers=base_headers, + data=data, + params=params, + stream=stream, + timeout=timeout or self.timeout, + ) + return response + except requests.exceptions.SSLError as e: + self._handle_ssl_error(e, attempt, max_retries) + if attempt == max_retries - 1: + raise + + def _handle_ssl_error(self, e, attempt, max_retries): + """Handle SSL errors with retry logic.""" + if "CERTIFICATE_UNKNOWN" in str(e): + logger.error( + "Certificate unknown error - this usually means the " + "server's certificate is not trusted" + ) + logger.error("Please verify that:") + logger.error("1. The root certificate contains all necessary intermediate certificates") + logger.error("2. The server's certificate is properly signed by a trusted CA") + logger.error("3. The hostname matches the certificate's subject") + if attempt < max_retries - 1 and self.refetch_server_cert_callback: + logger.debug("Attempting to refetch server certificate") + self.root_certificate = self.refetch_server_cert_callback() + # Update the cert_verification with the new root certificate + self.cert_verification = self._configure_cert_verification( + self.use_tls, self.root_certificate + ) + # Re-verify certificates + try: + self._verify_certificates() + except Exception as verify_error: + logger.error(f"Certificate re-verification failed: {verify_error}") + raise + else: + raise + else: + if attempt < max_retries - 1: + logger.warning(f"SSL error (attempt {attempt + 1}/{max_retries}): {str(e)}") + time.sleep(1) + else: + raise + + def _validate_response(self, response): + """Validate response headers and security settings.""" + security_headers = { + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + } + for header, expected_value in security_headers.items(): + if header in response.headers and response.headers[header] != expected_value: + logger.warning(f"Missing or incorrect security header: {header}") + + response.raise_for_status() + + def _handle_connection_error(self, e): + """Handle connection errors.""" + logger.error(f"Connection error: {e}") + if hasattr(e, "args") and len(e.args) > 0: + logger.error(f"Connection error details: {e.args[0]}") + + def _handle_request_error(self, e): + """Handle request errors.""" + logger.error(f"Request failed: {e}") + if hasattr(e, "args") and len(e.args) > 0: + logger.error(f"Request error details: {e.args[0]}") + + def get_tasks(self) -> Tuple[List[Any], int, int, bool]: + """Get tasks from the aggregator with proper security settings.""" + headers = {"Accept": "application/json", "Sender": self.collaborator_name} + params = { + "collaborator_id": self.collaborator_name, + "federation_uuid": self.federation_uuid, + } + url = f"{self.base_url}/tasks" + response = self._make_request("GET", url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + tasks_resp = aggregator_pb2.GetTasksResponse() + json_format.ParseDict(data, tasks_resp) + + logger.debug( + f"Received tasks response - Round: {tasks_resp.round_number}, " + f"Tasks: {[t.name for t in tasks_resp.tasks]}, " + f"Sleep: {tasks_resp.sleep_time}, Quit: {tasks_resp.quit}" + ) + return tasks_resp.tasks, tasks_resp.round_number, tasks_resp.sleep_time, tasks_resp.quit + + def get_aggregated_tensors( + self, + tensor_keys, + require_lossless: bool = True, + ) -> List[base_pb2.NamedTensor]: + """ + Get aggregated tensors from the aggregator. + + Args: + tensor_keys (list): A list of tensor keys to fetch from aggregator. + require_lossless (bool): Whether lossless compression is required. + + Returns: + A list of `NamedTensor`s in the same order as requested. + """ + logger.debug(f"Requesting {len(tensor_keys)} aggregated tensors") + + # Build the request payload similar to gRPC implementation + tensor_specs = [] + for k in tensor_keys: + tensor_specs.append( + { + "tensor_name": k.tensor_name, + "round_number": k.round_number, + "report": k.report, + "tags": k.tags, + "require_lossless": require_lossless, + } + ) + + request_data = { + "header": { + "sender": self.collaborator_name, + "receiver": self.aggregator_uuid, + "federation_uuid": self.federation_uuid, + "single_col_cert_common_name": self.single_col_cert_common_name or "", + }, + "tensor_specs": tensor_specs, + } + + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "Sender": self.collaborator_name, + } + url = f"{self.base_url}/tensors/aggregated/batch" + extended_timeout = (30, 600) # 30 seconds connect, 10 minutes read timeout + + try: + logger.debug(f"Requesting batch of {len(tensor_keys)} aggregated tensors") + response = self._make_request( + "POST", + url, + data=json_format.MessageToJson( + aggregator_pb2.GetAggregatedTensorsRequest(**request_data) + ), + headers=headers, + timeout=extended_timeout, + ) + data = response.json() + resp = aggregator_pb2.GetAggregatedTensorsResponse() + json_format.ParseDict(data, resp, ignore_unknown_fields=True) + logger.debug(f"Successfully retrieved {len(resp.tensors)} aggregated tensors") + return resp.tensors + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + # This is expected during round 0 or when tensors haven't been aggregated yet + logger.debug("No aggregated tensors found for the requested tensor keys") + return [] + raise + + def send_local_task_results( + self, + round_number: int, + task_name: str, + data_size: int, + named_tensors: List[Any], + ) -> bool: + """Send local task results with proper security settings.""" + logger.debug(f"Sending task results for round {round_number}, task {task_name}") + + # Create the TaskResults message + task_results = aggregator_pb2.TaskResults( + header=aggregator_pb2.MessageHeader( + sender=self.collaborator_name, + receiver=self.aggregator_uuid, + federation_uuid=self.federation_uuid, + single_col_cert_common_name=self.single_col_cert_common_name or "", + ), + round_number=round_number, + task_name=task_name, + data_size=data_size, + tensors=named_tensors, + ) + + # Serialize the TaskResults first + task_results_bytes = task_results.SerializeToString() + logger.debug(f"TaskResults serialized size: {len(task_results_bytes)} bytes") + + # Create a DataStream message containing the TaskResults bytes + data_stream = base_pb2.DataStream(size=len(task_results_bytes), npbytes=task_results_bytes) + + # Create an empty DataStream to signal end of stream + end_stream = base_pb2.DataStream(size=0, npbytes=b"") + + # Serialize both messages + data_bytes = data_stream.SerializeToString() + end_bytes = end_stream.SerializeToString() + + # Create length-prefixed stream format + stream_data = ( + struct.pack(">I", len(data_bytes)) # Length prefix for first message + + data_bytes # First message + + struct.pack(">I", len(end_bytes)) # Length prefix for second message + + end_bytes # Second message (empty message signals end) + ) + + url = f"{self.base_url}/tasks/results" + request_headers = self._build_header() + request_headers["Sender"] = self.collaborator_name + request_headers["Content-Type"] = "application/x-protobuf-stream" + request_headers["Content-Length"] = str(len(stream_data)) + + try: + response = self._make_request( + "POST", + url, + data=stream_data, + headers=request_headers, + timeout=(30, 60), # Keep shorter timeout since we're sending all data at once + ) + response.raise_for_status() + logger.debug(f"Successfully sent task results for round {round_number}") + return True + except Exception as e: + logger.error(f"Failed to send task results for round {round_number}: {str(e)}") + logger.error(f"Error type: {type(e).__name__}") + logger.error(f"Request headers were: {request_headers}") + raise + + def ping(self): + """Ping the aggregator to check connectivity.""" + logger.info("Aggregator ping...") + headers = {"Accept": "application/json", "Sender": self.collaborator_name} + params = { + "collaborator_id": self.collaborator_name, + "federation_uuid": self.federation_uuid, + } + url = f"{self.base_url}/ping" + response = self._make_request("GET", url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + + # Validate response header like GRPC client + header = data.get("header", {}) + assert header.get("receiver") == self.collaborator_name, ( + f"Receiver in response header does not match collaborator name. " + f"Expected: {self.collaborator_name}, Actual: {header.get('receiver')}" + ) + assert header.get("sender") == self.aggregator_uuid, ( + f"Sender in response header does not match aggregator UUID. " + f"Expected: {self.aggregator_uuid}, Actual: {header.get('sender')}" + ) + assert header.get("federationUuid") == self.federation_uuid, ( + f"Federation UUID in response header does not match. " + f"Expected: {self.federation_uuid}, Actual: {header.get('federationUuid')}" + ) + assert header.get("singleColCertCommonName", "") == ( + self.single_col_cert_common_name or "" + ), ( + f"Single collaborator certificate common name in response header does not match. " + f"Expected: {self.single_col_cert_common_name}, " + f"Actual: {header.get('singleColCertCommonName')}" + ) + + logger.info("Aggregator pong!") + + def send_message_to_server(self, openfl_message: Any, collaborator_name: str) -> Any: + """ + Forwards a converted message from the local REST client to the OpenFL server and returns + the response. + + Args: + openfl_message: The InteropMessage proto to be sent to the OpenFL server. + collaborator_name: The name of the collaborator. + + Returns: + The response from the OpenFL server (InteropMessage proto). + """ + # Set the header fields + header = aggregator_pb2.MessageHeader( + sender=collaborator_name, + receiver=self.aggregator_uuid, + federation_uuid=self.federation_uuid, + single_col_cert_common_name=self.single_col_cert_common_name or "", + ) + openfl_message.header.CopyFrom(header) + + # Serialize to JSON + json_payload = json_format.MessageToJson(openfl_message) + url = f"{self.base_url}/interop/relay" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Sender": collaborator_name, + } + response = self._make_request( + "POST", + url, + data=json_payload, + headers=headers, + timeout=(30, 300), + ) + response.raise_for_status() + response_json = response.json() + openfl_response = aggregator_pb2.InteropMessage() + json_format.ParseDict(response_json, openfl_response, ignore_unknown_fields=True) + return openfl_response + + def __del__(self): + """Cleanup when the client is destroyed.""" + self.session.close() diff --git a/openfl/transport/rest/aggregator_server.py b/openfl/transport/rest/aggregator_server.py new file mode 100644 index 0000000000..731f2c718f --- /dev/null +++ b/openfl/transport/rest/aggregator_server.py @@ -0,0 +1,923 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AggregatorRESTServer module.""" + +import logging +import ssl +import threading +import time +from functools import wraps +from random import random +from time import sleep + +from flask import Flask, abort, jsonify, request +from google.protobuf import json_format +from werkzeug.serving import make_server + +from openfl.protocols import aggregator_pb2, base_pb2 + +logger = logging.getLogger(__name__) + + +def synchronized(func): + """Synchronization decorator.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + with self._lock: + return func(self, *args, **kwargs) + + return wrapper + + +def create_header(sender, receiver, federation_uuid, single_col_cert_common_name=""): + """Create a standard message header with consistent fields.""" + return aggregator_pb2.MessageHeader( + sender=str(sender), + receiver=str(receiver), + federation_uuid=str(federation_uuid), + single_col_cert_common_name=single_col_cert_common_name or "", + ) + + +class AggregatorRESTServer: + """REST server for the aggregator.""" + + def __init__( + self, + aggregator, + agg_addr, + agg_port, + use_tls=True, + require_client_auth=True, + certificate=None, + private_key=None, + root_certificate=None, + **kwargs, + ): + """Initialize REST server with security defaults.""" + # Initialize lock for synchronized methods + self._lock = threading.Lock() + + # Set up base configuration + self.aggregator = aggregator + self.host = agg_addr + self.port = agg_port + + # Set API prefix + self.api_prefix = "experimental/v1" + + # Set security defaults + self.use_tls = use_tls + self.require_client_auth = require_client_auth + self.ssl_context = None + + # Set up server components with security focus + self._setup_server_components(certificate, private_key, root_certificate) + + # Set up routes with synchronized access + self._setup_routes() + + # Build the base URL + scheme = "https" if use_tls else "http" + self.base_url = f"{scheme}://{agg_addr}:{agg_port}/{self.api_prefix}" + + def _setup_ssl_context(self, certificate, private_key, root_certificate): + """Set up SSL context for TLS/mTLS.""" + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + + # Set secure cipher suites + ssl_context.set_ciphers("ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384") + + # Disable older TLS versions and set security options + ssl_context.options |= ( + ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1 + | ssl.OP_NO_TLSv1_2 + | ssl.OP_NO_COMPRESSION + | ssl.OP_NO_TICKET # Disable session tickets + | ssl.OP_CIPHER_SERVER_PREFERENCE # Server chooses cipher + | ssl.OP_SINGLE_DH_USE # Ensure perfect forward secrecy with DHE + | ssl.OP_SINGLE_ECDH_USE # Ensure perfect forward secrecy with ECDHE + ) + + # Set verification flags for strict certificate checking + ssl_context.verify_flags = ( + ssl.VERIFY_X509_STRICT | ssl.VERIFY_CRL_CHECK_CHAIN # Check certificate revocation + ) + + # Configure client certificate verification + if self.require_client_auth: + ssl_context.verify_mode = ssl.CERT_REQUIRED + # Load root CA for client cert verification + if root_certificate: + try: + ssl_context.load_verify_locations(cafile=root_certificate) + except Exception as e: + logger.error(f"Failed to load root CA certificate: {str(e)}") + raise + else: + logger.error("Root certificate is required when client authentication is enabled") + raise ValueError("Root certificate is required for mTLS") + else: + ssl_context.verify_mode = ssl.CERT_NONE + + # Load server certificate and key + try: + ssl_context.load_cert_chain(certfile=certificate, keyfile=private_key) + except Exception as e: + logger.error(f"Failed to load server certificate and key: {str(e)}") + raise + + # Load and trust the root CA certificate + if root_certificate: + try: + ssl_context.load_verify_locations(cafile=root_certificate) + except Exception as e: + logger.error(f"Failed to load root CA certificate: {str(e)}") + raise + # Enable post-handshake authentication for better security + if hasattr(ssl_context, "post_handshake_auth"): + ssl_context.post_handshake_auth = True + + # Set verification purpose + ssl_context.purpose = ssl.Purpose.CLIENT_AUTH + + return ssl_context + + def _setup_flask_app(self): + """Configure Flask application with proper settings for both TLS and non-TLS modes.""" + app = Flask(__name__) + + # Set session and file age defaults + app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 1800 # 30 minutes + app.config["PERMANENT_SESSION_LIFETIME"] = 1800 # 30 minutes + + # Configure logging to be minimal + import logging + + # Disable Flask's default logging + log = logging.getLogger("werkzeug") + log.setLevel(logging.ERROR) + + # Add security headers + @app.after_request + def add_security_headers(response): + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + if self.use_tls: + response.headers["Strict-Transport-Security"] = ( + "max-age=31536000; includeSubDomains" + ) + return response + + return app + + def _validate_client_certificate(self, request_environ, collaborator_name): + """ + Validate client certificate when mTLS is enabled. + Args: + request_environ: The request environment containing SSL information + collaborator_name: The collaborator name from the request (from header.sender) + Returns: + bool: True if validation passes + Raises: + abort: HTTP error if validation fails + """ + if not self.use_tls: + return True + + try: + # Default to collaborator name (like gRPC) + common_name = collaborator_name + + # Get certificate information if client auth is required + if self.require_client_auth: + cert_cn = self._get_certificate_cn(request_environ, collaborator_name) + if not cert_cn: + abort(401, "Client certificate validation failed - certificate not found") + common_name = cert_cn + + # Validate collaborator identity + return self._validate_collaborator(common_name, collaborator_name) + + except Exception as e: + logger.error(f"Certificate validation failed: {str(e)}") + abort(401, str(e)) + + def _get_certificate_cn(self, request_environ, collaborator_name): + """Get certificate CN from environment or headers.""" + # Try to get certificate info from environment + peercert = request_environ.get("SSL_CLIENT_CERT") + cert_cn = request_environ.get("SSL_CLIENT_S_DN_CN") + + # Try to extract CN if we have certificate but no CN + if peercert and not cert_cn: + try: + cert_cn = self._extract_cn_from_cert(peercert) + except Exception as e: + logger.error(f"Failed to extract CN from certificate: {e}") + + # If no certificate found, try fallback methods + if not peercert: + # Try header-based fallback for experimental mode + cert_cn = self._try_header_fallback(collaborator_name) + + return cert_cn + + def _try_header_fallback(self, collaborator_name): + """Try to get CN from headers as fallback in experimental mode.""" + # FALLBACK: In experimental mode, allow using header-based auth + # This should NOT be used in production + try: + from flask import request + + # Use Sender header as fallback + if hasattr(request, "headers") and "Sender" in request.headers: + cert_cn = request.headers.get("Sender") + return cert_cn + except Exception as e: + logger.error(f"Error in header fallback: {e}") + + # THIS SHOULD BE REMOVED POST EXPERIMENTAL MODE + return collaborator_name + + def _validate_collaborator(self, common_name, collaborator_name): + """Validate collaborator identity.""" + if not self.aggregator.valid_collaborator_cn_and_id(common_name, collaborator_name): + # Add timing attack protection + sleep(5 * random()) + logger.error( + f"Invalid collaborator. CN: |{common_name}| " + f"collaborator_name: |{collaborator_name}|" + ) + abort(401, "Collaborator validation failed") + + return True + + def _extract_cn_from_cert(self, cert_pem): + """Extract CN from a PEM certificate using standard libraries.""" + import re + + pass + pass + + # Try regex approach first (most reliable with PEM format) + cn_match = re.search( + r"CN\s*=\s*([^,/\n]+)", + cert_pem.decode("utf-8") if isinstance(cert_pem, bytes) else cert_pem, + ) + if cn_match: + return cn_match.group(1).strip() + + # Try using cryptography if available + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + + # Convert PEM to certificate object + cert_data = cert_pem.encode("utf-8") if isinstance(cert_pem, str) else cert_pem + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + + # Extract CN from subject + for attribute in cert.subject: + if attribute.oid._name == "commonName": + return attribute.value + except ImportError: + pass + + # Fallback: use the collaborator name from the environment + return None + + def _setup_interop_client(self): + """Set up inter-federation connector client.""" + try: + return self.aggregator.get_interop_client() + except AttributeError: + return None + + def _is_authorized(self, collaborator_id, federation_id, cert_common_name=None): + """ + Validate collaborator identity with strict checks. + + Args: + collaborator_id (str): The collaborator's ID + federation_id (str): The federation UUID + cert_common_name (str, optional): Certificate CN if using mTLS + + Returns: + bool: True if validation passes + + Raises: + abort: HTTP error if validation fails + """ + is_valid = False + try: + # Validate collaborator identity + if not collaborator_id: + logger.error("Collaborator identity not provided") + abort(400, "Collaborator identity not provided") + + # First check if collaborator is authorized + if collaborator_id not in self.aggregator.authorized_cols: + logger.error(f"Collaborator not in authorized list. Got: {collaborator_id}") + abort(401, "Unauthorized collaborator") + + # Validate collaborator identity + common_name = cert_common_name if cert_common_name is not None else collaborator_id + if not self.aggregator.valid_collaborator_cn_and_id(common_name, collaborator_id): + logger.error( + f"Collaborator validation failed. CN: {common_name}, ID: {collaborator_id}" + ) + abort(401, "Collaborator validation failed") + + # Validate client certificate if mTLS is enabled + if self.use_tls and self.require_client_auth: + self._validate_client_certificate(request.environ, collaborator_id) + + # Verify federation UUID + if federation_id != str(self.aggregator.federation_uuid): + logger.error( + f"Federation UUID mismatch. Expected: {self.aggregator.federation_uuid}, " + f"Got: {federation_id}" + ) + abort(401, "Federation UUID mismatch") + + is_valid = True + return True + + except Exception as e: + logger.error(f"Validation failed: {str(e)}") + abort(401, str(e)) + finally: + # Add timing attack protection for all error cases + if not is_valid: + sleep(5 * random()) + + def _validate_task_headers(self, headers): + """ + Validate task submission headers with timing attack protection. + + Args: + headers (dict): Request headers + + Returns: + str: Validated collaborator name + + Raises: + abort: HTTP error if validation fails + """ + try: + # Get collaborator identity from certificate or headers + collab_name = None + if self.use_tls and self.require_client_auth: + # Try to get from certificate first + collab_name = request.environ.get("SSL_CLIENT_S_DN_CN") + logger.debug(f"Using certificate CN: {collab_name}") + + # If not from certificate, try headers + if not collab_name: + collab_name = headers.get("Sender") + if not collab_name: + sleep(5 * random()) # Add timing attack protection + logger.error("No Sender header provided") + abort(401, "No Sender header provided") + + # Get other required headers + receiver = headers.get("Receiver") + federation_id = headers.get("Federation-UUID") + cert_common_name = headers.get("Single-Col-Cert-CN", "") + + # Validate collaborator identity + if not self.aggregator.valid_collaborator_cn_and_id(collab_name, collab_name): + sleep(5 * random()) # Add timing attack protection + msg = f"CN: {collab_name}, ID: {collab_name}" + logger.error(f"Collaborator validation failed. {msg}") + abort(401, "Collaborator validation failed") + + # Verify all headers with strict validation + assert receiver == str(self.aggregator.uuid), ( + f"Header receiver mismatch. Expected: {self.aggregator.uuid}, Got: {receiver}" + ) + + assert federation_id == str(self.aggregator.federation_uuid), ( + f"Federation UUID mismatch. Expected: {self.aggregator.federation_uuid}, " + f"Got: {federation_id}" + ) + + expected_cn = self.aggregator.single_col_cert_common_name or "" + assert cert_common_name == expected_cn, ( + f"Single col cert CN mismatch. Expected: {expected_cn}, Got: {cert_common_name}" + ) + + return collab_name + except AssertionError as e: + sleep(5 * random()) # Add timing attack protection + logger.error(f"Header validation failed: {str(e)}") + abort(401, str(e)) + + def _parse_protobuf_stream(self, data): + """Parse protobuf stream data.""" + logger.debug(f"Received {len(data)} bytes of protobuf stream data") + + # First message is DataStream containing TaskResults + msg_len = int.from_bytes(data[:4], byteorder="big") + logger.debug(f"First message length: {msg_len}") + data_stream_bytes = data[4 : 4 + msg_len] + data_stream = base_pb2.DataStream() + data_stream.ParseFromString(data_stream_bytes) + logger.debug(f"Parsed DataStream with size: {data_stream.size}") + + # Extract TaskResults from DataStream + task_results = aggregator_pb2.TaskResults() + task_results.ParseFromString(data_stream.npbytes) + + # Log task details + task_info = ( + f"Task: {task_results.task_name}, " + f"Round: {task_results.round_number}, " + f"Size: {task_results.data_size}, " + f"Tensors: {len(task_results.tensors)}" + ) + logger.debug(f"Extracted TaskResults from DataStream - {task_info}") + + # Verify end message + end_msg_offset = 4 + msg_len + end_msg_len = int.from_bytes(data[end_msg_offset : end_msg_offset + 4], byteorder="big") + logger.debug(f"End message length: {end_msg_len}") + + if end_msg_len != 0: + logger.error(f"Invalid end message length: {end_msg_len}") + abort(400, "Invalid stream format - expected empty end message") + + # Verify total length + expected_total_len = 4 + msg_len + 4 + end_msg_len + if len(data) != expected_total_len: + msg = f"Got {len(data)}, expected {expected_total_len}" + logger.error(f"Data length mismatch. {msg}") + abort(400, "Invalid stream data length") + + return task_results + + def _build_tasks_response( + self, + tasks_list, + round_number, + sleep_time, + time_to_quit, + collab_id, + ): + """Build GetTasksResponse protobuf.""" + tasks_proto = [] + if tasks_list: + if isinstance(tasks_list[0], str): + # Backward compatibility: list of task names + tasks_proto = [aggregator_pb2.Task(name=t) for t in tasks_list] + else: + tasks_proto = [ + aggregator_pb2.Task( + name=getattr(t, "name", ""), + function_name=getattr(t, "function_name", ""), + task_type=getattr(t, "task_type", ""), + apply_local=getattr(t, "apply_local", False), + ) + for t in tasks_list + ] + + # Create response header + header = create_header( + sender=str(self.aggregator.uuid), + receiver=collab_id, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + + return aggregator_pb2.GetTasksResponse( + header=header, + round_number=round_number, + tasks=tasks_proto, + sleep_time=sleep_time, + quit=time_to_quit, + ) + + def _setup_server_components(self, certificate=None, private_key=None, root_certificate=None): + """Set up server components including SSL, Flask app, and interop client.""" + # Set up SSL if enabled + if self.use_tls: + self.ssl_context = self._setup_ssl_context(certificate, private_key, root_certificate) + else: + self.ssl_context = None # Explicitly set to None when TLS is disabled + + # Set up Flask app + self.app = self._setup_flask_app() + + # Set up interop client + self.interop_client = self._setup_interop_client() + self.use_connector = self.interop_client is not None + + def _setup_routes(self): + """Set up Flask routes.""" + # Register the route handlers + self._setup_ping_route() + self._setup_tasks_route() + self._setup_task_results_route() + self._setup_tensor_route() + self._setup_relay_route() + # Add middleware for client certificate extraction + self._setup_certificate_middleware() + + def _setup_certificate_middleware(self): + """Set up middleware to capture SSL certificates from client connections.""" + + @self.app.before_request + def extract_client_cert(): + """Extract client certificate and add it to request environment.""" + if not (self.use_tls and self.require_client_auth): + return None + + # Get SSL connection information + try: + from flask import request + + # Try to extract certificate from the socket + cert_data = self._extract_certificate_from_socket(request.environ) + if cert_data: + # Process the certificate data + self._process_certificate_data(request.environ, cert_data) + + except Exception as e: + logger.warning(f"Failed to extract client certificate: {e}") + # Continue processing the request even if cert extraction fails + + return None + + def _extract_certificate_from_socket(self, environ): + """Extract the certificate from the socket if available.""" + # Access underlying SSL socket if possible + transport = environ.get("werkzeug.socket") + if not (transport and hasattr(transport, "getpeercert")): + return None + + # Extract certificate from socket + return transport.getpeercert(binary_form=True) + + def _process_certificate_data(self, environ, der_cert): + """Process the DER certificate data and store in environment.""" + if not der_cert: + return False + + # Convert DER to PEM format using built-in libraries + try: + # Try using cryptography if available + cn, pem_cert = self._convert_der_using_cryptography(der_cert) + if pem_cert: + environ["SSL_CLIENT_CERT"] = pem_cert + if cn: + environ["SSL_CLIENT_S_DN_CN"] = cn + logger.info(f"Extracted client certificate CN: {cn}") + return True + except ImportError: + # Fall back to regex method + return self._try_regex_cn_extraction(environ, der_cert) + except Exception as e: + logger.warning(f"Error converting certificate format: {e}") + + return False + + def _convert_der_using_cryptography(self, der_cert): + """Convert DER certificate using cryptography library.""" + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + cert = x509.load_der_x509_certificate(der_cert, default_backend()) + pem_cert = cert.public_bytes(encoding=serialization.Encoding.PEM) + + # Parse the subject to get CN + cn = None + for attribute in cert.subject: + if attribute.oid._name == "commonName": + cn = attribute.value + break + + return cn, pem_cert + + def _try_regex_cn_extraction(self, environ, der_cert): + """Try to extract CN using regex from binary certificate.""" + try: + import binascii + import re + + # Convert to hex and then look for CN + hex_data = binascii.hexlify(der_cert).decode("ascii") + # Look for common name pattern in hex + # This is a simplified approach and may not work for all certs + cn_pattern = ( + r"(?:3[0-9]|4[0-9]|5[0-9])(?:06|07|08|09|0a|0b|0c|0d|0e|0f)" + r"(?:03|04|05|06)(?:13|14|15|16)(.{2,60})(?:30|31)" + ) + cn_match = re.search(cn_pattern, hex_data) + if cn_match: + # Convert hex to ASCII + cn_hex = cn_match.group(1) + try: + cn = binascii.unhexlify(cn_hex).decode("utf-8") + environ["SSL_CLIENT_S_DN_CN"] = cn + logger.info(f"Extracted client certificate CN using regex: {cn}") + return True + except Exception as e: + logger.warning(f"Failed to decode CN: {e}") + + return False + except Exception as e: + logger.warning(f"Error in regex CN extraction: {e}") + return False + + def _setup_ping_route(self): + """Set up the /ping endpoint.""" + + @self.app.route(f"/{self.api_prefix}/ping", methods=["GET"]) + def ping(): + """Simple ping endpoint to check server connectivity.""" + try: + # Get collaborator identity from certificate or query param + collaborator_id = None + if self.require_client_auth: + collaborator_id = request.environ.get("SSL_CLIENT_S_DN_CN") + if collaborator_id is None: + collaborator_id = request.args.get("collaborator_id") + + federation_id = request.args.get("federation_uuid") + + # Use the consolidated validation method + self._is_authorized(collaborator_id, federation_id) + + # Create response header + header = create_header( + sender=str(self.aggregator.uuid), + receiver=collaborator_id, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + + # Return response in same format as GRPC + return jsonify({"header": json_format.MessageToDict(header)}) + except Exception as e: + logger.error(f"Ping request failed: {str(e)}") + abort(401, str(e)) + + def _setup_tasks_route(self): + """Set up the /tasks endpoint.""" + + @self.app.route(f"/{self.api_prefix}/tasks", methods=["GET"]) + def get_tasks(): + """Endpoint for collaborators to fetch pending tasks.""" + # Get collaborator identity from certificate or query param + collaborator_id = None + if self.require_client_auth: + collaborator_id = request.environ.get("SSL_CLIENT_S_DN_CN") + if collaborator_id is None: + collaborator_id = request.args.get("collaborator_id") + + federation_id = request.args.get("federation_uuid") + + # Use the consolidated validation method + self._is_authorized(collaborator_id, federation_id) + + # Check if connector mode is enabled + if self.use_connector: + abort(501, "GetTasks not supported in connector mode") + + # Fetch tasks from Aggregator core - directly delegate to the aggregator + tasks_list, round_number, sleep_time, time_to_quit = self.aggregator.get_tasks( + collaborator_id + ) + + # Log task assignment + task_names = [getattr(t, "name", t) for t in (tasks_list or [])] + logger.debug( + f"Collaborator {collaborator_id} requested tasks. " + f"Round: {round_number}, Tasks: {task_names}, " + f"Sleep: {sleep_time}, Quit: {time_to_quit}" + ) + + # Build and return response + response_proto = self._build_tasks_response( + tasks_list, round_number, sleep_time, time_to_quit, collaborator_id + ) + return jsonify(json_format.MessageToDict(response_proto)) + + def _setup_task_results_route(self): + """Set up the /tasks/results endpoint.""" + + @self.app.route(f"/{self.api_prefix}/tasks/results", methods=["POST"]) + def post_task_results(): + """Handle task results submission.""" + try: + # Validate headers and get collaborator name + collab_name = self._validate_task_headers(request.headers) + + # Parse protobuf stream data + task_results = self._parse_protobuf_stream(request.data) + + # Direct delegation to the aggregator for task results processing + # This matches the gRPC approach of calling send_local_task_results directly + self.aggregator.send_local_task_results( + collab_name, + task_results.round_number, + task_results.task_name, + task_results.data_size, + task_results.tensors, + ) + + return jsonify({"status": "success"}) + + except Exception as e: + logger.error(f"Error processing task results: {str(e)}") + abort(400, f"Error processing task results: {str(e)}") + + def _setup_tensor_route(self): + """Set up the /tensors/aggregated endpoint.""" + + @self.app.route(f"/{self.api_prefix}/tensors/aggregated/batch", methods=["POST"]) + def get_aggregated_tensors(): + """Endpoint for collaborators to retrieve multiple aggregated tensors.""" + start_time = time.time() + + # Validate that this endpoint is not used in connector mode + if self.use_connector: + abort(501, "GetAggregatedTensors not supported in connector mode") + + try: + # Parse the incoming JSON to a GetAggregatedTensorsRequest protobuf message + request_data = request.get_json() + if not request_data: + abort(400, "Invalid JSON payload") + + tensors_request = aggregator_pb2.GetAggregatedTensorsRequest() + json_format.ParseDict(request_data, tensors_request, ignore_unknown_fields=True) + + # Validate headers and get collaborator identity + collaborator_id = tensors_request.header.sender + federation_id = tensors_request.header.federation_uuid + + # Use the consolidated validation method + self._is_authorized(collaborator_id, federation_id) + + # Validate request header similar to gRPC implementation + assert tensors_request.header.receiver == str(self.aggregator.uuid), ( + f"Header receiver mismatch. Expected: {self.aggregator.uuid}, " + f"Got: {tensors_request.header.receiver}" + ) + + assert tensors_request.header.federation_uuid == str( + self.aggregator.federation_uuid + ), ( + f"Federation UUID mismatch. Expected: {self.aggregator.federation_uuid}, " + f"Got: {tensors_request.header.federation_uuid}" + ) + + expected_cn = self.aggregator.single_col_cert_common_name or "" + assert tensors_request.header.single_col_cert_common_name == expected_cn, ( + f"Single col cert CN mismatch. Expected: {expected_cn}, " + f"Got: {tensors_request.header.single_col_cert_common_name}" + ) + + # Get tensors from aggregator - similar to gRPC implementation + logger.debug( + f"Processing batch request for {len(tensors_request.tensor_specs)} tensors" + ) + + named_tensors = [] + for ts in tensors_request.tensor_specs: + named_tensor = self.aggregator.get_aggregated_tensor( + ts.tensor_name, + ts.round_number, + ts.report, + tuple(ts.tags), + ts.require_lossless, + collaborator_id, + ) + # Add tensor to list (None tensors will be handled by the client) + if named_tensor is not None: + named_tensors.append(named_tensor) + else: + # Add empty tensor placeholder to maintain order + named_tensors.append(aggregator_pb2.NamedTensorProto()) + + # Create response header using the standardized method + header = create_header( + sender=str(self.aggregator.uuid), + receiver=collaborator_id, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + + # Create response + response_proto = aggregator_pb2.GetAggregatedTensorsResponse( + header=header, tensors=named_tensors + ) + + logger.debug( + f"Batch tensor retrieval completed in {time.time() - start_time:.2f} seconds. " + f"Returned {len(named_tensors)} tensors" + ) + return jsonify(json_format.MessageToDict(response_proto)) + + except AssertionError as e: + logger.error(f"Header validation failed: {str(e)}") + abort(400, str(e)) + except Exception as e: + logger.error(f"Error processing batch tensor request: {str(e)}") + abort(400, f"Error processing batch tensor request: {str(e)}") + + def _setup_relay_route(self): + """Set up the /interop/relay endpoint.""" + + @self.app.route(f"/{self.api_prefix}/interop/relay", methods=["POST"]) + def relay_message(): + """Endpoint for collaborator-to-aggregator message relay.""" + # This endpoint is optional; only enable if connector mode is configured + if not self.use_connector or self.interop_client is None: + abort(501, "Interop relay is not enabled on this aggregator") + + # Parse the incoming JSON to an InteropRelay protobuf message + try: + relay_req = json_format.Parse( + request.data.decode("utf-8"), aggregator_pb2.InteropRelay() + ) + except Exception as e: + abort(400, f"Invalid InteropRelay payload: {e}") + + # Validate the collaborator via header + collab_name = relay_req.header.sender + self._is_authorized(collab_name, relay_req.header.federation_uuid) + + if relay_req.header.receiver != str(self.aggregator.uuid): + abort(400, "Header receiver mismatch") + + # Forward the request to the configured interop connector and get response + logger.debug( + f"Relaying message from {collab_name} to external federation via connector" + ) + # Create a header for forwarding using the standardized method + forward_header = create_header( + sender=str(self.aggregator.uuid), + receiver=relay_req.header.receiver, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + # Use the aggregator's interop client to send and receive + response_proto = self.interop_client.send_receive(relay_req, header=forward_header) + # Return the response from the remote as JSON + return jsonify(json_format.MessageToDict(response_proto)) + + def serve(self): + """Start the REST server with proper configuration for both TLS and non-TLS modes.""" + # If connector mode is enabled, start the connector service + if self.use_connector: + try: + self.aggregator.start_connector() + except AttributeError: + pass + + # Configure server based on TLS mode + if self.use_tls and self.ssl_context: + server = make_server( + self.host, + self.port, + self.app, + ssl_context=self.ssl_context, + threaded=True, # Enable threading for better performance + ) + else: + server = make_server( + self.host, + self.port, + self.app, + threaded=True, # Enable threading for better performance + ) + + # Configure server thread + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + logger.warning( + "Starting Aggregator REST Server (EXPERIMENTAL API - Not for production use)" + ) + thread.start() + + try: + while not self.aggregator.all_quit_jobs_sent(): + sleep(5) + finally: + # Synchronized shutdown + if self.use_connector: + try: + self.aggregator.stop_connector() + except AttributeError: + pass + server.shutdown() + thread.join() + logger.info("Aggregator REST Server stopped.") diff --git a/openfl/utilities/utils.py b/openfl/utilities/utils.py index bab2ccc8c3..4f4e5fc2eb 100644 --- a/openfl/utilities/utils.py +++ b/openfl/utilities/utils.py @@ -263,3 +263,22 @@ def remove_readonly(func, path, _): func(path) return shutil.rmtree(path, ignore_errors=ignore_errors, onerror=remove_readonly) + + +def generate_port(hash, port_range=(49152, 60999)): + """ + Generate a deterministic port number based on a hash and a unique key. + + Args: + hash (str): A string representing the hash of the plan. + port_range (tuple): A tuple containing the minimum and maximum port + numbers (inclusive). The default range is (49152, 60999). + + Returns: + int: A port number within the specified range. + """ + min_port, max_port = port_range + # Use the first 8 characters of the unique hash to ensure deterministic output + hash_segment = hash[:8] + port = int(hash_segment, 16) % (max_port - min_port) + min_port + return port diff --git a/setup.py b/setup.py index 884618e146..6f0d990bea 100644 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ def run(self): 'tensorboardX', 'protobuf>=4.21,<6.0.0', 'grpcio>=1.56.2,<1.66.0', + 'Flask==3.1.1', ], python_requires='>=3.10, <3.13', project_urls={ diff --git a/test-requirements.txt b/test-requirements.txt index 694fa5bf2b..53ba4a017e 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,8 +1,10 @@ docker +Flask==3.1.1 lxml==5.3.1 paramiko pytest==8.3.5 pytest-asyncio==0.26.0 +pytest-cov>=2.10.0 pytest-mock==3.14.0 defusedxml==0.7.1 matplotlib==3.10.1 @@ -14,3 +16,4 @@ boto3>=1.37.19 moto==5.1.1 torchvision==0.22.0 azure-storage-blob==12.25.1 +cryptography>=3.4.0 diff --git a/tests/end_to_end/conftest.py b/tests/end_to_end/conftest.py index be400e7b6c..732cc5ef0d 100644 --- a/tests/end_to_end/conftest.py +++ b/tests/end_to_end/conftest.py @@ -31,7 +31,7 @@ def pytest_addoption(parser): parser.addoption("--num_rounds") parser.addoption("--model_name") parser.addoption("--workflow_backend") - parser.addoption("--tr_rest_api", action="store_true") + parser.addoption("--tr_rest_protocol", action="store_true") parser.addoption("--disable_client_auth", action="store_true") parser.addoption("--disable_tls", action="store_true") parser.addoption("--log_memory_usage", action="store_true") @@ -54,7 +54,7 @@ def pytest_configure(config): config.use_tls = not args.disable_tls config.log_memory_usage = args.log_memory_usage config.secure_agg = args.secure_agg - config.tr_rest_api = args.tr_rest_api + config.tr_rest_protocol = args.tr_rest_protocol config.workflow_backend = args.workflow_backend config.results_dir = config.getini("results_dir") diff --git a/tests/end_to_end/models/aggregator.py b/tests/end_to_end/models/aggregator.py index 9c0cabc32d..e85b0f6453 100644 --- a/tests/end_to_end/models/aggregator.py +++ b/tests/end_to_end/models/aggregator.py @@ -6,7 +6,7 @@ import tempfile import tests.end_to_end.utils.exceptions as ex -import tests.end_to_end.utils.federation_helper as fh +import tests.end_to_end.utils.helper as helper import tests.end_to_end.utils.ssh_helper as ssh @@ -21,7 +21,7 @@ class Aggregator(): 2. Starting the aggregator """ - def __init__(self, agg_domain_name, workspace_path, eval_scope=False, container_id=None): + def __init__(self, agg_domain_name, workspace_path, transport_protocol, eval_scope=False, container_id=None): """ Initialize the Aggregator class Args: @@ -29,12 +29,14 @@ def __init__(self, agg_domain_name, workspace_path, eval_scope=False, container_ workspace_path (str): Workspace path container_id (str): Container ID eval_scope (bool, optional): Scope of aggregator is evaluation. Default is False. + transport_protocol (str): Transport protocol (default: "gRPC") """ self.name = "aggregator" self.agg_domain_name = agg_domain_name self.workspace_path = workspace_path self.eval_scope = eval_scope self.container_id = container_id + self.transport_protocol = transport_protocol self.tensor_db_file = os.path.join(self.workspace_path, "local_state", "tensor.db") self.res_file = None # Result file to track the logs self.start_process = None # Process associated with the aggregator start command @@ -46,13 +48,13 @@ def generate_sign_request(self): try: cmd = f"fx aggregator generate-cert-request --fqdn {self.agg_domain_name}" error_msg = "Failed to generate the sign request" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, error_msg=error_msg, container_id=self.container_id, workspace_path=self.workspace_path, ) - fh.verify_cmd_output(output, return_code, error, error_msg, f"Generated a sign request for {self.name}") + helper.verify_cmd_output(output, return_code, error, error_msg, f"Generated a sign request for {self.name}") except Exception as e: raise ex.CSRGenerationException(f"Failed to generate sign request for {self.name}: {e}") @@ -87,7 +89,7 @@ def start(self): cmd=command, work_dir=self.workspace_path, redirect_to_file=bg_file, - check_sleep=60, + check_sleep=30, env=env ) diff --git a/tests/end_to_end/models/az_storage.py b/tests/end_to_end/models/az_storage.py new file mode 100644 index 0000000000..df3c32f11f --- /dev/null +++ b/tests/end_to_end/models/az_storage.py @@ -0,0 +1,174 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from azure.storage.blob import BlobServiceClient +from pathlib import Path + +import tests.end_to_end.utils.defaults as defaults +import tests.end_to_end.utils.docker_helper as docker_helper +import tests.end_to_end.utils.exceptions as ex + +# Suppress Azure SDK and urllib3 info/debug logs +logging.getLogger("azure").setLevel(logging.WARNING) +logging.getLogger("azure.core.pipeline.policies._universal").setLevel(logging.ERROR) +logging.getLogger("urllib3").setLevel(logging.WARNING) + +log = logging.getLogger(__name__) + + +class AzureStorage(): + """ + Class to handle Azure Storage + """ + + def __init__(self, host, port, account_name, account_key, endpoints_protocol): + """ + Initialize the AzureStorage class + Args: + host (str): Azure Storage host + port (int): Azure Storage port + account_name (str): Azure Storage account name + account_key (str): Azure Storage account key + endpoints_protocol (str): Protocol for the endpoints (http or https) + """ + self.host = host + self.port = port + self.account_name = account_name + self.account_key = account_key + self.endpoints_protocol = endpoints_protocol + self.blob_endpoint = f"{self.endpoints_protocol}://{self.host}:{self.port}/{self.account_name}" + self.blob_service_client = BlobServiceClient( + account_url=self.blob_endpoint, + credential=self.account_key, + ) + self.connection_string = f"DefaultEndpointsProtocol={endpoints_protocol};AccountName={account_name};AccountKey={account_key};BlobEndpoint={self.blob_endpoint};" + + def create_container(self, container_name): + """ + Create a container in Azure Storage + Args: + container_name (str): Name of the container + """ + try: + container_client = self.blob_service_client.create_container(container_name) + log.info(f"Container {container_name} created successfully") + except Exception as e: + log.error(f"Failed to create container: {e}") + raise e + return container_client + + def delete_container(self, container_name): + """ + Delete a container in Azure Storage + Args: + container_name (str): Name of the container + """ + try: + container_client = self.blob_service_client.get_container_client(container_name) + container_client.delete_container() + log.info(f"Container {container_name} deleted successfully") + except Exception as e: + log.error(f"Failed to delete container: {e}") + raise e + return True + + def upload_data_to_container(self, container_name, data_path: Path): + """ + Upload a file to Azure Storage. + Assumption - data_path contains the file to be uploaded. + Args: + container_name (str): Name of the container + file_path (str): Path to the file + """ + try: + # Verify data path + if not data_path.exists() or not data_path.is_dir(): + raise ValueError(f"Expected {data_path} to be a directory, but it does not exist or is not a directory.") + + if not any(data_path.iterdir()): + raise ValueError(f"Directory {data_path} is empty. Nothing to upload.") + + container_client = self.blob_service_client.get_container_client(container_name) + num = 0 + for file_path in data_path.rglob("*"): + if file_path.is_file(): + blob_name = str(file_path.relative_to(data_path)).replace("\\", "/") + with open(file_path, "rb") as data: + container_client.upload_blob(blob_name, data, overwrite=True) + num += 1 + log.info(f"Uploaded {num} files to {container_name}: {blob_name}") + except Exception as e: + log.error(f"Failed to upload file: {e}") + raise e + return True + + +class AzuriteStorage(AzureStorage): + """ + Azurite is an emulator for local Azure Storage development. + This class provides methods to start, stop, and manage the Azurite container. + """ + def __init__( + self, + host=defaults.AZURE_STORAGE_HOST, + port=defaults.AZURE_STORAGE_PORT, + account_name=defaults.AZURE_STORAGE_ACCOUNT_NAME, + account_key=defaults.AZURE_STORAGE_ACCOUNT_KEY, + endpoints_protocol=defaults.AZURE_STORAGE_ENDPOINTS_PROTOCOL, + ): + """ + Initialize the AzuriteStorage class + Args: + account_name (str): Azure Storage account name + account_key (str): Azure Storage account key + """ + super().__init__(host, port, account_name, account_key, endpoints_protocol) + + def start_azurite_container(self): + """ + Start the Azurite container for local testing. + """ + try: + # Stop and remove if already running or remove if exited + is_container_present = self.is_azurite_container_present() + if is_container_present: + log.info("Azurite container is present in either running or exited state. Stopping/removing for a fresh start.") + self.stop_azurite_container() + client = docker_helper.get_docker_client() + container = client.containers.run( + "mcr.microsoft.com/azure-storage/azurite", + detach=True, + ports={"10000/tcp": 10000, "10001/tcp": 10001, "10002/tcp": 10002}, + name="azurite", + ) + log.info(f"Azurite container started with ID: {container.id}") + except Exception as e: + raise ex.DockerException(f"Error starting Azurite container: {e}") + return container + + def stop_azurite_container(self): + """ + Stop the Azurite container. + """ + try: + client = docker_helper.get_docker_client() + container = client.containers.get("azurite") + container.stop() + container.remove() + log.info("Azurite container stopped and removed successfully") + except Exception as e: + raise ex.DockerException(f"Error stopping Azurite container: {e}") + return True + + def is_azurite_container_present(self): + """Check if Azurite container is present.""" + try: + log.info("Checking if Azurite container is present...") + client = docker_helper.get_docker_client() + container = client.containers.get("azurite") + if container.status in ("running", "exited"): + return container + except Exception: + pass + return None diff --git a/tests/end_to_end/models/collaborator.py b/tests/end_to_end/models/collaborator.py index 0729d59222..3eb7a46b94 100644 --- a/tests/end_to_end/models/collaborator.py +++ b/tests/end_to_end/models/collaborator.py @@ -6,7 +6,7 @@ import tempfile import tests.end_to_end.utils.exceptions as ex -import tests.end_to_end.utils.federation_helper as fh +import tests.end_to_end.utils.helper as helper import tests.end_to_end.utils.ssh_helper as ssh log = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class Collaborator(): 4. Starting the collaborator """ - def __init__(self, collaborator_name=None, data_directory_path=None, workspace_path=None, container_id=None): + def __init__(self, collaborator_name, transport_protocol, data_directory_path=None, workspace_path=None, container_id=None): """ Initialize the Collaborator class Args: @@ -31,12 +31,14 @@ def __init__(self, collaborator_name=None, data_directory_path=None, workspace_p data_directory_path (str): Data directory path workspace_path (str): Workspace path container_id (str): Container ID + transport_protocol (str): Transport protocol (default: "gRPC") """ self.name = collaborator_name self.collaborator_name = collaborator_name self.data_directory_path = data_directory_path self.workspace_path = workspace_path self.container_id = container_id + self.transport_protocol = transport_protocol self.res_file = None # Result file to track the logs self.start_process = None # Process associated with the aggregator start command @@ -50,13 +52,13 @@ def generate_sign_request(self): log.info(f"Generating a sign request for {self.collaborator_name}") cmd = f"fx collaborator generate-cert-request -n {self.collaborator_name}" error_msg = "Failed to generate the sign request" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, error_msg=error_msg, container_id=self.container_id, workspace_path=self.workspace_path, ) - fh.verify_cmd_output(output, return_code, error, error_msg, f"Generated a sign request for {self.collaborator_name}") + helper.verify_cmd_output(output, return_code, error, error_msg, f"Generated a sign request for {self.collaborator_name}") except Exception as e: log.error(f"{error_msg}: {e}") @@ -72,13 +74,13 @@ def create_collaborator(self): try: cmd = f"fx collaborator create -n {self.collaborator_name} -d {self.data_directory_path}" error_msg = f"Failed to create {self.collaborator_name}" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, error_msg=error_msg, container_id=self.container_id, workspace_path=self.workspace_path, ) - fh.verify_cmd_output( + helper.verify_cmd_output( output, return_code, error, error_msg, f"Created {self.collaborator_name} with the data directory {self.data_directory_path}" ) @@ -100,14 +102,14 @@ def import_pki(self, zip_name, with_docker=False): try: cmd = f"fx collaborator certify --import {zip_name}" error_msg = f"Failed to import and certify the CSR for {self.collaborator_name}" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, error_msg=error_msg, container_id=self.container_id, workspace_path=self.workspace_path if not with_docker else "", with_docker=with_docker, ) - fh.verify_cmd_output( + helper.verify_cmd_output( output, return_code, error, error_msg, f"Successfully imported and certified the CSR for {self.collaborator_name} with zip {zip_name}" ) @@ -144,7 +146,7 @@ def start(self): cmd=command, work_dir=self.workspace_path, redirect_to_file=bg_file, - check_sleep=60, + check_sleep=30, env=env ) @@ -182,13 +184,13 @@ def import_workspace(self): # Assumption - workspace.zip is present in the collaborator workspace cmd = f"fx workspace import --archive {self.workspace_path}/workspace.zip" error_msg = "Failed to import the workspace" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, error_msg=error_msg, container_id=self.container_id, workspace_path=os.path.join(self.workspace_path, ".."), # Import the workspace to the parent directory ) - fh.verify_cmd_output(output, return_code, error, error_msg, f"Imported the workspace for {self.collaborator_name}") + helper.verify_cmd_output(output, return_code, error, error_msg, f"Imported the workspace for {self.collaborator_name}") except Exception as e: log.error(f"{error_msg}: {e}") @@ -236,7 +238,7 @@ def ping_aggregator(self): cmd=command, work_dir=self.workspace_path, redirect_to_file=bg_file, - check_sleep=60, + check_sleep=30, env=env ) log.info( @@ -246,3 +248,26 @@ def ping_aggregator(self): log.error(f"{error_msg}: {e}") raise e return True + + def calculate_hash(self): + """ + Calculate the hash of the data directory and store in hash.txt file + Returns: + bool: True if successful, else False + """ + try: + log.info(f"Calculating hash for {self.collaborator_name}") + cmd = f"fx collaborator calchash --data_path {self.data_directory_path}" + error_msg = "Failed to calculate hash" + return_code, output, error = helper.run_command( + cmd, + error_msg=error_msg, + container_id=self.container_id, + workspace_path=self.workspace_path, + ) + helper.verify_cmd_output(output, return_code, error, error_msg, f"Calculated hash for {self.collaborator_name}") + + except Exception as e: + log.error(f"{error_msg}: {e}") + raise e + return True diff --git a/tests/end_to_end/models/model_owner.py b/tests/end_to_end/models/model_owner.py index 3d2c8d0f9f..a4fa9049ba 100644 --- a/tests/end_to_end/models/model_owner.py +++ b/tests/end_to_end/models/model_owner.py @@ -5,9 +5,9 @@ import yaml import logging -import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.defaults as defaults import tests.end_to_end.utils.exceptions as ex -import tests.end_to_end.utils.federation_helper as fh +import tests.end_to_end.utils.helper as helper import tests.end_to_end.utils.ssh_helper as ssh log = logging.getLogger(__name__) @@ -39,8 +39,8 @@ def __init__(self, model_name, log_memory_usage, container_id=None, workspace_pa self.aggregator = None self.collaborators = [] self.workspace_path = workspace_path - self.num_collaborators = constants.NUM_COLLABORATORS - self.rounds_to_train = constants.NUM_ROUNDS + self.num_collaborators = defaults.NUM_COLLABORATORS + self.rounds_to_train = defaults.NUM_ROUNDS self.log_memory_usage = log_memory_usage self.container_id = container_id @@ -54,13 +54,13 @@ def create_workspace(self): ws_path = self.workspace_path - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( f"fx workspace create --prefix {ws_path} --template {self.model_name}", workspace_path="", # No workspace path required for this command error_msg=error_msg, container_id=self.container_id, ) - fh.verify_cmd_output( + helper.verify_cmd_output( output, return_code, error, @@ -68,7 +68,7 @@ def create_workspace(self): raise_exception=True ) - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( "pip install -r requirements.txt", workspace_path=ws_path, error_msg="Failed to install the requirements", @@ -109,14 +109,14 @@ def certify_collaborator(self, collaborator_name, zip_name): try: cmd = f"fx collaborator certify --request-pkg {zip_name} -s" error_msg = f"Failed to sign the CSR {zip_name}" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, workspace_path=self.workspace_path, error_msg=error_msg, container_id=self.container_id, ) - fh.verify_cmd_output( + helper.verify_cmd_output( output, return_code, error, @@ -163,8 +163,8 @@ def modify_plan(self, param_config, plan_path): data["network"]["settings"]["require_client_auth"] = param_config.require_client_auth data["network"]["settings"]["use_tls"] = param_config.use_tls - if param_config.tr_rest_api: - data["task_runner"]["settings"]["transport_protocol"] = "rest" + if param_config.tr_rest_protocol: + data["network"]["settings"]["transport_protocol"] = defaults.TransportProtocol.REST.value if param_config.secure_agg: data["aggregator"]["settings"]["secure_aggregation"] = True with open(plan_file, "w+") as write_file: @@ -207,13 +207,13 @@ def initialize_plan(self, agg_domain_name, extra_args=""): log.info("Initializing the plan. It will take some time to complete..") cmd = f"fx plan initialize -a {agg_domain_name} {extra_args}" error_msg="Failed to initialize the plan" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, workspace_path=self.workspace_path, error_msg=error_msg, container_id=self.container_id, ) - fh.verify_cmd_output( + helper.verify_cmd_output( output, return_code, error, @@ -234,13 +234,13 @@ def certify_workspace(self): log.info("Certifying the workspace..") cmd = f"fx workspace certify" error_msg = "Failed to certify the workspace" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, workspace_path=self.workspace_path, error_msg=error_msg, container_id=self.container_id, ) - fh.verify_cmd_output( + helper.verify_cmd_output( output, return_code, error, @@ -259,13 +259,13 @@ def dockerize_workspace(self, image_name): try: cmd = f"fx workspace dockerize --base-image {image_name} --save" error_msg = "Failed to dockerize the workspace" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, workspace_path=self.workspace_path, error_msg=error_msg, container_id=self.container_id, ) - fh.verify_cmd_output(output, return_code, error, error_msg, "Workspace dockerized successfully") + helper.verify_cmd_output(output, return_code, error, error_msg, "Workspace dockerized successfully") except Exception as e: raise ex.WorkspaceDockerizationException(f"{error_msg}: {e}") @@ -329,13 +329,13 @@ def certify_aggregator(self, agg_domain_name): try: cmd = f"fx aggregator certify --silent --fqdn {agg_domain_name}" error_msg = "Failed to certify the aggregator request" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, workspace_path=self.workspace_path, error_msg=error_msg, container_id=self.container_id, ) - fh.verify_cmd_output(output, return_code, error, error_msg, "CA signed the request from aggregator") + helper.verify_cmd_output(output, return_code, error, error_msg, "CA signed the request from aggregator") except Exception as e: raise ex.AggregatorCertificationException(f"{error_msg}: {e}") @@ -347,13 +347,13 @@ def export_workspace(self): try: cmd = "fx workspace export" error_msg = "Failed to export the workspace" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, workspace_path=self.workspace_path, error_msg=error_msg, container_id=self.container_id, ) - fh.verify_cmd_output(output, return_code, error, error_msg, "Workspace exported successfully") + helper.verify_cmd_output(output, return_code, error, error_msg, "Workspace exported successfully") except Exception as e: raise ex.WorkspaceExportException(f"{error_msg}: {e}") diff --git a/tests/end_to_end/models/s3_bucket.py b/tests/end_to_end/models/s3_bucket.py new file mode 100644 index 0000000000..8f9c892a6b --- /dev/null +++ b/tests/end_to_end/models/s3_bucket.py @@ -0,0 +1,728 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import subprocess +import time +import signal +import socket +import shutil +import atexit +import boto3 +import logging +from botocore.client import Config +from botocore.exceptions import ClientError +import fnmatch +from pathlib import Path + +import tests.end_to_end.utils.defaults as defaults + +log = logging.getLogger(__name__) + + +class MinioServer(): + """ + A class to manage MinIO server operations. + This class provides methods to start, stop, and check the status of a MinIO server. + """ + def __init__( + self, + access_key=defaults.MINIO_ROOT_USER, + secret_key=defaults.MINIO_ROOT_PASSWORD, + minio_url=defaults.MINIO_URL, + minio_console_url=defaults.MINIO_CONSOLE_URL, + ): + """ + Initialize MinIO server with connection details. + + Args: + access_key: MinIO access key (default: from instance) + secret_key: MinIO secret key (default: from instance) + minio_url: MinIO server URL (default: from instance) + minio_console_url: MinIO console URL (default: from instance) + """ + self.access_key = access_key + self.secret_key = secret_key + self.minio_url = minio_url.split("://")[-1] + self.minio_console_url = minio_console_url.split("://")[-1] + + def is_minio_server_running(self, port=9000): + """ + Check if a MinIO server is running on the specified host and port. + + Args: + port: Port number (default: 9000) + + Returns: + bool: True if MinIO server is running, False otherwise + """ + try: + check_cmd = ['lsof', '-i', f':{port}', '-t'] + output = subprocess.check_output(check_cmd, universal_newlines=True).strip() + if output: + pids = [int(pid) for pid in output.split()] + log.info(f"Port {port} is in use (lsof check), PID(s): {pids}") + return pids + except Exception: + pass + return None + + def start_minio_server(self, data_dir): + """ + Start a MinIO server as a subprocess. + + Args: + data_dir: Directory to store data + + Returns: + subprocess.Popen: The process object for the MinIO server + """ + # Use instance values if not provided + + # Parse address to get host and port + try: + host, port = self.minio_url.split(':') + port = int(port) + except ValueError: + host = 'localhost' + port = 9001 + + # Check if MinIO server is already running + running = self.is_minio_server_running(port) + if running: + log.info("MinIO server already running. Cleaning up for fresh start.") + + if isinstance(running, list): + self._kill_processes(running) + else: + log.warning("MinIO server running but PID not found. Please check manually.") + + # Wait for port to be released + if not self._wait_for_port_release(port, host): + log.error("Port is still in use. Cannot start MinIO server.") + return None + + # Throw error if data_dir is not provided + if data_dir is None: + log.error("Data directory is required to start MinIO server.") + return None + + # Create data directory if it doesn't exist + os.makedirs(data_dir, exist_ok=True) + + # Check if minio is installed + minio_path = shutil.which("minio") + if minio_path is None: + log.error("MinIO server not found. Please install MinIO first.") + log.warning("You can download it from: https://min.io/download") + return None + + # Set environment variables for the current process as well as the subprocess + # This is important for MinIO to pick up the access and secret keys + # and for the subprocess to inherit them + env = os.environ.copy() + env["MINIO_ROOT_USER"] = os.environ["MINIO_ROOT_USER"] = self.access_key + env["MINIO_ROOT_PASSWORD"] = os.environ["MINIO_ROOT_PASSWORD"] = self.secret_key + + # Start MinIO server + cmd = [ + minio_path, + "server", + data_dir, + "--address", + self.minio_url, + "--console-address", + self.minio_console_url, + ] + log.info( + "Starting MinIO server with below configurations:" + f"\n - Data Directory: {data_dir}" + f"\n - Address: {self.minio_url}" + f"\n - Console Address: {self.minio_console_url}" + ) + + # Start the process + process = subprocess.Popen( + cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + + # Register a function to stop the server at exit + def stop_server(): + if process.poll() is None: # If process is still running + log.info("Stopping MinIO server...") + process.send_signal(signal.SIGTERM) + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + atexit.register(stop_server) + + # Wait for server to start + time.sleep(2) + + # Check if server started successfully + if process.poll() is not None: + # Process exited already + out, err = process.communicate() + log.error("Failed to start MinIO server:") + log.info(f"STDOUT: {out}") + log.error(f"STDERR: {err}") + return None + + log.info("MinIO server started successfully.") + return process + + def _kill_processes(self, pids): + """Kill processes by PID (SIGTERM, then SIGKILL if needed).""" + for pid in pids: + try: + os.kill(pid, signal.SIGTERM) + log.info(f"Killed MinIO process with PID {pid} (SIGTERM)") + time.sleep(1) + # Check if process is still alive + try: + os.kill(pid, 0) + # Still alive, force kill + os.kill(pid, signal.SIGKILL) + log.info(f"Force killed MinIO process with PID {pid} (SIGKILL)") + except OSError: + # Process is gone + pass + except Exception as e: + log.warning(f"Could not kill PID {pid}: {e}") + time.sleep(2) # Give time for processes to terminate + + def _wait_for_port_release(self, port, host="127.0.0.1", timeout=10): + """Wait until the port is free, or timeout (seconds) is reached.""" + waited = 0 + while waited < timeout: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + if s.connect_ex((host, port)) != 0: + return True # Port is free + log.info(f"Waiting for port {port} to be released...") + time.sleep(1) + waited += 1 + log.error(f"Port {port} is still in use after waiting {timeout} seconds.") + return False + + +class S3Bucket(): + """ + A class to manage S3 bucket operations using boto3. + This class provides methods to create, delete, upload, download, + and list objects in S3 buckets, as well as manage MinIO server. + """ + + def __init__( + self, + endpoint_url=defaults.MINIO_URL, + access_key=defaults.MINIO_ROOT_USER, + secret_key=defaults.MINIO_ROOT_PASSWORD, + region=None, + ): + """ + Initialize S3Helper with connection details. + + Args: + endpoint_url: The S3 endpoint URL (default: http://localhost:9000 for MinIO) + access_key: The access key (if None, uses MINIO_ROOT_USER env variable) + secret_key: The secret key (if None, uses MINIO_ROOT_PASSWORD env variable) + region: The region name (default: None, required by boto3 but not used by MinIO or on local server) + """ + self.endpoint_url = endpoint_url + self.access_key = access_key or os.environ.get("MINIO_ROOT_USER", "minioadmin") + self.secret_key = secret_key or os.environ.get( + "MINIO_ROOT_PASSWORD", "minioadmin" + ) + self.region = region + + # Extract host and port from endpoint_url + url_parts = self.endpoint_url.split('://')[-1].split(':') + self.minio_host = url_parts[0] + self.minio_port = int(url_parts[1]) if len(url_parts) > 1 else 9000 + + # Set default URLs + self.minio_url = f"{self.minio_host}:{self.minio_port}" + self.minio_console_url = f"{self.minio_host}:{self.minio_port + 1}" + + # Initialize S3 client + self.client = boto3.client( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version="s3v4"), + region_name=self.region, + ) + + def create_bucket(self, bucket_name): + """ + Create a new bucket if it doesn't exist. + + Args: + bucket_name: Name of the bucket to create + + Returns: + bool: True if bucket was created or already exists, False on error + """ + try: + # Check if bucket already exists + self.client.head_bucket(Bucket=bucket_name) + log.info(f"Bucket {bucket_name} already exists. Deleting all objects in the bucket.") + self.delete_all_objects(bucket_name) + return True + except ClientError as e: + # If bucket doesn't exist, create it + if e.response["Error"]["Code"] == "404": + try: + self.client.create_bucket(Bucket=bucket_name) + log.info(f"Bucket {bucket_name} created successfully.") + return True + except ClientError as create_error: + log.error(f"Error creating bucket: {create_error}") + return False + else: + log.error(f"Error checking bucket: {e}") + return False + + def delete_bucket(self, bucket_name, force=False): + """ + Delete a bucket. + + Args: + bucket_name: Name of the bucket to delete + force: If True, delete all objects in the bucket before deletion + + Returns: + bool: True if bucket was deleted, False on error + """ + try: + if force: + # Delete all objects in the bucket first + self.delete_all_objects(bucket_name) + + # Delete the bucket + self.client.delete_bucket(Bucket=bucket_name) + log.info(f"Bucket {bucket_name} deleted successfully.") + return True + except ClientError as e: + log.error(f"Error deleting bucket {bucket_name}: {e}") + return False + + def list_buckets(self): + """ + List all buckets. + + Returns: + list: List of bucket names + """ + try: + response = self.client.list_buckets() + buckets = [bucket["Name"] for bucket in response.get("Buckets", [])] + log.info(f"Found {len(buckets)} buckets: {', '.join(buckets)}") + return buckets + except ClientError as e: + log.error(f"Error listing buckets: {e}") + return [] + + def upload_file(self, file_path, bucket_name, object_name=None): + """ + Upload a file to a bucket. + + Args: + file_path: Path to the file to upload + bucket_name: Name of the bucket + object_name: S3 object name (if None, uses file_path basename) + + Returns: + bool: True if file was uploaded, False on error + """ + # If object_name was not specified, use file_path basename + if object_name is None: + object_name = Path(file_path).name + + try: + self.client.upload_file(file_path, bucket_name, object_name) + log.debug(f"File {file_path} uploaded to {bucket_name}/{object_name}") + return True + except ClientError as e: + log.error(f"Error uploading file {file_path}: {e}") + return False + + def upload_directory(self, dir_path, bucket_name, prefix=""): + """ + Upload all files from a directory to a bucket. + + Args: + dir_path: Path to the directory to upload + bucket_name: Name of the bucket + prefix: Prefix to add to object names + + Returns: + int: Number of files uploaded + """ + dir_path = Path(dir_path) + count = 0 + + if not dir_path.is_dir(): + log.error(f"Error: {dir_path} is not a directory") + return count + + for root, _, files in os.walk(dir_path): + for file in files: + file_path = Path(root) / file + # Calculate relative path from dir_path + rel_path = file_path.relative_to(dir_path) + # Create object name with prefix + if prefix: + object_name = f"{prefix}/{rel_path}" + else: + object_name = str(rel_path) + + if self.upload_file(str(file_path), bucket_name, object_name): + count += 1 + + log.info(f"Uploaded {count} files to {bucket_name} from {dir_path}") + return count + + def download_file(self, bucket_name, object_name, file_path=None): + """ + Download a file from a bucket. + + Args: + bucket_name: Name of the bucket + object_name: S3 object name + file_path: Local path to save the file (if None, uses object_name basename) + + Returns: + bool: True if file was downloaded, False on error + """ + # If file_path was not specified, use object_name basename + if file_path is None: + file_path = Path(object_name).name + + try: + # Create directory if it doesn't exist + os.makedirs(Path(file_path).parent, exist_ok=True) + + self.client.download_file(bucket_name, object_name, file_path) + log.info(f"Downloaded {bucket_name}/{object_name} to {file_path}") + return True + except ClientError as e: + log.error(f"Error downloading {bucket_name}/{object_name}: {e}") + return False + + def download_directory(self, bucket_name, prefix, local_dir=None): + """ + Download all files with a prefix from a bucket. + + Args: + bucket_name: Name of the bucket + prefix: Prefix of objects to download + local_dir: Local directory to save files (if None, uses current dir) + + Returns: + int: Number of files downloaded + """ + if local_dir is None: + local_dir = "." + + local_dir = Path(local_dir) + os.makedirs(local_dir, exist_ok=True) + + count = 0 + try: + # List all objects with the prefix + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) + + for page in pages: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + object_name = obj["Key"] + + # Calculate relative path from prefix + if prefix and object_name.startswith(prefix): + rel_path = object_name[len(prefix) :] + if rel_path.startswith("/"): + rel_path = rel_path[1:] + else: + rel_path = object_name + + # Create local file path + file_path = local_dir / rel_path + + if self.download_file(bucket_name, object_name, str(file_path)): + count += 1 + + log.info( + f"Downloaded {count} files from {bucket_name}/{prefix} to {local_dir}" + ) + return count + except ClientError as e: + log.error(f"Error downloading from {bucket_name}/{prefix}: {e}") + return count + + def list_objects(self, bucket_name, prefix="", recursive=True, max_items=None, print=True): + """ + List objects in a bucket with an optional prefix. + + Args: + bucket_name: Name of the bucket + prefix: Prefix filter for objects + recursive: If False, emulates directory listing with delimiters + max_items: Maximum number of items to return + print: If True, prints the list of objects + + Returns: + list: List of object keys + """ + try: + paginator = self.client.get_paginator("list_objects_v2") + + # Set up pagination parameters + pagination_config = {} + if max_items: + pagination_config["MaxItems"] = max_items + + # Set up operation parameters + operation_params = {"Bucket": bucket_name, "Prefix": prefix} + + # If not recursive, use delimiter to emulate directory listing + if not recursive: + operation_params["Delimiter"] = "/" + + # Get pages of objects + pages = paginator.paginate( + **operation_params, PaginationConfig=pagination_config + ) + + objects = [] + + for page in pages: + # Add objects + if "Contents" in page: + for obj in page["Contents"]: + objects.append(obj["Key"]) + + # Add common prefixes (folders) if not recursive + if not recursive and "CommonPrefixes" in page: + for prefix in page["CommonPrefixes"]: + objects.append(prefix["Prefix"]) + + if print: + log.info(f"Found {len(objects)} objects in {bucket_name}/{prefix}") + for obj in objects: + log.info(f"- {obj}") + + return objects + except ClientError as e: + log.error(f"Error listing objects in {bucket_name}/{prefix}: {e}") + return [] + + def delete_object(self, bucket_name, object_name): + """ + Delete an object from a bucket. + + Args: + bucket_name: Name of the bucket + object_name: S3 object name to delete + + Returns: + bool: True if object was deleted, False on error + """ + try: + self.client.delete_object(Bucket=bucket_name, Key=object_name) + log.info(f"Deleted {bucket_name}/{object_name}") + return True + except ClientError as e: + log.error(f"Error deleting {bucket_name}/{object_name}: {e}") + return False + + def delete_objects(self, bucket_name, object_names): + """ + Delete multiple objects from a bucket. + + Args: + bucket_name: Name of the bucket + object_names: List of object names to delete + + Returns: + int: Number of objects deleted + """ + if not object_names: + return 0 + + try: + # Create delete request + objects = [{"Key": obj} for obj in object_names] + response = self.client.delete_objects( + Bucket=bucket_name, Delete={"Objects": objects} + ) + + deleted = len(response.get("Deleted", [])) + errors = len(response.get("Errors", [])) + + log.info(f"Deleted {deleted} objects from {bucket_name}") + if errors > 0: + log.error(f"Failed to delete {errors} objects") + + return deleted + except ClientError as e: + log.error(f"Error deleting objects from {bucket_name}: {e}") + return 0 + + def delete_prefix(self, bucket_name, prefix): + """ + Delete all objects with a specific prefix (like a folder). + + Args: + bucket_name: Name of the bucket + prefix: Prefix of objects to delete + + Returns: + int: Number of objects deleted + """ + try: + # List all objects with the prefix + objects = self.list_objects(bucket_name, prefix, print=False) + + # Delete the objects in batches + count = 0 + batch_size = 1000 # S3 limits delete_objects to 1000 at a time + + for i in range(0, len(objects), batch_size): + batch = objects[i : i + batch_size] + count += self.delete_objects(bucket_name, batch) + + log.info(f"Deleted {count} objects from {bucket_name}/{prefix}") + return count + except ClientError as e: + log.error(f"Error deleting prefix {bucket_name}/{prefix}: {e}") + return 0 + + def delete_all_objects(self, bucket_name): + """ + Delete all objects in a bucket. + + Args: + bucket_name: Name of the bucket + + Returns: + int: Number of objects deleted + """ + return self.delete_prefix(bucket_name, "") + + def split_directory_to_buckets( + self, source_path, bucket_name, folder_names, split_folders=None + ): + """ + Split folders from a directory into separate folders in a bucket. + + Args: + source_path: Path to the directory containing folders to split + bucket_name: Name of the bucket to upload to + folder_names: List of folder names to upload + split_folders: Dictionary mapping folders to destination prefixes, + if None, splits into equal groups + + Returns: + dict: Mapping of destination prefixes to lists of folders uploaded + """ + source_path = Path(source_path) + if not source_path.is_dir(): + log.error(f"Error: {source_path} is not a directory") + return {} + + # Ensure bucket exists + self.create_bucket(bucket_name) + + # Get folders in source directory that match requested folder names + folders = [] + for folder_name in folder_names: + folder_path = source_path / folder_name + if folder_path.is_dir(): + folders.append(folder_name) + else: + log.warning(f"Warning: {folder_path} is not a directory, skipping") + + # If split_folders is None, create equal groups + if split_folders is None: + half = len(folders) // 2 + split_folders = {"1": folders[:half], "2": folders[half:]} + + result = {} + + # Upload each group of folders to the specified prefix + for prefix, group_folders in split_folders.items(): + result[prefix] = [] + + for folder in group_folders: + if folder in folders: + folder_path = source_path / folder + # Upload the folder with the prefix + upload_prefix = f"{prefix}/{folder}" + count = self.upload_directory( + folder_path, bucket_name, upload_prefix + ) + if count > 0: + result[prefix].append(folder) + log.info(f"Uploaded {folder} to {bucket_name}/{upload_prefix}") + + return result + + def copy_object(self, source_bucket, source_key, dest_bucket, dest_key=None): + """ + Copy an object within or between buckets. + + Args: + source_bucket: Source bucket name + source_key: Source object key + dest_bucket: Destination bucket name + dest_key: Destination object key (if None, uses source_key) + + Returns: + bool: True if object was copied, False on error + """ + if dest_key is None: + dest_key = source_key + + try: + copy_source = {"Bucket": source_bucket, "Key": source_key} + + self.client.copy_object( + CopySource=copy_source, Bucket=dest_bucket, Key=dest_key + ) + + log.info(f"Copied {source_bucket}/{source_key} to {dest_bucket}/{dest_key}") + return True + except ClientError as e: + log.error(f"Error copying {source_bucket}/{source_key}: {e}") + return False + + def search_objects(self, bucket_name, pattern, prefix=""): + """ + Search for objects in a bucket using a glob pattern. + + Args: + bucket_name: Name of the bucket + pattern: Glob pattern to match object keys against + prefix: Optional prefix to limit search scope + + Returns: + list: List of matching object keys + """ + objects = self.list_objects(bucket_name, prefix) + matches = [obj for obj in objects if fnmatch.fnmatch(obj, pattern)] + + log.info( + f"Found {len(matches)} objects matching '{pattern}' in {bucket_name}/{prefix}" + ) + for obj in matches: + log.info(f"- {obj}") + + return matches diff --git a/tests/end_to_end/pytest.ini b/tests/end_to_end/pytest.ini index 2e8d4c9d69..f93782c49a 100644 --- a/tests/end_to_end/pytest.ini +++ b/tests/end_to_end/pytest.ini @@ -9,7 +9,12 @@ markers = task_runner_basic: mark a test as a task runner basic test. task_runner_dockerized_ws: mark a test as a task runner dockerized workspace test. task_runner_basic_gandlf: mark a test as a task runner basic for GanDLF test. + task_runner_connectivity: mark a test as a connectivity test. + task_runner_with_s3: mark a test as a task runner with S3 test. + task_runner_with_azure_blob: mark a test as a task runner with Azure Blob test. + task_runner_with_all_ds: mark a test as a task runner with all data sources test. federated_runtime_301_watermarking: mark a test as a federated runtime 301 watermarking test. straggler_tests: mark a test as a straggler test. + task_runner_fed_analytics: mark a test as a task runner analytics test. asyncio_mode=auto asyncio_default_fixture_loop_scope="function" diff --git a/tests/end_to_end/test_suites/memory_logs_tests.py b/tests/end_to_end/test_suites/memory_logs_tests.py index 0bbaa2abb8..3025f0a2ed 100644 --- a/tests/end_to_end/test_suites/memory_logs_tests.py +++ b/tests/end_to_end/test_suites/memory_logs_tests.py @@ -6,7 +6,7 @@ import os from tests.end_to_end.utils.tr_common_fixtures import fx_federation_tr, fx_federation_tr_dws -import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.defaults as defaults from tests.end_to_end.utils import federation_helper as fed_helper, ssh_helper as ssh from tests.end_to_end.utils.generate_report import generate_memory_report, convert_to_json @@ -66,7 +66,7 @@ def _log_memory_usage(request, fed_obj): ), "Federation completion failed" # Verify the aggregator memory logs - aggregator_memory_usage_file = constants.AGG_MEM_USAGE_LOGFILE.format(fed_obj.workspace_path) + aggregator_memory_usage_file = defaults.AGG_MEM_USAGE_LOGFILE.format(fed_obj.workspace_path) assert os.path.exists( aggregator_memory_usage_file @@ -84,7 +84,7 @@ def _log_memory_usage(request, fed_obj): # check memory usage entries for each collaborator for collaborator in fed_obj.collaborators: - collaborator_memory_usage_file = constants.COL_MEM_USAGE_LOGFILE.format( + collaborator_memory_usage_file = defaults.COL_MEM_USAGE_LOGFILE.format( fed_obj.workspace_path, collaborator.name ) assert os.path.exists( diff --git a/tests/end_to_end/test_suites/sample_tests.py b/tests/end_to_end/test_suites/sample_tests.py index cea7add2ec..28f376f5eb 100644 --- a/tests/end_to_end/test_suites/sample_tests.py +++ b/tests/end_to_end/test_suites/sample_tests.py @@ -15,7 +15,7 @@ # ** IMPORTANT **: This is just an example on how to add a test with below pre-requisites. # Task Runner API Test function for federation run using sample_model # 1. Create OpenFL workspace, if not present for the model and add relevant dataset and its path in plan/data.yaml -# 2. Append the model name to ModelName enum in tests/end_to_end/utils/constants.py +# 2. Append the model name to ModelName enum in tests/end_to_end/utils/defaults.py # 3. a. Use fx_federation_tr fixture for task runner with bare metal or docker approach. # 3. b. Use fx_federation_tr_dws fixture for task runner with dockerized workspace approach. # 4. Fixture will contain - model_owner, aggregator, collaborators, workspace_path, local_bind_path diff --git a/tests/end_to_end/test_suites/task_runner_tests.py b/tests/end_to_end/test_suites/task_runner_tests.py index 83b51f7ca5..52fe7c00f6 100644 --- a/tests/end_to_end/test_suites/task_runner_tests.py +++ b/tests/end_to_end/test_suites/task_runner_tests.py @@ -57,7 +57,7 @@ def test_federation_via_dockerized_workspace(request, fx_federation_tr_dws): log.info(f"Model best aggregated score post {request.config.num_rounds} is {best_agg_score}") -@pytest.mark.task_runner_basic_connectivity +@pytest.mark.task_runner_connectivity def test_federation_connectivity(request, fx_federation_tr): """ Verify that the collaborator can ping the aggregator. If Ping successful, collaborator can start the training. diff --git a/tests/end_to_end/test_suites/tr_fed_analytics_tests.py b/tests/end_to_end/test_suites/tr_fed_analytics_tests.py new file mode 100644 index 0000000000..72dbc62af3 --- /dev/null +++ b/tests/end_to_end/test_suites/tr_fed_analytics_tests.py @@ -0,0 +1,70 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import logging +import os + + +from tests.end_to_end.utils.tr_common_fixtures import ( + fx_federation_tr, +) +from tests.end_to_end.utils import federation_helper as fed_helper +import json +import tests.end_to_end.utils.defaults as defaults + +log = logging.getLogger(__name__) + +# write a fixture to update request.config.num_rounds to 1 +@pytest.fixture(scope="function") +def set_num_rounds(request): + """ + Fixture to set the number of rounds for the test. + Args: + request (Fixture): Pytest fixture + """ + # Set the number of rounds to 1 + log.info("Setting number of rounds to 1 for analytics test") + request.config.num_rounds = 1 + if "federated_analytics" not in request.config.model_name: + pytest.skip( + f"Model name {request.config.model_name} is not supported for this test. " + "Please use a different model name." + ) + + +@pytest.mark.task_runner_fed_analytics +def test_federation_analytics(request, set_num_rounds, fx_federation_tr): + """ + Test federation via native task runner. + Args: + request (Fixture): Pytest fixture + fx_federation_tr (Fixture): Pytest fixture for native task runner + """ + # Start the federation + assert fed_helper.run_federation(fx_federation_tr) + + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion( + fx_federation_tr, + test_env=request.config.test_env, + num_rounds=request.config.num_rounds, + ), "Federation completion failed" + + # verify that results get saved in save/results.json + result_path = os.path.join( + fx_federation_tr.aggregator.workspace_path, + "save", + "result.json" + ) + assert os.path.exists(result_path), f"Results file {result_path} does not exist" + + with open(result_path, "r") as f: + results = f.read() + try: + json.loads(results) + except json.JSONDecodeError as e: + log.warning("Results file is not valid JSON. Raw content:\n%s", results) + raise e + + assert results, f"Results file {result_path} is empty" diff --git a/tests/end_to_end/test_suites/tr_security_testssl.py b/tests/end_to_end/test_suites/tr_security_testssl.py index e81b3fce87..f275d69825 100644 --- a/tests/end_to_end/test_suites/tr_security_testssl.py +++ b/tests/end_to_end/test_suites/tr_security_testssl.py @@ -11,7 +11,7 @@ fx_federation_tr, ) from tests.end_to_end.utils import federation_helper as fed_helper -from tests.end_to_end.utils import constants +from tests.end_to_end.utils import defaults log = logging.getLogger(__name__) @@ -28,7 +28,7 @@ def test_federation_via_native(request, fx_federation_tr): assert fed_helper.run_federation(fx_federation_tr) # Get aggregator address and port from plan.yaml - plan_dir = constants.AGG_PLAN_PATH.format(fx_federation_tr.local_bind_path) + plan_dir = defaults.AGG_PLAN_PATH.format(fx_federation_tr.local_bind_path) plan_file = os.path.join(plan_dir, "plan.yaml") aggreagtor_addr, aggregator_port = fed_helper.get_agg_addr_port(plan_file) diff --git a/tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py b/tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py new file mode 100644 index 0000000000..7d873d1373 --- /dev/null +++ b/tests/end_to_end/test_suites/tr_verifiable_dataset_tests.py @@ -0,0 +1,111 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import logging + +from tests.end_to_end.utils.tr_common_fixtures import \ +( + fx_federation_tr, fx_verifiable_dataset_with_s3, + fx_verifiable_dataset_with_azure_blob, + fx_verifiable_dataset_with_all_ds +) +from tests.end_to_end.utils import federation_helper as fed_helper + +log = logging.getLogger(__name__) + +# IMPORTANT +# Ensure to have minio and minio client installed for S3 and azurite for Azure Blob Storage tests. + +@pytest.mark.task_runner_with_s3 +def test_federation_with_s3_bucket(request, fx_verifiable_dataset_with_s3, fx_federation_tr): + """ + Test federation with S3 bucket. Model name - torch/histology_s3 + Steps: + 1. Start the minio server, create buckets for every collaborator. + 2. Download data using torch/histology dataloader and upload data to the buckets. + 3. Create a datasources.json file for each collaborator which will contain the S3 bucket details. + 4. Calculate hash for each collaborator's data (it generates hash.txt file under the data directory). + 5. Start the federation (internally the hash is verified as well). + 6. Verify the completion of the federation run. + 7. Verify the best aggregated score. + Args: + request (Fixture): Pytest fixture + fx_federation_tr (Fixture): Pytest fixture for native task runner + """ + # Start the federation + assert fed_helper.run_federation(fx_federation_tr) + + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion( + fx_federation_tr, + test_env=request.config.test_env, + num_rounds=request.config.num_rounds, + time_for_each_round=300, + ), "Federation completion failed" + + best_agg_score = fed_helper.get_best_agg_score(fx_federation_tr.aggregator.tensor_db_file) + log.info(f"Model best aggregated score post {request.config.num_rounds} is {best_agg_score}") + + +@pytest.mark.task_runner_with_azure_blob +def test_federation_with_azure_blob(request, fx_verifiable_dataset_with_azure_blob, fx_federation_tr): + """ + Test federation with Azure Blob Storage. Model name - torch/histology_azure_blob + Steps: + 1. Start azurite emulator, create containers for every collaborator. + 2. Download data using torch/histology dataloader and upload data to the containers. + 3. Create a datasources.json file for each collaborator which will contain the azure blob container details. + 4. Calculate hash for each collaborator's data (it generates hash.txt file under the data directory). + 5. Start the federation (internally the hash is verified as well). + 6. Verify the completion of the federation run. + 7. Verify the best aggregated score. + Args: + request (Fixture): Pytest fixture + fx_federation_tr (Fixture): Pytest fixture for native task runner + """ + # Start the federation + assert fed_helper.run_federation(fx_federation_tr) + + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion( + fx_federation_tr, + test_env=request.config.test_env, + num_rounds=request.config.num_rounds, + time_for_each_round=300, + ), "Federation completion failed" + + best_agg_score = fed_helper.get_best_agg_score(fx_federation_tr.aggregator.tensor_db_file) + log.info(f"Model best aggregated score post {request.config.num_rounds} is {best_agg_score}") + + +@pytest.mark.task_runner_with_all_ds +def test_federation_with_all(request, fx_verifiable_dataset_with_all_ds, fx_federation_tr): + """ + Test federation with all combinations of S3, Azure Blob Storage and local data. Model name - torch/histology_all + Steps: + 1. Start the minio server, create buckets for every collaborator. + 2. Start azurite emulator, create containers for every collaborator. + 3. Download data using torch/histology dataloader and upload data (without overlapping) to the buckets and containers. + 4. Create a datasources.json file for each collaborator which will contain the S3 bucket, azure container and local datasources. + 5. Calculate hash for each collaborator's data (it generates hash.txt file under the data directory). + 6. Start the federation (internally the hash is verified as well). + 7. Verify the completion of the federation run. + 8. Verify the best aggregated score. + Args: + request (Fixture): Pytest fixture + fx_federation_tr (Fixture): Pytest fixture for native task runner + """ + # Start the federation + assert fed_helper.run_federation(fx_federation_tr) + + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion( + fx_federation_tr, + test_env=request.config.test_env, + num_rounds=request.config.num_rounds, + time_for_each_round=300, + ), "Federation completion failed" + + best_agg_score = fed_helper.get_best_agg_score(fx_federation_tr.aggregator.tensor_db_file) + log.info(f"Model best aggregated score post {request.config.num_rounds} is {best_agg_score}") diff --git a/tests/end_to_end/test_suites/wf_federated_runtime_tests.py b/tests/end_to_end/test_suites/wf_federated_runtime_tests.py index 2e798ea0a4..d7c3613b26 100644 --- a/tests/end_to_end/test_suites/wf_federated_runtime_tests.py +++ b/tests/end_to_end/test_suites/wf_federated_runtime_tests.py @@ -8,6 +8,8 @@ import concurrent.futures import tests.end_to_end.utils.federation_helper as fh +import tests.end_to_end.utils.helper as helper +import tests.end_to_end.utils.wf_helper as wf_helper log = logging.getLogger(__name__) @@ -62,7 +64,7 @@ def test_federated_runtime_301_watermarking(request): # This might not be true for all notebooks, thus keeping it as a separate step os.chdir(nb_workspace_path) - assert fh.run_notebook( + assert wf_helper.run_notebook( notebook_path=notebook_path, output_notebook_path=result_path + "/" + "MNIST_Watermarking_output.ipynb" ), "Notebook run failed" @@ -126,7 +128,7 @@ def test_federated_runtime_secure_aggregation(request): # This might not be true for all notebooks, thus keeping it as a separate step os.chdir(nb_workspace_path) - assert fh.run_notebook( + assert wf_helper.run_notebook( notebook_path=notebook_path, output_notebook_path=result_path + "/" + "MNIST_SecAgg_output.ipynb" ), "Notebook run failed" @@ -150,7 +152,7 @@ def activate_experimental_feature(workspace_path): # Activate the experimental feature cmd = f"fx experimental activate" error_msg = "Failed to activate the experimental feature" - return_code, output, error = fh.run_command( + return_code, output, error = helper.run_command( cmd, workspace_path=workspace_path, error_msg=error_msg, diff --git a/tests/end_to_end/test_suites/wf_local_func_tests.py b/tests/end_to_end/test_suites/wf_local_func_tests.py index 41b02af17d..d8002a6b26 100644 --- a/tests/end_to_end/test_suites/wf_local_func_tests.py +++ b/tests/end_to_end/test_suites/wf_local_func_tests.py @@ -8,21 +8,32 @@ import random from metaflow import Step -from tests.end_to_end.utils.wf_common_fixtures import fx_local_federated_workflow, fx_local_federated_workflow_prvt_attr +from tests.end_to_end.utils.wf_common_fixtures import ( + fx_local_federated_workflow, + fx_local_federated_workflow_prvt_attr, + fx_local_fed_wf_unserializable_pvt_attrs, +) + from tests.end_to_end.workflow.exclude_flow import TestFlowExclude from tests.end_to_end.workflow.include_exclude_flow import TestFlowIncludeExclude from tests.end_to_end.workflow.include_flow import TestFlowInclude from tests.end_to_end.workflow.internal_loop import TestFlowInternalLoop from tests.end_to_end.workflow.reference_flow import TestFlowReference from tests.end_to_end.workflow.subset_flow import TestFlowSubsetCollaborators -from tests.end_to_end.workflow.private_attr_wo_callable import TestFlowPrivateAttributesWoCallable +from tests.end_to_end.workflow.private_attr_wo_callable import ( + TestFlowPrivateAttributesWoCallable, +) from tests.end_to_end.workflow.private_attributes_flow import TestFlowPrivateAttributes from tests.end_to_end.workflow.private_attr_both import TestFlowPrivateAttributesBoth +from tests.end_to_end.workflow.unserializable_private_attr import ( + TestFlowUnserializablePrivateAttributes, +) from tests.end_to_end.utils import wf_helper as wf_helper log = logging.getLogger(__name__) + def test_exclude_flow(request, fx_local_federated_workflow): """ Test if variable is excluded, variables not show in next step @@ -73,7 +84,9 @@ def test_internal_loop(request, fx_local_federated_workflow): model = None optimizer = None - flflow = TestFlowInternalLoop(model, optimizer, request.config.num_rounds, checkpoint=True) + flflow = TestFlowInternalLoop( + model, optimizer, request.config.num_rounds, checkpoint=True + ) flflow.runtime = fx_local_federated_workflow.runtime flflow.run() @@ -87,25 +100,37 @@ def test_internal_loop(request, fx_local_federated_workflow): "end", ] - steps_present_in_cli, missing_steps_in_cli, extra_steps_in_cli = wf_helper.validate_flow( - flflow, expected_flow_steps - ) - - assert len(steps_present_in_cli) == len(expected_flow_steps), "Number of steps fetched from Datastore through CLI do not match the Expected steps provided" - assert len(missing_steps_in_cli) == 0, f"Following steps missing from Datastore: {missing_steps_in_cli}" - assert len(extra_steps_in_cli) == 0, f"Following steps are extra in Datastore: {extra_steps_in_cli}" + steps_present_in_cli, missing_steps_in_cli, extra_steps_in_cli = ( + wf_helper.validate_flow(flflow, expected_flow_steps) + ) + + assert len(steps_present_in_cli) == len( + expected_flow_steps + ), "Number of steps fetched from Datastore through CLI do not match the Expected steps provided" + assert ( + len(missing_steps_in_cli) == 0 + ), f"Following steps missing from Datastore: {missing_steps_in_cli}" + assert ( + len(extra_steps_in_cli) == 0 + ), f"Following steps are extra in Datastore: {extra_steps_in_cli}" assert flflow.end_count == 1, "End function called more than one time" - log.info("\n Summary of internal flow testing \n" - "No issues found and below are the tests that ran successfully\n" - "1. Number of training completed is equal to training rounds\n" - "2. CLI steps and Expected steps are matching\n" - "3. Number of tasks are aligned with number of rounds and number of collaborators\n" - "4. End function executed one time") + log.info( + "\n Summary of internal flow testing \n" + "No issues found and below are the tests that ran successfully\n" + "1. Number of training completed is equal to training rounds\n" + "2. CLI steps and Expected steps are matching\n" + "3. Number of tasks are aligned with number of rounds and number of collaborators\n" + "4. End function executed one time" + ) log.info("Successfully ended test_internal_loop") -@pytest.mark.parametrize("fx_local_federated_workflow", [("init_collaborator_private_attr_index", "int", None )], indirect=True) +@pytest.mark.parametrize( + "fx_local_federated_workflow", + [("init_collaborator_private_attr_index", "int", None)], + indirect=True, +) def test_reference_flow(request, fx_local_federated_workflow): """ Test reference variables matched through out the flow @@ -118,7 +143,12 @@ def test_reference_flow(request, fx_local_federated_workflow): flflow.run() log.info("Successfully ended test_reference_flow") -@pytest.mark.parametrize("fx_local_federated_workflow", [("init_collaborator_private_attr_name", "str", None )], indirect=True) + +@pytest.mark.parametrize( + "fx_local_federated_workflow", + [("init_collaborator_private_attr_name", "str", None)], + indirect=True, +) def test_subset_collaborators(request, fx_local_federated_workflow): """ Test the subset of collaborators in a federated workflow. @@ -158,16 +188,16 @@ def test_subset_collaborators(request, fx_local_federated_workflow): ) assert len(list(step)) == len(subset_collaborators), ( - f"...Flow only ran for {len(list(step))} " - + f"instead of the {len(subset_collaborators)} expected " - + f"collaborators- Testcase Failed." - ) + f"...Flow only ran for {len(list(step))} " + + f"instead of the {len(subset_collaborators)} expected " + + f"collaborators- Testcase Failed." + ) log.info( f"Found {len(list(step))} tasks for each of the " + f"{len(subset_collaborators)} collaborators" ) - log.info(f'subset_collaborators = {subset_collaborators}') - log.info(f'collaborators_ran = {collaborators_ran}') + log.info(f"subset_collaborators = {subset_collaborators}") + log.info(f"collaborators_ran = {collaborators_ran}") for collaborator_name in subset_collaborators: assert collaborator_name in collaborators_ran, ( f"...Flow did not execute for " @@ -177,7 +207,8 @@ def test_subset_collaborators(request, fx_local_federated_workflow): log.info( f"Testing FederatedFlow - Ending test for validating " - + f"the subset of collaborators.") + + f"the subset of collaborators." + ) log.info("Successfully ended test_subset_collaborators") @@ -194,7 +225,11 @@ def test_private_attr_wo_callable(request, fx_local_federated_workflow_prvt_attr log.info("Successfully ended test_private_attr_wo_callable") -@pytest.mark.parametrize("fx_local_federated_workflow", [("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np" )], indirect=True) +@pytest.mark.parametrize( + "fx_local_federated_workflow", + [("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np")], + indirect=True, +) def test_private_attributes(request, fx_local_federated_workflow): """ Set private attribute through callable function @@ -208,7 +243,11 @@ def test_private_attributes(request, fx_local_federated_workflow): log.info("Successfully ended test_private_attributes") -@pytest.mark.parametrize("fx_local_federated_workflow_prvt_attr", [("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np" )], indirect=True) +@pytest.mark.parametrize( + "fx_local_federated_workflow_prvt_attr", + [("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np")], + indirect=True, +) def test_private_attr_both(request, fx_local_federated_workflow_prvt_attr): """ Set private attribute through callable function and direct assignment @@ -220,3 +259,25 @@ def test_private_attr_both(request, fx_local_federated_workflow_prvt_attr): log.info(f"Starting round {i}...") flflow.run() log.info("Successfully ended test_private_attr_both") + + +@pytest.mark.parametrize( + "fx_local_fed_wf_unserializable_pvt_attrs", + [ + ("callable_to_init_collab_unserializable_pvt_attrs", + "int", + "callable_to_init_agg_unserializable_pvt_attrs") + ], + indirect=True, +) +def test_unserializable_private_attr( + request, fx_local_fed_wf_unserializable_pvt_attrs +): + """ + Validate unserializable objects are accessible as private attributes + """ + log.info("Starting Test for unserializable private attributes") + flflow = TestFlowUnserializablePrivateAttributes(rounds=request.config.num_rounds, checkpoint=False) + flflow.runtime = fx_local_fed_wf_unserializable_pvt_attrs.runtime + flflow.run() + log.info("Successfully ended Test for unserializable private attributes") diff --git a/tests/end_to_end/utils/conftest_helper.py b/tests/end_to_end/utils/conftest_helper.py index 0854b09268..8d0a9e8f12 100644 --- a/tests/end_to_end/utils/conftest_helper.py +++ b/tests/end_to_end/utils/conftest_helper.py @@ -31,7 +31,7 @@ def parse_arguments(): parser.add_argument("--num_rounds", type=int, default=5, help="Number of rounds to train. Default is 5") parser.add_argument("--model_name", type=str, help="Model name. Not required for Workflow APIs") parser.add_argument("--workflow_backend", type=str, help="Workflow backend, e.g - ray") - parser.add_argument("--tr_rest_api", action="store_true", help="Enable rest api protocol in task runner. If not set, grpc is used") + parser.add_argument("--tr_rest_protocol", action="store_true", help="Enable rest protocol in task runner. If not set, gRPC is used") parser.add_argument("--disable_client_auth", action="store_true", help="Disable client authentication. Default is False") parser.add_argument("--disable_tls", action="store_true", help="Disable TLS for communication. Default is False") parser.add_argument("--log_memory_usage", action="store_true", help="Enable Memory leak logs. Default is False") diff --git a/tests/end_to_end/utils/data_helper.py b/tests/end_to_end/utils/data_helper.py new file mode 100644 index 0000000000..cef8fa6668 --- /dev/null +++ b/tests/end_to_end/utils/data_helper.py @@ -0,0 +1,556 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import shutil +import subprocess +from glob import glob +import logging +import importlib +from pathlib import Path +import json + +import tests.end_to_end.utils.defaults as defaults +import tests.end_to_end.utils.exceptions as ex +from tests.end_to_end.models import az_storage as az_storage_model ,s3_bucket as s3_model + +log = logging.getLogger(__name__) + + +def setup_collaborator_data(collaborators, model_name, local_bind_path): + """ + Function to setup the data for collaborators. + IMP: This function is specific to the model and should be updated as per the model requirements. + Args: + collaborators (list): List of collaborator objects + model_name (str): Model name + local_bind_path (str): Local bind path + """ + # Check if data already exists, if yes, skip the download part + # This is mainly helpful in case of re-runs + if all(os.path.exists(os.path.join(collaborator.workspace_path, "data", str(index))) for index, collaborator in enumerate(collaborators, start=1)): + log.info("Data already exists for all the collaborators. Skipping the download part..") + return + else: + log.info("Data does not exist for all the collaborators. Proceeding with the download..") + # Below step will also modify the data.yaml file for all the collaborators + if model_name == defaults.ModelName.XGB_HIGGS.value: + download_higgs_data(collaborators, local_bind_path) + elif model_name == defaults.ModelName.FLOWER_APP_PYTORCH.value: + download_flower_data(collaborators, local_bind_path) + + log.info("Data setup is complete for all the collaborators") + + +def download_gandlf_data(aggregator, local_bind_path, num_collaborators, results_path): + """ + Function to download the data for GanDLF segmentation test model and copy to the respective collaborator workspaces + For GanDLF, data download happens at aggregator level, thus we can not call this function from setup_collaborator_data + where download is at collaborator level + Args: + aggregator: Aggregator object + collaborators: List of collaborator objects + local_bind_path: Local bind path + results_path: Result directory (mostly $HOME/results) where GaNDLF csv and config yaml files are present + """ + try: + # Get list of all CSV files in openfl_path + csv_files = glob(os.path.join(results_path, '*.csv')) + + # Get data.yaml file and remove any entry, if present + data_file = os.path.join(aggregator.workspace_path, "plan", "data.yaml") + with open(data_file, "w") as df: + df.write("") + + # Copy the data to the respective workspaces based on the index + for col_index in range(1, num_collaborators+1): + dst_folder = os.path.join(aggregator.workspace_path, "data", str(col_index)) + os.makedirs(dst_folder, exist_ok=True) + for csv_file in csv_files: + shutil.copy(csv_file, dst_folder) + log.info(f"Copied data from {csv_file} to {dst_folder}") + + aggregator.modify_data_file( + defaults.COL_DATA_FILE.format(local_bind_path, "aggregator"), + f"collaborator{col_index}", + col_index, + ) + except Exception as e: + raise ex.DataSetupException(f"Failed to modify the data file: {e}") + + return True + + +def copy_gandlf_data_to_collaborators(aggregator, collaborators, local_bind_path): + """ + Function to copy the GaNDLF data from aggregator to respective collaborators + """ + try: + # Copy the data to the respective workspaces based on the index + for index, collaborator in enumerate(collaborators, start=1): + src_folder = os.path.join(aggregator.workspace_path, "data", str(index)) + dst_folder = os.path.join(collaborator.workspace_path, "data", str(index)) + if os.path.exists(src_folder): + shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True) + log.info(f"Copied data from {src_folder} to {dst_folder}") + else: + raise ex.DataSetupException(f"Source folder {src_folder} does not exist for {collaborator.name}") + + # Modify the data.yaml file for all the collaborators + collaborator.modify_data_file( + defaults.COL_DATA_FILE.format(local_bind_path, collaborator.name), + index, + ) + except Exception as e: + raise ex.DataSetupException(f"Failed to modify the data file: {e}") + + +def download_flower_data(collaborators, local_bind_path): + """ + Download the data for the model and copy to the respective collaborator workspaces + Also modify the data.yaml file for all the collaborators + Args: + collaborators (list): List of collaborator objects + local_bind_path (str): Local bind path + Returns: + bool: True if successful, else False + """ + common_download_for_higgs_and_flower(collaborators, local_bind_path) + + +def download_higgs_data(collaborators, local_bind_path): + """ + Download the data for the model and copy to the respective collaborator workspaces + Also modify the data.yaml file for all the collaborators + Args: + collaborators (list): List of collaborator objects + local_bind_path (str): Local bind path + Returns: + bool: True if successful, else False + """ + common_download_for_higgs_and_flower(collaborators, local_bind_path) + + +def common_download_for_higgs_and_flower(collaborators, local_bind_path): + """ + Common function to download the data for both Higgs and Flower models. + In future, if the data setup for other models is similar, we can use this function. + Also, if the setup changes for any of the models, we can modify this function to accommodate the changes. + """ + log.info(f"Copying {defaults.DATA_SETUP_FILE} from one of the collaborator workspaces to the local bind path..") + try: + shutil.copyfile( + src=os.path.join(collaborators[0].workspace_path, "src", defaults.DATA_SETUP_FILE), + dst=os.path.join(local_bind_path, defaults.DATA_SETUP_FILE) + ) + except Exception as e: + raise ex.DataSetupException(f"Failed to copy data setup file: {e}") + + log.info("Downloading the data for the model. This will take some time to complete based on the data size ..") + try: + command = ["python", defaults.DATA_SETUP_FILE, str(len(collaborators))] + subprocess.run(command, cwd=local_bind_path, check=True) # nosec B603 + except Exception: + raise ex.DataSetupException(f"Failed to download data for given model") + + try: + # Copy the data to the respective workspaces based on the index + for index, collaborator in enumerate(collaborators, start=1): + src_folder = os.path.join(local_bind_path, "data", str(index)) + dst_folder = os.path.join(collaborator.workspace_path, "data", str(index)) + if os.path.exists(src_folder): + shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True) + log.info(f"Copied data from {src_folder} to {dst_folder}") + else: + raise ex.DataSetupException(f"Source folder {src_folder} does not exist for {collaborator.name}") + + # Modify the data.yaml file for all the collaborators + collaborator.modify_data_file( + defaults.COL_DATA_FILE.format(local_bind_path, collaborator.name), + index, + ) + except Exception as e: + raise ex.DataSetupException(f"Failed to modify the data file: {e}") + + # XGBoost model uses folder name higgs_data and Flower model uses data to create data folders. + shutil.rmtree(os.path.join(local_bind_path, "higgs_data"), ignore_errors=True) + shutil.rmtree(os.path.join(local_bind_path, "data"), ignore_errors=True) + return True + + +def prepare_verifiable_dataset(request, dataset_type): + """ + Prepare data for S3, Azurite and/or local datasource based on . + Args: + request (object): Pytest request object. + dataset_type (str): Type of dataset to prepare. Valid values - s3, azure_blob, all. + """ + if dataset_type not in ["s3", "azure_blob", "all"]: + raise ValueError(f"Invalid dataset_type: {dataset_type}. Valid values are 's3', 'azure_blob', 'all'.") + + num_collaborators = request.config.num_collaborators + data_path = Path.cwd().absolute() / 'data' + home_dir = Path().home() + results_path = os.path.join(home_dir, request.config.results_dir) + colab_data_mapping = {} + + # Download the histology data and distribute it among collaborators + # The data is downloaded in the current working directory under 'data' subfolder + download_histology_data(data_path) + distribute_data_to_collaborators(num_collaborators, data_path) + + if dataset_type == "all": + colab_data_mapping = handle_all_dataset_type(num_collaborators, data_path, request) + else: + if dataset_type == "s3": + colab_data_mapping = upload_all_to_s3(num_collaborators, data_path, request) + elif dataset_type == "azure_blob": + colab_data_mapping = upload_all_to_azure_blob(num_collaborators, data_path) + + # Create a datasources.json file for each collaborator + write_datasources_json(num_collaborators, colab_data_mapping, results_path) + + +def upload_all_to_s3(num_collaborators, data_path, request): + """Upload all data for each collaborator to S3.""" + colab_data_mapping = {} + minio_obj = s3_model.MinioServer() + + # Start minio server, create S3 buckets and upload the data to S3 + try: + if not minio_obj.start_minio_server( + data_dir=os.path.join(Path().home(), request.config.results_dir, defaults.MINIO_DATA_FOLDER) + ): + raise ex.MinioServerStartException( + "Failed to start minio server. Please check the logs for more details." + ) + except Exception as e: + raise ex.MinioServerStartException( + f"Failed to start minio server. Error: {e}" + ) + + s3_obj = s3_model.S3Bucket() + for index in range(1, num_collaborators + 1): + bucket_name = f"col{index}-bucket{index}" + try: + s3_obj.create_bucket(bucket_name=bucket_name) + except Exception as e: + raise ex.S3BucketCreationException( + f"Failed to create bucket {bucket_name} for collaborator{index}. Error: {e}" + ) + + collaborator_name = f"collaborator{index}" + local_dir = data_path / str(index) + s3_obj.upload_directory(dir_path=local_dir, bucket_name=bucket_name) + + s3_data = { + "type": "s3", + "params": { + "access_key_env_name": "MINIO_ROOT_USER", + "endpoint": defaults.MINIO_URL, + "secret_key_env_name": "MINIO_ROOT_PASSWORD", + "secret_name": "vault_secret_name1", + "uri": f"s3://{bucket_name}/" + } + } + if collaborator_name not in colab_data_mapping: + colab_data_mapping[collaborator_name] = {} + colab_data_mapping[collaborator_name]["s3_data"] = s3_data + shutil.rmtree(local_dir) # Remove local data after successful upload + return colab_data_mapping + + +def upload_all_to_azure_blob(num_collaborators, data_path): + """Upload all data for each collaborator to Azure Blob (Azurite).""" + azurite_obj = az_storage_model.AzuriteStorage() + colab_data_mapping = {} + try: + azurite_obj.start_azurite_container() + except Exception as e: + raise ex.AzureBlobContainerCreationException( + f"Failed to start azurite container. Error: {e}" + ) + + # Create container + for index in range(1, num_collaborators + 1): + container_name = f"col{index}-container{index}" + try: + azurite_obj.create_container(container_name) + log.info(f"Created container {container_name}") + except Exception as e: + if "specified container already exists" in str(e): + azurite_obj.delete_container(container_name) + azurite_obj.create_container(container_name) + else: + raise ex.AzureBlobContainerCreationException( + f"Failed to create container {container_name} for collaborator{index}. Error: {e}" + ) + collaborator_name = f"collaborator{index}" + local_dir = data_path / str(index) + # Upload data to the container + azurite_obj.upload_data_to_container( + container_name=container_name, + data_path=local_dir + ) + azure_blob_data = { + "type": "azure_blob", + "params": { + "connection_string": azurite_obj.connection_string, + "container_name": container_name + } + } + if collaborator_name not in colab_data_mapping: + colab_data_mapping[collaborator_name] = {} + colab_data_mapping[collaborator_name]["azure_blob_data"] = azure_blob_data + shutil.rmtree(local_dir) # Remove local data after successful upload + return colab_data_mapping + + +def handle_all_dataset_type(num_collaborators, data_path, request): + """ + For 'all' dataset_type, split the data into 3 non-overlapping parts and assign to S3, Azure Blob, and local. + """ + colab_data_mapping = {} + + # Create objects for minio and azurite + minio_obj = s3_model.MinioServer() + try: + if not minio_obj.start_minio_server( + data_dir=os.path.join(Path().home(), request.config.results_dir, defaults.MINIO_DATA_FOLDER) + ): + raise ex.MinioServerStartException( + "Failed to start minio server. Please check the logs for more details." + ) + except Exception as e: + raise ex.MinioServerStartException( + f"Failed to start minio server. Error: {e}" + ) + + s3_obj = s3_model.S3Bucket() + + azurite_obj = az_storage_model.AzuriteStorage() + try: + azurite_obj.start_azurite_container() + except Exception as e: + raise ex.AzureBlobContainerCreationException( + f"Failed to start azurite container. Error: {e}" + ) + + # Upload data to S3, Azure Blob and local for each collaborator + for index in range(1, num_collaborators + 1): + collaborator_name = f"collaborator{index}" + local_dir = data_path / str(index) + all_files = sorted([f for f in local_dir.iterdir() if f.is_dir() or f.is_file()]) + + total = len(all_files) + split_size = total // 3 + splits = [ + all_files[:split_size], + all_files[split_size:2*split_size], + all_files[2*split_size:] + ] + + # Prepare temp dirs for each split + s3_dir = local_dir / "s3_part" + azure_dir = local_dir / "azure_part" + local_part_dir = local_dir / "local_part" + for d in [s3_dir, azure_dir, local_part_dir]: + d.mkdir(parents=True, exist_ok=True) + + # Move files to their respective dirs + for f in splits[0]: + shutil.move(str(f), s3_dir / f.name) + for f in splits[1]: + shutil.move(str(f), azure_dir / f.name) + for f in splits[2]: + shutil.move(str(f), local_part_dir / f.name) + + # Ensure each part has at least one folder (copy from the largest part if needed) + part_dirs = [s3_dir, azure_dir, local_part_dir] + part_counts = [len(list(d.iterdir())) for d in part_dirs] + if any(count == 0 for count in part_counts): + # Find the largest part + largest_idx = part_counts.index(max(part_counts)) + largest_dir = part_dirs[largest_idx] + largest_files = list(largest_dir.iterdir()) + for idx, count in enumerate(part_counts): + if count == 0 and largest_files: + # Copy (not move) the first folder/file from the largest part + src = largest_files[0] + dst = part_dirs[idx] / src.name + if src.is_dir(): + shutil.copytree(src, dst) + else: + shutil.copy2(src, dst) + + # S3 data + bucket_name = f"col{index}-bucket{index}" + s3_obj.create_bucket(bucket_name=bucket_name) + s3_obj.upload_directory(dir_path=s3_dir, bucket_name=bucket_name) + s3_data = { + "type": "s3", + "params": { + "access_key_env_name": "MINIO_ROOT_USER", + "endpoint": defaults.MINIO_URL, + "secret_key_env_name": "MINIO_ROOT_PASSWORD", + "secret_name": "vault_secret_name1", + "uri": f"s3://{bucket_name}/" + } + } + + # Azure Blob data + container_name = f"col{index}-container{index}" + azurite_obj.create_container(container_name) + azurite_obj.upload_data_to_container(container_name=container_name, data_path=azure_dir) + azure_blob_data = { + "type": "azure_blob", + "params": { + "connection_string": azurite_obj.connection_string, + "container_name": container_name + } + } + + # Local data + local_data = { + "type": "local", + "params": { + "path": str(local_part_dir.relative_to(Path.cwd())) + } + } + # Print local data objects count + log.info(f"Retained {len(list(local_part_dir.rglob('*')))} files in local data for {collaborator_name}") + colab_data_mapping[collaborator_name] = { + "s3_data": s3_data, + "azure_blob_data": azure_blob_data, + "local_data": local_data + } + # Clean up temp dirs after upload if needed + # shutil.rmtree(s3_dir) + # shutil.rmtree(azure_dir) + # local_part_dir is kept for local access + + return colab_data_mapping + + +def write_datasources_json(num_collaborators, colab_data_mapping, results_path): + """ + Create a datasources.json file for each collaborator. + Args: + num_collaborators (int): Number of collaborators. + colab_data_mapping (dict): Mapping of collaborator names to their data sources. + results_path (str): Path to the results directory. + """ + for index in range(1, num_collaborators + 1): + collaborator_name = f"collaborator{index}" + col_mapping = colab_data_mapping[collaborator_name] + combined_data = {} + + # Add s3_data as s3_ds1 + if "s3_data" in col_mapping: + combined_data["s3_ds1"] = col_mapping["s3_data"] + + # Add azure_blob_data as azure_ds0 + if "azure_blob_data" in col_mapping: + combined_data["azure_ds1"] = col_mapping["azure_blob_data"] + + # Add local_data as local_ds1 + if "local_data" in col_mapping: + combined_data["local_ds1"] = col_mapping["local_data"] + + ds_file = os.path.join(results_path, "datasources", collaborator_name, "datasources.json") + os.makedirs(os.path.dirname(ds_file), exist_ok=True) + + with open(ds_file, "w") as f: + json.dump(combined_data, f, indent=2) + + +def distribute_data_to_collaborators(num_collaborators, data_path): + """ + Distribute the data among the collaborators uniformly. + Example: Assuming num_collaborators is 3 + If data_path has folder Kather_texture_2016_image_tiles_5000 (torch/histology) which further has 8 subfolders, + then the data will be distributed as: + collaborator1: 1 / first 3 subfolders + collaborator2: 2 / next 3 subfolders + collaborator3: 3 / last 2 subfolders + If data_path itself has multiple folders say 8, then the data will be distributed as: + collaborator1: 1 / first 3 folders + collaborator2: 2 / next 3 folders + collaborator3: 3 / last 2 folders + Args: + num_collaborators (int): Number of collaborators. + data_path (str): Path to the data directory. + Raises: + Exception: If the data distribution fails. + """ + # Pre-check: skip if all collaborator folders exist and are non-empty + already_distributed = True + for index in range(1, num_collaborators + 1): + collaborator_data_path = data_path / str(index) + if not (collaborator_data_path.exists() and any(collaborator_data_path.iterdir())): + already_distributed = False + break + + # If already distributed, just collect the mapping and return + if already_distributed: + log.info("Data already distributed among collaborators. Skipping distribution.") + return + + # If data_path has only one folder, go inside it and use its subfolders + all_entries = [f for f in data_path.iterdir() if f.is_dir()] + if len(all_entries) == 1: + # Use subfolders inside the single folder + all_folders = [f for f in all_entries[0].iterdir() if f.is_dir()] + else: + all_folders = all_entries + all_folders.sort() # For deterministic split + + num_folders = len(all_folders) + folders_per_collab = [num_folders // num_collaborators] * num_collaborators + for i in range(num_folders % num_collaborators): + folders_per_collab[i] += 1 + + start = 0 + for index in range(1, num_collaborators + 1): + collaborator_data_path = data_path / str(index) + collaborator_data_path.mkdir(parents=True, exist_ok=True) + end = start + folders_per_collab[index - 1] + for folder in all_folders[start:end]: + dest = collaborator_data_path / folder.name + if folder.parent != collaborator_data_path: + folder.rename(dest) + start = end + + # Remove all files/folders from 'data' except collaborator folders (1, 2, 3, ...) + for entry in data_path.iterdir(): + if entry.is_dir() and entry.name not in [str(i) for i in range(1, num_collaborators + 1)]: + shutil.rmtree(entry) + log.info(f"Removed folder {entry} from data path") + elif entry.is_file() and entry.name.endswith(".zip"): + os.remove(entry) + log.info(f"Removed zip file {entry} from data path") + + +def download_histology_data(data_path): + """ + Download the histology data using its dataloader module. + The data is downloaded in the current working directory under 'data' subfolder. + """ + # Check if data already exists, if yes delete the folder and download again + if data_path.exists() and any(data_path.iterdir()): + log.info("Data already exists. Deleting the folder and downloading again..") + shutil.rmtree(data_path) + + # Import the dataloader module for torch/histology to download the data + # As the folder name contains hyphen, we need to use importlib to import the module + dataloader_module = importlib.import_module("openfl-workspace.torch.histology.src.dataloader") + + # Download the data for torch/histology in current folder as internally it uses the current folder as data path + try: + log.info(f"Downloading data for {defaults.ModelName.TORCH_HISTOLOGY_S3.value}") + dataloader_module.HistologyDataset() + log.info("Download completed") + except Exception as e: + raise ex.DataDownloadException( + f"Failed to download data for {defaults.ModelName.TORCH_HISTOLOGY_S3.value}. Error: {e}" + ) diff --git a/tests/end_to_end/utils/db_helper.py b/tests/end_to_end/utils/db_helper.py index 9c4186eb23..3eff0a5787 100644 --- a/tests/end_to_end/utils/db_helper.py +++ b/tests/end_to_end/utils/db_helper.py @@ -54,7 +54,7 @@ def read_key_value_store(self): return key_value_dict -def get_key_value_from_db(key, database_file, max_retries=10, sleep_interval=5): +def get_key_value_from_db(key, database_file, max_retries=15, sleep_interval=10): """ Get value by key from the database file Args: diff --git a/tests/end_to_end/utils/constants.py b/tests/end_to_end/utils/defaults.py similarity index 65% rename from tests/end_to_end/utils/constants.py rename to tests/end_to_end/utils/defaults.py index c3e63de45f..0265053085 100644 --- a/tests/end_to_end/utils/constants.py +++ b/tests/end_to_end/utils/defaults.py @@ -14,20 +14,24 @@ class ModelName(Enum): KERAS_MNIST = "keras/mnist" KERAS_TORCH_MNIST = "keras/torch/mnist" TORCH_HISTOLOGY = "torch/histology" + TORCH_HISTOLOGY_S3 = "torch/histology_s3" TORCH_MNIST = "torch/mnist" TORCH_MNIST_EDEN_COMPRESSION = "torch/mnist_eden_compression" TORCH_MNIST_STRAGGLER_CHECK = "torch/mnist_straggler_check" + KERAS_TENSORFLOW_MNIST = "keras/tensorflow/mnist" XGB_HIGGS = "xgb_higgs" GANDLF_SEG_TEST = "gandlf_seg_test" FLOWER_APP_PYTORCH = "flower-app-pytorch" NO_OP = "no-op" + FEDERATED_ANALYTICS_HISTOGRAM = "federated_analytics/histogram" + FEDERATED_ANALYTICS_SMOKERS_HEALTH = "federated_analytics/smokers_health" NUM_COLLABORATORS = 2 NUM_ROUNDS = 5 WORKSPACE_NAME = "my_federation" SUCCESS_MARKER = "✔️ OK" -# Docker specific constants +# Docker specific defaults CREATE_OPENFL_NW = "docker network create" REMOVE_OPENFL_NW = "docker network rm" DOCKER_NETWORK_NAME = "openfl" @@ -60,3 +64,34 @@ class ModelName(Enum): EXCEPTION = "Exception" AGG_METRIC_MODEL_ACCURACY_KEY = "aggregator/aggregated_model_validation/accuracy" COL_TLS_END_MSG = "TLS connection established." + + +class TransportProtocol(Enum): + """ + Enum class to define the transport protocol. + """ + GRPC = "grpc" + REST = "rest" + +AGGREGATOR_REST_CLIENT = "Starting Aggregator REST Server" +AGGREGATOR_gRPC_CLIENT = "Starting Aggregator gRPC Server" + +# For S3 and MinIO +MINIO_ROOT_USER = "minioadmin" +MINIO_ROOT_PASSWORD = "minioadmin" +MINIO_HOST = "localhost" +MINIO_PORT = 9000 +MINIO_CONSOLE_PORT = 9001 +MINIO_URL = f"http://{MINIO_HOST}:{MINIO_PORT}" +MINIO_CONSOLE_URL = f"http://{MINIO_HOST}:{MINIO_CONSOLE_PORT}" +MINIO_DATA_FOLDER = "minio_data" + +# For Azure Blob Storage +AZURE_STORAGE_HOST = "localhost" +AZURE_STORAGE_PORT = 10000 +AZURE_STORAGE_ENDPOINTS_PROTOCOL = "http" +AZURE_STORAGE_ACCOUNT_NAME = "devstoreaccount1" +# IMP: The account key is provided by Azure for local development storage +# and is not a real key. It is used for testing purposes only. +AZURE_STORAGE_ACCOUNT_KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" +AZURE_BLOB_ENDPOINT = f"{AZURE_STORAGE_ENDPOINTS_PROTOCOL}://{AZURE_STORAGE_HOST}:{AZURE_STORAGE_PORT}/{AZURE_STORAGE_ACCOUNT_NAME}" diff --git a/tests/end_to_end/utils/docker_helper.py b/tests/end_to_end/utils/docker_helper.py index e3cab5824f..644d0fed69 100644 --- a/tests/end_to_end/utils/docker_helper.py +++ b/tests/end_to_end/utils/docker_helper.py @@ -6,13 +6,13 @@ import subprocess from functools import lru_cache -import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.defaults as defaults import tests.end_to_end.utils.exceptions as ex log = logging.getLogger(__name__) -def remove_docker_network(list_of_networks=[constants.DOCKER_NETWORK_NAME]): +def remove_docker_network(list_of_networks=[defaults.DOCKER_NETWORK_NAME]): """ Remove docker network. Args: @@ -30,7 +30,7 @@ def remove_docker_network(list_of_networks=[constants.DOCKER_NETWORK_NAME]): log.debug(f"Docker network(s) {list_of_networks} removed successfully") -def create_docker_network(list_of_networks=[constants.DOCKER_NETWORK_NAME]): +def create_docker_network(list_of_networks=[defaults.DOCKER_NETWORK_NAME]): """ Create docker network. Args: @@ -53,18 +53,18 @@ def check_docker_image(): Check if the docker image exists. """ client = get_docker_client() - images = client.images.list(name=constants.DEFAULT_OPENFL_IMAGE) + images = client.images.list(name=defaults.DEFAULT_OPENFL_IMAGE) if not images: - log.error(f"Image {constants.DEFAULT_OPENFL_IMAGE} does not exist") - raise Exception(f"Image {constants.DEFAULT_OPENFL_IMAGE} does not exist") - log.debug(f"Image {constants.DEFAULT_OPENFL_IMAGE} exists") + log.error(f"Image {defaults.DEFAULT_OPENFL_IMAGE} does not exist") + raise Exception(f"Image {defaults.DEFAULT_OPENFL_IMAGE} does not exist") + log.debug(f"Image {defaults.DEFAULT_OPENFL_IMAGE} exists") def start_docker_container_with_federation_run( participant, use_tls=True, - image=constants.DEFAULT_OPENFL_IMAGE, - network=constants.DOCKER_NETWORK_NAME, + image=defaults.DEFAULT_OPENFL_IMAGE, + network=defaults.DOCKER_NETWORK_NAME, env_keyval_list=None, security_opt=None, mount_mapping=None, @@ -94,7 +94,7 @@ def start_docker_container_with_federation_run( else: local_participant_path = participant.workspace_path - docker_participant_path = f"/{constants.DFLT_WORKSPACE_NAME}" + docker_participant_path = f"/{defaults.DFLT_WORKSPACE_NAME}" volumes = { local_participant_path: {"bind": docker_participant_path, "mode": "rw"}, @@ -117,15 +117,15 @@ def start_docker_container_with_federation_run( log_file = f"{docker_participant_path}/logs/{participant.name}.log" if participant.name == "aggregator": - start_agg = constants.AGG_START_CMD + start_agg = defaults.AGG_START_CMD # Handle Fed Eval case if participant.eval_scope: start_agg += " --task_group evaluation" command = ["bash", "-c", f"touch {log_file} && {start_agg} > {log_file} 2>&1"] else: - start_collaborator = f"touch {log_file} && {constants.COL_START_CMD.format(participant.name)} > {log_file} 2>&1" + start_collaborator = f"touch {log_file} && {defaults.COL_START_CMD.format(participant.name)} > {log_file} 2>&1" if use_tls: - command = ["bash", "-c", f"{constants.COL_CERTIFY_CMD.format(participant.name)} && {start_collaborator}"] + command = ["bash", "-c", f"{defaults.COL_CERTIFY_CMD.format(participant.name)} && {start_collaborator}"] else: command = ["bash", "-c", start_collaborator] diff --git a/tests/end_to_end/utils/exceptions.py b/tests/end_to_end/utils/exceptions.py index e7c353eaa3..16c1a37b86 100644 --- a/tests/end_to_end/utils/exceptions.py +++ b/tests/end_to_end/utils/exceptions.py @@ -129,3 +129,33 @@ class FlowerAppException(Exception): class ProcessKillException(Exception): """Exception for process kill""" pass + + +class HashCalculationException(Exception): + """Exception for hash calculation of collaborator's data path""" + pass + + +class MinioServerStartException(Exception): + """Exception for minio server start""" + pass + + +class S3BucketCreationException(Exception): + """Exception for S3 bucket creation""" + pass + + +class DataDownloadException(Exception): + """Exception for data download""" + pass + + +class DataUploadToS3Exception(Exception): + """Exception for data upload to S3""" + pass + + +class AzureBlobContainerCreationException(Exception): + """Exception for Azure Blob container creation""" + pass diff --git a/tests/end_to_end/utils/federation_helper.py b/tests/end_to_end/utils/federation_helper.py index 99d1a6c6b5..f5059d3350 100644 --- a/tests/end_to_end/utils/federation_helper.py +++ b/tests/end_to_end/utils/federation_helper.py @@ -6,17 +6,14 @@ import logging import yaml import os -import json -import re import subprocess # nosec B404 from pathlib import Path -import shutil -from glob import glob -import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.defaults as defaults import tests.end_to_end.utils.db_helper as db_helper import tests.end_to_end.utils.docker_helper as dh import tests.end_to_end.utils.exceptions as ex +import tests.end_to_end.utils.helper as helper import tests.end_to_end.utils.interruption_helper as intr_helper import tests.end_to_end.utils.ssh_helper as ssh from tests.end_to_end.models import collaborator as col_model @@ -38,7 +35,7 @@ def setup_pki_for_collaborators(collaborators, model_owner, local_bind_path): bool: True if successful, else False """ # PKI setup for aggregator is done at fixture level - local_agg_ws_path = constants.AGG_WORKSPACE_PATH.format(local_bind_path) + local_agg_ws_path = defaults.AGG_WORKSPACE_PATH.format(local_bind_path) executor = concurrent.futures.ThreadPoolExecutor() @@ -66,7 +63,7 @@ def setup_pki_for_collaborators(collaborators, model_owner, local_bind_path): results = [ executor.submit( copy_file_between_participants, - local_src_path=constants.COL_WORKSPACE_PATH.format( + local_src_path=defaults.COL_WORKSPACE_PATH.format( local_bind_path, collaborator.name ), local_dest_path=local_agg_ws_path, @@ -99,7 +96,7 @@ def setup_pki_for_collaborators(collaborators, model_owner, local_bind_path): executor.submit( copy_file_between_participants, local_src_path=local_agg_ws_path, - local_dest_path=constants.COL_WORKSPACE_PATH.format( + local_dest_path=defaults.COL_WORKSPACE_PATH.format( local_bind_path, collaborator.name ), file_name=f"agg_to_col_{collaborator.name}_signed_cert.zip", @@ -134,7 +131,7 @@ def _create_tarball(collaborator_name, data_file_path, local_bind_path, add_data If TLS is enabled - include client certificates and signed certificates in the tarball If data needs to be added - include the data file in the tarball """ - local_col_ws_path = constants.COL_WORKSPACE_PATH.format( + local_col_ws_path = defaults.COL_WORKSPACE_PATH.format( local_bind_path, collaborator_name ) client_cert_entries = "" @@ -220,6 +217,37 @@ def copy_file_between_participants( return True +def _check_aggregator_protocol_log(aggregator): + """ + Check if the aggregator started with the correct protocol by inspecting its log file. + Args: + aggregator (object): Aggregator object with res_file and transport_protocol attributes. + Raises: + Exception: If the expected protocol start message is not found in the logs. + """ + start_time = time.time() + found = False + while time.time() - start_time < 30: + with open(aggregator.res_file, "r") as file: + lines = [line.strip() for line in file.readlines()] + last_lines = lines[-5:] + if aggregator.transport_protocol == defaults.TransportProtocol.REST.value: + expected_msg = defaults.AGGREGATOR_REST_CLIENT + else: + expected_msg = defaults.AGGREGATOR_gRPC_CLIENT + + msg_received = [line for line in last_lines if expected_msg.lower() in line.lower()] + if msg_received: + found = True + break + time.sleep(10) + if not found: + raise Exception( + f"Aggregator did not start with {aggregator.transport_protocol} protocol. Check the logs for more details" + ) + log.info(f"Aggregator started with {aggregator.transport_protocol} protocol") + + def run_federation(fed_obj): """ Start the federation @@ -233,14 +261,15 @@ def run_federation(fed_obj): if "keras" in fed_obj.model_name: _ = set_keras_backend(fed_obj.model_name) - for participant in [fed_obj.aggregator] + fed_obj.collaborators: + # Start the aggregator + start_aggregator(fed_obj) + + for participant in fed_obj.collaborators: try: - # Start the participant participant.start() except Exception as e: log.error(f"Failed to start {participant.name}: {e}") raise e - return True @@ -257,7 +286,7 @@ def run_federation_for_dws(fed_obj, use_tls): try: container = dh.start_docker_container_with_federation_run( participant=participant, - image=constants.DFLT_WORKSPACE_NAME, + image=defaults.DFLT_WORKSPACE_NAME, use_tls=use_tls, env_keyval_list=set_keras_backend(fed_obj.model_name) if "keras" in fed_obj.model_name else None, ) @@ -271,13 +300,14 @@ def run_federation_for_dws(fed_obj, use_tls): return True -def verify_federation_run_completion(fed_obj, test_env, num_rounds): +def verify_federation_run_completion(fed_obj, test_env, num_rounds, time_for_each_round=100): """ Verify the completion of the process for all the participants Args: fed_obj (object): Federation fixture object test_env (str): Test environment num_rounds (int): Number of rounds + time_for_each_round (int): Time for each round (in seconds) Returns: list: List of response (True or False) for all the participants """ @@ -291,6 +321,7 @@ def verify_federation_run_completion(fed_obj, test_env, num_rounds): participant, num_rounds, num_collaborators=len(fed_obj.collaborators), + time_for_each_round=time_for_each_round, ) for participant in fed_obj.collaborators + [fed_obj.aggregator] ] @@ -353,13 +384,13 @@ def _verify_completion_for_participant( log.info(f"Last line in {participant.name} log: {lines[-1:]}") # If in logs Exception is encountered, throw Exception and stop the process - if constants.EXCEPTION in content: + if defaults.EXCEPTION in content: log.error( f"Process {participant.name} is throwing Exception. Check the logs for more details" ) raise Exception(f"Process failed for {participant.name}") - msg_received = [line for line in content if constants.AGG_END_MSG in line or constants.COL_END_MSG in line] + msg_received = [line for line in content if defaults.AGG_END_MSG in line or defaults.COL_END_MSG in line] if msg_received: log.info(f"Process completed for {participant.name}") break @@ -403,7 +434,7 @@ def federation_env_setup_and_validate(request, eval_scope=False): test_env = request.config.test_env # Validate the model name and create the workspace name - if not request.config.model_name.replace("/", "_").replace("-", "_").upper() in constants.ModelName._member_names_: + if not request.config.model_name.replace("/", "_").replace("-", "_").upper() in defaults.ModelName._member_names_: raise ValueError(f"Invalid model name: {request.config.model_name}") # Set the workspace path specific to the model and the test case @@ -431,6 +462,7 @@ def federation_env_setup_and_validate(request, eval_scope=False): dh.remove_docker_network() dh.create_docker_network() + request.config.transport_protocol = defaults.TransportProtocol.REST.value if request.config.tr_rest_protocol else defaults.TransportProtocol.GRPC.value log.info( f"Running federation setup using {test_env} API on single machine with below configurations:\n" f"Number of collaborators: {request.config.num_collaborators}\n" @@ -439,6 +471,7 @@ def federation_env_setup_and_validate(request, eval_scope=False): f"Client authentication: {request.config.require_client_auth}\n" f"TLS: {request.config.use_tls}\n" f"Secure Aggregation: {request.config.secure_agg}\n" + f"Transport protocol: {request.config.transport_protocol}\n" f"Memory Logs: {request.config.log_memory_usage}\n" f"Results directory: {request.config.results_dir}\n" f"Workspace path: {workspace_path}" @@ -461,7 +494,7 @@ def create_persistent_store(participant_name, local_bind_path): f"mkdir -p $WORKING_DIRECTORY/{participant_name}/workspace" ) log.debug(f"Creating persistent store") - return_code, output, error = run_command( + return_code, output, error = helper.run_command( cmd_persistent_store, workspace_path=Path().home(), ) @@ -474,102 +507,7 @@ def create_persistent_store(participant_name, local_bind_path): raise ex.PersistentStoreCreationException(f"{error_msg}: {e}") -def run_command( - command, - workspace_path, - error_msg=None, - container_id=None, - run_in_background=False, - bg_file=None, - print_output=False, - with_docker=False, - return_error=False, -): - """ - Run the command - Args: - command (str): Command to run - workspace_path (str): Workspace path - container_id (str): Container ID - run_in_background (bool): Run the command in background - bg_file (str): Background file (with path) - print_output (bool): Print the output - with_docker (bool): Flag specific to dockerized workspace scenario. Default is False. - return_error (bool): Return error message - Returns: - tuple: Return code, output and error - """ - return_code, output, error = 0, None, None - error_msg = error_msg or "Failed to run the command" - - if with_docker and container_id: - log.debug("Running command in docker container") - if len(workspace_path): - docker_command = f"docker exec -w {workspace_path} {container_id} sh -c " - else: - # This scenario is mainly for workspace creation where workspace path is not available - docker_command = f"docker exec -i {container_id} sh -c " - - if run_in_background and bg_file: - docker_command += f"'{command} > {bg_file}' &" - else: - docker_command += f"'{command}'" - - command = docker_command - else: - if not run_in_background: - # When the command is run in background, we anyways pass the workspace path - command = f"cd {workspace_path}; {command}" - - if print_output: - log.info(f"Running command: {command}") - - if run_in_background and not with_docker: - if bg_file: - bg_file = open(bg_file, "a", buffering=1) # open file in append mode, so that restarting scenarios can be handled - ssh.run_command_background( - command, - work_dir=workspace_path, - redirect_to_file=bg_file, - check_sleep=60, - ) - else: - return_code, output, error = ssh.run_command(command) - if return_code != 0 and not return_error: - log.error(f"{error_msg}: {error}") - raise Exception(f"{error_msg}: {error}") - - if print_output: - log.info(f"Output: {output}") - log.info(f"Error: {error}") - return return_code, output, error - - -# This functionality is common across multiple participants, thus moved to a common function -def verify_cmd_output( - output, return_code, error, error_msg, success_msg, raise_exception=True -): - """ - Verify the output of fx command run - Assumption - it will have '✔️ OK' in the output if the command is successful - Args: - output (list): Output of the command using run_command() - return_code (int): Return code of the command - error (list): Error message - error_msg (str): Error message - success_msg (str): Success message - """ - msg_received = [line for line in output if constants.SUCCESS_MARKER in line] - log.info(f"Message received: {msg_received}") - if return_code == 0 and len(msg_received): - log.info(success_msg) - else: - log.error(f"{error_msg}: {error}") - if raise_exception: - raise Exception(f"{error_msg}: {error}") - - -def setup_collaborator(index, workspace_path, local_bind_path): +def setup_collaborator(index, workspace_path, local_bind_path, data_path=None, calc_hash=False, colab_bucket_mapping=None, transport_protocol="grpc"): """ Setup the collaborator Includes - creation of collaborator objects, starting docker container, importing workspace, creating collaborator @@ -577,13 +515,25 @@ def setup_collaborator(index, workspace_path, local_bind_path): index (int): Index of the collaborator. Starts with 1. workspace_path (str): Workspace path local_bind_path (str): Local bind path + data_path (str): Data path + calc_hash (bool): Flag to indicate if hash calculation is required + colab_bucket_mapping (dict): Mapping of collaborator and its datasources + transport_protocol (str): Transport protocol (default: "gRPC") """ - local_agg_ws_path = constants.AGG_WORKSPACE_PATH.format(local_bind_path) + local_agg_ws_path = defaults.AGG_WORKSPACE_PATH.format(local_bind_path) + + # If datasource path exists, it indicates that the collaborator is using a custom data source + # After importing workspace, copy the datasources.json file to the collaborator workspace/data directory + # and set the data_directory_path to "data" + datasource_path = os.path.join(str(Path(local_bind_path).parents[1]), "datasources", f"collaborator{index}") + if not os.path.exists(datasource_path): + datasource_path = None try: collaborator = col_model.Collaborator( collaborator_name=f"collaborator{index}", - data_directory_path=index, + transport_protocol=transport_protocol, + data_directory_path=index if datasource_path is None else "data", workspace_path=f"{workspace_path}/collaborator{index}/workspace", ) create_persistent_store(collaborator.name, local_bind_path) @@ -594,11 +544,11 @@ def setup_collaborator(index, workspace_path, local_bind_path): ) try: - local_col_ws_path = constants.COL_WORKSPACE_PATH.format( + local_col_ws_path = defaults.COL_WORKSPACE_PATH.format( local_bind_path, collaborator.name ) copy_file_between_participants( - local_agg_ws_path, local_col_ws_path, f"{constants.DFLT_WORKSPACE_NAME}.zip" + local_agg_ws_path, local_col_ws_path, f"{defaults.DFLT_WORKSPACE_NAME}.zip" ) collaborator.import_workspace() except Exception as e: @@ -606,224 +556,36 @@ def setup_collaborator(index, workspace_path, local_bind_path): f"Failed to import workspace for {collaborator.name}: {e}" ) - try: - collaborator.create_collaborator() - except Exception as e: - raise ex.CollaboratorCreationException(f"Failed to create collaborator: {e}") - - return collaborator - - -def setup_collaborator_data(collaborators, model_name, local_bind_path): - """ - Function to setup the data for collaborators. - IMP: This function is specific to the model and should be updated as per the model requirements. - Args: - collaborators (list): List of collaborator objects - model_name (str): Model name - local_bind_path (str): Local bind path - """ - # Check if data already exists, if yes, skip the download part - # This is mainly helpful in case of re-runs - if all(os.path.exists(os.path.join(collaborator.workspace_path, "data", str(index))) for index, collaborator in enumerate(collaborators, start=1)): - log.info("Data already exists for all the collaborators. Skipping the download part..") - return - else: - log.info("Data does not exist for all the collaborators. Proceeding with the download..") - # Below step will also modify the data.yaml file for all the collaborators - if model_name == constants.ModelName.XGB_HIGGS.value: - download_higgs_data(collaborators, local_bind_path) - elif model_name == constants.ModelName.FLOWER_APP_PYTORCH.value: - download_flower_data(collaborators, local_bind_path) - - log.info("Data setup is complete for all the collaborators") - - -def download_gandlf_data(aggregator, local_bind_path, num_collaborators, results_path): - """ - Function to download the data for GanDLF segmentation test model and copy to the respective collaborator workspaces - For GanDLF, data download happens at aggregator level, thus we can not call this function from setup_collaborator_data - where download is at collaborator level - Args: - aggregator: Aggregator object - collaborators: List of collaborator objects - local_bind_path: Local bind path - results_path: Result directory (mostly $HOME/results) where GaNDLF csv and config yaml files are present - """ - try: - # Get list of all CSV files in openfl_path - csv_files = glob(os.path.join(results_path, '*.csv')) - - # Get data.yaml file and remove any entry, if present - data_file = os.path.join(aggregator.workspace_path, "plan", "data.yaml") - with open(data_file, "w") as df: - df.write("") - - # Copy the data to the respective workspaces based on the index - for col_index in range(1, num_collaborators+1): - dst_folder = os.path.join(aggregator.workspace_path, "data", str(col_index)) - os.makedirs(dst_folder, exist_ok=True) - for csv_file in csv_files: - shutil.copy(csv_file, dst_folder) - log.info(f"Copied data from {csv_file} to {dst_folder}") - - aggregator.modify_data_file( - constants.COL_DATA_FILE.format(local_bind_path, "aggregator"), - f"collaborator{col_index}", - col_index, + # If datasources path exist, copy the data files to the collaborator workspace + if datasource_path: + try: + copy_file_between_participants( + local_src_path=datasource_path, + local_dest_path=os.path.join(collaborator.workspace_path, "data"), + file_name="datasources.json", + run_with_sudo=True, ) - except Exception as e: - raise ex.DataSetupException(f"Failed to modify the data file: {e}") - - return True - - -def copy_gandlf_data_to_collaborators(aggregator, collaborators, local_bind_path): - """ - Function to copy the GaNDLF data from aggregator to respective collaborators - """ - try: - # Copy the data to the respective workspaces based on the index - for index, collaborator in enumerate(collaborators, start=1): - src_folder = os.path.join(aggregator.workspace_path, "data", str(index)) - dst_folder = os.path.join(collaborator.workspace_path, "data", str(index)) - if os.path.exists(src_folder): - shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True) - log.info(f"Copied data from {src_folder} to {dst_folder}") - else: - raise ex.DataSetupException(f"Source folder {src_folder} does not exist for {collaborator.name}") - - # Modify the data.yaml file for all the collaborators - collaborator.modify_data_file( - constants.COL_DATA_FILE.format(local_bind_path, collaborator.name), - index, + except Exception as e: + raise ex.DataCopyException( + f"Failed to copy datasources.json for {collaborator.name}: {e}" ) - except Exception as e: - raise ex.DataSetupException(f"Failed to modify the data file: {e}") - -def download_flower_data(collaborators, local_bind_path): - """ - Download the data for the model and copy to the respective collaborator workspaces - Also modify the data.yaml file for all the collaborators - Args: - collaborators (list): List of collaborator objects - local_bind_path (str): Local bind path - Returns: - bool: True if successful, else False - """ - common_download_for_higgs_and_flower(collaborators, local_bind_path) - - -def download_higgs_data(collaborators, local_bind_path): - """ - Download the data for the model and copy to the respective collaborator workspaces - Also modify the data.yaml file for all the collaborators - Args: - collaborators (list): List of collaborator objects - local_bind_path (str): Local bind path - Returns: - bool: True if successful, else False - """ - common_download_for_higgs_and_flower(collaborators, local_bind_path) - - -def common_download_for_higgs_and_flower(collaborators, local_bind_path): - """ - Common function to download the data for both Higgs and Flower models. - In future, if the data setup for other models is similar, we can use this function. - Also, if the setup changes for any of the models, we can modify this function to accommodate the changes. - """ - log.info(f"Copying {constants.DATA_SETUP_FILE} from one of the collaborator workspaces to the local bind path..") try: - shutil.copyfile( - src=os.path.join(collaborators[0].workspace_path, "src", constants.DATA_SETUP_FILE), - dst=os.path.join(local_bind_path, constants.DATA_SETUP_FILE) - ) + collaborator.create_collaborator() except Exception as e: - raise ex.DataSetupException(f"Failed to copy data setup file: {e}") - - log.info("Downloading the data for the model. This will take some time to complete based on the data size ..") - try: - command = ["python", constants.DATA_SETUP_FILE, str(len(collaborators))] - subprocess.run(command, cwd=local_bind_path, check=True) # nosec B603 - except Exception: - raise ex.DataSetupException(f"Failed to download data for given model") - - try: - # Copy the data to the respective workspaces based on the index - for index, collaborator in enumerate(collaborators, start=1): - src_folder = os.path.join(local_bind_path, "data", str(index)) - dst_folder = os.path.join(collaborator.workspace_path, "data", str(index)) - if os.path.exists(src_folder): - shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True) - log.info(f"Copied data from {src_folder} to {dst_folder}") - else: - raise ex.DataSetupException(f"Source folder {src_folder} does not exist for {collaborator.name}") + raise ex.CollaboratorCreationException(f"Failed to create collaborator: {e}") - # Modify the data.yaml file for all the collaborators - collaborator.modify_data_file( - constants.COL_DATA_FILE.format(local_bind_path, collaborator.name), - index, + # Calculate the hash of collaborator datasource (specific to torch/histology_s3 model). + if datasource_path: + try: + # Calculate hash for the collaborator + collaborator.calculate_hash() + except Exception as e: + raise ex.HashCalculationException( + f"Failed to calculate hash for {collaborator.name}: {e}" ) - except Exception as e: - raise ex.DataSetupException(f"Failed to modify the data file: {e}") - - # XGBoost model uses folder name higgs_data and Flower model uses data to create data folders. - shutil.rmtree(os.path.join(local_bind_path, "higgs_data"), ignore_errors=True) - shutil.rmtree(os.path.join(local_bind_path, "data"), ignore_errors=True) - return True - - -def extract_memory_usage(log_file): - """ - Extracts memory usage data from a log file. - This function reads the content of the specified log file, searches for memory usage data - using a regular expression pattern, and returns the extracted data as a dictionary. - Args: - log_file (str): The path to the log file from which to extract memory usage data. - Returns: - dict: A dictionary containing the memory usage data. - Raises: - json.JSONDecodeError: If there is an error decoding the JSON data. - Exception: If memory usage data is not found in the log file. - """ - try: - with open(log_file, "r") as file: - content = file.read() - - pattern = r"Publish memory usage: (\[.*?\])" - match = re.search(pattern, content, re.DOTALL) - - if match: - memory_usage_data = match.group(1) - memory_usage_data = re.sub(r"\S+\.py:\d+", "", memory_usage_data) - memory_usage_data = memory_usage_data.replace("\n", "").replace(" ", "") - memory_usage_data = memory_usage_data.replace("'", '"') - memory_usage_dict = json.loads(memory_usage_data) - return memory_usage_dict - else: - log.error("Memory usage data not found in the log file") - raise Exception("Memory usage data not found in the log file") - except Exception as e: - log.error(f"An error occurred while extracting memory usage: {e}") - raise e - -def write_memory_usage_to_file(memory_usage_dict, output_file): - """ - Writes memory usage data to a file. - This function writes the specified memory usage data to the specified output file. - Args: - memory_usage_dict (dict): A dictionary containing the memory usage data. - output_file (str): The path to the output file to which to write the memory usage data. - """ - try: - with open(output_file, "w") as file: - json.dump(memory_usage_dict, file, indent=4) - except Exception as e: - log.error(f"An error occurred while writing memory usage data to file: {e}") - raise e + return collaborator def start_director(workspace_path, dir_res_file): @@ -837,7 +599,7 @@ def start_director(workspace_path, dir_res_file): """ try: error_msg = "Failed to start the director" - return_code, output, error = run_command( + return_code, output, error = helper.run_command( "./start_director.sh", error_msg=error_msg, workspace_path=os.path.join(workspace_path, "director"), @@ -867,7 +629,7 @@ def start_envoy(envoy_name, workspace_path, res_file): """ try: error_msg = f"Failed to start {envoy_name} envoy" - return_code, output, error = run_command( + return_code, output, error = helper.run_command( f"./start_envoy.sh {envoy_name} {envoy_name}_config.yaml", error_msg=error_msg, workspace_path=os.path.join(workspace_path, envoy_name), @@ -950,35 +712,6 @@ def check_envoys_director_conn_federated_runtime( return False -def run_notebook(notebook_path, output_notebook_path): - """ - Function to run the notebook. - Args: - notebook_path (str): Path to the notebook - participant_res_files (dict): Dictionary containing participant names and their result log files - Returns: - bool: True if successful, else False - """ - import papermill as pm - try: - log.info(f"Running the notebook: {notebook_path} with output notebook path: {output_notebook_path}") - output = pm.execute_notebook( - input_path=notebook_path, - output_path=output_notebook_path, - request_save_on_cell_execute=True, - autosave_cell_every=5, # autosave every 5 seconds - log_output=True, - ) - except pm.exceptions.PapermillExecutionError as e: - log.error(f"PapermillExecutionError: {e}") - raise e - - except ex.NotebookRunException as e: - log.error(f"Failed to run the notebook: {e}") - raise e - return True - - def verify_federated_runtime_experiment_completion(participant_res_files): """ Verify the completion of the experiment using the participant logs. @@ -1039,7 +772,7 @@ def get_best_agg_score(database_file=None, agg_metric_file=None, max_retries=10, return db_helper.get_key_value_from_db("best_score", database_file, max_retries=max_retries, sleep_interval=sleep_interval) else: json_file = convert_to_json(agg_metric_file) - best_score = json_file[-1].get(constants.AGG_METRIC_MODEL_ACCURACY_KEY) + best_score = json_file[-1].get(defaults.AGG_METRIC_MODEL_ACCURACY_KEY) if best_score: return float(best_score) else: @@ -1104,30 +837,6 @@ def set_keras_backend(model_name): return [f"KERAS_BACKEND={backend}"] -def remove_stale_processes(aggregator=None, collaborators=[], director=None, envoys=[]): - """ - Remove stale processes - Args: - aggregator (object): Aggregator object - collaborators (list): List of collaborator objects - director (object): Director object - envoys (list): List of envoy objects - """ - if aggregator: - intr_helper.kill_processes(aggregator.name) - - for collaborators in collaborators: - intr_helper.kill_processes(collaborators.name) - - if director: - intr_helper.kill_processes("director") - - for envoy in envoys: - intr_helper.kill_processes(envoy) - - log.info("Stale processes (if any) removed successfully") - - def remove_workspace(path): """ Recursively delete given workspace and its contents, including symbolic links. @@ -1173,7 +882,7 @@ def start_aggregator(fed_obj): except Exception as e: log.error(f"Failed to start aggregator: {e}") raise e - + _check_aggregator_protocol_log(fed_obj.aggregator) return True @@ -1195,7 +904,7 @@ def ping_from_collaborator(collaborator): lines = [line.strip() for line in file.readlines()] # print last line log.info(f"Last line: {lines[-1]}") - if any(constants.COL_TLS_END_MSG in line for line in lines[-7:]): + if any(defaults.COL_TLS_END_MSG in line for line in lines[-7:]): log.info(f"Aggregator is reachable from {collaborator.name}") return True else: diff --git a/tests/end_to_end/utils/helper.py b/tests/end_to_end/utils/helper.py new file mode 100644 index 0000000000..1612682e1b --- /dev/null +++ b/tests/end_to_end/utils/helper.py @@ -0,0 +1,183 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +import json +import re + +import tests.end_to_end.utils.defaults as defaults +import tests.end_to_end.utils.interruption_helper as intr_helper +import tests.end_to_end.utils.ssh_helper as ssh + +log = logging.getLogger(__name__) + +def remove_stale_processes(aggregator=None, collaborators=[], director=None, envoys=[]): + """ + Remove stale processes + Args: + aggregator (object): Aggregator object + collaborators (list): List of collaborator objects + director (object): Director object + envoys (list): List of envoy objects + """ + if aggregator: + intr_helper.kill_processes(aggregator.name) + + for collaborator in collaborators: + intr_helper.kill_processes(collaborator.name) + + if director: + intr_helper.kill_processes("director") + + for envoy in envoys: + intr_helper.kill_processes(envoy) + + log.info("Stale processes (if any) removed successfully") + + +def run_command( + command, + workspace_path, + error_msg=None, + container_id=None, + run_in_background=False, + bg_file=None, + print_output=False, + with_docker=False, + return_error=False, +): + """ + Run the command + Args: + command (str): Command to run + workspace_path (str): Workspace path + container_id (str): Container ID + run_in_background (bool): Run the command in background + bg_file (str): Background file (with path) + print_output (bool): Print the output + with_docker (bool): Flag specific to dockerized workspace scenario. Default is False. + return_error (bool): Return error message + Returns: + tuple: Return code, output and error + """ + return_code, output, error = 0, None, None + error_msg = error_msg or "Failed to run the command" + + if with_docker and container_id: + log.debug("Running command in docker container") + if len(workspace_path): + docker_command = f"docker exec -w {workspace_path} {container_id} sh -c " + else: + # This scenario is mainly for workspace creation where workspace path is not available + docker_command = f"docker exec -i {container_id} sh -c " + + if run_in_background and bg_file: + docker_command += f"'{command} > {bg_file}' &" + else: + docker_command += f"'{command}'" + + command = docker_command + else: + if not run_in_background: + # When the command is run in background, we anyways pass the workspace path + command = f"cd {workspace_path}; {command}" + + if print_output: + log.info(f"Running command: {command}") + + if run_in_background and not with_docker: + if bg_file: + bg_file = open(bg_file, "a", buffering=1) # open file in append mode, so that restarting scenarios can be handled + ssh.run_command_background( + command, + work_dir=workspace_path, + redirect_to_file=bg_file, + check_sleep=60, + ) + else: + return_code, output, error = ssh.run_command(command) + if return_code != 0 and not return_error: + log.error(f"{error_msg}: {error}") + raise Exception(f"{error_msg}: {error}") + + if print_output: + log.info(f"Output: {output}") + log.info(f"Error: {error}") + return return_code, output, error + + +# This functionality is common across multiple participants, thus moved to a common function +def verify_cmd_output( + output, return_code, error, error_msg, success_msg, raise_exception=True +): + """ + Verify the output of fx command run + Assumption - it will have '✔️ OK' in the output if the command is successful + Args: + output (list): Output of the command using run_command() + return_code (int): Return code of the command + error (list): Error message + error_msg (str): Error message + success_msg (str): Success message + """ + msg_received = [line for line in output if defaults.SUCCESS_MARKER in line] + log.info(f"Message received: {msg_received}") + if return_code == 0 and len(msg_received): + log.info(success_msg) + else: + log.error(f"{error_msg}: {error}") + if raise_exception: + raise Exception(f"{error_msg}: {error}") + + +# TODO - remove if not needed in the near future +def extract_memory_usage(log_file): + """ + Extracts memory usage data from a log file. + This function reads the content of the specified log file, searches for memory usage data + using a regular expression pattern, and returns the extracted data as a dictionary. + Args: + log_file (str): The path to the log file from which to extract memory usage data. + Returns: + dict: A dictionary containing the memory usage data. + Raises: + json.JSONDecodeError: If there is an error decoding the JSON data. + Exception: If memory usage data is not found in the log file. + """ + try: + with open(log_file, "r") as file: + content = file.read() + + pattern = r"Publish memory usage: (\[.*?\])" + match = re.search(pattern, content, re.DOTALL) + + if match: + memory_usage_data = match.group(1) + memory_usage_data = re.sub(r"\S+\.py:\d+", "", memory_usage_data) + memory_usage_data = memory_usage_data.replace("\n", "").replace(" ", "") + memory_usage_data = memory_usage_data.replace("'", '"') + memory_usage_dict = json.loads(memory_usage_data) + return memory_usage_dict + else: + log.error("Memory usage data not found in the log file") + raise Exception("Memory usage data not found in the log file") + except Exception as e: + log.error(f"An error occurred while extracting memory usage: {e}") + raise e + + +# TODO - remove if not needed in the near future +def write_memory_usage_to_file(memory_usage_dict, output_file): + """ + Writes memory usage data to a file. + This function writes the specified memory usage data to the specified output file. + Args: + memory_usage_dict (dict): A dictionary containing the memory usage data. + output_file (str): The path to the output file to which to write the memory usage data. + """ + try: + with open(output_file, "w") as file: + json.dump(memory_usage_dict, file, indent=4) + except Exception as e: + log.error(f"An error occurred while writing memory usage data to file: {e}") + raise e diff --git a/tests/end_to_end/utils/interruption_helper.py b/tests/end_to_end/utils/interruption_helper.py index 3f875b630d..94bd9191ed 100644 --- a/tests/end_to_end/utils/interruption_helper.py +++ b/tests/end_to_end/utils/interruption_helper.py @@ -7,7 +7,7 @@ import psutil import subprocess # nosec B404 -import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.defaults as defaults import tests.end_to_end.utils.docker_helper as docker_helper import tests.end_to_end.utils.exceptions as ex diff --git a/tests/end_to_end/utils/summary_helper.py b/tests/end_to_end/utils/summary_helper.py index 742e0b68cc..ca0bf5b0c3 100644 --- a/tests/end_to_end/utils/summary_helper.py +++ b/tests/end_to_end/utils/summary_helper.py @@ -8,7 +8,7 @@ import re from pathlib import Path -import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.defaults as defaults import tests.end_to_end.utils.exceptions as ex from tests.end_to_end.utils import federation_helper as fed_helper @@ -119,7 +119,7 @@ def print_task_runner_score(): summary_file = _get_summary_file() # Validate the model name and create the workspace name - if not model_name.upper() in constants.ModelName._member_names_: + if not model_name.upper() in defaults.ModelName._member_names_: print( f"Invalid model name: {model_name}. Skipping writing to GitHub step summary" ) diff --git a/tests/end_to_end/utils/tr_common_fixtures.py b/tests/end_to_end/utils/tr_common_fixtures.py index 992fa6c605..42b05dad5d 100644 --- a/tests/end_to_end/utils/tr_common_fixtures.py +++ b/tests/end_to_end/utils/tr_common_fixtures.py @@ -2,7 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import shutil +from pathlib import Path +import tests.end_to_end.utils.data_helper as data_helper from tests.end_to_end.utils.tr_workspace import create_tr_workspace, create_tr_dws_workspace @@ -38,3 +41,42 @@ def fx_federation_tr_dws(request): """ request.config.test_env = "task_runner_dockerized_ws" return create_tr_dws_workspace(request) + + +def cleanup_data_dir(): + """ + Cleanup function to remove the data directory after test completion. + This function is called in the finalizer of the fixture.""" + data_dir = Path.cwd() / "data" + if data_dir.exists(): + shutil.rmtree(data_dir) + + +@pytest.fixture(scope="function") +def fx_verifiable_dataset_with_s3(request): + """ + Fixture for verifiable dataset with S3 bucket. + Note: As this is a function level fixture, thus no import is required at test level. + """ + request.addfinalizer(cleanup_data_dir) + data_helper.prepare_verifiable_dataset(request, dataset_type="s3") + + +@pytest.fixture(scope="function") +def fx_verifiable_dataset_with_azure_blob(request): + """ + Fixture for verifiable dataset with Azure Blob Storage. + Note: As this is a function level fixture, thus no import is required at test level. + """ + request.addfinalizer(cleanup_data_dir) + data_helper.prepare_verifiable_dataset(request, dataset_type="azure_blob") + + +@pytest.fixture(scope="function") +def fx_verifiable_dataset_with_all_ds(request): + """ + Fixture for verifiable dataset with all combinations of S3, Azure Blob Storage and local data. + Note: As this is a function level fixture, thus no import is required at test level. + """ + request.addfinalizer(cleanup_data_dir) + data_helper.prepare_verifiable_dataset(request, dataset_type="all") diff --git a/tests/end_to_end/utils/tr_workspace.py b/tests/end_to_end/utils/tr_workspace.py index 10715fa276..b794547ca8 100644 --- a/tests/end_to_end/utils/tr_workspace.py +++ b/tests/end_to_end/utils/tr_workspace.py @@ -7,9 +7,11 @@ import os from pathlib import Path -import tests.end_to_end.utils.constants as constants +import tests.end_to_end.utils.data_helper as data_helper +import tests.end_to_end.utils.defaults as defaults import tests.end_to_end.utils.exceptions as ex import tests.end_to_end.utils.federation_helper as fh +import tests.end_to_end.utils.helper as helper import tests.end_to_end.utils.ssh_helper as ssh from tests.end_to_end.models import aggregator as agg_model, model_owner as mo_model import tests.end_to_end.utils.docker_helper as dh @@ -42,10 +44,10 @@ def common_workspace_creation(request, eval_scope=False): fh.federation_env_setup_and_validate(request, eval_scope) ) - agg_workspace_path = constants.AGG_WORKSPACE_PATH.format(workspace_path) + agg_workspace_path = defaults.AGG_WORKSPACE_PATH.format(workspace_path) # For Flower App Pytorch, num of rounds must be 1 - if request.config.model_name.lower() == constants.ModelName.FLOWER_APP_PYTORCH.value: + if request.config.model_name.lower() == defaults.ModelName.FLOWER_APP_PYTORCH.value: if request.config.num_rounds != 1: raise ex.FlowerAppException( "Flower app with PyTorch only supports 1 round of training." @@ -63,7 +65,7 @@ def common_workspace_creation(request, eval_scope=False): model_owner.create_workspace() # Modify the plan - plan_path = constants.AGG_PLAN_PATH.format(local_bind_path) + plan_path = defaults.AGG_PLAN_PATH.format(local_bind_path) param_config = request.config initial_model_path = None @@ -97,6 +99,9 @@ def create_tr_workspace(request, eval_scope=False): tuple : A named tuple containing the objects for model owner, aggregator, and collaborators. """ + if not request.config.model_name: + raise ex.ModelNameException("Model name is not set in the request") + # get details of model owner, collaborators, and aggregator from common # workspace creation function workspace_path, local_bind_path, agg_domain_name, model_owner, plan_path, agg_workspace_path, initial_model_path = common_workspace_creation(request, eval_scope) @@ -119,6 +124,7 @@ def create_tr_workspace(request, eval_scope=False): aggregator = agg_model.Aggregator( agg_domain_name=agg_domain_name, workspace_path=agg_workspace_path, + transport_protocol=request.config.transport_protocol, eval_scope=eval_scope, container_id=model_owner.container_id, # None in case of native environment ) @@ -135,27 +141,32 @@ def create_tr_workspace(request, eval_scope=False): collaborators = [] executor = concurrent.futures.ThreadPoolExecutor() + + # In case of torch/histology_s3, we need to pass the data path, flag to calculate hash + # and bucket mapping to the setup_collaborator function futures = [ executor.submit( fh.setup_collaborator, index, workspace_path=workspace_path, local_bind_path=local_bind_path, + transport_protocol=request.config.transport_protocol, ) for index in range(1, request.config.num_collaborators+1) ] + collaborators = [f.result() for f in futures] # Data setup requires total no of collaborators, thus keeping the function call # outside of the loop - if request.config.model_name.lower() in [constants.ModelName.XGB_HIGGS.value, constants.ModelName.FLOWER_APP_PYTORCH.value]: - fh.setup_collaborator_data(collaborators, request.config.model_name, local_bind_path) + if request.config.model_name.lower() in [defaults.ModelName.XGB_HIGGS.value, defaults.ModelName.FLOWER_APP_PYTORCH.value]: + data_helper.setup_collaborator_data(collaborators, request.config.model_name, local_bind_path) if request.config.use_tls: fh.setup_pki_for_collaborators(collaborators, model_owner, local_bind_path) fh.import_pki_for_collaborators(collaborators) - fh.remove_stale_processes(aggregator, collaborators) + helper.remove_stale_processes(aggregator, collaborators) # Return the federation fixture return federation_details( @@ -211,6 +222,7 @@ def create_tr_workspace_gandlf(request, eval_scope=False): aggregator = agg_model.Aggregator( agg_domain_name=agg_domain_name, workspace_path=agg_workspace_path, + transport_protocol=request.config.transport_protocol, eval_scope=eval_scope, container_id=model_owner.container_id, # None in case of native environment ) @@ -218,7 +230,7 @@ def create_tr_workspace_gandlf(request, eval_scope=False): # Currently plan initialization internally checks data path in data.yaml # So we need to have data and modified data.yaml file in place before initializing the plan # Issue - https://github.com/securefederatedai/openfl/issues/73 - fh.download_gandlf_data(aggregator, local_bind_path, request.config.num_collaborators, results_path) + data_helper.download_gandlf_data(aggregator, local_bind_path, request.config.num_collaborators, results_path) # Initialize the plan extra_args = f"--gandlf_config {gandlf_seg_file}" @@ -249,13 +261,14 @@ def create_tr_workspace_gandlf(request, eval_scope=False): fh.setup_collaborator, index, workspace_path=workspace_path, - local_bind_path=local_bind_path + local_bind_path=local_bind_path, + transport_protocol=request.config.transport_protocol, ) for index in range(1, request.config.num_collaborators+1) ] collaborators = [f.result() for f in futures] - fh.copy_gandlf_data_to_collaborators(aggregator, collaborators, local_bind_path) + data_helper.copy_gandlf_data_to_collaborators(aggregator, collaborators, local_bind_path) if request.config.use_tls: fh.setup_pki_for_collaborators(collaborators, model_owner, local_bind_path) @@ -295,11 +308,11 @@ def create_tr_dws_workspace(request, eval_scope=False): ) # Create openfl image - dh.build_docker_image(constants.DEFAULT_OPENFL_IMAGE, constants.DEFAULT_OPENFL_DOCKERFILE) + dh.build_docker_image(defaults.DEFAULT_OPENFL_IMAGE, defaults.DEFAULT_OPENFL_DOCKERFILE) # Command 'fx workspace dockerize --save ..' will use the workspace name for # image name which is 'workspace' in this case. - model_owner.dockerize_workspace(constants.DEFAULT_OPENFL_IMAGE) + model_owner.dockerize_workspace(defaults.DEFAULT_OPENFL_IMAGE) # Certify the workspace in case of TLS if request.config.use_tls: @@ -314,6 +327,7 @@ def create_tr_dws_workspace(request, eval_scope=False): aggregator = agg_model.Aggregator( agg_domain_name=agg_domain_name, workspace_path=agg_workspace_path, + transport_protocol=request.config.transport_protocol, eval_scope=eval_scope, container_id=model_owner.container_id, # None in case of native environment ) @@ -326,6 +340,7 @@ def create_tr_dws_workspace(request, eval_scope=False): index, workspace_path=workspace_path, local_bind_path=local_bind_path, + transport_protocol=request.config.transport_protocol, ) for index in range(1, request.config.num_collaborators + 1) ] @@ -336,14 +351,14 @@ def create_tr_dws_workspace(request, eval_scope=False): # Data setup requires total no of collaborators, thus keeping the function call # outside of the loop - if request.config.model_name.lower() in [constants.ModelName.XGB_HIGGS.value, constants.ModelName.FLOWER_APP_PYTORCH.value]: - fh.setup_collaborator_data(collaborators, request.config.model_name, local_bind_path) + if request.config.model_name.lower() in [defaults.ModelName.XGB_HIGGS.value, defaults.ModelName.FLOWER_APP_PYTORCH.value]: + data_helper.setup_collaborator_data(collaborators, request.config.model_name, local_bind_path) # Note: In case of multiple machines setup, scp the created tar for collaborators # to the other machine(s) fh.create_tarball_for_collaborators( collaborators, local_bind_path, use_tls=request.config.use_tls, - add_data=True if request.config.model_name.lower() in [constants.ModelName.XGB_HIGGS.value, constants.ModelName.FLOWER_APP_PYTORCH.value] else False + add_data=True if request.config.model_name.lower() in [defaults.ModelName.XGB_HIGGS.value, defaults.ModelName.FLOWER_APP_PYTORCH.value] else False ) # Generate the sign request and certify the aggregator in case of TLS @@ -351,7 +366,7 @@ def create_tr_dws_workspace(request, eval_scope=False): aggregator.generate_sign_request() model_owner.certify_aggregator(agg_domain_name) - local_agg_ws_path = constants.AGG_WORKSPACE_PATH.format(local_bind_path) + local_agg_ws_path = defaults.AGG_WORKSPACE_PATH.format(local_bind_path) # Note: In case of multiple machines setup, scp this tar to the other machine(s) return_code, output, error = ssh.run_command( @@ -362,7 +377,7 @@ def create_tr_dws_workspace(request, eval_scope=False): # Note: In case of multiple machines setup, scp this workspace tar # to the other machine(s) so that docker load can load the image. - model_owner.load_workspace(workspace_tar_name=f"{constants.DFLT_WORKSPACE_NAME}.tar") + model_owner.load_workspace(workspace_tar_name=f"{defaults.DFLT_WORKSPACE_NAME}.tar") # Return the federation fixture return federation_details( diff --git a/tests/end_to_end/utils/wf_common_fixtures.py b/tests/end_to_end/utils/wf_common_fixtures.py index a6e242fe05..58998127f6 100644 --- a/tests/end_to_end/utils/wf_common_fixtures.py +++ b/tests/end_to_end/utils/wf_common_fixtures.py @@ -148,3 +148,62 @@ def fx_local_federated_workflow_prvt_attr(request): collaborators=collaborators_list, runtime=local_runtime, ) + + +@pytest.fixture(scope="function") +def fx_local_fed_wf_unserializable_pvt_attrs(request): + """ + Fixture to set up a local federated workflow for testing. + This fixture initializes an `Aggregator` and sets up a list of collaborators + based on the number specified in the test configuration. It also configures + a `LocalRuntime` with the aggregator, collaborators, and an optional backend + if specified in the test configuration. + Args: + request (FixtureRequest): The pytest request object that provides access + to the test configuration. + Yields: + LocalRuntime: An instance of `LocalRuntime` configured with the aggregator, + collaborators, and backend. + """ + # Inline import + from tests.end_to_end.utils.wf_helper import ( + callable_to_init_agg_unserializable_pvt_attrs, + callable_to_init_collab_unserializable_pvt_attrs + ) + collab_callback_func = request.param[0] if hasattr(request, 'param') and request.param else None + collab_value = request.param[1] if hasattr(request, 'param') and request.param else None + agg_callback_func = request.param[2] if hasattr(request, 'param') and request.param else None + + # Get the callback functions from the locals using string + collab_callback_func_name = locals()[collab_callback_func] if collab_callback_func else None + agg_callback_func_name = locals()[agg_callback_func] if agg_callback_func else None + collaborators_list = [] + + # Setup aggregator + if agg_callback_func_name: + aggregator = Aggregator(name="agg", + private_attributes_callable=agg_callback_func_name) + else: + aggregator = Aggregator() + + # Setup collaborators + for i in range(request.config.num_collaborators): + func_var = i if collab_value == "int" else f"collaborator{i}" if collab_value == "str" else None + collab = Collaborator( + name=f"collaborator{i}", + private_attributes_callable=collab_callback_func_name + ) + collaborators_list.append(collab) + + workflow_backend = request.config.workflow_backend if hasattr(request.config, 'workflow_backend') else None + if workflow_backend: + local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators_list, backend=workflow_backend) + else: + local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators_list) + + # Return the federation fixture + return workflow_local_fixture( + aggregator=aggregator, + collaborators=collaborators_list, + runtime=local_runtime, + ) diff --git a/tests/end_to_end/utils/wf_helper.py b/tests/end_to_end/utils/wf_helper.py index fcde1118d0..3d1403aae9 100644 --- a/tests/end_to_end/utils/wf_helper.py +++ b/tests/end_to_end/utils/wf_helper.py @@ -4,6 +4,10 @@ from metaflow import Flow import logging import numpy as np +from openfl.databases import TensorDB +from openfl.utilities import TensorKey + +import tests.end_to_end.utils.exceptions as ex log = logging.getLogger(__name__) @@ -112,3 +116,46 @@ def init_agg_pvt_attr_np(): of a NumPy array of shape (10, 28, 28) filled with random values. """ return {"test_loader": np.random.rand(10, 28, 28)} + + +def callable_to_init_collab_unserializable_pvt_attrs(): + """ + Create and return a TensorDB + """ + return {"col_tensor_db": TensorDB()} + + +def callable_to_init_agg_unserializable_pvt_attrs(): + """ + Create and return a TensorDB + """ + return {"agg_tensor_db": TensorDB()} + + +def run_notebook(notebook_path, output_notebook_path): + """ + Function to run the notebook. + Args: + notebook_path (str): Path to the notebook + participant_res_files (dict): Dictionary containing participant names and their result log files + Returns: + bool: True if successful, else False + """ + import papermill as pm + try: + log.info(f"Running the notebook: {notebook_path} with output notebook path: {output_notebook_path}") + output = pm.execute_notebook( + input_path=notebook_path, + output_path=output_notebook_path, + request_save_on_cell_execute=True, + autosave_cell_every=5, # autosave every 5 seconds + log_output=True, + ) + except pm.exceptions.PapermillExecutionError as e: + log.error(f"PapermillExecutionError: {e}") + raise e + + except ex.NotebookRunException as e: + log.error(f"Failed to run the notebook: {e}") + raise e + return True diff --git a/tests/end_to_end/workflow/unserializable_private_attr.py b/tests/end_to_end/workflow/unserializable_private_attr.py new file mode 100644 index 0000000000..7fc3a09ed5 --- /dev/null +++ b/tests/end_to_end/workflow/unserializable_private_attr.py @@ -0,0 +1,136 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from openfl.experimental.workflow.interface import FLSpec, Aggregator, Collaborator +from openfl.experimental.workflow.placement import aggregator, collaborator +from openfl.databases import TensorDB +from openfl.utilities import TensorKey +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger(__name__) + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = torch.relu(self.fc1(x)) + x = self.fc2(x) + return x + +def FedAvg(models, weights=None): + new_model = models[0] + state_dicts = [model.state_dict() for model in models] + state_dict = new_model.state_dict() + for key in models[1].state_dict(): + state_dict[key] = torch.from_numpy( + np.average( + [state[key].numpy() for state in state_dicts], axis=0, weights=weights + ) + ) + new_model.load_state_dict(state_dict) + return new_model + + +class TestFlowUnserializablePrivateAttributes(FLSpec): + """ + Testflow to validate handling of unserializable private attributes. + """ + __test__ = False # to prevent pytest from trying to discover tests in the class + + def __init__(self, rounds=5, **kwargs): + super().__init__(**kwargs) + self.model = Net() + self.current_round = 0 + self.n_rounds = rounds + + @aggregator + def start(self): + self.collaborators = self.runtime.collaborators + self.next(self.aggregated_model_validation, foreach="collaborators") + + @collaborator + def aggregated_model_validation(self): + log.info(f"Performing aggregated model validation for collaborator {self.input}") + self.next(self.train) + + @collaborator + def train(self): + # Save trained models to Collaborator's Tensor DB + self.save_model_to_tensordb( + model=self.model, + tensordb=self.col_tensor_db, + origin=self.input, + round=self.current_round, + report=False, + tags="Trained_Tensor", + ) + log.info(self.col_tensor_db) + self.next(self.local_model_validation) + + @collaborator + def local_model_validation(self): + self.next(self.join) + + @aggregator + def join(self, inputs): + + # Update agg_tensor_db with each collaborator's model weights + for input in inputs: + # Save model to Aggregator's Tensor DB + self.save_model_to_tensordb( + model=input.model, + tensordb=self.agg_tensor_db, + origin=input.input, + round=self.current_round, + report=False, + tags="Trained", + ) + + self.model = FedAvg([input.model for input in inputs]) + + # Save model to Aggregator's Tensor DB + self.save_model_to_tensordb( + model=self.model, + tensordb=self.agg_tensor_db, + origin="Agg", + round=self.current_round, + report=False, + tags="Agg_Tensor", + ) + print(self.agg_tensor_db) + + self.current_round += 1 + if self.current_round < self.n_rounds: + self.next(self.aggregated_model_validation, foreach="collaborators") + else: + self.next(self.end) + + @aggregator + def end(self): + print(f"This is the end of the flow") + + def save_model_to_tensordb( + self, model=None, tensordb=None, origin=None, round=0, report=False, tags=("") + ): + # Update tensor_db + tensor_key_dict = {} + for name, param in model.named_parameters(): + tensor_key = TensorKey( + tensor_name=name, + origin=origin, + round_number=round, + report=False, + tags=tags, + ) + tensor_key_dict[tensor_key] = param.detach().cpu().numpy() + tensordb.cache_tensor(tensor_key_dict) diff --git a/tests/github/test_hello_federation.py b/tests/github/test_hello_federation.py index fbd6da87e6..819d4443c7 100644 --- a/tests/github/test_hello_federation.py +++ b/tests/github/test_hello_federation.py @@ -32,6 +32,7 @@ def main(): parser.add_argument('--col1-data-path', default='1') parser.add_argument('--col2-data-path', default='2') parser.add_argument('--save-model') + parser.add_argument('--transport-protocol', default='grpc', help='Transport protocol for communication') origin_dir = Path.cwd().resolve() args = parser.parse_args() @@ -49,11 +50,14 @@ def main(): col1, col2 = args.col1, args.col2 col1_data_path, col2_data_path = args.col1_data_path, args.col2_data_path save_model = args.save_model + transport_protocol = args.transport_protocol + if transport_protocol not in ['grpc', 'rest']: # Updated to include 'rest' as a valid option + raise ValueError(f"Invalid transport protocol: {transport_protocol}. Use 'grpc' or 'rest'.") # START # ===== # Make sure you are in a Python virtual environment with the FL package installed. - create_certified_workspace(fed_workspace, template, fqdn, rounds_to_train) + create_certified_workspace(fed_workspace, template, fqdn, rounds_to_train, transport_protocol) certify_aggregator(fqdn) workspace_root = Path().resolve() # Get the absolute directory path for the workspace diff --git a/tests/github/utils.py b/tests/github/utils.py index 34aab84b42..ce263a7d33 100644 --- a/tests/github/utils.py +++ b/tests/github/utils.py @@ -4,8 +4,8 @@ from subprocess import check_call import os from pathlib import Path -import re import tarfile +import yaml def create_collaborator(col, workspace_root, data_path, archive_name, fed_workspace): @@ -46,7 +46,7 @@ def create_collaborator(col, workspace_root, data_path, archive_name, fed_worksp ) -def create_certified_workspace(path, template, fqdn, rounds_to_train): +def create_certified_workspace(path, template, fqdn, rounds_to_train, transport_protocol='grpc'): shutil.rmtree(path, ignore_errors=True) check_call(['fx', 'workspace', 'create', '--prefix', path, '--template', template]) os.chdir(path) @@ -55,17 +55,26 @@ def create_certified_workspace(path, template, fqdn, rounds_to_train): # Initialize FL plan check_call(['fx', 'plan', 'initialize', '-a', fqdn]) plan_path = Path('plan/plan.yaml') + + # Read the plan.yaml file + with open(plan_path, 'r', encoding='utf-8') as file: + plan_config = yaml.safe_load(file) + + # Update rounds_to_train and transport_protocol values try: - rounds_to_train = int(rounds_to_train) - with open(plan_path, "r", encoding='utf-8') as sources: - lines = sources.readlines() - with open(plan_path, "w", encoding='utf-8') as sources: - for line in lines: - sources.write( - re.sub(r'rounds_to_train.*', f'rounds_to_train: {rounds_to_train}', line) - ) - except (ValueError, TypeError): - pass + # Update rounds_to_train if provided + if rounds_to_train is not None: + plan_config['aggregator']['settings']['rounds_to_train'] = int(rounds_to_train) + + # Update transport_protocol + plan_config['network']['settings']['transport_protocol'] = transport_protocol + + # Write the updated config back to the file + with open(plan_path, 'w', encoding='utf-8') as file: + yaml.safe_dump(plan_config, file, default_flow_style=False) + except (ValueError, TypeError, KeyError) as e: + print(f"Warning: Could not update plan.yaml: {e}") + # Create certificate authority for workspace check_call(['fx', 'workspace', 'certify']) diff --git a/tests/openfl/federated/plan/test_plan.py b/tests/openfl/federated/plan/test_plan.py index dcc1f9707a..a4d63339b8 100644 --- a/tests/openfl/federated/plan/test_plan.py +++ b/tests/openfl/federated/plan/test_plan.py @@ -10,6 +10,8 @@ from openfl.federated.plan.plan import Plan from openfl.component.assigner import RandomGroupedAssigner from openfl.component.aggregator import Aggregator +from openfl.transport.rest.aggregator_server import AggregatorRESTServer +from openfl.transport.grpc.aggregator_server import AggregatorGRPCServer @pytest.fixture @@ -47,3 +49,27 @@ def test_get_aggregator(mocker, plan): mocker.patch('openfl.protocols.utils.load_proto', mock.Mock()) Aggregator._load_initial_tensors = mock.Mock() assert isinstance(plan.get_aggregator(), Aggregator) + +def test_get_server_rest(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + mock_setup_ssl = mocker.patch('openfl.transport.rest.aggregator_server.AggregatorRESTServer._setup_ssl_context', return_value=mock.Mock()) + plan.config['network']['settings']['transport_protocol'] = 'rest' + server = plan.get_server() + assert isinstance(server, AggregatorRESTServer) + +def test_get_server_grpc(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + plan.config['network']['settings']['transport_protocol'] = 'grpc' + server = plan.get_server() + assert isinstance(server, AggregatorGRPCServer) + +def test_get_server_default_certificates(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + server = plan.get_server() + assert isinstance(server, AggregatorGRPCServer) # Default to gRPC + +def test_get_server_invalid_protocol(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + plan.config['network']['settings']['transport_protocol'] = 'invalid_protocol' + with pytest.raises(ValueError): + plan.get_server() diff --git a/tests/openfl/transport/rest/__init__.py b/tests/openfl/transport/rest/__init__.py new file mode 100644 index 0000000000..ba2df757ff --- /dev/null +++ b/tests/openfl/transport/rest/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Transport tests package.""" diff --git a/tests/openfl/transport/rest/conftest.py b/tests/openfl/transport/rest/conftest.py new file mode 100644 index 0000000000..f1572cd34a --- /dev/null +++ b/tests/openfl/transport/rest/conftest.py @@ -0,0 +1,39 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Shared test configurations for transport tests.""" + +import pytest +import logging +from pathlib import Path + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Configure logging for tests.""" + logging.basicConfig(level=logging.DEBUG) + yield + + +@pytest.fixture(autouse=True) +def mock_environment(monkeypatch): + """Mock environment variables and system settings.""" + monkeypatch.setenv('PYTHONPATH', '') # Clear PYTHONPATH to avoid interference + yield + + +@pytest.fixture +def test_data_dir(): + """Get the test data directory.""" + return Path(__file__).parent / 'data' + + +@pytest.fixture(autouse=True) +def setup_test_data(test_data_dir): + """Set up test data directory.""" + test_data_dir.mkdir(exist_ok=True) + yield + # Cleanup after tests if needed + if test_data_dir.exists(): + for file in test_data_dir.glob('*'): + if file.is_file(): + file.unlink() diff --git a/tests/openfl/transport/rest/test_rest_server.py b/tests/openfl/transport/rest/test_rest_server.py new file mode 100644 index 0000000000..089235466d --- /dev/null +++ b/tests/openfl/transport/rest/test_rest_server.py @@ -0,0 +1,449 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""REST server tests module.""" + +import pytest +import ssl +from unittest import mock +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from datetime import datetime, timedelta + +from openfl.transport.rest.aggregator_server import AggregatorRESTServer +from openfl.protocols import aggregator_pb2, base_pb2 + +# The .proto file has GetAggregatedTensorsRequest/Response and TensorSpec, +# but the Python protobuf files are out of sync and missing these classes. +# Add them to the module using existing protobuf infrastructure. + +# Create TensorSpec class using the existing protobuf pattern +class TensorSpec: + def __init__(self): + self.tensor_name = "" + self.round_number = 0 + self.report = False + self.tags = [] + self.require_lossless = False + +# Create GetAggregatedTensorsRequest using existing MessageHeader +class GetAggregatedTensorsRequest: + def __init__(self): + self.header = aggregator_pb2.MessageHeader() + self.tensor_specs = [] + +# Create GetAggregatedTensorsResponse with protobuf compatibility +class GetAggregatedTensorsResponse: + def __init__(self, header=None, tensors=None): + self.header = header or aggregator_pb2.MessageHeader() + self.tensors = tensors or [] + self.DESCRIPTOR = None + +# Add the missing classes to aggregator_pb2 module so the server can find them +aggregator_pb2.TensorSpec = TensorSpec +aggregator_pb2.GetAggregatedTensorsRequest = GetAggregatedTensorsRequest +aggregator_pb2.GetAggregatedTensorsResponse = GetAggregatedTensorsResponse +aggregator_pb2.NamedTensorProto = base_pb2.NamedTensor + +# Patch json_format to handle the new classes since they're not "real" protobuf messages +from google.protobuf import json_format + +original_parse_dict = json_format.ParseDict +original_message_to_dict = json_format.MessageToDict + +def patched_parse_dict(js_dict, message, **kwargs): + """Custom ParseDict to handle the missing protobuf classes.""" + if isinstance(message, GetAggregatedTensorsRequest): + # Parse header manually + if 'header' in js_dict: + header_data = js_dict['header'] + message.header.sender = header_data.get('sender', '') + message.header.receiver = header_data.get('receiver', '') + message.header.federation_uuid = header_data.get('federation_uuid', '') + message.header.single_col_cert_common_name = header_data.get('single_col_cert_common_name', '') + + # Parse tensor specs manually + if 'tensor_specs' in js_dict: + message.tensor_specs = [] + for spec_data in js_dict['tensor_specs']: + spec = TensorSpec() + spec.tensor_name = spec_data.get('tensor_name', '') + spec.round_number = spec_data.get('round_number', 0) + spec.report = spec_data.get('report', False) + spec.tags = spec_data.get('tags', []) + spec.require_lossless = spec_data.get('require_lossless', False) + message.tensor_specs.append(spec) + return message + else: + return original_parse_dict(js_dict, message, **kwargs) + +def patched_message_to_dict(message, **kwargs): + """Custom MessageToDict to handle our custom protobuf classes.""" + if isinstance(message, GetAggregatedTensorsResponse): + # Manually convert our custom response to dict + result = { + "header": original_message_to_dict(message.header, **kwargs) if message.header else {}, + "tensors": [] + } + # Convert tensors to dict format + for tensor in message.tensors: + if tensor: + result["tensors"].append(original_message_to_dict(tensor, **kwargs)) + return result + else: + return original_message_to_dict(message, **kwargs) + +json_format.ParseDict = patched_parse_dict +json_format.MessageToDict = patched_message_to_dict + +def generate_test_certificates(cert_path, key_path, root_cert_path): + """Generate self-signed certificates for testing.""" + # Generate private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048 + ) + + # Generate self-signed certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, u"test.example.com"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Test Organization"), + ]) + + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + private_key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.utcnow() + ).not_valid_after( + datetime.utcnow() + timedelta(days=1) + ).sign(private_key, hashes.SHA256()) + + # Write private key + with open(key_path, "wb") as f: + f.write(private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + )) + + # Write certificate + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + # For testing, use the same cert as root CA + with open(root_cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + +@pytest.fixture +def mock_aggregator(): + """Create a mock aggregator for testing.""" + aggregator = mock.Mock() + aggregator.uuid = "test-uuid" + aggregator.federation_uuid = "fed-uuid" + aggregator.authorized_cols = ["test-collaborator"] + aggregator.single_col_cert_common_name = "test-cert-cn" + aggregator.valid_collaborator_cn_and_id = mock.Mock(return_value=True) + aggregator.get_tasks = mock.Mock(return_value=(["task1", "task2"], 1, 5, False)) + aggregator.get_aggregated_tensor = mock.Mock() + aggregator.send_local_task_results = mock.Mock() + # Disable connector mode by default + aggregator.get_interop_client = mock.Mock(return_value=None) + # Add mock for task completion tracking + aggregator._collaborator_task_completed = mock.Mock(return_value=True) + # Add mock assigner + mock_assigner = mock.Mock() + mock_assigner.get_tasks_for_collaborator = mock.Mock(return_value=[]) + aggregator.assigner = mock_assigner + # Add collaborators_done list + aggregator.collaborators_done = [] + return aggregator + + +@pytest.fixture +def ssl_certs(tmp_path): + """Create temporary SSL certificate files for testing.""" + cert_path = tmp_path / "test_cert.pem" + key_path = tmp_path / "test_key.pem" + root_path = tmp_path / "test_root.pem" + + generate_test_certificates(cert_path, key_path, root_path) + + return { + 'cert': str(cert_path), + 'key': str(key_path), + 'root': str(root_path) + } + + +@pytest.fixture +def rest_server(mock_aggregator, ssl_certs): + """Create REST server instance for testing.""" + server = AggregatorRESTServer( + aggregator=mock_aggregator, + agg_addr="localhost", + agg_port=8080, + use_tls=True, + require_client_auth=True, + certificate=ssl_certs['cert'], + private_key=ssl_certs['key'], + root_certificate=ssl_certs['root'] + ) + return server + + +class TestAggregatorRESTServer: + """Test cases for AggregatorRESTServer.""" + + def test_ssl_context_setup(self, rest_server, ssl_certs): + """Test SSL context configuration.""" + with mock.patch('ssl.SSLContext') as mock_ssl_context: + mock_context = mock.Mock() + mock_ssl_context.return_value = mock_context + mock_context.options = 0 + + rest_server._setup_ssl_context( + certificate=ssl_certs['cert'], + private_key=ssl_certs['key'], + root_certificate=ssl_certs['root'] + ) + + mock_ssl_context.assert_called_once_with(ssl.PROTOCOL_TLS_SERVER) + mock_context.load_cert_chain.assert_called_once_with( + certfile=ssl_certs['cert'], + keyfile=ssl_certs['key'] + ) + + assert mock_context.load_verify_locations.call_count == 2 + assert all( + call == mock.call(cafile=ssl_certs['root']) + for call in mock_context.load_verify_locations.call_args_list + ) + + assert mock_context.verify_mode == ssl.CERT_REQUIRED + + def test_get_tasks_valid_request(self, rest_server, mock_aggregator): + """Test successful task retrieval.""" + mock_tasks = [ + aggregator_pb2.Task(name="task1", function_name="func1", task_type="train"), + aggregator_pb2.Task(name="task2", function_name="func2", task_type="validate") + ] + mock_aggregator.get_tasks.return_value = (mock_tasks, 1, 5, True) + + with rest_server.app.test_client() as client: + response = client.get('experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid" + }) + + assert response.status_code == 200 + data = response.get_json() + assert data["roundNumber"] == 1 + assert len(data["tasks"]) == 2 + assert data["sleepTime"] == 5 + assert "quit" in data + assert data["quit"] + + mock_aggregator.get_tasks.return_value = (mock_tasks, 1, 5, False) + + with rest_server.app.test_client() as client: + response = client.get('experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid" + }) + + assert response.status_code == 200 + data = response.get_json() + assert not data.get("quit", False) + + def test_get_tasks_unauthorized(self, rest_server): + """Test task retrieval with unauthorized collaborator.""" + with rest_server.app.test_client() as client: + response = client.get('experimental/v1/tasks', query_string={ + "collaborator_id": "unauthorized-collaborator", + "federation_uuid": "fed-uuid" + }) + assert response.status_code == 401 + + def test_post_task_results(self, rest_server, mock_aggregator): + """Test task results submission.""" + task_results = aggregator_pb2.TaskResults() + task_results.task_name = "test_task" + task_results.round_number = 1 + task_results.data_size = 100 + + task_results.header.sender = "test-collaborator" + task_results.header.receiver = str(mock_aggregator.uuid) + task_results.header.federation_uuid = str(mock_aggregator.federation_uuid) + task_results.header.single_col_cert_common_name = "test-cert-cn" + + tensor = base_pb2.NamedTensor() + tensor.name = "test_tensor" + task_results.tensors.append(tensor) + + data_stream = base_pb2.DataStream() + data_stream.npbytes = task_results.SerializeToString() + data_stream.size = len(data_stream.npbytes) + + request_data = ( + len(data_stream.SerializeToString()).to_bytes(4, byteorder='big') + + data_stream.SerializeToString() + + (0).to_bytes(4, byteorder='big') + ) + + mock_aggregator.assigner.get_tasks_for_collaborator.return_value = [ + aggregator_pb2.Task(name="test_task") + ] + + with rest_server.app.test_client() as client: + response = client.post( + 'experimental/v1/tasks/results', + data=request_data, + headers={ + "Sender": "test-collaborator", + "Receiver": str(mock_aggregator.uuid), + "Federation-UUID": str(mock_aggregator.federation_uuid), + "Single-Col-Cert-CN": "test-cert-cn" + } + ) + + assert response.status_code == 200 + mock_aggregator.send_local_task_results.assert_called_once() + + def test_get_aggregated_tensor(self, rest_server, mock_aggregator): + """Test aggregated tensor retrieval.""" + mock_tensor = base_pb2.NamedTensor() + mock_tensor.name = "test_tensor" + + def mock_get_aggregated_tensor(tensor_name, round_number, report=False, tags=(), require_lossless=False, requested_by=None): + return mock_tensor + + mock_aggregator.get_aggregated_tensor.side_effect = mock_get_aggregated_tensor + + request_payload = { + "header": { + "sender": "test-collaborator", + "receiver": str(mock_aggregator.uuid), + "federation_uuid": str(mock_aggregator.federation_uuid), + "single_col_cert_common_name": "test-cert-cn" + }, + "tensor_specs": [{ + "tensor_name": "test_tensor", + "round_number": 1, + "report": False, + "tags": [], + "require_lossless": False + }] + } + + with rest_server.app.test_client() as client: + response = client.post('/experimental/v1/tensors/aggregated/batch', + json=request_payload, + headers={ + "Sender": "test-collaborator", + "Receiver": str(mock_aggregator.uuid), + "Federation-UUID": str(mock_aggregator.federation_uuid), + "Single-Col-Cert-CN": "test-cert-cn" + }) + + if response.status_code != 200: + print(f"Response status: {response.status_code}") + print(f"Response data: {response.get_data(as_text=True)}") + + assert response.status_code == 200 + data = response.get_json() + assert "header" in data + assert "tensors" in data + assert len(data["tensors"]) == 1 + + mock_aggregator.get_aggregated_tensor.assert_called_once_with( + "test_tensor", + 1, + False, + (), + False, + "test-collaborator" + ) + + def test_relay_message_not_enabled(self, rest_server): + """Test relay endpoint when not enabled.""" + relay_msg = aggregator_pb2.InteropMessage() + relay_msg.header.sender = "test-collaborator" + relay_msg.header.receiver = str(rest_server.aggregator.uuid) + relay_msg.header.federation_uuid = str(rest_server.aggregator.federation_uuid) + + with rest_server.app.test_client() as client: + response = client.post( + '/experimental/v1/interop/relay', + json={"header": {"sender": "test-collaborator"}} + ) + assert response.status_code == 501 + + def test_invalid_federation_uuid(self, rest_server): + """Test request with invalid federation UUID.""" + with rest_server.app.test_client() as client: + response = client.get('/experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "invalid-uuid" + }) + assert response.status_code == 401 + + def test_malformed_task_results(self, rest_server): + """Test submission of malformed task results.""" + with rest_server.app.test_client() as client: + response = client.post( + 'experimental/v1/tasks/results', + data=b"invalid data", + headers={ + "Sender": "test-collaborator", + "Receiver": str(rest_server.aggregator.uuid), + "Federation-UUID": str(rest_server.aggregator.federation_uuid) + } + ) + assert response.status_code == 400 + + def test_connector_mode_tasks(self, rest_server): + """Test task retrieval in connector mode.""" + rest_server.use_connector = True + with rest_server.app.test_client() as client: + response = client.get('/experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid" + }) + assert response.status_code == 501 + + def test_invalid_round_number(self, rest_server): + """Test tensor retrieval with invalid round number.""" + request_payload = { + "header": { + "sender": "test-collaborator", + "receiver": str(rest_server.aggregator.uuid), + "federation_uuid": str(rest_server.aggregator.federation_uuid), + "single_col_cert_common_name": "test-cert-cn" + }, + "tensor_specs": [{ + "tensor_name": "test_tensor", + "round_number": "invalid", + "report": False, + "tags": [], + "require_lossless": False + }] + } + + with rest_server.app.test_client() as client: + response = client.post('/experimental/v1/tensors/aggregated/batch', + json=request_payload, + headers={ + "Sender": "test-collaborator", + "Receiver": str(rest_server.aggregator.uuid), + "Federation-UUID": str(rest_server.aggregator.federation_uuid), + "Single-Col-Cert-CN": "test-cert-cn" + }) + assert response.status_code == 400