Skip to content

[Codegen] Fail to capture nested dict for pipeline parallelism #49

@zyeric

Description

@zyeric

Reproduce snippet

#  Copyright (c) Microsoft Corporation.
#  Licensed under the MIT License.

import torch
import torch.nn as nn
import tempfile
import shutil
import contextlib
import pytest
from pathlib import Path


import nnscaler
import nnscaler.graph.function.function as F
from nnscaler.ir.tensor import IRFullTensor
from nnscaler.graph import IRGraph
from nnscaler.ir.adapter import IRAdapter
from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer
from nnscaler.ir.operator import IRFwOperation, IRDataOperation
from nnscaler.graph.segment import IRSegment
from nnscaler.graph.schedule.predefined import PredefinedSched
from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause
from tests.launch_torchrun import torchrun
from tests.parallel_module.test_gencode import _gencode_contains, print_gencode


class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 4, bias=False)

    def forward(self, x, context):
        return self.linear(x) + context['bias']


class Model(torch.nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = [Layer() for _ in range(num_layers)]

    def forward(self, x, content):
        context = content['content']
        for layer in self.layers:
            x = layer(x, context)
        x = x.sum()
        return x


def pas(graph, cfg):
    print('graph', graph.nodes())
    dataloader, fc1, gi1, add1, fc2, gi2, add2, loss = graph.nodes()[:8]

    graph.staging([fc1, fc2])
    stages = graph.select(ntype=IRSegment, flatten=False)
    stages = [s for s in stages if s.isfw()]

    ngpus = cfg.plan_ngpus
    sub_nodes = graph.replicate(dataloader, ngpus)
    for i, sub_node in enumerate(sub_nodes):
        graph.assign(sub_node, i)

    print('stage 0', stages[0].nodes())
    print('stage 1', stages[1].nodes())
    for node in stages[0].nodes():
        graph.assign(node, 0)

    for node in stages[1].nodes():
        graph.assign(node, 1)
    return graph


def test_non_tensor_pp():
    m = Model(2)
    m.train()
    torch.manual_seed(0)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(0)
    x = torch.randn([2, 4], dtype=torch.float32, device=torch.cuda.current_device())
    content = {'content': {'bias': torch.randn([4], dtype=torch.float32, device=torch.cuda.current_device())}}

    with tempfile.TemporaryDirectory() as tempdir:
        parallelize(
            m,
            {'x': x, 'content': content},
            pas,
            ComputeConfig(2, 2, use_end2end=True),
            reuse='override',
            gen_savedir=tempdir,
            load_module=False,
        )
        print_gencode(tempdir, Model, 0)
        print_gencode(tempdir, Model, 1)

Description

In rank 1's generated code, there is a variable getitem_28 whose adapter is not generated correctly. It is a IRObject with dynamic size.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions