Files
ObsidianRAGPipe/ObsidianRAGPipe.py
Daniel Henry da0e1a1745 Update debug and system prompt
Tune system prompt to avoid asking questions at the end of the
output.  Update debug thinking to include the LLM model used
for that chat.
2026-01-21 13:56:58 -06:00

481 lines
21 KiB
Python

"""
title: Obsidian RAG Pipeline
author: Daniel Henry
version: 0.15
"""
import asyncio
import json
import time
import urllib.parse
from typing import AsyncGenerator
import aiohttp
from pydantic import BaseModel, Field
class Pipe:
class Valves(BaseModel):
# Endpoints
ollama_url: str = Field(default="http://ollama.internal.henryhosted.com:11434")
qdrant_url: str = Field(default="http://app-01.internal.henryhosted.com:6333")
rerank_url: str = Field(default="http://ollama.internal.henryhosted.com:7997")
# Qdrant
collection_name: str = Field(default="obsidian_vault")
retrieve_count: int = Field(
default=50, description="Candidates to fetch from Qdrant"
)
qdrant_score_threshold: float = Field(
default=0.3, description="Minimum similarity score"
)
# Reranker
rerank_enabled: bool = Field(
default=True, description="Set to False to skip reranking"
)
rerank_timeout: float = Field(default=60.0)
min_rerank_score: float = Field(
default=0.01, description="Minimum rerank score to keep"
)
final_top_k: int = Field(
default=10, description="Chunks to keep after reranking"
)
# LLM
embedding_model: str = Field(default="nomic-embed-text:latest")
llm_model: str = Field(default="llama3.2:3b")
llm_context_size: int = Field(default=8192)
llm_timeout: float = Field(default=300.0)
query_rewrite_model: str = Field(
default="",
description="Model for query rewriting. Leave empty to use llm_model.",
)
# Obsidian
vault_name: str = Field(
default="Main", description="For generating obsidian:// links"
)
# Display
show_thinking: bool = Field(default=True)
show_sources: bool = Field(default=True)
show_stats: bool = Field(default=True)
token_warning_threshold: int = Field(
default=6000, description="Warn if context exceeds this"
)
def __init__(self):
self.valves = self.Valves()
async def pipe(self, body: dict) -> AsyncGenerator[str, None]:
messages = body.get("messages", [])
if not messages:
yield "No messages provided."
return
query = messages[-1].get("content", "").strip()
if not query:
yield "Empty query."
return
async with aiohttp.ClientSession() as session:
async for chunk in self._execute(session, query, messages):
yield chunk
async def _execute(
self,
session: aiohttp.ClientSession,
query: str,
messages: list[dict],
) -> AsyncGenerator[str, None]:
think = self.valves.show_thinking
# Start thinking block immediately
if think:
yield "<think>\n"
yield f"**LLM Model: ** {self.valves.llm_model}\n\n"
yield f"**Query:** {query}\n\n"
# ─────────────────────────────────────────────
# Step 1: Rewrite query with conversation context
# ─────────────────────────────────────────────
if think:
yield "**Step 1: Query Rewriting**\n"
t0 = time.time()
rewrite_model = self.valves.query_rewrite_model or self.valves.llm_model
# Build conversation context for rewriting
conversation_for_rewrite = []
for m in messages[:-1]: # All messages except the last one
role = m.get("role", "")
content = m.get("content", "")
if role == "user":
conversation_for_rewrite.append(f"User: {content}")
elif role == "assistant":
# Truncate assistant responses to avoid bloat
truncated = content[:500] + "..." if len(content) > 500 else content
conversation_for_rewrite.append(f"Assistant: {truncated}")
current_question = messages[-1].get("content", "")
# If there's prior conversation, rewrite the query
if conversation_for_rewrite:
rewrite_prompt = f"""Do not interpret or answer the question. Simply add enough context from the conversation so the question makes sense on its own.
Conversation:
{chr(10).join(conversation_for_rewrite)}
Latest question: {current_question}
Rewrite the question to be standalone (respond with ONLY the rewritten question, nothing else):"""
try:
async with session.post(
f"{self.valves.ollama_url}/api/generate",
json={
"model": rewrite_model,
"prompt": rewrite_prompt,
"stream": False,
},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
if resp.status == 200:
data = await resp.json()
rewritten = data.get("response", "").strip()
# Sanity check - if rewrite is empty or way too long, use original
if rewritten and len(rewritten) < 1000:
search_query = rewritten
else:
search_query = current_question
else:
search_query = current_question
except Exception as e:
if think:
yield f" ⚠ Rewrite failed: {e}, using original query\n"
search_query = current_question
else:
# No prior conversation, use the question as-is
search_query = current_question
if think:
yield f" Model: {rewrite_model}\n"
yield f" Original: {current_question}\n"
yield f" Search query: {search_query}\n"
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
# ─────────────────────────────────────────────
# Step 2: Embed
# ─────────────────────────────────────────────
if think:
yield "**Step 2: Embedding**\n"
t0 = time.time()
try:
async with session.post(
f"{self.valves.ollama_url}/api/embeddings",
json={
"model": self.valves.embedding_model,
"prompt": search_query,
"options": {"num_ctx": 8192},
},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
if think:
yield f" ✗ HTTP {resp.status}\n</think>\n\n"
yield f"Embedding failed: HTTP {resp.status}"
return
embedding = (await resp.json()).get("embedding")
except Exception as e:
if think:
yield f"{e}\n</think>\n\n"
yield f"Embedding failed: {e}"
return
if think:
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
# ─────────────────────────────────────────────
# Step 3: Search Qdrant
# ─────────────────────────────────────────────
if think:
yield "**Step 3: Qdrant Search**\n"
t0 = time.time()
try:
async with session.post(
f"{self.valves.qdrant_url}/collections/{self.valves.collection_name}/points/search",
json={
"vector": embedding,
"limit": self.valves.retrieve_count,
"with_payload": True,
"score_threshold": self.valves.qdrant_score_threshold,
},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
if think:
yield f" ✗ HTTP {resp.status}\n</think>\n\n"
yield f"Qdrant search failed: HTTP {resp.status}"
return
qdrant_results = (await resp.json()).get("result", [])
except Exception as e:
if think:
yield f"{e}\n</think>\n\n"
yield f"Qdrant search failed: {e}"
return
if think:
yield f" ✓ Found {len(qdrant_results)} chunks ({time.time() - t0:.2f}s)\n"
if not qdrant_results:
if think:
yield " ✗ No results\n</think>\n\n"
yield "No relevant notes found for this query."
return
# Show top 5
if think:
yield " Top 5:\n"
for i, r in enumerate(qdrant_results[:5]):
name = r.get("payload", {}).get("fileName", "?")
score = r.get("score", 0)
yield f" {i+1}. [{score:.4f}] {name}\n"
yield "\n"
# ─────────────────────────────────────────────
# Step 4: Rerank (optional)
# ─────────────────────────────────────────────
if self.valves.rerank_enabled:
if think:
yield "**Step 4: Reranking**\n"
t0 = time.time()
docs_for_rerank = [
r.get("payload", {}).get("content", "") for r in qdrant_results
]
try:
async with session.post(
f"{self.valves.rerank_url}/rerank",
json={
"query": search_query,
"documents": docs_for_rerank,
"return_documents": False,
},
timeout=aiohttp.ClientTimeout(total=self.valves.rerank_timeout),
) as resp:
if resp.status != 200:
if think:
yield f" ⚠ Reranker failed: HTTP {resp.status}, using Qdrant order\n\n"
chunks = qdrant_results[: self.valves.final_top_k]
else:
rerank_results = (await resp.json()).get("results", [])
# Apply rerank scores and filter
scored = []
for item in rerank_results:
idx = item["index"]
score = item["relevance_score"]
if score >= self.valves.min_rerank_score:
chunk = qdrant_results[idx].copy()
chunk["rerank_score"] = score
scored.append(chunk)
scored.sort(key=lambda x: x["rerank_score"], reverse=True)
chunks = scored[: self.valves.final_top_k]
if think:
yield f" ✓ Kept {len(chunks)} chunks ({time.time() - t0:.2f}s)\n"
if chunks:
yield " Top 5 after rerank:\n"
for i, c in enumerate(chunks[:5]):
name = c.get("payload", {}).get("fileName", "?")
score = c.get("rerank_score", 0)
yield f" {i+1}. [{score:.4f}] {name}\n"
yield "\n"
except Exception as e:
if think:
yield f" ⚠ Reranker error: {e}, using Qdrant order\n\n"
chunks = qdrant_results[: self.valves.final_top_k]
else:
if think:
yield "**Step 4: Reranking** (disabled)\n\n"
chunks = qdrant_results[: self.valves.final_top_k]
if not chunks:
if think:
yield " ✗ No chunks after filtering\n</think>\n\n"
yield "No relevant notes passed the relevance threshold."
return
# ─────────────────────────────────────────────
# Step 5: Build context
# ─────────────────────────────────────────────
if think:
yield "**Step 5: Build Context**\n"
context_parts = []
for i, chunk in enumerate(chunks, 1):
payload = chunk.get("payload", {})
file_name = payload.get("fileName", "Unknown")
content = payload.get("content", "").strip()
source = payload.get("source", "")
part = f"### Note Name {i}: {file_name}\n"
if source:
part += f"Original source: {source}\n"
part += f"\n{content}"
context_parts.append(part)
context = "\n\n---\n\n".join(context_parts)
context_chars = len(context)
estimated_tokens = context_chars // 4
if think:
yield f"{len(chunks)} chunks, {context_chars:,} chars (~{estimated_tokens:,} tokens)\n"
if estimated_tokens > self.valves.token_warning_threshold:
yield f" ⚠ Warning: approaching context limit ({self.valves.llm_context_size})\n"
yield "\n"
# ─────────────────────────────────────────────
# Step 6: Build prompt and call LLM
# ─────────────────────────────────────────────
if think:
yield "**Step 6: Generate Response**\n"
yield "</think>\n\n"
system_prompt = f"""
### ROLE
You are the user's "Knowledge Partner." You are warm, enthusiastic, and helpful. You love the user's notes and want to help them connect ideas.
### THE GOLDEN RULE (HARD WALL)
Your knowledge is strictly limited to the provided <notes>.
- IF the answer is in the notes: Synthesize it warmly and cite it.
- IF the answer is NOT in the notes: You must admit it. Say: "I checked your notes, but I couldn't find info on that."
- Be honest with the user. The user does not want blind support. You are a friendly research assistant not an overly supportive friend.
- **EXCEPTION:** ONLY if the user explicitly types the trigger phrase "System: Add Context" are you allowed to use outside knowledge.
### INSTRUCTIONS
1. **Search First:** Look through the <notes> to find the answer.
2. **Synthesize:** You may combine facts from different notes to build a complete answer.
3. **Cite Everything:** Every single statement of fact must end with a citation in this format: `[Note Name]`.
4. **Tone:** Be conversational but professional. Avoid robotic phrases like "According to the provided text." Instead, say "Your note on [Topic] mentions..."
5. **Additional:** Avoid asking follow up questions at the end of your output.
### EXAMPLES (Follow this pattern)
**User:** "What did I write about the project deadline?"
**You:** "I looked through your project logs! It seems you set the final submission date for October 15th [Project_Alpha_Log]. You also noted that the design phase needs to wrap up by the 1st [Design_Team_Meeting]."
**User:** "Who is the president of France?" (Note: This is NOT in your notes)
**You:** "I checked your notes, but I don't see any mention of the current president of France. Would you like me to use outside knowledge? If so, just say 'System: Add Context'."
### SOURCE NOTES
<notes>
{context}
</notes>
"""
# Only keep user/assistant messages
conversation = [m for m in messages if m.get("role") in ("user", "assistant")]
# UPDATED: Robustly strip previous "Sources" to prevent pattern matching
conversation = []
for m in messages:
if m.get("role") not in ("user", "assistant"):
continue
msg = m.copy()
if msg["role"] == "assistant":
content = msg.get("content", "")
# Split on "**Sources:**" which is the visible header.
# This catches it even if the newlines/separators are slightly different.
if "**Sources:**" in content:
msg["content"] = content.split("**Sources:**")[0].strip()
conversation.append(msg)
llm_payload = {
"model": self.valves.llm_model,
"messages": [
{"role": "system", "content": system_prompt},
*conversation,
],
"stream": True,
"options": {"num_ctx": self.valves.llm_context_size},
}
# Stream LLM response
prompt_tokens = 0
completion_tokens = 0
try:
async with session.post(
f"{self.valves.ollama_url}/api/chat",
json=llm_payload,
timeout=aiohttp.ClientTimeout(total=self.valves.llm_timeout),
) as resp:
if resp.status != 200:
yield f"LLM error: HTTP {resp.status}"
return
async for line in resp.content:
if not line:
continue
try:
data = json.loads(line)
if text := data.get("message", {}).get("content"):
yield text
if data.get("done"):
prompt_tokens = data.get("prompt_eval_count", 0)
completion_tokens = data.get("eval_count", 0)
except json.JSONDecodeError:
continue
except asyncio.TimeoutError:
yield "\n\n⚠️ LLM timed out"
return
except Exception as e:
yield f"\n\nLLM error: {e}"
return
# ─────────────────────────────────────────────
# Sources
# ─────────────────────────────────────────────
if self.valves.show_sources:
# Dedupe by file path, count chunks
source_counts: dict[str, dict] = {}
for chunk in chunks:
payload = chunk.get("payload", {})
path = payload.get("filePath", "")
name = payload.get("fileName", "Unknown")
if path in source_counts:
source_counts[path]["count"] += 1
else:
source_counts[path] = {"name": name, "path": path, "count": 1}
yield "\n\n---\n**Sources:**\n"
for src in source_counts.values():
vault = urllib.parse.quote(self.valves.vault_name)
path = urllib.parse.quote(src["path"])
uri = f"obsidian://open?vault={vault}&file={path}"
count_str = f" ({src['count']} chunks)" if src["count"] > 1 else ""
yield f"- [{src['name']}]({uri}){count_str}\n"
# ─────────────────────────────────────────────
# Stats
# ─────────────────────────────────────────────
if self.valves.show_stats:
yield f"\n*{prompt_tokens:,} in / {completion_tokens:,} out*"