64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
import os
|
|
import glob
|
|
from typing import List
|
|
import json
|
|
|
|
from src.impl.datastore import Datastore, DataItem
|
|
from src.impl.indexer import Indexer
|
|
from src.impl.retriever import Retriever
|
|
from src.impl.response_generator import ResponseGenerator
|
|
from src.impl.evaluator import Evaluator
|
|
|
|
from src.RAG_pipeline import RAGpipeline
|
|
from create_parser import create_parser
|
|
|
|
DEFAULT_SOURCE_PATH = "data/source/"
|
|
DEFAULT_EVAL_PATH ="data/eval/sample_questions.json"
|
|
|
|
def create_pipeline() -> RAGpipeline:
|
|
indexer = Indexer()
|
|
datastore = Datastore()
|
|
retriever = Retriever(datastore= datastore)
|
|
response_generator = ResponseGenerator()
|
|
evaluator = Evaluator()
|
|
return RAGpipeline(indexer = indexer, datastore = datastore, retriever= retriever, response_generator= response_generator, evaluator= evaluator)
|
|
|
|
def main():
|
|
|
|
parser = create_parser()
|
|
args = parser.parse_args()
|
|
pipeline = create_pipeline()
|
|
|
|
source_path = getattr(args, "path", DEFAULT_SOURCE_PATH) or DEFAULT_SOURCE_PATH
|
|
documents_path = get_files_in_directory(source_path=source_path)
|
|
|
|
eval_path = args.eval_file if args.eval_file else DEFAULT_EVAL_PATH
|
|
sample_questions = json.load(open(eval_path, "r"))
|
|
|
|
commands = {
|
|
"run": lambda: pipeline.run(documents_path = documents_path),
|
|
"reset": lambda: pipeline.reset(),
|
|
"add": lambda: pipeline.add_documents(documents_path=documents_path),
|
|
"evaluate": lambda: pipeline.evaluate(sample_questions= sample_questions),
|
|
"query": lambda: print(pipeline.process_query(args.prompt)),
|
|
}
|
|
|
|
try:
|
|
commands[args.commands]()
|
|
except Exception as e:
|
|
print(f"❌ ERREUR: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
commands[args.commands]
|
|
|
|
return
|
|
|
|
def get_files_in_directory(source_path: str) -> List[str]:
|
|
if os.path.isfile(source_path):
|
|
return [source_path]
|
|
return glob.glob(os.path.join(source_path, "*"))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|