Tune system prompt to avoid asking questions at the end of the output. Update debug thinking to include the LLM model used for that chat.
481 lines
21 KiB
Python
481 lines
21 KiB
Python
"""
|
|
title: Obsidian RAG Pipeline
|
|
author: Daniel Henry
|
|
version: 0.15
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import urllib.parse
|
|
from typing import AsyncGenerator
|
|
|
|
import aiohttp
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class Pipe:
|
|
|
|
class Valves(BaseModel):
|
|
# Endpoints
|
|
ollama_url: str = Field(default="http://ollama.internal.henryhosted.com:11434")
|
|
qdrant_url: str = Field(default="http://app-01.internal.henryhosted.com:6333")
|
|
rerank_url: str = Field(default="http://ollama.internal.henryhosted.com:7997")
|
|
|
|
# Qdrant
|
|
collection_name: str = Field(default="obsidian_vault")
|
|
retrieve_count: int = Field(
|
|
default=50, description="Candidates to fetch from Qdrant"
|
|
)
|
|
qdrant_score_threshold: float = Field(
|
|
default=0.3, description="Minimum similarity score"
|
|
)
|
|
|
|
# Reranker
|
|
rerank_enabled: bool = Field(
|
|
default=True, description="Set to False to skip reranking"
|
|
)
|
|
rerank_timeout: float = Field(default=60.0)
|
|
min_rerank_score: float = Field(
|
|
default=0.01, description="Minimum rerank score to keep"
|
|
)
|
|
final_top_k: int = Field(
|
|
default=10, description="Chunks to keep after reranking"
|
|
)
|
|
|
|
# LLM
|
|
embedding_model: str = Field(default="nomic-embed-text:latest")
|
|
llm_model: str = Field(default="llama3.2:3b")
|
|
llm_context_size: int = Field(default=8192)
|
|
llm_timeout: float = Field(default=300.0)
|
|
query_rewrite_model: str = Field(
|
|
default="",
|
|
description="Model for query rewriting. Leave empty to use llm_model.",
|
|
)
|
|
|
|
# Obsidian
|
|
vault_name: str = Field(
|
|
default="Main", description="For generating obsidian:// links"
|
|
)
|
|
|
|
# Display
|
|
show_thinking: bool = Field(default=True)
|
|
show_sources: bool = Field(default=True)
|
|
show_stats: bool = Field(default=True)
|
|
token_warning_threshold: int = Field(
|
|
default=6000, description="Warn if context exceeds this"
|
|
)
|
|
|
|
def __init__(self):
|
|
self.valves = self.Valves()
|
|
|
|
async def pipe(self, body: dict) -> AsyncGenerator[str, None]:
|
|
messages = body.get("messages", [])
|
|
if not messages:
|
|
yield "No messages provided."
|
|
return
|
|
|
|
query = messages[-1].get("content", "").strip()
|
|
if not query:
|
|
yield "Empty query."
|
|
return
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async for chunk in self._execute(session, query, messages):
|
|
yield chunk
|
|
|
|
async def _execute(
|
|
self,
|
|
session: aiohttp.ClientSession,
|
|
query: str,
|
|
messages: list[dict],
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
think = self.valves.show_thinking
|
|
|
|
# Start thinking block immediately
|
|
if think:
|
|
yield "<think>\n"
|
|
yield f"**LLM Model: ** {self.valves.llm_model}\n\n"
|
|
yield f"**Query:** {query}\n\n"
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Step 1: Rewrite query with conversation context
|
|
# ─────────────────────────────────────────────
|
|
if think:
|
|
yield "**Step 1: Query Rewriting**\n"
|
|
|
|
t0 = time.time()
|
|
rewrite_model = self.valves.query_rewrite_model or self.valves.llm_model
|
|
|
|
# Build conversation context for rewriting
|
|
conversation_for_rewrite = []
|
|
for m in messages[:-1]: # All messages except the last one
|
|
role = m.get("role", "")
|
|
content = m.get("content", "")
|
|
if role == "user":
|
|
conversation_for_rewrite.append(f"User: {content}")
|
|
elif role == "assistant":
|
|
# Truncate assistant responses to avoid bloat
|
|
truncated = content[:500] + "..." if len(content) > 500 else content
|
|
conversation_for_rewrite.append(f"Assistant: {truncated}")
|
|
|
|
current_question = messages[-1].get("content", "")
|
|
|
|
# If there's prior conversation, rewrite the query
|
|
if conversation_for_rewrite:
|
|
rewrite_prompt = f"""Do not interpret or answer the question. Simply add enough context from the conversation so the question makes sense on its own.
|
|
|
|
Conversation:
|
|
{chr(10).join(conversation_for_rewrite)}
|
|
|
|
Latest question: {current_question}
|
|
|
|
Rewrite the question to be standalone (respond with ONLY the rewritten question, nothing else):"""
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.valves.ollama_url}/api/generate",
|
|
json={
|
|
"model": rewrite_model,
|
|
"prompt": rewrite_prompt,
|
|
"stream": False,
|
|
},
|
|
timeout=aiohttp.ClientTimeout(total=30),
|
|
) as resp:
|
|
if resp.status == 200:
|
|
data = await resp.json()
|
|
rewritten = data.get("response", "").strip()
|
|
# Sanity check - if rewrite is empty or way too long, use original
|
|
if rewritten and len(rewritten) < 1000:
|
|
search_query = rewritten
|
|
else:
|
|
search_query = current_question
|
|
else:
|
|
search_query = current_question
|
|
except Exception as e:
|
|
if think:
|
|
yield f" ⚠ Rewrite failed: {e}, using original query\n"
|
|
search_query = current_question
|
|
else:
|
|
# No prior conversation, use the question as-is
|
|
search_query = current_question
|
|
|
|
if think:
|
|
yield f" Model: {rewrite_model}\n"
|
|
yield f" Original: {current_question}\n"
|
|
yield f" Search query: {search_query}\n"
|
|
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Step 2: Embed
|
|
# ─────────────────────────────────────────────
|
|
if think:
|
|
yield "**Step 2: Embedding**\n"
|
|
t0 = time.time()
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.valves.ollama_url}/api/embeddings",
|
|
json={
|
|
"model": self.valves.embedding_model,
|
|
"prompt": search_query,
|
|
"options": {"num_ctx": 8192},
|
|
},
|
|
timeout=aiohttp.ClientTimeout(total=15),
|
|
) as resp:
|
|
if resp.status != 200:
|
|
if think:
|
|
yield f" ✗ HTTP {resp.status}\n</think>\n\n"
|
|
yield f"Embedding failed: HTTP {resp.status}"
|
|
return
|
|
embedding = (await resp.json()).get("embedding")
|
|
except Exception as e:
|
|
if think:
|
|
yield f" ✗ {e}\n</think>\n\n"
|
|
yield f"Embedding failed: {e}"
|
|
return
|
|
|
|
if think:
|
|
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Step 3: Search Qdrant
|
|
# ─────────────────────────────────────────────
|
|
if think:
|
|
yield "**Step 3: Qdrant Search**\n"
|
|
t0 = time.time()
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.valves.qdrant_url}/collections/{self.valves.collection_name}/points/search",
|
|
json={
|
|
"vector": embedding,
|
|
"limit": self.valves.retrieve_count,
|
|
"with_payload": True,
|
|
"score_threshold": self.valves.qdrant_score_threshold,
|
|
},
|
|
timeout=aiohttp.ClientTimeout(total=15),
|
|
) as resp:
|
|
if resp.status != 200:
|
|
if think:
|
|
yield f" ✗ HTTP {resp.status}\n</think>\n\n"
|
|
yield f"Qdrant search failed: HTTP {resp.status}"
|
|
return
|
|
qdrant_results = (await resp.json()).get("result", [])
|
|
except Exception as e:
|
|
if think:
|
|
yield f" ✗ {e}\n</think>\n\n"
|
|
yield f"Qdrant search failed: {e}"
|
|
return
|
|
|
|
if think:
|
|
yield f" ✓ Found {len(qdrant_results)} chunks ({time.time() - t0:.2f}s)\n"
|
|
|
|
if not qdrant_results:
|
|
if think:
|
|
yield " ✗ No results\n</think>\n\n"
|
|
yield "No relevant notes found for this query."
|
|
return
|
|
|
|
# Show top 5
|
|
if think:
|
|
yield " Top 5:\n"
|
|
for i, r in enumerate(qdrant_results[:5]):
|
|
name = r.get("payload", {}).get("fileName", "?")
|
|
score = r.get("score", 0)
|
|
yield f" {i+1}. [{score:.4f}] {name}\n"
|
|
yield "\n"
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Step 4: Rerank (optional)
|
|
# ─────────────────────────────────────────────
|
|
if self.valves.rerank_enabled:
|
|
if think:
|
|
yield "**Step 4: Reranking**\n"
|
|
t0 = time.time()
|
|
|
|
docs_for_rerank = [
|
|
r.get("payload", {}).get("content", "") for r in qdrant_results
|
|
]
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.valves.rerank_url}/rerank",
|
|
json={
|
|
"query": search_query,
|
|
"documents": docs_for_rerank,
|
|
"return_documents": False,
|
|
},
|
|
timeout=aiohttp.ClientTimeout(total=self.valves.rerank_timeout),
|
|
) as resp:
|
|
if resp.status != 200:
|
|
if think:
|
|
yield f" ⚠ Reranker failed: HTTP {resp.status}, using Qdrant order\n\n"
|
|
chunks = qdrant_results[: self.valves.final_top_k]
|
|
else:
|
|
rerank_results = (await resp.json()).get("results", [])
|
|
|
|
# Apply rerank scores and filter
|
|
scored = []
|
|
for item in rerank_results:
|
|
idx = item["index"]
|
|
score = item["relevance_score"]
|
|
if score >= self.valves.min_rerank_score:
|
|
chunk = qdrant_results[idx].copy()
|
|
chunk["rerank_score"] = score
|
|
scored.append(chunk)
|
|
|
|
scored.sort(key=lambda x: x["rerank_score"], reverse=True)
|
|
chunks = scored[: self.valves.final_top_k]
|
|
|
|
if think:
|
|
yield f" ✓ Kept {len(chunks)} chunks ({time.time() - t0:.2f}s)\n"
|
|
|
|
if chunks:
|
|
yield " Top 5 after rerank:\n"
|
|
for i, c in enumerate(chunks[:5]):
|
|
name = c.get("payload", {}).get("fileName", "?")
|
|
score = c.get("rerank_score", 0)
|
|
yield f" {i+1}. [{score:.4f}] {name}\n"
|
|
yield "\n"
|
|
|
|
except Exception as e:
|
|
if think:
|
|
yield f" ⚠ Reranker error: {e}, using Qdrant order\n\n"
|
|
chunks = qdrant_results[: self.valves.final_top_k]
|
|
|
|
else:
|
|
if think:
|
|
yield "**Step 4: Reranking** (disabled)\n\n"
|
|
chunks = qdrant_results[: self.valves.final_top_k]
|
|
|
|
if not chunks:
|
|
if think:
|
|
yield " ✗ No chunks after filtering\n</think>\n\n"
|
|
yield "No relevant notes passed the relevance threshold."
|
|
return
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Step 5: Build context
|
|
# ─────────────────────────────────────────────
|
|
if think:
|
|
yield "**Step 5: Build Context**\n"
|
|
|
|
context_parts = []
|
|
for i, chunk in enumerate(chunks, 1):
|
|
payload = chunk.get("payload", {})
|
|
file_name = payload.get("fileName", "Unknown")
|
|
content = payload.get("content", "").strip()
|
|
source = payload.get("source", "")
|
|
|
|
part = f"### Note Name {i}: {file_name}\n"
|
|
if source:
|
|
part += f"Original source: {source}\n"
|
|
part += f"\n{content}"
|
|
context_parts.append(part)
|
|
|
|
context = "\n\n---\n\n".join(context_parts)
|
|
context_chars = len(context)
|
|
estimated_tokens = context_chars // 4
|
|
|
|
if think:
|
|
yield f" ✓ {len(chunks)} chunks, {context_chars:,} chars (~{estimated_tokens:,} tokens)\n"
|
|
|
|
if estimated_tokens > self.valves.token_warning_threshold:
|
|
yield f" ⚠ Warning: approaching context limit ({self.valves.llm_context_size})\n"
|
|
|
|
yield "\n"
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Step 6: Build prompt and call LLM
|
|
# ─────────────────────────────────────────────
|
|
if think:
|
|
yield "**Step 6: Generate Response**\n"
|
|
yield "</think>\n\n"
|
|
|
|
system_prompt = f"""
|
|
### ROLE
|
|
You are the user's "Knowledge Partner." You are warm, enthusiastic, and helpful. You love the user's notes and want to help them connect ideas.
|
|
|
|
### THE GOLDEN RULE (HARD WALL)
|
|
Your knowledge is strictly limited to the provided <notes>.
|
|
- IF the answer is in the notes: Synthesize it warmly and cite it.
|
|
- IF the answer is NOT in the notes: You must admit it. Say: "I checked your notes, but I couldn't find info on that."
|
|
- Be honest with the user. The user does not want blind support. You are a friendly research assistant not an overly supportive friend.
|
|
- **EXCEPTION:** ONLY if the user explicitly types the trigger phrase "System: Add Context" are you allowed to use outside knowledge.
|
|
|
|
### INSTRUCTIONS
|
|
1. **Search First:** Look through the <notes> to find the answer.
|
|
2. **Synthesize:** You may combine facts from different notes to build a complete answer.
|
|
3. **Cite Everything:** Every single statement of fact must end with a citation in this format: `[Note Name]`.
|
|
4. **Tone:** Be conversational but professional. Avoid robotic phrases like "According to the provided text." Instead, say "Your note on [Topic] mentions..."
|
|
5. **Additional:** Avoid asking follow up questions at the end of your output.
|
|
|
|
### EXAMPLES (Follow this pattern)
|
|
|
|
**User:** "What did I write about the project deadline?"
|
|
**You:** "I looked through your project logs! It seems you set the final submission date for October 15th [Project_Alpha_Log]. You also noted that the design phase needs to wrap up by the 1st [Design_Team_Meeting]."
|
|
|
|
**User:** "Who is the president of France?" (Note: This is NOT in your notes)
|
|
**You:** "I checked your notes, but I don't see any mention of the current president of France. Would you like me to use outside knowledge? If so, just say 'System: Add Context'."
|
|
|
|
### SOURCE NOTES
|
|
<notes>
|
|
{context}
|
|
</notes>
|
|
"""
|
|
|
|
# Only keep user/assistant messages
|
|
conversation = [m for m in messages if m.get("role") in ("user", "assistant")]
|
|
|
|
# UPDATED: Robustly strip previous "Sources" to prevent pattern matching
|
|
conversation = []
|
|
for m in messages:
|
|
if m.get("role") not in ("user", "assistant"):
|
|
continue
|
|
|
|
msg = m.copy()
|
|
if msg["role"] == "assistant":
|
|
content = msg.get("content", "")
|
|
# Split on "**Sources:**" which is the visible header.
|
|
# This catches it even if the newlines/separators are slightly different.
|
|
if "**Sources:**" in content:
|
|
msg["content"] = content.split("**Sources:**")[0].strip()
|
|
|
|
conversation.append(msg)
|
|
|
|
llm_payload = {
|
|
"model": self.valves.llm_model,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
*conversation,
|
|
],
|
|
"stream": True,
|
|
"options": {"num_ctx": self.valves.llm_context_size},
|
|
}
|
|
|
|
# Stream LLM response
|
|
prompt_tokens = 0
|
|
completion_tokens = 0
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.valves.ollama_url}/api/chat",
|
|
json=llm_payload,
|
|
timeout=aiohttp.ClientTimeout(total=self.valves.llm_timeout),
|
|
) as resp:
|
|
if resp.status != 200:
|
|
yield f"LLM error: HTTP {resp.status}"
|
|
return
|
|
|
|
async for line in resp.content:
|
|
if not line:
|
|
continue
|
|
try:
|
|
data = json.loads(line)
|
|
if text := data.get("message", {}).get("content"):
|
|
yield text
|
|
if data.get("done"):
|
|
prompt_tokens = data.get("prompt_eval_count", 0)
|
|
completion_tokens = data.get("eval_count", 0)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
except asyncio.TimeoutError:
|
|
yield "\n\n⚠️ LLM timed out"
|
|
return
|
|
except Exception as e:
|
|
yield f"\n\nLLM error: {e}"
|
|
return
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Sources
|
|
# ─────────────────────────────────────────────
|
|
if self.valves.show_sources:
|
|
# Dedupe by file path, count chunks
|
|
source_counts: dict[str, dict] = {}
|
|
for chunk in chunks:
|
|
payload = chunk.get("payload", {})
|
|
path = payload.get("filePath", "")
|
|
name = payload.get("fileName", "Unknown")
|
|
if path in source_counts:
|
|
source_counts[path]["count"] += 1
|
|
else:
|
|
source_counts[path] = {"name": name, "path": path, "count": 1}
|
|
|
|
yield "\n\n---\n**Sources:**\n"
|
|
for src in source_counts.values():
|
|
vault = urllib.parse.quote(self.valves.vault_name)
|
|
path = urllib.parse.quote(src["path"])
|
|
uri = f"obsidian://open?vault={vault}&file={path}"
|
|
count_str = f" ({src['count']} chunks)" if src["count"] > 1 else ""
|
|
yield f"- [{src['name']}]({uri}){count_str}\n"
|
|
|
|
# ─────────────────────────────────────────────
|
|
# Stats
|
|
# ─────────────────────────────────────────────
|
|
if self.valves.show_stats:
|
|
yield f"\n*{prompt_tokens:,} in / {completion_tokens:,} out*"
|
|
|