AdvancedCustomization & Extensions

Customization & Extensions

Apollo RAG is designed with extensibility at its core. This guide shows you how to customize components, integrate new models, and extend the system for your specific needs.

Customization Overview

Apollo provides multiple extension points across its architecture:

  • Model Layer: Swap LLMs, embeddings, and rerankers
  • Retrieval Layer: Add custom strategies and transformations
  • API Layer: Extend endpoints and add middleware
  • Frontend: Theme customization and component extensions

All customization maintains backward compatibility with the existing API surface.

Custom Model Integration

Adding a New LLM Backend

Apollo supports multiple LLM backends through the factory pattern. To add a new backend:

1. Create LLM Engine Class (backend/_src/llm_engine_custom.py):

from typing import AsyncIterator, Optional
from _src.llm_base import BaseLLMEngine
 
class CustomLLMEngine(BaseLLMEngine):
    """Custom LLM implementation."""
 
    def __init__(self, model_name: str, api_key: str, **kwargs):
        self.model_name = model_name
        self.api_key = api_key
        # Initialize your custom client
        self.client = CustomClient(api_key=api_key)
 
    async def generate(
        self,
        prompt: str,
        max_tokens: int = 1024,
        temperature: float = 0.0,
        **kwargs
    ) -> str:
        """Generate non-streaming response."""
        response = await self.client.complete(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature
        )
        return response.text
 
    async def generate_stream(
        self,
        prompt: str,
        max_tokens: int = 1024,
        **kwargs
    ) -> AsyncIterator[str]:
        """Stream token generation."""
        async for token in self.client.stream(prompt=prompt):
            yield token
 
    def cleanup(self):
        """Clean up resources."""
        if self.client:
            self.client.close()

2. Register in Factory (backend/_src/llm_factory.py):

from _src.llm_engine_custom import CustomLLMEngine
 
def create_llm_engine(backend: str, config: dict) -> BaseLLMEngine:
    """Factory for LLM engine creation."""
    if backend == "llamacpp":
        return LlamaCppEngine(**config)
    elif backend == "ollama":
        return OllamaEngine(**config)
    elif backend == "custom":
        return CustomLLMEngine(**config)
    else:
        raise ValueError(f"Unknown backend: {backend}")

3. Configure Model Profile (backend/config.yml):

model_profiles:
  custom-gpt4:
    backend: custom
    model_name: gpt-4-turbo
    api_key: ${OPENAI_API_KEY}
    max_tokens: 4096
    temperature: 0.0

Model hotswapping automatically works with new backends via ModelManager.switch_model().

Custom Embedding Models

Replace the default BGE embeddings with your own:

from langchain_community.embeddings import HuggingFaceEmbeddings
 
class CustomEmbeddings:
    """Custom embedding implementation."""
 
    def __init__(self, model_name: str):
        self.embeddings = HuggingFaceEmbeddings(
            model_name=model_name,
            model_kwargs={"device": "cpu"},  # or cuda
            encode_kwargs={"normalize_embeddings": True}
        )
 
    def embed_query(self, text: str) -> list[float]:
        """Embed single query."""
        return self.embeddings.embed_query(text)
 
    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        """Batch embed documents."""
        return self.embeddings.embed_documents(texts)

Update backend/app/core/rag_engine.py:

# Replace default embeddings
self.embeddings = CustomEmbeddings(
    model_name="thenlper/gte-large"  # Example alternative
)

Custom Rerankers

Implement a custom reranking strategy:

class CustomReranker:
    """Custom document reranker."""
 
    def __init__(self, model_name: str):
        self.model = SentenceTransformer(model_name)
 
    def rerank(
        self,
        query: str,
        documents: list[Document],
        top_k: int = 5
    ) -> list[Document]:
        """Rerank documents by relevance."""
        # Compute relevance scores
        query_emb = self.model.encode(query)
        doc_embs = self.model.encode([d.page_content for d in documents])
 
        # Calculate cosine similarity
        scores = cosine_similarity([query_emb], doc_embs)[0]
 
        # Sort and return top-k
        ranked_indices = scores.argsort()[::-1][:top_k]
        return [documents[i] for i in ranked_indices]

