commit 9f8c654019f675c6302869f62cc6506873f8f96d Author: Daniel Henry Date: Tue Jan 20 13:23:32 2026 -0600 Initial Commit Signed-off-by: Daniel Henry diff --git a/ObsidianRAGPipe.py b/ObsidianRAGPipe.py new file mode 100644 index 0000000..1743aed --- /dev/null +++ b/ObsidianRAGPipe.py @@ -0,0 +1,449 @@ +""" +title: Obsidian RAG +author: Daniel +version: 6.0 +required_open_webui_version: 0.3.9 +""" + +import asyncio +import json +import time +import urllib.parse +from typing import AsyncGenerator + +import aiohttp +from pydantic import BaseModel, Field + + +class Pipe: + + class Valves(BaseModel): + # Endpoints + ollama_url: str = Field(default="http://ollama.internal.henryhosted.com:11434") + qdrant_url: str = Field(default="http://app-01.internal.henryhosted.com:6333") + rerank_url: str = Field(default="http://ollama.internal.henryhosted.com:7997") + + # Qdrant + collection_name: str = Field(default="obsidian_vault") + retrieve_count: int = Field( + default=50, description="Candidates to fetch from Qdrant" + ) + qdrant_score_threshold: float = Field( + default=0.3, description="Minimum similarity score" + ) + + # Reranker + rerank_enabled: bool = Field( + default=True, description="Set to False to skip reranking" + ) + rerank_timeout: float = Field(default=60.0) + min_rerank_score: float = Field( + default=0.01, description="Minimum rerank score to keep" + ) + final_top_k: int = Field( + default=10, description="Chunks to keep after reranking" + ) + + # LLM + embedding_model: str = Field(default="nomic-embed-text") + llm_model: str = Field(default="llama3.2:3b") + llm_context_size: int = Field(default=8192) + llm_timeout: float = Field(default=300.0) + query_rewrite_model: str = Field( + default="", + description="Model for query rewriting. Leave empty to use llm_model.", + ) + + # Obsidian + vault_name: str = Field( + default="Main", description="For generating obsidian:// links" + ) + + # Display + show_thinking: bool = Field(default=True) + show_sources: bool = Field(default=True) + show_stats: bool = Field(default=True) + token_warning_threshold: int = Field( + default=6000, description="Warn if context exceeds this" + ) + + def __init__(self): + self.valves = self.Valves() + + async def pipe(self, body: dict) -> AsyncGenerator[str, None]: + messages = body.get("messages", []) + if not messages: + yield "No messages provided." + return + + query = messages[-1].get("content", "").strip() + if not query: + yield "Empty query." + return + + async with aiohttp.ClientSession() as session: + async for chunk in self._execute(session, query, messages): + yield chunk + + async def _execute( + self, + session: aiohttp.ClientSession, + query: str, + messages: list[dict], + ) -> AsyncGenerator[str, None]: + + think = self.valves.show_thinking + + # Start thinking block immediately + if think: + yield "\n" + yield f"**Query:** {query}\n\n" + + # ───────────────────────────────────────────── + # Step 1: Rewrite query with conversation context + # ───────────────────────────────────────────── + if think: + yield "**Step 1: Query Rewriting**\n" + + t0 = time.time() + rewrite_model = self.valves.query_rewrite_model or self.valves.llm_model + + # Build conversation context for rewriting + conversation_for_rewrite = [] + for m in messages[:-1]: # All messages except the last one + role = m.get("role", "") + content = m.get("content", "") + if role == "user": + conversation_for_rewrite.append(f"User: {content}") + elif role == "assistant": + # Truncate assistant responses to avoid bloat + truncated = content[:500] + "..." if len(content) > 500 else content + conversation_for_rewrite.append(f"Assistant: {truncated}") + + current_question = messages[-1].get("content", "") + + # If there's prior conversation, rewrite the query + if conversation_for_rewrite: + rewrite_prompt = f"""Do not interpret or answer the question. Simply add enough context from the conversation so the question makes sense on its own. + +Conversation: +{chr(10).join(conversation_for_rewrite)} + +Latest question: {current_question} + +Rewrite the question to be standalone (respond with ONLY the rewritten question, nothing else):""" + + try: + async with session.post( + f"{self.valves.ollama_url}/api/generate", + json={ + "model": rewrite_model, + "prompt": rewrite_prompt, + "stream": False, + }, + timeout=aiohttp.ClientTimeout(total=30), + ) as resp: + if resp.status == 200: + data = await resp.json() + rewritten = data.get("response", "").strip() + # Sanity check - if rewrite is empty or way too long, use original + if rewritten and len(rewritten) < 1000: + search_query = rewritten + else: + search_query = current_question + else: + search_query = current_question + except Exception as e: + if think: + yield f" ⚠ Rewrite failed: {e}, using original query\n" + search_query = current_question + else: + # No prior conversation, use the question as-is + search_query = current_question + + if think: + yield f" Model: {rewrite_model}\n" + yield f" Original: {current_question}\n" + yield f" Search query: {search_query}\n" + yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n" + + # ───────────────────────────────────────────── + # Step 2: Embed + # ───────────────────────────────────────────── + if think: + yield "**Step 2: Embedding**\n" + t0 = time.time() + + try: + async with session.post( + f"{self.valves.ollama_url}/api/embeddings", + json={"model": self.valves.embedding_model, "prompt": search_query}, + timeout=aiohttp.ClientTimeout(total=15), + ) as resp: + if resp.status != 200: + if think: + yield f" ✗ HTTP {resp.status}\n\n\n" + yield f"Embedding failed: HTTP {resp.status}" + return + embedding = (await resp.json()).get("embedding") + except Exception as e: + if think: + yield f" ✗ {e}\n\n\n" + yield f"Embedding failed: {e}" + return + + if think: + yield f" ✓ Done ({time.time() - t0:.2f}s)\n\n" + + # ───────────────────────────────────────────── + # Step 3: Search Qdrant + # ───────────────────────────────────────────── + if think: + yield "**Step 3: Qdrant Search**\n" + t0 = time.time() + + try: + async with session.post( + f"{self.valves.qdrant_url}/collections/{self.valves.collection_name}/points/search", + json={ + "vector": embedding, + "limit": self.valves.retrieve_count, + "with_payload": True, + "score_threshold": self.valves.qdrant_score_threshold, + }, + timeout=aiohttp.ClientTimeout(total=15), + ) as resp: + if resp.status != 200: + if think: + yield f" ✗ HTTP {resp.status}\n\n\n" + yield f"Qdrant search failed: HTTP {resp.status}" + return + qdrant_results = (await resp.json()).get("result", []) + except Exception as e: + if think: + yield f" ✗ {e}\n\n\n" + yield f"Qdrant search failed: {e}" + return + + if think: + yield f" ✓ Found {len(qdrant_results)} chunks ({time.time() - t0:.2f}s)\n" + + if not qdrant_results: + if think: + yield " ✗ No results\n\n\n" + yield "No relevant notes found for this query." + return + + # Show top 5 + if think: + yield " Top 5:\n" + for i, r in enumerate(qdrant_results[:5]): + name = r.get("payload", {}).get("fileName", "?") + score = r.get("score", 0) + yield f" {i+1}. [{score:.4f}] {name}\n" + yield "\n" + + # ───────────────────────────────────────────── + # Step 4: Rerank (optional) + # ───────────────────────────────────────────── + if self.valves.rerank_enabled: + if think: + yield "**Step 4: Reranking**\n" + t0 = time.time() + + docs_for_rerank = [ + r.get("payload", {}).get("content", "") for r in qdrant_results + ] + + try: + async with session.post( + f"{self.valves.rerank_url}/rerank", + json={ + "query": search_query, + "documents": docs_for_rerank, + "return_documents": False, + }, + timeout=aiohttp.ClientTimeout(total=self.valves.rerank_timeout), + ) as resp: + if resp.status != 200: + if think: + yield f" ⚠ Reranker failed: HTTP {resp.status}, using Qdrant order\n\n" + chunks = qdrant_results[: self.valves.final_top_k] + else: + rerank_results = (await resp.json()).get("results", []) + + # Apply rerank scores and filter + scored = [] + for item in rerank_results: + idx = item["index"] + score = item["relevance_score"] + if score >= self.valves.min_rerank_score: + chunk = qdrant_results[idx].copy() + chunk["rerank_score"] = score + scored.append(chunk) + + scored.sort(key=lambda x: x["rerank_score"], reverse=True) + chunks = scored[: self.valves.final_top_k] + + if think: + yield f" ✓ Kept {len(chunks)} chunks ({time.time() - t0:.2f}s)\n" + + if chunks: + yield " Top 5 after rerank:\n" + for i, c in enumerate(chunks[:5]): + name = c.get("payload", {}).get("fileName", "?") + score = c.get("rerank_score", 0) + yield f" {i+1}. [{score:.4f}] {name}\n" + yield "\n" + + except Exception as e: + if think: + yield f" ⚠ Reranker error: {e}, using Qdrant order\n\n" + chunks = qdrant_results[: self.valves.final_top_k] + + else: + if think: + yield "**Step 4: Reranking** (disabled)\n\n" + chunks = qdrant_results[: self.valves.final_top_k] + + if not chunks: + if think: + yield " ✗ No chunks after filtering\n\n\n" + yield "No relevant notes passed the relevance threshold." + return + + # ───────────────────────────────────────────── + # Step 5: Build context + # ───────────────────────────────────────────── + if think: + yield "**Step 5: Build Context**\n" + + context_parts = [] + for i, chunk in enumerate(chunks, 1): + payload = chunk.get("payload", {}) + file_name = payload.get("fileName", "Unknown") + content = payload.get("content", "").strip() + source = payload.get("source", "") + + part = f"### Note {i}: {file_name}\n" + if source: + part += f"Original source: {source}\n" + part += f"\n{content}" + context_parts.append(part) + + context = "\n\n---\n\n".join(context_parts) + context_chars = len(context) + estimated_tokens = context_chars // 4 + + if think: + yield f" ✓ {len(chunks)} chunks, {context_chars:,} chars (~{estimated_tokens:,} tokens)\n" + + if estimated_tokens > self.valves.token_warning_threshold: + yield f" ⚠ Warning: approaching context limit ({self.valves.llm_context_size})\n" + + yield "\n" + + # ───────────────────────────────────────────── + # Step 6: Build prompt and call LLM + # ───────────────────────────────────────────── + if think: + yield "**Step 6: Generate Response**\n" + yield "\n\n" + + system_prompt = f"""You are a specialized Research Assistant. Your goal is to synthesize information from the provided user notes. + + ### INSTRUCTIONS + 1. **Primary Source:** Answer the user's question using strictly the content found within the section below. + 2. **Citation:** Every claim you make must be immediately followed by a citation in this format: [Note Name]. + 3. **Missing Info:** If the do not contain the answer, explicitly state: "Your notes don't cover this." Do not attempt to guess or hallucinate an answer. + 4. **Exception Handling (Outside Knowledge):** + - You are generally FORBIDDEN from using outside knowledge. + - **HOWEVER**, if the user explicitly asks for external context (e.g., "What am I missing?", "Add outside context"), you may provide it. + - If you trigger this exception, you must prefix that specific part of the response with: "**Outside Context:**". + + ### FORMATTING + - Be concise and direct. + - Use bullet points for lists. + + ### SOURCE DATA + + {context} + + """ + + # Only keep user/assistant messages + conversation = [m for m in messages if m.get("role") in ("user", "assistant")] + + llm_payload = { + "model": self.valves.llm_model, + "messages": [ + {"role": "system", "content": system_prompt}, + *conversation, + ], + "stream": True, + "options": {"num_ctx": self.valves.llm_context_size}, + } + + # Stream LLM response + prompt_tokens = 0 + completion_tokens = 0 + + try: + async with session.post( + f"{self.valves.ollama_url}/api/chat", + json=llm_payload, + timeout=aiohttp.ClientTimeout(total=self.valves.llm_timeout), + ) as resp: + if resp.status != 200: + yield f"LLM error: HTTP {resp.status}" + return + + async for line in resp.content: + if not line: + continue + try: + data = json.loads(line) + if text := data.get("message", {}).get("content"): + yield text + if data.get("done"): + prompt_tokens = data.get("prompt_eval_count", 0) + completion_tokens = data.get("eval_count", 0) + except json.JSONDecodeError: + continue + + except asyncio.TimeoutError: + yield "\n\n⚠️ LLM timed out" + return + except Exception as e: + yield f"\n\nLLM error: {e}" + return + + # ───────────────────────────────────────────── + # Sources + # ───────────────────────────────────────────── + if self.valves.show_sources: + # Dedupe by file path, count chunks + source_counts: dict[str, dict] = {} + for chunk in chunks: + payload = chunk.get("payload", {}) + path = payload.get("filePath", "") + name = payload.get("fileName", "Unknown") + if path in source_counts: + source_counts[path]["count"] += 1 + else: + source_counts[path] = {"name": name, "path": path, "count": 1} + + yield "\n\n---\n**Sources:**\n" + for src in source_counts.values(): + vault = urllib.parse.quote(self.valves.vault_name) + path = urllib.parse.quote(src["path"]) + uri = f"obsidian://open?vault={vault}&file={path}" + count_str = f" ({src['count']} chunks)" if src["count"] > 1 else "" + yield f"- [{src['name']}]({uri}){count_str}\n" + + # ───────────────────────────────────────────── + # Stats + # ───────────────────────────────────────────── + if self.valves.show_stats: + yield f"\n*{prompt_tokens:,} in / {completion_tokens:,} out*" +