Skip to content

使用 fastchat 框架报错 #199

@thunderbolt-fire

Description

@thunderbolt-fire

报错如下

(flashrag) (base) root@5e17ee1d7b71:~/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/RADIO# python run_rag.py  --dataset=nq
Loading test dataset from: dataset/nq/test.jsonl...
[2025-07-09 18:44:39,495] [WARNING] [real_accelerator.py:209:get_accelerator] Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.
[2025-07-09 18:44:39,497] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cpu (auto detect)
[2025-07-09 18:44:41,762] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
Traceback (most recent call last):
  File "/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/RADIO/run_rag.py", line 103, in <module>
    main(args)
  File "/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/RADIO/run_rag.py", line 42, in main
    generator = get_generator(args.config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/FlashRAG/flashrag/utils/utils.py", line 62, in get_generator
    return getattr(importlib.import_module("flashrag.generator"), "FastChatGenerator")(config, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/FlashRAG/flashrag/generator/generator.py", line 527, in __init__
    super().__init__(config)
  File "/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/FlashRAG/flashrag/generator/generator.py", line 280, in __init__
    self.model, self.tokenizer = self._load_model(model=model)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/FlashRAG/flashrag/generator/generator.py", line 558, in _load_model
    max_gpu_memory = str(int(min(available_gpu_memory) * gpu_memory_utilization)) + "GiB"
                             ^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: min() arg is an empty sequence

文件如下

import argparse
import os
from flashrag.config import Config
from flashrag.utils import get_dataset
from flashrag.pipeline import SequentialPipeline
from flashrag.prompt import PromptTemplate
from flashrag.utils import get_generator
from transformers import AutoTokenizer
from utils.classes import CustomSequentialPipeline, mmluPipeline

def main(args):
    if args.dataset == 'mmlu':
        print('dataset is mmlu')
        all_split = get_dataset(args.config)
        test_data = all_split["test"]
        prompt_templete = PromptTemplate(
            args.config,
            system_prompt="Answer the question based on the given document. \
                            Only give me the option (A/B/C/D) and do not output any other words. \
                            \nThe following are given documents.\n\n{reference}",
            user_prompt="Question: {question}\nAnswer:",
        )
        generator = get_generator(args.config)

        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        tokenizer.pad_token = "<lendoftext|>"
        tokenizer.pad_token_id = 128009
        tokenizer.padding_side = "left"
        generator.tokenizer = tokenizer
        pipeline = mmluPipeline(args.config, prompt_template=prompt_templete, generator=generator)
        output_dataset = pipeline.run(test_data, do_eval=True, pred_process_fun=None)
    else:
        all_split = get_dataset(args.config)
        test_data = all_split["test"]
        prompt_templete = PromptTemplate(
            args.config,
            system_prompt="Answer the question based on the given document. \
                            Only give me the answer and do not output any other words. \
                            \nThe following are given documents.\n\n{reference}",
            user_prompt="Question: {question}\nAnswer:",
        )
        generator = get_generator(args.config)

        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        tokenizer.pad_token = "<lendoftext|>"
        tokenizer.pad_token_id = 128009
        tokenizer.padding_side = "left"
        generator.tokenizer = tokenizer

        pipeline = SequentialPipeline(args.config, prompt_template=prompt_templete, generator=generator)

        output_dataset = pipeline.run(test_data, do_eval=True)
        # output_dataset.save(os.path.join(args.config['save_dir'], 'intermediate_data.json'))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default='/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/FlashRAG/models/Meta-Llama-3-8B-Instruct')
    parser.add_argument("--retriever_path", type=str, default='/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/FlashRAG/models/e5-base-v2')
    parser.add_argument("--rerank_model_name", type=str, default="bge-reranker-base")
    parser.add_argument("--rerank_model_path", type=str, default='/root/.cache/huggingface/hub/models--BAAI--bge-reranker-base/snapshots/2cfc18c9415c912f9d8155881c133215df768a70')
    parser.add_argument("--gpu_id", type=int, default=3)
    parser.add_argument("--max_tokens", type=int, default=512)
    parser.add_argument("--dataset", type=str, default="nq")
    parser.add_argument("--framework", type=str, default="fschat")
    parser.add_argument("--openai_model", type=str, default="gpt-4o-mini")
    args = parser.parse_args()

    config_dict = {
        "data_dir": "dataset/",
        "index_path": "indexes/e5_Flat.index",
        "corpus_path": "indexes/retrieval-corpus/wiki-18.jsonl",
        "model2path": {
            # "e5": args.retriever_path, 
            # "llama3-8B-instruct": args.model_path
            },
        # "generator_model": "llama3-8B-instruct",
        "framework": args.framework,
        "generator_model": "llama3-8B-instruct",
        "openai_setting": {
            "api_key": "sk-xxx",
            "base_url": "https://xxx.xxx.xxx"
        },
        # "retrieval_method": "e5",
        "retrieval_method": args.retriever_path,
        "metrics": ["em", "f1"],
        "retrieval_topk": 20,
        "save_intermediate_data": True,
        "dataset_name": args.dataset,
        # "test_sample_num": 100,
        "gpu_id": args.gpu_id,
        "generation_params": {
            "max_tokens": args.max_tokens,
        },
        "use_reranker": True,
        "rerank_model_name": "bge-reranker-base",
        "rerank_model_path": args.rerank_model_path,
        "rerank_topk": 5,
        "rerank_batch_size": 128,
        "save_retrieval_cache": False,
    }

    args.config = Config(config_file_path='/root/siton-data-0553377b2d664236bad5b5d0ba8aa419/workspace/RADIO/utils/basic_config.yaml',config_dict=config_dict)
    main(args)

generator = get_generator(args.config)

这一行报错

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions