449
ObsidianRAGPipe.py
Normal file
449
ObsidianRAGPipe.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
"""
|
||||||
|
title: Obsidian RAG
|
||||||
|
author: Daniel
|
||||||
|
version: 6.0
|
||||||
|
required_open_webui_version: 0.3.9
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Pipe:
|
||||||
|
|
||||||
|
class Valves(BaseModel):
|
||||||
|
# Endpoints
|
||||||
|
ollama_url: str = Field(default="http://ollama.internal.henryhosted.com:11434")
|
||||||
|
qdrant_url: str = Field(default="http://app-01.internal.henryhosted.com:6333")
|
||||||
|
rerank_url: str = Field(default="http://ollama.internal.henryhosted.com:7997")
|
||||||
|
|
||||||
|
# Qdrant
|
||||||
|
collection_name: str = Field(default="obsidian_vault")
|
||||||
|
retrieve_count: int = Field(
|
||||||
|
default=50, description="Candidates to fetch from Qdrant"
|
||||||
|
)
|
||||||
|
qdrant_score_threshold: float = Field(
|
||||||
|
default=0.3, description="Minimum similarity score"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reranker
|
||||||
|
rerank_enabled: bool = Field(
|
||||||
|
default=True, description="Set to False to skip reranking"
|
||||||
|
)
|
||||||
|
rerank_timeout: float = Field(default=60.0)
|
||||||
|
min_rerank_score: float = Field(
|
||||||
|
default=0.01, description="Minimum rerank score to keep"
|
||||||
|
)
|
||||||
|
final_top_k: int = Field(
|
||||||
|
default=10, description="Chunks to keep after reranking"
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM
|
||||||
|
embedding_model: str = Field(default="nomic-embed-text")
|
||||||
|
llm_model: str = Field(default="llama3.2:3b")
|
||||||
|
llm_context_size: int = Field(default=8192)
|
||||||
|
llm_timeout: float = Field(default=300.0)
|
||||||
|
query_rewrite_model: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Model for query rewriting. Leave empty to use llm_model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Obsidian
|
||||||
|
vault_name: str = Field(
|
||||||
|
default="Main", description="For generating obsidian:// links"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display
|
||||||
|
show_thinking: bool = Field(default=True)
|
||||||
|
show_sources: bool = Field(default=True)
|
||||||
|
show_stats: bool = Field(default=True)
|
||||||
|
token_warning_threshold: int = Field(
|
||||||
|
default=6000, description="Warn if context exceeds this"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.valves = self.Valves()
|
||||||
|
|
||||||
|
async def pipe(self, body: dict) -> AsyncGenerator[str, None]:
|
||||||
|
messages = body.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
yield "No messages provided."
|
||||||
|
return
|
||||||
|
|
||||||
|
query = messages[-1].get("content", "").strip()
|
||||||
|
if not query:
|
||||||
|
yield "Empty query."
|
||||||
|
return
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async for chunk in self._execute(session, query, messages):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
query: str,
|
||||||
|
messages: list[dict],
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
|
||||||
|
think = self.valves.show_thinking
|
||||||
|
|
||||||
|
# Start thinking block immediately
|
||||||
|
if think:
|
||||||
|
yield "<think>\n"
|
||||||
|
yield f"**Query:** {query}\n\n"
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Step 1: Rewrite query with conversation context
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if think:
|
||||||
|
yield "**Step 1: Query Rewriting**\n"
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
rewrite_model = self.valves.query_rewrite_model or self.valves.llm_model
|
||||||
|
|
||||||
|
# Build conversation context for rewriting
|
||||||
|
conversation_for_rewrite = []
|
||||||
|
for m in messages[:-1]: # All messages except the last one
|
||||||
|
role = m.get("role", "")
|
||||||
|
content = m.get("content", "")
|
||||||
|
if role == "user":
|
||||||
|
conversation_for_rewrite.append(f"User: {content}")
|
||||||
|
elif role == "assistant":
|
||||||
|
# Truncate assistant responses to avoid bloat
|
||||||
|
truncated = content[:500] + "..." if len(content) > 500 else content
|
||||||
|
conversation_for_rewrite.append(f"Assistant: {truncated}")
|
||||||
|
|
||||||
|
current_question = messages[-1].get("content", "")
|
||||||
|
|
||||||
|
# If there's prior conversation, rewrite the query
|
||||||
|
if conversation_for_rewrite:
|
||||||
|
rewrite_prompt = f"""Do not interpret or answer the question. Simply add enough context from the conversation so the question makes sense on its own.
|
||||||
|
|
||||||
|
Conversation:
|
||||||
|
{chr(10).join(conversation_for_rewrite)}
|
||||||
|
|
||||||
|
Latest question: {current_question}
|
||||||
|
|
||||||
|
Rewrite the question to be standalone (respond with ONLY the rewritten question, nothing else):"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.valves.ollama_url}/api/generate",
|
||||||
|
json={
|
||||||
|
"model": rewrite_model,
|
||||||
|
"prompt": rewrite_prompt,
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=30),
|
||||||
|
) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
rewritten = data.get("response", "").strip()
|
||||||
|
# Sanity check - if rewrite is empty or way too long, use original
|
||||||
|
if rewritten and len(rewritten) < 1000:
|
||||||
|
search_query = rewritten
|
||||||
|
else:
|
||||||
|
search_query = current_question
|
||||||
|
else:
|
||||||
|
search_query = current_question
|
||||||
|
except Exception as e:
|
||||||
|
if think:
|
||||||
|
yield f" ⚠ Rewrite failed: {e}, using original query\n"
|
||||||
|
search_query = current_question
|
||||||
|
else:
|
||||||
|
# No prior conversation, use the question as-is
|
||||||
|
search_query = current_question
|
||||||
|
|
||||||
|
if think:
|
||||||
|
yield f" Model: {rewrite_model}\n"
|
||||||
|
yield f" Original: {current_question}\n"
|
||||||
|
yield f" Search query: {search_query}\n"
|
||||||
|
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Step 2: Embed
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if think:
|
||||||
|
yield "**Step 2: Embedding**\n"
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.valves.ollama_url}/api/embeddings",
|
||||||
|
json={"model": self.valves.embedding_model, "prompt": search_query},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
if think:
|
||||||
|
yield f" ✗ HTTP {resp.status}\n</think>\n\n"
|
||||||
|
yield f"Embedding failed: HTTP {resp.status}"
|
||||||
|
return
|
||||||
|
embedding = (await resp.json()).get("embedding")
|
||||||
|
except Exception as e:
|
||||||
|
if think:
|
||||||
|
yield f" ✗ {e}\n</think>\n\n"
|
||||||
|
yield f"Embedding failed: {e}"
|
||||||
|
return
|
||||||
|
|
||||||
|
if think:
|
||||||
|
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Step 3: Search Qdrant
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if think:
|
||||||
|
yield "**Step 3: Qdrant Search**\n"
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.valves.qdrant_url}/collections/{self.valves.collection_name}/points/search",
|
||||||
|
json={
|
||||||
|
"vector": embedding,
|
||||||
|
"limit": self.valves.retrieve_count,
|
||||||
|
"with_payload": True,
|
||||||
|
"score_threshold": self.valves.qdrant_score_threshold,
|
||||||
|
},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
if think:
|
||||||
|
yield f" ✗ HTTP {resp.status}\n</think>\n\n"
|
||||||
|
yield f"Qdrant search failed: HTTP {resp.status}"
|
||||||
|
return
|
||||||
|
qdrant_results = (await resp.json()).get("result", [])
|
||||||
|
except Exception as e:
|
||||||
|
if think:
|
||||||
|
yield f" ✗ {e}\n</think>\n\n"
|
||||||
|
yield f"Qdrant search failed: {e}"
|
||||||
|
return
|
||||||
|
|
||||||
|
if think:
|
||||||
|
yield f" ✓ Found {len(qdrant_results)} chunks ({time.time() - t0:.2f}s)\n"
|
||||||
|
|
||||||
|
if not qdrant_results:
|
||||||
|
if think:
|
||||||
|
yield " ✗ No results\n</think>\n\n"
|
||||||
|
yield "No relevant notes found for this query."
|
||||||
|
return
|
||||||
|
|
||||||
|
# Show top 5
|
||||||
|
if think:
|
||||||
|
yield " Top 5:\n"
|
||||||
|
for i, r in enumerate(qdrant_results[:5]):
|
||||||
|
name = r.get("payload", {}).get("fileName", "?")
|
||||||
|
score = r.get("score", 0)
|
||||||
|
yield f" {i+1}. [{score:.4f}] {name}\n"
|
||||||
|
yield "\n"
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Step 4: Rerank (optional)
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if self.valves.rerank_enabled:
|
||||||
|
if think:
|
||||||
|
yield "**Step 4: Reranking**\n"
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
docs_for_rerank = [
|
||||||
|
r.get("payload", {}).get("content", "") for r in qdrant_results
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.valves.rerank_url}/rerank",
|
||||||
|
json={
|
||||||
|
"query": search_query,
|
||||||
|
"documents": docs_for_rerank,
|
||||||
|
"return_documents": False,
|
||||||
|
},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self.valves.rerank_timeout),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
if think:
|
||||||
|
yield f" ⚠ Reranker failed: HTTP {resp.status}, using Qdrant order\n\n"
|
||||||
|
chunks = qdrant_results[: self.valves.final_top_k]
|
||||||
|
else:
|
||||||
|
rerank_results = (await resp.json()).get("results", [])
|
||||||
|
|
||||||
|
# Apply rerank scores and filter
|
||||||
|
scored = []
|
||||||
|
for item in rerank_results:
|
||||||
|
idx = item["index"]
|
||||||
|
score = item["relevance_score"]
|
||||||
|
if score >= self.valves.min_rerank_score:
|
||||||
|
chunk = qdrant_results[idx].copy()
|
||||||
|
chunk["rerank_score"] = score
|
||||||
|
scored.append(chunk)
|
||||||
|
|
||||||
|
scored.sort(key=lambda x: x["rerank_score"], reverse=True)
|
||||||
|
chunks = scored[: self.valves.final_top_k]
|
||||||
|
|
||||||
|
if think:
|
||||||
|
yield f" ✓ Kept {len(chunks)} chunks ({time.time() - t0:.2f}s)\n"
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
yield " Top 5 after rerank:\n"
|
||||||
|
for i, c in enumerate(chunks[:5]):
|
||||||
|
name = c.get("payload", {}).get("fileName", "?")
|
||||||
|
score = c.get("rerank_score", 0)
|
||||||
|
yield f" {i+1}. [{score:.4f}] {name}\n"
|
||||||
|
yield "\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if think:
|
||||||
|
yield f" ⚠ Reranker error: {e}, using Qdrant order\n\n"
|
||||||
|
chunks = qdrant_results[: self.valves.final_top_k]
|
||||||
|
|
||||||
|
else:
|
||||||
|
if think:
|
||||||
|
yield "**Step 4: Reranking** (disabled)\n\n"
|
||||||
|
chunks = qdrant_results[: self.valves.final_top_k]
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
if think:
|
||||||
|
yield " ✗ No chunks after filtering\n</think>\n\n"
|
||||||
|
yield "No relevant notes passed the relevance threshold."
|
||||||
|
return
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Step 5: Build context
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if think:
|
||||||
|
yield "**Step 5: Build Context**\n"
|
||||||
|
|
||||||
|
context_parts = []
|
||||||
|
for i, chunk in enumerate(chunks, 1):
|
||||||
|
payload = chunk.get("payload", {})
|
||||||
|
file_name = payload.get("fileName", "Unknown")
|
||||||
|
content = payload.get("content", "").strip()
|
||||||
|
source = payload.get("source", "")
|
||||||
|
|
||||||
|
part = f"### Note {i}: {file_name}\n"
|
||||||
|
if source:
|
||||||
|
part += f"Original source: {source}\n"
|
||||||
|
part += f"\n{content}"
|
||||||
|
context_parts.append(part)
|
||||||
|
|
||||||
|
context = "\n\n---\n\n".join(context_parts)
|
||||||
|
context_chars = len(context)
|
||||||
|
estimated_tokens = context_chars // 4
|
||||||
|
|
||||||
|
if think:
|
||||||
|
yield f" ✓ {len(chunks)} chunks, {context_chars:,} chars (~{estimated_tokens:,} tokens)\n"
|
||||||
|
|
||||||
|
if estimated_tokens > self.valves.token_warning_threshold:
|
||||||
|
yield f" ⚠ Warning: approaching context limit ({self.valves.llm_context_size})\n"
|
||||||
|
|
||||||
|
yield "\n"
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Step 6: Build prompt and call LLM
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if think:
|
||||||
|
yield "**Step 6: Generate Response**\n"
|
||||||
|
yield "</think>\n\n"
|
||||||
|
|
||||||
|
system_prompt = f"""You are a specialized Research Assistant. Your goal is to synthesize information from the provided user notes.
|
||||||
|
|
||||||
|
### INSTRUCTIONS
|
||||||
|
1. **Primary Source:** Answer the user's question using strictly the content found within the <notes> section below.
|
||||||
|
2. **Citation:** Every claim you make must be immediately followed by a citation in this format: [Note Name].
|
||||||
|
3. **Missing Info:** If the <notes> do not contain the answer, explicitly state: "Your notes don't cover this." Do not attempt to guess or hallucinate an answer.
|
||||||
|
4. **Exception Handling (Outside Knowledge):**
|
||||||
|
- You are generally FORBIDDEN from using outside knowledge.
|
||||||
|
- **HOWEVER**, if the user explicitly asks for external context (e.g., "What am I missing?", "Add outside context"), you may provide it.
|
||||||
|
- If you trigger this exception, you must prefix that specific part of the response with: "**Outside Context:**".
|
||||||
|
|
||||||
|
### FORMATTING
|
||||||
|
- Be concise and direct.
|
||||||
|
- Use bullet points for lists.
|
||||||
|
|
||||||
|
### SOURCE DATA
|
||||||
|
<notes>
|
||||||
|
{context}
|
||||||
|
</notes>
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Only keep user/assistant messages
|
||||||
|
conversation = [m for m in messages if m.get("role") in ("user", "assistant")]
|
||||||
|
|
||||||
|
llm_payload = {
|
||||||
|
"model": self.valves.llm_model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
*conversation,
|
||||||
|
],
|
||||||
|
"stream": True,
|
||||||
|
"options": {"num_ctx": self.valves.llm_context_size},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stream LLM response
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.valves.ollama_url}/api/chat",
|
||||||
|
json=llm_payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self.valves.llm_timeout),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
yield f"LLM error: HTTP {resp.status}"
|
||||||
|
return
|
||||||
|
|
||||||
|
async for line in resp.content:
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
if text := data.get("message", {}).get("content"):
|
||||||
|
yield text
|
||||||
|
if data.get("done"):
|
||||||
|
prompt_tokens = data.get("prompt_eval_count", 0)
|
||||||
|
completion_tokens = data.get("eval_count", 0)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield "\n\n⚠️ LLM timed out"
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
yield f"\n\nLLM error: {e}"
|
||||||
|
return
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Sources
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if self.valves.show_sources:
|
||||||
|
# Dedupe by file path, count chunks
|
||||||
|
source_counts: dict[str, dict] = {}
|
||||||
|
for chunk in chunks:
|
||||||
|
payload = chunk.get("payload", {})
|
||||||
|
path = payload.get("filePath", "")
|
||||||
|
name = payload.get("fileName", "Unknown")
|
||||||
|
if path in source_counts:
|
||||||
|
source_counts[path]["count"] += 1
|
||||||
|
else:
|
||||||
|
source_counts[path] = {"name": name, "path": path, "count": 1}
|
||||||
|
|
||||||
|
yield "\n\n---\n**Sources:**\n"
|
||||||
|
for src in source_counts.values():
|
||||||
|
vault = urllib.parse.quote(self.valves.vault_name)
|
||||||
|
path = urllib.parse.quote(src["path"])
|
||||||
|
uri = f"obsidian://open?vault={vault}&file={path}"
|
||||||
|
count_str = f" ({src['count']} chunks)" if src["count"] > 1 else ""
|
||||||
|
yield f"- [{src['name']}]({uri}){count_str}\n"
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
# Stats
|
||||||
|
# ─────────────────────────────────────────────
|
||||||
|
if self.valves.show_stats:
|
||||||
|
yield f"\n*{prompt_tokens:,} in / {completion_tokens:,} out*"
|
||||||
|
|
||||||
Reference in New Issue
Block a user