Skip to content

GPT-OSS-20B not work in AMD GPUs #45237

@tanreinama

Description

@tanreinama

System Info

GPT-OSS-20B does not work on Radeon GPUs.
I tested it in both the native environment and the Docker container rocm/pytorch:rocm7.2.1_ubuntu24.04_py3.12_pytorch_release_2.9.1.
I tried updating Triton, but it still didn't work.
I tried those versions of Triton, triton-rocm 3.6.0, 3.5.1+rocm (included in rocm/pytorch), 3.6.0 nightly, and 3.7.0 nightly, but all resulted in errors.

@ivarflakstad

command log:

$ pip install -U transformers kernels accelerate
$ python
>>> from transformers import pipeline
>>> import torch
>>> model_id = "openai/gpt-oss-20b"
>>> pipe = pipeline(
...     "text-generation",
...     model=model_id,
...     torch_dtype="auto",
...     device_map="auto",
... )
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 42 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:02<00:00, 16.72it/s]
Download complete: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:02<00:00, 39.0B/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 411/411 [01:16<00:00,  5.40it/s]
>>> messages = [
...     {"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
... ]
>>> outputs = pipe(
...     messages,
...     max_new_tokens=256,
... )
Passing `generation_config` together with generation-related arguments=({'max_new_tokens'}) is deprecated and will be removed in future versions. Please pass either a `generation_config` object OR all generation parameters explicitly, but not both.
Both `max_new_tokens` (=256) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Traceback (most recent call last):
  File "<python-input-5>", line 1, in <module>
    outputs = pipe(
        messages,
        max_new_tokens=256,
    )
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/pipelines/text_generation.py", line 299, in __call__
    return super().__call__(text_inputs, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/pipelines/base.py", line 1264, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/pipelines/base.py", line 1271, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/pipelines/base.py", line 1163, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/pipelines/text_generation.py", line 403, in _forward
    output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/generation/utils.py", line 2543, in generate
    result = decoding_method(
        self,
    ...<5 lines>...
        **model_kwargs,
    )
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/generation/utils.py", line 2736, in _sample
    outputs = self._prefill(
        input_ids,
    ...<2 lines>...
        is_first_iteration=not generation_config.is_assistant,
    )
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/generation/utils.py", line 3768, in _prefill
    return self(**model_inputs, return_dict=True)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/utils/generic.py", line 876, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 649, in forward
    outputs: MoeModelOutputWithPast = self.model(
                                      ~~~~~~~~~~^
        input_ids=input_ids,
        ^^^^^^^^^^^^^^^^^^^^
    ...<6 lines>...
        **kwargs,
        ^^^^^^^^^
    )
    ^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/utils/generic.py", line 952, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/utils/output_capturing.py", line 248, in wrapper
    outputs = func(self, *args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 490, in forward
    hidden_states = decoder_layer(
        hidden_states,
    ...<5 lines>...
        **kwargs,
    )
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 384, in forward
    hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores
                       ~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/integrations/mxfp4.py", line 508, in mlp_forward
    routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx=scatter_idx)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nama/.pyenv/versions/rocm/lib/python3.13/site-packages/transformers/integrations/mxfp4.py", line 411, in forward
    intermediate_cache3 = matmul_ogs(
        intermediate_cache1,
    ...<5 lines>...
        gammas=routing_data.gate_scal,
    )
  File "/home/nama/.cache/huggingface/hub/models--kernels-community--gpt-oss-triton-kernels/snapshots/76c23fb9a6607cd5c62c1e6b8e7f436ec5385517/build/torch-rocm/matmul_ogs.py", line 583, in matmul_ogs
    out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
                                        num_indx, precision_config, routing_data,
                                        postprocessing_features, memory, fused_postprocess_activation, epilogue)
  File "/home/nama/.cache/huggingface/hub/models--kernels-community--gpt-oss-triton-kernels/snapshots/76c23fb9a6607cd5c62c1e6b8e7f436ec5385517/build/torch-rocm/matmul_ogs.py", line 252, in apply_postprocessing_features
    grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
                                          ~~~~~~~~~~~~^^^^
  File "/home/nama/.cache/huggingface/hub/models--kernels-community--gpt-oss-triton-kernels/snapshots/76c23fb9a6607cd5c62c1e6b8e7f436ec5385517/build/torch-rocm/matmul_ogs.py", line 223, in compute_grid
    num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
              ~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'

environment:

$ pip list
Package           Version
----------------- --------------
accelerate        1.13.0
annotated-doc     0.0.4
anyio             4.13.0
certifi           2026.2.25
click             8.3.2
filelock          3.25.2
fsspec            2026.2.0
h11               0.16.0
hf-xet            1.4.3
httpcore          1.0.9
httpx             0.28.1
huggingface_hub   1.9.0
idna              3.11
Jinja2            3.1.6
kernels           0.12.3
markdown-it-py    4.0.0
MarkupSafe        3.0.3
mdurl             0.1.2
mpmath            1.3.0
networkx          3.6.1
numpy             2.4.3
packaging         26.0
pillow            12.1.1
pip               25.3
psutil            7.2.2
Pygments          2.20.0
PyYAML            6.0.3
regex             2026.4.4
rich              14.3.3
safetensors       0.7.0
setuptools        70.2.0
shellingham       1.5.4
sympy             1.14.0
tokenizers        0.22.2
torch             2.11.0+rocm7.2
torchvision       0.26.0+rocm7.2
tqdm              4.67.3
transformers      5.5.0
triton-rocm       3.6.0
typer             0.24.1
typing_extensions 4.15.0

Tested triton version:

$ pip list|grep triton
triton             3.5.1+rocm7.2.1.gita272dfa8
$ pip install -U https://download-r2.pytorch.org/whl/nightly/triton_rocm-3.6.0%2Bgit6213a0e8-cp312-cp312-linux_x86_64.whl
$ pip install -U https://download-r2.pytorch.org/whl/nightly/triton_rocm-3.7.0%2Bgit9c288bc5-cp312-cp312-linux_x86_64.whl

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

$ docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined rocm/pytorch:rocm7.2.1_ubuntu24.04_py3.12_pytorch_release_2.9.1
# pip install -U transformers kernels accelerate
# python
>>> from transformers import pipeline
>>> import torch
>>> model_id = "openai/gpt-oss-20b"
>>> pipe = pipeline(
...     "text-generation",
...     model=model_id,
...     torch_dtype="auto",
...     device_map="auto",
... )
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
config.json: 1.81kB [00:00, 1.49MB/s]
`torch_dtype` is deprecated! Use `dtype` instead!
model.safetensors.index.json: 36.4kB [00:00, 62.0MB/s]
Fetching 3 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [03:00<00:00, 60.12s/it]
Download complete: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 13.8G/13.8G [03:00<00:00, 76.3MB/s]
Fetching 42 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:01<00:00, 33.86it/s]
Download complete: : 249kB [00:01, 193kB/s]              ████████████████████████████████████████████████████████████████████▌  | 41/42 [00:01<00:00, 40.35it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 411/411 [00:14<00:00, 29.26it/s]
generation_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 177/177 [00:00<00:00, 1.55MB/s]
tokenizer_config.json: 4.20kB [00:00, 9.92MB/s]
tokenizer.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 27.9M/27.9M [00:01<00:00, 19.6MB/s]
special_tokens_map.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 98.0/98.0 [00:00<00:00, 317kB/s]
chat_template.jinja: 16.7kB [00:00, 28.5MB/s]
>>> messages = [
...     {"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
... ]
>>> outputs = pipe(
...     messages,
...     max_new_tokens=256,
... )
Passing `generation_config` together with generation-related arguments=({'max_new_tokens'}) is deprecated and will be removed in future versions. Please pass either a `generation_config` object OR all generation parameters explicitly, but not both.
Both `max_new_tokens` (=256) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/venv/lib/python3.12/site-packages/transformers/pipelines/text_generation.py", line 299, in __call__
    return super().__call__(text_inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1264, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1271, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1163, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/pipelines/text_generation.py", line 403, in _forward
    output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2543, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2736, in _sample
    outputs = self._prefill(
              ^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 3768, in _prefill
    return self(**model_inputs, return_dict=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 876, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 649, in forward
    outputs: MoeModelOutputWithPast = self.model(
                                      ^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 952, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/utils/output_capturing.py", line 248, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 490, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 93, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 384, in forward
    hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 508, in mlp_forward
    routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx=scatter_idx)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 411, in forward
    intermediate_cache3 = matmul_ogs(
                          ^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--gpt-oss-triton-kernels/snapshots/76c23fb9a6607cd5c62c1e6b8e7f436ec5385517/build/torch-rocm/matmul_ogs.py", line 583, in matmul_ogs
    out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--gpt-oss-triton-kernels/snapshots/76c23fb9a6607cd5c62c1e6b8e7f436ec5385517/build/torch-rocm/matmul_ogs.py", line 252, in apply_postprocessing_features
    grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
                                          ^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--gpt-oss-triton-kernels/snapshots/76c23fb9a6607cd5c62c1e6b8e7f436ec5385517/build/torch-rocm/matmul_ogs.py", line 223, in compute_grid
    num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
              ~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'

Expected behavior

Execute without errors

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions