Source code for promptguard.core

"""Core PromptGuard classifier."""
import logging
from typing import Any, Dict, List, Optional, Tuple
from tqdm import tqdm
import torch
from .analyzers import (
    SentimentAnalyzer,
    IntentClassifier,
    KeywordExtractor,
    AttackPatternDetector,
)
from .sanitizers import PromptSanitizer
from .models import ModelLoader
from .schemas import (
    RiskScore,
    RiskLevel,
    Intent,
    SanitizationStrategy,
    SanitizeResponse,
)
from .exceptions import ValidationError, InferenceError
from .cache import PromptCache
from .config import PromptGuardConfig

logger = logging.getLogger(__name__)


[docs] class PromptGuard: """ Main PromptGuard classifier for detecting malicious prompts. """
[docs] def __init__( self, model_name: str = "arkaean/promptguard-distilbert", threshold: float = 0.5, device: Optional[str] = "auto", use_cache: bool = True, cache_size: int = 10000, cache_ttl: Optional[int] = 3600, enable_analysis: bool = True, enable_sanitization: bool = True, **kwargs ): """Initialise the PromptGuard classifier. Args: model_name: HuggingFace Hub model identifier. threshold: Malicious classification threshold in ``[0.0, 1.0]``. Prompts with a model probability at or above this value are classified as malicious. device: Device for model inference — ``"cuda"``, ``"cpu"``, or ``"auto"`` (selects CUDA when available). use_cache: Enable in-memory LRU caching of analysis results. cache_size: Maximum number of entries held in the cache. cache_ttl: Cache entry time-to-live in seconds. Pass ``None`` to disable expiry. enable_analysis: When ``True`` (default), enables supplementary sentiment, intent, keyword, and attack-pattern analysis that enriches :class:`~promptguard.schemas.RiskScore` metadata. enable_sanitization: When ``True`` (default), enables the :meth:`sanitize` and :meth:`sanitize_if_malicious` methods. **kwargs: Additional options forwarded to :class:`~promptguard.config.PromptGuardConfig`. """ # Create configuration self.config = PromptGuardConfig( model_name=model_name, threshold=threshold, device=device, **kwargs ) # Initialize cache self.use_cache = use_cache if use_cache: self.cache = PromptCache( max_size=cache_size, ttl_seconds=cache_ttl) logger.info("Caching enabled") else: self.cache = None logger.info("Caching disabled") # Initialize analyzers self.enable_analysis = enable_analysis if enable_analysis: self.sentiment_analyzer = SentimentAnalyzer() self.intent_classifier = IntentClassifier() self.keyword_extractor = KeywordExtractor() self.attack_detector = AttackPatternDetector() logger.info("Analysis features enabled") else: logger.info("Analysis features disabled") self.enable_sanitization = enable_sanitization if enable_sanitization: self.sanitizer = PromptSanitizer() logger.info("Sanitization features enabled") else: self.sanitizer = None logger.info("Sanitization features disabled") # Initialize model loader self.model_loader = ModelLoader(self.config) # Load model and tokenizer self.model, self.tokenizer = self.model_loader.load() logger.info("PromptGuard initialized with model: %s", model_name)
[docs] def sanitize( self, prompt: str, strategy: SanitizationStrategy = SanitizationStrategy.BALANCED, analyze_after: bool = True, ) -> SanitizeResponse: """Sanitise a potentially malicious prompt. Args: prompt: Prompt to sanitise. strategy: Sanitisation strategy to apply. analyze_after: When ``True`` (default), the sanitised prompt is re-analysed and the result stored in :attr:`SanitizeResponse.sanitized_analysis`. Returns: A :class:`~promptguard.schemas.SanitizeResponse` with the sanitisation outcome and before/after risk scores. Raises: ValueError: When sanitisation is not enabled on this instance. """ if not self.enable_sanitization or self.sanitizer is None: raise ValueError("Sanitization is not enabled") original_analysis = self.analyze(prompt) sanitization = self.sanitizer.sanitize(prompt, strategy) sanitized_analysis = None if analyze_after and sanitization.was_modified: sanitized_analysis = self.analyze(sanitization.sanitized) risk_after = sanitized_analysis.probability if sanitized_analysis else None risk_reduction = original_analysis.probability - ( risk_after if risk_after is not None else original_analysis.probability ) return SanitizeResponse( sanitization=sanitization, original_analysis=original_analysis, sanitized_analysis=sanitized_analysis, risk_before=original_analysis.probability, risk_after=risk_after, risk_reduction=risk_reduction, )
[docs] def sanitize_if_malicious( self, prompt: str, strategy: SanitizationStrategy = SanitizationStrategy.BALANCED ) -> Tuple[str, bool]: """ Sanitize prompt only if it's detected as malicious. Args: prompt: Prompt to check and potentially sanitize strategy: Sanitization strategy if needed Returns: Tuple of (potentially_sanitized_prompt, was_sanitized) """ if not self.enable_sanitization or self.sanitizer is None: raise ValueError("Sanitization is not enabled") # Check if malicious analysis = self.analyze(prompt) if analysis.is_malicious: # Sanitize sanitization = self.sanitizer.sanitize(prompt, strategy) return sanitization.sanitized, True else: # Return unchanged return prompt, False
def _perform_analysis( self, prompt: str, probability: float, is_batch: bool = False ) -> RiskScore: """ Perform analysis on a single prompt (shared by analyze and analyze_batch). Args: prompt: The prompt text probability: Malicious probability from model is_batch: Whether this is from batch processing Returns: Complete RiskScore with all analysis """ # Classify is_malicious = probability >= self.config.threshold # Determine risk level risk_level = self._get_risk_level(probability) # Calculate confidence confidence = abs(probability - 0.5) * 2 # Perform additional analysis if enabled metadata = { "model": self.config.model_name, "threshold": self.config.threshold, "prompt_length": len(prompt) } if is_batch: metadata['batch_processed'] = True if self.enable_analysis: # Sentiment analysis sentiment_result = self.sentiment_analyzer.analyze(prompt) metadata['sentiment'] = sentiment_result # Intent classification intent_result = self.intent_classifier.classify(prompt) metadata['intent'] = intent_result # Extract keywords if malicious or suspicious if is_malicious or probability > 0.3: keywords = self.keyword_extractor.extract(prompt) metadata['keywords'] = keywords # Detect attack patterns attack_patterns = self.attack_detector.detect(prompt) metadata['attack_patterns'] = attack_patterns # Generate enhanced explanation explanation = self._generate_explanation( prompt, probability, is_malicious, metadata if self.enable_analysis else None ) return RiskScore( is_malicious=is_malicious, probability=float(probability), risk_level=risk_level, confidence=float(confidence), explanation=explanation, metadata=metadata )
[docs] def analyze(self, prompt: str) -> RiskScore: """ Analyze a single prompt for malicious content. """ # Validate input if not isinstance(prompt, str) or not prompt.strip(): raise ValidationError("Prompt must be a non-empty string") # Check cache first if self.use_cache and self.cache is not None: cached_result = self.cache.get(prompt) if cached_result is not None: logger.debug("Returning cached result") return cached_result try: # Get probability probability = self._predict_single(prompt) # Perform analysis (now uses shared method) result = self._perform_analysis( prompt, probability, is_batch=False) # Cache result if self.use_cache and self.cache is not None: self.cache.set(prompt, result) return result except Exception as e: if isinstance(e, (ValidationError, InferenceError)): raise error_msg = f"Failed to analyze prompt: {str(e)}" logger.error(error_msg) raise InferenceError(error_msg) from e
[docs] def clear_cache(self) -> None: """ Clear the analysis cache """ if self.cache is not None: self.cache.clear() logger.info("Cache cleared") else: logger.warning("Caching is not enabled")
[docs] def cache_stats(self) -> Optional[Dict[str, Any]]: """ Get cache statistics """ if self.cache is not None: return self.cache.stats() return None
def _predict_single(self, prompt: str) -> float: """ Get probability for a single prompt. """ # Tokenize inputs = self.tokenizer( prompt, truncation=True, max_length=self.config.max_length, padding=True, return_tensors="pt" ) # Move to device inputs = {k: v.to(self.model_loader.device) for k, v in inputs.items()} # Predict with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.softmax(outputs.logits, dim=1) malicious_prob = probabilities[0, 1].item() return malicious_prob def _get_risk_level(self, probability: float) -> RiskLevel: """ Determine risk level based on probability. """ if probability < 0.3: return RiskLevel.LOW elif probability < 0.7: return RiskLevel.MEDIUM else: return RiskLevel.HIGH def _generate_explanation( self, prompt: str, probability: float, is_malicious: bool, metadata: Optional[Dict[str, Any]] = None ) -> str: """ Generate human-readable explanation with evidence. Args: prompt: The analyzed prompt probability: Malicious probability is_malicious: Classification result metadata: Additional analysis metadata Returns: Explanation string """ if is_malicious: explanation_parts = [] # Base explanation if probability > 0.9: explanation_parts.append( f"This prompt is highly likely to be malicious " f"({probability:.1%} confidence)." ) elif probability > 0.7: explanation_parts.append( f"This prompt appears to be malicious " f"({probability:.1%} confidence)." ) else: explanation_parts.append( f"This prompt is classified as malicious " f"({probability:.1%} confidence)." ) # Add evidence from metadata if metadata: evidence = [] # Intent evidence if 'intent' in metadata: intent_data = metadata['intent'] if intent_data['intent'] in [Intent.JAILBREAK, Intent.INJECTION]: evidence.append( f"Detected {intent_data['intent'].value} attempt") # Attack patterns if 'attack_patterns' in metadata: attack_data = metadata['attack_patterns'] if attack_data['has_attack_patterns']: attack_types = ', '.join(attack_data['attack_types']) evidence.append(f"Attack patterns: {attack_types}") # Keywords if 'keywords' in metadata and metadata['keywords']: keywords_str = ', '.join( f"'{kw}'" for kw in metadata['keywords'][:3]) evidence.append(f"Suspicious keywords: {keywords_str}") # Sentiment if 'sentiment' in metadata: sentiment_data = metadata['sentiment'] if sentiment_data['is_aggressive']: evidence.append("Aggressive tone detected") if evidence: explanation_parts.append( " Evidence: " + "; ".join(evidence) + ".") return " ".join(explanation_parts) else: # Benign explanation explanation = ( f"This prompt appears benign " f"({(1-probability):.1%} confidence). " f"No significant security concerns detected." ) # Add intent info if available if metadata and 'intent' in metadata: intent_data = metadata['intent'] explanation += f" Intent: {intent_data['description']}" return explanation
[docs] def classify(self, prompt: str, threshold: Optional[float] = None) -> bool: """ Simple binary classification. """ result = self.analyze(prompt) if threshold is not None: return result.probability >= threshold return result.is_malicious
[docs] def analyze_batch( self, prompts: List[str], batch_size: Optional[int] = None, show_progress: bool = True ) -> List[Optional[RiskScore]]: """ Analyze multiple prompts efficiently in batches, using cache when enabled. """ if not prompts: raise ValidationError("Prompts list cannot be empty") if not isinstance(prompts, list): raise ValidationError("Prompts must be a list of strings") batch_size = batch_size or self.config.batch_size final_results: List[Optional[RiskScore]] = [None] * len(prompts) # Track prompts that need inference uncached_prompts = [] uncached_indices = [] for i, prompt in enumerate(prompts): if not isinstance(prompt, str) or len(prompt.strip()) == 0: logger.warning("Skipping invalid prompt at index %i", i) continue # Check cache if self.use_cache and self.cache is not None: cached = self.cache.get(prompt) if cached is not None: final_results[i] = cached continue # Needs prediction uncached_prompts.append(prompt) uncached_indices.append(i) if not uncached_prompts: return final_results # Create progress iterator if show_progress: batches = tqdm( range(0, len(uncached_prompts), batch_size), desc="Analyzing prompts", unit="batch" ) else: batches = range(0, len(uncached_prompts), batch_size) for i in batches: batch_prompts = uncached_prompts[i:i + batch_size] batch_probs = self._predict_batch(batch_prompts) for j, prob in enumerate(batch_probs): prompt = batch_prompts[j] original_index = uncached_indices[i + j] # Use shared analysis method (NOW INCLUDES ALL FEATURES!) result = self._perform_analysis(prompt, prob, is_batch=True) # Store result final_results[original_index] = result # Cache it if self.use_cache and self.cache is not None: self.cache.set(prompt, result) return final_results
def _predict_batch(self, prompts: List[str]) -> List[float]: """ Get probabilities for a batch of prompts """ # Tokenize batch inputs = self.tokenizer( prompts, truncation=True, max_length=self.config.max_length, padding=True, return_tensors="pt" ) # Move to device inputs = {k: v.to(self.model_loader.device) for k, v in inputs.items()} # Predict with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.softmax(outputs.logits, dim=1) malicious_probs = probabilities[:, 1].cpu().numpy() return malicious_probs.tolist()
[docs] def classify_batch( self, prompts: List[str], threshold: Optional[float] = None, show_progress: bool = False ) -> List[Optional[bool]]: """ Simple binary classification for multiple prompts """ results = self.analyze_batch(prompts, show_progress=show_progress) if threshold is None: threshold = self.config.threshold return [ result.probability >= threshold if result is not None else None for result in results ]
@property def device(self) -> str: """Get the device being used for inference.""" return self.model_loader.device @property def threshold(self) -> float: """Get current classification threshold.""" return self.config.threshold @threshold.setter def threshold(self, value: float): """Set classification threshold.""" if not 0.0 <= value <= 1.0: raise ValueError(f"Threshold must be between 0 and 1, got {value}") self.config.threshold = value logger.info("Threshold updated to: %f", value)