Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/unittest_ppu_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Unit Test PPU CI

on:
pull_request:
types: [opened, reopened, synchronize]
branches: ['release/**']
workflow_dispatch:

concurrency:
group: unittest-ppu-ci-${{ github.event.pull_request.number }}
cancel-in-progress: true
Comment on lines +9 to +11
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github.event.pull_request.number is empty for workflow_dispatch, so all manual invocations collapse into the group key unittest-ppu-ci- and cancel-in-progress: true will cancel each other. Suggest:

Suggested change
concurrency:
group: unittest-ppu-ci-${{ github.event.pull_request.number }}
cancel-in-progress: true
concurrency:
group: unittest-ppu-ci-${{ github.event.pull_request.number || github.run_id }}
cancel-in-progress: true


jobs:
ci-test:
# `container:` would inject -v /var/run/docker.sock which DSW authZ blocks here.
runs-on: tzrec-ppu-runner
timeout-minutes: 1440
steps:
- name: FetchCommit
uses: actions/checkout@v4
with:
path: run_${{ github.run_id }}
Comment on lines +19 to +22
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor drift from unittest_h20_ci.yml: the H20 lane pins ref: ${{ github.event.pull_request.head.sha }}, this one doesn't. With pull_request, the default checkout is the merge ref — slightly different semantics (catches conflicts, but won't pin if the PR is rebased mid-run). If the intent was parity with H20, add the ref: line; if intentional, worth a one-line comment noting why.

- name: RunUnitTestPPUCI
id: run_unittest_ppu_ci
run: |
WORK_ROOT="$(dirname "$RUNNER_WORKSPACE")"
EXTERNALS_ROOT="$(dirname "$WORK_ROOT")/externals"
docker run --rm \
--workdir /__w/TorchEasyRec/TorchEasyRec \
--device=/dev/alixpu_ppu0 --device=/dev/alixpu_ppu1 --device=/dev/alixpu --device=/dev/alixpu_ctl \
--shm-size=256g --ulimit memlock=-1 \
-e HOME=/github/home \
-e GITHUB_ACTIONS=true \
-e CI=true \
-e CI_HYPOTHESIS=true \
-e TORCHINDUCTOR_CACHE_DIR=/github/home/.torchinductor \
-v "$WORK_ROOT":/__w \
-v "$EXTERNALS_ROOT":/__e:ro \
-v "$WORK_ROOT/_temp":/__w/_temp \
-v "$WORK_ROOT/_actions":/__w/_actions \
-v "$WORK_ROOT/_tool":/__w/_tool \
-v "$WORK_ROOT/_temp/_github_home":/github/home \
-v "$WORK_ROOT/_temp/_github_workflow":/github/workflow \
mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/tzrec-test:1.1-ppu \
bash -c 'cd run_${{ github.run_id }} && bash scripts/ci/ci_test_ppu.sh'
6 changes: 6 additions & 0 deletions scripts/ci/ci_test_ppu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

bash scripts/gen_proto.sh
bash scripts/ci/ci_data.sh

MKL_THREADING_LAYER=GNU PYTHONPATH=. python tzrec/tests/run.py
20 changes: 20 additions & 0 deletions tzrec/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
# limitations under the License.

from enum import Enum, unique
from typing import Optional

import torch


@unique
Expand All @@ -19,3 +22,20 @@ class Kernel(Enum):
TRITON = "TRITON"
PYTORCH = "PYTORCH"
CUTLASS = "CUTLASS"


_is_ppu_arch_cached: Optional[bool] = None


def is_ppu_arch() -> bool:
"""Return True if a CUDA device is an Alibaba PPU (alixpu) accelerator."""
global _is_ppu_arch_cached
if _is_ppu_arch_cached is None:
try:
_is_ppu_arch_cached = torch.cuda.is_available() and any(
"PPU" in torch.cuda.get_device_name(i)
for i in range(torch.cuda.device_count())
)
except Exception:
_is_ppu_arch_cached = False
return _is_ppu_arch_cached
Comment on lines +30 to +41
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two notes on the implementation vs. the docstring:

  1. Scope is broader than the docstring suggests. The function returns True if any visible CUDA device's name contains the substring "PPU", and the result is cached for the process lifetime. The docstring reads as if it tests "a CUDA device" — worth tightening to e.g. "Return True if any visible CUDA device name contains 'PPU'. Cached after first call." This also surfaces the mixed-fleet edge case (PPU + non-PPU on one host both report True).
  2. except Exception is wider than needed. The only realistic failure modes are RuntimeError from a broken driver and AssertionError from a stale device index. The broad clause will silently swallow future typos / API changes and quietly disable PPU code paths. Narrow to the specific exceptions, or at least leave a one-line comment naming the failure mode being defended against — otherwise the next reader has to guess.

