first commit
This commit is contained in:
200
main.py
Normal file
200
main.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import csv
|
||||
import os
|
||||
|
||||
import streamlit as st
|
||||
from dotenv import load_dotenv
|
||||
from kiwipiepy import Kiwi
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||
from langchain_community.document_loaders import (
|
||||
CSVLoader,
|
||||
DirectoryLoader,
|
||||
PyPDFLoader,
|
||||
TextLoader,
|
||||
UnstructuredHTMLLoader,
|
||||
UnstructuredMarkdownLoader,
|
||||
)
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_core.messages import ChatMessage
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
csv.field_size_limit(10000000)
|
||||
load_dotenv()
|
||||
EMBEDDING_MODEL = "nlpai-lab/KoE5"
|
||||
RERANK_MODEL = "upskyy/ko-reranker-8k"
|
||||
LOAD_DOCS = "data/fake_rag"
|
||||
SAVE_VD = "fake_all"
|
||||
|
||||
|
||||
def embedding():
|
||||
model_kwargs = {"device": "cuda"}
|
||||
model_name = EMBEDDING_MODEL
|
||||
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
|
||||
return embeddings
|
||||
|
||||
|
||||
def load_and_process_documents(directory_path):
|
||||
loaders = [
|
||||
DirectoryLoader(
|
||||
directory_path,
|
||||
glob="*.html",
|
||||
loader_cls=UnstructuredHTMLLoader,
|
||||
loader_kwargs={"encoding": "utf-8"},
|
||||
),
|
||||
DirectoryLoader(
|
||||
directory_path,
|
||||
glob="*.csv",
|
||||
loader_cls=CSVLoader,
|
||||
loader_kwargs={"encoding": "utf-8"},
|
||||
),
|
||||
DirectoryLoader(directory_path, glob="*.pdf", loader_cls=PyPDFLoader),
|
||||
DirectoryLoader(
|
||||
directory_path,
|
||||
glob="*.md",
|
||||
loader_cls=UnstructuredMarkdownLoader,
|
||||
loader_kwargs={"encoding": "utf-8"},
|
||||
),
|
||||
DirectoryLoader(
|
||||
directory_path,
|
||||
glob="*.txt",
|
||||
loader_cls=TextLoader,
|
||||
loader_kwargs={"encoding": "utf-8"},
|
||||
),
|
||||
]
|
||||
|
||||
all_documents = []
|
||||
for loader in loaders:
|
||||
loaded_docs = loader.load()
|
||||
all_documents.extend(loaded_docs)
|
||||
|
||||
# 메타데이터에 source 파일명 기록
|
||||
for doc in all_documents:
|
||||
filename = os.path.basename(doc.metadata["source"])
|
||||
filename_without_ext = os.path.splitext(filename)[0]
|
||||
doc.metadata = {"source": filename_without_ext}
|
||||
|
||||
text_splitter = SemanticChunker(embeddings=embedding())
|
||||
texts = text_splitter.split_documents(all_documents)
|
||||
return texts
|
||||
|
||||
|
||||
def save_faiss_index(faiss_index, path="faiss_index"):
|
||||
faiss_index.save_local(path)
|
||||
|
||||
|
||||
def load_faiss_index(path="faiss_index"):
|
||||
embeddings = embedding()
|
||||
return FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
|
||||
|
||||
|
||||
kiwi = Kiwi()
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def initialize_processing():
|
||||
processed_texts = load_and_process_documents(LOAD_DOCS)
|
||||
docs = processed_texts
|
||||
try:
|
||||
faiss_index = load_faiss_index(SAVE_VD)
|
||||
except Exception:
|
||||
embeddings = embedding()
|
||||
faiss_index = FAISS.from_documents(docs, embeddings)
|
||||
save_faiss_index(faiss_index, SAVE_VD)
|
||||
faiss = faiss_index.as_retriever(search_kwargs={"k": 3})
|
||||
kiwi_bm25 = BM25Retriever.from_documents(
|
||||
docs, preprocess_func=lambda text: [token.form for token in kiwi.tokenize(text)]
|
||||
)
|
||||
kiwi_bm25.k = 3
|
||||
st.session_state["retrievers"] = {
|
||||
"kiwi_bm25": kiwi_bm25,
|
||||
"faiss": faiss,
|
||||
}
|
||||
st.session_state["Ensembleretriever"] = EnsembleRetriever(
|
||||
retrievers=[kiwi_bm25, faiss],
|
||||
weights=[0.6, 0.4],
|
||||
search_type="mmr",
|
||||
)
|
||||
rerank_model = HuggingFaceCrossEncoder(model_name=RERANK_MODEL)
|
||||
compressor = CrossEncoderReranker(model=rerank_model, top_n=3)
|
||||
st.session_state["compression_retriever"] = ContextualCompressionRetriever(
|
||||
base_compressor=compressor, base_retriever=st.session_state["Ensembleretriever"]
|
||||
)
|
||||
return st.session_state["retrievers"], st.session_state["compression_retriever"]
|
||||
|
||||
|
||||
# GPT 모델 설정
|
||||
gpt_model = ChatOpenAI(model_name="gpt-4", temperature=0)
|
||||
|
||||
st.set_page_config(page_title="RAG 테스트", page_icon="💬")
|
||||
st.title("RAG 테스트")
|
||||
|
||||
if "initialize_processing" not in st.session_state:
|
||||
st.session_state["retrievers"], st.session_state["compression_retriever"] = (
|
||||
initialize_processing()
|
||||
)
|
||||
|
||||
retrievers = st.session_state["retrievers"]
|
||||
reranker = st.session_state["compression_retriever"]
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [
|
||||
ChatMessage(role="assistant", content="무엇을 도와드릴까요?")
|
||||
]
|
||||
|
||||
|
||||
# 대화 기록 출력
|
||||
def print_history(st, session_state):
|
||||
for msg in session_state.messages:
|
||||
st.chat_message(msg.role).write(msg.content)
|
||||
|
||||
|
||||
# 대화 기록 추가
|
||||
def add_history(session_state, role, content):
|
||||
session_state.messages.append(ChatMessage(role=role, content=content))
|
||||
|
||||
|
||||
print_history(st, st.session_state)
|
||||
|
||||
if user_input := st.chat_input():
|
||||
add_history(st.session_state, "user", user_input)
|
||||
st.chat_message("user").write(user_input)
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
chat_container = st.empty()
|
||||
|
||||
with st.spinner("문서 검색 및 답변 생성 중..."):
|
||||
docs = reranker.invoke(user_input)
|
||||
|
||||
context_texts = "\n\n".join(
|
||||
[
|
||||
f"- {doc[0].metadata['source']} (유사도: {doc[1]}): {doc[0].page_content}"
|
||||
for doc in docs
|
||||
]
|
||||
)
|
||||
|
||||
prompt = f"""사용자의 질문: {user_input}
|
||||
다음은 참고할 수 있는 문서입니다:
|
||||
{context_texts}
|
||||
위 문서의 내용을 기반으로 사용자의 질문에 대한 답변을 생성해주세요.
|
||||
문서에서 관련 내용을 찾을 수 없다면, 일반적인 정보를 제공해주세요.
|
||||
"""
|
||||
gpt_response = gpt_model.invoke(prompt)
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
add_history(st.session_state, "assistant", gpt_response.content)
|
||||
st.markdown(f"**Chat 응답:**\n\n{gpt_response.content}")
|
||||
st.markdown("---")
|
||||
|
||||
# 문서 검색 결과 출력
|
||||
for idx, (doc, score) in enumerate(docs):
|
||||
with st.chat_message("assistant"):
|
||||
st.markdown(
|
||||
f"**{idx + 1}번 문서**: {doc.metadata['source']} (유사도: {score})"
|
||||
)
|
||||
with st.expander("문서 내용 보기"):
|
||||
st.markdown(doc.page_content)
|
||||
st.markdown("---")
|
||||
Reference in New Issue
Block a user