Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions .ci/scripts/test_riscv_qemu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

# CI wrapper: install RISC-V cross-compile + qemu-user tooling, then run the
# RISC-V Phase 1 smoke test (export, cross-compile, qemu-user execution) via
# RISC-V smoke test (export, cross-compile, qemu-user execution) via
# examples/riscv/run.sh. The bundled-IO comparison and Test_result: PASS
# check are done by run.sh.

Expand All @@ -14,5 +14,43 @@ set -eu
script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")")
et_root_dir=$(realpath "${script_dir}/../..")

model="add"
xnnpack=false
quantize=false
verbose=false

usage() {
cat <<EOF
Usage: $(basename "$0") [options]
Options:
--model=<NAME> Which model to export and run (default: add)
--xnnpack Enable the XNNPACK backend (AOT partitioner + runtime)
--quantize Produce an 8-bit quantized model
-h, --help Show this help
EOF
}

for arg in "$@"; do
case $arg in
--model=*) model="${arg#*=}" ;;
--xnnpack) xnnpack=true ;;
--quantize) quantize=true ;;
--verbose) verbose=true ;;
-h|--help) usage; exit 0 ;;
*) echo "Unknown option: $arg" >&2; usage; exit 1 ;;
esac
done

run_extra_args=()
if ${xnnpack}; then
run_extra_args+=(--xnnpack)
fi
if ${quantize}; then
run_extra_args+=(--quantize)
fi
if ${verbose}; then
run_extra_args+=(--verbose)
fi

bash "${et_root_dir}/examples/riscv/setup.sh"
bash "${et_root_dir}/examples/riscv/run.sh"
bash "${et_root_dir}/examples/riscv/run.sh" --model="${model}" "${run_extra_args[@]}"
28 changes: 26 additions & 2 deletions .github/workflows/_test_riscv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,36 @@ on:
required: false
type: number
default: 30
model:
description: 'Which model to run. Possible values are: add, mv2 (mobilenetv2)'
required: false
type: string
default: 'add'
xnnpack:
description: 'Whether to enable XNNPACK'
required: false
type: boolean
default: false
quantize:
description: 'Produce an 8-bit quantized model'
required: false
type: boolean
default: false
gcc-version:
description: 'The version of GCC to use'
required: false
type: number
docker-image:
description: 'The docker image to use for this job'
required: false
type: string

jobs:
run:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
runner: linux.2xlarge
docker-image: ci-image:executorch-ubuntu-22.04-gcc11
docker-image: ${{ inputs.docker-image || 'ci-image:executorch-ubuntu-22.04-gcc11' }}
submodules: 'recursive'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: ${{ inputs.timeout }}
Expand All @@ -29,4 +52,5 @@ jobs:
source .ci/scripts/utils.sh
install_executorch "--use-pt-pinned-commit"
bash .ci/scripts/test_riscv_qemu.sh
export GCC_VERSION=${{ inputs.gcc-version }}
bash .ci/scripts/test_riscv_qemu.sh --model="${{ inputs.model }}" ${{ inputs.xnnpack && '--xnnpack' }} ${{ inputs.quantize && '--quantize' }}
24 changes: 24 additions & 0 deletions .github/workflows/riscv64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ jobs:
test-riscv:
name: test-riscv
uses: ./.github/workflows/_test_riscv.yml
strategy:
matrix:
include:
- { model: add, xnnpack: false, quantize: false }
- { model: add, xnnpack: true, quantize: false }
- { model: mv2, xnnpack: false, quantize: false }
- { model: mv2, xnnpack: true, quantize: false }
- { model: mv2, xnnpack: true, quantize: true }
- { model: mobilebert, xnnpack: false, quantize: false }
- { model: mobilebert, xnnpack: true, quantize: false }
- { model: mobilebert, xnnpack: true, quantize: true }
- { model: llama2, xnnpack: false, quantize: false }
- { model: llama2, xnnpack: true, quantize: false }
- { model: llama2, xnnpack: true, quantize: true }
- { model: resnet18, xnnpack: false, quantize: false }
- { model: resnet18, xnnpack: true, quantize: false }
- { model: resnet18, xnnpack: true, quantize: true }
permissions:
id-token: write
contents: read
with:
model: ${{ matrix.model }}
xnnpack: ${{ matrix.xnnpack }}
quantize: ${{ matrix.quantize }}
# XNNPACK requires GCC 14+
gcc-version: ${{ matrix.xnnpack && 14 }}
docker-image: ${{ matrix.xnnpack && 'ci-image:executorch-ubuntu-24.04-gcc14' }}
189 changes: 175 additions & 14 deletions examples/riscv/aot_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""AOT export for the RISC-V Phase 1.0 smoke test.
"""AOT export for the RISC-V smoke test.

Exports a trivial ``torch.add`` module to a BundledProgram (.bpte) that the
portable executor_runner can load on a riscv64 target and verify against the
embedded reference output, emitting ``Test_result: PASS`` on success.
Exports a small model to a BundledProgram (.bpte) that the portable
executor_runner can load on a riscv64 target and verify against the embedded
reference output, emitting ``Test_result: PASS`` on success.
"""

import argparse
import logging
from pathlib import Path

import torch
Expand All @@ -28,26 +29,186 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y


