import abc import logging from pathlib import Path from typing import Union import pandas as pd from autorag.schema import BaseModule from autorag.utils import validate_qa_dataset logger = logging.getLogger("AutoRAG") class BasePassageReranker(BaseModule, metaclass=abc.ABCMeta): def __init__(self, project_dir: Union[str, Path], *args, **kwargs): logger.info( f"Initialize passage reranker node - {self.__class__.__name__} module..." ) def __del__(self): logger.info( f"Deleting passage reranker node - {self.__class__.__name__} module..." ) def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs): logger.info( f"Running passage reranker node - {self.__class__.__name__} module..." ) validate_qa_dataset(previous_result) # find queries columns assert ( "query" in previous_result.columns ), "previous_result must have query column." queries = previous_result["query"].tolist() # find contents_list columns assert ( "retrieved_contents" in previous_result.columns ), "previous_result must have retrieved_contents column." contents = previous_result["retrieved_contents"].tolist() # find scores columns assert ( "retrieve_scores" in previous_result.columns ), "previous_result must have retrieve_scores column." scores = previous_result["retrieve_scores"].tolist() # find ids columns assert ( "retrieved_ids" in previous_result.columns ), "previous_result must have retrieved_ids column." ids = previous_result["retrieved_ids"].tolist() return queries, contents, scores, ids