1071 lines
43 KiB
Python
1071 lines
43 KiB
Python
"""
|
|
LLM Router Pipeline — v3.0
|
|
|
|
AI-powered prompt classifier + aggressive web search with full page fetching.
|
|
Uses qwen2.5:7b as zero-shot classifier backed by robust heuristic overrides.
|
|
|
|
Routing:
|
|
- coding → qwen2.5-coder:14b
|
|
- diagram → qwen2.5-coder:14b
|
|
- reasoning (FI) → gpt-oss:120b
|
|
- reasoning (EN) → gpt-oss:120b
|
|
- image_generation → gpt-oss:120b (prompt refinement) → AUTOMATIC1111 (Stable Diffusion)
|
|
- vision → llama3.2-vision:11b
|
|
- general → gpt-oss:120b
|
|
|
|
Web search is triggered aggressively — any question that might benefit from
|
|
current information will trigger a Brave Search + page content fetch.
|
|
"""
|
|
|
|
from typing import List, Iterator, Optional
|
|
from pydantic import BaseModel, Field
|
|
from datetime import date, datetime
|
|
import requests
|
|
import json
|
|
import re
|
|
import os
|
|
import base64
|
|
from io import BytesIO
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration defaults
|
|
# ---------------------------------------------------------------------------
|
|
OLLAMA_URL = "http://ollama:11434"
|
|
CLASSIFIER_MODEL = "qwen2.5:7b"
|
|
BRAVE_API_KEY = os.environ.get("BRAVE_API_KEY", "")
|
|
BRAVE_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
|
|
|
|
MODELS = {
|
|
"coding": "qwen2.5-coder:14b",
|
|
"diagram": "qwen2.5-coder:14b",
|
|
"reasoning_en": "gpt-oss:120b",
|
|
"reasoning_fi": "gpt-oss:120b",
|
|
"vision": "llama3.2-vision:11b",
|
|
"image_generation": "gpt-oss:120b",
|
|
"general": "gpt-oss:120b",
|
|
}
|
|
|
|
SD_URL = "http://172.18.0.1:7860"
|
|
|
|
SYSTEM_PROMPTS = {
|
|
"image_generation": (
|
|
"You are a Stable Diffusion prompt engineer. The user wants an image generated. "
|
|
"Your ONLY job is to convert their request into an optimized Stable Diffusion prompt. "
|
|
"Output ONLY the prompt text, nothing else — no explanation, no markdown, no quotes. "
|
|
"Use comma-separated tags and descriptors. Include quality boosters like: "
|
|
"masterpiece, best quality, highly detailed, sharp focus, professional. "
|
|
"Include style, lighting, and composition details. "
|
|
"If the user writes in Finnish, still output the SD prompt in English."
|
|
),
|
|
"coding": (
|
|
"You are an expert programmer and DevOps engineer. "
|
|
"Provide clean, well-commented code. Use best practices. "
|
|
"If the user writes in Finnish, respond in Finnish but keep code and comments in English."
|
|
),
|
|
"diagram": (
|
|
"You are a technical architect who communicates through diagrams. "
|
|
"Always respond with a Mermaid diagram inside a ```mermaid code block. "
|
|
"After the diagram, provide a short explanation. "
|
|
"If the user writes in Finnish, the explanation should be in Finnish. "
|
|
"Use flowchart TD, sequenceDiagram, classDiagram, erDiagram or gantt as appropriate."
|
|
),
|
|
"reasoning_en": (
|
|
"Reasoning: high\n"
|
|
"You are an analytical expert. Think step by step, consider multiple perspectives, "
|
|
"and provide well-structured, thorough responses."
|
|
),
|
|
"reasoning_fi": (
|
|
"Olet analyyttinen asiantuntija. Ajattele vaihe vaiheelta, harkitse useita näkökulmia "
|
|
"ja anna hyvin jäsennelty, perusteellinen vastaus suomen kielellä."
|
|
),
|
|
"vision": (
|
|
"You are a visual analysis assistant. Describe and analyze images in detail. "
|
|
"If the user writes in Finnish, respond in Finnish."
|
|
),
|
|
"general": (
|
|
"You are a helpful, friendly assistant. You speak both Finnish and English fluently — "
|
|
"always respond in the same language the user is using. "
|
|
"Be concise but thorough."
|
|
),
|
|
}
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Finnish language detection
|
|
# ---------------------------------------------------------------------------
|
|
FINNISH_INDICATORS = [
|
|
"ä", "ö", "analysoi", "vertaile", "arvioi", "selitä", "miksi", "miten",
|
|
"mitä", "kuinka", "kerro", "anna", "tee", "kirjoita", "auta", "voitko",
|
|
"pitäisi", "kannattaa", "parempi", "huonompi", "ero", "hyödyt", "haitat",
|
|
"onko", "missä", "milloin", "paljonko", "kuka", "mikä",
|
|
]
|
|
|
|
|
|
def detect_finnish(text: str) -> bool:
|
|
t = text.lower()
|
|
return sum(1 for ind in FINNISH_INDICATORS if ind in t) >= 2
|
|
|
|
|
|
def resolve_reasoning_key(prompt: str) -> str:
|
|
return "reasoning_fi" if detect_finnish(prompt) else "reasoning_en"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Search-need heuristics (the "safety net" that fires BEFORE the AI classifier)
|
|
#
|
|
# The philosophy: if in doubt, search. A redundant search costs ~200 ms.
|
|
# A missed search costs the user a hallucinated or stale answer.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Question-word patterns (EN + FI)
|
|
_QUESTION_WORDS_EN = (
|
|
r"\b(?:what|which|who|whom|whose|where|when|why|how|is|are|was|were|do|does|did"
|
|
r"|can|could|will|would|should|has|have|had)\b"
|
|
)
|
|
_QUESTION_WORDS_FI = (
|
|
r"\b(?:mikä|mitä|kuka|kenen|missä|minne|mistä|milloin|miksi|miten|kuinka"
|
|
r"|onko|ovatko|oliko|voiko|pitääkö|saako|paljonko|kumpi|montako)\b"
|
|
)
|
|
|
|
# Time-sensitive / factual trigger words
|
|
_FRESHNESS_KEYWORDS = {
|
|
# English
|
|
"latest", "newest", "current", "recent", "update", "updated", "today",
|
|
"yesterday", "this week", "this month", "this year", "2024", "2025", "2026",
|
|
"now", "right now", "currently", "as of", "still", "anymore",
|
|
"release", "released", "version", "price", "cost", "stock", "weather",
|
|
"score", "election", "news", "announcement", "launched", "launches",
|
|
"available", "discontinued", "deprecated", "roadmap", "deadline",
|
|
"schedule", "status", "outage", "incident",
|
|
# Finnish
|
|
"uusin", "viimeisin", "nykyinen", "tuorein", "päivitetty", "tänään",
|
|
"eilen", "tällä viikolla", "tässä kuussa", "tänä vuonna",
|
|
"hinta", "sää", "tulos", "uutinen", "julkaistu", "julkaisu",
|
|
"versio", "saatavilla", "poistunut", "aikataulu", "tilanne",
|
|
}
|
|
|
|
# Patterns that almost always need search
|
|
_SEARCH_PATTERNS = [
|
|
# "What is the latest X", "Who is the CEO of Y", etc.
|
|
re.compile(r"(?:what|which|who|where|when)\s+(?:is|are|was|were)\s+(?:the\s+)?(?:latest|current|new)", re.I),
|
|
# "How much does X cost"
|
|
re.compile(r"how\s+much\s+(?:does|do|is|are|will)", re.I),
|
|
# URLs or domain names in the prompt → user is asking about a specific site/service
|
|
re.compile(r"https?://|www\.", re.I),
|
|
# Explicit search requests
|
|
re.compile(r"\b(?:search|google|look\s*up|find\s+out|hae|etsi|googlaa)\b", re.I),
|
|
# "Is X still Y" / "Does X still Y"
|
|
re.compile(r"\b(?:is|does|are|do)\s+\w+\s+still\b", re.I),
|
|
# Finnish question forms ending with -ko/-kö
|
|
re.compile(r"\b\w+(?:ko|kö)\b", re.I),
|
|
]
|
|
|
|
# Topics that are inherently time-sensitive
|
|
_SEARCH_TOPIC_PATTERNS = [
|
|
re.compile(r"\b(?:CVE|vulnerability|exploit|zero[- ]?day|security\s+(?:patch|update|advisory))\b", re.I),
|
|
re.compile(r"\b(?:LTS|EOL|end[- ]of[- ]life|support\s+(?:lifecycle|end))\b", re.I),
|
|
re.compile(r"\b(?:Windows\s+1[12]|Ubuntu\s+\d|macOS\s+\w+|iOS\s+\d|Android\s+\d)\b", re.I),
|
|
re.compile(r"\b(?:GPT|Claude|Gemini|Llama|Mistral|Qwen|DeepSeek)\s*[-.]?\s*\d", re.I),
|
|
re.compile(r"\b(?:Azure|AWS|GCP)\s+\w+", re.I),
|
|
]
|
|
|
|
|
|
def heuristic_needs_search(text: str) -> bool:
|
|
"""Return True if heuristics say this message almost certainly needs a web search."""
|
|
t_lower = text.lower()
|
|
|
|
# 1. Any freshness keyword present?
|
|
for kw in _FRESHNESS_KEYWORDS:
|
|
if kw in t_lower:
|
|
return True
|
|
|
|
# 2. Regex patterns
|
|
for pat in _SEARCH_PATTERNS:
|
|
if pat.search(text):
|
|
return True
|
|
|
|
# 3. Time-sensitive topics
|
|
for pat in _SEARCH_TOPIC_PATTERNS:
|
|
if pat.search(text):
|
|
return True
|
|
|
|
# 4. Question with a named entity (capitalised word after question word) → likely factual
|
|
if re.search(_QUESTION_WORDS_EN, t_lower) or re.search(_QUESTION_WORDS_FI, t_lower):
|
|
# If the message contains a proper noun (capitalized word that isn't sentence-start)
|
|
words = text.split()
|
|
if len(words) > 2:
|
|
interior_caps = [w for w in words[1:] if w[0:1].isupper() and not w.isupper()]
|
|
if interior_caps:
|
|
return True
|
|
|
|
# 5. If the message is a question (ends with ?) and is not purely about coding
|
|
if text.rstrip().endswith("?"):
|
|
coding_signals = {"code", "script", "function", "error", "bug", "koodi", "skripti"}
|
|
if not any(sig in t_lower for sig in coding_signals):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Category classification
|
|
# ---------------------------------------------------------------------------
|
|
|
|
CLASSIFIER_SYSTEM_TEMPLATE = """You are a prompt classifier. Output ONLY a JSON object, no other text.
|
|
|
|
TODAY'S DATE: {today}
|
|
|
|
Categories:
|
|
- coding: the user is explicitly asking you to WRITE, FIX, DEBUG, or REVIEW actual code, scripts, config files, or CLI commands. Mentioning IT, software, or technology in conversation is NOT coding — the user must be requesting code output.
|
|
- diagram: the user explicitly asks for a diagram, flowchart, architecture drawing, Mermaid, UML, ER diagram, or visual representation
|
|
- reasoning: deep analysis, comparisons, strategy, pros/cons, evaluations, "explain in detail", workplace situations, decision-making, persuasion, argumentation
|
|
- image_generation: user wants to CREATE/GENERATE/DRAW an image, picture, photo, illustration, artwork, logo
|
|
- vision: when the user references or asks about an EXISTING image/picture (already uploaded)
|
|
- general: everything else — simple questions, facts, casual conversation, advice, how-to, storytelling, venting about work situations
|
|
|
|
IMPORTANT: When the user is discussing work situations, meetings, processes, or asking for advice about IT management/strategy, that is "general" or "reasoning" — NOT "coding". Only classify as "coding" when the user literally wants code written.
|
|
|
|
Search decision — set search=true if ANY of these apply:
|
|
- The answer might change over time (versions, prices, dates, availability, people in roles)
|
|
- The question asks about a specific product, service, tool, or technology by name
|
|
- The question asks about events, news, announcements
|
|
- The question is about something that happened or will happen at a specific time
|
|
- The user asks "what is X" about something that isn't a timeless concept
|
|
- You are not 100% sure of the correct answer from general knowledge alone
|
|
Default to search=true when uncertain.
|
|
|
|
Search query rules:
|
|
- Write a concise, effective web search query in the same language as the user.
|
|
- When the user says "yesterday", "today", "this week" etc., convert to actual dates based on TODAY'S DATE above.
|
|
- The current year is {year}. NEVER use any other year unless the user explicitly mentions one.
|
|
- If the topic involves versions or releases, include {year} in the query.
|
|
|
|
Output format (ONLY this JSON, nothing else):
|
|
{{"category": "general", "search": true, "query": "search terms here"}}
|
|
|
|
Examples:
|
|
{{"category": "coding", "search": false, "query": ""}}
|
|
{{"category": "coding", "search": true, "query": "Python 3.13 new features {year}"}}
|
|
{{"category": "general", "search": true, "query": "latest Ubuntu LTS release {year}"}}
|
|
{{"category": "reasoning", "search": false, "query": ""}}
|
|
{{"category": "general", "search": true, "query": "Microsoft Intune new features {year}"}}
|
|
{{"category": "diagram", "search": false, "query": ""}}
|
|
{{"category": "reasoning", "search": true, "query": "Azure Landing Zone vs hub-spoke comparison"}}
|
|
{{"category": "image_generation", "search": false, "query": ""}}"""
|
|
|
|
|
|
def _build_classifier_prompt() -> str:
|
|
today = date.today()
|
|
return CLASSIFIER_SYSTEM_TEMPLATE.format(
|
|
today=today.strftime("%Y-%m-%d"),
|
|
year=today.year,
|
|
)
|
|
|
|
|
|
def _keyword_classify(text: str) -> str:
|
|
"""Fast keyword-based fallback classifier."""
|
|
t = text.lower()
|
|
|
|
# Image generation — check first because it's the most specific
|
|
image_gen_kw = [
|
|
"generate an image", "generate image", "create an image", "create image",
|
|
"draw me", "draw a picture", "make an image", "make a picture",
|
|
"generate a photo", "create a photo", "make a photo",
|
|
"generate art", "create art", "make art",
|
|
"generate a picture", "create a picture",
|
|
"luo kuva", "generoi kuva", "tee kuva", "piirrä kuva",
|
|
"luo valokuva", "tee valokuva",
|
|
]
|
|
for kw in image_gen_kw:
|
|
if kw in t:
|
|
return "image_generation"
|
|
|
|
# Diagram
|
|
diagram_kw = [
|
|
"kaavio", "piirrä", "diagram", "flowchart", "sequence diagram",
|
|
"arkkitehtuurikuva", "mermaid", "vuokaavio", "uml", "er-kaavio",
|
|
"gantt", "mindmap", "visualisoi", "visualize", "draw a diagram",
|
|
"draw a chart", "draw a flow",
|
|
]
|
|
for kw in diagram_kw:
|
|
if kw in t:
|
|
return "diagram"
|
|
|
|
# Coding — be strict to avoid false positives
|
|
coding_kw = [
|
|
"write code", "write a script", "kirjoita koodi", "kirjoita skripti",
|
|
"koodaa", "debug", "refactor", "fix this code", "fix this script",
|
|
"powershell", "bash script", "python script", "dockerfile",
|
|
"terraform", "ansible", "yaml config", "json schema",
|
|
"function that", "class that", "regex for", "kirjoita funktio",
|
|
]
|
|
for kw in coding_kw:
|
|
if kw in t:
|
|
return "coding"
|
|
|
|
# Strong reasoning signals (multi-word to reduce false positives)
|
|
reasoning_kw = [
|
|
"analysoi yksityiskohtaisesti", "vertaile näitä", "arvioi",
|
|
"analyze in detail", "compare these", "pros and cons",
|
|
"hyvät ja huonot puolet", "explain in detail", "deep dive",
|
|
"selitä yksityiskohtaisesti", "strategia", "strategy for",
|
|
]
|
|
for kw in reasoning_kw:
|
|
if kw in t:
|
|
return "reasoning"
|
|
|
|
return "general"
|
|
|
|
|
|
def _ai_classify(prompt: str, ollama_url: str) -> dict:
|
|
"""
|
|
Call the classifier LLM. Returns dict with keys: category, search, query, method.
|
|
On failure returns a keyword-based fallback.
|
|
"""
|
|
fallback = {
|
|
"category": _keyword_classify(prompt),
|
|
"search": heuristic_needs_search(prompt),
|
|
"query": "",
|
|
"method": "fallback",
|
|
}
|
|
|
|
try:
|
|
payload = {
|
|
"model": CLASSIFIER_MODEL,
|
|
"messages": [
|
|
{"role": "system", "content": _build_classifier_prompt()},
|
|
{"role": "user", "content": prompt[:1500]},
|
|
],
|
|
"stream": False,
|
|
"options": {"temperature": 0, "num_ctx": 2048},
|
|
}
|
|
|
|
resp = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=15)
|
|
resp.raise_for_status()
|
|
|
|
raw = resp.json()["message"]["content"].strip()
|
|
# Strip markdown fences if the model wraps its output
|
|
raw = re.sub(r"^```(?:json)?\s*", "", raw)
|
|
raw = re.sub(r"\s*```$", "", raw)
|
|
|
|
parsed = json.loads(raw)
|
|
category = parsed.get("category", "general").lower().strip()
|
|
search = bool(parsed.get("search", False))
|
|
query = parsed.get("query", "").strip()
|
|
|
|
if category not in ("coding", "diagram", "reasoning", "vision", "image_generation", "general"):
|
|
print(f"[Router] Unknown category '{category}', falling back to keyword")
|
|
category = _keyword_classify(prompt)
|
|
|
|
return {
|
|
"category": category,
|
|
"search": search,
|
|
"query": query,
|
|
"method": "ai",
|
|
}
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"[Router] Classifier returned invalid JSON: {e}")
|
|
return fallback
|
|
except requests.RequestException as e:
|
|
print(f"[Router] Classifier request failed: {e}")
|
|
return fallback
|
|
except Exception as e:
|
|
print(f"[Router] Classifier error: {e}")
|
|
return fallback
|
|
|
|
|
|
def classify(prompt: str, ollama_url: str, use_ai: bool) -> dict:
|
|
"""
|
|
Main classification entry point.
|
|
Returns dict: category (resolved to model key), search, query, method.
|
|
|
|
The heuristic search layer acts as a safety net: even if the AI says
|
|
search=false, heuristics can override to search=true.
|
|
"""
|
|
if use_ai:
|
|
result = _ai_classify(prompt, ollama_url)
|
|
else:
|
|
result = {
|
|
"category": _keyword_classify(prompt),
|
|
"search": False,
|
|
"query": "",
|
|
"method": "keyword",
|
|
}
|
|
|
|
# --- Heuristic overrides ---
|
|
|
|
# 1. Force search if heuristics say so (even if AI said no)
|
|
if not result["search"] and heuristic_needs_search(prompt):
|
|
result["search"] = True
|
|
result["method"] += "+heuristic_search"
|
|
print("[Router] Heuristic override: forcing search=true")
|
|
|
|
# 2. Generate a search query if search is needed but query is empty
|
|
if result["search"] and not result["query"]:
|
|
result["query"] = _generate_search_query(prompt)
|
|
|
|
# 3. Inject year into version/release queries
|
|
if result["search"] and result["query"]:
|
|
result["query"] = _inject_year(result["query"])
|
|
|
|
# 4. Resolve reasoning → reasoning_fi / reasoning_en
|
|
if result["category"] == "reasoning":
|
|
result["category"] = resolve_reasoning_key(prompt)
|
|
|
|
return result
|
|
|
|
|
|
def _generate_search_query(prompt: str) -> str:
|
|
"""Extract a reasonable search query from the user prompt."""
|
|
# Take the first sentence or first 120 chars, whichever is shorter
|
|
first_sentence = re.split(r"[.!?\n]", prompt)[0].strip()
|
|
query = first_sentence[:120]
|
|
# Remove filler words for a tighter query
|
|
for filler in ["please", "can you", "could you", "tell me", "I want to know",
|
|
"voitko", "kerro", "haluaisin tietää", "ole hyvä"]:
|
|
query = re.sub(re.escape(filler), "", query, flags=re.IGNORECASE)
|
|
return query.strip() or prompt[:80]
|
|
|
|
|
|
_VERSION_KEYWORDS = [
|
|
"latest", "newest", "current", "recent", "release", "version",
|
|
"uusin", "viimeisin", "nykyinen", "tuorein", "versio", "julkaisu",
|
|
]
|
|
|
|
|
|
def _inject_year(query: str) -> str:
|
|
"""
|
|
Fix year issues in search queries:
|
|
1. Replace any hallucinated wrong year (e.g. 2023, 2024) with the current year
|
|
2. Append the current year if version/release keywords are present but no year is
|
|
"""
|
|
current_year = date.today().year
|
|
current_str = str(current_year)
|
|
|
|
# Step 1: Replace wrong years that the classifier may have hallucinated.
|
|
# We consider any 4-digit year from 2020 to current_year-1 as potentially wrong,
|
|
# UNLESS the user's original query explicitly contained that year (we can't check
|
|
# that here, but the classifier prompt now tells it the correct year, so a wrong
|
|
# year in the query is almost certainly hallucinated).
|
|
wrong_year_pattern = re.compile(r"\b(20(?:2[0-9]|3[0-9]))\b")
|
|
def _replace_year(m: re.Match) -> str:
|
|
y = int(m.group(1))
|
|
if y != current_year and y < current_year:
|
|
return current_str
|
|
return m.group(0)
|
|
|
|
fixed = wrong_year_pattern.sub(_replace_year, query)
|
|
|
|
# Step 2: If no year present at all and version keywords exist, append current year
|
|
if current_str not in fixed:
|
|
if any(kw in fixed.lower() for kw in _VERSION_KEYWORDS):
|
|
fixed = f"{fixed} {current_str}"
|
|
|
|
return fixed
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Brave Search with page content fetching
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def brave_search(query: str, api_key: str, max_results: int = 6, status_callback=None) -> str:
|
|
"""
|
|
Search Brave and fetch top page contents for richer context.
|
|
Returns a formatted string suitable for injection into the system prompt.
|
|
status_callback: optional function(str) called with real-time status updates.
|
|
"""
|
|
def _status(msg: str):
|
|
if status_callback:
|
|
status_callback(msg)
|
|
|
|
if not api_key:
|
|
return "⚠️ Brave API key not configured."
|
|
|
|
try:
|
|
headers = {
|
|
"Accept": "application/json",
|
|
"Accept-Encoding": "gzip",
|
|
"X-Subscription-Token": api_key,
|
|
}
|
|
params = {
|
|
"q": query,
|
|
"count": max_results,
|
|
"text_decorations": False,
|
|
"search_lang": "fi" if detect_finnish(query) else "en",
|
|
}
|
|
|
|
_status(f"Searching: *{query}*")
|
|
resp = requests.get(BRAVE_SEARCH_URL, headers=headers, params=params, timeout=10)
|
|
resp.raise_for_status()
|
|
|
|
data = resp.json()
|
|
web_results = data.get("web", {}).get("results", [])
|
|
|
|
if not web_results:
|
|
return f"No web results found for: {query}"
|
|
|
|
_status(f"Found {len(web_results)} results")
|
|
|
|
# Also grab any infobox / knowledge graph snippet
|
|
infobox_text = ""
|
|
infobox = data.get("infobox", {})
|
|
if isinstance(infobox, dict):
|
|
long_desc = infobox.get("long_desc", "")
|
|
if long_desc:
|
|
infobox_text = f"Knowledge panel: {long_desc}\n\n"
|
|
|
|
# Build results with page content fetching for top 3
|
|
sections = []
|
|
if infobox_text:
|
|
sections.append(infobox_text)
|
|
|
|
for i, r in enumerate(web_results):
|
|
title = r.get("title", "")
|
|
url = r.get("url", "")
|
|
desc = r.get("description", "")
|
|
age = r.get("age", "")
|
|
age_str = f" ({age})" if age else ""
|
|
|
|
# Fetch full page content for the top 3 results
|
|
page_content = ""
|
|
if i < 3:
|
|
_status(f"Reading [{i+1}/{min(3, len(web_results))}]: [{title}]({url})")
|
|
page_content = _fetch_page_content(url)
|
|
|
|
section = f"[{i+1}] {title}{age_str}\nURL: {url}\nSnippet: {desc}"
|
|
if page_content:
|
|
section += f"\nContent:\n{page_content}"
|
|
sections.append(section)
|
|
|
|
_status("Search complete")
|
|
return f"Web search results for: {query}\n{'='*60}\n\n" + "\n\n---\n\n".join(sections)
|
|
|
|
except Exception as e:
|
|
print(f"[Router] Brave Search failed: {e}")
|
|
return f"Web search failed: {e}"
|
|
|
|
|
|
def _fetch_page_content(url: str, max_chars: int = 3000) -> str:
|
|
"""Fetch and extract readable text from a URL. Returns truncated plain text."""
|
|
try:
|
|
headers = {
|
|
"User-Agent": "Mozilla/5.0 (compatible; LLMRouter/3.0)",
|
|
"Accept": "text/html,application/xhtml+xml",
|
|
}
|
|
resp = requests.get(url, headers=headers, timeout=8, allow_redirects=True)
|
|
resp.raise_for_status()
|
|
|
|
content_type = resp.headers.get("Content-Type", "")
|
|
if "html" not in content_type and "text" not in content_type:
|
|
return ""
|
|
|
|
html = resp.text
|
|
|
|
# Lightweight HTML → text extraction (no BeautifulSoup dependency)
|
|
# Remove script, style, nav, header, footer tags and their contents
|
|
for tag in ["script", "style", "nav", "header", "footer", "aside", "noscript"]:
|
|
html = re.sub(rf"<{tag}[^>]*>.*?</{tag}>", " ", html, flags=re.S | re.I)
|
|
|
|
# Remove all remaining HTML tags
|
|
text = re.sub(r"<[^>]+>", " ", html)
|
|
|
|
# Decode HTML entities
|
|
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
|
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
|
text = re.sub(r"&#\d+;", " ", text)
|
|
text = re.sub(r"&\w+;", " ", text)
|
|
|
|
# Collapse whitespace
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
|
|
if len(text) < 50:
|
|
return ""
|
|
|
|
return text[:max_chars]
|
|
|
|
except Exception:
|
|
return ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Stable Diffusion image generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _refine_sd_prompt(user_message: str, ollama_url: str, messages: List[dict] = None) -> str:
|
|
"""Use the LLM to convert a user request into an optimized SD prompt.
|
|
Includes conversation history so the model understands context like 'generate an image of that'.
|
|
"""
|
|
try:
|
|
# Build context from recent conversation history
|
|
context_messages = [{"role": "system", "content": SYSTEM_PROMPTS["image_generation"]}]
|
|
if messages:
|
|
# Include last few exchanges for context (trim to avoid blowing up the context)
|
|
recent = [m for m in messages if m.get("role") in ("user", "assistant") and m.get("content")]
|
|
for msg in recent[-6:]: # Last 3 exchanges
|
|
content = msg["content"]
|
|
if isinstance(content, list):
|
|
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
|
|
context_messages.append({"role": msg["role"], "content": content[:500]})
|
|
else:
|
|
context_messages.append({"role": "user", "content": user_message[:500]})
|
|
|
|
payload = {
|
|
"model": MODELS["image_generation"],
|
|
"messages": context_messages,
|
|
"stream": False,
|
|
"options": {"temperature": 0.7, "num_ctx": 4096},
|
|
}
|
|
resp = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=30)
|
|
resp.raise_for_status()
|
|
refined = resp.json()["message"]["content"].strip()
|
|
# Strip any accidental markdown or quotes the model might add
|
|
refined = refined.strip('"\'`')
|
|
refined = re.sub(r"^```\w*\s*", "", refined)
|
|
refined = re.sub(r"\s*```$", "", refined)
|
|
return refined
|
|
except Exception as e:
|
|
print(f"[Router] SD prompt refinement failed: {e}")
|
|
# Fallback: use the user message directly
|
|
return user_message
|
|
|
|
|
|
def _negative_prompt() -> str:
|
|
"""Standard negative prompt for SD."""
|
|
return (
|
|
"lowres, bad anatomy, bad hands, text, error, missing fingers, "
|
|
"extra digit, fewer digits, cropped, worst quality, low quality, "
|
|
"normal quality, jpeg artifacts, signature, watermark, username, blurry, "
|
|
"deformed, distorted, disfigured, mutation, mutated, ugly"
|
|
)
|
|
|
|
|
|
def _compress_image(b64_png: str, quality: int = 80) -> str:
|
|
"""Convert a base64 PNG from SD to a smaller base64 JPEG."""
|
|
try:
|
|
from PIL import Image
|
|
img_data = base64.b64decode(b64_png)
|
|
img = Image.open(BytesIO(img_data))
|
|
if img.mode == "RGBA":
|
|
img = img.convert("RGB")
|
|
buf = BytesIO()
|
|
img.save(buf, format="JPEG", quality=quality, optimize=True)
|
|
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
except Exception as e:
|
|
print(f"[Router] Image compression failed: {e}, using original")
|
|
return b64_png
|
|
|
|
|
|
def _unload_ollama_models(ollama_url: str):
|
|
"""Unload all Ollama models from VRAM to make room for image generation."""
|
|
try:
|
|
# List running models
|
|
resp = requests.get(f"{ollama_url}/api/ps", timeout=5)
|
|
if resp.ok:
|
|
models = resp.json().get("models", [])
|
|
for model in models:
|
|
name = model.get("name", "")
|
|
if name:
|
|
# Setting keep_alive to 0 unloads the model immediately
|
|
requests.post(
|
|
f"{ollama_url}/api/generate",
|
|
json={"model": name, "keep_alive": 0},
|
|
timeout=10,
|
|
)
|
|
print(f"[Router] Unloaded Ollama model: {name}")
|
|
except Exception as e:
|
|
print(f"[Router] Failed to unload Ollama models: {e}")
|
|
|
|
|
|
def _cleanup_after_generation(sd_url: str):
|
|
"""Free VRAM and RAM after image generation so Ollama can load models."""
|
|
# 1. Unload SD checkpoint from VRAM
|
|
try:
|
|
requests.post(f"{sd_url}/sdapi/v1/unload-checkpoint", timeout=5)
|
|
print("[Router] SD checkpoint unloaded from VRAM")
|
|
except Exception:
|
|
pass
|
|
|
|
# 2. Drop Linux page cache to free RAM
|
|
try:
|
|
os.system("sync; echo 3 > /proc/sys/vm/drop_caches 2>/dev/null")
|
|
print("[Router] Page cache dropped")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def generate_image(
|
|
user_message: str,
|
|
ollama_url: str,
|
|
sd_url: str,
|
|
width: int = 512,
|
|
height: int = 512,
|
|
steps: int = 30,
|
|
cfg_scale: float = 7.0,
|
|
messages: List[dict] = None,
|
|
) -> tuple:
|
|
"""
|
|
Generate an image via AUTOMATIC1111 API.
|
|
Returns (base64_image, refined_prompt) on success, or (None, error_message) on failure.
|
|
"""
|
|
# Step 1: Refine the prompt using the LLM FIRST (while Ollama is still loaded)
|
|
refined_prompt = _refine_sd_prompt(user_message, ollama_url, messages)
|
|
|
|
# Step 2: Unload Ollama models from VRAM to make room for SDXL
|
|
_unload_ollama_models(ollama_url)
|
|
print(f"[Router] SD prompt: {refined_prompt[:120]}")
|
|
|
|
# Step 2: Call AUTOMATIC1111
|
|
try:
|
|
payload = {
|
|
"prompt": refined_prompt,
|
|
"negative_prompt": _negative_prompt(),
|
|
"width": width,
|
|
"height": height,
|
|
"steps": steps,
|
|
"cfg_scale": cfg_scale,
|
|
"sampler_name": "DPM++ 2M Karras",
|
|
"batch_size": 1,
|
|
"n_iter": 1,
|
|
}
|
|
resp = requests.post(
|
|
f"{sd_url}/sdapi/v1/txt2img",
|
|
json=payload,
|
|
timeout=120,
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
data = resp.json()
|
|
images = data.get("images", [])
|
|
if not images:
|
|
return None, "Stable Diffusion returned no images."
|
|
|
|
# Compress PNG→JPEG to reduce base64 size for streaming
|
|
compressed = _compress_image(images[0])
|
|
|
|
# Free VRAM and RAM so Ollama can load models again
|
|
_cleanup_after_generation(sd_url)
|
|
|
|
return compressed, refined_prompt
|
|
|
|
except requests.exceptions.ConnectionError:
|
|
return None, f"Cannot connect to Stable Diffusion at {sd_url}. Is it running?"
|
|
except requests.exceptions.Timeout:
|
|
return None, "Image generation timed out (>120s)."
|
|
except Exception as e:
|
|
return None, f"Image generation failed: {e}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Image handling (unchanged from v2)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def extract_images_from_messages(messages: List[dict]) -> tuple:
|
|
"""Separate base64 images from message content."""
|
|
images = []
|
|
clean_messages = []
|
|
|
|
for msg in messages:
|
|
content = msg.get("content", "")
|
|
if isinstance(content, list):
|
|
text_parts = []
|
|
for part in content:
|
|
if isinstance(part, dict):
|
|
if part.get("type") == "text":
|
|
text_parts.append(part.get("text", ""))
|
|
elif part.get("type") == "image_url":
|
|
url = part.get("image_url", {}).get("url", "")
|
|
if url.startswith("data:"):
|
|
match = re.match(r"data:[^;]+;base64,(.+)", url)
|
|
if match:
|
|
images.append(match.group(1))
|
|
clean_messages.append({
|
|
"role": msg["role"],
|
|
"content": " ".join(text_parts).strip(),
|
|
})
|
|
else:
|
|
clean_messages.append(msg)
|
|
|
|
return images, clean_messages
|
|
|
|
|
|
def has_image_content(messages: List[dict]) -> bool:
|
|
"""Check if the latest user message contains an uploaded image.
|
|
Only checks user messages (not assistant responses which may contain generated images).
|
|
"""
|
|
# Find the last user message
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
content = msg.get("content", "")
|
|
if isinstance(content, list):
|
|
for part in content:
|
|
if isinstance(part, dict) and part.get("type") == "image_url":
|
|
return True
|
|
elif isinstance(content, str) and "data:image" in content:
|
|
return True
|
|
return False # Last user message found but has no image
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pipeline
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class Pipeline:
|
|
class Valves(BaseModel):
|
|
ollama_url: str = Field(default=OLLAMA_URL, description="Ollama API base URL")
|
|
sd_url: str = Field(default=SD_URL, description="AUTOMATIC1111 Stable Diffusion WebUI URL")
|
|
sd_width: int = Field(default=1024, description="Generated image width")
|
|
sd_height: int = Field(default=1024, description="Generated image height")
|
|
sd_steps: int = Field(default=25, description="Stable Diffusion sampling steps")
|
|
sd_cfg_scale: float = Field(default=7.0, description="Stable Diffusion CFG scale")
|
|
brave_api_key: str = Field(default=BRAVE_API_KEY, description="Brave Search API key")
|
|
brave_max_results: int = Field(default=6, description="Number of Brave search results to fetch")
|
|
brave_fetch_pages: int = Field(default=3, description="Number of top results to fetch full page content for")
|
|
use_ai_classifier: bool = Field(default=True, description="Use AI classifier (vs keyword-only)")
|
|
show_routing_info: bool = Field(default=True, description="Show routing banner in responses")
|
|
search_context_max_chars: int = Field(default=12000, description="Max chars of search context to inject")
|
|
|
|
def __init__(self):
|
|
self.id = "llm-router"
|
|
self.name = "LLM Router v3"
|
|
self.valves = self.Valves()
|
|
|
|
async def on_startup(self):
|
|
print(f"[Router] LLM Router v3.0 starting — Ollama: {self.valves.ollama_url}")
|
|
print(f"[Router] Classifier: {CLASSIFIER_MODEL} | AI: {self.valves.use_ai_classifier}")
|
|
print(f"[Router] Brave Search: {'configured' if self.valves.brave_api_key else 'NO API KEY'}")
|
|
|
|
async def on_shutdown(self):
|
|
print("[Router] LLM Router v3.0 shutting down")
|
|
|
|
def pipe(
|
|
self,
|
|
user_message: str,
|
|
model_id: str,
|
|
messages: List[dict],
|
|
body: dict,
|
|
) -> Iterator[str]:
|
|
|
|
# --- Step 1: Vision override ---
|
|
if has_image_content(messages):
|
|
category = "vision"
|
|
needs_search = False
|
|
search_query = ""
|
|
method = "vision_detect"
|
|
else:
|
|
# --- Step 2: Classify ---
|
|
result = classify(
|
|
user_message,
|
|
self.valves.ollama_url,
|
|
self.valves.use_ai_classifier,
|
|
)
|
|
category = result["category"]
|
|
needs_search = result["search"]
|
|
search_query = result["query"]
|
|
method = result["method"]
|
|
|
|
target_model = MODELS.get(category, MODELS["general"])
|
|
system_prompt = SYSTEM_PROMPTS.get(category, SYSTEM_PROMPTS["general"])
|
|
|
|
# Inject language instruction — always respond in the user's language
|
|
if detect_finnish(user_message) and category not in ("reasoning_fi", "image_generation"):
|
|
system_prompt = (
|
|
"TÄRKEÄ: Käyttäjä kirjoittaa suomeksi. Vastaa AINA suomeksi.\n\n"
|
|
+ system_prompt
|
|
)
|
|
|
|
print(f"[Router] {method} → {category} → {target_model} | search={needs_search} query='{search_query}'")
|
|
|
|
# --- Step 3: Routing info banner ---
|
|
if self.valves.show_routing_info:
|
|
display_cat = category.replace("_en", " 🇬🇧").replace("_fi", " 🇫🇮")
|
|
search_label = f" | 🌐 `{search_query}`" if needs_search else ""
|
|
yield f"> 🔀 **Router** `[{method}]` → `{target_model}` *(category: {display_cat}){search_label}*\n\n"
|
|
|
|
# --- Step 4: Image generation (early return) ---
|
|
if category == "image_generation":
|
|
yield "> 🎨 Generating image…\n\n"
|
|
base64_img, refined_prompt = generate_image(
|
|
user_message,
|
|
self.valves.ollama_url,
|
|
self.valves.sd_url,
|
|
width=self.valves.sd_width,
|
|
height=self.valves.sd_height,
|
|
steps=self.valves.sd_steps,
|
|
cfg_scale=self.valves.sd_cfg_scale,
|
|
messages=messages,
|
|
)
|
|
if base64_img:
|
|
# Yield the image in chunks to avoid "chunk too big" errors
|
|
img_tag = f""
|
|
chunk_size = 4096
|
|
for i in range(0, len(img_tag), chunk_size):
|
|
yield img_tag[i:i + chunk_size]
|
|
yield "\n\n"
|
|
yield f"*Prompt used: {refined_prompt}*\n"
|
|
else:
|
|
yield f"\n\n❌ {refined_prompt}\n"
|
|
return
|
|
|
|
# --- Step 5: Web search ---
|
|
search_context = ""
|
|
search_status_lines = []
|
|
if needs_search and search_query and self.valves.brave_api_key:
|
|
# Collect status updates via callback, yield them in real time
|
|
def _on_status(msg: str):
|
|
search_status_lines.append(msg)
|
|
|
|
yield "> 🔍 Searching the web…\n\n"
|
|
|
|
# We need to yield status in real-time, so we run search in a thread
|
|
import threading
|
|
search_result = [None]
|
|
def _run_search():
|
|
search_result[0] = brave_search(
|
|
search_query,
|
|
self.valves.brave_api_key,
|
|
max_results=self.valves.brave_max_results,
|
|
status_callback=_on_status,
|
|
)
|
|
|
|
t = threading.Thread(target=_run_search)
|
|
t.start()
|
|
|
|
last_count = 0
|
|
while t.is_alive():
|
|
t.join(timeout=0.3)
|
|
# Yield any new status lines
|
|
while last_count < len(search_status_lines):
|
|
yield f"> {search_status_lines[last_count]}\n>\n"
|
|
last_count += 1
|
|
# Yield any remaining status lines
|
|
while last_count < len(search_status_lines):
|
|
yield f"> {search_status_lines[last_count]}\n>\n"
|
|
last_count += 1
|
|
|
|
yield "\n\n"
|
|
search_context = search_result[0] or ""
|
|
# Truncate if too large
|
|
if len(search_context) > self.valves.search_context_max_chars:
|
|
search_context = search_context[:self.valves.search_context_max_chars] + "\n\n[...truncated]"
|
|
|
|
# Detect failed search and warn the user
|
|
if (search_context.startswith("⚠️")
|
|
or search_context.startswith("Web search failed")
|
|
or search_context.startswith("No web results found")):
|
|
yield "> ⚠️ Web search failed — answering from model knowledge.\n\n"
|
|
print(f"[Router] Search failed: {search_context[:120]}")
|
|
else:
|
|
print(f"[Router] Search complete: {len(search_context)} chars")
|
|
|
|
elif needs_search and not self.valves.brave_api_key:
|
|
yield "> ⚠️ Web search not available (no API key) — answering from model knowledge.\n\n"
|
|
print("[Router] Search needed but no Brave API key configured!")
|
|
|
|
# --- Step 6: Build messages ---
|
|
images, clean_messages = extract_images_from_messages(messages)
|
|
|
|
# Check if search actually returned usable results (not just an error)
|
|
search_ok = bool(
|
|
search_context
|
|
and not search_context.startswith("⚠️")
|
|
and not search_context.startswith("Web search failed")
|
|
and not search_context.startswith("No web results found")
|
|
)
|
|
|
|
if search_ok:
|
|
today = date.today().strftime("%Y-%m-%d")
|
|
full_system = (
|
|
f"{system_prompt}\n\n"
|
|
f"Today's date: {today}\n\n"
|
|
f"## Web Search Results\n"
|
|
f"The following are fresh web search results. Use them as your PRIMARY source of truth.\n"
|
|
f"Your training data may be outdated — always prefer information from these results.\n"
|
|
f"When results conflict, prefer the most recent one (check dates/ages).\n"
|
|
f"Cite the source URL when stating specific facts.\n"
|
|
f"If the search results don't contain enough information to fully answer, "
|
|
f"say so honestly rather than guessing.\n\n"
|
|
f"{search_context}"
|
|
)
|
|
elif needs_search:
|
|
# Search was requested but failed — tell the model so it can be honest
|
|
today = date.today().strftime("%Y-%m-%d")
|
|
full_system = (
|
|
f"{system_prompt}\n\n"
|
|
f"Today's date: {today}\n\n"
|
|
f"NOTE: A web search was attempted for this question but failed or returned no results. "
|
|
f"Answer as best you can from your training data, but clearly state that you could not "
|
|
f"verify the information with a live web search and the answer may be outdated."
|
|
)
|
|
else:
|
|
full_system = system_prompt
|
|
|
|
ollama_messages = [{"role": "system", "content": full_system}]
|
|
for msg in clean_messages:
|
|
if msg.get("role") in ("user", "assistant") and msg.get("content"):
|
|
ollama_messages.append(msg)
|
|
|
|
# Attach images to the last user message
|
|
if images:
|
|
for i in range(len(ollama_messages) - 1, -1, -1):
|
|
if ollama_messages[i]["role"] == "user":
|
|
ollama_messages[i]["images"] = images
|
|
break
|
|
|
|
# --- Step 7: Call the target model ---
|
|
payload = {
|
|
"model": target_model,
|
|
"messages": ollama_messages,
|
|
"stream": True,
|
|
"options": {
|
|
"temperature": body.get("temperature", 0.7),
|
|
"num_ctx": 8192,
|
|
},
|
|
}
|
|
|
|
try:
|
|
resp = requests.post(
|
|
f"{self.valves.ollama_url}/api/chat",
|
|
json=payload,
|
|
stream=True,
|
|
timeout=180,
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
in_thinking = False
|
|
for line in resp.iter_lines():
|
|
if line:
|
|
try:
|
|
chunk = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
msg = chunk.get("message", {})
|
|
|
|
# Handle thinking/reasoning tokens (displayed in a collapsible block)
|
|
thinking_content = msg.get("thinking", "")
|
|
if thinking_content:
|
|
if not in_thinking:
|
|
yield "<details>\n<summary>💭 Thinking…</summary>\n\n"
|
|
in_thinking = True
|
|
yield thinking_content
|
|
|
|
# Handle regular content
|
|
if msg.get("content"):
|
|
if in_thinking:
|
|
yield "\n</details>\n\n"
|
|
in_thinking = False
|
|
yield msg["content"]
|
|
|
|
if chunk.get("done"):
|
|
if in_thinking:
|
|
yield "\n</details>\n\n"
|
|
break
|
|
|
|
except requests.exceptions.ConnectionError:
|
|
yield f"\n\n❌ Connection error to Ollama ({self.valves.ollama_url}). Is the service running?"
|
|
except requests.exceptions.Timeout:
|
|
yield "\n\n❌ Timeout — the model is responding too slowly."
|
|
except Exception as e:
|
|
yield f"\n\n❌ Error: {str(e)}"
|