Source code for promptguard.cache

"""Caching system for PromptGuard to avoid re-analysing identical prompts."""

import hashlib
import logging
from collections import OrderedDict
from datetime import datetime
from typing import Any, Dict, Optional

from .schemas import RiskLevel, RiskScore

logger = logging.getLogger(__name__)


[docs] class PromptCache: """ In-memory LRU cache for prompt analysis results. Uses MD5 hash of the prompt text as the cache key. Entries are evicted in least-recently-used order when the cache reaches ``max_size``. An optional TTL causes stale entries to be discarded on read. Args: max_size: Maximum number of entries to hold before evicting the LRU entry. Must be a positive integer. Defaults to 10 000. ttl_seconds: Optional time-to-live in seconds. Entries older than this value are treated as cache misses and silently removed. Pass ``None`` (the default) to disable expiry. """
[docs] def __init__( self, max_size: int = 10_000, ttl_seconds: Optional[int] = 3600, ) -> None: self.max_size = max_size self.ttl_seconds = ttl_seconds # OrderedDict preserves insertion/access order for O(1) LRU operations. # Each value is {"result": dict, "timestamp": datetime}. self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() logger.info( "Cache initialised (max_size=%s, ttl=%ss)", max_size, ttl_seconds, )
# ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------
[docs] def get(self, prompt: str) -> Optional[RiskScore]: """Return the cached ``RiskScore`` for *prompt*, or ``None`` on miss. A miss is returned when: * the prompt has never been cached, or * the cached entry has exceeded ``ttl_seconds``. """ key = self._hash_prompt(prompt) if key not in self._cache: return None entry = self._cache[key] # Check TTL expiry if self.ttl_seconds is not None: age = (datetime.now() - entry["timestamp"]).total_seconds() if age > self.ttl_seconds: logger.debug("Cache entry expired (age=%.1fs)", age) del self._cache[key] return None # Move to end to mark as most recently used (O(1)) self._cache.move_to_end(key) logger.debug("Cache hit") return self._deserialize_result(entry["result"])
[docs] def set(self, prompt: str, result: RiskScore) -> None: """Store *result* in the cache keyed by *prompt*. If the cache is at capacity, the least-recently-used entry is evicted before the new entry is inserted. """ key = self._hash_prompt(prompt) if key in self._cache: # Update existing entry and move to MRU position self._cache[key] = { "result": self._serialize_result(result), "timestamp": datetime.now(), } self._cache.move_to_end(key) else: # Evict LRU entry when at capacity (popitem(last=False) is O(1)) if len(self._cache) >= self.max_size: self._cache.popitem(last=False) logger.debug("Cache eviction (LRU)") self._cache[key] = { "result": self._serialize_result(result), "timestamp": datetime.now(), } logger.debug("Cached result (total entries: %s)", len(self._cache))
[docs] def clear(self) -> None: """Remove all cached entries.""" self._cache.clear() logger.info("Cache cleared")
[docs] def size(self) -> int: """Return the number of currently cached entries.""" return len(self._cache)
[docs] def stats(self) -> Dict[str, Any]: """Return a dictionary of cache statistics.""" return { "size": len(self._cache), "max_size": self.max_size, "ttl_seconds": self.ttl_seconds, "oldest_entry_age": self._get_oldest_age(), }
# ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ @staticmethod def _hash_prompt(prompt: str) -> str: """Return the MD5 hex-digest of *prompt* (used as cache key).""" return hashlib.md5(prompt.encode("utf-8")).hexdigest() @staticmethod def _serialize_result(result: RiskScore) -> Dict[str, Any]: """Convert a ``RiskScore`` to a plain dict for storage.""" return result.to_dict() @staticmethod def _deserialize_result(data: Dict[str, Any]) -> RiskScore: """Reconstruct a ``RiskScore`` from a stored dict.""" return RiskScore( is_malicious=data["is_malicious"], probability=data["probability"], risk_level=RiskLevel(data["risk_level"]), confidence=data["confidence"], explanation=data["explanation"], metadata=data.get("metadata", {}), ) def _get_oldest_age(self) -> Optional[float]: """Return the age of the oldest cache entry in seconds, or ``None``.""" if not self._cache: return None oldest_time = min(e["timestamp"] for e in self._cache.values()) return (datetime.now() - oldest_time).total_seconds()