Files
RAG/main.py

64 lines
2.0 KiB
Python

import os
import glob
from typing import List
import json
from src.impl.datastore import Datastore, DataItem
from src.impl.indexer import Indexer
from src.impl.retriever import Retriever
from src.impl.response_generator import ResponseGenerator
from src.impl.evaluator import Evaluator
from src.RAG_pipeline import RAGpipeline
from create_parser import create_parser
DEFAULT_SOURCE_PATH = "data/source/"
DEFAULT_EVAL_PATH ="data/eval/sample_questions.json"
def create_pipeline() -> RAGpipeline:
indexer = Indexer()
datastore = Datastore()
retriever = Retriever(datastore= datastore)
response_generator = ResponseGenerator()
evaluator = Evaluator()
return RAGpipeline(indexer = indexer, datastore = datastore, retriever= retriever, response_generator= response_generator, evaluator= evaluator)
def main():
parser = create_parser()
args = parser.parse_args()
pipeline = create_pipeline()
source_path = getattr(args, "path", DEFAULT_SOURCE_PATH) or DEFAULT_SOURCE_PATH
documents_path = get_files_in_directory(source_path=source_path)
eval_path = args.eval_file if args.eval_file else DEFAULT_EVAL_PATH
sample_questions = json.load(open(eval_path, "r"))
commands = {
"run": lambda: pipeline.run(documents_path = documents_path),
"reset": lambda: pipeline.reset(),
"add": lambda: pipeline.add_documents(documents_path=documents_path),
"evaluate": lambda: pipeline.evaluate(sample_questions= sample_questions),
"query": lambda: print(pipeline.process_query(args.prompt)),
}
try:
commands[args.commands]()
except Exception as e:
print(f"❌ ERREUR: {e}")
import traceback
traceback.print_exc()
commands[args.commands]
return
def get_files_in_directory(source_path: str) -> List[str]:
if os.path.isfile(source_path):
return [source_path]
return glob.glob(os.path.join(source_path, "*"))
if __name__ == "__main__":
main()