Custom Retrieval Strategies

Adding a Retrieval Mode

Extend AdaptiveRetriever with a new strategy:

# In backend/_src/adaptive_retrieval.py
 
class AdaptiveRetriever:
    """Extended with custom strategy."""
 
    async def retrieve_custom(
        self,
        query: str,
        k: int = 10,
        **kwargs
    ) -> RetrievalResult:
        """Custom retrieval strategy."""
        # 1. Transform query
        enhanced_query = self._custom_transform(query)
 
        # 2. Multi-stage search
        dense_results = await self._dense_search(enhanced_query, k=k)
        sparse_results = await self._sparse_search(query, k=k)
 
        # 3. Fusion with custom weights
        fused = self._custom_fusion(dense_results, sparse_results)
 
        # 4. Advanced reranking
        reranked = await self._custom_rerank(query, fused)
 
        return RetrievalResult(
            documents=reranked[:k],
            strategy="custom",
            metadata={"stages": ["transform", "search", "fusion", "rerank"]}
        )

Register in query router:

# In backend/app/api/query.py
 
@router.post("/query")
async def query_endpoint(request: QueryRequest):
    if request.mode == "custom":
        result = await rag_engine.retrieval_engine.retrieve_custom(
            query=request.question,
            k=request.top_k or 10
        )

Custom Query Transformations

Add domain-specific query enhancement:

class CustomQueryTransformer:
    """Domain-aware query transformation."""
 
    def __init__(self, llm: BaseLLMEngine):
        self.llm = llm
 
    async def transform_legal_query(self, query: str) -> list[str]:
        """Transform legal queries into case law search terms."""
        prompt = f"""Given this legal query: "{query}"
        Generate 3 search variants using:
        1. Legal terminology
        2. Case law references
        3. Statute citations
 
        Output format: one variant per line."""
 
        response = await self.llm.generate(prompt, max_tokens=200)
        return response.strip().split("\n")
 
    async def transform_medical_query(self, query: str) -> list[str]:
        """Transform medical queries with ICD-10 codes."""
        # Similar pattern for medical domain
        pass

API Extensions

Adding Custom Endpoints

Extend the FastAPI application:

# In backend/app/api/custom_routes.py
 
from fastapi import APIRouter, Depends
from app.core.rag_engine import RAGEngine
 
router = APIRouter(prefix="/api/custom", tags=["custom"])
 
@router.post("/specialized-search")
async def specialized_search(
    query: str,
    domain: str,
    rag_engine: RAGEngine = Depends()
):
    """Domain-specific search endpoint."""
    # Custom logic here
    results = await rag_engine.specialized_query(
        query=query,
        domain=domain
    )
    return {"results": results}
 
@router.get("/analytics/query-patterns")
async def query_analytics():
    """Custom analytics endpoint."""
    patterns = analyze_query_patterns()
    return {"patterns": patterns}

Register in backend/app/main.py:

from app.api import custom_routes
 
app.include_router(custom_routes.router)

Custom Middleware

Add request preprocessing:

# In backend/app/middleware/custom.py
 
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
 
class DomainDetectionMiddleware(BaseHTTPMiddleware):
    """Automatically detect query domain."""
 
    async def dispatch(self, request: Request, call_next):
        if request.url.path == "/api/query":
            body = await request.json()
            domain = self._detect_domain(body.get("question", ""))
            request.state.detected_domain = domain
 
        response = await call_next(request)
        return response
 
    def _detect_domain(self, text: str) -> str:
        """Classify query domain."""
        # ML-based or rule-based classification
        return "general"

Register middleware:

app.add_middleware(DomainDetectionMiddleware)

Frontend Customization

