import csv import os import streamlit as st from dotenv import load_dotenv from kiwipiepy import Kiwi from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain_community.cross_encoders import HuggingFaceCrossEncoder from langchain_community.document_loaders import ( CSVLoader, DirectoryLoader, PyPDFLoader, TextLoader, UnstructuredHTMLLoader, UnstructuredMarkdownLoader, ) from langchain_community.retrievers import BM25Retriever from langchain_community.vectorstores import FAISS from langchain_core.messages import ChatMessage from langchain_experimental.text_splitter import SemanticChunker from langchain_huggingface import HuggingFaceEmbeddings from langchain_openai import ChatOpenAI # from langchain.retrievers.document_compressors import CrossEncoderReranker from custom_CrossEncoderReranker import CrossEncoderReranker csv.field_size_limit(10000000) load_dotenv() EMBEDDING_MODEL = "nlpai-lab/KoE5" RERANK_MODEL = "upskyy/ko-reranker-8k" LOAD_DOCS = "data/fake_rag" SAVE_VD = "fake_all_KCDS" def embedding(): model_kwargs = {"device": "cuda"} model_name = EMBEDDING_MODEL embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) return embeddings def load_and_process_documents(directory_path): loaders = [ DirectoryLoader( directory_path, glob="*.html", loader_cls=UnstructuredHTMLLoader, loader_kwargs={"encoding": "utf-8"}, ), DirectoryLoader( directory_path, glob="*.csv", loader_cls=CSVLoader, loader_kwargs={"encoding": "utf-8"}, ), DirectoryLoader(directory_path, glob="*.pdf", loader_cls=PyPDFLoader), DirectoryLoader( directory_path, glob="*.md", loader_cls=UnstructuredMarkdownLoader, loader_kwargs={"encoding": "utf-8"}, ), DirectoryLoader( directory_path, glob="*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}, ), ] all_documents = [] for loader in loaders: loaded_docs = loader.load() all_documents.extend(loaded_docs) # 메타데이터에 source 파일명 기록 for doc in all_documents: filename = os.path.basename(doc.metadata["source"]) filename_without_ext = os.path.splitext(filename)[0] doc.metadata = {"source": filename_without_ext} text_splitter = SemanticChunker(embeddings=embedding()) texts = text_splitter.split_documents(all_documents) return texts def save_faiss_index(faiss_index, path="faiss_index"): faiss_index.save_local(path) print("-------------------------------------------") print("save_faiss_index") print("-------------------------------------------") def load_faiss_index(path="faiss_index"): embeddings = embedding() print("-------------------------------------------") print("load_faiss_index") print("-------------------------------------------") return FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True) kiwi = Kiwi() @st.cache_resource def initialize_processing(): try: faiss_index = load_faiss_index(SAVE_VD) docs = list(faiss_index.docstore._dict.values()) except Exception: print("no load_faiss_index") docs = load_and_process_documents(LOAD_DOCS) embeddings = embedding() faiss_index = FAISS.from_documents(docs, embeddings) save_faiss_index(faiss_index, SAVE_VD) faiss = faiss_index.as_retriever(search_kwargs={"k": 3}) kiwi_bm25 = BM25Retriever.from_documents( docs, preprocess_func=lambda text: [token.form for token in kiwi.tokenize(text)] ) kiwi_bm25.k = 3 st.session_state["retrievers"] = { "kiwi_bm25": kiwi_bm25, "faiss": faiss, } st.session_state["Ensembleretriever"] = EnsembleRetriever( retrievers=[kiwi_bm25, faiss], weights=[0.6, 0.4], search_type="mmr", ) rerank_model = HuggingFaceCrossEncoder(model_name=RERANK_MODEL) compressor = CrossEncoderReranker(model=rerank_model, top_n=3) st.session_state["compression_retriever"] = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=st.session_state["Ensembleretriever"] ) return st.session_state["retrievers"], st.session_state["compression_retriever"] # GPT 모델 설정 gpt_model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) st.set_page_config(page_title="RAG 테스트", page_icon="💬") st.title("RAG 테스트") if "initialize_processing" not in st.session_state: st.session_state["retrievers"], st.session_state["compression_retriever"] = ( initialize_processing() ) retrievers = st.session_state["retrievers"] reranker = st.session_state["compression_retriever"] if "messages" not in st.session_state: st.session_state["messages"] = [ ChatMessage(role="assistant", content="무엇을 도와드릴까요?") ] # 대화 기록 출력 def print_history(st, session_state): for msg in session_state.messages: st.chat_message(msg.role).write(msg.content) # 대화 기록 추가 def add_history(session_state, role, content): session_state.messages.append(ChatMessage(role=role, content=content)) print_history(st, st.session_state) if user_input := st.chat_input(): add_history(st.session_state, "user", user_input) st.chat_message("user").write(user_input) with st.chat_message("assistant"): chat_container = st.empty() with st.spinner("문서 검색 및 답변 생성 중..."): docs = reranker.invoke(user_input) context_texts = "\n\n".join( [ f"- {doc[0].metadata['source']} (유사도: {doc[1]}): {doc[0].page_content}" for doc in docs ] ) prompt = f"""사용자의 질문: {user_input} 다음은 참고할 수 있는 문서입니다: {context_texts} 위 문서의 내용을 기반으로 사용자의 질문에 대한 답변을 생성해주세요. 문서에서 관련 내용을 찾을 수 없다면, 일반적인 정보를 제공해주세요. """ gpt_response = gpt_model.invoke(prompt) with st.chat_message("assistant"): add_history(st.session_state, "assistant", gpt_response.content) st.markdown(f"**Chat 응답:**\n\n{gpt_response.content}") st.markdown("---") # 문서 검색 결과 출력 for idx, (doc, score) in enumerate(docs): with st.chat_message("assistant"): st.markdown( f"**{idx + 1}번 문서**: {doc.metadata['source']} (유사도: {score})" ) with st.expander("문서 내용 보기"): st.markdown(doc.page_content) st.markdown("---")