update eval scripts
Browse files- eval/cmteb_eval.py +24 -0
- eval/cmteb_eval.sh +1 -0
- eval/retrieval_eval.py +106 -0
- eval/retrieval_eval.sh +17 -0
eval/cmteb_eval.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from mteb import MTEB
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
logging.basicConfig(level=logging.INFO)
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger("main")
|
| 10 |
+
|
| 11 |
+
CLASSIFICATION_LIST = ["TNews", "IFlyTek", "MultilingualSentiment", "JDReview", "OnlineShopping", "Waimai"]
|
| 12 |
+
STS_LIST = ["ATEC", "BQ", "LCQMC", "PAWSX", "STSB", "AFQMC", "QBQTC"]
|
| 13 |
+
PAIRCLASSIFICATION_LIST = ["Ocnli", "Cmnli"]
|
| 14 |
+
RERANKING_LIST = ["T2Reranking", "MmarcoReranking", "CMedQAv1", "CMedQAv2"]
|
| 15 |
+
CLUSTERING_LIST = ["CLSClusteringS2S", "CLSClusteringP2P", "ThuNewsClusteringS2S", "ThuNewsClusteringP2P"]
|
| 16 |
+
TASK_LIST = [CLASSIFICATION_LIST, STS_LIST, PAIRCLASSIFICATION_LIST, RERANKING_LIST, CLUSTERING_LIST]
|
| 17 |
+
names = ['Classification', 'STS', 'Pairclassification', 'Reranking', 'Clustering']
|
| 18 |
+
|
| 19 |
+
model = SentenceTransformer('piccolo-base-zh')
|
| 20 |
+
for name, task_list in zip(names, TASK_LIST):
|
| 21 |
+
for task in task_list:
|
| 22 |
+
logger.info(f"Running task: {task}")
|
| 23 |
+
evaluation = MTEB(tasks=[task])
|
| 24 |
+
evaluation.run(model, output_folder=f"results/{name}")
|
eval/cmteb_eval.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python cmteb_eval.py
|
eval/retrieval_eval.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''this eval code is borrowed from E5'''
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
from datasets import Dataset
|
| 10 |
+
from typing import List, Dict
|
| 11 |
+
from functools import partial
|
| 12 |
+
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from mteb import MTEB, AbsTaskRetrieval, DRESModel
|
| 16 |
+
|
| 17 |
+
from utils import pool, logger, move_to_cuda
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser(description='evaluation for BEIR benchmark')
|
| 20 |
+
parser.add_argument('--model-name-or-path', default='bert-base-uncased',
|
| 21 |
+
type=str, metavar='N', help='which model to use')
|
| 22 |
+
parser.add_argument('--output-dir', default='tmp-outputs/',
|
| 23 |
+
type=str, metavar='N', help='output directory')
|
| 24 |
+
parser.add_argument('--pool-type', default='avg', help='pool type')
|
| 25 |
+
parser.add_argument('--max-length', default=512, help='max length')
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
|
| 29 |
+
assert args.pool_type in ['cls', 'avg'], 'pool_type should be cls or avg'
|
| 30 |
+
assert args.output_dir, 'output_dir should be set'
|
| 31 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _transform_func(tokenizer: PreTrainedTokenizerFast,
|
| 35 |
+
examples: Dict[str, List]) -> BatchEncoding:
|
| 36 |
+
return tokenizer(examples['contents'],
|
| 37 |
+
max_length=int(args.max_length),
|
| 38 |
+
padding=True,
|
| 39 |
+
return_token_type_ids=False,
|
| 40 |
+
truncation=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RetrievalModel(DRESModel):
|
| 44 |
+
# Refer to the code of DRESModel for the methods to overwrite
|
| 45 |
+
def __init__(self, **kwargs):
|
| 46 |
+
self.encoder = AutoModel.from_pretrained(args.model_name_or_path)
|
| 47 |
+
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
| 48 |
+
self.gpu_count = torch.cuda.device_count()
|
| 49 |
+
if self.gpu_count > 1:
|
| 50 |
+
self.encoder = torch.nn.DataParallel(self.encoder)
|
| 51 |
+
|
| 52 |
+
self.encoder.cuda()
|
| 53 |
+
self.encoder.eval()
|
| 54 |
+
|
| 55 |
+
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
| 56 |
+
input_texts = ['查询: {}'.format(q) for q in queries]
|
| 57 |
+
return self._do_encode(input_texts)
|
| 58 |
+
|
| 59 |
+
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
|
| 60 |
+
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
| 61 |
+
input_texts = ['结果: {}'.format(t) for t in input_texts]
|
| 62 |
+
return self._do_encode(input_texts)
|
| 63 |
+
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def _do_encode(self, input_texts: List[str]) -> np.ndarray:
|
| 66 |
+
dataset: Dataset = Dataset.from_dict({'contents': input_texts})
|
| 67 |
+
dataset.set_transform(partial(_transform_func, self.tokenizer))
|
| 68 |
+
|
| 69 |
+
data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
|
| 70 |
+
batch_size = 128 * self.gpu_count
|
| 71 |
+
data_loader = DataLoader(
|
| 72 |
+
dataset,
|
| 73 |
+
batch_size=batch_size,
|
| 74 |
+
shuffle=False,
|
| 75 |
+
drop_last=False,
|
| 76 |
+
num_workers=4,
|
| 77 |
+
collate_fn=data_collator,
|
| 78 |
+
pin_memory=True)
|
| 79 |
+
|
| 80 |
+
encoded_embeds = []
|
| 81 |
+
for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10):
|
| 82 |
+
batch_dict = move_to_cuda(batch_dict)
|
| 83 |
+
|
| 84 |
+
with torch.cuda.amp.autocast():
|
| 85 |
+
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
| 86 |
+
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
| 87 |
+
encoded_embeds.append(embeds.cpu().numpy())
|
| 88 |
+
|
| 89 |
+
return np.concatenate(encoded_embeds, axis=0)
|
| 90 |
+
|
| 91 |
+
TASKS = ["T2Retrieval", "MMarcoRetrieval", "DuRetrieval", "CovidRetrieval", "CmedqaRetrieval", "EcomRetrieval", "MedicalRetrieval", "VideoRetrieval"]
|
| 92 |
+
def main():
|
| 93 |
+
assert AbsTaskRetrieval.is_dres_compatible(RetrievalModel)
|
| 94 |
+
model = RetrievalModel()
|
| 95 |
+
|
| 96 |
+
task_names = [t.description["name"] for t in MTEB(tasks=TASKS).tasks]
|
| 97 |
+
logger.info('Tasks: {}'.format(task_names))
|
| 98 |
+
|
| 99 |
+
for task in task_names:
|
| 100 |
+
logger.info('Processing task: {}'.format(task))
|
| 101 |
+
evaluation = MTEB(tasks=[task])
|
| 102 |
+
evaluation.run(model, output_folder=args.output_dir, overwrite_results=False)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == '__main__':
|
| 106 |
+
main()
|
eval/retrieval_eval.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -x
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )"
|
| 7 |
+
echo "working directory: ${DIR}"
|
| 8 |
+
|
| 9 |
+
MODEL_NAME_OR_PATH="piccolo-base-zh"
|
| 10 |
+
OUTPUT_DIR='Retrieval'
|
| 11 |
+
|
| 12 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 13 |
+
|
| 14 |
+
python -u retrieval_eval.py \
|
| 15 |
+
--model-name-or-path "${MODEL_NAME_OR_PATH}" \
|
| 16 |
+
--pool-type avg \
|
| 17 |
+
--output-dir "${OUTPUT_DIR}" "$@"
|