6 changes: 5 additions & 1 deletion tzrec/ops/_triton/triton_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,13 @@ def _weighted_layer_norm_bwd_dx(


def _get_bwd_dwdb_configs() -> List[triton.Config]:
# PPU Triton mis-compiles num_warps=32 for this reduction kernel.
from tzrec.ops import is_ppu_arch

skip_32_warps = torch.ops.hip or is_ppu_arch()
configs = []
for BLOCK_N in [32, 64, 128, 256]:
for num_warps in [8, 16] + ([] if torch.ops.hip else [32]):
for num_warps in [8, 16] + ([] if skip_32_warps else [32]):
configs.append(
triton.Config(
{"BLOCK_N": BLOCK_N},
Expand Down
3 changes: 3 additions & 0 deletions tzrec/ops/hstu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Kernel,
)
from tzrec.utils.test_util import (
cutlass_hstu_unavailable,
generate_sparse_seq_len,
get_test_dtypes,
get_test_enable_tma,
Expand Down Expand Up @@ -491,6 +492,7 @@ def test_cache(
real_delta_out,
)

@unittest.skipIf(*cutlass_hstu_unavailable)
@unittest.skipIf(*gpu_unavailable)
# pyre-ignore
@given(
Expand Down Expand Up @@ -528,6 +530,7 @@ def test_attn_cutlass(self, *args, **kwargs) -> None:
# CUTLASS implementation and falls back to Triton internally. The
# delta/cached path is already covered by ``test_delta_attn_triton``.

@unittest.skipIf(*cutlass_hstu_unavailable)
@unittest.skipIf(*gpu_unavailable)
@given(
batch_size=st.sampled_from([1, 4]),
Expand Down
4 changes: 4 additions & 0 deletions tzrec/ops/jagged_tensors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ def _test_jagged_dense_bmm_broadcast_add(
from tzrec.ops.jagged_tensors import (
jagged_dense_bmm_broadcast_add,
)
from tzrec.utils.test_util import get_compare_tolerance

if atol is None and rtol is None:
atol, rtol = get_compare_tolerance(dtype)

if sparsity > 0.0:
lengths = generate_sparse_seq_len(
Expand Down
4 changes: 4 additions & 0 deletions tzrec/ops/mm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def _test_addmm(
rtol: Optional[float] = None,
) -> None:
from tzrec.ops.mm import addmm
from tzrec.utils.test_util import get_compare_tolerance

if atol is None and rtol is None:
atol, rtol = get_compare_tolerance(dtype)

# to enable more deterministic results.
torch.manual_seed(0)
Expand Down
15 changes: 14 additions & 1 deletion tzrec/tests/rank_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from tzrec.main import _create_features
from tzrec.tests import utils
from tzrec.utils import checkpoint_util, config_util, dynamicemb_util
from tzrec.utils.test_util import dfs_are_close, gpu_unavailable
from tzrec.utils.test_util import (
cutlass_hstu_unavailable,
dfs_are_close,
gpu_unavailable,
torch_fx_tool_unavailable,
)


class RankIntegrationTest(unittest.TestCase):
Expand Down Expand Up @@ -1036,6 +1041,7 @@ def test_rank_dlrm_hstu_train_eval_export_unified_aot(self):
acc_cfg = json.load(f)
self.assertEqual(acc_cfg.get("UNIFIED_AOT"), "1")

@unittest.skipIf(*cutlass_hstu_unavailable)
@unittest.skipIf(*gpu_unavailable)
def test_rank_dlrm_hstu_cutlass_train_eval_export(self):
self.success = utils.test_train_eval(
Expand Down Expand Up @@ -1064,6 +1070,7 @@ def test_rank_dlrm_hstu_cutlass_train_eval_export(self):
)
self.assertTrue(self.success)

@unittest.skipIf(*cutlass_hstu_unavailable)
@unittest.skipIf(*gpu_unavailable)
def test_rank_ultra_hstu_cutlass_train_eval_export(self):
self.success = utils.test_train_eval(
Expand Down Expand Up @@ -1132,6 +1139,11 @@ def test_multi_tower_zch_with_fg_train_eval_export_trt(self):
predict_columns=["user_id", "item_id", "clk", "probs"],
)

@unittest.skipIf(*torch_fx_tool_unavailable)
@unittest.skipIf(
not dynamicemb_util.has_dynamicemb,
"dynamicemb not available (config sets `dynamicemb { }` on features).",
)
@unittest.skipIf(*gpu_unavailable)
def test_multi_tower_din_rtp_train_export(self):
# set USE_RTP env here for gen correct rtp style train/eval data
Expand Down Expand Up @@ -1164,6 +1176,7 @@ def test_multi_tower_din_rtp_train_export(self):
in weight_json
)

@unittest.skipIf(*torch_fx_tool_unavailable)
@unittest.skipIf(*gpu_unavailable)
def test_dlrm_hstu_rtp_train_export(self):
self.success = utils.test_train_eval(
Expand Down
5 changes: 5 additions & 0 deletions tzrec/utils/dynamicemb_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ def build_dynamicemb_constraints(
dynamicemb_cfg: feature_pb2.DynamicEmbedding, emb_config: BaseEmbeddingConfig
) -> ParameterConstraints:
"""Build ParameterConstraints for DynamicEmbedding."""
if not has_dynamicemb:
raise RuntimeError(
"dynamicemb is not installed; required by features with "
"`dynamicemb { }` set."
)
embedding_dim = emb_config.embedding_dim
num_embeddings = emb_config.num_embeddings

Expand Down
36 changes: 36 additions & 0 deletions tzrec/utils/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,42 @@
"CUDA/HIP is not available or no GPUs detected",
)

try:
import hstu_attn_2_cuda # noqa: F401

_has_hstu_attn_2_cuda = True
except ImportError:
_has_hstu_attn_2_cuda = False

cutlass_hstu_unavailable: Tuple[bool, str] = (
not _has_hstu_attn_2_cuda,
"hstu_attn_2_cuda wheel is not installed",
)

try:
import torch_fx_tool # noqa: F401

_has_torch_fx_tool = True
except ImportError:
_has_torch_fx_tool = False

torch_fx_tool_unavailable: Tuple[bool, str] = (
not _has_torch_fx_tool,
"torch_fx_tool wheel is not installed (required for RTP export)",
)


def get_compare_tolerance(
dtype: torch.dtype,
) -> Tuple[Optional[float], Optional[float]]:
"""Return (atol, rtol) for Triton-vs-PyTorch comparisons; widen fp32 on PPU."""
from tzrec.ops import is_ppu_arch

if is_ppu_arch() and dtype == torch.float32:
return (3e-5, 2e-5)
return (None, None)
Comment on lines +67 to +75
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The (None, None) return path makes the contract awkward — every caller has to special-case it, and the docstring doesn't mention it. Two concrete issues this invites:

  1. Partial-override blind spot at call sites. Both callsites added in this PR use if atol is None and rtol is None: (see jagged_tensors_test.py:389, mm_test.py:79). A future caller that passes atol=1e-6, rtol=None silently skips the widening and rtol falls back to torch's dtype defaults — which on PPU could mask or surface false regressions depending on direction. Widening each independently would be safer:
    ppu_atol, ppu_rtol = get_compare_tolerance(dtype)
    atol = atol if atol is not None else ppu_atol
    rtol = rtol if rtol is not None else ppu_rtol
  2. Returning (None, None) to mean "use defaults" is unusual. Consider returning concrete defaults for non-PPU/non-fp32 (e.g., the dtype's standard torch.testing.assert_close defaults), or restructure as get_compare_tolerance(dtype, atol, rtol) so the helper owns the precedence rule and callsites stay one line.

At minimum, the docstring should call out the (None, None) return path.



_settings.register_profile(
"default", _settings(_settings.get_profile("default"), print_blob=True)
)
Expand Down
Loading