Theme Customization

Modify src/theme.ts:

export const customTheme = {
  colors: {
    primary: '#3b82f6',    // Custom blue
    secondary: '#8b5cf6',  // Custom purple
    background: '#0f172a', // Dark slate
    surface: '#1e293b',
    text: '#f1f5f9',
    border: '#334155'
  },
  fonts: {
    sans: 'Inter, system-ui, sans-serif',
    mono: 'JetBrains Mono, monospace'
  },
  spacing: {
    xs: '0.25rem',
    sm: '0.5rem',
    md: '1rem',
    lg: '1.5rem',
    xl: '2rem'
  }
}

Custom Components

Create specialized UI components:

// src/components/Custom/SpecializedChat.tsx
 
interface SpecializedChatProps {
  domain: 'legal' | 'medical' | 'technical'
  onQuery: (query: string, metadata: object) => void
}
 
export function SpecializedChat({ domain, onQuery }: SpecializedChatProps) {
  const [query, setQuery] = useState('')
 
  const handleSubmit = () => {
    // Add domain-specific metadata
    onQuery(query, {
      domain,
      specialFilters: getFiltersForDomain(domain)
    })
  }
 
  return (
    <div className="specialized-chat">
      {/* Domain-specific UI elements */}
      <DomainSelector value={domain} />
      <SpecializedInput domain={domain} value={query} />
      <button onClick={handleSubmit}>Ask {domain} Question</button>
    </div>
  )
}

Example: Custom Search Provider

Complete example of adding a web search provider:

Backend Integration:

# backend/_src/web_search_provider.py
 
import aiohttp
from typing import list, dict
 
class WebSearchProvider:
    """Integrate external search API."""
 
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.base_url = "https://api.searchprovider.com/v1"
 
    async def search(self, query: str, limit: int = 5) -> list[dict]:
        """Search external web sources."""
        async with aiohttp.ClientSession() as session:
            params = {
                "q": query,
                "limit": limit,
                "api_key": self.api_key
            }
            async with session.get(f"{self.base_url}/search", params=params) as resp:
                data = await resp.json()
                return data["results"]
 
    async def augment_retrieval(
        self,
        query: str,
        local_results: list[Document],
        web_limit: int = 3
    ) -> list[Document]:
        """Combine local and web results."""
        web_results = await self.search(query, limit=web_limit)
 
        # Convert to Document format
        web_docs = [
            Document(
                page_content=r["snippet"],
                metadata={
                    "source": "web",
                    "url": r["url"],
                    "title": r["title"]
                }
            )
            for r in web_results
        ]
 
        # Merge with local results
        return local_results + web_docs

Enable in Config:

# backend/config.yml
 
web_search:
  enabled: true
  provider: custom_search
  api_key: ${SEARCH_API_KEY}
  augment_local: true
  max_web_results: 3

Best Practices

Version Compatibility: Test extensions against the target Apollo version. Extension APIs maintain backward compatibility within major versions.

Extension Guidelines

  • Isolation: Keep custom code in separate modules (e.g., _src/custom/)
  • Configuration: Use config.yml for all customization parameters
  • Logging: Add structured logging for debugging
  • Error Handling: Gracefully degrade if custom components fail
  • Documentation: Document extension points and breaking changes

Performance Considerations

  • Caching: Integrate with EmbeddingCache for custom embeddings
  • Async: Use async/await for I/O-bound operations
  • Batching: Process documents in batches for efficiency
  • Monitoring: Add metrics for custom components

Testing Extensions

# backend/tests/test_custom.py
 
import pytest
from _src.custom_retriever import CustomRetriever
 
@pytest.mark.asyncio
async def test_custom_retrieval():
    retriever = CustomRetriever()
    results = await retriever.retrieve_custom("test query")
 
    assert len(results) > 0
    assert results[0].metadata["strategy"] == "custom"

Next Steps

Need help with customization? Check the GitHub Discussions for community support.