From df5e2350d6c4430b87ab9522e5e23e8c0ecfe80c Mon Sep 17 00:00:00 2001 From: chan Date: Mon, 10 Mar 2025 10:55:46 +0900 Subject: [PATCH] =?UTF-8?q?=EC=9C=A0=EC=82=AC=EB=8F=84=20=EC=A0=90?= =?UTF-8?q?=EC=88=98=20=EC=B6=9C=EB=A0=A5=20=EA=B8=B0=EB=8A=A5=20=EC=B6=94?= =?UTF-8?q?=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom_CrossEncoderReranker.py | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 custom_CrossEncoderReranker.py diff --git a/custom_CrossEncoderReranker.py b/custom_CrossEncoderReranker.py new file mode 100644 index 0000000..3af0873 --- /dev/null +++ b/custom_CrossEncoderReranker.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import operator +from typing import Optional, Sequence + +from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from pydantic import ConfigDict + + +class CrossEncoderReranker(BaseDocumentCompressor): + """Document compressor that uses CrossEncoder for reranking.""" + + model: BaseCrossEncoder + """CrossEncoder model to use for scoring similarity + between the query and documents.""" + top_n: int = 3 + """Number of documents to return.""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """ + Rerank documents using CrossEncoder. + + Args: + documents: A sequence of documents to compress. + query: The query to use for compressing the documents. + callbacks: Callbacks to run during the compression process. + + Returns: + A sequence of compressed documents. + """ + scores = self.model.score([(query, doc.page_content) for doc in documents]) + docs_with_scores = list(zip(documents, scores)) + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) + return [[doc, score] for doc, score in result[: self.top_n]]