def build_add():
model = AddModule().eval()
example_inputs = (torch.ones(1, 4), torch.full((1, 4), 2.0))
test_inputs = [
(torch.ones(1, 4), torch.full((1, 4), 2.0)),
(torch.full((1, 4), 3.0), torch.full((1, 4), 4.0)),
]
return model, example_inputs, test_inputs, True


def build_mv2():
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
torch.manual_seed(0)
example_inputs = (torch.randn(1, 3, 224, 224),)
test_inputs = [example_inputs]
return model, example_inputs, test_inputs, False


def build_mobilebert():
from transformers import MobileBertConfig, MobileBertModel

config = MobileBertConfig(
vocab_size=1024,
hidden_size=128,
embedding_size=64,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=128,
intra_bottleneck_size=32,
)

class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = MobileBertModel(config).eval()

def forward(self, input_ids):
return self.model(input_ids).last_hidden_state

model = Wrapper().eval()
example_inputs = (torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]),)
test_inputs = [example_inputs]
return model, example_inputs, test_inputs, False


def build_llama2():
# Use the executorch native Transformer (matches MODEL_NAME_TO_MODEL["llama2"]
# in examples/models/__init__.py). Unlike HF LlamaModel, RoPE freqs are
# precomputed buffers and just sliced at forward time, so no
# torch.arange()/Long causal mask is built per forward — which is what
# the PT2E XNNPACK quantizer trips over on HF Llama.
from executorch.examples.models.llama.llama_transformer import construct_transformer
from executorch.examples.models.llama.model_args import ModelArgs

seq_len = 8
args = ModelArgs(
dim=128,
n_layers=2,
n_heads=4,
n_kv_heads=2, # GQA: kv_heads < n_heads exercises the GQA path
vocab_size=1024,
hidden_dim=256, # SwiGLU FFN: gate + up projections at this width
max_seq_len=seq_len,
max_context_len=seq_len,
rope_theta=10000.0,
)
torch.manual_seed(0)
model = construct_transformer(args).eval()
example_inputs = (torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long),)
test_inputs = [example_inputs]
return model, example_inputs, test_inputs, False


def build_resnet18():
from torchvision.models import resnet18, ResNet18_Weights

model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
torch.manual_seed(0)
example_inputs = (torch.randn(1, 3, 224, 224),)
test_inputs = [example_inputs]
return model, example_inputs, test_inputs, False


MODELS = {
"add": build_add,
"mv2": build_mv2,
"mobilebert": build_mobilebert,
"llama2": build_llama2,
"resnet18": build_resnet18,
}


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model",
choices=sorted(MODELS),
default="add",
help="Which model to export",
)
parser.add_argument(
"--output",
type=Path,
default=Path("add_riscv.bpte"),
help="Output .bpte path",
default=None,
help="Output .bpte path (default: <model>_riscv.bpte)",
)
parser.add_argument(
"--xnnpack",
action="store_true",
help="Lower through the XNNPACK partitioner",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Produce an 8-bit quantized model",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable XNNPACK partitioner DEBUG logging and dump the lowered graph",
)
args = parser.parse_args()

model = AddModule().eval()
example_inputs = (torch.ones(1, 4), torch.full((1, 4), 2.0))
if args.verbose:
logging.basicConfig(level=logging.DEBUG)

exported = export(model, example_inputs)
et_program = to_edge_transform_and_lower(exported).to_executorch()
if args.output is None:
args.output = Path(f"{args.model}_riscv.bpte")

model, example_inputs, test_inputs, strict = MODELS[args.model]()

if args.quantize:
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
from executorch.examples.xnnpack.quantization.utils import quantize

if args.model not in MODEL_NAME_TO_OPTIONS:
parser.error(f"No XNNPACK quantization recipe for model {args.model!r}")
quant_type = MODEL_NAME_TO_OPTIONS[args.model].quantization
if quant_type == QuantType.NONE:
parser.error(f"Quantization recipe for {args.model!r} is NONE")
ep = export(model, example_inputs, strict=strict)
model = quantize(ep.module(), example_inputs, quant_type)

exported = export(model, example_inputs, strict=strict)
partitioners = []
if args.xnnpack:
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackPartitioner,
)

partitioners.append(XnnpackPartitioner(verbose=args.verbose))

compile_config = None
if args.quantize:
from executorch.exir import EdgeCompileConfig

compile_config = EdgeCompileConfig(_check_ir_validity=False)

edge = to_edge_transform_and_lower(
exported, partitioner=partitioners, compile_config=compile_config
)
delegated = sum(
1
for n in edge.exported_program().graph.nodes
if n.op == "call_function" and "call_delegate" in str(n.target)
)
print(
f"[aot_riscv] model={args.model} xnnpack={args.xnnpack} "
f"quantize={args.quantize} delegated_nodes={delegated}"
)

if args.verbose:
from executorch.exir.backend.utils import print_delegated_graph

print_delegated_graph(edge.exported_program().graph_module)

et_program = edge.to_executorch()

test_inputs = [
(torch.ones(1, 4), torch.full((1, 4), 2.0)),
(torch.full((1, 4), 3.0), torch.full((1, 4), 4.0)),
]
test_suite = MethodTestSuite(
method_name="forward",
test_cases=[
Expand Down
2 changes: 2 additions & 0 deletions examples/riscv/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torchvision
transformers
Loading
Loading