commit 6e852871a66e724706ed67f3f3fc2e7f37f36a1b Author: samsamfin Date: Sun Apr 5 05:20:44 2026 +0000 Upload files to "/" diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..12eb648 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,58 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is an **Open WebUI Pipeline** (`llm_router_v3.py`) that acts as an intelligent LLM router. It classifies user prompts and routes them to different Ollama models based on intent, with integrated web search and image generation. + +## Architecture + +Single-file pipeline (`llm_router_v3.py`) that runs inside Open WebUI's pipelines container. The flow is: + +1. **Task detection** — Open WebUI internal requests (title/tag generation) bypass routing and go to qwen2.5:7b directly +2. **Vision detection** — checks if the latest user message contains an uploaded image +3. **AI classification** — qwen2.5:7b classifies prompts into: coding, diagram, reasoning, image_generation, vision, general +4. **Heuristic safety net** — keyword/pattern-based overrides can force search=true even if AI said no +5. **Web search** — Brave Search API with full page content fetching for top 3 results +6. **Image generation** — AUTOMATIC1111/Forge API via Stable Diffusion XL, with LLM-refined prompts +7. **VRAM management** — automatically unloads Ollama models before SD generation and unloads SD checkpoint after, plus drops page cache to free RAM +8. **Streaming response** — streams model output including thinking/reasoning tokens in collapsible blocks + +### Model Routing + +| Category | Model | Notes | +|---|---|---| +| coding | qwen2.5-coder:14b | | +| diagram | qwen2.5-coder:14b | Mermaid output | +| reasoning (FI/EN) | gpt-oss:120b | Finnish detection via keyword scoring | +| image_generation | gpt-oss:120b → SDXL | LLM refines prompt, then calls A1111 API | +| vision | llama3.2-vision:11b | Only when latest user message has image | +| general | gpt-oss:120b | | + +### Key Design Decisions + +- **Finnish/English bilingual** — Finnish detected by scoring FINNISH_INDICATORS (threshold ≥ 2 matches). Reasoning routes to language-specific system prompts. +- **Search is aggressive** — heuristic layer ensures search triggers for questions with named entities, freshness keywords, time-sensitive topics, even if AI classifier says no. +- **Year injection** — search queries have wrong years replaced with current year to counter LLM hallucination. +- **Image generation VRAM dance** — RTX 2000 Ada 16GB can't hold both gpt-oss:120b and SDXL simultaneously. Pipeline unloads Ollama before SD, unloads SD after, and drops Linux page cache. +- **Chunked image streaming** — base64 images are compressed PNG→JPEG and yielded in 4KB chunks to avoid Open WebUI "chunk too big" errors. + +## Deployment + +- **Open WebUI**: Docker container on `ai-stack_default` network +- **Ollama**: Native on host (not Docker), reached via `http://ollama:11434` from containers +- **AUTOMATIC1111 Forge**: Native on host, systemd service `stable-diffusion`, reached via `http://172.18.0.1:7860` (Docker bridge gateway) +- **Server**: Ubuntu 22.04 LTS, NVIDIA RTX 2000 Ada 16GB + +Pipeline is deployed by copying `llm_router_v3.py` to `~/ai-stack/pipelines/` on the server and restarting the pipelines container. + +## Setup Scripts + +- `setup-sd.sh` — installs AUTOMATIC1111 Forge + downloads SDXL model (Ubuntu 22.04 specific) +- `setup-sd-service.sh` — creates systemd service for Forge (run after setup-sd.sh) + +## Configuration + +All runtime settings are exposed as **Valves** in Open WebUI's pipeline settings UI: +`ollama_url`, `sd_url`, `sd_width/height/steps/cfg_scale`, `brave_api_key`, `brave_max_results`, `use_ai_classifier`, `show_routing_info`, `search_context_max_chars` diff --git a/README.md b/README.md new file mode 100644 index 0000000..84157f1 --- /dev/null +++ b/README.md @@ -0,0 +1,238 @@ +# LLM Router Pipeline for Open WebUI + +An intelligent prompt classification and routing pipeline for [Open WebUI](https://github.com/open-webui/open-webui). Classifies user prompts using AI (qwen2.5:7b) and routes them to specialized Ollama models, with integrated Brave web search, image generation via Stable Diffusion, and full Finnish/English bilingual support. + +## Features + +- **AI-powered prompt classification** with keyword-based fallback +- **Model routing** — coding, diagram, reasoning, vision, image generation, and general categories +- **Brave web search** with full page content fetching (top 3 results scraped) +- **Heuristic search overrides** — safety net that forces search for time-sensitive or factual questions +- **Image generation** via AUTOMATIC1111/Forge (Stable Diffusion XL) with LLM-refined prompts +- **VRAM management** — automatically juggles GPU memory between Ollama and Stable Diffusion +- **Bilingual** — detects Finnish and forces responses in the correct language +- **Thinking/reasoning display** — streams model thinking tokens in collapsible blocks +- **Real-time search status** — shows which URLs are being fetched as search runs + +## Model Routing + +| Category | Model (120B) | Model (20B) | Trigger | +|---|---|---|---| +| coding | qwen2.5-coder:14b | qwen2.5-coder:14b | User asks to write/fix/debug code | +| diagram | qwen2.5-coder:14b | qwen2.5-coder:14b | Mermaid, flowchart, UML requests | +| reasoning (FI) | gpt-oss:120b | gpt-oss:20b | Analysis, comparison, strategy (Finnish) | +| reasoning (EN) | gpt-oss:120b | gpt-oss:20b | Analysis, comparison, strategy (English) | +| image generation | gpt-oss:120b + SDXL | gpt-oss:20b + SDXL | "generate an image", "luo kuva" | +| vision | llama3.2-vision:11b | llama3.2-vision:11b | User uploads an image | +| general | gpt-oss:120b | gpt-oss:20b | Everything else | + +Two pipeline variants are provided: +- **`llm_router_v3.py`** — uses gpt-oss:120b (higher quality, more VRAM/RAM) +- **`llm_router-20b.py`** — uses gpt-oss:20b (lighter, better for constrained hardware) + +## Prerequisites + +- **Ubuntu 22.04 LTS** (tested) +- **NVIDIA GPU** with 16GB+ VRAM (tested on RTX 2000 Ada) +- **Open WebUI** running in Docker with pipelines enabled +- **Ollama** installed natively with models pulled: + ```bash + ollama pull qwen2.5:7b + ollama pull qwen2.5-coder:14b + ollama pull gpt-oss:120b # or gpt-oss:20b for the lighter variant + ollama pull llama3.2-vision:11b + ``` +- **Brave Search API key** (free tier: https://brave.com/search/api/) + +## Setup + +### 1. Deploy the Pipeline + +Copy your chosen pipeline file to the Open WebUI pipelines directory: + +```bash +cp llm_router_v3.py ~/ai-stack/pipelines/ +# or for the 20B variant: +cp llm_router-20b.py ~/ai-stack/pipelines/ +``` + +Restart the pipelines container: + +```bash +docker restart pipelines +``` + +### 2. Configure Valves in Open WebUI + +Go to **Admin Panel > Pipelines** in Open WebUI and configure: + +| Setting | Description | Default | +|---|---|---| +| `ollama_url` | Ollama API URL | `http://ollama:11434` | +| `sd_url` | Stable Diffusion API URL | `http://172.18.0.1:7860` | +| `brave_api_key` | Brave Search API key | (from env `BRAVE_API_KEY`) | +| `sd_width` / `sd_height` | Generated image dimensions | 1024 x 1024 | +| `sd_steps` | Sampling steps | 25 | +| `sd_cfg_scale` | CFG scale | 7.0 | +| `brave_max_results` | Number of search results | 6 | +| `use_ai_classifier` | Use AI vs keyword-only classification | true | +| `show_routing_info` | Show routing banner in responses | true | +| `search_context_max_chars` | Max search context size | 12000 | + +### 3. Set Up Stable Diffusion (Image Generation) + +> Skip this section if you don't need image generation. + +#### Install Forge (AUTOMATIC1111 fork) + +```bash +# Install system dependencies +sudo apt-get update +sudo apt-get install -y git wget python3-venv python3-pip \ + libgl1 libglib2.0-0 libsm6 libxrender1 libxext6 libffi-dev libssl-dev + +# Clone Forge +git clone https://github.com/lllyasviel/stable-diffusion-webui-forge.git ~/stable-diffusion-webui +cd ~/stable-diffusion-webui + +# Download SDXL model (~6.9GB) +mkdir -p models/Stable-diffusion +wget -O models/Stable-diffusion/sd_xl_base_1.0.safetensors \ + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors" +``` + +#### Fix Python 3.10 build issues (Ubuntu 22.04) + +Before the first launch, pre-install CLIP dependencies to avoid build failures: + +```bash +cd ~/stable-diffusion-webui +# First launch creates the venv — run it once, let it fail, then fix: +./webui.sh --api --listen --xformers --no-half-vae || true + +# Fix CLIP build issue +venv/bin/pip install "setuptools<70" wheel +venv/bin/pip install --no-build-isolation \ + https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip + +# Launch again +./webui.sh --api --listen --xformers --no-half-vae +``` + +#### Select SDXL model + +Once the UI is running, open it in a browser and select `sd_xl_base_1.0` from the checkpoint dropdown. Or via API: + +```bash +curl -X POST http://localhost:7860/sdapi/v1/options \ + -H "Content-Type: application/json" \ + -d '{"sd_model_checkpoint": "sd_xl_base_1.0.safetensors"}' +``` + +#### Create a systemd service + +```bash +chmod +x setup-sd-service.sh +sudo ./setup-sd-service.sh +``` + +Or manually: + +```bash +sudo tee /etc/systemd/system/stable-diffusion.service > /dev/null < --format '{{range .IPAM.Config}}{{.Gateway}}{{end}}' +``` + +Verify connectivity from inside the container: + +```bash +docker exec open-webui curl -s http://172.18.0.1:7860/sdapi/v1/sd-models +``` + +## VRAM Management + +On a single 16GB GPU, gpt-oss:120b and SDXL cannot be loaded simultaneously. The pipeline handles this automatically: + +1. **Before image generation**: unloads all Ollama models from VRAM +2. **After image generation**: unloads SD checkpoint from VRAM and drops Linux page cache +3. Ollama reloads the model on the next chat request (~10-15s warm-up) + +If Ollama fails to load after image generation with a memory error, clear the page cache: + +```bash +sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches' +``` + +## Architecture + +``` +User Message + │ + ├─ Image uploaded? ──────────────── → llama3.2-vision:11b + │ + ├─ AI Classifier (qwen2.5:7b) + │ │ + │ ├─ coding ──────────────── → qwen2.5-coder:14b + │ ├─ diagram ─────────────── → qwen2.5-coder:14b (Mermaid) + │ ├─ reasoning ───────────── → gpt-oss:120b (FI/EN system prompt) + │ ├─ image_generation ────── → gpt-oss:120b (refine) → SDXL (generate) + │ └─ general ─────────────── → gpt-oss:120b + │ + ├─ Heuristic Search Override + │ │ + │ └─ Brave Search + page fetch (if needed) + │ + └─ Stream response (with thinking tokens) +``` + +## Files + +| File | Description | +|---|---| +| `llm_router_v3.py` | Main pipeline (gpt-oss:120b) | +| `llm_router-20b.py` | Lighter pipeline variant (gpt-oss:20b) | +| `setup-sd.sh` | Stable Diffusion Forge install script | +| `setup-sd-service.sh` | systemd service creation script | + +## License + +MIT diff --git a/llm_router-20b.py b/llm_router-20b.py new file mode 100644 index 0000000..c1086e9 --- /dev/null +++ b/llm_router-20b.py @@ -0,0 +1,1070 @@ +""" +LLM Router Pipeline — v3.0 + +AI-powered prompt classifier + aggressive web search with full page fetching. +Uses qwen2.5:7b as zero-shot classifier backed by robust heuristic overrides. + +Routing: + - coding → qwen2.5-coder:14b + - diagram → qwen2.5-coder:14b + - reasoning (FI) → gpt-oss:20b + - reasoning (EN) → gpt-oss:20b + - image_generation → gpt-oss:20b (prompt refinement) → AUTOMATIC1111 (Stable Diffusion) + - vision → llama3.2-vision:11b + - general → gpt-oss:20b + +Web search is triggered aggressively — any question that might benefit from +current information will trigger a Brave Search + page content fetch. +""" + +from typing import List, Iterator, Optional +from pydantic import BaseModel, Field +from datetime import date, datetime +import requests +import json +import re +import os +import base64 +from io import BytesIO + +# --------------------------------------------------------------------------- +# Configuration defaults +# --------------------------------------------------------------------------- +OLLAMA_URL = "http://ollama:11434" +CLASSIFIER_MODEL = "qwen2.5:7b" +BRAVE_API_KEY = os.environ.get("BRAVE_API_KEY", "") +BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search" + +MODELS = { + "coding": "qwen2.5-coder:14b", + "diagram": "qwen2.5-coder:14b", + "reasoning_en": "gpt-oss:20b", + "reasoning_fi": "gpt-oss:20b", + "vision": "llama3.2-vision:11b", + "image_generation": "gpt-oss:20b", + "general": "gpt-oss:20b", +} + +SD_URL = "http://172.18.0.1:7860" + +SYSTEM_PROMPTS = { + "image_generation": ( + "You are a Stable Diffusion prompt engineer. The user wants an image generated. " + "Your ONLY job is to convert their request into an optimized Stable Diffusion prompt. " + "Output ONLY the prompt text, nothing else — no explanation, no markdown, no quotes. " + "Use comma-separated tags and descriptors. Include quality boosters like: " + "masterpiece, best quality, highly detailed, sharp focus, professional. " + "Include style, lighting, and composition details. " + "If the user writes in Finnish, still output the SD prompt in English." + ), + "coding": ( + "You are an expert programmer and DevOps engineer. " + "Provide clean, well-commented code. Use best practices. " + "If the user writes in Finnish, respond in Finnish but keep code and comments in English." + ), + "diagram": ( + "You are a technical architect who communicates through diagrams. " + "Always respond with a Mermaid diagram inside a ```mermaid code block. " + "After the diagram, provide a short explanation. " + "If the user writes in Finnish, the explanation should be in Finnish. " + "Use flowchart TD, sequenceDiagram, classDiagram, erDiagram or gantt as appropriate." + ), + "reasoning_en": ( + "Reasoning: high\n" + "You are an analytical expert. Think step by step, consider multiple perspectives, " + "and provide well-structured, thorough responses." + ), + "reasoning_fi": ( + "Olet analyyttinen asiantuntija. Ajattele vaihe vaiheelta, harkitse useita näkökulmia " + "ja anna hyvin jäsennelty, perusteellinen vastaus suomen kielellä." + ), + "vision": ( + "You are a visual analysis assistant. Describe and analyze images in detail. " + "If the user writes in Finnish, respond in Finnish." + ), + "general": ( + "You are a helpful, friendly assistant. You speak both Finnish and English fluently — " + "always respond in the same language the user is using. " + "Be concise but thorough." + ), +} + +# --------------------------------------------------------------------------- +# Finnish language detection +# --------------------------------------------------------------------------- +FINNISH_INDICATORS = [ + "ä", "ö", "analysoi", "vertaile", "arvioi", "selitä", "miksi", "miten", + "mitä", "kuinka", "kerro", "anna", "tee", "kirjoita", "auta", "voitko", + "pitäisi", "kannattaa", "parempi", "huonompi", "ero", "hyödyt", "haitat", + "onko", "missä", "milloin", "paljonko", "kuka", "mikä", +] + + +def detect_finnish(text: str) -> bool: + t = text.lower() + return sum(1 for ind in FINNISH_INDICATORS if ind in t) >= 2 + + +def resolve_reasoning_key(prompt: str) -> str: + return "reasoning_fi" if detect_finnish(prompt) else "reasoning_en" + + +# --------------------------------------------------------------------------- +# Search-need heuristics (the "safety net" that fires BEFORE the AI classifier) +# +# The philosophy: if in doubt, search. A redundant search costs ~200 ms. +# A missed search costs the user a hallucinated or stale answer. +# --------------------------------------------------------------------------- + +# Question-word patterns (EN + FI) +_QUESTION_WORDS_EN = ( + r"\b(?:what|which|who|whom|whose|where|when|why|how|is|are|was|were|do|does|did" + r"|can|could|will|would|should|has|have|had)\b" +) +_QUESTION_WORDS_FI = ( + r"\b(?:mikä|mitä|kuka|kenen|missä|minne|mistä|milloin|miksi|miten|kuinka" + r"|onko|ovatko|oliko|voiko|pitääkö|saako|paljonko|kumpi|montako)\b" +) + +# Time-sensitive / factual trigger words +_FRESHNESS_KEYWORDS = { + # English + "latest", "newest", "current", "recent", "update", "updated", "today", + "yesterday", "this week", "this month", "this year", "2024", "2025", "2026", + "now", "right now", "currently", "as of", "still", "anymore", + "release", "released", "version", "price", "cost", "stock", "weather", + "score", "election", "news", "announcement", "launched", "launches", + "available", "discontinued", "deprecated", "roadmap", "deadline", + "schedule", "status", "outage", "incident", + # Finnish + "uusin", "viimeisin", "nykyinen", "tuorein", "päivitetty", "tänään", + "eilen", "tällä viikolla", "tässä kuussa", "tänä vuonna", + "hinta", "sää", "tulos", "uutinen", "julkaistu", "julkaisu", + "versio", "saatavilla", "poistunut", "aikataulu", "tilanne", +} + +# Patterns that almost always need search +_SEARCH_PATTERNS = [ + # "What is the latest X", "Who is the CEO of Y", etc. + re.compile(r"(?:what|which|who|where|when)\s+(?:is|are|was|were)\s+(?:the\s+)?(?:latest|current|new)", re.I), + # "How much does X cost" + re.compile(r"how\s+much\s+(?:does|do|is|are|will)", re.I), + # URLs or domain names in the prompt → user is asking about a specific site/service + re.compile(r"https?://|www\.", re.I), + # Explicit search requests + re.compile(r"\b(?:search|google|look\s*up|find\s+out|hae|etsi|googlaa)\b", re.I), + # "Is X still Y" / "Does X still Y" + re.compile(r"\b(?:is|does|are|do)\s+\w+\s+still\b", re.I), + # Finnish question forms ending with -ko/-kö + re.compile(r"\b\w+(?:ko|kö)\b", re.I), +] + +# Topics that are inherently time-sensitive +_SEARCH_TOPIC_PATTERNS = [ + re.compile(r"\b(?:CVE|vulnerability|exploit|zero[- ]?day|security\s+(?:patch|update|advisory))\b", re.I), + re.compile(r"\b(?:LTS|EOL|end[- ]of[- ]life|support\s+(?:lifecycle|end))\b", re.I), + re.compile(r"\b(?:Windows\s+1[12]|Ubuntu\s+\d|macOS\s+\w+|iOS\s+\d|Android\s+\d)\b", re.I), + re.compile(r"\b(?:GPT|Claude|Gemini|Llama|Mistral|Qwen|DeepSeek)\s*[-.]?\s*\d", re.I), + re.compile(r"\b(?:Azure|AWS|GCP)\s+\w+", re.I), +] + + +def heuristic_needs_search(text: str) -> bool: + """Return True if heuristics say this message almost certainly needs a web search.""" + t_lower = text.lower() + + # 1. Any freshness keyword present? + for kw in _FRESHNESS_KEYWORDS: + if kw in t_lower: + return True + + # 2. Regex patterns + for pat in _SEARCH_PATTERNS: + if pat.search(text): + return True + + # 3. Time-sensitive topics + for pat in _SEARCH_TOPIC_PATTERNS: + if pat.search(text): + return True + + # 4. Question with a named entity (capitalised word after question word) → likely factual + if re.search(_QUESTION_WORDS_EN, t_lower) or re.search(_QUESTION_WORDS_FI, t_lower): + # If the message contains a proper noun (capitalized word that isn't sentence-start) + words = text.split() + if len(words) > 2: + interior_caps = [w for w in words[1:] if w[0:1].isupper() and not w.isupper()] + if interior_caps: + return True + + # 5. If the message is a question (ends with ?) and is not purely about coding + if text.rstrip().endswith("?"): + coding_signals = {"code", "script", "function", "error", "bug", "koodi", "skripti"} + if not any(sig in t_lower for sig in coding_signals): + return True + + return False + + +# --------------------------------------------------------------------------- +# Category classification +# --------------------------------------------------------------------------- + +CLASSIFIER_SYSTEM_TEMPLATE = """You are a prompt classifier. Output ONLY a JSON object, no other text. + +TODAY'S DATE: {today} + +Categories: +- coding: the user is explicitly asking you to WRITE, FIX, DEBUG, or REVIEW actual code, scripts, config files, or CLI commands. Mentioning IT, software, or technology in conversation is NOT coding — the user must be requesting code output. +- diagram: the user explicitly asks for a diagram, flowchart, architecture drawing, Mermaid, UML, ER diagram, or visual representation +- reasoning: deep analysis, comparisons, strategy, pros/cons, evaluations, "explain in detail", workplace situations, decision-making, persuasion, argumentation +- image_generation: user wants to CREATE/GENERATE/DRAW an image, picture, photo, illustration, artwork, logo +- vision: when the user references or asks about an EXISTING image/picture (already uploaded) +- general: everything else — simple questions, facts, casual conversation, advice, how-to, storytelling, venting about work situations + +IMPORTANT: When the user is discussing work situations, meetings, processes, or asking for advice about IT management/strategy, that is "general" or "reasoning" — NOT "coding". Only classify as "coding" when the user literally wants code written. + +Search decision — set search=true if ANY of these apply: +- The answer might change over time (versions, prices, dates, availability, people in roles) +- The question asks about a specific product, service, tool, or technology by name +- The question asks about events, news, announcements +- The question is about something that happened or will happen at a specific time +- The user asks "what is X" about something that isn't a timeless concept +- You are not 100% sure of the correct answer from general knowledge alone +Default to search=true when uncertain. + +Search query rules: +- Write a concise, effective web search query in the same language as the user. +- When the user says "yesterday", "today", "this week" etc., convert to actual dates based on TODAY'S DATE above. +- The current year is {year}. NEVER use any other year unless the user explicitly mentions one. +- If the topic involves versions or releases, include {year} in the query. + +Output format (ONLY this JSON, nothing else): +{{"category": "general", "search": true, "query": "search terms here"}} + +Examples: +{{"category": "coding", "search": false, "query": ""}} +{{"category": "coding", "search": true, "query": "Python 3.13 new features {year}"}} +{{"category": "general", "search": true, "query": "latest Ubuntu LTS release {year}"}} +{{"category": "reasoning", "search": false, "query": ""}} +{{"category": "general", "search": true, "query": "Microsoft Intune new features {year}"}} +{{"category": "diagram", "search": false, "query": ""}} +{{"category": "reasoning", "search": true, "query": "Azure Landing Zone vs hub-spoke comparison"}} +{{"category": "image_generation", "search": false, "query": ""}}""" + + +def _build_classifier_prompt() -> str: + today = date.today() + return CLASSIFIER_SYSTEM_TEMPLATE.format( + today=today.strftime("%Y-%m-%d"), + year=today.year, + ) + + +def _keyword_classify(text: str) -> str: + """Fast keyword-based fallback classifier.""" + t = text.lower() + + # Image generation — check first because it's the most specific + image_gen_kw = [ + "generate an image", "generate image", "create an image", "create image", + "draw me", "draw a picture", "make an image", "make a picture", + "generate a photo", "create a photo", "make a photo", + "generate art", "create art", "make art", + "generate a picture", "create a picture", + "luo kuva", "generoi kuva", "tee kuva", "piirrä kuva", + "luo valokuva", "tee valokuva", + ] + for kw in image_gen_kw: + if kw in t: + return "image_generation" + + # Diagram + diagram_kw = [ + "kaavio", "piirrä", "diagram", "flowchart", "sequence diagram", + "arkkitehtuurikuva", "mermaid", "vuokaavio", "uml", "er-kaavio", + "gantt", "mindmap", "visualisoi", "visualize", "draw a diagram", + "draw a chart", "draw a flow", + ] + for kw in diagram_kw: + if kw in t: + return "diagram" + + # Coding — be strict to avoid false positives + coding_kw = [ + "write code", "write a script", "kirjoita koodi", "kirjoita skripti", + "koodaa", "debug", "refactor", "fix this code", "fix this script", + "powershell", "bash script", "python script", "dockerfile", + "terraform", "ansible", "yaml config", "json schema", + "function that", "class that", "regex for", "kirjoita funktio", + ] + for kw in coding_kw: + if kw in t: + return "coding" + + # Strong reasoning signals (multi-word to reduce false positives) + reasoning_kw = [ + "analysoi yksityiskohtaisesti", "vertaile näitä", "arvioi", + "analyze in detail", "compare these", "pros and cons", + "hyvät ja huonot puolet", "explain in detail", "deep dive", + "selitä yksityiskohtaisesti", "strategia", "strategy for", + ] + for kw in reasoning_kw: + if kw in t: + return "reasoning" + + return "general" + + +def _ai_classify(prompt: str, ollama_url: str) -> dict: + """ + Call the classifier LLM. Returns dict with keys: category, search, query, method. + On failure returns a keyword-based fallback. + """ + fallback = { + "category": _keyword_classify(prompt), + "search": heuristic_needs_search(prompt), + "query": "", + "method": "fallback", + } + + try: + payload = { + "model": CLASSIFIER_MODEL, + "messages": [ + {"role": "system", "content": _build_classifier_prompt()}, + {"role": "user", "content": prompt[:1500]}, + ], + "stream": False, + "options": {"temperature": 0, "num_ctx": 2048}, + } + + resp = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=15) + resp.raise_for_status() + + raw = resp.json()["message"]["content"].strip() + # Strip markdown fences if the model wraps its output + raw = re.sub(r"^```(?:json)?\s*", "", raw) + raw = re.sub(r"\s*```$", "", raw) + + parsed = json.loads(raw) + category = parsed.get("category", "general").lower().strip() + search = bool(parsed.get("search", False)) + query = parsed.get("query", "").strip() + + if category not in ("coding", "diagram", "reasoning", "vision", "image_generation", "general"): + print(f"[Router] Unknown category '{category}', falling back to keyword") + category = _keyword_classify(prompt) + + return { + "category": category, + "search": search, + "query": query, + "method": "ai", + } + + except json.JSONDecodeError as e: + print(f"[Router] Classifier returned invalid JSON: {e}") + return fallback + except requests.RequestException as e: + print(f"[Router] Classifier request failed: {e}") + return fallback + except Exception as e: + print(f"[Router] Classifier error: {e}") + return fallback + + +def classify(prompt: str, ollama_url: str, use_ai: bool) -> dict: + """ + Main classification entry point. + Returns dict: category (resolved to model key), search, query, method. + + The heuristic search layer acts as a safety net: even if the AI says + search=false, heuristics can override to search=true. + """ + if use_ai: + result = _ai_classify(prompt, ollama_url) + else: + result = { + "category": _keyword_classify(prompt), + "search": False, + "query": "", + "method": "keyword", + } + + # --- Heuristic overrides --- + + # 1. Force search if heuristics say so (even if AI said no) + if not result["search"] and heuristic_needs_search(prompt): + result["search"] = True + result["method"] += "+heuristic_search" + print("[Router] Heuristic override: forcing search=true") + + # 2. Generate a search query if search is needed but query is empty + if result["search"] and not result["query"]: + result["query"] = _generate_search_query(prompt) + + # 3. Inject year into version/release queries + if result["search"] and result["query"]: + result["query"] = _inject_year(result["query"]) + + # 4. Resolve reasoning → reasoning_fi / reasoning_en + if result["category"] == "reasoning": + result["category"] = resolve_reasoning_key(prompt) + + return result + + +def _generate_search_query(prompt: str) -> str: + """Extract a reasonable search query from the user prompt.""" + # Take the first sentence or first 120 chars, whichever is shorter + first_sentence = re.split(r"[.!?\n]", prompt)[0].strip() + query = first_sentence[:120] + # Remove filler words for a tighter query + for filler in ["please", "can you", "could you", "tell me", "I want to know", + "voitko", "kerro", "haluaisin tietää", "ole hyvä"]: + query = re.sub(re.escape(filler), "", query, flags=re.IGNORECASE) + return query.strip() or prompt[:80] + + +_VERSION_KEYWORDS = [ + "latest", "newest", "current", "recent", "release", "version", + "uusin", "viimeisin", "nykyinen", "tuorein", "versio", "julkaisu", +] + + +def _inject_year(query: str) -> str: + """ + Fix year issues in search queries: + 1. Replace any hallucinated wrong year (e.g. 2023, 2024) with the current year + 2. Append the current year if version/release keywords are present but no year is + """ + current_year = date.today().year + current_str = str(current_year) + + # Step 1: Replace wrong years that the classifier may have hallucinated. + # We consider any 4-digit year from 2020 to current_year-1 as potentially wrong, + # UNLESS the user's original query explicitly contained that year (we can't check + # that here, but the classifier prompt now tells it the correct year, so a wrong + # year in the query is almost certainly hallucinated). + wrong_year_pattern = re.compile(r"\b(20(?:2[0-9]|3[0-9]))\b") + def _replace_year(m: re.Match) -> str: + y = int(m.group(1)) + if y != current_year and y < current_year: + return current_str + return m.group(0) + + fixed = wrong_year_pattern.sub(_replace_year, query) + + # Step 2: If no year present at all and version keywords exist, append current year + if current_str not in fixed: + if any(kw in fixed.lower() for kw in _VERSION_KEYWORDS): + fixed = f"{fixed} {current_str}" + + return fixed + + +# --------------------------------------------------------------------------- +# Brave Search with page content fetching +# --------------------------------------------------------------------------- + +def brave_search(query: str, api_key: str, max_results: int = 6, status_callback=None) -> str: + """ + Search Brave and fetch top page contents for richer context. + Returns a formatted string suitable for injection into the system prompt. + status_callback: optional function(str) called with real-time status updates. + """ + def _status(msg: str): + if status_callback: + status_callback(msg) + + if not api_key: + return "⚠️ Brave API key not configured." + + try: + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": api_key, + } + params = { + "q": query, + "count": max_results, + "text_decorations": False, + "search_lang": "fi" if detect_finnish(query) else "en", + } + + _status(f"Searching: *{query}*") + resp = requests.get(BRAVE_SEARCH_URL, headers=headers, params=params, timeout=10) + resp.raise_for_status() + + data = resp.json() + web_results = data.get("web", {}).get("results", []) + + if not web_results: + return f"No web results found for: {query}" + + _status(f"Found {len(web_results)} results") + + # Also grab any infobox / knowledge graph snippet + infobox_text = "" + infobox = data.get("infobox", {}) + if isinstance(infobox, dict): + long_desc = infobox.get("long_desc", "") + if long_desc: + infobox_text = f"Knowledge panel: {long_desc}\n\n" + + # Build results with page content fetching for top 3 + sections = [] + if infobox_text: + sections.append(infobox_text) + + for i, r in enumerate(web_results): + title = r.get("title", "") + url = r.get("url", "") + desc = r.get("description", "") + age = r.get("age", "") + age_str = f" ({age})" if age else "" + + # Fetch full page content for the top 3 results + page_content = "" + if i < 3: + _status(f"Reading [{i+1}/{min(3, len(web_results))}]: [{title}]({url})") + page_content = _fetch_page_content(url) + + section = f"[{i+1}] {title}{age_str}\nURL: {url}\nSnippet: {desc}" + if page_content: + section += f"\nContent:\n{page_content}" + sections.append(section) + + _status("Search complete") + return f"Web search results for: {query}\n{'='*60}\n\n" + "\n\n---\n\n".join(sections) + + except Exception as e: + print(f"[Router] Brave Search failed: {e}") + return f"Web search failed: {e}" + + +def _fetch_page_content(url: str, max_chars: int = 3000) -> str: + """Fetch and extract readable text from a URL. Returns truncated plain text.""" + try: + headers = { + "User-Agent": "Mozilla/5.0 (compatible; LLMRouter/3.0)", + "Accept": "text/html,application/xhtml+xml", + } + resp = requests.get(url, headers=headers, timeout=8, allow_redirects=True) + resp.raise_for_status() + + content_type = resp.headers.get("Content-Type", "") + if "html" not in content_type and "text" not in content_type: + return "" + + html = resp.text + + # Lightweight HTML → text extraction (no BeautifulSoup dependency) + # Remove script, style, nav, header, footer tags and their contents + for tag in ["script", "style", "nav", "header", "footer", "aside", "noscript"]: + html = re.sub(rf"<{tag}[^>]*>.*?", " ", html, flags=re.S | re.I) + + # Remove all remaining HTML tags + text = re.sub(r"<[^>]+>", " ", html) + + # Decode HTML entities + text = text.replace("&", "&").replace("<", "<").replace(">", ">") + text = text.replace(""", '"').replace("'", "'").replace(" ", " ") + text = re.sub(r"&#\d+;", " ", text) + text = re.sub(r"&\w+;", " ", text) + + # Collapse whitespace + text = re.sub(r"\s+", " ", text).strip() + + if len(text) < 50: + return "" + + return text[:max_chars] + + except Exception: + return "" + + +# --------------------------------------------------------------------------- +# Stable Diffusion image generation +# --------------------------------------------------------------------------- + +def _refine_sd_prompt(user_message: str, ollama_url: str, messages: List[dict] = None) -> str: + """Use the LLM to convert a user request into an optimized SD prompt. + Includes conversation history so the model understands context like 'generate an image of that'. + """ + try: + # Build context from recent conversation history + context_messages = [{"role": "system", "content": SYSTEM_PROMPTS["image_generation"]}] + if messages: + # Include last few exchanges for context (trim to avoid blowing up the context) + recent = [m for m in messages if m.get("role") in ("user", "assistant") and m.get("content")] + for msg in recent[-6:]: # Last 3 exchanges + content = msg["content"] + if isinstance(content, list): + content = " ".join(p.get("text", "") for p in content if isinstance(p, dict)) + context_messages.append({"role": msg["role"], "content": content[:500]}) + else: + context_messages.append({"role": "user", "content": user_message[:500]}) + + payload = { + "model": MODELS["image_generation"], + "messages": context_messages, + "stream": False, + "options": {"temperature": 0.7, "num_ctx": 4096}, + } + resp = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=30) + resp.raise_for_status() + refined = resp.json()["message"]["content"].strip() + # Strip any accidental markdown or quotes the model might add + refined = refined.strip('"\'`') + refined = re.sub(r"^```\w*\s*", "", refined) + refined = re.sub(r"\s*```$", "", refined) + return refined + except Exception as e: + print(f"[Router] SD prompt refinement failed: {e}") + # Fallback: use the user message directly + return user_message + + +def _negative_prompt() -> str: + """Standard negative prompt for SD.""" + return ( + "lowres, bad anatomy, bad hands, text, error, missing fingers, " + "extra digit, fewer digits, cropped, worst quality, low quality, " + "normal quality, jpeg artifacts, signature, watermark, username, blurry, " + "deformed, distorted, disfigured, mutation, mutated, ugly" + ) + + +def _compress_image(b64_png: str, quality: int = 80) -> str: + """Convert a base64 PNG from SD to a smaller base64 JPEG.""" + try: + from PIL import Image + img_data = base64.b64decode(b64_png) + img = Image.open(BytesIO(img_data)) + if img.mode == "RGBA": + img = img.convert("RGB") + buf = BytesIO() + img.save(buf, format="JPEG", quality=quality, optimize=True) + return base64.b64encode(buf.getvalue()).decode("utf-8") + except Exception as e: + print(f"[Router] Image compression failed: {e}, using original") + return b64_png + + +def _unload_ollama_models(ollama_url: str): + """Unload all Ollama models from VRAM to make room for image generation.""" + try: + # List running models + resp = requests.get(f"{ollama_url}/api/ps", timeout=5) + if resp.ok: + models = resp.json().get("models", []) + for model in models: + name = model.get("name", "") + if name: + # Setting keep_alive to 0 unloads the model immediately + requests.post( + f"{ollama_url}/api/generate", + json={"model": name, "keep_alive": 0}, + timeout=10, + ) + print(f"[Router] Unloaded Ollama model: {name}") + except Exception as e: + print(f"[Router] Failed to unload Ollama models: {e}") + + +def _cleanup_after_generation(sd_url: str): + """Free VRAM and RAM after image generation so Ollama can load models.""" + # 1. Unload SD checkpoint from VRAM + try: + requests.post(f"{sd_url}/sdapi/v1/unload-checkpoint", timeout=5) + print("[Router] SD checkpoint unloaded from VRAM") + except Exception: + pass + + # 2. Drop Linux page cache to free RAM + try: + os.system("sync; echo 3 > /proc/sys/vm/drop_caches 2>/dev/null") + print("[Router] Page cache dropped") + except Exception: + pass + + +def generate_image( + user_message: str, + ollama_url: str, + sd_url: str, + width: int = 512, + height: int = 512, + steps: int = 30, + cfg_scale: float = 7.0, + messages: List[dict] = None, +) -> tuple: + """ + Generate an image via AUTOMATIC1111 API. + Returns (base64_image, refined_prompt) on success, or (None, error_message) on failure. + """ + # Step 1: Refine the prompt using the LLM FIRST (while Ollama is still loaded) + refined_prompt = _refine_sd_prompt(user_message, ollama_url, messages) + + # Step 2: Unload Ollama models from VRAM to make room for SDXL + _unload_ollama_models(ollama_url) + print(f"[Router] SD prompt: {refined_prompt[:120]}") + + # Step 2: Call AUTOMATIC1111 + try: + payload = { + "prompt": refined_prompt, + "negative_prompt": _negative_prompt(), + "width": width, + "height": height, + "steps": steps, + "cfg_scale": cfg_scale, + "sampler_name": "DPM++ 2M Karras", + "batch_size": 1, + "n_iter": 1, + } + resp = requests.post( + f"{sd_url}/sdapi/v1/txt2img", + json=payload, + timeout=120, + ) + resp.raise_for_status() + + data = resp.json() + images = data.get("images", []) + if not images: + return None, "Stable Diffusion returned no images." + + # Compress PNG→JPEG to reduce base64 size for streaming + compressed = _compress_image(images[0]) + + # Free VRAM and RAM so Ollama can load models again + _cleanup_after_generation(sd_url) + + return compressed, refined_prompt + + except requests.exceptions.ConnectionError: + return None, f"Cannot connect to Stable Diffusion at {sd_url}. Is it running?" + except requests.exceptions.Timeout: + return None, "Image generation timed out (>120s)." + except Exception as e: + return None, f"Image generation failed: {e}" + + +# --------------------------------------------------------------------------- +# Image handling (unchanged from v2) +# --------------------------------------------------------------------------- + +def extract_images_from_messages(messages: List[dict]) -> tuple: + """Separate base64 images from message content.""" + images = [] + clean_messages = [] + + for msg in messages: + content = msg.get("content", "") + if isinstance(content, list): + text_parts = [] + for part in content: + if isinstance(part, dict): + if part.get("type") == "text": + text_parts.append(part.get("text", "")) + elif part.get("type") == "image_url": + url = part.get("image_url", {}).get("url", "") + if url.startswith("data:"): + match = re.match(r"data:[^;]+;base64,(.+)", url) + if match: + images.append(match.group(1)) + clean_messages.append({ + "role": msg["role"], + "content": " ".join(text_parts).strip(), + }) + else: + clean_messages.append(msg) + + return images, clean_messages + + +def has_image_content(messages: List[dict]) -> bool: + """Check if the latest user message contains an uploaded image. + Only checks user messages (not assistant responses which may contain generated images). + """ + # Find the last user message + for msg in reversed(messages): + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image_url": + return True + elif isinstance(content, str) and "data:image" in content: + return True + return False # Last user message found but has no image + return False + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + +class Pipeline: + class Valves(BaseModel): + ollama_url: str = Field(default=OLLAMA_URL, description="Ollama API base URL") + sd_url: str = Field(default=SD_URL, description="AUTOMATIC1111 Stable Diffusion WebUI URL") + sd_width: int = Field(default=1024, description="Generated image width") + sd_height: int = Field(default=1024, description="Generated image height") + sd_steps: int = Field(default=25, description="Stable Diffusion sampling steps") + sd_cfg_scale: float = Field(default=7.0, description="Stable Diffusion CFG scale") + brave_api_key: str = Field(default=BRAVE_API_KEY, description="Brave Search API key") + brave_max_results: int = Field(default=6, description="Number of Brave search results to fetch") + brave_fetch_pages: int = Field(default=3, description="Number of top results to fetch full page content for") + use_ai_classifier: bool = Field(default=True, description="Use AI classifier (vs keyword-only)") + show_routing_info: bool = Field(default=True, description="Show routing banner in responses") + search_context_max_chars: int = Field(default=12000, description="Max chars of search context to inject") + + def __init__(self): + self.id = "llm-router-20b" + self.name = "LLM Router v3 (20B)" + self.valves = self.Valves() + + async def on_startup(self): + print(f"[Router] LLM Router v3.0 starting — Ollama: {self.valves.ollama_url}") + print(f"[Router] Classifier: {CLASSIFIER_MODEL} | AI: {self.valves.use_ai_classifier}") + print(f"[Router] Brave Search: {'configured' if self.valves.brave_api_key else 'NO API KEY'}") + + async def on_shutdown(self): + print("[Router] LLM Router v3.0 shutting down") + + def pipe( + self, + user_message: str, + model_id: str, + messages: List[dict], + body: dict, + ) -> Iterator[str]: + + # --- Step 1: Vision override --- + if has_image_content(messages): + category = "vision" + needs_search = False + search_query = "" + method = "vision_detect" + else: + # --- Step 2: Classify --- + result = classify( + user_message, + self.valves.ollama_url, + self.valves.use_ai_classifier, + ) + category = result["category"] + needs_search = result["search"] + search_query = result["query"] + method = result["method"] + + target_model = MODELS.get(category, MODELS["general"]) + system_prompt = SYSTEM_PROMPTS.get(category, SYSTEM_PROMPTS["general"]) + + # Inject language instruction — always respond in the user's language + if detect_finnish(user_message) and category not in ("reasoning_fi", "image_generation"): + system_prompt = ( + "TÄRKEÄ: Käyttäjä kirjoittaa suomeksi. Vastaa AINA suomeksi.\n\n" + + system_prompt + ) + + print(f"[Router] {method} → {category} → {target_model} | search={needs_search} query='{search_query}'") + + # --- Step 3: Routing info banner --- + if self.valves.show_routing_info: + display_cat = category.replace("_en", " 🇬🇧").replace("_fi", " 🇫🇮") + search_label = f" | 🌐 `{search_query}`" if needs_search else "" + yield f"> 🔀 **Router** `[{method}]` → `{target_model}` *(category: {display_cat}){search_label}*\n\n" + + # --- Step 4: Image generation (early return) --- + if category == "image_generation": + yield "> 🎨 Generating image…\n\n" + base64_img, refined_prompt = generate_image( + user_message, + self.valves.ollama_url, + self.valves.sd_url, + width=self.valves.sd_width, + height=self.valves.sd_height, + steps=self.valves.sd_steps, + cfg_scale=self.valves.sd_cfg_scale, + messages=messages, + ) + if base64_img: + # Yield the image in chunks to avoid "chunk too big" errors + img_tag = f"![Generated image](data:image/jpeg;base64,{base64_img})" + chunk_size = 4096 + for i in range(0, len(img_tag), chunk_size): + yield img_tag[i:i + chunk_size] + yield "\n\n" + yield f"*Prompt used: {refined_prompt}*\n" + else: + yield f"\n\n❌ {refined_prompt}\n" + return + + # --- Step 5: Web search --- + search_context = "" + search_status_lines = [] + if needs_search and search_query and self.valves.brave_api_key: + # Collect status updates via callback, yield them in real time + def _on_status(msg: str): + search_status_lines.append(msg) + + yield "> 🔍 Searching the web…\n\n" + + # We need to yield status in real-time, so we run search in a thread + import threading + search_result = [None] + def _run_search(): + search_result[0] = brave_search( + search_query, + self.valves.brave_api_key, + max_results=self.valves.brave_max_results, + status_callback=_on_status, + ) + + t = threading.Thread(target=_run_search) + t.start() + + last_count = 0 + while t.is_alive(): + t.join(timeout=0.3) + # Yield any new status lines + while last_count < len(search_status_lines): + yield f"> {search_status_lines[last_count]}\n>\n" + last_count += 1 + # Yield any remaining status lines + while last_count < len(search_status_lines): + yield f"> {search_status_lines[last_count]}\n>\n" + last_count += 1 + + yield "\n\n" + search_context = search_result[0] or "" + # Truncate if too large + if len(search_context) > self.valves.search_context_max_chars: + search_context = search_context[:self.valves.search_context_max_chars] + "\n\n[...truncated]" + + # Detect failed search and warn the user + if (search_context.startswith("⚠️") + or search_context.startswith("Web search failed") + or search_context.startswith("No web results found")): + yield "> ⚠️ Web search failed — answering from model knowledge.\n\n" + print(f"[Router] Search failed: {search_context[:120]}") + else: + print(f"[Router] Search complete: {len(search_context)} chars") + + elif needs_search and not self.valves.brave_api_key: + yield "> ⚠️ Web search not available (no API key) — answering from model knowledge.\n\n" + print("[Router] Search needed but no Brave API key configured!") + + # --- Step 6: Build messages --- + images, clean_messages = extract_images_from_messages(messages) + + # Check if search actually returned usable results (not just an error) + search_ok = bool( + search_context + and not search_context.startswith("⚠️") + and not search_context.startswith("Web search failed") + and not search_context.startswith("No web results found") + ) + + if search_ok: + today = date.today().strftime("%Y-%m-%d") + full_system = ( + f"{system_prompt}\n\n" + f"Today's date: {today}\n\n" + f"## Web Search Results\n" + f"The following are fresh web search results. Use them as your PRIMARY source of truth.\n" + f"Your training data may be outdated — always prefer information from these results.\n" + f"When results conflict, prefer the most recent one (check dates/ages).\n" + f"Cite the source URL when stating specific facts.\n" + f"If the search results don't contain enough information to fully answer, " + f"say so honestly rather than guessing.\n\n" + f"{search_context}" + ) + elif needs_search: + # Search was requested but failed — tell the model so it can be honest + today = date.today().strftime("%Y-%m-%d") + full_system = ( + f"{system_prompt}\n\n" + f"Today's date: {today}\n\n" + f"NOTE: A web search was attempted for this question but failed or returned no results. " + f"Answer as best you can from your training data, but clearly state that you could not " + f"verify the information with a live web search and the answer may be outdated." + ) + else: + full_system = system_prompt + + ollama_messages = [{"role": "system", "content": full_system}] + for msg in clean_messages: + if msg.get("role") in ("user", "assistant") and msg.get("content"): + ollama_messages.append(msg) + + # Attach images to the last user message + if images: + for i in range(len(ollama_messages) - 1, -1, -1): + if ollama_messages[i]["role"] == "user": + ollama_messages[i]["images"] = images + break + + # --- Step 7: Call the target model --- + payload = { + "model": target_model, + "messages": ollama_messages, + "stream": True, + "options": { + "temperature": body.get("temperature", 0.7), + "num_ctx": 8192, + }, + } + + try: + resp = requests.post( + f"{self.valves.ollama_url}/api/chat", + json=payload, + stream=True, + timeout=180, + ) + resp.raise_for_status() + + in_thinking = False + for line in resp.iter_lines(): + if line: + try: + chunk = json.loads(line) + except json.JSONDecodeError: + continue + + msg = chunk.get("message", {}) + + # Handle thinking/reasoning tokens (displayed in a collapsible block) + thinking_content = msg.get("thinking", "") + if thinking_content: + if not in_thinking: + yield "
\n💭 Thinking…\n\n" + in_thinking = True + yield thinking_content + + # Handle regular content + if msg.get("content"): + if in_thinking: + yield "\n
\n\n" + in_thinking = False + yield msg["content"] + + if chunk.get("done"): + if in_thinking: + yield "\n\n\n" + break + + except requests.exceptions.ConnectionError: + yield f"\n\n❌ Connection error to Ollama ({self.valves.ollama_url}). Is the service running?" + except requests.exceptions.Timeout: + yield "\n\n❌ Timeout — the model is responding too slowly." + except Exception as e: + yield f"\n\n❌ Error: {str(e)}" diff --git a/llm_router_v3.py b/llm_router_v3.py new file mode 100644 index 0000000..7f31cc5 --- /dev/null +++ b/llm_router_v3.py @@ -0,0 +1,1070 @@ +""" +LLM Router Pipeline — v3.0 + +AI-powered prompt classifier + aggressive web search with full page fetching. +Uses qwen2.5:7b as zero-shot classifier backed by robust heuristic overrides. + +Routing: + - coding → qwen2.5-coder:14b + - diagram → qwen2.5-coder:14b + - reasoning (FI) → gpt-oss:120b + - reasoning (EN) → gpt-oss:120b + - image_generation → gpt-oss:120b (prompt refinement) → AUTOMATIC1111 (Stable Diffusion) + - vision → llama3.2-vision:11b + - general → gpt-oss:120b + +Web search is triggered aggressively — any question that might benefit from +current information will trigger a Brave Search + page content fetch. +""" + +from typing import List, Iterator, Optional +from pydantic import BaseModel, Field +from datetime import date, datetime +import requests +import json +import re +import os +import base64 +from io import BytesIO + +# --------------------------------------------------------------------------- +# Configuration defaults +# --------------------------------------------------------------------------- +OLLAMA_URL = "http://ollama:11434" +CLASSIFIER_MODEL = "qwen2.5:7b" +BRAVE_API_KEY = os.environ.get("BRAVE_API_KEY", "") +BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search" + +MODELS = { + "coding": "qwen2.5-coder:14b", + "diagram": "qwen2.5-coder:14b", + "reasoning_en": "gpt-oss:120b", + "reasoning_fi": "gpt-oss:120b", + "vision": "llama3.2-vision:11b", + "image_generation": "gpt-oss:120b", + "general": "gpt-oss:120b", +} + +SD_URL = "http://172.18.0.1:7860" + +SYSTEM_PROMPTS = { + "image_generation": ( + "You are a Stable Diffusion prompt engineer. The user wants an image generated. " + "Your ONLY job is to convert their request into an optimized Stable Diffusion prompt. " + "Output ONLY the prompt text, nothing else — no explanation, no markdown, no quotes. " + "Use comma-separated tags and descriptors. Include quality boosters like: " + "masterpiece, best quality, highly detailed, sharp focus, professional. " + "Include style, lighting, and composition details. " + "If the user writes in Finnish, still output the SD prompt in English." + ), + "coding": ( + "You are an expert programmer and DevOps engineer. " + "Provide clean, well-commented code. Use best practices. " + "If the user writes in Finnish, respond in Finnish but keep code and comments in English." + ), + "diagram": ( + "You are a technical architect who communicates through diagrams. " + "Always respond with a Mermaid diagram inside a ```mermaid code block. " + "After the diagram, provide a short explanation. " + "If the user writes in Finnish, the explanation should be in Finnish. " + "Use flowchart TD, sequenceDiagram, classDiagram, erDiagram or gantt as appropriate." + ), + "reasoning_en": ( + "Reasoning: high\n" + "You are an analytical expert. Think step by step, consider multiple perspectives, " + "and provide well-structured, thorough responses." + ), + "reasoning_fi": ( + "Olet analyyttinen asiantuntija. Ajattele vaihe vaiheelta, harkitse useita näkökulmia " + "ja anna hyvin jäsennelty, perusteellinen vastaus suomen kielellä." + ), + "vision": ( + "You are a visual analysis assistant. Describe and analyze images in detail. " + "If the user writes in Finnish, respond in Finnish." + ), + "general": ( + "You are a helpful, friendly assistant. You speak both Finnish and English fluently — " + "always respond in the same language the user is using. " + "Be concise but thorough." + ), +} + +# --------------------------------------------------------------------------- +# Finnish language detection +# --------------------------------------------------------------------------- +FINNISH_INDICATORS = [ + "ä", "ö", "analysoi", "vertaile", "arvioi", "selitä", "miksi", "miten", + "mitä", "kuinka", "kerro", "anna", "tee", "kirjoita", "auta", "voitko", + "pitäisi", "kannattaa", "parempi", "huonompi", "ero", "hyödyt", "haitat", + "onko", "missä", "milloin", "paljonko", "kuka", "mikä", +] + + +def detect_finnish(text: str) -> bool: + t = text.lower() + return sum(1 for ind in FINNISH_INDICATORS if ind in t) >= 2 + + +def resolve_reasoning_key(prompt: str) -> str: + return "reasoning_fi" if detect_finnish(prompt) else "reasoning_en" + + +# --------------------------------------------------------------------------- +# Search-need heuristics (the "safety net" that fires BEFORE the AI classifier) +# +# The philosophy: if in doubt, search. A redundant search costs ~200 ms. +# A missed search costs the user a hallucinated or stale answer. +# --------------------------------------------------------------------------- + +# Question-word patterns (EN + FI) +_QUESTION_WORDS_EN = ( + r"\b(?:what|which|who|whom|whose|where|when|why|how|is|are|was|were|do|does|did" + r"|can|could|will|would|should|has|have|had)\b" +) +_QUESTION_WORDS_FI = ( + r"\b(?:mikä|mitä|kuka|kenen|missä|minne|mistä|milloin|miksi|miten|kuinka" + r"|onko|ovatko|oliko|voiko|pitääkö|saako|paljonko|kumpi|montako)\b" +) + +# Time-sensitive / factual trigger words +_FRESHNESS_KEYWORDS = { + # English + "latest", "newest", "current", "recent", "update", "updated", "today", + "yesterday", "this week", "this month", "this year", "2024", "2025", "2026", + "now", "right now", "currently", "as of", "still", "anymore", + "release", "released", "version", "price", "cost", "stock", "weather", + "score", "election", "news", "announcement", "launched", "launches", + "available", "discontinued", "deprecated", "roadmap", "deadline", + "schedule", "status", "outage", "incident", + # Finnish + "uusin", "viimeisin", "nykyinen", "tuorein", "päivitetty", "tänään", + "eilen", "tällä viikolla", "tässä kuussa", "tänä vuonna", + "hinta", "sää", "tulos", "uutinen", "julkaistu", "julkaisu", + "versio", "saatavilla", "poistunut", "aikataulu", "tilanne", +} + +# Patterns that almost always need search +_SEARCH_PATTERNS = [ + # "What is the latest X", "Who is the CEO of Y", etc. + re.compile(r"(?:what|which|who|where|when)\s+(?:is|are|was|were)\s+(?:the\s+)?(?:latest|current|new)", re.I), + # "How much does X cost" + re.compile(r"how\s+much\s+(?:does|do|is|are|will)", re.I), + # URLs or domain names in the prompt → user is asking about a specific site/service + re.compile(r"https?://|www\.", re.I), + # Explicit search requests + re.compile(r"\b(?:search|google|look\s*up|find\s+out|hae|etsi|googlaa)\b", re.I), + # "Is X still Y" / "Does X still Y" + re.compile(r"\b(?:is|does|are|do)\s+\w+\s+still\b", re.I), + # Finnish question forms ending with -ko/-kö + re.compile(r"\b\w+(?:ko|kö)\b", re.I), +] + +# Topics that are inherently time-sensitive +_SEARCH_TOPIC_PATTERNS = [ + re.compile(r"\b(?:CVE|vulnerability|exploit|zero[- ]?day|security\s+(?:patch|update|advisory))\b", re.I), + re.compile(r"\b(?:LTS|EOL|end[- ]of[- ]life|support\s+(?:lifecycle|end))\b", re.I), + re.compile(r"\b(?:Windows\s+1[12]|Ubuntu\s+\d|macOS\s+\w+|iOS\s+\d|Android\s+\d)\b", re.I), + re.compile(r"\b(?:GPT|Claude|Gemini|Llama|Mistral|Qwen|DeepSeek)\s*[-.]?\s*\d", re.I), + re.compile(r"\b(?:Azure|AWS|GCP)\s+\w+", re.I), +] + + +def heuristic_needs_search(text: str) -> bool: + """Return True if heuristics say this message almost certainly needs a web search.""" + t_lower = text.lower() + + # 1. Any freshness keyword present? + for kw in _FRESHNESS_KEYWORDS: + if kw in t_lower: + return True + + # 2. Regex patterns + for pat in _SEARCH_PATTERNS: + if pat.search(text): + return True + + # 3. Time-sensitive topics + for pat in _SEARCH_TOPIC_PATTERNS: + if pat.search(text): + return True + + # 4. Question with a named entity (capitalised word after question word) → likely factual + if re.search(_QUESTION_WORDS_EN, t_lower) or re.search(_QUESTION_WORDS_FI, t_lower): + # If the message contains a proper noun (capitalized word that isn't sentence-start) + words = text.split() + if len(words) > 2: + interior_caps = [w for w in words[1:] if w[0:1].isupper() and not w.isupper()] + if interior_caps: + return True + + # 5. If the message is a question (ends with ?) and is not purely about coding + if text.rstrip().endswith("?"): + coding_signals = {"code", "script", "function", "error", "bug", "koodi", "skripti"} + if not any(sig in t_lower for sig in coding_signals): + return True + + return False + + +# --------------------------------------------------------------------------- +# Category classification +# --------------------------------------------------------------------------- + +CLASSIFIER_SYSTEM_TEMPLATE = """You are a prompt classifier. Output ONLY a JSON object, no other text. + +TODAY'S DATE: {today} + +Categories: +- coding: the user is explicitly asking you to WRITE, FIX, DEBUG, or REVIEW actual code, scripts, config files, or CLI commands. Mentioning IT, software, or technology in conversation is NOT coding — the user must be requesting code output. +- diagram: the user explicitly asks for a diagram, flowchart, architecture drawing, Mermaid, UML, ER diagram, or visual representation +- reasoning: deep analysis, comparisons, strategy, pros/cons, evaluations, "explain in detail", workplace situations, decision-making, persuasion, argumentation +- image_generation: user wants to CREATE/GENERATE/DRAW an image, picture, photo, illustration, artwork, logo +- vision: when the user references or asks about an EXISTING image/picture (already uploaded) +- general: everything else — simple questions, facts, casual conversation, advice, how-to, storytelling, venting about work situations + +IMPORTANT: When the user is discussing work situations, meetings, processes, or asking for advice about IT management/strategy, that is "general" or "reasoning" — NOT "coding". Only classify as "coding" when the user literally wants code written. + +Search decision — set search=true if ANY of these apply: +- The answer might change over time (versions, prices, dates, availability, people in roles) +- The question asks about a specific product, service, tool, or technology by name +- The question asks about events, news, announcements +- The question is about something that happened or will happen at a specific time +- The user asks "what is X" about something that isn't a timeless concept +- You are not 100% sure of the correct answer from general knowledge alone +Default to search=true when uncertain. + +Search query rules: +- Write a concise, effective web search query in the same language as the user. +- When the user says "yesterday", "today", "this week" etc., convert to actual dates based on TODAY'S DATE above. +- The current year is {year}. NEVER use any other year unless the user explicitly mentions one. +- If the topic involves versions or releases, include {year} in the query. + +Output format (ONLY this JSON, nothing else): +{{"category": "general", "search": true, "query": "search terms here"}} + +Examples: +{{"category": "coding", "search": false, "query": ""}} +{{"category": "coding", "search": true, "query": "Python 3.13 new features {year}"}} +{{"category": "general", "search": true, "query": "latest Ubuntu LTS release {year}"}} +{{"category": "reasoning", "search": false, "query": ""}} +{{"category": "general", "search": true, "query": "Microsoft Intune new features {year}"}} +{{"category": "diagram", "search": false, "query": ""}} +{{"category": "reasoning", "search": true, "query": "Azure Landing Zone vs hub-spoke comparison"}} +{{"category": "image_generation", "search": false, "query": ""}}""" + + +def _build_classifier_prompt() -> str: + today = date.today() + return CLASSIFIER_SYSTEM_TEMPLATE.format( + today=today.strftime("%Y-%m-%d"), + year=today.year, + ) + + +def _keyword_classify(text: str) -> str: + """Fast keyword-based fallback classifier.""" + t = text.lower() + + # Image generation — check first because it's the most specific + image_gen_kw = [ + "generate an image", "generate image", "create an image", "create image", + "draw me", "draw a picture", "make an image", "make a picture", + "generate a photo", "create a photo", "make a photo", + "generate art", "create art", "make art", + "generate a picture", "create a picture", + "luo kuva", "generoi kuva", "tee kuva", "piirrä kuva", + "luo valokuva", "tee valokuva", + ] + for kw in image_gen_kw: + if kw in t: + return "image_generation" + + # Diagram + diagram_kw = [ + "kaavio", "piirrä", "diagram", "flowchart", "sequence diagram", + "arkkitehtuurikuva", "mermaid", "vuokaavio", "uml", "er-kaavio", + "gantt", "mindmap", "visualisoi", "visualize", "draw a diagram", + "draw a chart", "draw a flow", + ] + for kw in diagram_kw: + if kw in t: + return "diagram" + + # Coding — be strict to avoid false positives + coding_kw = [ + "write code", "write a script", "kirjoita koodi", "kirjoita skripti", + "koodaa", "debug", "refactor", "fix this code", "fix this script", + "powershell", "bash script", "python script", "dockerfile", + "terraform", "ansible", "yaml config", "json schema", + "function that", "class that", "regex for", "kirjoita funktio", + ] + for kw in coding_kw: + if kw in t: + return "coding" + + # Strong reasoning signals (multi-word to reduce false positives) + reasoning_kw = [ + "analysoi yksityiskohtaisesti", "vertaile näitä", "arvioi", + "analyze in detail", "compare these", "pros and cons", + "hyvät ja huonot puolet", "explain in detail", "deep dive", + "selitä yksityiskohtaisesti", "strategia", "strategy for", + ] + for kw in reasoning_kw: + if kw in t: + return "reasoning" + + return "general" + + +def _ai_classify(prompt: str, ollama_url: str) -> dict: + """ + Call the classifier LLM. Returns dict with keys: category, search, query, method. + On failure returns a keyword-based fallback. + """ + fallback = { + "category": _keyword_classify(prompt), + "search": heuristic_needs_search(prompt), + "query": "", + "method": "fallback", + } + + try: + payload = { + "model": CLASSIFIER_MODEL, + "messages": [ + {"role": "system", "content": _build_classifier_prompt()}, + {"role": "user", "content": prompt[:1500]}, + ], + "stream": False, + "options": {"temperature": 0, "num_ctx": 2048}, + } + + resp = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=15) + resp.raise_for_status() + + raw = resp.json()["message"]["content"].strip() + # Strip markdown fences if the model wraps its output + raw = re.sub(r"^```(?:json)?\s*", "", raw) + raw = re.sub(r"\s*```$", "", raw) + + parsed = json.loads(raw) + category = parsed.get("category", "general").lower().strip() + search = bool(parsed.get("search", False)) + query = parsed.get("query", "").strip() + + if category not in ("coding", "diagram", "reasoning", "vision", "image_generation", "general"): + print(f"[Router] Unknown category '{category}', falling back to keyword") + category = _keyword_classify(prompt) + + return { + "category": category, + "search": search, + "query": query, + "method": "ai", + } + + except json.JSONDecodeError as e: + print(f"[Router] Classifier returned invalid JSON: {e}") + return fallback + except requests.RequestException as e: + print(f"[Router] Classifier request failed: {e}") + return fallback + except Exception as e: + print(f"[Router] Classifier error: {e}") + return fallback + + +def classify(prompt: str, ollama_url: str, use_ai: bool) -> dict: + """ + Main classification entry point. + Returns dict: category (resolved to model key), search, query, method. + + The heuristic search layer acts as a safety net: even if the AI says + search=false, heuristics can override to search=true. + """ + if use_ai: + result = _ai_classify(prompt, ollama_url) + else: + result = { + "category": _keyword_classify(prompt), + "search": False, + "query": "", + "method": "keyword", + } + + # --- Heuristic overrides --- + + # 1. Force search if heuristics say so (even if AI said no) + if not result["search"] and heuristic_needs_search(prompt): + result["search"] = True + result["method"] += "+heuristic_search" + print("[Router] Heuristic override: forcing search=true") + + # 2. Generate a search query if search is needed but query is empty + if result["search"] and not result["query"]: + result["query"] = _generate_search_query(prompt) + + # 3. Inject year into version/release queries + if result["search"] and result["query"]: + result["query"] = _inject_year(result["query"]) + + # 4. Resolve reasoning → reasoning_fi / reasoning_en + if result["category"] == "reasoning": + result["category"] = resolve_reasoning_key(prompt) + + return result + + +def _generate_search_query(prompt: str) -> str: + """Extract a reasonable search query from the user prompt.""" + # Take the first sentence or first 120 chars, whichever is shorter + first_sentence = re.split(r"[.!?\n]", prompt)[0].strip() + query = first_sentence[:120] + # Remove filler words for a tighter query + for filler in ["please", "can you", "could you", "tell me", "I want to know", + "voitko", "kerro", "haluaisin tietää", "ole hyvä"]: + query = re.sub(re.escape(filler), "", query, flags=re.IGNORECASE) + return query.strip() or prompt[:80] + + +_VERSION_KEYWORDS = [ + "latest", "newest", "current", "recent", "release", "version", + "uusin", "viimeisin", "nykyinen", "tuorein", "versio", "julkaisu", +] + + +def _inject_year(query: str) -> str: + """ + Fix year issues in search queries: + 1. Replace any hallucinated wrong year (e.g. 2023, 2024) with the current year + 2. Append the current year if version/release keywords are present but no year is + """ + current_year = date.today().year + current_str = str(current_year) + + # Step 1: Replace wrong years that the classifier may have hallucinated. + # We consider any 4-digit year from 2020 to current_year-1 as potentially wrong, + # UNLESS the user's original query explicitly contained that year (we can't check + # that here, but the classifier prompt now tells it the correct year, so a wrong + # year in the query is almost certainly hallucinated). + wrong_year_pattern = re.compile(r"\b(20(?:2[0-9]|3[0-9]))\b") + def _replace_year(m: re.Match) -> str: + y = int(m.group(1)) + if y != current_year and y < current_year: + return current_str + return m.group(0) + + fixed = wrong_year_pattern.sub(_replace_year, query) + + # Step 2: If no year present at all and version keywords exist, append current year + if current_str not in fixed: + if any(kw in fixed.lower() for kw in _VERSION_KEYWORDS): + fixed = f"{fixed} {current_str}" + + return fixed + + +# --------------------------------------------------------------------------- +# Brave Search with page content fetching +# --------------------------------------------------------------------------- + +def brave_search(query: str, api_key: str, max_results: int = 6, status_callback=None) -> str: + """ + Search Brave and fetch top page contents for richer context. + Returns a formatted string suitable for injection into the system prompt. + status_callback: optional function(str) called with real-time status updates. + """ + def _status(msg: str): + if status_callback: + status_callback(msg) + + if not api_key: + return "⚠️ Brave API key not configured." + + try: + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": api_key, + } + params = { + "q": query, + "count": max_results, + "text_decorations": False, + "search_lang": "fi" if detect_finnish(query) else "en", + } + + _status(f"Searching: *{query}*") + resp = requests.get(BRAVE_SEARCH_URL, headers=headers, params=params, timeout=10) + resp.raise_for_status() + + data = resp.json() + web_results = data.get("web", {}).get("results", []) + + if not web_results: + return f"No web results found for: {query}" + + _status(f"Found {len(web_results)} results") + + # Also grab any infobox / knowledge graph snippet + infobox_text = "" + infobox = data.get("infobox", {}) + if isinstance(infobox, dict): + long_desc = infobox.get("long_desc", "") + if long_desc: + infobox_text = f"Knowledge panel: {long_desc}\n\n" + + # Build results with page content fetching for top 3 + sections = [] + if infobox_text: + sections.append(infobox_text) + + for i, r in enumerate(web_results): + title = r.get("title", "") + url = r.get("url", "") + desc = r.get("description", "") + age = r.get("age", "") + age_str = f" ({age})" if age else "" + + # Fetch full page content for the top 3 results + page_content = "" + if i < 3: + _status(f"Reading [{i+1}/{min(3, len(web_results))}]: [{title}]({url})") + page_content = _fetch_page_content(url) + + section = f"[{i+1}] {title}{age_str}\nURL: {url}\nSnippet: {desc}" + if page_content: + section += f"\nContent:\n{page_content}" + sections.append(section) + + _status("Search complete") + return f"Web search results for: {query}\n{'='*60}\n\n" + "\n\n---\n\n".join(sections) + + except Exception as e: + print(f"[Router] Brave Search failed: {e}") + return f"Web search failed: {e}" + + +def _fetch_page_content(url: str, max_chars: int = 3000) -> str: + """Fetch and extract readable text from a URL. Returns truncated plain text.""" + try: + headers = { + "User-Agent": "Mozilla/5.0 (compatible; LLMRouter/3.0)", + "Accept": "text/html,application/xhtml+xml", + } + resp = requests.get(url, headers=headers, timeout=8, allow_redirects=True) + resp.raise_for_status() + + content_type = resp.headers.get("Content-Type", "") + if "html" not in content_type and "text" not in content_type: + return "" + + html = resp.text + + # Lightweight HTML → text extraction (no BeautifulSoup dependency) + # Remove script, style, nav, header, footer tags and their contents + for tag in ["script", "style", "nav", "header", "footer", "aside", "noscript"]: + html = re.sub(rf"<{tag}[^>]*>.*?", " ", html, flags=re.S | re.I) + + # Remove all remaining HTML tags + text = re.sub(r"<[^>]+>", " ", html) + + # Decode HTML entities + text = text.replace("&", "&").replace("<", "<").replace(">", ">") + text = text.replace(""", '"').replace("'", "'").replace(" ", " ") + text = re.sub(r"&#\d+;", " ", text) + text = re.sub(r"&\w+;", " ", text) + + # Collapse whitespace + text = re.sub(r"\s+", " ", text).strip() + + if len(text) < 50: + return "" + + return text[:max_chars] + + except Exception: + return "" + + +# --------------------------------------------------------------------------- +# Stable Diffusion image generation +# --------------------------------------------------------------------------- + +def _refine_sd_prompt(user_message: str, ollama_url: str, messages: List[dict] = None) -> str: + """Use the LLM to convert a user request into an optimized SD prompt. + Includes conversation history so the model understands context like 'generate an image of that'. + """ + try: + # Build context from recent conversation history + context_messages = [{"role": "system", "content": SYSTEM_PROMPTS["image_generation"]}] + if messages: + # Include last few exchanges for context (trim to avoid blowing up the context) + recent = [m for m in messages if m.get("role") in ("user", "assistant") and m.get("content")] + for msg in recent[-6:]: # Last 3 exchanges + content = msg["content"] + if isinstance(content, list): + content = " ".join(p.get("text", "") for p in content if isinstance(p, dict)) + context_messages.append({"role": msg["role"], "content": content[:500]}) + else: + context_messages.append({"role": "user", "content": user_message[:500]}) + + payload = { + "model": MODELS["image_generation"], + "messages": context_messages, + "stream": False, + "options": {"temperature": 0.7, "num_ctx": 4096}, + } + resp = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=30) + resp.raise_for_status() + refined = resp.json()["message"]["content"].strip() + # Strip any accidental markdown or quotes the model might add + refined = refined.strip('"\'`') + refined = re.sub(r"^```\w*\s*", "", refined) + refined = re.sub(r"\s*```$", "", refined) + return refined + except Exception as e: + print(f"[Router] SD prompt refinement failed: {e}") + # Fallback: use the user message directly + return user_message + + +def _negative_prompt() -> str: + """Standard negative prompt for SD.""" + return ( + "lowres, bad anatomy, bad hands, text, error, missing fingers, " + "extra digit, fewer digits, cropped, worst quality, low quality, " + "normal quality, jpeg artifacts, signature, watermark, username, blurry, " + "deformed, distorted, disfigured, mutation, mutated, ugly" + ) + + +def _compress_image(b64_png: str, quality: int = 80) -> str: + """Convert a base64 PNG from SD to a smaller base64 JPEG.""" + try: + from PIL import Image + img_data = base64.b64decode(b64_png) + img = Image.open(BytesIO(img_data)) + if img.mode == "RGBA": + img = img.convert("RGB") + buf = BytesIO() + img.save(buf, format="JPEG", quality=quality, optimize=True) + return base64.b64encode(buf.getvalue()).decode("utf-8") + except Exception as e: + print(f"[Router] Image compression failed: {e}, using original") + return b64_png + + +def _unload_ollama_models(ollama_url: str): + """Unload all Ollama models from VRAM to make room for image generation.""" + try: + # List running models + resp = requests.get(f"{ollama_url}/api/ps", timeout=5) + if resp.ok: + models = resp.json().get("models", []) + for model in models: + name = model.get("name", "") + if name: + # Setting keep_alive to 0 unloads the model immediately + requests.post( + f"{ollama_url}/api/generate", + json={"model": name, "keep_alive": 0}, + timeout=10, + ) + print(f"[Router] Unloaded Ollama model: {name}") + except Exception as e: + print(f"[Router] Failed to unload Ollama models: {e}") + + +def _cleanup_after_generation(sd_url: str): + """Free VRAM and RAM after image generation so Ollama can load models.""" + # 1. Unload SD checkpoint from VRAM + try: + requests.post(f"{sd_url}/sdapi/v1/unload-checkpoint", timeout=5) + print("[Router] SD checkpoint unloaded from VRAM") + except Exception: + pass + + # 2. Drop Linux page cache to free RAM + try: + os.system("sync; echo 3 > /proc/sys/vm/drop_caches 2>/dev/null") + print("[Router] Page cache dropped") + except Exception: + pass + + +def generate_image( + user_message: str, + ollama_url: str, + sd_url: str, + width: int = 512, + height: int = 512, + steps: int = 30, + cfg_scale: float = 7.0, + messages: List[dict] = None, +) -> tuple: + """ + Generate an image via AUTOMATIC1111 API. + Returns (base64_image, refined_prompt) on success, or (None, error_message) on failure. + """ + # Step 1: Refine the prompt using the LLM FIRST (while Ollama is still loaded) + refined_prompt = _refine_sd_prompt(user_message, ollama_url, messages) + + # Step 2: Unload Ollama models from VRAM to make room for SDXL + _unload_ollama_models(ollama_url) + print(f"[Router] SD prompt: {refined_prompt[:120]}") + + # Step 2: Call AUTOMATIC1111 + try: + payload = { + "prompt": refined_prompt, + "negative_prompt": _negative_prompt(), + "width": width, + "height": height, + "steps": steps, + "cfg_scale": cfg_scale, + "sampler_name": "DPM++ 2M Karras", + "batch_size": 1, + "n_iter": 1, + } + resp = requests.post( + f"{sd_url}/sdapi/v1/txt2img", + json=payload, + timeout=120, + ) + resp.raise_for_status() + + data = resp.json() + images = data.get("images", []) + if not images: + return None, "Stable Diffusion returned no images." + + # Compress PNG→JPEG to reduce base64 size for streaming + compressed = _compress_image(images[0]) + + # Free VRAM and RAM so Ollama can load models again + _cleanup_after_generation(sd_url) + + return compressed, refined_prompt + + except requests.exceptions.ConnectionError: + return None, f"Cannot connect to Stable Diffusion at {sd_url}. Is it running?" + except requests.exceptions.Timeout: + return None, "Image generation timed out (>120s)." + except Exception as e: + return None, f"Image generation failed: {e}" + + +# --------------------------------------------------------------------------- +# Image handling (unchanged from v2) +# --------------------------------------------------------------------------- + +def extract_images_from_messages(messages: List[dict]) -> tuple: + """Separate base64 images from message content.""" + images = [] + clean_messages = [] + + for msg in messages: + content = msg.get("content", "") + if isinstance(content, list): + text_parts = [] + for part in content: + if isinstance(part, dict): + if part.get("type") == "text": + text_parts.append(part.get("text", "")) + elif part.get("type") == "image_url": + url = part.get("image_url", {}).get("url", "") + if url.startswith("data:"): + match = re.match(r"data:[^;]+;base64,(.+)", url) + if match: + images.append(match.group(1)) + clean_messages.append({ + "role": msg["role"], + "content": " ".join(text_parts).strip(), + }) + else: + clean_messages.append(msg) + + return images, clean_messages + + +def has_image_content(messages: List[dict]) -> bool: + """Check if the latest user message contains an uploaded image. + Only checks user messages (not assistant responses which may contain generated images). + """ + # Find the last user message + for msg in reversed(messages): + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image_url": + return True + elif isinstance(content, str) and "data:image" in content: + return True + return False # Last user message found but has no image + return False + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + +class Pipeline: + class Valves(BaseModel): + ollama_url: str = Field(default=OLLAMA_URL, description="Ollama API base URL") + sd_url: str = Field(default=SD_URL, description="AUTOMATIC1111 Stable Diffusion WebUI URL") + sd_width: int = Field(default=1024, description="Generated image width") + sd_height: int = Field(default=1024, description="Generated image height") + sd_steps: int = Field(default=25, description="Stable Diffusion sampling steps") + sd_cfg_scale: float = Field(default=7.0, description="Stable Diffusion CFG scale") + brave_api_key: str = Field(default=BRAVE_API_KEY, description="Brave Search API key") + brave_max_results: int = Field(default=6, description="Number of Brave search results to fetch") + brave_fetch_pages: int = Field(default=3, description="Number of top results to fetch full page content for") + use_ai_classifier: bool = Field(default=True, description="Use AI classifier (vs keyword-only)") + show_routing_info: bool = Field(default=True, description="Show routing banner in responses") + search_context_max_chars: int = Field(default=12000, description="Max chars of search context to inject") + + def __init__(self): + self.id = "llm-router" + self.name = "LLM Router v3" + self.valves = self.Valves() + + async def on_startup(self): + print(f"[Router] LLM Router v3.0 starting — Ollama: {self.valves.ollama_url}") + print(f"[Router] Classifier: {CLASSIFIER_MODEL} | AI: {self.valves.use_ai_classifier}") + print(f"[Router] Brave Search: {'configured' if self.valves.brave_api_key else 'NO API KEY'}") + + async def on_shutdown(self): + print("[Router] LLM Router v3.0 shutting down") + + def pipe( + self, + user_message: str, + model_id: str, + messages: List[dict], + body: dict, + ) -> Iterator[str]: + + # --- Step 1: Vision override --- + if has_image_content(messages): + category = "vision" + needs_search = False + search_query = "" + method = "vision_detect" + else: + # --- Step 2: Classify --- + result = classify( + user_message, + self.valves.ollama_url, + self.valves.use_ai_classifier, + ) + category = result["category"] + needs_search = result["search"] + search_query = result["query"] + method = result["method"] + + target_model = MODELS.get(category, MODELS["general"]) + system_prompt = SYSTEM_PROMPTS.get(category, SYSTEM_PROMPTS["general"]) + + # Inject language instruction — always respond in the user's language + if detect_finnish(user_message) and category not in ("reasoning_fi", "image_generation"): + system_prompt = ( + "TÄRKEÄ: Käyttäjä kirjoittaa suomeksi. Vastaa AINA suomeksi.\n\n" + + system_prompt + ) + + print(f"[Router] {method} → {category} → {target_model} | search={needs_search} query='{search_query}'") + + # --- Step 3: Routing info banner --- + if self.valves.show_routing_info: + display_cat = category.replace("_en", " 🇬🇧").replace("_fi", " 🇫🇮") + search_label = f" | 🌐 `{search_query}`" if needs_search else "" + yield f"> 🔀 **Router** `[{method}]` → `{target_model}` *(category: {display_cat}){search_label}*\n\n" + + # --- Step 4: Image generation (early return) --- + if category == "image_generation": + yield "> 🎨 Generating image…\n\n" + base64_img, refined_prompt = generate_image( + user_message, + self.valves.ollama_url, + self.valves.sd_url, + width=self.valves.sd_width, + height=self.valves.sd_height, + steps=self.valves.sd_steps, + cfg_scale=self.valves.sd_cfg_scale, + messages=messages, + ) + if base64_img: + # Yield the image in chunks to avoid "chunk too big" errors + img_tag = f"![Generated image](data:image/jpeg;base64,{base64_img})" + chunk_size = 4096 + for i in range(0, len(img_tag), chunk_size): + yield img_tag[i:i + chunk_size] + yield "\n\n" + yield f"*Prompt used: {refined_prompt}*\n" + else: + yield f"\n\n❌ {refined_prompt}\n" + return + + # --- Step 5: Web search --- + search_context = "" + search_status_lines = [] + if needs_search and search_query and self.valves.brave_api_key: + # Collect status updates via callback, yield them in real time + def _on_status(msg: str): + search_status_lines.append(msg) + + yield "> 🔍 Searching the web…\n\n" + + # We need to yield status in real-time, so we run search in a thread + import threading + search_result = [None] + def _run_search(): + search_result[0] = brave_search( + search_query, + self.valves.brave_api_key, + max_results=self.valves.brave_max_results, + status_callback=_on_status, + ) + + t = threading.Thread(target=_run_search) + t.start() + + last_count = 0 + while t.is_alive(): + t.join(timeout=0.3) + # Yield any new status lines + while last_count < len(search_status_lines): + yield f"> {search_status_lines[last_count]}\n>\n" + last_count += 1 + # Yield any remaining status lines + while last_count < len(search_status_lines): + yield f"> {search_status_lines[last_count]}\n>\n" + last_count += 1 + + yield "\n\n" + search_context = search_result[0] or "" + # Truncate if too large + if len(search_context) > self.valves.search_context_max_chars: + search_context = search_context[:self.valves.search_context_max_chars] + "\n\n[...truncated]" + + # Detect failed search and warn the user + if (search_context.startswith("⚠️") + or search_context.startswith("Web search failed") + or search_context.startswith("No web results found")): + yield "> ⚠️ Web search failed — answering from model knowledge.\n\n" + print(f"[Router] Search failed: {search_context[:120]}") + else: + print(f"[Router] Search complete: {len(search_context)} chars") + + elif needs_search and not self.valves.brave_api_key: + yield "> ⚠️ Web search not available (no API key) — answering from model knowledge.\n\n" + print("[Router] Search needed but no Brave API key configured!") + + # --- Step 6: Build messages --- + images, clean_messages = extract_images_from_messages(messages) + + # Check if search actually returned usable results (not just an error) + search_ok = bool( + search_context + and not search_context.startswith("⚠️") + and not search_context.startswith("Web search failed") + and not search_context.startswith("No web results found") + ) + + if search_ok: + today = date.today().strftime("%Y-%m-%d") + full_system = ( + f"{system_prompt}\n\n" + f"Today's date: {today}\n\n" + f"## Web Search Results\n" + f"The following are fresh web search results. Use them as your PRIMARY source of truth.\n" + f"Your training data may be outdated — always prefer information from these results.\n" + f"When results conflict, prefer the most recent one (check dates/ages).\n" + f"Cite the source URL when stating specific facts.\n" + f"If the search results don't contain enough information to fully answer, " + f"say so honestly rather than guessing.\n\n" + f"{search_context}" + ) + elif needs_search: + # Search was requested but failed — tell the model so it can be honest + today = date.today().strftime("%Y-%m-%d") + full_system = ( + f"{system_prompt}\n\n" + f"Today's date: {today}\n\n" + f"NOTE: A web search was attempted for this question but failed or returned no results. " + f"Answer as best you can from your training data, but clearly state that you could not " + f"verify the information with a live web search and the answer may be outdated." + ) + else: + full_system = system_prompt + + ollama_messages = [{"role": "system", "content": full_system}] + for msg in clean_messages: + if msg.get("role") in ("user", "assistant") and msg.get("content"): + ollama_messages.append(msg) + + # Attach images to the last user message + if images: + for i in range(len(ollama_messages) - 1, -1, -1): + if ollama_messages[i]["role"] == "user": + ollama_messages[i]["images"] = images + break + + # --- Step 7: Call the target model --- + payload = { + "model": target_model, + "messages": ollama_messages, + "stream": True, + "options": { + "temperature": body.get("temperature", 0.7), + "num_ctx": 8192, + }, + } + + try: + resp = requests.post( + f"{self.valves.ollama_url}/api/chat", + json=payload, + stream=True, + timeout=180, + ) + resp.raise_for_status() + + in_thinking = False + for line in resp.iter_lines(): + if line: + try: + chunk = json.loads(line) + except json.JSONDecodeError: + continue + + msg = chunk.get("message", {}) + + # Handle thinking/reasoning tokens (displayed in a collapsible block) + thinking_content = msg.get("thinking", "") + if thinking_content: + if not in_thinking: + yield "
\n💭 Thinking…\n\n" + in_thinking = True + yield thinking_content + + # Handle regular content + if msg.get("content"): + if in_thinking: + yield "\n
\n\n" + in_thinking = False + yield msg["content"] + + if chunk.get("done"): + if in_thinking: + yield "\n\n\n" + break + + except requests.exceptions.ConnectionError: + yield f"\n\n❌ Connection error to Ollama ({self.valves.ollama_url}). Is the service running?" + except requests.exceptions.Timeout: + yield "\n\n❌ Timeout — the model is responding too slowly." + except Exception as e: + yield f"\n\n❌ Error: {str(e)}" diff --git a/setup-sd-service.sh b/setup-sd-service.sh new file mode 100644 index 0000000..59c8387 --- /dev/null +++ b/setup-sd-service.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Create a systemd service for AUTOMATIC1111 so it starts on boot +# Run this AFTER setup-sd.sh has completed successfully + +set -e + +SD_DIR="$HOME/stable-diffusion-webui" +SERVICE_FILE="/etc/systemd/system/stable-diffusion.service" +CURRENT_USER=$(whoami) + +echo "Creating systemd service for Stable Diffusion WebUI..." + +sudo tee "$SERVICE_FILE" > /dev/null <