유사도 점수 출력 기능 추가
This commit is contained in:
46
custom_CrossEncoderReranker.py
Normal file
46
custom_CrossEncoderReranker.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
|
||||||
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncoderReranker(BaseDocumentCompressor):
|
||||||
|
"""Document compressor that uses CrossEncoder for reranking."""
|
||||||
|
|
||||||
|
model: BaseCrossEncoder
|
||||||
|
"""CrossEncoder model to use for scoring similarity
|
||||||
|
between the query and documents."""
|
||||||
|
top_n: int = 3
|
||||||
|
"""Number of documents to return."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
|
extra="forbid",
|
||||||
|
)
|
||||||
|
|
||||||
|
def compress_documents(
|
||||||
|
self,
|
||||||
|
documents: Sequence[Document],
|
||||||
|
query: str,
|
||||||
|
callbacks: Optional[Callbacks] = None,
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""
|
||||||
|
Rerank documents using CrossEncoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: A sequence of documents to compress.
|
||||||
|
query: The query to use for compressing the documents.
|
||||||
|
callbacks: Callbacks to run during the compression process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of compressed documents.
|
||||||
|
"""
|
||||||
|
scores = self.model.score([(query, doc.page_content) for doc in documents])
|
||||||
|
docs_with_scores = list(zip(documents, scores))
|
||||||
|
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||||
|
return [[doc, score] for doc, score in result[: self.top_n]]
|
||||||
Reference in New Issue
Block a user