Customization & Extensions
Apollo RAG is designed with extensibility at its core. This guide shows you how to customize components, integrate new models, and extend the system for your specific needs.
Customization Overview
Apollo provides multiple extension points across its architecture:
- Model Layer: Swap LLMs, embeddings, and rerankers
- Retrieval Layer: Add custom strategies and transformations
- API Layer: Extend endpoints and add middleware
- Frontend: Theme customization and component extensions
All customization maintains backward compatibility with the existing API surface.
Custom Model Integration
Adding a New LLM Backend
Apollo supports multiple LLM backends through the factory pattern. To add a new backend:
1. Create LLM Engine Class (backend/_src/llm_engine_custom.py):
from typing import AsyncIterator, Optional
from _src.llm_base import BaseLLMEngine
class CustomLLMEngine(BaseLLMEngine):
"""Custom LLM implementation."""
def __init__(self, model_name: str, api_key: str, **kwargs):
self.model_name = model_name
self.api_key = api_key
# Initialize your custom client
self.client = CustomClient(api_key=api_key)
async def generate(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 0.0,
**kwargs
) -> str:
"""Generate non-streaming response."""
response = await self.client.complete(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature
)
return response.text
async def generate_stream(
self,
prompt: str,
max_tokens: int = 1024,
**kwargs
) -> AsyncIterator[str]:
"""Stream token generation."""
async for token in self.client.stream(prompt=prompt):
yield token
def cleanup(self):
"""Clean up resources."""
if self.client:
self.client.close()2. Register in Factory (backend/_src/llm_factory.py):
from _src.llm_engine_custom import CustomLLMEngine
def create_llm_engine(backend: str, config: dict) -> BaseLLMEngine:
"""Factory for LLM engine creation."""
if backend == "llamacpp":
return LlamaCppEngine(**config)
elif backend == "ollama":
return OllamaEngine(**config)
elif backend == "custom":
return CustomLLMEngine(**config)
else:
raise ValueError(f"Unknown backend: {backend}")3. Configure Model Profile (backend/config.yml):
model_profiles:
custom-gpt4:
backend: custom
model_name: gpt-4-turbo
api_key: ${OPENAI_API_KEY}
max_tokens: 4096
temperature: 0.0Model hotswapping automatically works with new backends via ModelManager.switch_model().
Custom Embedding Models
Replace the default BGE embeddings with your own:
from langchain_community.embeddings import HuggingFaceEmbeddings
class CustomEmbeddings:
"""Custom embedding implementation."""
def __init__(self, model_name: str):
self.embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={"device": "cpu"}, # or cuda
encode_kwargs={"normalize_embeddings": True}
)
def embed_query(self, text: str) -> list[float]:
"""Embed single query."""
return self.embeddings.embed_query(text)
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Batch embed documents."""
return self.embeddings.embed_documents(texts)Update backend/app/core/rag_engine.py:
# Replace default embeddings
self.embeddings = CustomEmbeddings(
model_name="thenlper/gte-large" # Example alternative
)Custom Rerankers
Implement a custom reranking strategy:
class CustomReranker:
"""Custom document reranker."""
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name)
def rerank(
self,
query: str,
documents: list[Document],
top_k: int = 5
) -> list[Document]:
"""Rerank documents by relevance."""
# Compute relevance scores
query_emb = self.model.encode(query)
doc_embs = self.model.encode([d.page_content for d in documents])
# Calculate cosine similarity
scores = cosine_similarity([query_emb], doc_embs)[0]
# Sort and return top-k
ranked_indices = scores.argsort()[::-1][:top_k]
return [documents[i] for i in ranked_indices]Custom Retrieval Strategies
Adding a Retrieval Mode
Extend AdaptiveRetriever with a new strategy:
# In backend/_src/adaptive_retrieval.py
class AdaptiveRetriever:
"""Extended with custom strategy."""
async def retrieve_custom(
self,
query: str,
k: int = 10,
**kwargs
) -> RetrievalResult:
"""Custom retrieval strategy."""
# 1. Transform query
enhanced_query = self._custom_transform(query)
# 2. Multi-stage search
dense_results = await self._dense_search(enhanced_query, k=k)
sparse_results = await self._sparse_search(query, k=k)
# 3. Fusion with custom weights
fused = self._custom_fusion(dense_results, sparse_results)
# 4. Advanced reranking
reranked = await self._custom_rerank(query, fused)
return RetrievalResult(
documents=reranked[:k],
strategy="custom",
metadata={"stages": ["transform", "search", "fusion", "rerank"]}
)Register in query router:
# In backend/app/api/query.py
@router.post("/query")
async def query_endpoint(request: QueryRequest):
if request.mode == "custom":
result = await rag_engine.retrieval_engine.retrieve_custom(
query=request.question,
k=request.top_k or 10
)Custom Query Transformations
Add domain-specific query enhancement:
class CustomQueryTransformer:
"""Domain-aware query transformation."""
def __init__(self, llm: BaseLLMEngine):
self.llm = llm
async def transform_legal_query(self, query: str) -> list[str]:
"""Transform legal queries into case law search terms."""
prompt = f"""Given this legal query: "{query}"
Generate 3 search variants using:
1. Legal terminology
2. Case law references
3. Statute citations
Output format: one variant per line."""
response = await self.llm.generate(prompt, max_tokens=200)
return response.strip().split("\n")
async def transform_medical_query(self, query: str) -> list[str]:
"""Transform medical queries with ICD-10 codes."""
# Similar pattern for medical domain
passAPI Extensions
Adding Custom Endpoints
Extend the FastAPI application:
# In backend/app/api/custom_routes.py
from fastapi import APIRouter, Depends
from app.core.rag_engine import RAGEngine
router = APIRouter(prefix="/api/custom", tags=["custom"])
@router.post("/specialized-search")
async def specialized_search(
query: str,
domain: str,
rag_engine: RAGEngine = Depends()
):
"""Domain-specific search endpoint."""
# Custom logic here
results = await rag_engine.specialized_query(
query=query,
domain=domain
)
return {"results": results}
@router.get("/analytics/query-patterns")
async def query_analytics():
"""Custom analytics endpoint."""
patterns = analyze_query_patterns()
return {"patterns": patterns}Register in backend/app/main.py:
from app.api import custom_routes
app.include_router(custom_routes.router)Custom Middleware
Add request preprocessing:
# In backend/app/middleware/custom.py
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
class DomainDetectionMiddleware(BaseHTTPMiddleware):
"""Automatically detect query domain."""
async def dispatch(self, request: Request, call_next):
if request.url.path == "/api/query":
body = await request.json()
domain = self._detect_domain(body.get("question", ""))
request.state.detected_domain = domain
response = await call_next(request)
return response
def _detect_domain(self, text: str) -> str:
"""Classify query domain."""
# ML-based or rule-based classification
return "general"Register middleware:
app.add_middleware(DomainDetectionMiddleware)Frontend Customization
Theme Customization
Modify src/theme.ts:
export const customTheme = {
colors: {
primary: '#3b82f6', // Custom blue
secondary: '#8b5cf6', // Custom purple
background: '#0f172a', // Dark slate
surface: '#1e293b',
text: '#f1f5f9',
border: '#334155'
},
fonts: {
sans: 'Inter, system-ui, sans-serif',
mono: 'JetBrains Mono, monospace'
},
spacing: {
xs: '0.25rem',
sm: '0.5rem',
md: '1rem',
lg: '1.5rem',
xl: '2rem'
}
}Custom Components
Create specialized UI components:
// src/components/Custom/SpecializedChat.tsx
interface SpecializedChatProps {
domain: 'legal' | 'medical' | 'technical'
onQuery: (query: string, metadata: object) => void
}
export function SpecializedChat({ domain, onQuery }: SpecializedChatProps) {
const [query, setQuery] = useState('')
const handleSubmit = () => {
// Add domain-specific metadata
onQuery(query, {
domain,
specialFilters: getFiltersForDomain(domain)
})
}
return (
<div className="specialized-chat">
{/* Domain-specific UI elements */}
<DomainSelector value={domain} />
<SpecializedInput domain={domain} value={query} />
<button onClick={handleSubmit}>Ask {domain} Question</button>
</div>
)
}Example: Custom Search Provider
Complete example of adding a web search provider:
Backend Integration:
# backend/_src/web_search_provider.py
import aiohttp
from typing import list, dict
class WebSearchProvider:
"""Integrate external search API."""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.searchprovider.com/v1"
async def search(self, query: str, limit: int = 5) -> list[dict]:
"""Search external web sources."""
async with aiohttp.ClientSession() as session:
params = {
"q": query,
"limit": limit,
"api_key": self.api_key
}
async with session.get(f"{self.base_url}/search", params=params) as resp:
data = await resp.json()
return data["results"]
async def augment_retrieval(
self,
query: str,
local_results: list[Document],
web_limit: int = 3
) -> list[Document]:
"""Combine local and web results."""
web_results = await self.search(query, limit=web_limit)
# Convert to Document format
web_docs = [
Document(
page_content=r["snippet"],
metadata={
"source": "web",
"url": r["url"],
"title": r["title"]
}
)
for r in web_results
]
# Merge with local results
return local_results + web_docsEnable in Config:
# backend/config.yml
web_search:
enabled: true
provider: custom_search
api_key: ${SEARCH_API_KEY}
augment_local: true
max_web_results: 3Best Practices
Version Compatibility: Test extensions against the target Apollo version. Extension APIs maintain backward compatibility within major versions.
Extension Guidelines
- Isolation: Keep custom code in separate modules (e.g.,
_src/custom/) - Configuration: Use
config.ymlfor all customization parameters - Logging: Add structured logging for debugging
- Error Handling: Gracefully degrade if custom components fail
- Documentation: Document extension points and breaking changes
Performance Considerations
- Caching: Integrate with
EmbeddingCachefor custom embeddings - Async: Use async/await for I/O-bound operations
- Batching: Process documents in batches for efficiency
- Monitoring: Add metrics for custom components
Testing Extensions
# backend/tests/test_custom.py
import pytest
from _src.custom_retriever import CustomRetriever
@pytest.mark.asyncio
async def test_custom_retrieval():
retriever = CustomRetriever()
results = await retriever.retrieve_custom("test query")
assert len(results) > 0
assert results[0].metadata["strategy"] == "custom"Next Steps
- Troubleshooting - Debug custom extensions
- API Reference - Full API documentation
- Architecture - Deep dive into system design
Need help with customization? Check the GitHub Discussions for community support.