From ed818c2161a649ad37a31689154514e840efe795 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 01:00:18 +0000 Subject: [PATCH 1/7] potential fix --- .../graph_store/shared_dist_sampling_producer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 0f7461196..f7852c6f6 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -103,6 +103,7 @@ SamplerRuntime, create_dist_sampler, ) +from gigl.utils.share_memory import share_memory logger = Logger() @@ -871,7 +872,13 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) + # Move degree tensors to shared memory before workers are spawned so + # each worker maps the same allocation instead of pickling a private copy. + # In colocated mode this is handled by DistDataset.to_ipc_handle(); here + # the tensors arrive via RPC from the storage server and are not yet in + # shared memory, causing num_workers copies without this call. self._degree_tensors = degree_tensors + share_memory(self._degree_tensors) def init_backend(self) -> None: """Initialize worker processes once for this backend. From abb8e569dcc537566b817dc4d521c6a20eadc571 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 01:22:04 +0000 Subject: [PATCH 2/7] Update --- gigl/distributed/dist_ppr_sampler.py | 156 +++++++++++++----- .../shared_dist_sampling_producer.py | 44 ++++- 2 files changed, 149 insertions(+), 51 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 402e381c1..9aaefbfa1 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -37,6 +37,92 @@ ) +def build_ppr_node_type_to_edge_types( + is_homogeneous: bool, + edge_types: list[EdgeType], + edge_dir: str, +) -> dict[NodeType, list[EdgeType]]: + """Build the node-type → edge-types mapping used by the PPR forward-push kernel. + + For homogeneous graphs returns the singleton sentinel mapping. For + heterogeneous graphs, groups non-label edge types by their anchor node type + (destination for ``edge_dir="in"``, source for ``edge_dir="out"``). + + Args: + is_homogeneous: True if the graph has a single node/edge type. + edge_types: All edge types present in the graph (ignored when homogeneous). + edge_dir: Sampling direction — ``"in"`` or ``"out"``. + + Returns: + Dict mapping each anchor NodeType to the list of EdgeTypes traversable + from it during a PPR walk. + """ + if is_homogeneous: + return {_PPR_HOMOGENEOUS_NODE_TYPE: [_PPR_HOMOGENEOUS_EDGE_TYPE]} + + node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict(list) + for etype in edge_types: + if is_label_edge_type(etype): + continue + anchor_type = etype[-1] if edge_dir == "in" else etype[0] + node_type_to_edge_types[anchor_type].append(etype) + return dict(node_type_to_edge_types) + + +def build_ppr_total_degree_tensors( + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + dtype: torch.dtype, + node_type_to_edge_types: dict[NodeType, list[EdgeType]], +) -> dict[NodeType, torch.Tensor]: + """Pre-compute total-degree tensors for the PPR forward-push kernel. + + For homogeneous graphs converts the single degree tensor to ``dtype``. + For heterogeneous graphs sums per-edge-type degrees into a per-node-type + total, padding shorter tensors with zeros where node counts differ. + + This function is intentionally standalone so it can be called once in the + parent process (and the result shared across workers) rather than redundantly + inside each worker's ``DistPPRNeighborSampler.__init__``. + + Args: + degree_tensors: Per-edge-type degree tensors (homogeneous: single + ``torch.Tensor``; heterogeneous: ``dict[EdgeType, torch.Tensor]``). + dtype: Target dtype for the output tensors. + node_type_to_edge_types: Mapping from anchor NodeType to the list of + EdgeTypes traversable from it, as returned by + :func:`build_ppr_node_type_to_edge_types`. + + Returns: + Dict mapping NodeType to a 1-D total-degree tensor of shape + ``[num_nodes_of_that_type]`` with dtype ``dtype``. + + Raises: + ValueError: If a required edge type is missing from ``degree_tensors``. + """ + result: dict[NodeType, torch.Tensor] = {} + + if isinstance(degree_tensors, torch.Tensor): + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + else: + dtype_max = torch.iinfo(dtype).max + for node_type, edge_types in node_type_to_edge_types.items(): + max_len = 0 + for et in edge_types: + if et not in degree_tensors: + raise ValueError( + f"Edge type {et} not found in degree tensors. " + f"Available: {list(degree_tensors.keys())}" + ) + max_len = max(max_len, len(degree_tensors[et])) + summed = torch.zeros(max_len, dtype=torch.int64) + for et in edge_types: + et_degrees = degree_tensors[et] + summed[: len(et_degrees)] += et_degrees.to(torch.int64) + result[node_type] = summed.clamp(max=dtype_max).to(dtype) + + return result + + class DistPPRNeighborSampler(BaseDistNeighborSampler): """Personalized PageRank (PPR) based distributed neighbor sampler. @@ -134,14 +220,26 @@ def __init__( # edge types traversable from that node type. This is a graph-level # property used on every PPR iteration, so computing it once at init # avoids per-node summation and cache lookups in the hot loop. - # TODO (mkolodner-sc): This trades memory for throughput — we - # materialize a tensor per node type to avoid recomputing total degree - # on every neighbor during sampling. Computing it here (rather than in - # the dataset) also keeps the door open for edge-specific degree - # strategies. If memory becomes a bottleneck, revisit this. - self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = ( - self._build_total_degree_tensors(degree_tensors, total_degree_dtype) - ) + # + # In graph-store mode, SharedDistSamplingProducer pre-computes the + # total-degree dict once in the parent process, moves it to shared + # memory, and passes it here as degree_tensors (keys are NodeType + # strings). In colocated mode degree_tensors arrives as raw + # per-edge-type tensors (keys are EdgeType tuples, or a bare Tensor + # for homogeneous graphs) and we compute the total here. + if ( + isinstance(degree_tensors, dict) + and degree_tensors + and not isinstance(next(iter(degree_tensors)), tuple) + ): + # Already the pre-computed total (NodeType string keys). + self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = ( + degree_tensors + ) + else: + self._node_type_to_total_degree = self._build_total_degree_tensors( + degree_tensors, total_degree_dtype + ) # Build integer ID mappings for the C++ forward-push kernel. String # NodeType / EdgeType keys are only used at the Python boundary @@ -198,9 +296,7 @@ def _build_total_degree_tensors( ) -> dict[NodeType, torch.Tensor]: """Build total-degree tensors by summing per-edge-type degrees for each node type. - For homogeneous graphs, the total degree is just the single degree tensor. - For heterogeneous graphs, it sums degree tensors across all edge types - traversable from each node type, padding shorter tensors with zeros. + Delegates to the module-level :func:`build_ppr_total_degree_tensors`. Args: degree_tensors: Per-edge-type degree tensors from the dataset. @@ -209,39 +305,11 @@ def _build_total_degree_tensors( Returns: Dict mapping node type to a 1-D tensor of total degrees. """ - result: dict[NodeType, torch.Tensor] = {} - - if self._is_homogeneous: - assert isinstance(degree_tensors, torch.Tensor) - # Single edge type: degree values fit directly in the target dtype. - result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) - else: - assert isinstance(degree_tensors, dict) - dtype_max = torch.iinfo(dtype).max - for node_type, edge_types in self._node_type_to_edge_types.items(): - max_len = 0 - for et in edge_types: - if et not in degree_tensors: - raise ValueError( - f"Edge type {et} not found in degree tensors. " - f"Available: {list(degree_tensors.keys())}" - ) - max_len = max(max_len, len(degree_tensors[et])) - - # Each degree tensor is indexed by node ID (derived from CSR - # indptr), so index i in every edge type's tensor refers to - # the same node. Element-wise summation gives the total degree - # per node across all edge types. Shorter tensors are padded - # implicitly (only the first len(et_degrees) entries are added). - # Sum in int64: aggregate degrees are bounded by partition size - # and fit comfortably within int64 range in practice. - summed = torch.zeros(max_len, dtype=torch.int64) - for et in edge_types: - et_degrees = degree_tensors[et] - summed[: len(et_degrees)] += et_degrees.to(torch.int64) - result[node_type] = summed.clamp(max=dtype_max).to(dtype) - - return result + return build_ppr_total_degree_tensors( + degree_tensors=degree_tensors, + dtype=dtype, + node_type_to_edge_types=self._node_type_to_edge_types, + ) def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index f7852c6f6..b7838c02c 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -97,7 +97,11 @@ from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.dist_ppr_sampler import ( + build_ppr_node_type_to_edge_types, + build_ppr_total_degree_tensors, +) +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils.dist_sampler import ( SamplerInput, SamplerRuntime, @@ -872,12 +876,38 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) - # Move degree tensors to shared memory before workers are spawned so - # each worker maps the same allocation instead of pickling a private copy. - # In colocated mode this is handled by DistDataset.to_ipc_handle(); here - # the tensors arrive via RPC from the storage server and are not yet in - # shared memory, causing num_workers copies without this call. - self._degree_tensors = degree_tensors + # For PPR sampling, pre-compute the total-degree dict (summed across edge + # types, converted to the target dtype) once here in the parent process. + # Workers receive the result directly as degree_tensors and skip the + # per-worker summation in DistPPRNeighborSampler._build_total_degree_tensors. + # + # Then move to shared memory so all spawned workers map the same + # allocation instead of each pickling a private copy. In colocated mode + # DistDataset.to_ipc_handle() handles shared memory; here the tensors + # arrive via RPC and are plain heap allocations without this call. + if ( + isinstance(sampler_options, PPRSamplerOptions) + and degree_tensors is not None + ): + assert data.graph is not None, ( + "DistDataset.graph must be set for PPR sampling" + ) + is_homogeneous = not isinstance(data.graph, dict) + edge_types = list(data.graph.keys()) if isinstance(data.graph, dict) else [] + node_type_to_edge_types = build_ppr_node_type_to_edge_types( + is_homogeneous=is_homogeneous, + edge_types=edge_types, + edge_dir=data.edge_dir, + ) + self._degree_tensors: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = build_ppr_total_degree_tensors( + degree_tensors=degree_tensors, + dtype=sampler_options.total_degree_dtype, + node_type_to_edge_types=node_type_to_edge_types, + ) + else: + self._degree_tensors = degree_tensors share_memory(self._degree_tensors) def init_backend(self) -> None: From a0e84fab04f6811353c8f5737a3560743134c883 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 01:38:25 +0000 Subject: [PATCH 3/7] Update --- gigl/distributed/base_dist_loader.py | 36 ++++++--- gigl/distributed/dist_ppr_sampler.py | 81 +++++-------------- gigl/distributed/dist_sampling_producer.py | 8 +- .../shared_dist_sampling_producer.py | 17 ++-- gigl/distributed/sampler_options.py | 5 -- gigl/distributed/utils/dist_sampler.py | 5 +- 6 files changed, 58 insertions(+), 94 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 203c8520d..4e39273c5 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -39,6 +39,10 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_ppr_sampler import ( + build_ppr_node_type_to_edge_types, + build_ppr_total_degree_tensors, +) from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.compute import async_request_server from gigl.distributed.graph_store.dist_server import DistServer @@ -425,17 +429,27 @@ def create_mp_producer( """ channel = BaseDistLoader.create_colocated_channel(worker_options) if isinstance(sampler_options, PPRSamplerOptions): - degree_tensors = dataset.degree_tensor - if isinstance(degree_tensors, dict): - logger.info( - f"Pre-computed degree tensors for PPR sampling across " - f"{len(degree_tensors)} edge types." - ) - else: - logger.info( - f"Pre-computed degree tensor for PPR sampling with " - f"{degree_tensors.size(0)} nodes." - ) + assert dataset.graph is not None, ( + "DistDataset.graph must be set for PPR sampling" + ) + raw_degree_tensors = dataset.degree_tensor + is_homogeneous = not isinstance(dataset.graph, dict) + edge_types = ( + list(dataset.graph.keys()) if isinstance(dataset.graph, dict) else [] + ) + node_type_to_edge_types = build_ppr_node_type_to_edge_types( + is_homogeneous=is_homogeneous, + edge_types=edge_types, + edge_dir=dataset.edge_dir, + ) + degree_tensors = build_ppr_total_degree_tensors( + degree_tensors=raw_degree_tensors, + node_type_to_edge_types=node_type_to_edge_types, + ) + logger.info( + f"Pre-computed total degree tensors for PPR sampling across " + f"{len(degree_tensors)} node types." + ) else: degree_tensors = None return DistSamplingProducer( diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 9aaefbfa1..c6120cffa 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -71,14 +71,14 @@ def build_ppr_node_type_to_edge_types( def build_ppr_total_degree_tensors( degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - dtype: torch.dtype, node_type_to_edge_types: dict[NodeType, list[EdgeType]], ) -> dict[NodeType, torch.Tensor]: """Pre-compute total-degree tensors for the PPR forward-push kernel. - For homogeneous graphs converts the single degree tensor to ``dtype``. + For homogeneous graphs converts the single degree tensor to int16. For heterogeneous graphs sums per-edge-type degrees into a per-node-type - total, padding shorter tensors with zeros where node counts differ. + total (capped at int16 max), padding shorter tensors with zeros where node + counts differ. This function is intentionally standalone so it can be called once in the parent process (and the result shared across workers) rather than redundantly @@ -87,24 +87,24 @@ def build_ppr_total_degree_tensors( Args: degree_tensors: Per-edge-type degree tensors (homogeneous: single ``torch.Tensor``; heterogeneous: ``dict[EdgeType, torch.Tensor]``). - dtype: Target dtype for the output tensors. node_type_to_edge_types: Mapping from anchor NodeType to the list of EdgeTypes traversable from it, as returned by :func:`build_ppr_node_type_to_edge_types`. Returns: Dict mapping NodeType to a 1-D total-degree tensor of shape - ``[num_nodes_of_that_type]`` with dtype ``dtype``. + ``[num_nodes_of_that_type]`` with dtype ``torch.int16``, capped at + ``torch.iinfo(torch.int16).max``. Raises: ValueError: If a required edge type is missing from ``degree_tensors``. """ + _INT16_MAX = torch.iinfo(torch.int16).max result: dict[NodeType, torch.Tensor] = {} if isinstance(degree_tensors, torch.Tensor): - result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(torch.int16) else: - dtype_max = torch.iinfo(dtype).max for node_type, edge_types in node_type_to_edge_types.items(): max_len = 0 for et in edge_types: @@ -118,7 +118,7 @@ def build_ppr_total_degree_tensors( for et in edge_types: et_degrees = degree_tensors[et] summed[: len(et_degrees)] += et_degrees.to(torch.int64) - result[node_type] = summed.clamp(max=dtype_max).to(dtype) + result[node_type] = summed.clamp(max=_INT16_MAX).to(torch.int16) return result @@ -160,10 +160,10 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. - total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults - to ``torch.int32``. Use a larger dtype if nodes have exceptionally high - aggregate degrees. - degree_tensors: Pre-computed degree tensors from the dataset. + degree_tensors: Pre-computed total-degree tensors (int16, capped at + int16 max), keyed by NodeType. Must be pre-computed by the caller + (e.g. via :func:`build_ppr_total_degree_tensors`) so that workers + share a single allocation rather than recomputing per-worker. """ def __init__( @@ -173,8 +173,7 @@ def __init__( eps: float = 1e-4, max_ppr_nodes: int = 50, num_neighbors_per_hop: int = 100_000, - total_degree_dtype: torch.dtype = torch.int32, - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + degree_tensors: dict[NodeType, torch.Tensor], max_fetch_iterations: Optional[int] = None, **kwargs, ): @@ -216,30 +215,12 @@ def __init__( ] self._is_homogeneous = True - # Precompute total degree per node type: the sum of degrees across all - # edge types traversable from that node type. This is a graph-level - # property used on every PPR iteration, so computing it once at init - # avoids per-node summation and cache lookups in the hot loop. - # - # In graph-store mode, SharedDistSamplingProducer pre-computes the - # total-degree dict once in the parent process, moves it to shared - # memory, and passes it here as degree_tensors (keys are NodeType - # strings). In colocated mode degree_tensors arrives as raw - # per-edge-type tensors (keys are EdgeType tuples, or a bare Tensor - # for homogeneous graphs) and we compute the total here. - if ( - isinstance(degree_tensors, dict) - and degree_tensors - and not isinstance(next(iter(degree_tensors)), tuple) - ): - # Already the pre-computed total (NodeType string keys). - self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = ( - degree_tensors - ) - else: - self._node_type_to_total_degree = self._build_total_degree_tensors( - degree_tensors, total_degree_dtype - ) + # Total-degree tensors keyed by NodeType, pre-computed by the caller. + # Callers (create_mp_producer for colocated, SharedDistSamplingBackend + # for graph-store) run build_ppr_total_degree_tensors once in the parent + # process and place the result in shared memory so all worker processes + # map the same allocation. + self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = degree_tensors # Build integer ID mappings for the C++ forward-push kernel. String # NodeType / EdgeType keys are only used at the Python boundary @@ -285,32 +266,10 @@ def __init__( # Degree tensors indexed by ntype_id. Destination-only types get an empty # tensor; the C++ kernel returns 0 for those, matching _get_total_degree. self._degree_tensors_for_cpp: list[torch.Tensor] = [ - self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int32)) + self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int16)) for nt in all_node_types ] - def _build_total_degree_tensors( - self, - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - dtype: torch.dtype, - ) -> dict[NodeType, torch.Tensor]: - """Build total-degree tensors by summing per-edge-type degrees for each node type. - - Delegates to the module-level :func:`build_ppr_total_degree_tensors`. - - Args: - degree_tensors: Per-edge-type degree tensors from the dataset. - dtype: Dtype for the output tensors. - - Returns: - Dict mapping node type to a 1-D tensor of total degrees. - """ - return build_ppr_total_degree_tensors( - degree_tensors=degree_tensors, - dtype=dtype, - node_type_to_edge_types=self._node_type_to_edge_types, - ) - def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" return edge_type[0] if self.edge_dir == "in" else edge_type[-1] diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 3a51715e2..15d29a48c 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -30,7 +30,7 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType +from graphlearn_torch.typing import NodeType from graphlearn_torch.utils import seed_everything from torch._C import _set_worker_signal_handlers from torch.utils.data.dataloader import DataLoader @@ -55,7 +55,7 @@ def _sampling_worker_loop( sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], ): dist_sampler = None try: @@ -180,9 +180,7 @@ def __init__( worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = None, + degree_tensors: Optional[dict[NodeType, torch.Tensor]] = None, ): super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index b7838c02c..6712ac850 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -93,7 +93,7 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType +from graphlearn_torch.typing import EdgeType, NodeType from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger @@ -343,7 +343,7 @@ def _shared_sampling_worker_loop( event_queue: mp.Queue, mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], ) -> None: """Run one shared graph-store worker that schedules many input channels. @@ -899,15 +899,14 @@ def __init__( edge_types=edge_types, edge_dir=data.edge_dir, ) - self._degree_tensors: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = build_ppr_total_degree_tensors( - degree_tensors=degree_tensors, - dtype=sampler_options.total_degree_dtype, - node_type_to_edge_types=node_type_to_edge_types, + self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = ( + build_ppr_total_degree_tensors( + degree_tensors=degree_tensors, + node_type_to_edge_types=node_type_to_edge_types, + ) ) else: - self._degree_tensors = degree_tensors + self._degree_tensors = None share_memory(self._degree_tensors) def init_backend(self) -> None: diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index fccd7a3ba..08cd27352 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -10,7 +10,6 @@ from dataclasses import dataclass from typing import Optional, Union -import torch from graphlearn_torch.typing import EdgeType from gigl.common.logger import Logger @@ -58,9 +57,6 @@ class PPRSamplerOptions: hub nodes receive diminishing residual per neighbor, so capping the fetch has little effect on PPR accuracy while keeping per-hop RPC cost bounded. Set large to approximate fetching all neighbors. - total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults - to ``torch.int32``, which supports total degrees up to ~2 billion. - Use a larger dtype if nodes have exceptionally high aggregate degrees. max_fetch_iterations: Maximum number of iterations that issue RPC neighbor fetches. After this many fetch iterations, subsequent iterations push residuals using only already-cached neighbor lists (no new RPCs). @@ -73,7 +69,6 @@ class PPRSamplerOptions: eps: float = 1e-4 max_ppr_nodes: int = 50 num_neighbors_per_hop: int = 1_000 - total_degree_dtype: torch.dtype = torch.int32 max_fetch_iterations: Optional[int] = None diff --git a/gigl/distributed/utils/dist_sampler.py b/gigl/distributed/utils/dist_sampler.py index 0333f4138..db5dba1af 100644 --- a/gigl/distributed/utils/dist_sampler.py +++ b/gigl/distributed/utils/dist_sampler.py @@ -10,7 +10,7 @@ RemoteDistSamplingWorkerOptions, ) from graphlearn_torch.sampler import EdgeSamplerInput, NodeSamplerInput, SamplingConfig -from graphlearn_torch.typing import EdgeType +from graphlearn_torch.typing import NodeType from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler @@ -35,7 +35,7 @@ def create_dist_sampler( worker_options: Union[MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions], channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], current_device: torch.device, ) -> SamplerRuntime: """Create a GiGL sampler runtime for one channel on one worker. @@ -84,7 +84,6 @@ def create_dist_sampler( max_ppr_nodes=sampler_options.max_ppr_nodes, max_fetch_iterations=sampler_options.max_fetch_iterations, num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, - total_degree_dtype=sampler_options.total_degree_dtype, degree_tensors=degree_tensors, ) else: From 088fe1bfc5a93d98b25f51ffb3380feb2bd8ee48 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 19 May 2026 03:10:49 +0000 Subject: [PATCH 4/7] Improvements --- gigl/distributed/base_dist_loader.py | 26 +--- gigl/distributed/dist_dataset.py | 34 ++--- gigl/distributed/dist_ppr_sampler.py | 112 ++------------ .../shared_dist_sampling_producer.py | 44 +----- gigl/distributed/utils/degree.py | 139 +++++++++--------- 5 files changed, 103 insertions(+), 252 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 4e39273c5..496b32381 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -39,10 +39,6 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.dist_ppr_sampler import ( - build_ppr_node_type_to_edge_types, - build_ppr_total_degree_tensors, -) from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.compute import async_request_server from gigl.distributed.graph_store.dist_server import DistServer @@ -429,27 +425,7 @@ def create_mp_producer( """ channel = BaseDistLoader.create_colocated_channel(worker_options) if isinstance(sampler_options, PPRSamplerOptions): - assert dataset.graph is not None, ( - "DistDataset.graph must be set for PPR sampling" - ) - raw_degree_tensors = dataset.degree_tensor - is_homogeneous = not isinstance(dataset.graph, dict) - edge_types = ( - list(dataset.graph.keys()) if isinstance(dataset.graph, dict) else [] - ) - node_type_to_edge_types = build_ppr_node_type_to_edge_types( - is_homogeneous=is_homogeneous, - edge_types=edge_types, - edge_dir=dataset.edge_dir, - ) - degree_tensors = build_ppr_total_degree_tensors( - degree_tensors=raw_degree_tensors, - node_type_to_edge_types=node_type_to_edge_types, - ) - logger.info( - f"Pre-computed total degree tensors for PPR sampling across " - f"{len(degree_tensors)} node types." - ) + degree_tensors = dataset.degree_tensor else: degree_tensors = None return DistSamplingProducer( diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index cd38c5653..c0cf6f207 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -80,9 +80,7 @@ def __init__( edge_feature_info: Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ] = None, - degree_tensor: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = None, + degree_tensor: Optional[dict[NodeType, torch.Tensor]] = None, max_labels_per_anchor_node: Optional[int] = None, ) -> None: """ @@ -108,7 +106,7 @@ def __init__( Note this will be None in the homogeneous case if the data has no node features, or will only contain node types with node features in the heterogeneous case. edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. - degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. + degree_tensor: Optional[dict[NodeType, torch.Tensor]]: Pre-computed degree tensor keyed by node type. Lazily computed on first access via the degree_tensor property. max_labels_per_anchor_node (Optional[int]): Optional cap for how many labels to materialize per anchor node for ABLP label fetching. """ @@ -146,9 +144,7 @@ def __init__( self._node_feature_info = node_feature_info self._edge_feature_info = edge_feature_info - self._degree_tensor: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] - ] = degree_tensor + self._degree_tensor: Optional[dict[NodeType, torch.Tensor]] = degree_tensor self._max_labels_per_anchor_node = max_labels_per_anchor_node # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear @@ -307,13 +303,15 @@ def edge_feature_info( @property def degree_tensor( self, - ) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: + ) -> dict[NodeType, torch.Tensor]: """ - Lazily compute and return the degree tensor for the graph. + Lazily compute and return the total degree tensor per node type. On first access, computes node degrees from the graph partition and uses - all-reduce to aggregate across all machines. Requires torch.distributed - to be initialized. + all-reduce to aggregate across all machines. Degrees are summed across + all incident edge types per anchor node type before the all-reduce, so + the per-edge-type tensor is never stored. Requires torch.distributed to + be initialized. Over-counting correction (for processes sharing the same data on the same machine) is handled automatically by detecting the distributed topology. @@ -321,9 +319,9 @@ def degree_tensor( The result is cached for subsequent accesses. Returns: - Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensor. - - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping EdgeType to degree tensors. + dict[NodeType, torch.Tensor]: Total degree tensors keyed by node type. + For homogeneous graphs the single entry uses + ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Raises: RuntimeError: If torch.distributed is not initialized. @@ -333,7 +331,9 @@ def degree_tensor( if self.graph is None: raise ValueError("Dataset graph is None. Cannot compute degrees.") - self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) + self._degree_tensor = compute_and_broadcast_degree_tensor( + self.graph, self._edge_dir + ) return self._degree_tensor @property @@ -902,7 +902,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]]: Number of test nodes on the current machine. Will be a dict if heterogeneous. Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous - Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous + Optional[dict[NodeType, torch.Tensor]]: Degree tensors keyed by node type Optional[int]: Optional per-anchor label cap for ABLP label fetching """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function @@ -1188,7 +1188,7 @@ def _rebuild_distributed_dataset( Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ], # Edge feature dim and its data type - Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors + Optional[dict[NodeType, torch.Tensor]], # Degree tensors Optional[int], # Optional per-anchor label cap for ABLP label fetching ], ): diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index c6120cffa..69ea230f5 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -17,7 +17,7 @@ from graphlearn_torch.utils import merge_dict from gigl.distributed.base_sampler import BaseDistNeighborSampler -from gigl.types.graph import is_label_edge_type +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type # Trailing "." is an intentional separator. These constants are used both to # write metadata keys (f"{KEY}{repr(edge_type)}" → e.g. "ppr_edge_index.('user', 'to', 'story')") @@ -26,103 +26,17 @@ PPR_EDGE_INDEX_METADATA_KEY = "ppr_edge_index." PPR_WEIGHT_METADATA_KEY = "ppr_weight." -# Sentinel type names for homogeneous graphs. The PPR algorithm uses -# dict[NodeType, ...] internally for both homo and hetero graphs; these -# sentinels let the homogeneous path reuse the same dict-based code. -_PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type" +# Sentinel edge type for homogeneous graphs. The PPR algorithm uses +# dict[NodeType, ...] internally for both homo and hetero graphs; the +# DEFAULT_HOMOGENEOUS_NODE_TYPE sentinel lets the homogeneous path reuse +# the same dict-based code. _PPR_HOMOGENEOUS_EDGE_TYPE = ( - _PPR_HOMOGENEOUS_NODE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, "to", - _PPR_HOMOGENEOUS_NODE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, ) -def build_ppr_node_type_to_edge_types( - is_homogeneous: bool, - edge_types: list[EdgeType], - edge_dir: str, -) -> dict[NodeType, list[EdgeType]]: - """Build the node-type → edge-types mapping used by the PPR forward-push kernel. - - For homogeneous graphs returns the singleton sentinel mapping. For - heterogeneous graphs, groups non-label edge types by their anchor node type - (destination for ``edge_dir="in"``, source for ``edge_dir="out"``). - - Args: - is_homogeneous: True if the graph has a single node/edge type. - edge_types: All edge types present in the graph (ignored when homogeneous). - edge_dir: Sampling direction — ``"in"`` or ``"out"``. - - Returns: - Dict mapping each anchor NodeType to the list of EdgeTypes traversable - from it during a PPR walk. - """ - if is_homogeneous: - return {_PPR_HOMOGENEOUS_NODE_TYPE: [_PPR_HOMOGENEOUS_EDGE_TYPE]} - - node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict(list) - for etype in edge_types: - if is_label_edge_type(etype): - continue - anchor_type = etype[-1] if edge_dir == "in" else etype[0] - node_type_to_edge_types[anchor_type].append(etype) - return dict(node_type_to_edge_types) - - -def build_ppr_total_degree_tensors( - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - node_type_to_edge_types: dict[NodeType, list[EdgeType]], -) -> dict[NodeType, torch.Tensor]: - """Pre-compute total-degree tensors for the PPR forward-push kernel. - - For homogeneous graphs converts the single degree tensor to int16. - For heterogeneous graphs sums per-edge-type degrees into a per-node-type - total (capped at int16 max), padding shorter tensors with zeros where node - counts differ. - - This function is intentionally standalone so it can be called once in the - parent process (and the result shared across workers) rather than redundantly - inside each worker's ``DistPPRNeighborSampler.__init__``. - - Args: - degree_tensors: Per-edge-type degree tensors (homogeneous: single - ``torch.Tensor``; heterogeneous: ``dict[EdgeType, torch.Tensor]``). - node_type_to_edge_types: Mapping from anchor NodeType to the list of - EdgeTypes traversable from it, as returned by - :func:`build_ppr_node_type_to_edge_types`. - - Returns: - Dict mapping NodeType to a 1-D total-degree tensor of shape - ``[num_nodes_of_that_type]`` with dtype ``torch.int16``, capped at - ``torch.iinfo(torch.int16).max``. - - Raises: - ValueError: If a required edge type is missing from ``degree_tensors``. - """ - _INT16_MAX = torch.iinfo(torch.int16).max - result: dict[NodeType, torch.Tensor] = {} - - if isinstance(degree_tensors, torch.Tensor): - result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(torch.int16) - else: - for node_type, edge_types in node_type_to_edge_types.items(): - max_len = 0 - for et in edge_types: - if et not in degree_tensors: - raise ValueError( - f"Edge type {et} not found in degree tensors. " - f"Available: {list(degree_tensors.keys())}" - ) - max_len = max(max_len, len(degree_tensors[et])) - summed = torch.zeros(max_len, dtype=torch.int64) - for et in edge_types: - et_degrees = degree_tensors[et] - summed[: len(et_degrees)] += et_degrees.to(torch.int64) - result[node_type] = summed.clamp(max=_INT16_MAX).to(torch.int16) - - return result - - class DistPPRNeighborSampler(BaseDistNeighborSampler): """Personalized PageRank (PPR) based distributed neighbor sampler. @@ -210,7 +124,7 @@ def __init__( self._node_type_to_edge_types[anchor_type].append(etype) else: - self._node_type_to_edge_types[_PPR_HOMOGENEOUS_NODE_TYPE] = [ + self._node_type_to_edge_types[DEFAULT_HOMOGENEOUS_NODE_TYPE] = [ _PPR_HOMOGENEOUS_EDGE_TYPE ] self._is_homogeneous = True @@ -389,7 +303,7 @@ async def _compute_ppr_scores( valid_counts = tensor([1, 3, 2, 0]) """ if seed_node_type is None: - seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE + seed_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE device = seed_nodes.device ppr_state = PPRForwardPush( @@ -449,12 +363,12 @@ async def _compute_ppr_scores( if self._is_homogeneous: assert ( len(ntype_to_flat_ids) == 1 - and _PPR_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids + and DEFAULT_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids ) return ( - ntype_to_flat_ids[_PPR_HOMOGENEOUS_NODE_TYPE], - ntype_to_flat_weights[_PPR_HOMOGENEOUS_NODE_TYPE], - ntype_to_valid_counts[_PPR_HOMOGENEOUS_NODE_TYPE], + ntype_to_flat_ids[DEFAULT_HOMOGENEOUS_NODE_TYPE], + ntype_to_flat_weights[DEFAULT_HOMOGENEOUS_NODE_TYPE], + ntype_to_valid_counts[DEFAULT_HOMOGENEOUS_NODE_TYPE], ) else: return ( diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 6712ac850..b45f8deae 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -93,15 +93,11 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType, NodeType +from graphlearn_torch.typing import NodeType from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger -from gigl.distributed.dist_ppr_sampler import ( - build_ppr_node_type_to_edge_types, - build_ppr_total_degree_tensors, -) -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.dist_sampler import ( SamplerInput, SamplerRuntime, @@ -840,7 +836,7 @@ def __init__( worker_options: RemoteDistSamplingWorkerOptions, sampling_config: SamplingConfig, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[dict[NodeType, torch.Tensor]], ) -> None: """Initialize the shared sampling backend. @@ -876,37 +872,9 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) - # For PPR sampling, pre-compute the total-degree dict (summed across edge - # types, converted to the target dtype) once here in the parent process. - # Workers receive the result directly as degree_tensors and skip the - # per-worker summation in DistPPRNeighborSampler._build_total_degree_tensors. - # - # Then move to shared memory so all spawned workers map the same - # allocation instead of each pickling a private copy. In colocated mode - # DistDataset.to_ipc_handle() handles shared memory; here the tensors - # arrive via RPC and are plain heap allocations without this call. - if ( - isinstance(sampler_options, PPRSamplerOptions) - and degree_tensors is not None - ): - assert data.graph is not None, ( - "DistDataset.graph must be set for PPR sampling" - ) - is_homogeneous = not isinstance(data.graph, dict) - edge_types = list(data.graph.keys()) if isinstance(data.graph, dict) else [] - node_type_to_edge_types = build_ppr_node_type_to_edge_types( - is_homogeneous=is_homogeneous, - edge_types=edge_types, - edge_dir=data.edge_dir, - ) - self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = ( - build_ppr_total_degree_tensors( - degree_tensors=degree_tensors, - node_type_to_edge_types=node_type_to_edge_types, - ) - ) - else: - self._degree_tensors = None + # Move degree tensors to shared memory so all spawned workers map the + # same allocation instead of each pickling a private copy. + self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = degree_tensors share_memory(self._degree_tensors) def init_backend(self) -> None: diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 7374f53ed..eab3e7ec3 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -5,8 +5,9 @@ and aggregate them across distributed machines. Degrees are computed from the CSR (Compressed Sparse Row) topology stored in GraphLearn-Torch Graph objects. -Note: Degree tensors are not moved to shared memory and may be duplicated across -processes on the same machine. +Degrees are accumulated per anchor node type (summing across all edge types +incident to that node type) before the distributed all-reduce, so callers +receive ``dict[NodeType, torch.Tensor]`` directly with no further conversion. Requirements ============ @@ -27,24 +28,28 @@ import torch from graphlearn_torch.data import Graph +from graphlearn_torch.typing import NodeType from torch_geometric.typing import EdgeType from gigl.common.logger import Logger from gigl.distributed.utils.device import get_device_from_process_group from gigl.distributed.utils.networking import get_internal_ip_from_all_ranks -from gigl.types.graph import is_label_edge_type +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type logger = Logger() def compute_and_broadcast_degree_tensor( graph: Union[Graph, dict[EdgeType, Graph]], -) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: - """ - Compute node degrees from a graph and aggregate across all machines. + edge_dir: str, +) -> dict[NodeType, torch.Tensor]: + """Compute node degrees from a graph and aggregate across all machines. - Computes degrees from the CSR row pointers (indptr) and performs all-reduce - to aggregate across ranks. + For each non-label edge type, degrees are derived from the CSR row pointers + (indptr). For heterogeneous graphs, degrees are summed across all edge types + incident to each anchor node type **locally** before the all-reduce, so the + per-edge-type tensor is only a transient intermediate and is never stored, + returned, or transmitted over RPC. Over-counting correction (for processes sharing the same data) is handled automatically by detecting the distributed topology. @@ -52,13 +57,17 @@ def compute_and_broadcast_degree_tensor( Args: graph: A Graph (homogeneous) or dict[EdgeType, Graph] (heterogeneous). For heterogeneous graphs, label edge types are automatically excluded - from the computation — they are supervision edges and should not - contribute to node degree for graph traversal algorithms like PPR. + — they are supervision edges and should not contribute to node degree + for graph traversal algorithms like PPR. + edge_dir: Sampling direction — ``"in"`` or ``"out"``. Determines which + end of each edge is the anchor node type for degree accumulation. Returns: - Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensors. - - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping non-label EdgeType to degree tensors. + dict[NodeType, torch.Tensor]: Aggregated degree tensors keyed by node + type. For homogeneous graphs the single entry uses + ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Values are int16 + tensors of shape ``[num_nodes_of_that_type]``, capped at + ``torch.iinfo(torch.int16).max``. Raises: RuntimeError: If torch.distributed is not initialized. @@ -69,52 +78,51 @@ def compute_and_broadcast_degree_tensor( "compute_and_broadcast_degree_tensor requires torch.distributed to be initialized." ) - # Compute local degrees from graph topology + local_dict: dict[NodeType, torch.Tensor] = {} + if isinstance(graph, Graph): topo = graph.topo if topo is None or topo.indptr is None: raise ValueError("Topology/indptr not available for graph.") - local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] = ( - _compute_degrees_from_indptr(topo.indptr) + local_dict[DEFAULT_HOMOGENEOUS_NODE_TYPE] = _compute_degrees_from_indptr( + topo.indptr ) else: - local_dict: dict[EdgeType, torch.Tensor] = {} for edge_type, edge_graph in graph.items(): - # Label edge types are supervision edges and should not contribute - # to node degree for graph traversal algorithms like PPR. if is_label_edge_type(edge_type): continue + anchor_type: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] topo = edge_graph.topo if topo is None or topo.indptr is None: logger.warning( f"Topology/indptr not available for edge type {edge_type}, using empty tensor." ) - local_dict[edge_type] = torch.empty(0, dtype=torch.int16) + degrees = torch.empty(0, dtype=torch.int16) + else: + degrees = _compute_degrees_from_indptr(topo.indptr) + + if anchor_type in local_dict: + # Accumulate in int64 to avoid overflow, clamp back to int16 + existing = local_dict[anchor_type] + max_len = max(len(existing), len(degrees)) + summed = _pad_to_size(existing, max_len).to(torch.int64) + summed[: len(degrees)] += degrees.to(torch.int64) + local_dict[anchor_type] = _clamp_to_int16(summed) else: - local_dict[edge_type] = _compute_degrees_from_indptr(topo.indptr) - local_degrees = local_dict + local_dict[anchor_type] = degrees - # All-reduce across ranks (over-counting correction handled internally) - result = _all_reduce_degrees(local_degrees) + result = _all_reduce_degrees(local_dict) - # Log results - if isinstance(result, torch.Tensor): - if result.numel() > 0: + for node_type, degrees in result.items(): + if degrees.numel() > 0: logger.info( - f"{result.size(0)} nodes, max={result.max().item()}, min={result.min().item()}" + f"{node_type}: {degrees.size(0)} nodes, " + f"max={degrees.max().item()}, min={degrees.min().item()}" ) else: - logger.info("Graph contained 0 nodes when computing degrees") - else: - for edge_type, degrees in result.items(): - if degrees.numel() > 0: - logger.info( - f"{edge_type}: {degrees.size(0)} nodes, max={degrees.max().item()}, min={degrees.min().item()}" - ) - else: - logger.info( - f"Graph contained 0 nodes for edge type {edge_type} when computing degrees" - ) + logger.info( + f"Graph contained 0 nodes for node type {node_type} when computing degrees" + ) return result @@ -143,21 +151,19 @@ def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: def _all_reduce_degrees( - local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], -) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: - """All-reduce degree tensors across ranks, handling both homogeneous and heterogeneous cases. + local_degrees: dict[NodeType, torch.Tensor], +) -> dict[NodeType, torch.Tensor]: + """All-reduce degree tensors across ranks. - For heterogeneous graphs, iterates over the edge types in local_degrees. All partitions - are expected to have entries for all edge types (even if some have empty tensors). - - Moves tensors to GPU for the all-reduce if using NCCL backend (which requires CUDA), - otherwise keeps tensors on CPU (for Gloo backend). + Moves tensors to GPU for the all-reduce if using NCCL backend (which + requires CUDA), otherwise keeps tensors on CPU (for Gloo backend). Over-counting correction: - In distributed training, multiple processes on the same machine often share the - same graph partition data (via shared memory). When we all-reduce degrees, each - process contributes its "local" degrees - but if 4 processes on one machine all - read the same partition, that partition's degrees get summed 4 times instead of 1. + In distributed training, multiple processes on the same machine often + share the same graph partition data (via shared memory). When we + all-reduce degrees, each process contributes its "local" degrees — but + if 4 processes on one machine all read the same partition, that + partition's degrees get summed 4 times instead of 1. Example: Machine A has 2 processes sharing partition with degrees [3, 5, 2]. Machine B has 2 processes sharing partition with degrees [1, 4, 6]. @@ -168,16 +174,16 @@ def _all_reduce_degrees( With correction: divide by local_world_size (2 per machine) = [4, 9, 8] (correct: [3+1, 5+4, 2+6]) - This function detects how many processes share the same machine by comparing - IP addresses, then divides by that count to correct the over-counting. + This function detects how many processes share the same machine by + comparing IP addresses, then divides by that count to correct the + over-counting. Args: - local_degrees: Either a single tensor (homogeneous) or dict mapping EdgeType - to tensors (heterogeneous). For heterogeneous graphs, all partitions must - have entries for all edge types. + local_degrees: Dict mapping NodeType to local degree tensors. + All partitions must have entries for all node types. Returns: - Aggregated degree tensors in the same format as input. + Aggregated degree tensors keyed by NodeType. Raises: RuntimeError: If torch.distributed is not initialized. @@ -187,38 +193,25 @@ def _all_reduce_degrees( "_all_reduce_degrees requires torch.distributed to be initialized." ) - # Compute local_world_size: number of processes on the same machine sharing data all_ips = get_internal_ip_from_all_ranks() my_rank = torch.distributed.get_rank() my_ip = all_ips[my_rank] local_world_size = Counter(all_ips)[my_ip] - # NCCL backend requires CUDA tensors; Gloo works with CPU device = get_device_from_process_group() def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: """All-reduce a single tensor with size sync and over-counting correction.""" - # Synchronize max size across all ranks local_size = torch.tensor([tensor.size(0)], dtype=torch.long, device=device) torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) max_size = int(local_size.item()) - # Pad, convert to int64 (all_reduce doesn't support int16), move to device padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) - # Correct for over-counting, move back to CPU, and clamp to int16 - # TODO (mkolodner-sc): Potentially want to paramaterize this in the future if we want degrees higher than the int16 max. return _clamp_to_int16((padded // local_world_size).cpu()) - # Homogeneous case - if isinstance(local_degrees, torch.Tensor): - return reduce_tensor(local_degrees) - - # Heterogeneous case: all-reduce each edge type - # Sort edge types for deterministic ordering across ranks - result: dict[EdgeType, torch.Tensor] = {} - for edge_type in sorted(local_degrees.keys()): - result[edge_type] = reduce_tensor(local_degrees[edge_type]) - + result: dict[NodeType, torch.Tensor] = {} + for node_type in sorted(local_degrees.keys()): + result[node_type] = reduce_tensor(local_degrees[node_type]) return result From 4b04e90171b24ae1efe109189584d5c759c0ab4b Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Mon, 18 May 2026 22:20:46 -0700 Subject: [PATCH 5/7] readout mode --- gigl/nn/graph_transformer.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index f9d8b345a..f44f06e6d 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -366,8 +366,8 @@ class GraphTransformerEncoder(nn.Module): Converts heterogeneous graph data into fixed-length sequences via ``heterodata_to_graph_transformer_input``, processes through pre-norm - transformer encoder layers, and produces per-node embeddings via - attention-weighted neighbor readout (from RelGT's LocalModule). + transformer encoder layers, and produces per-node embeddings via a + configurable readout over the anchor token and its sequence context. Conforms to the same forward interface as ``HGT`` and ``SimpleHGN``, making it a drop-in encoder for ``LinkPredictionGNN``. @@ -435,6 +435,10 @@ class GraphTransformerEncoder(nn.Module): feature_embedding_layer_dict: Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None) + readout_mode: Readout applied after the transformer encoder stack. + ``"anchor_neighbor_attention"`` preserves the current RelGT-style + anchor-plus-neighbor attention pooling. ``"anchor_only"`` returns + the normalized anchor token directly. pe_integration_mode: How to fuse positional encodings into the model input. ``"concat"`` preserves the current behavior by concatenating node-level PE to token features. ``"add"`` uses node-level additive @@ -496,6 +500,9 @@ def __init__( anchor_based_input_embedding_dict: Optional[nn.ModuleDict] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, feature_embedding_layer_dict: Optional[nn.ModuleDict] = None, + readout_mode: Literal["anchor_neighbor_attention", "anchor_only"] = ( + "anchor_neighbor_attention" + ), pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", feedforward_ratio: Optional[float] = None, @@ -540,6 +547,12 @@ def __init__( "sequence_construction_method='ppr' because khop sequences do not " "enforce a stable token order." ) + if readout_mode not in {"anchor_neighbor_attention", "anchor_only"}: + raise ValueError( + "readout_mode must be one of " + "{'anchor_neighbor_attention', 'anchor_only'}, " + f"got '{readout_mode}'" + ) anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] @@ -569,6 +582,7 @@ def __init__( self._anchor_based_input_embedding_dict = anchor_based_input_embedding_dict self._pairwise_attention_bias_attr_names = pairwise_attention_bias_attr_names self._feature_embedding_layer_dict = feature_embedding_layer_dict + self._readout_mode = readout_mode self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads anchor_input_embedding_attr_names = ( @@ -671,7 +685,8 @@ def __init__( self._final_norm = nn.LayerNorm(hid_dim) - # Readout attention: projects concatenated (anchor, neighbor) to score + # Always instantiate the neighbor readout head so checkpoints can move + # between readout modes without changing parameter shapes. self._readout_attention = nn.Linear(2 * hid_dim, 1) # Output projection: hid_dim -> out_dim @@ -1037,7 +1052,7 @@ def _encode_and_readout( valid_mask: Tensor, attn_bias: Optional[Tensor] = None, ) -> Tensor: - """Process sequences through transformer layers and attention readout. + """Process sequences through transformer layers and configured readout. Args: sequences: Input tensor of shape ``(batch_size, max_seq_len, hid_dim)``. @@ -1056,7 +1071,11 @@ def _encode_and_readout( x = self._final_norm(x) x = x * valid_mask.unsqueeze(-1).to(x.dtype) - # Readout: anchor (position 0) + attention-weighted neighbor aggregation + if self._readout_mode == "anchor_only": + return x[:, 0, :] + + # RelGT-style readout: anchor (position 0) + attention-weighted + # neighbor aggregation. anchor = x[:, 0, :].unsqueeze(1) # (batch, 1, hid_dim) neighbors = x[:, 1:, :] # (batch, seq-1, hid_dim) neighbor_valid_mask = valid_mask[:, 1:] From 977db41544676216a7cfac41e457dc46c0a5bbc6 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Wed, 20 May 2026 00:53:43 -0700 Subject: [PATCH 6/7] updates --- gigl/distributed/base_dist_loader.py | 10 ++ gigl/distributed/dist_dataset.py | 55 ++++--- gigl/distributed/dist_ppr_sampler.py | 119 +++++++++++---- gigl/distributed/dist_sampling_producer.py | 8 +- .../shared_dist_sampling_producer.py | 12 +- gigl/distributed/sampler_options.py | 5 + gigl/distributed/utils/degree.py | 139 +++++++++--------- gigl/distributed/utils/dist_sampler.py | 5 +- gigl/nn/graph_transformer.py | 17 +++ gigl/transforms/graph_transformer.py | 104 ++++++++++--- tests/unit/nn/graph_transformer_test.py | 48 +++++- .../unit/transforms/graph_transformer_test.py | 44 +++++- 12 files changed, 403 insertions(+), 163 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 496b32381..203c8520d 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -426,6 +426,16 @@ def create_mp_producer( channel = BaseDistLoader.create_colocated_channel(worker_options) if isinstance(sampler_options, PPRSamplerOptions): degree_tensors = dataset.degree_tensor + if isinstance(degree_tensors, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across " + f"{len(degree_tensors)} edge types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with " + f"{degree_tensors.size(0)} nodes." + ) else: degree_tensors = None return DistSamplingProducer( diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index c0cf6f207..b40f2969a 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -80,7 +80,9 @@ def __init__( edge_feature_info: Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ] = None, - degree_tensor: Optional[dict[NodeType, torch.Tensor]] = None, + degree_tensor: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = None, max_labels_per_anchor_node: Optional[int] = None, ) -> None: """ @@ -106,7 +108,7 @@ def __init__( Note this will be None in the homogeneous case if the data has no node features, or will only contain node types with node features in the heterogeneous case. edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. - degree_tensor: Optional[dict[NodeType, torch.Tensor]]: Pre-computed degree tensor keyed by node type. Lazily computed on first access via the degree_tensor property. + degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. max_labels_per_anchor_node (Optional[int]): Optional cap for how many labels to materialize per anchor node for ABLP label fetching. """ @@ -144,7 +146,9 @@ def __init__( self._node_feature_info = node_feature_info self._edge_feature_info = edge_feature_info - self._degree_tensor: Optional[dict[NodeType, torch.Tensor]] = degree_tensor + self._degree_tensor: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = degree_tensor self._max_labels_per_anchor_node = max_labels_per_anchor_node # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear @@ -303,15 +307,13 @@ def edge_feature_info( @property def degree_tensor( self, - ) -> dict[NodeType, torch.Tensor]: + ) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: """ - Lazily compute and return the total degree tensor per node type. + Lazily compute and return the degree tensor for the graph. On first access, computes node degrees from the graph partition and uses - all-reduce to aggregate across all machines. Degrees are summed across - all incident edge types per anchor node type before the all-reduce, so - the per-edge-type tensor is never stored. Requires torch.distributed to - be initialized. + all-reduce to aggregate across all machines. Requires torch.distributed + to be initialized. Over-counting correction (for processes sharing the same data on the same machine) is handled automatically by detecting the distributed topology. @@ -319,9 +321,9 @@ def degree_tensor( The result is cached for subsequent accesses. Returns: - dict[NodeType, torch.Tensor]: Total degree tensors keyed by node type. - For homogeneous graphs the single entry uses - ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. + Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensor. + - For homogeneous graphs: A tensor of shape [num_nodes]. + - For heterogeneous graphs: A dict mapping EdgeType to degree tensors. Raises: RuntimeError: If torch.distributed is not initialized. @@ -331,9 +333,7 @@ def degree_tensor( if self.graph is None: raise ValueError("Dataset graph is None. Cannot compute degrees.") - self._degree_tensor = compute_and_broadcast_degree_tensor( - self.graph, self._edge_dir - ) + self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) return self._degree_tensor @property @@ -499,16 +499,11 @@ def _initialize_node_ids( ) else: train_nodes, val_nodes, test_nodes = splits - self._num_train = ( - train_nodes.numel() # ty: ignore[unresolved-attribute] - ) - self._num_val = val_nodes.numel() # ty: ignore[unresolved-attribute] - self._num_test = test_nodes.numel() # ty: ignore[unresolved-attribute] + self._num_train = train_nodes.numel() + self._num_val = val_nodes.numel() + self._num_test = test_nodes.numel() self._node_ids = _append_non_split_node_ids( - train_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - val_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - test_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - node_ids_on_machine, + train_nodes, val_nodes, test_nodes, node_ids_on_machine ) else: logger.info( @@ -642,8 +637,8 @@ def _initialize_node_features( # if it is not an edge type, since it must be one of the two. assert not isinstance(node_type, EdgeType) self._node_feature_info[node_type] = FeatureInfo( - dim=node_features_per_node_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - dtype=node_features_per_node_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + dim=node_features_per_node_type.size(1), + dtype=node_features_per_node_type.dtype, ) logger.info( f"Initialized node features for heterogeneous graph to dataset with node types: {node_features.keys()}" @@ -725,8 +720,8 @@ def _initialize_edge_features( for edge_type, edge_features_per_edge_type in edge_features.items(): assert isinstance(edge_type, EdgeType) self._edge_feature_info[edge_type] = FeatureInfo( - dim=edge_features_per_edge_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - dtype=edge_features_per_edge_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + dim=edge_features_per_edge_type.size(1), + dtype=edge_features_per_edge_type.dtype, ) logger.info( f"Initialized edge features for heterogeneous graph to dataset with edge types: {edge_features.keys()}" @@ -902,7 +897,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]]: Number of test nodes on the current machine. Will be a dict if heterogeneous. Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous - Optional[dict[NodeType, torch.Tensor]]: Degree tensors keyed by node type + Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous Optional[int]: Optional per-anchor label cap for ABLP label fetching """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function @@ -1188,7 +1183,7 @@ def _rebuild_distributed_dataset( Optional[ Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ], # Edge feature dim and its data type - Optional[dict[NodeType, torch.Tensor]], # Degree tensors + Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors Optional[int], # Optional per-anchor label cap for ABLP label fetching ], ): diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 69ea230f5..83369d8c2 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -17,7 +17,7 @@ from graphlearn_torch.utils import merge_dict from gigl.distributed.base_sampler import BaseDistNeighborSampler -from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type +from gigl.types.graph import is_label_edge_type # Trailing "." is an intentional separator. These constants are used both to # write metadata keys (f"{KEY}{repr(edge_type)}" → e.g. "ppr_edge_index.('user', 'to', 'story')") @@ -26,14 +26,14 @@ PPR_EDGE_INDEX_METADATA_KEY = "ppr_edge_index." PPR_WEIGHT_METADATA_KEY = "ppr_weight." -# Sentinel edge type for homogeneous graphs. The PPR algorithm uses -# dict[NodeType, ...] internally for both homo and hetero graphs; the -# DEFAULT_HOMOGENEOUS_NODE_TYPE sentinel lets the homogeneous path reuse -# the same dict-based code. +# Sentinel type names for homogeneous graphs. The PPR algorithm uses +# dict[NodeType, ...] internally for both homo and hetero graphs; these +# sentinels let the homogeneous path reuse the same dict-based code. +_PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type" _PPR_HOMOGENEOUS_EDGE_TYPE = ( - DEFAULT_HOMOGENEOUS_NODE_TYPE, + _PPR_HOMOGENEOUS_NODE_TYPE, "to", - DEFAULT_HOMOGENEOUS_NODE_TYPE, + _PPR_HOMOGENEOUS_NODE_TYPE, ) @@ -74,10 +74,10 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. - degree_tensors: Pre-computed total-degree tensors (int16, capped at - int16 max), keyed by NodeType. Must be pre-computed by the caller - (e.g. via :func:`build_ppr_total_degree_tensors`) so that workers - share a single allocation rather than recomputing per-worker. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults + to ``torch.int32``. Use a larger dtype if nodes have exceptionally high + aggregate degrees. + degree_tensors: Pre-computed degree tensors from the dataset. """ def __init__( @@ -87,7 +87,8 @@ def __init__( eps: float = 1e-4, max_ppr_nodes: int = 50, num_neighbors_per_hop: int = 100_000, - degree_tensors: dict[NodeType, torch.Tensor], + total_degree_dtype: torch.dtype = torch.int32, + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], max_fetch_iterations: Optional[int] = None, **kwargs, ): @@ -124,17 +125,23 @@ def __init__( self._node_type_to_edge_types[anchor_type].append(etype) else: - self._node_type_to_edge_types[DEFAULT_HOMOGENEOUS_NODE_TYPE] = [ + self._node_type_to_edge_types[_PPR_HOMOGENEOUS_NODE_TYPE] = [ _PPR_HOMOGENEOUS_EDGE_TYPE ] self._is_homogeneous = True - # Total-degree tensors keyed by NodeType, pre-computed by the caller. - # Callers (create_mp_producer for colocated, SharedDistSamplingBackend - # for graph-store) run build_ppr_total_degree_tensors once in the parent - # process and place the result in shared memory so all worker processes - # map the same allocation. - self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = degree_tensors + # Precompute total degree per node type: the sum of degrees across all + # edge types traversable from that node type. This is a graph-level + # property used on every PPR iteration, so computing it once at init + # avoids per-node summation and cache lookups in the hot loop. + # TODO (mkolodner-sc): This trades memory for throughput — we + # materialize a tensor per node type to avoid recomputing total degree + # on every neighbor during sampling. Computing it here (rather than in + # the dataset) also keeps the door open for edge-specific degree + # strategies. If memory becomes a bottleneck, revisit this. + self._node_type_to_total_degree: dict[NodeType, torch.Tensor] = ( + self._build_total_degree_tensors(degree_tensors, total_degree_dtype) + ) # Build integer ID mappings for the C++ forward-push kernel. String # NodeType / EdgeType keys are only used at the Python boundary @@ -180,10 +187,62 @@ def __init__( # Degree tensors indexed by ntype_id. Destination-only types get an empty # tensor; the C++ kernel returns 0 for those, matching _get_total_degree. self._degree_tensors_for_cpp: list[torch.Tensor] = [ - self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int16)) + self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int32)) for nt in all_node_types ] + def _build_total_degree_tensors( + self, + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + dtype: torch.dtype, + ) -> dict[NodeType, torch.Tensor]: + """Build total-degree tensors by summing per-edge-type degrees for each node type. + + For homogeneous graphs, the total degree is just the single degree tensor. + For heterogeneous graphs, it sums degree tensors across all edge types + traversable from each node type, padding shorter tensors with zeros. + + Args: + degree_tensors: Per-edge-type degree tensors from the dataset. + dtype: Dtype for the output tensors. + + Returns: + Dict mapping node type to a 1-D tensor of total degrees. + """ + result: dict[NodeType, torch.Tensor] = {} + + if self._is_homogeneous: + assert isinstance(degree_tensors, torch.Tensor) + # Single edge type: degree values fit directly in the target dtype. + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + else: + assert isinstance(degree_tensors, dict) + dtype_max = torch.iinfo(dtype).max + for node_type, edge_types in self._node_type_to_edge_types.items(): + max_len = 0 + for et in edge_types: + if et not in degree_tensors: + raise ValueError( + f"Edge type {et} not found in degree tensors. " + f"Available: {list(degree_tensors.keys())}" + ) + max_len = max(max_len, len(degree_tensors[et])) + + # Each degree tensor is indexed by node ID (derived from CSR + # indptr), so index i in every edge type's tensor refers to + # the same node. Element-wise summation gives the total degree + # per node across all edge types. Shorter tensors are padded + # implicitly (only the first len(et_degrees) entries are added). + # Sum in int64: aggregate degrees are bounded by partition size + # and fit comfortably within int64 range in practice. + summed = torch.zeros(max_len, dtype=torch.int64) + for et in edge_types: + et_degrees = degree_tensors[et] + summed[: len(et_degrees)] += et_degrees.to(torch.int64) + result[node_type] = summed.clamp(max=dtype_max).to(dtype) + + return result + def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" return edge_type[0] if self.edge_dir == "in" else edge_type[-1] @@ -303,7 +362,7 @@ async def _compute_ppr_scores( valid_counts = tensor([1, 3, 2, 0]) """ if seed_node_type is None: - seed_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE device = seed_nodes.device ppr_state = PPRForwardPush( @@ -363,12 +422,12 @@ async def _compute_ppr_scores( if self._is_homogeneous: assert ( len(ntype_to_flat_ids) == 1 - and DEFAULT_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids + and _PPR_HOMOGENEOUS_NODE_TYPE in ntype_to_flat_ids ) return ( - ntype_to_flat_ids[DEFAULT_HOMOGENEOUS_NODE_TYPE], - ntype_to_flat_weights[DEFAULT_HOMOGENEOUS_NODE_TYPE], - ntype_to_valid_counts[DEFAULT_HOMOGENEOUS_NODE_TYPE], + ntype_to_flat_ids[_PPR_HOMOGENEOUS_NODE_TYPE], + ntype_to_flat_weights[_PPR_HOMOGENEOUS_NODE_TYPE], + ntype_to_valid_counts[_PPR_HOMOGENEOUS_NODE_TYPE], ) else: return ( @@ -478,7 +537,7 @@ async def _sample_from_nodes( seed_types = list(nodes_to_sample.keys()) ppr_results = await asyncio.gather( *[ - self._compute_ppr_scores(nodes_to_sample[seed_type], seed_type) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + self._compute_ppr_scores(nodes_to_sample[seed_type], seed_type) for seed_type in seed_types ] ) @@ -497,20 +556,20 @@ async def _sample_from_nodes( for ntype, flat_ids in ntype_to_flat_ids.items(): ppr_edge_type: EdgeType = (seed_type, "ppr", ntype) - valid_counts = ntype_to_valid_counts[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + valid_counts = ntype_to_valid_counts[ntype] ppr_edge_type_to_flat_weights[ppr_edge_type] = ( - ntype_to_flat_weights[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + ntype_to_flat_weights[ntype] ) # Skip empty pairs; induce_next handles deduplication across # seed types so a neighbor reachable from multiple seed types # gets one consistent local index in node_dict[ntype]. - if flat_ids.numel() > 0: # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + if flat_ids.numel() > 0: nbr_dict[ppr_edge_type] = [ src_dict[seed_type], flat_ids, valid_counts, - ] # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. + ] # induce_next processes all PPR edge types in nbr_dict in one # pass, assigning local indices to neighbors not yet registered and diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 15d29a48c..3a51715e2 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -30,7 +30,7 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import NodeType +from graphlearn_torch.typing import EdgeType from graphlearn_torch.utils import seed_everything from torch._C import _set_worker_signal_handlers from torch.utils.data.dataloader import DataLoader @@ -55,7 +55,7 @@ def _sampling_worker_loop( sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ): dist_sampler = None try: @@ -180,7 +180,9 @@ def __init__( worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]] = None, + degree_tensors: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = None, ): super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index b45f8deae..0f7461196 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -93,7 +93,7 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import NodeType +from graphlearn_torch.typing import EdgeType from torch._C import _set_worker_signal_handlers from gigl.common.logger import Logger @@ -103,7 +103,6 @@ SamplerRuntime, create_dist_sampler, ) -from gigl.utils.share_memory import share_memory logger = Logger() @@ -339,7 +338,7 @@ def _shared_sampling_worker_loop( event_queue: mp.Queue, mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ) -> None: """Run one shared graph-store worker that schedules many input channels. @@ -836,7 +835,7 @@ def __init__( worker_options: RemoteDistSamplingWorkerOptions, sampling_config: SamplingConfig, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ) -> None: """Initialize the shared sampling backend. @@ -872,10 +871,7 @@ def __init__( self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( set ) - # Move degree tensors to shared memory so all spawned workers map the - # same allocation instead of each pickling a private copy. - self._degree_tensors: Optional[dict[NodeType, torch.Tensor]] = degree_tensors - share_memory(self._degree_tensors) + self._degree_tensors = degree_tensors def init_backend(self) -> None: """Initialize worker processes once for this backend. diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index 08cd27352..fccd7a3ba 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from typing import Optional, Union +import torch from graphlearn_torch.typing import EdgeType from gigl.common.logger import Logger @@ -57,6 +58,9 @@ class PPRSamplerOptions: hub nodes receive diminishing residual per neighbor, so capping the fetch has little effect on PPR accuracy while keeping per-hop RPC cost bounded. Set large to approximate fetching all neighbors. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults + to ``torch.int32``, which supports total degrees up to ~2 billion. + Use a larger dtype if nodes have exceptionally high aggregate degrees. max_fetch_iterations: Maximum number of iterations that issue RPC neighbor fetches. After this many fetch iterations, subsequent iterations push residuals using only already-cached neighbor lists (no new RPCs). @@ -69,6 +73,7 @@ class PPRSamplerOptions: eps: float = 1e-4 max_ppr_nodes: int = 50 num_neighbors_per_hop: int = 1_000 + total_degree_dtype: torch.dtype = torch.int32 max_fetch_iterations: Optional[int] = None diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index eab3e7ec3..7374f53ed 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -5,9 +5,8 @@ and aggregate them across distributed machines. Degrees are computed from the CSR (Compressed Sparse Row) topology stored in GraphLearn-Torch Graph objects. -Degrees are accumulated per anchor node type (summing across all edge types -incident to that node type) before the distributed all-reduce, so callers -receive ``dict[NodeType, torch.Tensor]`` directly with no further conversion. +Note: Degree tensors are not moved to shared memory and may be duplicated across +processes on the same machine. Requirements ============ @@ -28,28 +27,24 @@ import torch from graphlearn_torch.data import Graph -from graphlearn_torch.typing import NodeType from torch_geometric.typing import EdgeType from gigl.common.logger import Logger from gigl.distributed.utils.device import get_device_from_process_group from gigl.distributed.utils.networking import get_internal_ip_from_all_ranks -from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type +from gigl.types.graph import is_label_edge_type logger = Logger() def compute_and_broadcast_degree_tensor( graph: Union[Graph, dict[EdgeType, Graph]], - edge_dir: str, -) -> dict[NodeType, torch.Tensor]: - """Compute node degrees from a graph and aggregate across all machines. +) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: + """ + Compute node degrees from a graph and aggregate across all machines. - For each non-label edge type, degrees are derived from the CSR row pointers - (indptr). For heterogeneous graphs, degrees are summed across all edge types - incident to each anchor node type **locally** before the all-reduce, so the - per-edge-type tensor is only a transient intermediate and is never stored, - returned, or transmitted over RPC. + Computes degrees from the CSR row pointers (indptr) and performs all-reduce + to aggregate across ranks. Over-counting correction (for processes sharing the same data) is handled automatically by detecting the distributed topology. @@ -57,17 +52,13 @@ def compute_and_broadcast_degree_tensor( Args: graph: A Graph (homogeneous) or dict[EdgeType, Graph] (heterogeneous). For heterogeneous graphs, label edge types are automatically excluded - — they are supervision edges and should not contribute to node degree - for graph traversal algorithms like PPR. - edge_dir: Sampling direction — ``"in"`` or ``"out"``. Determines which - end of each edge is the anchor node type for degree accumulation. + from the computation — they are supervision edges and should not + contribute to node degree for graph traversal algorithms like PPR. Returns: - dict[NodeType, torch.Tensor]: Aggregated degree tensors keyed by node - type. For homogeneous graphs the single entry uses - ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` as its key. Values are int16 - tensors of shape ``[num_nodes_of_that_type]``, capped at - ``torch.iinfo(torch.int16).max``. + Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensors. + - For homogeneous graphs: A tensor of shape [num_nodes]. + - For heterogeneous graphs: A dict mapping non-label EdgeType to degree tensors. Raises: RuntimeError: If torch.distributed is not initialized. @@ -78,51 +69,52 @@ def compute_and_broadcast_degree_tensor( "compute_and_broadcast_degree_tensor requires torch.distributed to be initialized." ) - local_dict: dict[NodeType, torch.Tensor] = {} - + # Compute local degrees from graph topology if isinstance(graph, Graph): topo = graph.topo if topo is None or topo.indptr is None: raise ValueError("Topology/indptr not available for graph.") - local_dict[DEFAULT_HOMOGENEOUS_NODE_TYPE] = _compute_degrees_from_indptr( - topo.indptr + local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]] = ( + _compute_degrees_from_indptr(topo.indptr) ) else: + local_dict: dict[EdgeType, torch.Tensor] = {} for edge_type, edge_graph in graph.items(): + # Label edge types are supervision edges and should not contribute + # to node degree for graph traversal algorithms like PPR. if is_label_edge_type(edge_type): continue - anchor_type: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] topo = edge_graph.topo if topo is None or topo.indptr is None: logger.warning( f"Topology/indptr not available for edge type {edge_type}, using empty tensor." ) - degrees = torch.empty(0, dtype=torch.int16) - else: - degrees = _compute_degrees_from_indptr(topo.indptr) - - if anchor_type in local_dict: - # Accumulate in int64 to avoid overflow, clamp back to int16 - existing = local_dict[anchor_type] - max_len = max(len(existing), len(degrees)) - summed = _pad_to_size(existing, max_len).to(torch.int64) - summed[: len(degrees)] += degrees.to(torch.int64) - local_dict[anchor_type] = _clamp_to_int16(summed) + local_dict[edge_type] = torch.empty(0, dtype=torch.int16) else: - local_dict[anchor_type] = degrees + local_dict[edge_type] = _compute_degrees_from_indptr(topo.indptr) + local_degrees = local_dict - result = _all_reduce_degrees(local_dict) + # All-reduce across ranks (over-counting correction handled internally) + result = _all_reduce_degrees(local_degrees) - for node_type, degrees in result.items(): - if degrees.numel() > 0: + # Log results + if isinstance(result, torch.Tensor): + if result.numel() > 0: logger.info( - f"{node_type}: {degrees.size(0)} nodes, " - f"max={degrees.max().item()}, min={degrees.min().item()}" + f"{result.size(0)} nodes, max={result.max().item()}, min={result.min().item()}" ) else: - logger.info( - f"Graph contained 0 nodes for node type {node_type} when computing degrees" - ) + logger.info("Graph contained 0 nodes when computing degrees") + else: + for edge_type, degrees in result.items(): + if degrees.numel() > 0: + logger.info( + f"{edge_type}: {degrees.size(0)} nodes, max={degrees.max().item()}, min={degrees.min().item()}" + ) + else: + logger.info( + f"Graph contained 0 nodes for edge type {edge_type} when computing degrees" + ) return result @@ -151,19 +143,21 @@ def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: def _all_reduce_degrees( - local_degrees: dict[NodeType, torch.Tensor], -) -> dict[NodeType, torch.Tensor]: - """All-reduce degree tensors across ranks. + local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], +) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: + """All-reduce degree tensors across ranks, handling both homogeneous and heterogeneous cases. - Moves tensors to GPU for the all-reduce if using NCCL backend (which - requires CUDA), otherwise keeps tensors on CPU (for Gloo backend). + For heterogeneous graphs, iterates over the edge types in local_degrees. All partitions + are expected to have entries for all edge types (even if some have empty tensors). + + Moves tensors to GPU for the all-reduce if using NCCL backend (which requires CUDA), + otherwise keeps tensors on CPU (for Gloo backend). Over-counting correction: - In distributed training, multiple processes on the same machine often - share the same graph partition data (via shared memory). When we - all-reduce degrees, each process contributes its "local" degrees — but - if 4 processes on one machine all read the same partition, that - partition's degrees get summed 4 times instead of 1. + In distributed training, multiple processes on the same machine often share the + same graph partition data (via shared memory). When we all-reduce degrees, each + process contributes its "local" degrees - but if 4 processes on one machine all + read the same partition, that partition's degrees get summed 4 times instead of 1. Example: Machine A has 2 processes sharing partition with degrees [3, 5, 2]. Machine B has 2 processes sharing partition with degrees [1, 4, 6]. @@ -174,16 +168,16 @@ def _all_reduce_degrees( With correction: divide by local_world_size (2 per machine) = [4, 9, 8] (correct: [3+1, 5+4, 2+6]) - This function detects how many processes share the same machine by - comparing IP addresses, then divides by that count to correct the - over-counting. + This function detects how many processes share the same machine by comparing + IP addresses, then divides by that count to correct the over-counting. Args: - local_degrees: Dict mapping NodeType to local degree tensors. - All partitions must have entries for all node types. + local_degrees: Either a single tensor (homogeneous) or dict mapping EdgeType + to tensors (heterogeneous). For heterogeneous graphs, all partitions must + have entries for all edge types. Returns: - Aggregated degree tensors keyed by NodeType. + Aggregated degree tensors in the same format as input. Raises: RuntimeError: If torch.distributed is not initialized. @@ -193,25 +187,38 @@ def _all_reduce_degrees( "_all_reduce_degrees requires torch.distributed to be initialized." ) + # Compute local_world_size: number of processes on the same machine sharing data all_ips = get_internal_ip_from_all_ranks() my_rank = torch.distributed.get_rank() my_ip = all_ips[my_rank] local_world_size = Counter(all_ips)[my_ip] + # NCCL backend requires CUDA tensors; Gloo works with CPU device = get_device_from_process_group() def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: """All-reduce a single tensor with size sync and over-counting correction.""" + # Synchronize max size across all ranks local_size = torch.tensor([tensor.size(0)], dtype=torch.long, device=device) torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) max_size = int(local_size.item()) + # Pad, convert to int64 (all_reduce doesn't support int16), move to device padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) + # Correct for over-counting, move back to CPU, and clamp to int16 + # TODO (mkolodner-sc): Potentially want to paramaterize this in the future if we want degrees higher than the int16 max. return _clamp_to_int16((padded // local_world_size).cpu()) - result: dict[NodeType, torch.Tensor] = {} - for node_type in sorted(local_degrees.keys()): - result[node_type] = reduce_tensor(local_degrees[node_type]) + # Homogeneous case + if isinstance(local_degrees, torch.Tensor): + return reduce_tensor(local_degrees) + + # Heterogeneous case: all-reduce each edge type + # Sort edge types for deterministic ordering across ranks + result: dict[EdgeType, torch.Tensor] = {} + for edge_type in sorted(local_degrees.keys()): + result[edge_type] = reduce_tensor(local_degrees[edge_type]) + return result diff --git a/gigl/distributed/utils/dist_sampler.py b/gigl/distributed/utils/dist_sampler.py index db5dba1af..0333f4138 100644 --- a/gigl/distributed/utils/dist_sampler.py +++ b/gigl/distributed/utils/dist_sampler.py @@ -10,7 +10,7 @@ RemoteDistSamplingWorkerOptions, ) from graphlearn_torch.sampler import EdgeSamplerInput, NodeSamplerInput, SamplingConfig -from graphlearn_torch.typing import NodeType +from graphlearn_torch.typing import EdgeType from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler @@ -35,7 +35,7 @@ def create_dist_sampler( worker_options: Union[MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions], channel: ChannelBase, sampler_options: SamplerOptions, - degree_tensors: Optional[dict[NodeType, torch.Tensor]], + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], current_device: torch.device, ) -> SamplerRuntime: """Create a GiGL sampler runtime for one channel on one worker. @@ -84,6 +84,7 @@ def create_dist_sampler( max_ppr_nodes=sampler_options.max_ppr_nodes, max_fetch_iterations=sampler_options.max_fetch_iterations, num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, + total_degree_dtype=sampler_options.total_degree_dtype, degree_tensors=degree_tensors, ) else: diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index f44f06e6d..3313d22ac 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -660,6 +660,11 @@ def __init__( num_heads, bias=False, ) + self._pairwise_nonmissing_attention_bias = nn.Parameter( + torch.zeros(num_heads) + ) + else: + self.register_parameter("_pairwise_nonmissing_attention_bias", None) # Transformer encoder layers # Default feedforward ratio: 4.0 for standard activations, 8/3 for XGLU @@ -981,6 +986,7 @@ def _build_attention_bias( attention_bias_data: Dictionary containing optional PE tensors: - "anchor_bias": (batch, seq, num_anchor_attrs) or None - "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None + - "pairwise_nonmissing_mask": (batch, seq, seq) or None Returns: Combined attention bias tensor of shape (batch_size, num_heads, seq_len, seq_len) @@ -1044,6 +1050,17 @@ def _build_attention_bias( ) # (batch, num_heads, seq, seq) attn_bias = attn_bias + pairwise_bias + pairwise_nonmissing_mask = attention_bias_data.get("pairwise_nonmissing_mask") + if pairwise_nonmissing_mask is not None: + if self._pairwise_nonmissing_attention_bias is None: + raise ValueError( + "Pairwise nonmissing attention bias is not initialized." + ) + attn_bias = attn_bias + ( + pairwise_nonmissing_mask.to(dtype).unsqueeze(1) + * self._pairwise_nonmissing_attention_bias.view(1, -1, 1, 1) + ) + return attn_bias def _encode_and_readout( diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 602f95bde..fe3745d9f 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -71,6 +71,7 @@ class SequenceAuxiliaryData(TypedDict): anchor_bias: Optional[Tensor] pairwise_bias: Optional[Tensor] + pairwise_nonmissing_mask: Optional[Tensor] token_input: Optional[TokenInputData] @@ -143,6 +144,7 @@ def heterodata_to_graph_transformer_input( ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None + ``"pairwise_nonmissing_mask"`` shaped ``(batch, seq, seq)`` or None ``"token_input"`` as a dict mapping attribute name to a ``(batch, seq, 1)`` tensor, or None @@ -306,11 +308,14 @@ def heterodata_to_graph_transformer_input( device=device, ) - pairwise_feature_sequences = _lookup_pairwise_relative_features( - node_index_sequences=node_index_sequences, - valid_mask=valid_mask, - csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, - device=device, + pairwise_feature_sequences, pairwise_nonmissing_mask = ( + _lookup_pairwise_relative_features( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, + attr_names=pairwise_bias_attr_names, + device=device, + ) ) anchor_bias_features = _compose_anchor_feature_tensor( @@ -332,6 +337,7 @@ def heterodata_to_graph_transformer_input( { "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, + "pairwise_nonmissing_mask": pairwise_nonmissing_mask, "token_input": token_input_features, }, ) @@ -798,8 +804,9 @@ def _lookup_pairwise_relative_features( node_index_sequences: Tensor, valid_mask: Tensor, csr_matrices: Optional[list[Tensor]], + attr_names: Optional[list[str]], device: torch.device, -) -> Optional[Tensor]: +) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Look up pairwise sparse values for each valid token pair in the sequence. @@ -815,13 +822,19 @@ def _lookup_pairwise_relative_features( node_index_sequences: (batch_size, max_seq_len) node indices for each sequence position valid_mask: (batch_size, max_seq_len) bool tensor indicating valid positions csr_matrices: List of sparse CSR matrices, each (num_nodes, num_nodes) + attr_names: Optional names for the pairwise attributes. Used only to + produce clearer error messages when multiple attrs disagree on + sparse support. device: Device for output tensor Returns: features: (batch_size, max_seq_len, max_seq_len, num_attrs) tensor where features[b, i, j, k] = csr_matrices[k][node_index_sequences[b, i], node_index_sequences[b, j]] for valid (i, j) pairs, 0.0 for padding positions. - Returns None if csr_matrices is empty. + nonmissing_mask: (batch_size, max_seq_len, max_seq_len) bool tensor + that is True for valid diagonal self pairs and valid sparse + entries, and False for missing non-self pairs and padding. + Returns (None, None) if csr_matrices is empty. Example: # batch_size=2, max_seq_len=3, num_attrs=1 (e.g., random_walk_se) @@ -844,7 +857,7 @@ def _lookup_pairwise_relative_features( # (pad) [0.0, 0.0, 0.0] """ if not csr_matrices: - return None + return None, None batch_size, max_seq_len = node_index_sequences.shape num_attrs = len(csr_matrices) @@ -853,10 +866,21 @@ def _lookup_pairwise_relative_features( dtype=torch.float, device=device, ) + nonmissing_mask = torch.zeros( + (batch_size, max_seq_len, max_seq_len), + dtype=torch.bool, + device=device, + ) pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1) if not pair_valid_mask.any(): - return features + return features, nonmissing_mask + + self_pair_mask = pair_valid_mask & torch.eye( + max_seq_len, + dtype=torch.bool, + device=device, + ).unsqueeze(0) row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len) col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1) @@ -864,15 +888,32 @@ def _lookup_pairwise_relative_features( valid_row_indices = row_indices[pair_valid_mask] valid_col_indices = col_indices[pair_valid_mask] + first_attr_name = attr_names[0] if attr_names else "attr_0" for attr_idx, pe_matrix in enumerate(csr_matrices): - pe_values = _lookup_csr_values( + pe_values, found_mask = _lookup_csr_values_and_found( csr_matrix=pe_matrix, row_indices=valid_row_indices, col_indices=valid_col_indices, ) features[..., attr_idx][pair_valid_mask] = pe_values + attr_nonmissing_mask = torch.zeros_like(nonmissing_mask) + attr_nonmissing_mask[pair_valid_mask] = found_mask + attr_nonmissing_mask |= self_pair_mask + if attr_idx == 0: + nonmissing_mask = attr_nonmissing_mask + continue + if not torch.equal(nonmissing_mask, attr_nonmissing_mask): + attr_name = ( + attr_names[attr_idx] if attr_names else f"attr_{attr_idx}" + ) + raise ValueError( + "Pairwise attention bias attributes must share identical " + "nonmissing support after treating valid diagonal self pairs " + f"as nonmissing, but '{first_attr_name}' and '{attr_name}' " + "differ." + ) - return features + return features, nonmissing_mask def _get_k_hop_neighbors_sparse( @@ -964,47 +1005,66 @@ def _lookup_csr_values( Returns: (n,) values from csr_matrix[row, col], or default_value if not present """ + values, _ = _lookup_csr_values_and_found( + csr_matrix=csr_matrix, + row_indices=row_indices, + col_indices=col_indices, + default_value=default_value, + ) + return values + + +def _lookup_csr_values_and_found( + csr_matrix: Tensor, + row_indices: Tensor, + col_indices: Tensor, + default_value: float = 0.0, +) -> tuple[Tensor, Tensor]: + """ + Look up values in a CSR sparse matrix and report which entries were present. + + Returns both the looked-up values and a boolean found-mask so callers can + distinguish missing sparse entries from explicit zero-valued entries. + """ n = row_indices.size(0) device = row_indices.device if n == 0: - return torch.zeros(0, device=device, dtype=torch.float) + return ( + torch.zeros(0, device=device, dtype=torch.float), + torch.zeros(0, device=device, dtype=torch.bool), + ) crow_indices = csr_matrix.crow_indices() col_indices_csr = csr_matrix.col_indices() values_csr = csr_matrix.values() - # Get row start/end pointers row_starts = crow_indices[row_indices] row_ends = crow_indices[row_indices + 1] row_lengths = row_ends - row_starts max_row_len = row_lengths.max().item() if max_row_len == 0: - return torch.full((n,), default_value, device=device, dtype=torch.float) + return ( + torch.full((n,), default_value, device=device, dtype=torch.float), + torch.zeros((n,), device=device, dtype=torch.bool), + ) - # Build offset matrix: (n, max_row_len) offsets = row_starts.unsqueeze(1) + torch.arange(max_row_len, device=device) valid_mask = offsets < row_ends.unsqueeze(1) - # Safe indexing with clamping nnz = col_indices_csr.size(0) offsets_clamped = offsets.clamp(max=max(nnz - 1, 0)) - # Get columns at offsets and find matches cols_at_offsets = col_indices_csr[offsets_clamped] col_matches = (cols_at_offsets == col_indices.unsqueeze(1)) & valid_mask - - # Find which queries have matches found = col_matches.any(dim=1) - # Initialize output result = torch.full((n,), default_value, device=device, dtype=torch.float) if found.any(): - # Get match positions and retrieve values match_offsets = col_matches.float().argmax(dim=1) value_indices = row_starts[found] + match_offsets[found] result[found] = values_csr[value_indices].float() - return result + return result, found diff --git a/tests/unit/nn/graph_transformer_test.py b/tests/unit/nn/graph_transformer_test.py index d0fce10c3..091e69713 100644 --- a/tests/unit/nn/graph_transformer_test.py +++ b/tests/unit/nn/graph_transformer_test.py @@ -14,7 +14,11 @@ GraphTransformerEncoderLayer, ) from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation -from tests.test_assets.test_case import TestCase + +try: + from tests.test_assets.test_case import TestCase +except ModuleNotFoundError: # pragma: no cover - optional test harness deps + TestCase = absltest.TestCase def _create_simple_hetero_data() -> HeteroData: @@ -388,6 +392,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: assert encoder._anchor_pe_attention_bias_projection is not None assert encoder._pairwise_pe_attention_bias_projection is not None + assert encoder._pairwise_nonmissing_attention_bias is not None with torch.no_grad(): encoder._anchor_pe_attention_bias_projection.weight.copy_( @@ -411,6 +416,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: ] ] ), + "pairwise_nonmissing_mask": None, "token_input": None, }, ) @@ -447,6 +453,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( [[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]] ), "pairwise_bias": None, + "pairwise_nonmissing_mask": None, "token_input": None, }, ) @@ -457,6 +464,45 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( self.assertEqual(attn_bias[0, 0, 0, 2].item(), 4.25) self.assertEqual(attn_bias[0, 1, 0, 2].item(), 8.5) + def test_pairwise_nonmissing_mask_adds_head_specific_bias(self) -> None: + encoder = self._create_encoder( + pairwise_attention_bias_attr_names=["pairwise_distance"], + ) + + assert encoder._pairwise_pe_attention_bias_projection is not None + assert encoder._pairwise_nonmissing_attention_bias is not None + + with torch.no_grad(): + encoder._pairwise_pe_attention_bias_projection.weight.zero_() + encoder._pairwise_nonmissing_attention_bias.copy_( + torch.tensor([0.5, 1.5]) + ) + + attn_bias = encoder._build_attention_bias( + valid_mask=torch.ones((1, 3), dtype=torch.bool), + sequences=torch.zeros((1, 3, 8), dtype=torch.float), + attention_bias_data={ + "anchor_bias": None, + "pairwise_bias": torch.zeros((1, 3, 3, 1), dtype=torch.float), + "pairwise_nonmissing_mask": torch.tensor( + [ + [ + [True, True, False], + [True, True, False], + [False, False, True], + ] + ] + ), + "token_input": None, + }, + ) + + self.assertEqual(attn_bias.shape, (1, 2, 3, 3)) + self.assertEqual(attn_bias[0, 0, 0, 0].item(), 0.5) + self.assertEqual(attn_bias[0, 1, 0, 0].item(), 1.5) + self.assertEqual(attn_bias[0, 0, 0, 2].item(), 0.0) + self.assertEqual(attn_bias[0, 1, 1, 2].item(), 0.0) + def test_sinusoidal_sequence_positional_encoding_masks_padding(self) -> None: encoder = self._create_encoder( sequence_construction_method="ppr", diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 18551014b..0cb69150f 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -11,7 +11,11 @@ _get_k_hop_neighbors_sparse, heterodata_to_graph_transformer_input, ) -from tests.test_assets.test_case import TestCase + +try: + from tests.test_assets.test_case import TestCase +except ModuleNotFoundError: # pragma: no cover - optional test harness deps + TestCase = absltest.TestCase def create_simple_hetero_data() -> HeteroData: @@ -252,6 +256,7 @@ def test_basic_transform(self): self.assertIsInstance(attention_bias_data, dict) self.assertIn("anchor_bias", attention_bias_data) self.assertIn("pairwise_bias", attention_bias_data) + self.assertIn("pairwise_nonmissing_mask", attention_bias_data) def test_attention_mask_validity(self): """Test that attention mask correctly identifies valid positions.""" @@ -792,6 +797,7 @@ def test_transform_returns_base_sequences_and_anchor_relative_bias(self) -> None assert anchor_bias is not None self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertIsNone(attention_bias_data["pairwise_bias"]) + self.assertIsNone(attention_bias_data["pairwise_nonmissing_mask"]) self.assertTrue(valid_mask[0, 0].item()) def test_attention_bias_outputs_include_valid_mask_and_relative_features( @@ -817,17 +823,53 @@ def test_attention_bias_outputs_include_valid_mask_and_relative_features( self.assertEqual(valid_mask.shape, (1, 4)) anchor_bias = attention_bias_data["anchor_bias"] pairwise_bias = attention_bias_data["pairwise_bias"] + pairwise_nonmissing_mask = attention_bias_data["pairwise_nonmissing_mask"] assert anchor_bias is not None assert pairwise_bias is not None + assert pairwise_nonmissing_mask is not None self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertEqual(pairwise_bias.shape, (1, 4, 4, 1)) + self.assertEqual(pairwise_nonmissing_mask.shape, (1, 4, 4)) self.assertAlmostEqual(anchor_bias[0, 0, 0].item(), 0.0, places=5) self.assertAlmostEqual(anchor_bias[0, 1, 0].item(), 1.0, places=5) self.assertAlmostEqual(anchor_bias[0, 2, 0].item(), 3.0, places=5) self.assertAlmostEqual(pairwise_bias[0, 0, 0, 0].item(), 0.1, places=5) + self.assertTrue(torch.all(pairwise_nonmissing_mask[0, :3, :3])) invalid_pair_mask = ~(valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)) self.assertTrue(torch.all(pairwise_bias[..., 0][invalid_pair_mask] == 0)) + self.assertTrue(torch.all(~pairwise_nonmissing_mask[invalid_pair_mask])) + + def test_pairwise_attention_bias_attr_support_mismatch_raises(self) -> None: + data = _create_hetero_data_with_relative_pe() + pairwise_distance_sparse_mismatch = torch.tensor( + [ + [0.1, 0.0, 0.3, 0.4, 0.5], + [0.6, 0.7, 0.0, 0.9, 1.0], + [1.1, 1.2, 1.3, 1.4, 1.5], + [1.6, 1.7, 1.8, 1.9, 2.0], + [2.1, 2.2, 2.3, 2.4, 2.5], + ] + ) + data.pairwise_distance_sparse_mismatch = ( + pairwise_distance_sparse_mismatch.to_sparse_csr() + ) + + with self.assertRaisesRegex( + ValueError, + "Pairwise attention bias attributes must share identical nonmissing support", + ): + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + pairwise_attention_bias_attr_names=[ + "pairwise_distance", + "pairwise_distance_sparse_mismatch", + ], + ) if __name__ == "__main__": From c03d5ed4b5e65b71a00de554d51ee3518e74b72a Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Wed, 27 May 2026 11:45:29 -0700 Subject: [PATCH 7/7] Add sparse relation-aware graph transformer signals --- gigl/nn/graph_transformer.py | 590 +++++++++++++++++- gigl/transforms/graph_transformer.py | 584 +++++++++++++++-- tests/unit/nn/graph_transformer_test.py | 547 +++++++++++++++- .../unit/transforms/graph_transformer_test.py | 183 +++++- 4 files changed, 1807 insertions(+), 97 deletions(-) diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index 3313d22ac..d826bedce 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -239,6 +239,18 @@ class GraphTransformerEncoderLayer(nn.Module): activation: Activation function for the feed-forward network. Supported values: "gelu" (default), "relu", "silu", "tanh", "geglu", "swiglu", "reglu". + relation_attention_mode: Optional relation-aware augmentation strategy + for attention scores. ``"none"`` preserves the default shared + self-attention path. ``"edge_type_bilinear"`` adds a learned + per-edge-type bilinear term for sampled directed graph edges. This + changes attention weights, not value/message content. + relation_value_mode: Optional relation-aware value augmentation strategy. + ``"sparse_residual_gate"`` adds a zero-initialized sparse residual + message path from relation-indexed source values to target queries. + This changes relation-specific message content without replacing + the main SDPA attention implementation. + num_relations: Number of relation channels expected in + ``pairwise_relation_indices`` when a relation-aware mode is enabled. Raises: ValueError: If model_dim is not divisible by num_heads. @@ -252,16 +264,41 @@ def __init__( dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, activation: str = "gelu", + relation_attention_mode: Literal["none", "edge_type_bilinear"] = "none", + relation_value_mode: Literal["none", "sparse_residual_gate"] = "none", + num_relations: int = 0, ) -> None: super().__init__() if model_dim % num_heads != 0: raise ValueError( f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})" ) + if relation_attention_mode not in {"none", "edge_type_bilinear"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_bilinear'}, " + f"got '{relation_attention_mode}'" + ) + if relation_value_mode not in {"none", "sparse_residual_gate"}: + raise ValueError( + "relation_value_mode must be one of " + "{'none', 'sparse_residual_gate'}, " + f"got '{relation_value_mode}'" + ) + if ( + relation_attention_mode == "edge_type_bilinear" + or relation_value_mode == "sparse_residual_gate" + ) and num_relations <= 0: + raise ValueError( + "Relation-aware attention/value modes require num_relations > 0." + ) self._num_heads = num_heads self._head_dim = model_dim // num_heads self._attention_dropout_rate = attention_dropout_rate + self._relation_attention_mode = relation_attention_mode + self._relation_value_mode = relation_value_mode + self._num_relations = num_relations self._attention_norm = nn.LayerNorm(model_dim) self._query_projection = nn.Linear(model_dim, model_dim) @@ -269,6 +306,23 @@ def __init__( self._value_projection = nn.Linear(model_dim, model_dim) self._output_projection = nn.Linear(model_dim, model_dim) self._dropout = nn.Dropout(dropout_rate) + self._relation_attention_matrices: Optional[nn.Parameter] = None + if relation_attention_mode == "edge_type_bilinear": + # Relation-specific bilinear logit term: + # score(target, source, relation) += q_target^T W_relation k_source + # Zero init keeps startup behavior identical to shared attention. + self._relation_attention_matrices = nn.Parameter( + torch.zeros(num_relations, num_heads, self._head_dim, self._head_dim) + ) + self._relation_value_gates: Optional[nn.Parameter] = None + if relation_value_mode == "sparse_residual_gate": + # Lightweight relation-specific value/message path: + # message(target) += gate_relation * value_source + # This is a sparse residual on relation edges because PyTorch SDPA + # only accepts one shared value tensor, not per-edge transformed values. + self._relation_value_gates = nn.Parameter( + torch.zeros(num_relations, num_heads, self._head_dim) + ) self._ffn_norm = nn.LayerNorm(model_dim) self._ffn = FeedForwardNetwork( @@ -287,6 +341,10 @@ def reset_parameters(self) -> None: nn.init.xavier_uniform_(projection.weight) if projection.bias is not None: nn.init.zeros_(projection.bias) + if self._relation_attention_matrices is not None: + nn.init.zeros_(self._relation_attention_matrices) + if self._relation_value_gates is not None: + nn.init.zeros_(self._relation_value_gates) self._ffn_norm.reset_parameters() self._ffn.reset_parameters() @@ -294,6 +352,7 @@ def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_indices: Optional[Tensor] = None, valid_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass. @@ -303,6 +362,9 @@ def forward( attn_bias: Optional attention bias of shape ``(batch, num_heads, seq, seq)`` or broadcastable. Added as an additive mask to attention scores. + pairwise_relation_indices: Optional long tensor of shape + ``(num_relation_edges, 4)`` with sparse + ``(batch_idx, query_pos, key_pos, relation_idx)`` coordinates. valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used to zero out padded token states after each residual block. @@ -330,15 +392,24 @@ def forward( batch_size, seq_len, self._num_heads, self._head_dim ).transpose(1, 2) - attention_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=self._attention_dropout_rate if self.training else 0.0, - is_causal=False, + attention_output = self._run_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + pairwise_relation_indices=pairwise_relation_indices, ) + if self._relation_value_mode == "sparse_residual_gate": + # Relation bilinear attention decides how strongly to attend along + # relation edges. This residual separately lets relation type + # change the content passed from source values to target queries. + attention_output = self._add_relation_value_residual( + attention_output=attention_output, + value=value, + pairwise_relation_indices=pairwise_relation_indices, + ) + # Reshape back to (batch, seq, model_dim) attention_output = attention_output.transpose(1, 2).reshape( batch_size, seq_len, model_dim @@ -360,6 +431,283 @@ def forward( return x + def _run_attention( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + pairwise_relation_indices: Optional[Tensor], + ) -> Tensor: + if self._relation_attention_mode == "edge_type_bilinear": + return self._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + pairwise_relation_indices=pairwise_relation_indices, + ) + + # Keep the main path on PyTorch SDPA. Depending on device/dtype/mask, + # PyTorch may dispatch this to FlashAttention or another SDPA backend. + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + + def _run_relation_aware_attention( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + pairwise_relation_indices: Optional[Tensor], + ) -> Tensor: + # The relation-aware logit path still uses SDPA. We only add an + # additive per-relation bias before attention; value vectors remain + # shared unless relation_value_mode adds a sparse residual afterward. + attn_bias = self._add_relation_attention_bias( + attn_bias=attn_bias, + query=query, + key=key, + pairwise_relation_indices=pairwise_relation_indices, + ) + + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + + def _add_relation_value_residual( + self, + attention_output: Tensor, + value: Tensor, + pairwise_relation_indices: Optional[Tensor], + ) -> Tensor: + if pairwise_relation_indices is None: + raise ValueError( + "pairwise_relation_indices is required when " + "relation_value_mode='sparse_residual_gate'." + ) + if self._relation_value_gates is None: + raise ValueError("Relation value gates are not initialized.") + if pairwise_relation_indices.numel() == 0: + return attention_output + if ( + pairwise_relation_indices.dim() != 2 + or pairwise_relation_indices.size(-1) != 4 + ): + raise ValueError( + "pairwise_relation_indices must have shape (num_relation_edges, 4)." + ) + + pairwise_relation_indices = pairwise_relation_indices.to( + device=value.device, + dtype=torch.long, + ) + batch_indices = pairwise_relation_indices[:, 0] + query_indices = pairwise_relation_indices[:, 1] + key_indices = pairwise_relation_indices[:, 2] + relation_indices = pairwise_relation_indices[:, 3] + if ( + relation_indices.min().item() < 0 + or relation_indices.max().item() >= self._num_relations + ): + raise ValueError( + "pairwise_relation_indices contains relation ids outside " + f"[0, {self._num_relations})." + ) + + batch_size, _, seq_len, _ = value.shape + value_by_position = value.transpose(1, 2) + selected_values = value_by_position[batch_indices, key_indices] + selected_gates = self._relation_value_gates.to(dtype=value.dtype)[ + relation_indices + ] + messages = selected_values * selected_gates + + residual_by_position = value.new_zeros( + (batch_size, seq_len, self._num_heads, self._head_dim) + ) + residual_by_position.index_put_( + (batch_indices, query_indices), + messages, + accumulate=True, + ) + + # Multiple relation edges can target the same token. Average the sparse + # residual by target degree so this auxiliary path does not grow just + # because a sampled sequence has more relation edges. + counts = value.new_zeros((batch_size, seq_len)) + counts.index_put_( + (batch_indices, query_indices), + torch.ones( + pairwise_relation_indices.size(0), + dtype=value.dtype, + device=value.device, + ), + accumulate=True, + ) + residual_by_position = residual_by_position / counts.clamp_min(1).view( + batch_size, + seq_len, + 1, + 1, + ) + + return attention_output + residual_by_position.transpose(1, 2).to( + dtype=attention_output.dtype + ) + + def _build_relation_attention_bias( + self, + query: Tensor, + key: Tensor, + pairwise_relation_indices: Optional[Tensor], + ) -> Optional[Tensor]: + if pairwise_relation_indices is not None and pairwise_relation_indices.numel() == 0: + return None + + batch_size, _, seq_len, _ = query.shape + empty_bias = query.new_zeros( + (batch_size, self._num_heads, seq_len, seq_len) + ) + return self._add_relation_attention_bias( + attn_bias=empty_bias, + query=query, + key=key, + pairwise_relation_indices=pairwise_relation_indices, + ) + + def _add_relation_attention_bias( + self, + attn_bias: Optional[Tensor], + query: Tensor, + key: Tensor, + pairwise_relation_indices: Optional[Tensor], + ) -> Optional[Tensor]: + if pairwise_relation_indices is None: + raise ValueError( + "pairwise_relation_indices is required when " + "relation_attention_mode='edge_type_bilinear'." + ) + if self._relation_attention_matrices is None: + raise ValueError("Relation attention matrices are not initialized.") + if pairwise_relation_indices.numel() == 0: + return attn_bias + if ( + pairwise_relation_indices.dim() != 2 + or pairwise_relation_indices.size(-1) != 4 + ): + raise ValueError( + "pairwise_relation_indices must have shape (num_relation_edges, 4)." + ) + + pairwise_relation_indices = pairwise_relation_indices.to( + device=query.device, + dtype=torch.long, + ) + batch_indices = pairwise_relation_indices[:, 0] + query_indices = pairwise_relation_indices[:, 1] + key_indices = pairwise_relation_indices[:, 2] + relation_indices = pairwise_relation_indices[:, 3] + if ( + relation_indices.min().item() < 0 + or relation_indices.max().item() >= self._num_relations + ): + raise ValueError( + "pairwise_relation_indices contains relation ids outside " + f"[0, {self._num_relations})." + ) + + batch_size, _, seq_len, _ = query.shape + if attn_bias is None: + attn_bias = query.new_zeros( + (batch_size, self._num_heads, seq_len, seq_len) + ) + elif attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len): + attn_bias = attn_bias.expand( + batch_size, + self._num_heads, + seq_len, + seq_len, + ).clone() + else: + # The same base attention bias is reused across encoder layers, so + # relation-aware logits must be added to a per-layer copy. + attn_bias = attn_bias.clone() + + attn_bias_by_position = attn_bias.permute(0, 2, 3, 1) + query_by_position = query.transpose(1, 2) + key_by_position = key.transpose(1, 2) + relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype) + + # Relation indices are emitted grouped by relation in the transform path. + # Sort only if callers provide an unsorted tensor, avoiding repeated + # full boolean masks over all relation edges. + if relation_indices.numel() > 1 and not torch.all( + relation_indices[1:] >= relation_indices[:-1] + ).item(): + relation_sort_perm = torch.argsort(relation_indices) + relation_indices = relation_indices[relation_sort_perm] + batch_indices = batch_indices[relation_sort_perm] + query_indices = query_indices[relation_sort_perm] + key_indices = key_indices[relation_sort_perm] + + unique_relation_indices, relation_counts = torch.unique_consecutive( + relation_indices, + return_counts=True, + ) + relation_start = 0 + for relation_idx_tensor, relation_count_tensor in zip( + unique_relation_indices, + relation_counts, + ): + relation_idx = int(relation_idx_tensor.item()) + relation_end = relation_start + int(relation_count_tensor.item()) + relation_batch_indices = batch_indices[relation_start:relation_end] + relation_query_indices = query_indices[relation_start:relation_end] + relation_key_indices = key_indices[relation_start:relation_end] + + selected_query = query_by_position[ + relation_batch_indices, + relation_query_indices, + ] + # This term changes the attention score for relation edge + # source -> target, but the SDPA value content is still value_source. + transformed_query = torch.einsum( + "nhd,hde->nhe", + selected_query, + relation_matrices[relation_idx], + ) + selected_key = key_by_position[ + relation_batch_indices, + relation_key_indices, + ] + relation_scores = (transformed_query * selected_key).sum(dim=-1) + attn_bias_by_position.index_put_( + ( + relation_batch_indices, + relation_query_indices, + relation_key_indices, + ), + (relation_scores / math.sqrt(self._head_dim)).to( + dtype=attn_bias.dtype + ), + accumulate=True, + ) + relation_start = relation_end + + return attn_bias + class GraphTransformerEncoder(nn.Module): """Graph Transformer encoder for heterogeneous graphs. @@ -376,9 +724,8 @@ class GraphTransformerEncoder(nn.Module): node_type_to_feat_dim_map: Dictionary mapping node types to their input feature dimensions. edge_type_to_feat_dim_map: Dictionary mapping edge types to their - feature dimensions. Accepted for interface conformance with - ``HGT``/``SimpleHGN``; edge features are not used by the - graph transformer. + feature dimensions. Used by optional relation-aware and sparse + edge-attribute attention-bias paths. hid_dim: Hidden dimension for transformer layers. All node types are projected to this dimension before processing. out_dim: Output embedding dimension. @@ -454,6 +801,18 @@ class GraphTransformerEncoder(nn.Module): uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU's gating doubles the effective parameters, so a smaller ratio maintains similar parameter count. + relation_attention_mode: Optional relation-aware augmentation for + attention scores. ``"none"`` preserves the current transformer path. + ``"edge_type_bilinear"`` adds a learned per-edge-type bilinear score + term for sampled directed edges. + relation_value_mode: Optional relation-aware value augmentation. + ``"none"`` preserves the current transformer path. + ``"sparse_residual_gate"`` adds a zero-initialized sparse residual + value path on sampled directed relation edges. + edge_attr_attention_bias_mode: Optional edge-attribute logit-bias path. + ``"none"`` preserves the current behavior. ``"sparse_linear"`` adds + a zero-initialized per-edge-type linear projection from sampled + edge attributes to per-head attention logits. Notes: This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap @@ -506,6 +865,9 @@ def __init__( pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", feedforward_ratio: Optional[float] = None, + relation_attention_mode: Literal["none", "edge_type_bilinear"] = "none", + relation_value_mode: Literal["none", "sparse_residual_gate"] = "none", + edge_attr_attention_bias_mode: Literal["none", "sparse_linear"] = "none", **kwargs: object, ) -> None: super().__init__() @@ -553,6 +915,24 @@ def __init__( "{'anchor_neighbor_attention', 'anchor_only'}, " f"got '{readout_mode}'" ) + if relation_attention_mode not in {"none", "edge_type_bilinear"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_bilinear'}, " + f"got '{relation_attention_mode}'" + ) + if relation_value_mode not in {"none", "sparse_residual_gate"}: + raise ValueError( + "relation_value_mode must be one of " + "{'none', 'sparse_residual_gate'}, " + f"got '{relation_value_mode}'" + ) + if edge_attr_attention_bias_mode not in {"none", "sparse_linear"}: + raise ValueError( + "edge_attr_attention_bias_mode must be one of " + "{'none', 'sparse_linear'}, " + f"got '{edge_attr_attention_bias_mode}'" + ) anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] @@ -585,6 +965,24 @@ def __init__( self._readout_mode = readout_mode self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads + self._relation_attention_mode = relation_attention_mode + self._relation_value_mode = relation_value_mode + self._edge_attr_attention_bias_mode = edge_attr_attention_bias_mode + self._edge_type_to_feat_dim_map = { + edge_type: edge_type_to_feat_dim_map[edge_type] + for edge_type in sorted(edge_type_to_feat_dim_map.keys()) + } + self._relation_attention_edge_types = ( + list(self._edge_type_to_feat_dim_map.keys()) + if relation_attention_mode == "edge_type_bilinear" + or relation_value_mode == "sparse_residual_gate" + else [] + ) + self._edge_attr_attention_bias_edge_types = ( + list(self._edge_type_to_feat_dim_map.keys()) + if edge_attr_attention_bias_mode == "sparse_linear" + else [] + ) anchor_input_embedding_attr_names = ( set(anchor_based_input_embedding_dict.keys()) if anchor_based_input_embedding_dict is not None @@ -652,6 +1050,8 @@ def __init__( num_heads, bias=False, ) + # Start structural logit bias neutral; training can turn it on if useful. + nn.init.zeros_(self._anchor_pe_attention_bias_projection.weight) self._pairwise_pe_attention_bias_projection: Optional[nn.Linear] = None if self._pairwise_attention_bias_attr_names: @@ -660,12 +1060,27 @@ def __init__( num_heads, bias=False, ) + nn.init.zeros_(self._pairwise_pe_attention_bias_projection.weight) self._pairwise_nonmissing_attention_bias = nn.Parameter( torch.zeros(num_heads) ) else: self.register_parameter("_pairwise_nonmissing_attention_bias", None) + self._edge_attr_attention_bias_projection_dict = nn.ModuleDict() + if self._edge_attr_attention_bias_mode == "sparse_linear": + for relation_idx, edge_type in enumerate( + self._edge_attr_attention_bias_edge_types + ): + edge_attr_dim = int(self._edge_type_to_feat_dim_map[edge_type]) + if edge_attr_dim <= 0: + continue + projection = nn.Linear(edge_attr_dim, num_heads, bias=False) + nn.init.zeros_(projection.weight) + self._edge_attr_attention_bias_projection_dict[str(relation_idx)] = ( + projection + ) + # Transformer encoder layers # Default feedforward ratio: 4.0 for standard activations, 8/3 for XGLU # XGLU's gating mechanism doubles effective parameters, so smaller ratio @@ -683,6 +1098,9 @@ def __init__( dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, activation=activation, + relation_attention_mode=relation_attention_mode, + relation_value_mode=relation_value_mode, + num_relations=len(self._relation_attention_edge_types), ) for _ in range(num_layers) ] @@ -821,6 +1239,20 @@ def forward( anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, anchor_based_input_attr_names=self._anchor_based_input_attr_names, pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, + relation_edge_types=( + self._relation_attention_edge_types + if self._relation_attention_mode == "edge_type_bilinear" + or self._relation_value_mode == "sparse_residual_gate" + else None + ), + edge_attr_edge_type_to_feat_dim_map=( + { + edge_type: self._edge_type_to_feat_dim_map[edge_type] + for edge_type in self._edge_attr_attention_bias_edge_types + } + if self._edge_attr_attention_bias_mode == "sparse_linear" + else None + ), ) # Free memory after sequences are built @@ -857,6 +1289,9 @@ def forward( sequences=sequences, valid_mask=valid_mask, attn_bias=attn_bias, + pairwise_relation_indices=sequence_auxiliary_data.get( + "pairwise_relation_indices" + ), ) embeddings = self._output_projection(embeddings) @@ -986,7 +1421,9 @@ def _build_attention_bias( attention_bias_data: Dictionary containing optional PE tensors: - "anchor_bias": (batch, seq, num_anchor_attrs) or None - "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None - - "pairwise_nonmissing_mask": (batch, seq, seq) or None + - "pairwise_nonmissing_indices": (num_pairs, 3) or None + - "pairwise_edge_attr_indices": dict[int, (num_edges, 3)] or None + - "pairwise_edge_attr_values": dict[int, (num_edges, edge_dim)] or None Returns: Combined attention bias tensor of shape (batch_size, num_heads, seq_len, seq_len) @@ -1006,14 +1443,14 @@ def _build_attention_bias( device = sequences.device negative_inf = torch.finfo(dtype).min - # Step 1: Initialize with padding mask bias - # Shape: (batch, 1, 1, seq) - broadcasts to mask invalid keys for all queries/heads + # Step 1: Initialize with padding mask bias. + # Shape: (batch, 1, 1, seq) broadcasts to mask invalid keys. attn_bias = torch.zeros( (batch_size, 1, 1, seq_len), dtype=dtype, device=device, ) - attn_bias = attn_bias.masked_fill( + attn_bias.masked_fill_( ~valid_mask.unsqueeze(1).unsqueeze(2), # (batch, 1, 1, seq) negative_inf, ) @@ -1050,16 +1487,119 @@ def _build_attention_bias( ) # (batch, num_heads, seq, seq) attn_bias = attn_bias + pairwise_bias - pairwise_nonmissing_mask = attention_bias_data.get("pairwise_nonmissing_mask") - if pairwise_nonmissing_mask is not None: + pairwise_nonmissing_indices = attention_bias_data.get( + "pairwise_nonmissing_indices" + ) + if pairwise_nonmissing_indices is not None: if self._pairwise_nonmissing_attention_bias is None: raise ValueError( "Pairwise nonmissing attention bias is not initialized." ) - attn_bias = attn_bias + ( - pairwise_nonmissing_mask.to(dtype).unsqueeze(1) - * self._pairwise_nonmissing_attention_bias.view(1, -1, 1, 1) - ) + if pairwise_nonmissing_indices.numel() > 0: + if attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len): + attn_bias = attn_bias.expand( + batch_size, + self._num_heads, + seq_len, + seq_len, + ).clone() + pairwise_nonmissing_indices = pairwise_nonmissing_indices.to( + device=device, + dtype=torch.long, + ) + attn_bias_by_position = attn_bias.permute(0, 2, 3, 1) + nonmissing_bias = self._pairwise_nonmissing_attention_bias.to( + dtype=attn_bias.dtype + ).view(1, -1) + attn_bias_by_position.index_put_( + ( + pairwise_nonmissing_indices[:, 0], + pairwise_nonmissing_indices[:, 1], + pairwise_nonmissing_indices[:, 2], + ), + nonmissing_bias.expand(pairwise_nonmissing_indices.size(0), -1), + accumulate=True, + ) + + pairwise_edge_attr_indices = attention_bias_data.get( + "pairwise_edge_attr_indices" + ) + pairwise_edge_attr_values = attention_bias_data.get("pairwise_edge_attr_values") + if pairwise_edge_attr_indices is not None or pairwise_edge_attr_values is not None: + if self._edge_attr_attention_bias_mode != "sparse_linear": + raise ValueError( + "Sparse edge-attribute attention-bias payloads require " + "edge_attr_attention_bias_mode='sparse_linear'." + ) + if pairwise_edge_attr_indices is None or pairwise_edge_attr_values is None: + raise ValueError( + "pairwise_edge_attr_indices and pairwise_edge_attr_values " + "must be provided together." + ) + if set(pairwise_edge_attr_indices.keys()) != set( + pairwise_edge_attr_values.keys() + ): + raise ValueError( + "pairwise_edge_attr_indices and pairwise_edge_attr_values " + "must have identical relation-index keys." + ) + + attn_bias_by_position: Optional[Tensor] = None + for relation_idx in sorted(pairwise_edge_attr_indices.keys()): + edge_attr_indices = pairwise_edge_attr_indices[relation_idx] + edge_attr_values = pairwise_edge_attr_values[relation_idx] + if edge_attr_indices.numel() == 0: + continue + if edge_attr_indices.dim() != 2 or edge_attr_indices.size(-1) != 3: + raise ValueError( + "pairwise_edge_attr_indices entries must have shape " + "(num_edges, 3)." + ) + if ( + edge_attr_values.dim() != 2 + or edge_attr_values.size(0) != edge_attr_indices.size(0) + ): + raise ValueError( + "pairwise_edge_attr_values entries must have shape " + "(num_edges, edge_attr_dim) with the same num_edges " + "as their index tensor." + ) + + projection_key = str(relation_idx) + if projection_key not in self._edge_attr_attention_bias_projection_dict: + raise ValueError( + "No edge-attribute attention-bias projection is " + f"initialized for relation index {relation_idx}." + ) + if attn_bias_by_position is None: + if attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len): + attn_bias = attn_bias.expand( + batch_size, + self._num_heads, + seq_len, + seq_len, + ).clone() + attn_bias_by_position = attn_bias.permute(0, 2, 3, 1) + + edge_attr_projection = ( + self._edge_attr_attention_bias_projection_dict[projection_key] + ) + edge_attr_indices = edge_attr_indices.to( + device=device, + dtype=torch.long, + ) + edge_attr_bias = edge_attr_projection( + edge_attr_values.to(device=device, dtype=dtype) + ) + attn_bias_by_position.index_put_( + ( + edge_attr_indices[:, 0], + edge_attr_indices[:, 1], + edge_attr_indices[:, 2], + ), + edge_attr_bias.to(dtype=attn_bias.dtype), + accumulate=True, + ) return attn_bias @@ -1068,6 +1608,7 @@ def _encode_and_readout( sequences: Tensor, valid_mask: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_indices: Optional[Tensor] = None, ) -> Tensor: """Process sequences through transformer layers and configured readout. @@ -1076,6 +1617,8 @@ def _encode_and_readout( valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. attn_bias: Optional additive attention bias broadcastable to ``(batch_size, num_heads, seq, seq)``. + pairwise_relation_indices: Optional sparse relation coordinates shaped + ``(num_relation_edges, 4)``. Returns: Output embeddings of shape ``(batch_size, hid_dim)``. @@ -1083,7 +1626,12 @@ def _encode_and_readout( x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) for encoder_layer in self._encoder_layers: - x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask) + x = encoder_layer( + x, + attn_bias=attn_bias, + pairwise_relation_indices=pairwise_relation_indices, + valid_mask=valid_mask, + ) x = self._final_norm(x) x = x * valid_mask.unsqueeze(-1).to(x.dtype) diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index fe3745d9f..06fadf016 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -57,7 +57,7 @@ >>> # attention_bias_data['anchor_bias']: (batch_size, max_seq_len, 1) """ -from typing import Literal, Optional, TypedDict +from typing import Literal, NamedTuple, Optional, TypedDict import torch from torch import Tensor @@ -65,19 +65,34 @@ from torch_geometric.typing import NodeType from torch_geometric.utils import to_torch_sparse_tensor +from gigl.src.common.types.graph_data import EdgeType as GiGLEdgeType + TokenInputData = dict[str, Tensor] class SequenceAuxiliaryData(TypedDict): anchor_bias: Optional[Tensor] pairwise_bias: Optional[Tensor] - pairwise_nonmissing_mask: Optional[Tensor] + pairwise_nonmissing_indices: Optional[Tensor] + pairwise_relation_indices: Optional[Tensor] + pairwise_edge_attr_indices: Optional[dict[int, Tensor]] + pairwise_edge_attr_values: Optional[dict[int, Tensor]] token_input: Optional[TokenInputData] PPR_WEIGHT_FEATURE_NAME = "ppr_weight" +class _TokenOccurrenceIndex(NamedTuple): + batch_indices: Tensor + positions: Tensor + node_indices: Tensor + sorted_node_indices: Tensor + node_sort_perm: Tensor + sorted_batch_node_keys: Tensor + batch_node_sort_perm: Tensor + + def heterodata_to_graph_transformer_input( data: HeteroData, batch_size: int, @@ -91,6 +106,8 @@ def heterodata_to_graph_transformer_input( anchor_based_attention_bias_attr_names: Optional[list[str]] = None, anchor_based_input_attr_names: Optional[list[str]] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, + relation_edge_types: Optional[list[GiGLEdgeType]] = None, + edge_attr_edge_type_to_feat_dim_map: Optional[dict[GiGLEdgeType, int]] = None, ) -> tuple[Tensor, Tensor, SequenceAuxiliaryData]: """ Transform a HeteroData object to Graph Transformer sequence input. @@ -132,6 +149,16 @@ def heterodata_to_graph_transformer_input( pairwise_attention_bias_attr_names: List of pairwise feature names used as attention bias. These must correspond to sparse graph-level attributes on ``data``. Example: ['pairwise_distance']. + relation_edge_types: Optional ordered edge types used to materialize sparse + relation coordinates. Each output relation index corresponds to one + edge type in this list. Directed edges are stored as + ``(batch_idx, query_pos=dst_token, key_pos=src_token, relation_idx)``. + edge_attr_edge_type_to_feat_dim_map: Optional ordered-by-sorted-key edge + feature dimensions used to materialize sparse edge-attribute + attention-bias payloads. Only edge types with positive feature dim + contribute. Directed edges are stored as + ``(batch_idx, query_pos=dst_token, key_pos=src_token)`` under the + same relation index as the sorted edge-type order. Returns: (sequences, valid_mask, attention_bias_data), where: @@ -144,7 +171,17 @@ def heterodata_to_graph_transformer_input( ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None - ``"pairwise_nonmissing_mask"`` shaped ``(batch, seq, seq)`` or None + ``"pairwise_nonmissing_indices"`` shaped ``(num_pairs, 3)`` or None, + storing ``(batch_idx, row_pos, col_pos)`` coordinates for + nonmissing pairwise entries + ``"pairwise_relation_indices"`` shaped + ``(num_relation_edges, 4)`` or None, storing + ``(batch_idx, query_pos, key_pos, relation_idx)`` coordinates + ``"pairwise_edge_attr_indices"`` as a dict mapping relation index + to ``(num_edges, 3)`` sparse ``(batch_idx, query_pos, key_pos)`` + coordinates, or None + ``"pairwise_edge_attr_values"`` as a dict mapping relation index + to ``(num_edges, edge_attr_dim)`` edge-attribute values, or None ``"token_input"`` as a dict mapping attribute name to a ``(batch, seq, 1)`` tensor, or None @@ -206,8 +243,12 @@ def heterodata_to_graph_transformer_input( device = data[anchor_node_type].x.device - # Convert to homogeneous for easier neighborhood extraction - homo_data = data.to_homogeneous() + # Convert to homogeneous for easier neighborhood extraction. In khop mode + # edge attributes stay on the original hetero edge stores because different + # relations may have different feature dimensions. + homo_data = data.to_homogeneous( + edge_attrs=[] if sequence_construction_method == "khop" else None + ) homo_x = homo_data.x # (total_nodes, feature_dim) num_nodes = homo_data.num_nodes @@ -308,7 +349,21 @@ def heterodata_to_graph_transformer_input( device=device, ) - pairwise_feature_sequences, pairwise_nonmissing_mask = ( + needs_token_occurrence_index = bool(relation_edge_types) or bool( + edge_attr_edge_type_to_feat_dim_map + ) + token_occurrences = ( + _build_token_occurrence_index( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + num_nodes=num_nodes, + device=device, + ) + if needs_token_occurrence_index + else None + ) + + pairwise_feature_sequences, pairwise_nonmissing_indices = ( _lookup_pairwise_relative_features( node_index_sequences=node_index_sequences, valid_mask=valid_mask, @@ -317,6 +372,28 @@ def heterodata_to_graph_transformer_input( device=device, ) ) + pairwise_relation_indices = _lookup_pairwise_relation_indices( + data=data, + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + relation_edge_types=relation_edge_types, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, + device=device, + token_occurrences=token_occurrences, + ) + pairwise_edge_attr_indices, pairwise_edge_attr_values = ( + _lookup_pairwise_edge_attr_payloads( + data=data, + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + edge_attr_edge_type_to_feat_dim_map=edge_attr_edge_type_to_feat_dim_map, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, + device=device, + token_occurrences=token_occurrences, + ) + ) anchor_bias_features = _compose_anchor_feature_tensor( anchor_relative_feature_sequences=anchor_relative_feature_sequences, @@ -337,7 +414,10 @@ def heterodata_to_graph_transformer_input( { "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, - "pairwise_nonmissing_mask": pairwise_nonmissing_mask, + "pairwise_nonmissing_indices": pairwise_nonmissing_indices, + "pairwise_relation_indices": pairwise_relation_indices, + "pairwise_edge_attr_indices": pairwise_edge_attr_indices, + "pairwise_edge_attr_values": pairwise_edge_attr_values, "token_input": token_input_features, }, ) @@ -831,9 +911,10 @@ def _lookup_pairwise_relative_features( features: (batch_size, max_seq_len, max_seq_len, num_attrs) tensor where features[b, i, j, k] = csr_matrices[k][node_index_sequences[b, i], node_index_sequences[b, j]] for valid (i, j) pairs, 0.0 for padding positions. - nonmissing_mask: (batch_size, max_seq_len, max_seq_len) bool tensor - that is True for valid diagonal self pairs and valid sparse - entries, and False for missing non-self pairs and padding. + nonmissing_indices: (num_nonmissing_pairs, 3) long tensor containing + ``(batch_idx, row_pos, col_pos)`` coordinates for valid diagonal + self pairs and valid sparse entries. Missing non-self pairs and + padding are omitted. Returns (None, None) if csr_matrices is empty. Example: @@ -866,46 +947,45 @@ def _lookup_pairwise_relative_features( dtype=torch.float, device=device, ) - nonmissing_mask = torch.zeros( - (batch_size, max_seq_len, max_seq_len), - dtype=torch.bool, + ( + valid_batch_indices, + valid_row_positions, + valid_col_positions, + valid_row_indices, + valid_col_indices, + ) = _build_flat_valid_pair_layout( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, device=device, ) + if valid_batch_indices.numel() == 0: + return features, torch.zeros((0, 3), dtype=torch.long, device=device) - pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1) - if not pair_valid_mask.any(): - return features, nonmissing_mask - - self_pair_mask = pair_valid_mask & torch.eye( - max_seq_len, - dtype=torch.bool, - device=device, - ).unsqueeze(0) - - row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len) - col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1) - - valid_row_indices = row_indices[pair_valid_mask] - valid_col_indices = col_indices[pair_valid_mask] + self_pair_mask = valid_row_positions == valid_col_positions first_attr_name = attr_names[0] if attr_names else "attr_0" + nonmissing_support: Optional[Tensor] = None for attr_idx, pe_matrix in enumerate(csr_matrices): pe_values, found_mask = _lookup_csr_values_and_found( csr_matrix=pe_matrix, row_indices=valid_row_indices, col_indices=valid_col_indices, ) - features[..., attr_idx][pair_valid_mask] = pe_values - attr_nonmissing_mask = torch.zeros_like(nonmissing_mask) - attr_nonmissing_mask[pair_valid_mask] = found_mask - attr_nonmissing_mask |= self_pair_mask + features[ + valid_batch_indices, + valid_row_positions, + valid_col_positions, + attr_idx, + ] = pe_values + attr_nonmissing_support = found_mask | self_pair_mask if attr_idx == 0: - nonmissing_mask = attr_nonmissing_mask + nonmissing_support = attr_nonmissing_support continue - if not torch.equal(nonmissing_mask, attr_nonmissing_mask): - attr_name = ( - attr_names[attr_idx] if attr_names else f"attr_{attr_idx}" - ) + if nonmissing_support is None or not torch.equal( + nonmissing_support, + attr_nonmissing_support, + ): + attr_name = attr_names[attr_idx] if attr_names else f"attr_{attr_idx}" raise ValueError( "Pairwise attention bias attributes must share identical " "nonmissing support after treating valid diagonal self pairs " @@ -913,7 +993,394 @@ def _lookup_pairwise_relative_features( "differ." ) - return features, nonmissing_mask + assert nonmissing_support is not None + pairwise_nonmissing_indices = torch.stack( + [ + valid_batch_indices[nonmissing_support], + valid_row_positions[nonmissing_support], + valid_col_positions[nonmissing_support], + ], + dim=1, + ) + return features, pairwise_nonmissing_indices + + +def _lookup_pairwise_relation_indices( + data: HeteroData, + node_index_sequences: Tensor, + valid_mask: Tensor, + relation_edge_types: Optional[list[GiGLEdgeType]], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, + token_occurrences: Optional[_TokenOccurrenceIndex] = None, +) -> Optional[Tensor]: + """Build sparse relation coordinates for valid token pairs. + + For a directed edge ``source -> target``, attention uses + ``query=target`` and ``key=source`` so relation-aware attention follows + message-passing orientation. + """ + if not relation_edge_types: + return None + + if token_occurrences is None: + token_occurrences = _build_token_occurrence_index( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + num_nodes=num_nodes, + device=device, + ) + if token_occurrences.batch_indices.numel() == 0: + return torch.zeros((0, 4), dtype=torch.long, device=device) + + relation_index_parts: list[Tensor] = [] + for relation_idx, edge_type in enumerate(relation_edge_types): + edge_type_tuple = edge_type.tuple_repr() + if edge_type_tuple not in data.edge_types: + continue + + edge_index = data[edge_type_tuple].edge_index.to( + device=device, dtype=torch.long + ) + if edge_index.numel() == 0: + continue + + src_offset = int(node_type_offsets[edge_type.src_node_type]) + dst_offset = int(node_type_offsets[edge_type.dst_node_type]) + source_indices = edge_index[0] + src_offset + target_indices = edge_index[1] + dst_offset + ( + relation_batch_indices, + relation_query_positions, + relation_key_positions, + _, + ) = _match_directed_edges_to_token_pairs( + source_indices=source_indices, + target_indices=target_indices, + token_occurrences=token_occurrences, + num_nodes=num_nodes, + device=device, + ) + if relation_batch_indices.numel() == 0: + continue + + relation_indices = torch.stack( + [ + relation_batch_indices, + relation_query_positions, + relation_key_positions, + torch.full( + (relation_batch_indices.size(0),), + relation_idx, + dtype=torch.long, + device=device, + ), + ], + dim=1, + ) + relation_index_parts.append(torch.unique(relation_indices, dim=0)) + + if not relation_index_parts: + return torch.zeros((0, 4), dtype=torch.long, device=device) + return torch.cat(relation_index_parts, dim=0) + + +def _lookup_pairwise_edge_attr_payloads( + data: HeteroData, + node_index_sequences: Tensor, + valid_mask: Tensor, + edge_attr_edge_type_to_feat_dim_map: Optional[dict[GiGLEdgeType, int]], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, + token_occurrences: Optional[_TokenOccurrenceIndex] = None, +) -> tuple[Optional[dict[int, Tensor]], Optional[dict[int, Tensor]]]: + """Build sparse edge-attribute payloads for valid token pairs. + + For a directed edge ``source -> target``, attention uses + ``query=target`` and ``key=source`` so edge-attribute bias follows the same + message-passing orientation as GAT. + """ + if not edge_attr_edge_type_to_feat_dim_map: + return None, None + + edge_attr_indices_by_relation: dict[int, Tensor] = {} + edge_attr_values_by_relation: dict[int, Tensor] = {} + if token_occurrences is None: + token_occurrences = _build_token_occurrence_index( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + num_nodes=num_nodes, + device=device, + ) + if token_occurrences.batch_indices.numel() == 0: + return edge_attr_indices_by_relation, edge_attr_values_by_relation + + for relation_idx, edge_type in enumerate( + sorted(edge_attr_edge_type_to_feat_dim_map.keys()) + ): + edge_attr_dim = int(edge_attr_edge_type_to_feat_dim_map[edge_type]) + if edge_attr_dim <= 0: + continue + + edge_type_tuple = edge_type.tuple_repr() + if edge_type_tuple not in data.edge_types: + continue + + edge_store = data[edge_type_tuple] + edge_index = edge_store.edge_index.to(device=device, dtype=torch.long) + if edge_index.numel() == 0: + continue + + if not hasattr(edge_store, "edge_attr") or edge_store.edge_attr is None: + raise ValueError( + "edge_attr_attention_bias_mode='sparse_linear' requires " + f"edge_attr for edge type {edge_type_tuple} because its " + f"configured feature dim is {edge_attr_dim}." + ) + edge_attr = edge_store.edge_attr.to(device=device) + if edge_attr.dim() == 1: + edge_attr = edge_attr.view(-1, 1) + if edge_attr.dim() != 2: + raise ValueError( + f"edge_attr for edge type {edge_type_tuple} must be 1D or 2D, " + f"got shape {tuple(edge_attr.shape)}." + ) + if edge_attr.size(0) != edge_index.size(1): + raise ValueError( + f"edge_attr for edge type {edge_type_tuple} has " + f"{edge_attr.size(0)} rows but edge_index has " + f"{edge_index.size(1)} edges." + ) + if edge_attr.size(1) != edge_attr_dim: + raise ValueError( + f"edge_attr for edge type {edge_type_tuple} has dim " + f"{edge_attr.size(1)} but configured dim is {edge_attr_dim}." + ) + + src_offset = int(node_type_offsets[edge_type.src_node_type]) + dst_offset = int(node_type_offsets[edge_type.dst_node_type]) + source_indices = edge_index[0] + src_offset + target_indices = edge_index[1] + dst_offset + ( + edge_batch_indices, + edge_query_positions, + edge_key_positions, + matched_edge_indices, + ) = _match_directed_edges_to_token_pairs( + source_indices=source_indices, + target_indices=target_indices, + token_occurrences=token_occurrences, + num_nodes=num_nodes, + device=device, + ) + if edge_batch_indices.numel() == 0: + continue + + edge_attr_indices_by_relation[relation_idx] = torch.stack( + [ + edge_batch_indices, + edge_query_positions, + edge_key_positions, + ], + dim=1, + ) + edge_attr_values_by_relation[relation_idx] = edge_attr[matched_edge_indices] + + return edge_attr_indices_by_relation, edge_attr_values_by_relation + + +def _build_token_occurrence_index( + node_index_sequences: Tensor, + valid_mask: Tensor, + num_nodes: int, + device: torch.device, +) -> _TokenOccurrenceIndex: + """Index valid sequence tokens for sparse directed-edge to token matching.""" + token_batch_indices, token_positions = torch.nonzero( + valid_mask, + as_tuple=True, + ) + token_batch_indices = token_batch_indices.to(device=device, dtype=torch.long) + token_positions = token_positions.to(device=device, dtype=torch.long) + token_node_indices = node_index_sequences[token_batch_indices, token_positions].to( + device=device, + dtype=torch.long, + ) + + sorted_token_node_indices, node_sort_perm = torch.sort(token_node_indices) + token_batch_node_keys = token_batch_indices * num_nodes + token_node_indices + sorted_token_batch_node_keys, batch_node_sort_perm = torch.sort( + token_batch_node_keys + ) + + return _TokenOccurrenceIndex( + batch_indices=token_batch_indices, + positions=token_positions, + node_indices=token_node_indices, + sorted_node_indices=sorted_token_node_indices, + node_sort_perm=node_sort_perm, + sorted_batch_node_keys=sorted_token_batch_node_keys, + batch_node_sort_perm=batch_node_sort_perm, + ) + + +def _match_directed_edges_to_token_pairs( + source_indices: Tensor, + target_indices: Tensor, + token_occurrences: _TokenOccurrenceIndex, + num_nodes: int, + device: torch.device, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Map ``source -> target`` graph edges onto valid sequence coordinates. + + The returned coordinates follow attention orientation: + ``query_pos=target_token`` and ``key_pos=source_token``. The final tensor + contains source edge-row ids, repeated when an edge is present in multiple + anchor sequences. + """ + empty = torch.zeros((0,), dtype=torch.long, device=device) + if source_indices.numel() == 0 or token_occurrences.batch_indices.numel() == 0: + return empty, empty, empty, empty + + source_indices = source_indices.to(device=device, dtype=torch.long) + target_indices = target_indices.to(device=device, dtype=torch.long) + + target_lower_bounds = torch.searchsorted( + token_occurrences.sorted_node_indices, + target_indices, + right=False, + ) + target_upper_bounds = torch.searchsorted( + token_occurrences.sorted_node_indices, + target_indices, + right=True, + ) + target_match_counts = target_upper_bounds - target_lower_bounds + matched_edge_mask = target_match_counts > 0 + if not matched_edge_mask.any(): + return empty, empty, empty, empty + + matched_edge_indices = torch.nonzero(matched_edge_mask, as_tuple=True)[0] + matched_target_counts = target_match_counts[matched_edge_indices] + total_target_matches = int(matched_target_counts.sum().item()) + repeated_target_edge_indices = torch.repeat_interleave( + matched_edge_indices, + matched_target_counts, + ) + repeated_target_lower_bounds = torch.repeat_interleave( + target_lower_bounds[matched_edge_indices], + matched_target_counts, + ) + target_group_start_offsets = torch.repeat_interleave( + torch.cumsum(matched_target_counts, dim=0) - matched_target_counts, + matched_target_counts, + ) + target_sorted_positions = ( + repeated_target_lower_bounds + + torch.arange(total_target_matches, device=device, dtype=torch.long) + - target_group_start_offsets + ) + target_token_indices = token_occurrences.node_sort_perm[target_sorted_positions] + target_batch_indices = token_occurrences.batch_indices[target_token_indices] + target_query_positions = token_occurrences.positions[target_token_indices] + + source_query_keys = ( + target_batch_indices * num_nodes + source_indices[repeated_target_edge_indices] + ) + source_lower_bounds = torch.searchsorted( + token_occurrences.sorted_batch_node_keys, + source_query_keys, + right=False, + ) + source_upper_bounds = torch.searchsorted( + token_occurrences.sorted_batch_node_keys, + source_query_keys, + right=True, + ) + source_match_counts = source_upper_bounds - source_lower_bounds + matched_target_mask = source_match_counts > 0 + if not matched_target_mask.any(): + return empty, empty, empty, empty + + matched_target_indices = torch.nonzero(matched_target_mask, as_tuple=True)[0] + matched_source_counts = source_match_counts[matched_target_indices] + total_source_matches = int(matched_source_counts.sum().item()) + repeated_target_indices = torch.repeat_interleave( + matched_target_indices, + matched_source_counts, + ) + repeated_source_lower_bounds = torch.repeat_interleave( + source_lower_bounds[matched_target_indices], + matched_source_counts, + ) + source_group_start_offsets = torch.repeat_interleave( + torch.cumsum(matched_source_counts, dim=0) - matched_source_counts, + matched_source_counts, + ) + source_sorted_positions = ( + repeated_source_lower_bounds + + torch.arange(total_source_matches, device=device, dtype=torch.long) + - source_group_start_offsets + ) + source_token_indices = token_occurrences.batch_node_sort_perm[ + source_sorted_positions + ] + + return ( + target_batch_indices[repeated_target_indices], + target_query_positions[repeated_target_indices], + token_occurrences.positions[source_token_indices], + repeated_target_edge_indices[repeated_target_indices], + ) + + +def _build_flat_valid_pair_layout( + node_index_sequences: Tensor, + valid_mask: Tensor, + device: torch.device, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Enumerate valid sequence pairs without building dense pairwise masks.""" + batch_indices_parts: list[Tensor] = [] + row_positions_parts: list[Tensor] = [] + col_positions_parts: list[Tensor] = [] + row_node_indices_parts: list[Tensor] = [] + col_node_indices_parts: list[Tensor] = [] + + for batch_idx in range(valid_mask.size(0)): + valid_positions = torch.nonzero(valid_mask[batch_idx], as_tuple=True)[0] + num_valid = valid_positions.numel() + if num_valid == 0: + continue + + valid_node_indices = node_index_sequences[batch_idx, valid_positions] + pair_count = num_valid * num_valid + + batch_indices_parts.append( + torch.full( + (pair_count,), + batch_idx, + dtype=torch.long, + device=device, + ) + ) + row_positions_parts.append(valid_positions.repeat_interleave(num_valid)) + col_positions_parts.append(valid_positions.repeat(num_valid)) + row_node_indices_parts.append(valid_node_indices.repeat_interleave(num_valid)) + col_node_indices_parts.append(valid_node_indices.repeat(num_valid)) + + if not batch_indices_parts: + empty = torch.zeros((0,), dtype=torch.long, device=device) + return empty, empty, empty, empty, empty + + return ( + torch.cat(batch_indices_parts, dim=0), + torch.cat(row_positions_parts, dim=0), + torch.cat(col_positions_parts, dim=0), + torch.cat(row_node_indices_parts, dim=0), + torch.cat(col_node_indices_parts, dim=0), + ) def _get_k_hop_neighbors_sparse( @@ -1039,32 +1506,35 @@ def _lookup_csr_values_and_found( col_indices_csr = csr_matrix.col_indices() values_csr = csr_matrix.values() - row_starts = crow_indices[row_indices] - row_ends = crow_indices[row_indices + 1] - row_lengths = row_ends - row_starts - max_row_len = row_lengths.max().item() - - if max_row_len == 0: + if col_indices_csr.numel() == 0: return ( torch.full((n,), default_value, device=device, dtype=torch.float), torch.zeros((n,), device=device, dtype=torch.bool), ) - offsets = row_starts.unsqueeze(1) + torch.arange(max_row_len, device=device) - valid_mask = offsets < row_ends.unsqueeze(1) - - nnz = col_indices_csr.size(0) - offsets_clamped = offsets.clamp(max=max(nnz - 1, 0)) - - cols_at_offsets = col_indices_csr[offsets_clamped] - col_matches = (cols_at_offsets == col_indices.unsqueeze(1)) & valid_mask - found = col_matches.any(dim=1) + num_rows, num_cols = csr_matrix.size() + row_counts = crow_indices[1:] - crow_indices[:-1] + csr_row_indices = torch.repeat_interleave( + torch.arange(num_rows, device=device), + row_counts, + ) + # CSR stores entries grouped by row, and sparse graph features are emitted + # with sorted column indices per row, so linearized row-major keys remain + # globally sorted for searchsorted. + csr_keys = csr_row_indices * num_cols + col_indices_csr + query_keys = row_indices * num_cols + col_indices + match_positions = torch.searchsorted(csr_keys, query_keys) + + candidate_mask = match_positions < csr_keys.numel() + found = torch.zeros((n,), device=device, dtype=torch.bool) + if candidate_mask.any(): + valid_match_positions = match_positions[candidate_mask] + found[candidate_mask] = ( + csr_keys[valid_match_positions] == query_keys[candidate_mask] + ) result = torch.full((n,), default_value, device=device, dtype=torch.float) - if found.any(): - match_offsets = col_matches.float().argmax(dim=1) - value_indices = row_starts[found] + match_offsets[found] - result[found] = values_csr[value_indices].float() + result[found] = values_csr[match_positions[found]].float() return result, found diff --git a/tests/unit/nn/graph_transformer_test.py b/tests/unit/nn/graph_transformer_test.py index 091e69713..05670939c 100644 --- a/tests/unit/nn/graph_transformer_test.py +++ b/tests/unit/nn/graph_transformer_test.py @@ -1,5 +1,7 @@ """Tests for GraphTransformerEncoder.""" +import sys +import types from typing import cast import torch @@ -8,6 +10,117 @@ from torch import Tensor from torch_geometric.data import HeteroData + +def _install_torchrec_stub() -> None: + if "torchrec" in sys.modules: + return + + torchrec_module = types.ModuleType("torchrec") + distributed_module = types.ModuleType("torchrec.distributed") + distributed_types_module = types.ModuleType("torchrec.distributed.types") + modules_module = types.ModuleType("torchrec.modules") + embedding_configs_module = types.ModuleType("torchrec.modules.embedding_configs") + embedding_modules_module = types.ModuleType("torchrec.modules.embedding_modules") + sparse_module = types.ModuleType("torchrec.sparse") + jagged_tensor_module = types.ModuleType("torchrec.sparse.jagged_tensor") + + class Awaitable: # pragma: no cover - import compatibility shim + pass + + class EmbeddingBagConfig: # pragma: no cover - import compatibility shim + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + + class EmbeddingBagCollection: # pragma: no cover - import compatibility shim + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + + class KeyedJaggedTensor: # pragma: no cover - import compatibility shim + pass + + distributed_types_module.Awaitable = Awaitable + embedding_configs_module.EmbeddingBagConfig = EmbeddingBagConfig + embedding_modules_module.EmbeddingBagCollection = EmbeddingBagCollection + jagged_tensor_module.KeyedJaggedTensor = KeyedJaggedTensor + + torchrec_module.distributed = distributed_module + torchrec_module.modules = modules_module + torchrec_module.sparse = sparse_module + distributed_module.types = distributed_types_module + modules_module.embedding_configs = embedding_configs_module + modules_module.embedding_modules = embedding_modules_module + sparse_module.jagged_tensor = jagged_tensor_module + + sys.modules["torchrec"] = torchrec_module + sys.modules["torchrec.distributed"] = distributed_module + sys.modules["torchrec.distributed.types"] = distributed_types_module + sys.modules["torchrec.modules"] = modules_module + sys.modules["torchrec.modules.embedding_configs"] = embedding_configs_module + sys.modules["torchrec.modules.embedding_modules"] = embedding_modules_module + sys.modules["torchrec.sparse"] = sparse_module + sys.modules["torchrec.sparse.jagged_tensor"] = jagged_tensor_module + + +def _install_graphlearn_torch_stub() -> None: + if "graphlearn_torch.partition" in sys.modules: + return + + graphlearn_torch_module = types.ModuleType("graphlearn_torch") + partition_module = types.ModuleType("graphlearn_torch.partition") + + class PartitionBook: # pragma: no cover - import compatibility shim + pass + + partition_module.PartitionBook = PartitionBook + graphlearn_torch_module.partition = partition_module + + sys.modules["graphlearn_torch"] = graphlearn_torch_module + sys.modules["graphlearn_torch.partition"] = partition_module + + +def _install_tensorflow_metadata_stub() -> None: + if "tensorflow_metadata.proto.v0.schema_pb2" in sys.modules: + return + + tensorflow_metadata_module = types.ModuleType("tensorflow_metadata") + proto_module = types.ModuleType("tensorflow_metadata.proto") + v0_module = types.ModuleType("tensorflow_metadata.proto.v0") + schema_pb2_module = types.ModuleType("tensorflow_metadata.proto.v0.schema_pb2") + + class Feature: # pragma: no cover - import compatibility shim + pass + + class Schema: # pragma: no cover - import compatibility shim + pass + + schema_pb2_module.Feature = Feature + schema_pb2_module.Schema = Schema + tensorflow_metadata_module.proto = proto_module + proto_module.v0 = v0_module + v0_module.schema_pb2 = schema_pb2_module + + sys.modules["tensorflow_metadata"] = tensorflow_metadata_module + sys.modules["tensorflow_metadata.proto"] = proto_module + sys.modules["tensorflow_metadata.proto.v0"] = v0_module + sys.modules["tensorflow_metadata.proto.v0.schema_pb2"] = schema_pb2_module + + +def _install_tensorflow_transform_stub() -> None: + if "tensorflow_transform" in sys.modules: + return + + tensorflow_transform_module = types.ModuleType("tensorflow_transform") + common_types = types.SimpleNamespace(FeatureSpecType=object, TensorType=object) + + tensorflow_transform_module.common_types = common_types + sys.modules["tensorflow_transform"] = tensorflow_transform_module + + +_install_tensorflow_metadata_stub() +_install_tensorflow_transform_stub() +_install_torchrec_stub() +_install_graphlearn_torch_stub() + from gigl.nn.graph_transformer import ( FeedForwardNetwork, GraphTransformerEncoder, @@ -278,6 +391,18 @@ def _create_user_graph_with_ppr_edges() -> HeteroData: return data +def _pairwise_nonmissing_indices( + coords: list[tuple[int, int, int]], +) -> torch.Tensor: + return torch.tensor(coords, dtype=torch.long) + + +def _pairwise_relation_indices( + coords: list[tuple[int, int, int, int]], +) -> torch.Tensor: + return torch.tensor(coords, dtype=torch.long) + + class TestGraphTransformerEncoderPEModes(TestCase): def setUp(self) -> None: self._node_type = NodeType("user") @@ -416,7 +541,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: ] ] ), - "pairwise_nonmissing_mask": None, + "pairwise_nonmissing_indices": None, "token_input": None, }, ) @@ -453,7 +578,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( [[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]] ), "pairwise_bias": None, - "pairwise_nonmissing_mask": None, + "pairwise_nonmissing_indices": None, "token_input": None, }, ) @@ -464,7 +589,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( self.assertEqual(attn_bias[0, 0, 0, 2].item(), 4.25) self.assertEqual(attn_bias[0, 1, 0, 2].item(), 8.5) - def test_pairwise_nonmissing_mask_adds_head_specific_bias(self) -> None: + def test_pairwise_nonmissing_indices_add_head_specific_bias(self) -> None: encoder = self._create_encoder( pairwise_attention_bias_attr_names=["pairwise_distance"], ) @@ -474,9 +599,7 @@ def test_pairwise_nonmissing_mask_adds_head_specific_bias(self) -> None: with torch.no_grad(): encoder._pairwise_pe_attention_bias_projection.weight.zero_() - encoder._pairwise_nonmissing_attention_bias.copy_( - torch.tensor([0.5, 1.5]) - ) + encoder._pairwise_nonmissing_attention_bias.copy_(torch.tensor([0.5, 1.5])) attn_bias = encoder._build_attention_bias( valid_mask=torch.ones((1, 3), dtype=torch.bool), @@ -484,14 +607,8 @@ def test_pairwise_nonmissing_mask_adds_head_specific_bias(self) -> None: attention_bias_data={ "anchor_bias": None, "pairwise_bias": torch.zeros((1, 3, 3, 1), dtype=torch.float), - "pairwise_nonmissing_mask": torch.tensor( - [ - [ - [True, True, False], - [True, True, False], - [False, False, True], - ] - ] + "pairwise_nonmissing_indices": _pairwise_nonmissing_indices( + [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (0, 2, 2)] ), "token_input": None, }, @@ -503,6 +620,408 @@ def test_pairwise_nonmissing_mask_adds_head_specific_bias(self) -> None: self.assertEqual(attn_bias[0, 0, 0, 2].item(), 0.0) self.assertEqual(attn_bias[0, 1, 1, 2].item(), 0.0) + def test_relation_attention_zero_init_matches_plain_layer(self) -> None: + torch.manual_seed(0) + base_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + relation_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_bilinear", + num_relations=2, + ) + relation_layer.load_state_dict(base_layer.state_dict(), strict=False) + base_layer.eval() + relation_layer.eval() + + x = torch.randn(2, 4, 8) + valid_mask = torch.ones((2, 4), dtype=torch.bool) + + with torch.no_grad(): + assert relation_layer._relation_attention_matrices is not None + self.assertTrue( + torch.equal( + relation_layer._relation_attention_matrices, + torch.zeros_like(relation_layer._relation_attention_matrices), + ) + ) + base_output = base_layer(x, valid_mask=valid_mask) + relation_output = relation_layer( + x, + pairwise_relation_indices=_pairwise_relation_indices( + [(0, 1, 0, 0), (1, 2, 1, 1)] + ), + valid_mask=valid_mask, + ) + + self.assertTrue(torch.allclose(base_output, relation_output, atol=1e-6)) + + def test_relation_value_mode_none_ignores_relation_indices(self) -> None: + torch.manual_seed(0) + layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + layer.eval() + x = torch.randn(1, 3, 8) + + with torch.no_grad(): + base_output = layer(x, valid_mask=torch.ones((1, 3), dtype=torch.bool)) + relation_index_output = layer( + x, + pairwise_relation_indices=_pairwise_relation_indices( + [(0, 1, 0, 0), (0, 2, 1, 0)] + ), + valid_mask=torch.ones((1, 3), dtype=torch.bool), + ) + + self.assertTrue(torch.allclose(base_output, relation_index_output, atol=1e-6)) + + def test_relation_value_zero_init_matches_plain_layer(self) -> None: + torch.manual_seed(0) + base_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + relation_value_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_value_mode="sparse_residual_gate", + num_relations=2, + ) + relation_value_layer.load_state_dict(base_layer.state_dict(), strict=False) + base_layer.eval() + relation_value_layer.eval() + + x = torch.randn(2, 4, 8) + valid_mask = torch.ones((2, 4), dtype=torch.bool) + + with torch.no_grad(): + assert relation_value_layer._relation_value_gates is not None + self.assertTrue( + torch.equal( + relation_value_layer._relation_value_gates, + torch.zeros_like(relation_value_layer._relation_value_gates), + ) + ) + base_output = base_layer(x, valid_mask=valid_mask) + relation_value_output = relation_value_layer( + x, + pairwise_relation_indices=_pairwise_relation_indices( + [(0, 1, 0, 0), (1, 2, 1, 1)] + ), + valid_mask=valid_mask, + ) + + self.assertTrue(torch.allclose(base_output, relation_value_output, atol=1e-6)) + + def test_relation_value_residual_affects_indexed_queries_and_normalizes( + self, + ) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=4, + num_heads=2, + feedforward_dim=8, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_value_mode="sparse_residual_gate", + num_relations=2, + ) + attention_output = torch.zeros((1, 2, 3, 2)) + value = torch.tensor( + [ + [ + [[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]], + [[4.0, 40.0], [5.0, 50.0], [6.0, 60.0]], + ] + ] + ) + + with torch.no_grad(): + assert layer._relation_value_gates is not None + layer._relation_value_gates[0] = torch.tensor( + [[1.0, 0.5], [0.25, 0.0]] + ) + layer._relation_value_gates[1] = torch.tensor( + [[2.0, 0.0], [0.0, 3.0]] + ) + actual = layer._add_relation_value_residual( + attention_output=attention_output, + value=value, + pairwise_relation_indices=_pairwise_relation_indices( + [ + (0, 1, 0, 0), + (0, 1, 0, 0), + (0, 1, 2, 1), + (0, 2, 1, 0), + ] + ), + ) + + expected = torch.zeros_like(attention_output) + expected[0, 0, 1] = torch.tensor([8.0 / 3.0, 10.0 / 3.0]) + expected[0, 1, 1] = torch.tensor([2.0 / 3.0, 60.0]) + expected[0, 0, 2] = torch.tensor([2.0, 10.0]) + expected[0, 1, 2] = torch.tensor([1.25, 0.0]) + self.assertTrue(torch.allclose(actual, expected, atol=1e-6)) + + def test_relation_value_residual_rejects_invalid_relation_ids(self) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=4, + num_heads=2, + feedforward_dim=8, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_value_mode="sparse_residual_gate", + num_relations=1, + ) + + with self.assertRaisesRegex(ValueError, "relation ids outside"): + layer._add_relation_value_residual( + attention_output=torch.zeros((1, 2, 2, 2)), + value=torch.zeros((1, 2, 2, 2)), + pairwise_relation_indices=_pairwise_relation_indices([(0, 1, 0, 1)]), + ) + + def test_relation_attention_nonzero_bias_only_indexed_pairs(self) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=2, + num_heads=1, + feedforward_dim=4, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_bilinear", + num_relations=1, + ) + query = torch.tensor([[[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]]]) + key = torch.tensor([[[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]]]) + + with torch.no_grad(): + assert layer._relation_attention_matrices is not None + layer._relation_attention_matrices[0, 0] = torch.eye(2) + relation_bias = layer._build_relation_attention_bias( + query=query, + key=key, + pairwise_relation_indices=_pairwise_relation_indices([(0, 2, 1, 0)]), + ) + + assert relation_bias is not None + expected = torch.zeros((1, 1, 3, 3)) + expected[0, 0, 2, 1] = 1.0 / torch.sqrt(torch.tensor(2.0)).item() + self.assertTrue(torch.allclose(relation_bias, expected, atol=1e-6)) + + def test_relation_attention_respects_existing_negative_bias(self) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=2, + num_heads=1, + feedforward_dim=4, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_bilinear", + num_relations=1, + ) + query = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + key = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + negative_inf = torch.finfo(torch.float).min + attn_bias = torch.zeros((1, 1, 2, 2), dtype=torch.float) + attn_bias[0, 0, 0, 0] = negative_inf + + with torch.no_grad(): + assert layer._relation_attention_matrices is not None + layer._relation_attention_matrices[0, 0] = torch.eye(2) + relation_bias = layer._build_relation_attention_bias( + query=query, + key=key, + pairwise_relation_indices=_pairwise_relation_indices([(0, 0, 0, 0)]), + ) + + assert relation_bias is not None + self.assertGreater(relation_bias[0, 0, 0, 0].item(), 0.0) + self.assertEqual((attn_bias + relation_bias)[0, 0, 0, 0].item(), negative_inf) + + def test_relation_attention_supports_ppr_sequence_construction(self) -> None: + data = _create_user_graph_with_ppr_edges() + ppr_edge_type = EdgeType(self._node_type, Relation("ppr"), self._node_type) + + base_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + sequence_construction_method="ppr", + ) + relation_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + sequence_construction_method="ppr", + relation_attention_mode="edge_type_bilinear", + ) + relation_encoder.load_state_dict(base_encoder.state_dict(), strict=False) + base_encoder.eval() + relation_encoder.eval() + + with torch.no_grad(): + base_output = base_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + relation_output = relation_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertTrue(torch.allclose(base_output, relation_output, atol=1e-6)) + + def test_edge_attr_attention_bias_zero_init_matches_baseline(self) -> None: + data = HeteroData() + data["user"].x = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + ] + ) + data["user"].batch_size = 1 + data[self._edge_type.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + data[self._edge_type.tuple_repr()].edge_attr = torch.tensor([[3.0]]) + + edge_type_to_feat_dim_map = {self._edge_type: 1} + base_encoder = self._create_encoder( + edge_type_to_feat_dim_map=edge_type_to_feat_dim_map, + ) + edge_attr_encoder = self._create_encoder( + edge_type_to_feat_dim_map=edge_type_to_feat_dim_map, + edge_attr_attention_bias_mode="sparse_linear", + ) + edge_attr_encoder.load_state_dict(base_encoder.state_dict(), strict=False) + base_encoder.eval() + edge_attr_encoder.eval() + + with torch.no_grad(): + edge_attr_projection = ( + edge_attr_encoder._edge_attr_attention_bias_projection_dict["0"] + ) + self.assertTrue( + torch.equal( + edge_attr_projection.weight, + torch.zeros_like(edge_attr_projection.weight), + ) + ) + base_output = base_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + edge_attr_output = edge_attr_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertTrue(torch.allclose(base_output, edge_attr_output, atol=1e-6)) + + def test_edge_attr_attention_bias_adds_only_indexed_pairs_and_accumulates( + self, + ) -> None: + encoder = self._create_encoder( + edge_type_to_feat_dim_map={self._edge_type: 2}, + edge_attr_attention_bias_mode="sparse_linear", + ) + projection = encoder._edge_attr_attention_bias_projection_dict["0"] + with torch.no_grad(): + projection.weight.copy_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + + attn_bias = encoder._build_attention_bias( + valid_mask=torch.tensor([[True, True, True]]), + sequences=torch.zeros((1, 3, 8), dtype=torch.float), + attention_bias_data={ + "anchor_bias": None, + "pairwise_bias": None, + "pairwise_nonmissing_indices": None, + "pairwise_edge_attr_indices": { + 0: torch.tensor( + [ + (0, 1, 0), + (0, 1, 0), + (0, 2, 1), + ], + dtype=torch.long, + ) + }, + "pairwise_edge_attr_values": { + 0: torch.tensor( + [ + [1.0, 1.0], + [2.0, 0.0], + [0.0, 1.0], + ] + ) + }, + "token_input": None, + }, + ) + + expected = torch.zeros((1, 2, 3, 3), dtype=torch.float) + expected[0, :, 1, 0] = torch.tensor([5.0, 13.0]) + expected[0, :, 2, 1] = torch.tensor([2.0, 4.0]) + self.assertTrue(torch.allclose(attn_bias, expected, atol=1e-6)) + + def test_edge_attr_attention_bias_supports_ppr_sequence_construction( + self, + ) -> None: + data = _create_user_graph_with_ppr_edges() + ppr_edge_type = EdgeType(self._node_type, Relation("ppr"), self._node_type) + + base_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 1}, + sequence_construction_method="ppr", + ) + edge_attr_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 1}, + sequence_construction_method="ppr", + edge_attr_attention_bias_mode="sparse_linear", + ) + edge_attr_encoder.load_state_dict(base_encoder.state_dict(), strict=False) + base_encoder.eval() + edge_attr_encoder.eval() + + with torch.no_grad(): + edge_attr_projection = ( + edge_attr_encoder._edge_attr_attention_bias_projection_dict["0"] + ) + self.assertTrue( + torch.equal( + edge_attr_projection.weight, + torch.zeros_like(edge_attr_projection.weight), + ) + ) + base_output = base_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + edge_attr_output = edge_attr_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertTrue(torch.allclose(base_output, edge_attr_output, atol=1e-6)) + def test_sinusoidal_sequence_positional_encoding_masks_padding(self) -> None: encoder = self._create_encoder( sequence_construction_method="ppr", diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 0cb69150f..69cab0cd5 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -7,6 +7,7 @@ from absl.testing import absltest from torch_geometric.data import HeteroData +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.transforms.graph_transformer import ( _get_k_hop_neighbors_sparse, heterodata_to_graph_transformer_input, @@ -126,6 +127,27 @@ def create_ppr_sequence_hetero_data() -> HeteroData: return data +def _dense_nonmissing_mask_from_indices( + pairwise_nonmissing_indices: torch.Tensor | None, + batch_size: int, + seq_len: int, + device: torch.device, +) -> torch.Tensor: + dense_mask = torch.zeros( + (batch_size, seq_len, seq_len), + dtype=torch.bool, + device=device, + ) + if pairwise_nonmissing_indices is None or pairwise_nonmissing_indices.numel() == 0: + return dense_mask + dense_mask[ + pairwise_nonmissing_indices[:, 0], + pairwise_nonmissing_indices[:, 1], + pairwise_nonmissing_indices[:, 2], + ] = True + return dense_mask + + class TestGetKHopNeighborsSparse(TestCase): """Tests for _get_k_hop_neighbors_sparse helper function.""" @@ -256,7 +278,10 @@ def test_basic_transform(self): self.assertIsInstance(attention_bias_data, dict) self.assertIn("anchor_bias", attention_bias_data) self.assertIn("pairwise_bias", attention_bias_data) - self.assertIn("pairwise_nonmissing_mask", attention_bias_data) + self.assertIn("pairwise_nonmissing_indices", attention_bias_data) + self.assertIn("pairwise_relation_indices", attention_bias_data) + self.assertIn("pairwise_edge_attr_indices", attention_bias_data) + self.assertIn("pairwise_edge_attr_values", attention_bias_data) def test_attention_mask_validity(self): """Test that attention mask correctly identifies valid positions.""" @@ -310,6 +335,148 @@ def test_anchor_first(self): # First position should be anchor node self.assertTrue(torch.allclose(sequences[0, 0], anchor_feature)) + def test_pairwise_relation_indices_follow_order_direction_and_padding(self): + """Sparse relation indices preserve edge-type labels before homogenization.""" + user = NodeType("user") + likes = EdgeType(user, Relation("likes"), user) + follows = EdgeType(user, Relation("follows"), user) + missing = EdgeType(user, Relation("missing"), user) + + data = HeteroData() + data["user"].x = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ] + ) + data["user"].batch_size = 1 + data[likes.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + data[follows.tuple_repr()].edge_index = torch.tensor([[0, 1], [1, 2]]) + + _, valid_mask, auxiliary_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + relation_edge_types=[likes, follows, missing], + ) + + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + pairwise_relation_indices = auxiliary_data["pairwise_relation_indices"] + self.assertIsNotNone(pairwise_relation_indices) + assert pairwise_relation_indices is not None + self.assertEqual(pairwise_relation_indices.shape[1], 4) + self.assertEqual( + {tuple(coord) for coord in pairwise_relation_indices.tolist()}, + { + (0, 1, 0, 0), # likes: source 0 -> target 1 + (0, 1, 0, 1), # follows: source 0 -> target 1 + (0, 2, 1, 1), # follows: source 1 -> target 2 + }, + ) + self.assertFalse((pairwise_relation_indices[:, 1:3] == 3).any().item()) + self.assertFalse((pairwise_relation_indices[:, 3] == 2).any().item()) + + def test_pairwise_edge_attr_payloads_follow_order_direction_and_padding(self): + """Sparse edge-attr payloads preserve relation labels and GAT direction.""" + user = NodeType("user") + likes = EdgeType(user, Relation("likes"), user) + follows = EdgeType(user, Relation("follows"), user) + missing = EdgeType(user, Relation("missing"), user) + + data = HeteroData() + data["user"].x = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ] + ) + data["user"].batch_size = 1 + data[likes.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + data[likes.tuple_repr()].edge_attr = torch.tensor([[0.25, 1.25]]) + data[follows.tuple_repr()].edge_index = torch.tensor([[0, 1], [1, 2]]) + data[follows.tuple_repr()].edge_attr = torch.tensor([[2.0], [3.0]]) + + edge_attr_dim_map = { + likes: 2, + follows: 1, + missing: 1, + } + _, valid_mask, auxiliary_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + edge_attr_edge_type_to_feat_dim_map=edge_attr_dim_map, + ) + + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + edge_attr_indices = auxiliary_data["pairwise_edge_attr_indices"] + edge_attr_values = auxiliary_data["pairwise_edge_attr_values"] + self.assertIsNotNone(edge_attr_indices) + self.assertIsNotNone(edge_attr_values) + assert edge_attr_indices is not None + assert edge_attr_values is not None + + sorted_edge_types = sorted(edge_attr_dim_map.keys()) + likes_idx = sorted_edge_types.index(likes) + follows_idx = sorted_edge_types.index(follows) + missing_idx = sorted_edge_types.index(missing) + + self.assertEqual( + {tuple(coord) for coord in edge_attr_indices[likes_idx].tolist()}, + {(0, 1, 0)}, # likes: source 0 -> target 1 + ) + self.assertTrue( + torch.allclose( + edge_attr_values[likes_idx], + torch.tensor([[0.25, 1.25]]), + ) + ) + self.assertEqual( + {tuple(coord) for coord in edge_attr_indices[follows_idx].tolist()}, + { + (0, 1, 0), # follows: source 0 -> target 1 + (0, 2, 1), # follows: source 1 -> target 2 + }, + ) + self.assertTrue( + torch.allclose(edge_attr_values[follows_idx], torch.tensor([[2.0], [3.0]])) + ) + self.assertNotIn(missing_idx, edge_attr_indices) + self.assertFalse((edge_attr_indices[likes_idx] == 3).any().item()) + self.assertFalse((edge_attr_indices[follows_idx] == 3).any().item()) + + def test_pairwise_edge_attr_payloads_missing_edge_attr_raises(self): + user = NodeType("user") + follows = EdgeType(user, Relation("follows"), user) + + data = HeteroData() + data["user"].x = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + data["user"].batch_size = 1 + data[follows.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + + with self.assertRaisesRegex( + ValueError, + "requires edge_attr for edge type", + ): + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=2, + anchor_node_type="user", + hop_distance=1, + edge_attr_edge_type_to_feat_dim_map={follows: 1}, + ) + def test_different_anchor_types(self): """Test with different anchor node types.""" data = create_simple_hetero_data() @@ -797,7 +964,7 @@ def test_transform_returns_base_sequences_and_anchor_relative_bias(self) -> None assert anchor_bias is not None self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertIsNone(attention_bias_data["pairwise_bias"]) - self.assertIsNone(attention_bias_data["pairwise_nonmissing_mask"]) + self.assertIsNone(attention_bias_data["pairwise_nonmissing_indices"]) self.assertTrue(valid_mask[0, 0].item()) def test_attention_bias_outputs_include_valid_mask_and_relative_features( @@ -823,13 +990,19 @@ def test_attention_bias_outputs_include_valid_mask_and_relative_features( self.assertEqual(valid_mask.shape, (1, 4)) anchor_bias = attention_bias_data["anchor_bias"] pairwise_bias = attention_bias_data["pairwise_bias"] - pairwise_nonmissing_mask = attention_bias_data["pairwise_nonmissing_mask"] + pairwise_nonmissing_indices = attention_bias_data["pairwise_nonmissing_indices"] assert anchor_bias is not None assert pairwise_bias is not None - assert pairwise_nonmissing_mask is not None + assert pairwise_nonmissing_indices is not None + pairwise_nonmissing_mask = _dense_nonmissing_mask_from_indices( + pairwise_nonmissing_indices=pairwise_nonmissing_indices, + batch_size=1, + seq_len=4, + device=pairwise_bias.device, + ) self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertEqual(pairwise_bias.shape, (1, 4, 4, 1)) - self.assertEqual(pairwise_nonmissing_mask.shape, (1, 4, 4)) + self.assertEqual(pairwise_nonmissing_indices.shape[1], 3) self.assertAlmostEqual(anchor_bias[0, 0, 0].item(), 0.0, places=5) self.assertAlmostEqual(anchor_bias[0, 1, 0].item(), 1.0, places=5) self.assertAlmostEqual(anchor_bias[0, 2, 0].item(), 3.0, places=5)