-
Notifications
You must be signed in to change notification settings - Fork 72
[ci] add PPU unittest workflow #522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/v1.1
Are you sure you want to change the base?
Changes from all commits
8828322
c2e9bd0
edaba50
da2b6c7
9d229e5
b8a20b1
26313f4
8b9bfbc
0561082
154419d
f9b3df5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| name: Unit Test PPU CI | ||
|
|
||
| on: | ||
| pull_request: | ||
| types: [opened, reopened, synchronize] | ||
| branches: ['release/**'] | ||
| workflow_dispatch: | ||
|
|
||
| concurrency: | ||
| group: unittest-ppu-ci-${{ github.event.pull_request.number }} | ||
| cancel-in-progress: true | ||
|
|
||
| jobs: | ||
| ci-test: | ||
| # `container:` would inject -v /var/run/docker.sock which DSW authZ blocks here. | ||
| runs-on: tzrec-ppu-runner | ||
| timeout-minutes: 1440 | ||
| steps: | ||
| - name: FetchCommit | ||
| uses: actions/checkout@v4 | ||
| with: | ||
| path: run_${{ github.run_id }} | ||
|
Comment on lines
+19
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor drift from |
||
| - name: RunUnitTestPPUCI | ||
| id: run_unittest_ppu_ci | ||
| run: | | ||
| WORK_ROOT="$(dirname "$RUNNER_WORKSPACE")" | ||
| EXTERNALS_ROOT="$(dirname "$WORK_ROOT")/externals" | ||
| docker run --rm \ | ||
| --workdir /__w/TorchEasyRec/TorchEasyRec \ | ||
| --device=/dev/alixpu_ppu0 --device=/dev/alixpu_ppu1 --device=/dev/alixpu --device=/dev/alixpu_ctl \ | ||
| --shm-size=256g --ulimit memlock=-1 \ | ||
| -e HOME=/github/home \ | ||
| -e GITHUB_ACTIONS=true \ | ||
| -e CI=true \ | ||
| -e CI_HYPOTHESIS=true \ | ||
| -e TORCHINDUCTOR_CACHE_DIR=/github/home/.torchinductor \ | ||
| -v "$WORK_ROOT":/__w \ | ||
| -v "$EXTERNALS_ROOT":/__e:ro \ | ||
| -v "$WORK_ROOT/_temp":/__w/_temp \ | ||
| -v "$WORK_ROOT/_actions":/__w/_actions \ | ||
| -v "$WORK_ROOT/_tool":/__w/_tool \ | ||
| -v "$WORK_ROOT/_temp/_github_home":/github/home \ | ||
| -v "$WORK_ROOT/_temp/_github_workflow":/github/workflow \ | ||
| mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/tzrec-test:1.1-ppu \ | ||
| bash -c 'cd run_${{ github.run_id }} && bash scripts/ci/ci_test_ppu.sh' | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| #!/usr/bin/env bash | ||
|
|
||
| bash scripts/gen_proto.sh | ||
| bash scripts/ci/ci_data.sh | ||
|
|
||
| MKL_THREADING_LAYER=GNU PYTHONPATH=. python tzrec/tests/run.py |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,9 @@ | |
| # limitations under the License. | ||
|
|
||
| from enum import Enum, unique | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| @unique | ||
|
|
@@ -19,3 +22,20 @@ class Kernel(Enum): | |
| TRITON = "TRITON" | ||
| PYTORCH = "PYTORCH" | ||
| CUTLASS = "CUTLASS" | ||
|
|
||
|
|
||
| _is_ppu_arch_cached: Optional[bool] = None | ||
|
|
||
|
|
||
| def is_ppu_arch() -> bool: | ||
| """Return True if a CUDA device is an Alibaba PPU (alixpu) accelerator.""" | ||
| global _is_ppu_arch_cached | ||
| if _is_ppu_arch_cached is None: | ||
| try: | ||
| _is_ppu_arch_cached = torch.cuda.is_available() and any( | ||
| "PPU" in torch.cuda.get_device_name(i) | ||
| for i in range(torch.cuda.device_count()) | ||
| ) | ||
| except Exception: | ||
| _is_ppu_arch_cached = False | ||
| return _is_ppu_arch_cached | ||
|
Comment on lines
+30
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two notes on the implementation vs. the docstring:
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,42 @@ | |
| "CUDA/HIP is not available or no GPUs detected", | ||
| ) | ||
|
|
||
| try: | ||
| import hstu_attn_2_cuda # noqa: F401 | ||
|
|
||
| _has_hstu_attn_2_cuda = True | ||
| except ImportError: | ||
| _has_hstu_attn_2_cuda = False | ||
|
|
||
| cutlass_hstu_unavailable: Tuple[bool, str] = ( | ||
| not _has_hstu_attn_2_cuda, | ||
| "hstu_attn_2_cuda wheel is not installed", | ||
| ) | ||
|
|
||
| try: | ||
| import torch_fx_tool # noqa: F401 | ||
|
|
||
| _has_torch_fx_tool = True | ||
| except ImportError: | ||
| _has_torch_fx_tool = False | ||
|
|
||
| torch_fx_tool_unavailable: Tuple[bool, str] = ( | ||
| not _has_torch_fx_tool, | ||
| "torch_fx_tool wheel is not installed (required for RTP export)", | ||
| ) | ||
|
|
||
|
|
||
| def get_compare_tolerance( | ||
| dtype: torch.dtype, | ||
| ) -> Tuple[Optional[float], Optional[float]]: | ||
| """Return (atol, rtol) for Triton-vs-PyTorch comparisons; widen fp32 on PPU.""" | ||
| from tzrec.ops import is_ppu_arch | ||
|
|
||
| if is_ppu_arch() and dtype == torch.float32: | ||
| return (3e-5, 2e-5) | ||
| return (None, None) | ||
|
Comment on lines
+67
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
At minimum, the docstring should call out the |
||
|
|
||
|
|
||
| _settings.register_profile( | ||
| "default", _settings(_settings.get_profile("default"), print_blob=True) | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
github.event.pull_request.numberis empty forworkflow_dispatch, so all manual invocations collapse into the group keyunittest-ppu-ci-andcancel-in-progress: truewill cancel each other. Suggest: