"""
title: Obsidian RAG Pipeline
author: Daniel Henry
version: 0.17
description: Updated for llama-swap with llama.cpp (OpenAI-compatible API)
"""
import asyncio
import json
import time
import math
import urllib.parse
from typing import AsyncGenerator
import aiohttp
from pydantic import BaseModel, Field
class Pipe:
class Valves(BaseModel):
# Endpoints
llamacpp_url: str = Field(default="http://ollama.internal.henryhosted.com:9292")
qdrant_url: str = Field(default="http://app-01.internal.henryhosted.com:6333")
# Qdrant
collection_name: str = Field(default="obsidian_vault")
retrieve_count: int = Field(
default=50, description="Candidates to fetch from Qdrant"
)
qdrant_score_threshold: float = Field(
default=0.3, description="Minimum similarity score"
)
# Reranker
rerank_enabled: bool = Field(
default=True, description="Set to False to skip reranking"
)
rerank_logit: bool = Field(
default=False, description="Enable if reranker outputs logits"
)
rerank_debug: bool = Field(
default=False, description="Output all rerank values into think"
)
rerank_model: str = Field(
default="bge-reranker-v2-m3-q8_0",
description="Reranker model name",
)
rerank_timeout: float = Field(default=60.0)
min_rerank_score: float = Field(
default=0.01, description="Minimum rerank score to keep"
)
final_top_k: int = Field(
default=10, description="Chunks to keep after reranking"
)
# LLM
embedding_model: str = Field(
default="nomic-embed-text-v1.5.f16",
description="Embedding model name",
)
llm_model: str = Field(
default="qwen2.5-3b-instruct-q4_k_m",
description="LLM model name",
)
llm_max_tokens: int = Field(
default=2048, description="Max tokens for LLM response"
)
llm_timeout: float = Field(default=300.0)
query_rewrite_model: str = Field(
default="",
description="Model for query rewriting. Leave empty to use llm_model.",
)
# Obsidian
vault_name: str = Field(
default="Main", description="For generating obsidian:// links"
)
# Display
show_thinking: bool = Field(default=True)
show_sources: bool = Field(default=True)
show_stats: bool = Field(default=True)
token_warning_threshold: int = Field(
default=6000, description="Warn if context exceeds this"
)
def __init__(self):
self.valves = self.Valves()
def _estimate_tokens(self, text: str) -> int:
"""Rough token estimate: ~4 chars per token for English text."""
return len(text) // 4
async def pipe(self, body: dict) -> AsyncGenerator[str, None]:
messages = body.get("messages", [])
if not messages:
yield "No messages provided."
return
query = messages[-1].get("content", "").strip()
if not query:
yield "Empty query."
return
async with aiohttp.ClientSession() as session:
async for chunk in self._execute(session, query, messages):
yield chunk
async def _execute(
self,
session: aiohttp.ClientSession,
query: str,
messages: list[dict],
) -> AsyncGenerator[str, None]:
think = self.valves.show_thinking
total_prompt_tokens = 0
# Start thinking block
if think:
yield "\n"
yield f"**LLM Model:** {self.valves.llm_model}\n"
yield f"**Query:** {query}\n\n"
# ─────────────────────────────────────────────
# Step 1: Rewrite query with conversation context
# ─────────────────────────────────────────────
if think:
yield "**Step 1: Query Rewriting**\n"
t0 = time.time()
rewrite_model = self.valves.query_rewrite_model or self.valves.llm_model
current_question = messages[-1].get("content", "")
# Build conversation context for rewriting (only if there's prior conversation)
conversation_for_rewrite = []
for m in messages[:-1]:
role = m.get("role", "")
content = m.get("content", "")
if role == "user":
conversation_for_rewrite.append(f"User: {content}")
elif role == "assistant":
truncated = content[:500] + "..." if len(content) > 500 else content
conversation_for_rewrite.append(f"Assistant: {truncated}")
if conversation_for_rewrite:
rewrite_prompt = f"""Do not interpret or answer the question. Simply add enough context from the conversation so the question makes sense on its own.
Conversation:
{chr(10).join(conversation_for_rewrite)}
Latest question: {current_question}
Rewrite the question to be standalone (respond with ONLY the rewritten question, nothing else):"""
try:
async with session.post(
f"{self.valves.llamacpp_url}/v1/chat/completions",
json={
"model": rewrite_model,
"messages": [{"role": "user", "content": rewrite_prompt}],
"stream": False,
"max_tokens": 256,
},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
if resp.status == 200:
data = await resp.json()
rewritten = (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
.strip()
)
if rewritten and len(rewritten) < 1000:
search_query = rewritten
else:
search_query = current_question
else:
search_query = current_question
except Exception as e:
if think:
yield f" ⚠ Rewrite failed: {e}, using original query\n"
search_query = current_question
else:
search_query = current_question
if think:
yield f" Model: {rewrite_model}\n"
yield f" Original: {current_question}\n"
yield f" Search query: {search_query}\n"
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
# ─────────────────────────────────────────────
# Step 2: Embed
# ─────────────────────────────────────────────
if think:
yield "**Step 2: Embedding**\n"
t0 = time.time()
try:
async with session.post(
f"{self.valves.llamacpp_url}/v1/embeddings",
json={
"model": self.valves.embedding_model,
"input": search_query,
},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
if resp.status != 200:
error_text = await resp.text()
if think:
yield f" ✗ HTTP {resp.status}: {error_text}\n\n\n"
yield f"Embedding failed: HTTP {resp.status}"
return
data = await resp.json()
embedding = data.get("data", [{}])[0].get("embedding")
if not embedding:
if think:
yield " ✗ No embedding in response\n\n\n"
yield "Embedding failed: No embedding returned"
return
except Exception as e:
if think:
yield f" ✗ {e}\n\n\n"
yield f"Embedding failed: {e}"
return
if think:
yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n"
# ─────────────────────────────────────────────
# Step 3: Search Qdrant
# ─────────────────────────────────────────────
if think:
yield "**Step 3: Qdrant Search**\n"
t0 = time.time()
try:
async with session.post(
f"{self.valves.qdrant_url}/collections/{self.valves.collection_name}/points/search",
json={
"vector": embedding,
"limit": self.valves.retrieve_count,
"with_payload": True,
"score_threshold": self.valves.qdrant_score_threshold,
},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
if think:
yield f" ✗ HTTP {resp.status}\n\n\n"
yield f"Qdrant search failed: HTTP {resp.status}"
return
qdrant_results = (await resp.json()).get("result", [])
except Exception as e:
if think:
yield f" ✗ {e}\n\n\n"
yield f"Qdrant search failed: {e}"
return
if think:
yield f" ✓ Found {len(qdrant_results)} chunks ({time.time() - t0:.2f}s)\n"
if not qdrant_results:
if think:
yield " ✗ No results\n\n\n"
yield "No relevant notes found for this query."
return
if think:
yield " Top 5:\n"
for i, r in enumerate(qdrant_results[:5]):
name = r.get("payload", {}).get("fileName", "?")
score = r.get("score", 0)
yield f" {i+1}. [{score:.4f}] {name}\n"
yield "\n"
# ─────────────────────────────────────────────
# Step 4: Rerank (optional)
# ─────────────────────────────────────────────
if self.valves.rerank_enabled:
if think:
yield "**Step 4: Reranking**\n"
yield f"**Rerank Model:** {self.valves.rerank_model}\n"
t0 = time.time()
docs_for_rerank = [
r.get("payload", {}).get("content", "") for r in qdrant_results
]
try:
async with session.post(
f"{self.valves.llamacpp_url}/v1/rerank",
json={
"model": self.valves.rerank_model,
"query": search_query,
"documents": docs_for_rerank,
},
timeout=aiohttp.ClientTimeout(total=self.valves.rerank_timeout),
) as resp:
if resp.status != 200:
error_text = await resp.text()
if think:
yield f" ⚠ Reranker failed: HTTP {resp.status} - {error_text}, using Qdrant order\n\n"
chunks = qdrant_results[: self.valves.final_top_k]
else:
rerank_data = await resp.json()
rerank_results = rerank_data.get("results", [])
scored = []
for item in rerank_results:
idx = item["index"]
score = item["relevance_score"]
if self.valves.rerank_logit:
score = 1 / (1 + math.exp(-item["relevance_score"]))
if think and self.valves.rerank_debug:
yield f" • Debug: Doc {idx} score: {score}\n"
if score >= self.valves.min_rerank_score:
chunk = qdrant_results[idx].copy()
chunk["rerank_score"] = score
scored.append(chunk)
scored.sort(key=lambda x: x["rerank_score"], reverse=True)
chunks = scored[: self.valves.final_top_k]
if think:
yield f" ✓ Kept {len(chunks)} chunks ({time.time() - t0:.2f}s)\n"
if chunks:
yield " Top 5 after rerank:\n"
for i, c in enumerate(chunks[:5]):
name = c.get("payload", {}).get("fileName", "?")
score = c.get("rerank_score", 0)
yield f" {i+1}. [{score:.4f}] {name}\n"
yield "\n"
except Exception as e:
if think:
yield f" ⚠ Reranker error: {e}, using Qdrant order\n\n"
chunks = qdrant_results[: self.valves.final_top_k]
else:
if think:
yield "**Step 4: Reranking** (disabled)\n\n"
chunks = qdrant_results[: self.valves.final_top_k]
if not chunks:
if think:
yield " ✗ No chunks after filtering\n\n\n"
yield "No relevant notes passed the relevance threshold."
return
# ─────────────────────────────────────────────
# Step 5: Build context
# ─────────────────────────────────────────────
if think:
yield "**Step 5: Build Context**\n"
context_parts = []
for i, chunk in enumerate(chunks, 1):
payload = chunk.get("payload", {})
file_name = payload.get("fileName", "Unknown")
content = payload.get("content", "").strip()
source = payload.get("source", "")
# CHANGE: Explicit bracketed ID format
part = f"[{i}] File: {file_name}\n"
if source:
part += f"Source: {source}\n"
part += f"\n{content}"
context_parts.append(part)
context = "\n\n---\n\n".join(context_parts)
context_tokens = self._estimate_tokens(context)
if think:
yield f" ✓ {len(chunks)} chunks, ~{context_tokens:,} tokens\n"
if context_tokens > self.valves.token_warning_threshold:
yield f" ⚠ Warning: large context may affect quality\n"
yield "\n"
# ─────────────────────────────────────────────
# Step 6: Build prompt and call LLM
# ─────────────────────────────────────────────
if think:
yield "**Step 6: Generate Response**\n"
yield "\n\n"
system_prompt = f"""You are a helpful assistant. Use the provided notes to answer the user's question.
RULES:
1. Use the as your source of truth.
2. Cite facts using the bracketed ID number [1].
3. SYNTHESIS: You are encouraged to draw connections between different notes to form a complete answer.
4. INFERENCE: If the answer is not explicitly written but can be logically inferred from the notes, you may answer, but please use phrases like "The notes imply..." or "Based on [1], it suggests..."
5. If the answer is completely absent, say "I couldn't find that in your notes."
{context}
"""
# Build conversation, stripping previous sources from assistant messages
conversation = []
for m in messages:
role = m.get("role")
if role not in ("user", "assistant"):
continue
msg = {"role": role, "content": m.get("content", "")}
if role == "assistant" and "**Sources:**" in msg["content"]:
msg["content"] = msg["content"].split("**Sources:**")[0].strip()
conversation.append(msg)
llm_messages = [{"role": "system", "content": system_prompt}] + conversation
# Estimate prompt tokens
prompt_text = system_prompt + "".join(m["content"] for m in conversation)
total_prompt_tokens = self._estimate_tokens(prompt_text)
llm_payload = {
"model": self.valves.llm_model,
"messages": llm_messages,
"stream": True,
"max_tokens": self.valves.llm_max_tokens,
}
completion_tokens = 0
completion_text = ""
try:
async with session.post(
f"{self.valves.llamacpp_url}/v1/chat/completions",
json=llm_payload,
timeout=aiohttp.ClientTimeout(total=self.valves.llm_timeout),
) as resp:
if resp.status != 200:
error_text = await resp.text()
yield f"LLM error: HTTP {resp.status} - {error_text}"
return
async for line in resp.content:
if not line:
continue
line_str = line.decode("utf-8").strip()
if not line_str or line_str.startswith(":"):
continue
if line_str.startswith("data: "):
line_str = line_str[6:]
if line_str == "[DONE]":
break
try:
data = json.loads(line_str)
delta = data.get("choices", [{}])[0].get("delta", {})
if content := delta.get("content"):
yield content
completion_text += content
except json.JSONDecodeError:
continue
except asyncio.TimeoutError:
yield "\n\n⚠️ LLM timed out"
return
except Exception as e:
yield f"\n\nLLM error: {e}"
return
# Estimate completion tokens
completion_tokens = self._estimate_tokens(completion_text)
# ─────────────────────────────────────────────
# Sources
# ─────────────────────────────────────────────
if self.valves.show_sources:
# We now track 'indices' list along with the count
source_counts: dict[str, dict] = {}
# 'chunks' is still available from Step 4/Step 3
for i, chunk in enumerate(chunks, 1):
payload = chunk.get("payload", {})
path = payload.get("filePath", "")
name = payload.get("fileName", "Unknown")
if path in source_counts:
source_counts[path]["count"] += 1
source_counts[path]["indices"].append(i)
else:
source_counts[path] = {
"name": name,
"path": path,
"count": 1,
"indices": [i],
}
yield "\n\n---\n**Sources:**\n"
for src in source_counts.values():
vault = urllib.parse.quote(self.valves.vault_name)
path = urllib.parse.quote(src["path"])
uri = f"obsidian://open?vault={vault}&file={path}"
# Format indices like: [1, 2, 5]
indices_str = ", ".join(map(str, src["indices"]))
yield f"- [{src['name']}]({uri}) (Chunks: {indices_str})\n"
# ─────────────────────────────────────────────
# Stats
# ─────────────────────────────────────────────
if self.valves.show_stats:
yield f"\n*~{total_prompt_tokens:,} in / ~{completion_tokens:,} out (estimated)*"