diff --git a/main.py b/main.py index 6737fc4..5e0e420 100644 --- a/main.py +++ b/main.py @@ -6,15 +6,17 @@ TEST_PATH = "data/source" def main(): query_graphiste = "Quel est le salaire brut mensuel du graphiste ?" query_graphiste_en = "What is the monthly gross salary of the graphist designer ?" + query_tshirt = "Quel est le prix du Tshirt rouge avec le motif1 ?" print("Testing indexer") indexer = Indexer() - items_from_indexer = indexer.index(["data/source/database.pdf", - "data/source/bilan_comptable_2024.pdf", - "data/source/employes.pdf", + items_from_indexer = indexer.index(["data/source/informations_entreprise.pdf", + "data/source/bilan_comptable_2024.csv", + "data/source/employes.csv", "data/source/facture_14_03_2025.pdf", - "data/source/fournisseurs.pdf", - "data/source/historique_commandes.pdf", - "data/source/planning_production_mars_2025.pdf", + "data/source/fournisseurs.csv", + "data/source/historique_commandes.csv", + "data/source/planning_production_mars_2025.csv", + "data/source/stock_tshirt.csv" ]) @@ -36,11 +38,11 @@ def main(): print("Testing retriever") retriever = Retriever(datastore= datastore) - print(retriever.search_retriever(query_graphiste)) + #print(retriever.search_retriever(query_graphiste)) print("Testing Response generator") - #response_generator = ResponseGenerator() - #print(response_generator.generate_response(query_graphiste, retriever.search_retriever(query_graphiste))) + response_generator = ResponseGenerator() + print(response_generator.generate_response(query_graphiste, retriever.search_retriever(query_graphiste))) print("fin") exit diff --git a/requirements.txt b/requirements.txt index 9b4727b..1c25d83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ pydantic>=2.0.0 # For data validation lancedb==0.6.13 docling==2.31.0 cohere==5.15.0 - +requests>=2.31.0 diff --git a/src/impl/response_generator.py b/src/impl/response_generator.py index 204d06f..7b6ce97 100644 --- a/src/impl/response_generator.py +++ b/src/impl/response_generator.py @@ -1,7 +1,8 @@ -from typing import List +from typing import List, Optional from ..interface.base_response_generator import BaseResponseGenerator -import requests -import json +from groq import Groq +import os + SYSTEM_PROMPT = """Tu es un assistant intelligent qui répond aux questions en te basant sur le contexte fourni. @@ -14,13 +15,32 @@ Règles importantes: class ResponseGenerator(BaseResponseGenerator): - - def __init__(self, model_name: str = "llama3.2:3b", base_url: str = "http://localhost:11434"): - self.model_name = model_name - self.base_url = base_url + + def __init__(self, api_key: Optional[str] = None): + try : + self.api_key = api_key or os.getenv("GROQ_API_KEY") + except Exception as e: + raise ValueError(f"erreur avec la clé API: {e}") - def generate_response(self, query: str, context: List[str]) -> str: - """Génère une réponse basée sur la requête et le contexte.""" + try: + self.client = Groq(api_key=self.api_key) + self.model = "llama-3.1-8b-instant" # Rapide et gratuit + print("✅ Générateur Groq initialisé avec succès") + except Exception as e: + raise ValueError(f"❌ Erreur lors de l'initialisation de Groq: {e}") + + def generate_response(self, query: str, context: List[str], max_tokens: int = 512) -> str: + """ + Génère une réponse basée sur la requête et le contexte. + + Args: + query: Question de l'utilisateur + context: Liste de documents pertinents + max_tokens: Longueur maximale de la réponse + + Returns: + Réponse générée + """ # Formater le contexte formatted_context = "\n\n".join([f"Document {i+1}:\n{doc}" for i, doc in enumerate(context)]) @@ -36,48 +56,31 @@ class ResponseGenerator(BaseResponseGenerator): # Appeler Ollama via l'API try: - response = requests.post( - f"{self.base_url}/api/generate", - json={ - "model": self.model_name, - "prompt": prompt, - "stream": False, - "options": { - "temperature": 0.7, - "top_p": 0.9, - } - } + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt} + ], + temperature=0.7, + max_tokens=max_tokens, ) - - # Vérifier le statut de la réponse - response.raise_for_status() - # Parser le JSON - result = response.json() + answer = response.choices[0].message.content.strip() - # DEBUG: Afficher la structure de la réponse - print(f"DEBUG - Structure de la réponse: {result.keys()}") + if not answer: + return "⚠️ Le modèle n'a pas pu générer de réponse." - # Vérifier les différentes clés possibles - if "response" in result: - return result["response"] - elif "message" in result: - return result["message"] - elif "content" in result: - return result["content"] - else: - # Si aucune clé attendue n'est trouvée - print(f"DEBUG - Réponse complète: {result}") - return f"Erreur: Format de réponse inattendu. Clés disponibles: {list(result.keys())}" - - except requests.exceptions.ConnectionError: - return "❌ Impossible de se connecter au serveur Ollama. Vérifiez qu'Ollama est en cours d'exécution avec: ollama serve" - - except requests.exceptions.Timeout: - return "⚠️ La génération a pris trop de temps. Essayez avec un modèle plus petit." - - except requests.exceptions.HTTPError as e: - return f"❌ Erreur HTTP {response.status_code}: {e}" + return answer except Exception as e: - return f"❌ Erreur lors de la génération: {str(e)}" \ No newline at end of file + error_msg = str(e).lower() + + # Erreurs spécifiques + if "rate" in error_msg or "limit" in error_msg: + return "⚠️ Limite de requêtes atteinte. Attendez 1 minute et réessayez." + elif "authentication" in error_msg or "api" in error_msg: + return "❌ Erreur d'authentification. Vérifiez votre clé API dans le fichier .env" + else: + return f"❌ Erreur lors de la génération: {str(e)}" + \ No newline at end of file