diff --git a/cortex/config_manager.py b/cortex/config_manager.py old mode 100755 new mode 100644 index 9b6e22dd..d7f35fd0 --- a/cortex/config_manager.py +++ b/cortex/config_manager.py @@ -9,6 +9,7 @@ import os import re import subprocess +import threading from datetime import datetime from pathlib import Path from typing import Any, ClassVar @@ -43,17 +44,24 @@ class ConfigManager: def __init__(self, sandbox_executor=None): """ - Initialize ConfigManager. - - Args: - sandbox_executor: Optional SandboxExecutor instance for safe command execution - + Create a ConfigManager and prepare its runtime environment. + + Parameters: + sandbox_executor (optional): Executor used to run system commands inside a sandbox; if None, commands run directly. + + Notes: + - Initializes paths `~/.cortex` and `~/.cortex/preferences.yaml`. + - Creates the `~/.cortex` directory with mode 0700 if it does not exist. + - Enforces ownership and 0700 permissions on the directory. + - Creates an instance-level lock to serialize preferences file I/O. + Raises: - PermissionError: If directory ownership or permissions cannot be secured + PermissionError: If ownership or permissions for the cortex directory cannot be secured. """ self.sandbox_executor = sandbox_executor self.cortex_dir = Path.home() / ".cortex" self.preferences_file = self.cortex_dir / "preferences.yaml" + self._file_lock = threading.Lock() # Protect file I/O operations # Ensure .cortex directory exists with secure permissions self.cortex_dir.mkdir(mode=0o700, exist_ok=True) @@ -273,15 +281,18 @@ def _detect_os_version(self) -> str: def _load_preferences(self) -> dict[str, Any]: """ - Load user preferences from ~/.cortex/preferences.yaml. - + Load preferences from the user's preferences YAML file and return them as a dictionary. + + Reads the configured preferences file while holding the instance file lock. If the file does not exist, is empty, malformed, or cannot be read, this returns an empty dict. + Returns: - Dictionary of preferences + dict: Preferences mapping loaded from YAML, or an empty dict if no valid preferences could be loaded. """ if self.preferences_file.exists(): try: - with open(self.preferences_file) as f: - return yaml.safe_load(f) or {} + with self._file_lock: + with open(self.preferences_file) as f: + return yaml.safe_load(f) or {} except Exception: pass @@ -289,14 +300,20 @@ def _load_preferences(self) -> dict[str, Any]: def _save_preferences(self, preferences: dict[str, Any]) -> None: """ - Save user preferences to ~/.cortex/preferences.yaml. - - Args: - preferences: Dictionary of preferences to save + Save the provided preferences dictionary to the user's preferences file (~/.cortex/preferences.yaml). + + Writes the preferences as YAML while acquiring the instance file lock to serialize concurrent access. + + Parameters: + preferences (dict[str, Any]): Preferences to persist. + + Raises: + RuntimeError: If writing the preferences file fails. """ try: - with open(self.preferences_file, "w") as f: - yaml.safe_dump(preferences, f, default_flow_style=False) + with self._file_lock: + with open(self.preferences_file, "w") as f: + yaml.safe_dump(preferences, f, default_flow_style=False) except Exception as e: raise RuntimeError(f"Failed to save preferences: {e}") @@ -1064,4 +1081,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/cortex/context_memory.py b/cortex/context_memory.py index 55a13734..46b0e213 100644 --- a/cortex/context_memory.py +++ b/cortex/context_memory.py @@ -17,6 +17,8 @@ from pathlib import Path from typing import Any +from cortex.utils.db_pool import get_connection_pool, SQLiteConnectionPool + @dataclass class MemoryEntry: @@ -83,125 +85,131 @@ def __init__(self, db_path: str = "~/.cortex/context_memory.db"): """Initialize the context memory system""" self.db_path = Path(db_path).expanduser() self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._pool: SQLiteConnectionPool | None = None self._init_database() def _init_database(self): - """Initialize SQLite database schema""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + """ + Initialize the thread-safe SQLite connection pool and ensure the database schema exists. + + Sets up a pooled SQLite connection for the instance and creates required tables (memory_entries, patterns, suggestions, preferences) along with indexes used for query performance. + """ + # Initialize connection pool (thread-safe singleton) + self._pool = get_connection_pool(str(self.db_path), pool_size=5) + + with self._pool.get_connection() as conn: + cursor = conn.cursor() - # Memory entries table - cursor.execute( + # Memory entries table + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS memory_entries ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + category TEXT NOT NULL, + context TEXT, + action TEXT NOT NULL, + result TEXT, + success BOOLEAN DEFAULT 1, + confidence REAL DEFAULT 1.0, + frequency INTEGER DEFAULT 1, + metadata TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) """ - CREATE TABLE IF NOT EXISTS memory_entries ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL, - category TEXT NOT NULL, - context TEXT, - action TEXT NOT NULL, - result TEXT, - success BOOLEAN DEFAULT 1, - confidence REAL DEFAULT 1.0, - frequency INTEGER DEFAULT 1, - metadata TEXT, - created_at TEXT DEFAULT CURRENT_TIMESTAMP ) - """ - ) - # Patterns table - cursor.execute( + # Patterns table + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS patterns ( + pattern_id TEXT PRIMARY KEY, + pattern_type TEXT NOT NULL, + description TEXT, + frequency INTEGER DEFAULT 1, + last_seen TEXT, + confidence REAL DEFAULT 0.0, + actions TEXT, + context TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) """ - CREATE TABLE IF NOT EXISTS patterns ( - pattern_id TEXT PRIMARY KEY, - pattern_type TEXT NOT NULL, - description TEXT, - frequency INTEGER DEFAULT 1, - last_seen TEXT, - confidence REAL DEFAULT 0.0, - actions TEXT, - context TEXT, - created_at TEXT DEFAULT CURRENT_TIMESTAMP ) - """ - ) - # Suggestions table - cursor.execute( + # Suggestions table + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS suggestions ( + suggestion_id TEXT PRIMARY KEY, + suggestion_type TEXT NOT NULL, + title TEXT NOT NULL, + description TEXT, + confidence REAL DEFAULT 0.0, + based_on TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + dismissed BOOLEAN DEFAULT 0 + ) """ - CREATE TABLE IF NOT EXISTS suggestions ( - suggestion_id TEXT PRIMARY KEY, - suggestion_type TEXT NOT NULL, - title TEXT NOT NULL, - description TEXT, - confidence REAL DEFAULT 0.0, - based_on TEXT, - created_at TEXT DEFAULT CURRENT_TIMESTAMP, - dismissed BOOLEAN DEFAULT 0 ) - """ - ) - # User preferences table - cursor.execute( + # User preferences table + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS preferences ( + key TEXT PRIMARY KEY, + value TEXT, + category TEXT, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP + ) """ - CREATE TABLE IF NOT EXISTS preferences ( - key TEXT PRIMARY KEY, - value TEXT, - category TEXT, - updated_at TEXT DEFAULT CURRENT_TIMESTAMP ) - """ - ) - # Create indexes for performance - cursor.execute("CREATE INDEX IF NOT EXISTS idx_memory_category ON memory_entries(category)") - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_memory_timestamp ON memory_entries(timestamp)" - ) - cursor.execute("CREATE INDEX IF NOT EXISTS idx_patterns_type ON patterns(pattern_type)") - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_suggestions_type ON suggestions(suggestion_type)" - ) + # Create indexes for performance + cursor.execute("CREATE INDEX IF NOT EXISTS idx_memory_category ON memory_entries(category)") + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_memory_timestamp ON memory_entries(timestamp)" + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_patterns_type ON patterns(pattern_type)") + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_suggestions_type ON suggestions(suggestion_type)" + ) - conn.commit() - conn.close() + conn.commit() def record_interaction(self, entry: MemoryEntry) -> int: """ - Record a user interaction in memory - - Args: - entry: MemoryEntry object containing interaction details - + Store a MemoryEntry in persistent storage and trigger pattern analysis. + + Parameters: + entry (MemoryEntry): Interaction record to persist. + Returns: - ID of the inserted memory entry + int: Row ID of the inserted memory entry. """ - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute( - """ - INSERT INTO memory_entries - (timestamp, category, context, action, result, success, confidence, frequency, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - entry.timestamp, - entry.category, - entry.context, - entry.action, - entry.result, - entry.success, - entry.confidence, - entry.frequency, - json.dumps(entry.metadata), - ), - ) + cursor.execute( + """ + INSERT INTO memory_entries + (timestamp, category, context, action, result, success, confidence, frequency, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.timestamp, + entry.category, + entry.context, + entry.action, + entry.result, + entry.success, + entry.confidence, + entry.frequency, + json.dumps(entry.metadata), + ), + ) - entry_id = cursor.lastrowid - conn.commit() - conn.close() + entry_id = cursor.lastrowid + conn.commit() # Trigger pattern analysis self._analyze_patterns(entry) @@ -210,43 +218,48 @@ def record_interaction(self, entry: MemoryEntry) -> int: def get_similar_interactions(self, context: str, limit: int = 10) -> list[MemoryEntry]: """ - Find similar past interactions based on context - - Args: - context: Context string to match against - limit: Maximum number of results - + Find past interactions whose context or action text matches keywords extracted from the given context, returning the most recent unique matches. + + Parameters: + context (str): Text used to extract keywords for matching against stored interaction contexts and actions. + limit (int): Maximum number of returned entries. + Returns: - List of similar MemoryEntry objects + list[MemoryEntry]: Up to `limit` MemoryEntry objects that match the context or action (ordered by timestamp descending, duplicates removed). """ - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - # Simple keyword-based similarity for now - keywords = self._extract_keywords(context) - - results = [] - for keyword in keywords: - cursor.execute( - """ - SELECT * FROM memory_entries - WHERE context LIKE ? OR action LIKE ? - ORDER BY timestamp DESC - LIMIT ? - """, - (f"%{keyword}%", f"%{keyword}%", limit), - ) + with self._pool.get_connection() as conn: + cursor = conn.cursor() + + # Simple keyword-based similarity for now + keywords = self._extract_keywords(context) + + results = [] + for keyword in keywords: + cursor.execute( + """ + SELECT * FROM memory_entries + WHERE context LIKE ? OR action LIKE ? + ORDER BY timestamp DESC + LIMIT ? + """, + (f"%{keyword}%", f"%{keyword}%", limit), + ) - for row in cursor.fetchall(): - entry = self._row_to_memory_entry(row) - if entry not in results: - results.append(entry) + for row in cursor.fetchall(): + entry = self._row_to_memory_entry(row) + if entry not in results: + results.append(entry) - conn.close() return results[:limit] def _row_to_memory_entry(self, row: tuple) -> MemoryEntry: - """Convert database row to MemoryEntry object""" + """ + Convert a database row tuple into a MemoryEntry dataclass. + + Expected row layout: (id, timestamp, category, context, action, result, success, confidence, frequency, metadata_json). + @param row (tuple): Database row with columns in the order listed above; `metadata_json` may be NULL or a JSON string. + @returns MemoryEntry: A MemoryEntry populated from the row; `success` is converted to `bool` and `metadata` is parsed from JSON (empty dict if missing). + """ return MemoryEntry( id=row[0], timestamp=row[1], @@ -283,59 +296,66 @@ def _extract_keywords(self, text: str) -> list[str]: def _analyze_patterns(self, entry: MemoryEntry): """ - Analyze entry for patterns and update pattern database - - This runs after each new entry to detect recurring patterns + Detect recurring actions related to a memory entry and create or update corresponding pattern records. + + Scans recent memory entries in the same category (past 30 days) for actions that occur at least three times. For each recurring action found, inserts a new pattern or updates the existing pattern's frequency, last-seen timestamp, and confidence (confidence increases with observed frequency, capped at 1.0). + + Parameters: + entry (MemoryEntry): The newly recorded memory entry used as the basis for analyzing patterns. """ - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - # Look for similar actions in recent history - cursor.execute( - """ - SELECT action, COUNT(*) as count - FROM memory_entries - WHERE category = ? - AND timestamp > datetime('now', '-30 days') - GROUP BY action - HAVING count >= 3 - """, - (entry.category,), - ) - - for row in cursor.fetchall(): - action, frequency = row - pattern_id = self._generate_pattern_id(entry.category, action) + with self._pool.get_connection() as conn: + cursor = conn.cursor() - # Update or create pattern + # Look for similar actions in recent history cursor.execute( """ - INSERT INTO patterns (pattern_id, pattern_type, description, frequency, last_seen, confidence, actions, context) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(pattern_id) DO UPDATE SET - frequency = ?, - last_seen = ?, - confidence = MIN(1.0, confidence + 0.1) + SELECT action, COUNT(*) as count + FROM memory_entries + WHERE category = ? + AND timestamp > datetime('now', '-30 days') + GROUP BY action + HAVING count >= 3 """, - ( - pattern_id, - entry.category, - f"Recurring pattern: {action}", - frequency, - entry.timestamp, - min(1.0, frequency / 10.0), # Confidence increases with frequency - json.dumps([action]), - json.dumps({"category": entry.category}), - frequency, - entry.timestamp, - ), + (entry.category,), ) - conn.commit() - conn.close() + for row in cursor.fetchall(): + action, frequency = row + pattern_id = self._generate_pattern_id(entry.category, action) + + # Update or create pattern + cursor.execute( + """ + INSERT INTO patterns (pattern_id, pattern_type, description, frequency, last_seen, confidence, actions, context) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(pattern_id) DO UPDATE SET + frequency = ?, + last_seen = ?, + confidence = MIN(1.0, confidence + 0.1) + """, + ( + pattern_id, + entry.category, + f"Recurring pattern: {action}", + frequency, + entry.timestamp, + min(1.0, frequency / 10.0), # Confidence increases with frequency + json.dumps([action]), + json.dumps({"category": entry.category}), + frequency, + entry.timestamp, + ), + ) + + conn.commit() def _generate_pattern_id(self, category: str, action: str) -> str: - """Generate unique pattern ID""" + """ + Generate a stable 16-character hexadecimal identifier for a pattern based on category and action. + + Returns: + 16-character hexadecimal string derived deterministically from the given category and action. + """ content = f"{category}:{action}".encode() return hashlib.sha256(content).hexdigest()[:16] @@ -343,58 +363,59 @@ def get_patterns( self, pattern_type: str | None = None, min_confidence: float = 0.5 ) -> list[Pattern]: """ - Retrieve learned patterns - - Args: - pattern_type: Filter by pattern type - min_confidence: Minimum confidence threshold - + Retrieve stored patterns that match an optional type and meet a minimum confidence threshold. + + Parameters: + pattern_type (str | None): If provided, only patterns with this type are returned. + min_confidence (float): Minimum confidence (0.0–1.0) required for returned patterns. + Returns: - List of Pattern objects + list[Pattern]: Patterns ordered by confidence descending then frequency descending. """ - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - query = """ - SELECT * FROM patterns - WHERE confidence >= ? - """ - params = [min_confidence] - - if pattern_type: - query += " AND pattern_type = ?" - params.append(pattern_type) - - query += " ORDER BY confidence DESC, frequency DESC" - - cursor.execute(query, params) - - patterns = [] - for row in cursor.fetchall(): - pattern = Pattern( - pattern_id=row[0], - pattern_type=row[1], - description=row[2], - frequency=row[3], - last_seen=row[4], - confidence=row[5], - actions=json.loads(row[6]), - context=json.loads(row[7]), - ) - patterns.append(pattern) + query = """ + SELECT * FROM patterns + WHERE confidence >= ? + """ + params = [min_confidence] + + if pattern_type: + query += " AND pattern_type = ?" + params.append(pattern_type) + + query += " ORDER BY confidence DESC, frequency DESC" + + cursor.execute(query, params) + + patterns = [] + for row in cursor.fetchall(): + pattern = Pattern( + pattern_id=row[0], + pattern_type=row[1], + description=row[2], + frequency=row[3], + last_seen=row[4], + confidence=row[5], + actions=json.loads(row[6]), + context=json.loads(row[7]), + ) + patterns.append(pattern) - conn.close() return patterns def generate_suggestions(self, context: str = None) -> list[Suggestion]: """ - Generate intelligent suggestions based on memory and patterns - - Args: - context: Optional context to focus suggestions - + Generate context-aware suggestions from recent memory and learned patterns. + + Builds optimization, alternative, and proactive suggestions using high-confidence patterns and memory entries from the last 7 days. If `context` is provided, suggestions are focused on that context. Each generated Suggestion is persisted to the suggestions table before being returned. + + Parameters: + context (str | None): Optional text used to focus or filter generated suggestions. + Returns: - List of Suggestion objects + list[Suggestion]: The list of generated Suggestion objects. """ suggestions = [] @@ -402,19 +423,19 @@ def generate_suggestions(self, context: str = None) -> list[Suggestion]: patterns = self.get_patterns(min_confidence=0.7) # Get recent memory entries - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute( + cursor.execute( + """ + SELECT * FROM memory_entries + WHERE timestamp > datetime('now', '-7 days') + ORDER BY timestamp DESC + LIMIT 50 """ - SELECT * FROM memory_entries - WHERE timestamp > datetime('now', '-7 days') - ORDER BY timestamp DESC - LIMIT 50 - """ - ) + ) - recent_entries = [self._row_to_memory_entry(row) for row in cursor.fetchall()] + recent_entries = [self._row_to_memory_entry(row) for row in cursor.fetchall()] # Analyze for optimization opportunities suggestions.extend(self._suggest_optimizations(recent_entries, patterns)) @@ -425,8 +446,6 @@ def generate_suggestions(self, context: str = None) -> list[Suggestion]: # Suggest proactive actions based on patterns suggestions.extend(self._suggest_proactive_actions(patterns)) - conn.close() - # Store suggestions for suggestion in suggestions: self._store_suggestion(suggestion) @@ -507,233 +526,285 @@ def _generate_suggestion_id(self, suggestion_type: str, identifier: str) -> str: return hashlib.sha256(content).hexdigest()[:16] def _store_suggestion(self, suggestion: Suggestion): - """Store suggestion in database""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + """ + Insert a Suggestion into the suggestions table, ignoring duplicates. + + The suggestion's `based_on` list is JSON-encoded before storage. If a row with the same `suggestion_id` already exists, the insert is ignored. The change is committed to the database. + + Parameters: + suggestion (Suggestion): Suggestion object to persist; its fields (suggestion_id, suggestion_type, title, description, confidence, based_on, created_at) are stored. + """ + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute( - """ - INSERT OR IGNORE INTO suggestions - (suggestion_id, suggestion_type, title, description, confidence, based_on, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - ( - suggestion.suggestion_id, - suggestion.suggestion_type, - suggestion.title, - suggestion.description, - suggestion.confidence, - json.dumps(suggestion.based_on), - suggestion.created_at, - ), - ) + cursor.execute( + """ + INSERT OR IGNORE INTO suggestions + (suggestion_id, suggestion_type, title, description, confidence, based_on, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + suggestion.suggestion_id, + suggestion.suggestion_type, + suggestion.title, + suggestion.description, + suggestion.confidence, + json.dumps(suggestion.based_on), + suggestion.created_at, + ), + ) - conn.commit() - conn.close() + conn.commit() def get_active_suggestions(self, limit: int = 10) -> list[Suggestion]: - """Get active (non-dismissed) suggestions""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - cursor.execute( - """ - SELECT * FROM suggestions - WHERE dismissed = 0 - ORDER BY confidence DESC, created_at DESC - LIMIT ? - """, - (limit,), - ) + """ + Return a list of active (not dismissed) suggestions ordered by confidence and recency. + + Parameters: + limit (int): Maximum number of suggestions to return. + + Returns: + list[Suggestion]: Active suggestions ordered by descending confidence then descending creation time, limited to `limit`. + """ + with self._pool.get_connection() as conn: + cursor = conn.cursor() - suggestions = [] - for row in cursor.fetchall(): - suggestion = Suggestion( - suggestion_id=row[0], - suggestion_type=row[1], - title=row[2], - description=row[3], - confidence=row[4], - based_on=json.loads(row[5]), - created_at=row[6], + cursor.execute( + """ + SELECT * FROM suggestions + WHERE dismissed = 0 + ORDER BY confidence DESC, created_at DESC + LIMIT ? + """, + (limit,), ) - suggestions.append(suggestion) - conn.close() + suggestions = [] + for row in cursor.fetchall(): + suggestion = Suggestion( + suggestion_id=row[0], + suggestion_type=row[1], + title=row[2], + description=row[3], + confidence=row[4], + based_on=json.loads(row[5]), + created_at=row[6], + ) + suggestions.append(suggestion) + return suggestions def dismiss_suggestion(self, suggestion_id: str): - """Mark a suggestion as dismissed""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + """ + Mark a suggestion as dismissed so it is excluded from active suggestions. + + Parameters: + suggestion_id (str): Unique identifier of the suggestion to dismiss. + """ + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute( - """ - UPDATE suggestions - SET dismissed = 1 - WHERE suggestion_id = ? - """, - (suggestion_id,), - ) + cursor.execute( + """ + UPDATE suggestions + SET dismissed = 1 + WHERE suggestion_id = ? + """, + (suggestion_id,), + ) - conn.commit() - conn.close() + conn.commit() def set_preference(self, key: str, value: Any, category: str = "general"): - """Store a user preference""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + """ + Store or update a user preference in the persistent preferences table. + + Parameters: + key (str): Preference key identifier. + value (Any): Preference value; will be JSON-encoded before storage. + category (str): Preference category or namespace; defaults to "general". + + Detailed behavior: + - Inserts a new preference row or updates the existing row with the same key. + - Updates the `updated_at` timestamp to the current time in ISO 8601 format. + """ + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute( - """ - INSERT INTO preferences (key, value, category, updated_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(key) DO UPDATE SET - value = ?, - updated_at = ? - """, - ( - key, - json.dumps(value), - category, - datetime.now().isoformat(), - json.dumps(value), - datetime.now().isoformat(), - ), - ) + cursor.execute( + """ + INSERT INTO preferences (key, value, category, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(key) DO UPDATE SET + value = ?, + updated_at = ? + """, + ( + key, + json.dumps(value), + category, + datetime.now().isoformat(), + json.dumps(value), + datetime.now().isoformat(), + ), + ) - conn.commit() - conn.close() + conn.commit() def get_preference(self, key: str, default: Any = None) -> Any: - """Retrieve a user preference""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + """ + Get a stored preference value by key. + + Parameters: + key (str): Preference key to look up. + default: Value to return when the preference is not found. + + Returns: + The decoded preference value from storage if present, otherwise `default`. + """ + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute( - """ - SELECT value FROM preferences WHERE key = ? - """, - (key,), - ) + cursor.execute( + """ + SELECT value FROM preferences WHERE key = ? + """, + (key,), + ) - row = cursor.fetchone() - conn.close() + row = cursor.fetchone() if row: return json.loads(row[0]) return default def get_statistics(self) -> dict[str, Any]: - """Get memory system statistics""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + """ + Return a summary of usage and system statistics for the context memory. + + Returns: + stats (dict[str, Any]): A dictionary with the following keys: + - total_entries (int): Total number of recorded memory entries. + - by_category (dict[str, int]): Mapping of category names to their entry counts. + - success_rate (float): Percentage of entries marked successful, rounded to two decimals. + - total_patterns (int): Total number of learned patterns. + - active_suggestions (int): Count of suggestions that are not dismissed. + - recent_activity (int): Number of memory entries recorded in the last 7 days. + """ + with self._pool.get_connection() as conn: + cursor = conn.cursor() - stats = {} + stats = {} - # Total entries - cursor.execute("SELECT COUNT(*) FROM memory_entries") - stats["total_entries"] = cursor.fetchone()[0] + # Total entries + cursor.execute("SELECT COUNT(*) FROM memory_entries") + stats["total_entries"] = cursor.fetchone()[0] - # Entries by category - cursor.execute( + # Entries by category + cursor.execute( + """ + SELECT category, COUNT(*) + FROM memory_entries + GROUP BY category """ - SELECT category, COUNT(*) - FROM memory_entries - GROUP BY category - """ - ) - stats["by_category"] = dict(cursor.fetchall()) + ) + stats["by_category"] = dict(cursor.fetchall()) - # Success rate - cursor.execute( + # Success rate + cursor.execute( + """ + SELECT + SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) * 100.0 / COUNT(*) as success_rate + FROM memory_entries """ - SELECT - SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) * 100.0 / COUNT(*) as success_rate - FROM memory_entries - """ - ) - stats["success_rate"] = round(cursor.fetchone()[0], 2) if stats["total_entries"] > 0 else 0 + ) + stats["success_rate"] = round(cursor.fetchone()[0], 2) if stats["total_entries"] > 0 else 0 - # Total patterns - cursor.execute("SELECT COUNT(*) FROM patterns") - stats["total_patterns"] = cursor.fetchone()[0] + # Total patterns + cursor.execute("SELECT COUNT(*) FROM patterns") + stats["total_patterns"] = cursor.fetchone()[0] - # Active suggestions - cursor.execute("SELECT COUNT(*) FROM suggestions WHERE dismissed = 0") - stats["active_suggestions"] = cursor.fetchone()[0] + # Active suggestions + cursor.execute("SELECT COUNT(*) FROM suggestions WHERE dismissed = 0") + stats["active_suggestions"] = cursor.fetchone()[0] - # Recent activity - cursor.execute( + # Recent activity + cursor.execute( + """ + SELECT COUNT(*) FROM memory_entries + WHERE timestamp > datetime('now', '-7 days') """ - SELECT COUNT(*) FROM memory_entries - WHERE timestamp > datetime('now', '-7 days') - """ - ) - stats["recent_activity"] = cursor.fetchone()[0] + ) + stats["recent_activity"] = cursor.fetchone()[0] - conn.close() return stats def export_memory(self, output_path: str, include_dismissed: bool = False): - """Export all memory data to JSON""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - data = { - "exported_at": datetime.now().isoformat(), - "entries": [], - "patterns": [], - "suggestions": [], - "preferences": [], - } - - # Export entries - cursor.execute("SELECT * FROM memory_entries") - for row in cursor.fetchall(): - entry = self._row_to_memory_entry(row) - data["entries"].append(asdict(entry)) - - # Export patterns - cursor.execute("SELECT * FROM patterns") - for row in cursor.fetchall(): - pattern = { - "pattern_id": row[0], - "pattern_type": row[1], - "description": row[2], - "frequency": row[3], - "last_seen": row[4], - "confidence": row[5], - "actions": json.loads(row[6]), - "context": json.loads(row[7]), - } - data["patterns"].append(pattern) - - # Export suggestions - query = "SELECT * FROM suggestions" - if not include_dismissed: - query += " WHERE dismissed = 0" - cursor.execute(query) - - for row in cursor.fetchall(): - suggestion = { - "suggestion_id": row[0], - "suggestion_type": row[1], - "title": row[2], - "description": row[3], - "confidence": row[4], - "based_on": json.loads(row[5]), - "created_at": row[6], + """ + Export the stored memory, patterns, suggestions, and preferences to a JSON file. + + Parameters: + output_path (str): Filesystem path where the exported JSON will be written. + include_dismissed (bool): If True, include suggestions that have been dismissed; otherwise omit dismissed suggestions. + + Returns: + str: The path to the written JSON file (same as `output_path`). + """ + with self._pool.get_connection() as conn: + cursor = conn.cursor() + + data = { + "exported_at": datetime.now().isoformat(), + "entries": [], + "patterns": [], + "suggestions": [], + "preferences": [], } - data["suggestions"].append(suggestion) - # Export preferences - cursor.execute("SELECT key, value, category FROM preferences") - for row in cursor.fetchall(): - pref = {"key": row[0], "value": json.loads(row[1]), "category": row[2]} - data["preferences"].append(pref) + # Export entries + cursor.execute("SELECT * FROM memory_entries") + for row in cursor.fetchall(): + entry = self._row_to_memory_entry(row) + data["entries"].append(asdict(entry)) + + # Export patterns + cursor.execute("SELECT * FROM patterns") + for row in cursor.fetchall(): + pattern = { + "pattern_id": row[0], + "pattern_type": row[1], + "description": row[2], + "frequency": row[3], + "last_seen": row[4], + "confidence": row[5], + "actions": json.loads(row[6]), + "context": json.loads(row[7]), + } + data["patterns"].append(pattern) + + # Export suggestions + query = "SELECT * FROM suggestions" + if not include_dismissed: + query += " WHERE dismissed = 0" + cursor.execute(query) - conn.close() + for row in cursor.fetchall(): + suggestion = { + "suggestion_id": row[0], + "suggestion_type": row[1], + "title": row[2], + "description": row[3], + "confidence": row[4], + "based_on": json.loads(row[5]), + "created_at": row[6], + } + data["suggestions"].append(suggestion) + + # Export preferences + cursor.execute("SELECT key, value, category FROM preferences") + for row in cursor.fetchall(): + pref = {"key": row[0], "value": json.loads(row[1]), "category": row[2]} + data["preferences"].append(pref) with open(output_path, "w") as f: json.dump(data, f, indent=2) @@ -785,4 +856,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/cortex/dependency_resolver.py b/cortex/dependency_resolver.py index a7e72bb3..7e38e0a8 100644 --- a/cortex/dependency_resolver.py +++ b/cortex/dependency_resolver.py @@ -8,6 +8,7 @@ import logging import re import subprocess +import threading from dataclasses import asdict, dataclass logging.basicConfig(level=logging.INFO) @@ -64,6 +65,13 @@ class DependencyResolver: } def __init__(self): + """ + Initialize a DependencyResolver instance, setting up thread-safe caches and populating installed packages. + + Creates two locks for thread-safe access: one protecting the dependency cache and one protecting the installed packages set. Initializes an empty dependency_cache and installed_packages collection, then refreshes the installed_packages cache by calling _refresh_installed_packages to populate current system packages. + """ + self._cache_lock = threading.Lock() # Protect dependency_cache + self._packages_lock = threading.Lock() # Protect installed_packages self.dependency_cache: dict[str, DependencyGraph] = {} self.installed_packages: set[str] = set() self._refresh_installed_packages() @@ -79,25 +87,45 @@ def _run_command(self, cmd: list[str]) -> tuple[bool, str, str]: return (False, "", str(e)) def _refresh_installed_packages(self) -> None: - """Refresh cache of installed packages""" + """ + Update the resolver's installed package cache from the system package database. + + Queries `dpkg -l` to collect installed package names and replaces the resolver's + installed_packages set with the discovered names while holding `_packages_lock`. + Logs the refresh start and the number of packages found. + """ logger.info("Refreshing installed packages cache...") success, stdout, _ = self._run_command(["dpkg", "-l"]) if success: + new_packages = set() for line in stdout.split("\n"): if line.startswith("ii"): parts = line.split() if len(parts) >= 2: - self.installed_packages.add(parts[1]) - - logger.info(f"Found {len(self.installed_packages)} installed packages") + new_packages.add(parts[1]) + + with self._packages_lock: + self.installed_packages = new_packages + logger.info(f"Found {len(self.installed_packages)} installed packages") def is_package_installed(self, package_name: str) -> bool: - """Check if package is installed""" - return package_name in self.installed_packages + """ + Check whether the given package is currently recorded as installed. + + Returns: + True if the package is in the installed package set, False otherwise. + """ + with self._packages_lock: + return package_name in self.installed_packages def get_installed_version(self, package_name: str) -> str | None: - """Get version of installed package""" + """ + Retrieve the installed package's version string. + + Returns: + str: The installed package version, or `None` if the package is not installed or the version cannot be determined. + """ if not self.is_package_installed(package_name): return None @@ -201,18 +229,24 @@ def get_predefined_dependencies(self, package_name: str) -> list[Dependency]: def resolve_dependencies(self, package_name: str, recursive: bool = True) -> DependencyGraph: """ - Resolve all dependencies for a package - - Args: - package_name: Package to resolve dependencies for - recursive: Whether to resolve transitive dependencies + Compute a DependencyGraph for the given package. + + Includes direct dependencies, optional transitive dependencies (when requested), detected conflicts, and a suggested installation order. The resolved graph is stored in the resolver's cache for subsequent calls. + + Parameters: + package_name (str): Package to analyze. + recursive (bool): If True, include transitive (second-level) dependencies in the graph. + + Returns: + DependencyGraph: Object containing `package_name`, `direct_dependencies`, `all_dependencies`, `conflicts`, and `installation_order`. """ logger.info(f"Resolving dependencies for {package_name}...") - # Check cache - if package_name in self.dependency_cache: - logger.info(f"Using cached dependencies for {package_name}") - return self.dependency_cache[package_name] + # Check cache (thread-safe) + with self._cache_lock: + if package_name in self.dependency_cache: + logger.info(f"Using cached dependencies for {package_name}") + return self.dependency_cache[package_name] # Get dependencies from multiple sources apt_deps = self.get_apt_dependencies(package_name) @@ -254,8 +288,9 @@ def resolve_dependencies(self, package_name: str, recursive: bool = True) -> Dep installation_order=installation_order, ) - # Cache result - self.dependency_cache[package_name] = graph + # Cache result (thread-safe) + with self._cache_lock: + self.dependency_cache[package_name] = graph return graph @@ -446,4 +481,4 @@ def export_graph_json(self, package_name: str, filepath: str) -> None: print(f"Total dependencies: {len(graph.all_dependencies)}") satisfied = sum(1 for d in graph.all_dependencies if d.is_satisfied) print(f"✅ Satisfied: {satisfied}") - print(f"❌ Missing: {len(graph.all_dependencies) - satisfied}") + print(f"❌ Missing: {len(graph.all_dependencies) - satisfied}") \ No newline at end of file diff --git a/cortex/graceful_degradation.py b/cortex/graceful_degradation.py index 30d82543..1c93d301 100644 --- a/cortex/graceful_degradation.py +++ b/cortex/graceful_degradation.py @@ -11,6 +11,7 @@ import logging import os import sqlite3 +import threading import time from collections.abc import Callable from dataclasses import dataclass, field @@ -19,6 +20,8 @@ from pathlib import Path from typing import Any +from cortex.utils.db_pool import get_connection_pool, SQLiteConnectionPool + logger = logging.getLogger(__name__) @@ -69,13 +72,26 @@ class ResponseCache: """SQLite-based cache for LLM responses.""" def __init__(self, db_path: Path | None = None): + """ + Initialize the response cache, ensuring the storage path exists and preparing the database. + + Parameters: + db_path (Path | None): Filesystem path to the SQLite database file. If omitted, defaults to + ~/.cortex/response_cache.db. The parent directory will be created if it does not exist. + """ self.db_path = db_path or Path.home() / ".cortex" / "response_cache.db" self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._pool: SQLiteConnectionPool | None = None self._init_db() def _init_db(self): - """Initialize the cache database.""" - with sqlite3.connect(self.db_path) as conn: + """ + Set up the response cache database and connection pool. + + Creates a connection pool for the cache file and ensures the required schema exists: the `response_cache` table (with columns `query_hash`, `query`, `response`, `created_at`, `hit_count`, `last_used`) and the `idx_last_used` index. + """ + self._pool = get_connection_pool(str(self.db_path), pool_size=5) + with self._pool.get_connection() as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS response_cache ( @@ -102,10 +118,19 @@ def _hash_query(self, query: str) -> str: return hashlib.sha256(normalized.encode()).hexdigest()[:16] def get(self, query: str) -> CachedResponse | None: - """Retrieve a cached response.""" + """ + Retrieve a cached response for the given query and update its usage metadata. + + If a cached entry exists for the normalized query hash, increments the entry's hit count + and updates its last-used timestamp before returning it. + + Returns: + CachedResponse: The cached response with an incremented `hit_count` and updated `last_used`. + `None` if no cached entry exists for the query. + """ query_hash = self._hash_query(query) - with sqlite3.connect(self.db_path) as conn: + with self._pool.get_connection() as conn: conn.row_factory = sqlite3.Row cursor = conn.execute( "SELECT * FROM response_cache WHERE query_hash = ?", (query_hash,) @@ -136,10 +161,17 @@ def get(self, query: str) -> CachedResponse | None: return None def put(self, query: str, response: str) -> CachedResponse: - """Store a response in the cache.""" + """ + Store or replace a cached LLM response for the given query. + + The cache entry is created or replaced using a hash derived from the normalized query. The stored entry is initialized with a hit count of 0 and no last-used timestamp. + + Returns: + CachedResponse: The created cache record with `query_hash` derived from the normalized query and `created_at` set to the current time; `hit_count` will be 0 and `last_used` will be None. + """ query_hash = self._hash_query(query) - with sqlite3.connect(self.db_path) as conn: + with self._pool.get_connection() as conn: conn.execute( """ INSERT OR REPLACE INTO response_cache @@ -155,11 +187,20 @@ def put(self, query: str, response: str) -> CachedResponse: ) def get_similar(self, query: str, limit: int = 5) -> list[CachedResponse]: - """Get similar cached responses using simple keyword matching.""" + """ + Find cached responses whose queries share keywords with the given query, ranked by keyword overlap. + + Parameters: + query (str): The query text to match against cached entries. + limit (int): Maximum number of similar cached responses to return. + + Returns: + list[CachedResponse]: Up to `limit` cached responses ordered by descending keyword overlap (case-insensitive, whitespace token matching). + """ keywords = set(query.lower().split()) results = [] - with sqlite3.connect(self.db_path) as conn: + with self._pool.get_connection() as conn: conn.row_factory = sqlite3.Row cursor = conn.execute("SELECT * FROM response_cache ORDER BY hit_count DESC LIMIT 100") @@ -185,10 +226,18 @@ def get_similar(self, query: str, limit: int = 5) -> list[CachedResponse]: return [r[1] for r in results[:limit]] def clear_old_entries(self, days: int = 30) -> int: - """Remove entries older than specified days.""" + """ + Delete cache entries older than the specified number of days. + + Parameters: + days (int): Threshold age in days; entries with created_at earlier than now minus this many days will be removed. + + Returns: + deleted_count (int): Number of cache rows deleted. + """ cutoff = datetime.now() - timedelta(days=days) - with sqlite3.connect(self.db_path) as conn: + with self._pool.get_connection() as conn: cursor = conn.execute( "DELETE FROM response_cache WHERE created_at < ?", (cutoff.isoformat(),) ) @@ -196,8 +245,16 @@ def clear_old_entries(self, days: int = 30) -> int: return cursor.rowcount def get_stats(self) -> dict[str, Any]: - """Get cache statistics.""" - with sqlite3.connect(self.db_path) as conn: + """ + Return statistics about the response cache. + + Returns: + stats (dict): Dictionary with keys: + - total_entries (int): Number of rows in the cache. + - total_hits (int): Sum of `hit_count` across all cached entries. + - db_size_kb (float): Size of the cache database file in kilobytes (0 if the file does not exist). + """ + with self._pool.get_connection() as conn: conn.row_factory = sqlite3.Row total = conn.execute("SELECT COUNT(*) as count FROM response_cache").fetchone()["count"] @@ -499,15 +556,48 @@ def reset(self): # CLI Integration +# Global instance for degradation manager (thread-safe) +_degradation_instance = None +_degradation_lock = threading.Lock() + + def get_degradation_manager() -> GracefulDegradation: - """Get or create the global degradation manager.""" - if not hasattr(get_degradation_manager, "_instance"): - get_degradation_manager._instance = GracefulDegradation() - return get_degradation_manager._instance + """ + Return the module-level GracefulDegradation singleton, initializing it once using a thread-safe double-checked locking pattern. + + Returns: + degradation_manager (GracefulDegradation): The shared GracefulDegradation instance used by the module. + """ + global _degradation_instance + # Fast path: avoid lock if already initialized + if _degradation_instance is None: + with _degradation_lock: + # Double-checked locking pattern + if _degradation_instance is None: + _degradation_instance = GracefulDegradation() + return _degradation_instance def process_with_fallback(query: str, llm_fn: Callable | None = None) -> dict[str, Any]: - """Convenience function for processing queries with fallback.""" + """ + Process a user query using the global GracefulDegradation manager and its configured fallback strategies. + + Parameters: + query (str): The user-provided query or command string to process. + llm_fn (Callable | None): Optional callable used to invoke an LLM with the query; if None the manager uses its configured LLM behaviour and fallbacks. + + Returns: + dict[str, Any]: Result dictionary containing at minimum: + - query (str): original query + - response (str): chosen textual response or suggestion + - source (str): origin of the response ("llm", "cache", "cache_similar", "pattern_matching", "manual_mode") + - confidence (float): confidence score for the response + - mode (FallbackMode): current operating mode + - cached (bool): whether the response was served from cache + Optional fields may include: + - command (str): suggested shell/apt command when applicable + - similar_query (str): a similar cached query that was used + """ manager = get_degradation_manager() return manager.process_query(query, llm_fn) @@ -539,4 +629,4 @@ def process_with_fallback(query: str, llm_fn: Callable | None = None) -> dict[st print(f" Confidence: {result['confidence']:.0%}") if result["command"]: print(f" Command: {result['command']}") - print() + print() \ No newline at end of file diff --git a/cortex/hardware_detection.py b/cortex/hardware_detection.py index a61eb0e4..c18b3ed1 100644 --- a/cortex/hardware_detection.py +++ b/cortex/hardware_detection.py @@ -16,6 +16,7 @@ import re import shutil import subprocess +import threading from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path @@ -190,11 +191,23 @@ class HardwareDetector: CACHE_MAX_AGE_SECONDS = 3600 # 1 hour def __init__(self, use_cache: bool = True): + """ + Initialize the HardwareDetector. + + Parameters: + use_cache (bool): If True, enable loading and saving hardware detection results to a local JSON cache. + """ self.use_cache = use_cache self._info: SystemInfo | None = None + self._cache_lock = threading.RLock() # Reentrant lock for cache file access def _uname(self): - """Return uname-like info with nodename/release/machine attributes.""" + """ + Return system uname information, preferring os.uname() when available. + + Returns: + An object exposing uname-like attributes (`sysname`, `nodename`, `release`, `version`, `machine`), typically an `os.uname_result` or the value returned by `platform.uname()`. + """ uname_fn = getattr(os, "uname", None) if callable(uname_fn): return uname_fn() @@ -236,9 +249,14 @@ def detect(self, force_refresh: bool = False) -> SystemInfo: def detect_quick(self) -> dict[str, Any]: """ - Quick detection of essential hardware info. - - Returns minimal info for fast startup. + Return a minimal set of hardware metrics for fast startup. + + Returns: + dict[str, Any]: Dictionary with keys: + - "cpu_cores": number of CPU cores (int). + - "ram_gb": total RAM in gigabytes (float). + - "has_nvidia": `True` if an NVIDIA GPU is present, `False` otherwise (bool). + - "disk_free_gb": free disk space on root in gigabytes (float). """ return { "cpu_cores": self._get_cpu_cores(), @@ -248,64 +266,97 @@ def detect_quick(self) -> dict[str, Any]: } def _load_cache(self) -> SystemInfo | None: - """Load cached hardware info if valid.""" - try: - if not self.CACHE_FILE.exists(): - return None - - # Check age - import time - - if time.time() - self.CACHE_FILE.stat().st_mtime > self.CACHE_MAX_AGE_SECONDS: - return None - - with open(self.CACHE_FILE) as f: - data = json.load(f) - - # Reconstruct SystemInfo - info = SystemInfo() - info.hostname = data.get("hostname", "") - info.kernel_version = data.get("kernel_version", "") - info.distro = data.get("distro", "") - info.distro_version = data.get("distro_version", "") - - # CPU - cpu_data = data.get("cpu", {}) - info.cpu = CPUInfo( - vendor=CPUVendor(cpu_data.get("vendor", "unknown")), - model=cpu_data.get("model", "Unknown"), - cores=cpu_data.get("cores", 0), - threads=cpu_data.get("threads", 0), - ) - - # Memory - mem_data = data.get("memory", {}) - info.memory = MemoryInfo( - total_mb=mem_data.get("total_mb", 0), - available_mb=mem_data.get("available_mb", 0), - ) - - # Capabilities - info.has_nvidia_gpu = data.get("has_nvidia_gpu", False) - info.cuda_available = data.get("cuda_available", False) - - return info - - except Exception as e: - logger.debug(f"Cache load failed: {e}") + """ + Load a previously saved SystemInfo from the cache file when available and not expired. + + Attempts a thread-safe read of the cache file and reconstructs a partial SystemInfo containing system, CPU, memory, and capability fields. Returns None if caching is disabled, the cache file is absent, the cache is older than CACHE_MAX_AGE_SECONDS, or the cache cannot be parsed. + + Returns: + SystemInfo or `None`: SystemInfo reconstructed from cache; `None` if no valid cache is available. + """ + if not self.use_cache: return None + + with self._cache_lock: + try: + if not self.CACHE_FILE.exists(): + return None + + # Check age + import time + + if time.time() - self.CACHE_FILE.stat().st_mtime > self.CACHE_MAX_AGE_SECONDS: + return None + + with open(self.CACHE_FILE) as f: + data = json.load(f) + + # Reconstruct SystemInfo + info = SystemInfo() + info.hostname = data.get("hostname", "") + info.kernel_version = data.get("kernel_version", "") + info.distro = data.get("distro", "") + info.distro_version = data.get("distro_version", "") + + # CPU + cpu_data = data.get("cpu", {}) + info.cpu = CPUInfo( + vendor=CPUVendor(cpu_data.get("vendor", "unknown")), + model=cpu_data.get("model", "Unknown"), + cores=cpu_data.get("cores", 0), + threads=cpu_data.get("threads", 0), + ) + + # Memory + mem_data = data.get("memory", {}) + info.memory = MemoryInfo( + total_mb=mem_data.get("total_mb", 0), + available_mb=mem_data.get("available_mb", 0), + ) + + # Capabilities + info.has_nvidia_gpu = data.get("has_nvidia_gpu", False) + info.cuda_available = data.get("cuda_available", False) + + return info + + except Exception as e: + logger.debug(f"Cache load failed: {e}") + return None - def _save_cache(self, info: SystemInfo): - """Save hardware info to cache.""" - try: - self.CACHE_FILE.parent.mkdir(parents=True, exist_ok=True) - with open(self.CACHE_FILE, "w") as f: - json.dump(info.to_dict(), f, indent=2) - except Exception as e: - logger.debug(f"Cache save failed: {e}") + def _save_cache(self, info: SystemInfo) -> None: + """ + Persist the provided SystemInfo to the on-disk cache for later reuse. + + If caching is disabled, this is a no-op. The method acquires an internal lock to ensure thread-safe access, creates parent directories as needed, and writes the system information as JSON to the configured cache file. Failures are caught and logged without raising. + + Parameters: + info (SystemInfo): The system hardware information to serialize and store. + """ + if not self.use_cache: + return + + with self._cache_lock: + try: + self.CACHE_FILE.parent.mkdir(parents=True, exist_ok=True) + with open(self.CACHE_FILE, "w") as f: + json.dump(info.to_dict(), f, indent=2) + except Exception as e: + logger.debug(f"Cache save failed: {e}") def _detect_system(self, info: SystemInfo): - """Detect basic system information.""" + """ + Populate the provided SystemInfo with basic system-level details: hostname, kernel version, distribution name and version, and system uptime. + + This fills these fields on the given `info` object when available: + - `hostname` (set to "unknown" if it cannot be determined) + - `kernel_version` + - `distro` + - `distro_version` + - `uptime_seconds` + + Detection uses system sources such as uname, /etc/os-release, and /proc/uptime. If individual values cannot be read, those fields are left unchanged (except `hostname`, which is set to "unknown" on failure). + """ # Hostname try: info.hostname = self._uname().nodename @@ -741,4 +792,4 @@ def get_cpu_cores() -> int: if info.virtualization: print(f" Virtualization: {info.virtualization}") - print("\n✅ Detection complete!") + print("\n✅ Detection complete!") \ No newline at end of file diff --git a/cortex/installation_history.py b/cortex/installation_history.py index 1c3289a4..3cc9e6a2 100644 --- a/cortex/installation_history.py +++ b/cortex/installation_history.py @@ -17,6 +17,8 @@ from enum import Enum from pathlib import Path +from cortex.utils.db_pool import get_connection_pool, SQLiteConnectionPool + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -72,12 +74,25 @@ class InstallationHistory: """Manages installation history and rollback""" def __init__(self, db_path: str = "/var/lib/cortex/history.db"): + """ + Initialize the InstallationHistory manager using the given SQLite database path. + + Ensures the database directory exists (will fall back to the user's data directory on permission errors), initializes the connection pool, and creates the required database schema if missing. + + Parameters: + db_path (str): Filesystem path to the SQLite database file (default: "/var/lib/cortex/history.db"). + """ self.db_path = db_path self._ensure_db_directory() + self._pool: SQLiteConnectionPool | None = None self._init_database() def _ensure_db_directory(self): - """Ensure database directory exists""" + """ + Ensure the directory for the configured SQLite database exists. + + If creating the parent directory of self.db_path raises PermissionError, fall back to a per-user location (~/.cortex), create that directory, update self.db_path to point to ~/.cortex/history.db, and emit a warning. + """ db_dir = Path(self.db_path).parent try: db_dir.mkdir(parents=True, exist_ok=True) @@ -89,40 +104,45 @@ def _ensure_db_directory(self): logger.warning(f"Using user directory for database: {self.db_path}") def _init_database(self): - """Initialize SQLite database""" + """ + Initialize the installation history database and connection pool. + + Creates or opens the SQLite-backed database at the configured path, initializes a connection pool assigned to `self._pool`, ensures the installations table and a timestamp index exist, and commits the schema changes. Raises any exception encountered during initialization. + """ try: - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + self._pool = get_connection_pool(self.db_path, pool_size=5) + + with self._pool.get_connection() as conn: + cursor = conn.cursor() - # Create installations table - cursor.execute( + # Create installations table + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS installations ( + id TEXT PRIMARY KEY, + timestamp TEXT NOT NULL, + operation_type TEXT NOT NULL, + packages TEXT NOT NULL, + status TEXT NOT NULL, + before_snapshot TEXT, + after_snapshot TEXT, + commands_executed TEXT, + error_message TEXT, + rollback_available INTEGER, + duration_seconds REAL + ) """ - CREATE TABLE IF NOT EXISTS installations ( - id TEXT PRIMARY KEY, - timestamp TEXT NOT NULL, - operation_type TEXT NOT NULL, - packages TEXT NOT NULL, - status TEXT NOT NULL, - before_snapshot TEXT, - after_snapshot TEXT, - commands_executed TEXT, - error_message TEXT, - rollback_available INTEGER, - duration_seconds REAL ) - """ - ) - # Create index on timestamp - cursor.execute( + # Create index on timestamp + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_timestamp + ON installations(timestamp) """ - CREATE INDEX IF NOT EXISTS idx_timestamp - ON installations(timestamp) - """ - ) + ) - conn.commit() - conn.close() + conn.commit() logger.info(f"Database initialized at {self.db_path}") except Exception as e: @@ -255,10 +275,16 @@ def record_installation( start_time: datetime.datetime, ) -> str: """ - Record an installation operation - + Create and persist a new installation record and return its generated ID. + + Parameters: + operation_type (InstallationType): The type of installation operation (e.g., INSTALL, UPGRADE). + packages (list[str]): Package names involved; if empty, package names are inferred from `commands`. + commands (list[str]): The shell commands executed for the installation. + start_time (datetime.datetime): The timestamp to record as the installation start time. + Returns: - Installation ID + str: The generated installation ID. """ # If packages list is empty, try to extract from commands if not packages: @@ -277,12 +303,12 @@ def record_installation( timestamp = start_time.isoformat() try: - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute( - """ - INSERT INTO installations VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + cursor.execute( + """ + INSERT INTO installations VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( install_id, @@ -300,7 +326,6 @@ def record_installation( ) conn.commit() - conn.close() logger.info(f"Installation {install_id} recorded") return install_id @@ -311,23 +336,34 @@ def record_installation( def update_installation( self, install_id: str, status: InstallationStatus, error_message: str | None = None ): - """Update installation record after completion""" + """ + Update the stored installation record after an installation completes. + + Updates the installation identified by `install_id` with the provided `status`, captures an "after" package snapshot, calculates and stores the operation duration, and records an optional `error_message`. This persists changes to the history database. + + Parameters: + install_id (str): Unique identifier of the installation record to update. + status (InstallationStatus): Final status to store for the installation. + error_message (str | None): Optional error message to record if the installation failed. + + Raises: + Exception: If snapshot creation or database update fails. + """ try: - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - # Get packages from record - cursor.execute( - "SELECT packages, timestamp FROM installations WHERE id = ?", (install_id,) - ) - result = cursor.fetchone() + # Get packages from record + cursor.execute( + "SELECT packages, timestamp FROM installations WHERE id = ?", (install_id,) + ) + result = cursor.fetchone() - if not result: - logger.error(f"Installation {install_id} not found") - conn.close() - return + if not result: + logger.error(f"Installation {install_id} not found") + return - packages = json.loads(result[0]) + packages = json.loads(result[0]) start_time = datetime.datetime.fromisoformat(result[1]) duration = (datetime.datetime.now() - start_time).total_seconds() @@ -354,7 +390,6 @@ def update_installation( ) conn.commit() - conn.close() logger.info(f"Installation {install_id} updated: {status.value}") except Exception as e: @@ -364,75 +399,92 @@ def update_installation( def get_history( self, limit: int = 50, status_filter: InstallationStatus | None = None ) -> list[InstallationRecord]: - """Get installation history""" + """ + Retrieve recent installation records from the history store. + + If `status_filter` is provided, only records with that status are returned. Results are ordered by timestamp from newest to oldest and limited to `limit`. Malformed database rows are skipped (a warning is logged); on unexpected failure an empty list is returned. + + Parameters: + limit (int): Maximum number of records to return. + status_filter (InstallationStatus | None): Optional status to filter records by. + + Returns: + list[InstallationRecord]: A list of InstallationRecord objects matching the query, newest first. + """ try: - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - if status_filter: - cursor.execute( - """ - SELECT * FROM installations - WHERE status = ? - ORDER BY timestamp DESC - LIMIT ? + if status_filter: + cursor.execute( + """ + SELECT * FROM installations + WHERE status = ? + ORDER BY timestamp DESC + LIMIT ? """, - (status_filter.value, limit), - ) - else: - cursor.execute( - """ - SELECT * FROM installations - ORDER BY timestamp DESC - LIMIT ? + (status_filter.value, limit), + ) + else: + cursor.execute( + """ + SELECT * FROM installations + ORDER BY timestamp DESC + LIMIT ? """, - (limit,), - ) - - records = [] - for row in cursor.fetchall(): - try: - record = InstallationRecord( - id=row[0], - timestamp=row[1], - operation_type=InstallationType(row[2]), - packages=json.loads(row[3]) if row[3] else [], - status=InstallationStatus(row[4]), - before_snapshot=[ - PackageSnapshot(**s) for s in (json.loads(row[5]) if row[5] else []) - ], - after_snapshot=[ - PackageSnapshot(**s) for s in (json.loads(row[6]) if row[6] else []) - ], - commands_executed=json.loads(row[7]) if row[7] else [], - error_message=row[8], - rollback_available=bool(row[9]) if row[9] is not None else True, - duration_seconds=row[10], + (limit,), ) - records.append(record) - except Exception as e: - logger.warning(f"Failed to parse record {row[0]}: {e}") - continue - conn.close() - return records + records = [] + for row in cursor.fetchall(): + try: + record = InstallationRecord( + id=row[0], + timestamp=row[1], + operation_type=InstallationType(row[2]), + packages=json.loads(row[3]) if row[3] else [], + status=InstallationStatus(row[4]), + before_snapshot=[ + PackageSnapshot(**s) for s in (json.loads(row[5]) if row[5] else []) + ], + after_snapshot=[ + PackageSnapshot(**s) for s in (json.loads(row[6]) if row[6] else []) + ], + commands_executed=json.loads(row[7]) if row[7] else [], + error_message=row[8], + rollback_available=bool(row[9]) if row[9] is not None else True, + duration_seconds=row[10], + ) + records.append(record) + except Exception as e: + logger.warning(f"Failed to parse record {row[0]}: {e}") + continue + + return records except Exception as e: logger.error(f"Failed to get history: {e}") return [] def get_installation(self, install_id: str) -> InstallationRecord | None: - """Get specific installation by ID""" + """ + Retrieve an installation record by its unique identifier. + + Parameters: + install_id (str): The unique ID of the installation record to retrieve. + + Returns: + InstallationRecord | None: The matching InstallationRecord, or `None` if no record exists or retrieval fails. + """ try: - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute("SELECT * FROM installations WHERE id = ?", (install_id,)) + cursor.execute("SELECT * FROM installations WHERE id = ?", (install_id,)) - row = cursor.fetchone() - conn.close() + row = cursor.fetchone() - if not row: - return None + if not row: + return None return InstallationRecord( id=row[0], @@ -457,14 +509,14 @@ def get_installation(self, install_id: str) -> InstallationRecord | None: def rollback(self, install_id: str, dry_run: bool = False) -> tuple[bool, str]: """ - Rollback an installation - - Args: - install_id: Installation to rollback - dry_run: If True, only show what would be done - + Perform a rollback for a recorded installation, executing package actions to restore the previous package state. + + Parameters: + install_id (str): ID of the installation record to rollback. + dry_run (bool): If True, do not execute commands and instead return the list of rollback actions. + Returns: - (success, message) + tuple[bool, str]: `True` and a human-readable success or actions string if rollback succeeded or dry-run; `False` and an error message if rollback could not be performed or failed. """ # Get installation record record = self.get_installation(install_id) @@ -546,14 +598,13 @@ def rollback(self, install_id: str, dry_run: bool = False) -> tuple[bool, str]: # Mark original as rolled back try: - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - cursor.execute( - "UPDATE installations SET status = ? WHERE id = ?", - (InstallationStatus.ROLLED_BACK.value, install_id), - ) - conn.commit() - conn.close() + with self._pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE installations SET status = ? WHERE id = ?", + (InstallationStatus.ROLLED_BACK.value, install_id), + ) + conn.commit() except Exception as e: logger.error(f"Failed to update rollback status: {e}") @@ -610,21 +661,28 @@ def export_history(self, filepath: str, format: str = "json"): logger.info(f"History exported to {filepath}") def cleanup_old_records(self, days: int = 90): - """Remove records older than specified days""" + """ + Delete installation records older than the specified number of days. + + Parameters: + days (int): Number of days; records with a timestamp earlier than now minus this value will be removed. + + Returns: + int: The number of records deleted. Returns 0 if the operation fails. + """ cutoff = datetime.datetime.now() - datetime.timedelta(days=days) cutoff_str = cutoff.isoformat() try: - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() + with self._pool.get_connection() as conn: + cursor = conn.cursor() - cursor.execute("DELETE FROM installations WHERE timestamp < ?", (cutoff_str,)) + cursor.execute("DELETE FROM installations WHERE timestamp < ?", (cutoff_str,)) - deleted = cursor.rowcount - conn.commit() - conn.close() + deleted = cursor.rowcount + conn.commit() - logger.info(f"Deleted {deleted} old records") + logger.info(f"Deleted {deleted} old records") return deleted except Exception as e: logger.error(f"Failed to cleanup records: {e}") @@ -756,4 +814,4 @@ def cleanup_old_records(self, days: int = 90): logger.exception("CLI error") sys.exit(1) - sys.exit(exit_code) + sys.exit(exit_code) \ No newline at end of file diff --git a/cortex/kernel_features/accelerator_limits.py b/cortex/kernel_features/accelerator_limits.py index 47a6f370..9bb62461 100644 --- a/cortex/kernel_features/accelerator_limits.py +++ b/cortex/kernel_features/accelerator_limits.py @@ -11,6 +11,8 @@ from enum import Enum from pathlib import Path +from cortex.utils.db_pool import get_connection_pool + CORTEX_DB = Path.home() / ".cortex/limits.db" CGROUP_ROOT = Path("/sys/fs/cgroup") @@ -52,24 +54,53 @@ def from_preset(cls, name: str, preset: str, gpus: int = 0): class LimitsDatabase: def __init__(self): + """ + Initialize the LimitsDatabase. + + Ensures the directory for the Cortex SQLite file exists, creates a connection pool for the database, and ensures the `profiles` table (columns: `name` primary key, `config`) is present. + """ CORTEX_DB.parent.mkdir(parents=True, exist_ok=True) - with sqlite3.connect(CORTEX_DB) as conn: + self._pool = get_connection_pool(str(CORTEX_DB), pool_size=5) + with self._pool.get_connection() as conn: conn.execute("CREATE TABLE IF NOT EXISTS profiles (name TEXT PRIMARY KEY, config TEXT)") def save(self, limits: ResourceLimits): - with sqlite3.connect(CORTEX_DB) as conn: + """ + Persist a ResourceLimits profile to the database, replacing any existing profile with the same name. + + Parameters: + limits (ResourceLimits): The resource limits profile to save; stored as JSON in the `profiles` table under its `name`. + """ + with self._pool.get_connection() as conn: conn.execute( "INSERT OR REPLACE INTO profiles VALUES (?,?)", (limits.name, json.dumps(asdict(limits))), ) def get(self, name: str) -> ResourceLimits | None: - with sqlite3.connect(CORTEX_DB) as conn: + """ + Retrieve a saved resource limits profile by name. + + Parameters: + name (str): Profile name to look up in the database. + + Returns: + ResourceLimits | None: The deserialized ResourceLimits for the profile if found, `None` if no profile exists with that name. + """ + with self._pool.get_connection() as conn: row = conn.execute("SELECT config FROM profiles WHERE name=?", (name,)).fetchone() return ResourceLimits(**json.loads(row[0])) if row else None def list_all(self): - with sqlite3.connect(CORTEX_DB) as conn: + """ + Return all saved ResourceLimits profiles from the database. + + Retrieves each profile's JSON config from the profiles table and deserializes it into a ResourceLimits instance. + + Returns: + list[ResourceLimits]: List of ResourceLimits objects representing all stored profiles (empty list if none). + """ + with self._pool.get_connection() as conn: return [ ResourceLimits(**json.loads(r[0])) for r in conn.execute("SELECT config FROM profiles") @@ -130,4 +161,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/cortex/kernel_features/kv_cache_manager.py b/cortex/kernel_features/kv_cache_manager.py index 3d7f7610..6457ca76 100644 --- a/cortex/kernel_features/kv_cache_manager.py +++ b/cortex/kernel_features/kv_cache_manager.py @@ -9,6 +9,7 @@ import contextlib import json import sqlite3 +from cortex.utils.db_pool import get_connection_pool from dataclasses import asdict, dataclass from enum import Enum from multiprocessing import shared_memory @@ -45,8 +46,14 @@ class CacheEntry: class CacheDatabase: def __init__(self): + """ + Initialize the CacheDatabase by preparing the on-disk database and connection pool. + + Creates the database directory if missing, opens a pooled connection to the SQLite file, and ensures the schema for the following tables exists: `pools`, `entries`, and `stats`. + """ CORTEX_DB.parent.mkdir(parents=True, exist_ok=True) - with sqlite3.connect(CORTEX_DB) as conn: + self._pool = get_connection_pool(str(CORTEX_DB), pool_size=5) + with self._pool.get_connection() as conn: conn.executescript( """ CREATE TABLE IF NOT EXISTS pools (name TEXT PRIMARY KEY, config TEXT, shm_name TEXT); @@ -57,7 +64,17 @@ def __init__(self): ) def save_pool(self, cfg: CacheConfig, shm: str): - with sqlite3.connect(CORTEX_DB) as conn: + """ + Persist a cache pool's configuration and its shared-memory segment name in the database. + + Stores the given CacheConfig (as JSON) and the associated shared-memory name into the pools table, + replacing any existing entry for that pool, and ensures a corresponding row exists in the stats table. + + Parameters: + cfg (CacheConfig): Configuration for the cache pool to persist. + shm (str): Name of the shared-memory segment associated with the pool. + """ + with self._pool.get_connection() as conn: conn.execute( "INSERT OR REPLACE INTO pools VALUES (?,?,?)", (cfg.name, json.dumps(asdict(cfg)), shm), @@ -65,14 +82,31 @@ def save_pool(self, cfg: CacheConfig, shm: str): conn.execute("INSERT OR IGNORE INTO stats (pool) VALUES (?)", (cfg.name,)) def get_pool(self, name: str): - with sqlite3.connect(CORTEX_DB) as conn: + """ + Retrieve the stored cache pool configuration and its shared-memory segment name for a given pool. + + Parameters: + name (str): The pool name to look up. + + Returns: + tuple: (CacheConfig, str) reconstructed CacheConfig and the shared-memory name, or None if the pool is not found. + """ + with self._pool.get_connection() as conn: row = conn.execute( "SELECT config, shm_name FROM pools WHERE name=?", (name,) ).fetchone() return (CacheConfig(**json.loads(row[0])), row[1]) if row else None def list_pools(self): - with sqlite3.connect(CORTEX_DB) as conn: + """ + Return all cache pool configurations persisted in the database. + + Queries the pools table and reconstructs each pool's configuration from the stored JSON. + + Returns: + list[CacheConfig]: A list of CacheConfig objects rebuilt from the stored JSON config for each pool. + """ + with self._pool.get_connection() as conn: return [ CacheConfig(**json.loads(r[0])) for r in conn.execute("SELECT config FROM pools").fetchall() @@ -116,10 +150,19 @@ def create_pool(self, cfg: CacheConfig) -> bool: return True def destroy_pool(self, name: str) -> bool: + """ + Destroy a named cache pool, removing its in-memory segment and deleting its metadata. + + Parameters: + name (str): The name of the pool to destroy. + + Returns: + bool: `True` if the pool destruction and metadata deletion were performed. + """ if name in self.pools: self.pools[name].destroy() del self.pools[name] - with sqlite3.connect(CORTEX_DB) as conn: + with self.db._pool.get_connection() as conn: conn.execute("DELETE FROM pools WHERE name=?", (name,)) print(f"✅ Destroyed pool '{name}'") return True @@ -164,4 +207,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/cortex/llm_router.py b/cortex/llm_router.py index d43f9eaa..8a5a4a26 100644 --- a/cortex/llm_router.py +++ b/cortex/llm_router.py @@ -15,6 +15,7 @@ import json import logging import os +import threading import time from dataclasses import dataclass from enum import Enum @@ -117,14 +118,14 @@ def __init__( track_costs: bool = True, ): """ - Initialize LLM Router. - - Args: - claude_api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env) - kimi_api_key: Moonshot API key (defaults to MOONSHOT_API_KEY env) - default_provider: Fallback provider if routing fails - enable_fallback: Try alternate LLM if primary fails - track_costs: Track token usage and costs + Create an LLMRouter and initialize provider clients, configuration, and usage tracking. + + Parameters: + claude_api_key (str | None): Anthropic (Claude) API key; if omitted, reads ANTHROPIC_API_KEY from the environment. + kimi_api_key (str | None): Moonshot (Kimi K2) API key; if omitted, reads MOONSHOT_API_KEY from the environment. + default_provider (LLMProvider): Provider to use when routing rules don't specify or preferred provider is unavailable. + enable_fallback (bool): If True, allows falling back to the alternate provider when the chosen provider is unavailable or fails. + track_costs (bool): If True, enable tracking of token usage and estimated USD costs (protected by an internal lock). """ self.claude_api_key = claude_api_key or os.getenv("ANTHROPIC_API_KEY") self.kimi_api_key = kimi_api_key or os.getenv("MOONSHOT_API_KEY") @@ -161,7 +162,8 @@ def __init__( # Rate limiting for parallel calls self._rate_limit_semaphore: asyncio.Semaphore | None = None - # Cost tracking + # Cost tracking (protected by lock for thread-safety) + self._stats_lock = threading.Lock() self.total_cost_usd = 0.0 self.request_count = 0 self.provider_stats = { @@ -382,42 +384,73 @@ def _complete_kimi( def _calculate_cost( self, provider: LLMProvider, input_tokens: int, output_tokens: int ) -> float: - """Calculate cost in USD for this request.""" + """ + Compute the USD cost for the given provider using the request's input and output token counts. + + The calculation uses the provider's per-million-token input and output rates from the router's COSTS mapping. + + Parameters: + provider (LLMProvider): Provider whose rates are applied. + input_tokens (int): Number of input tokens. + output_tokens (int): Number of output tokens. + + Returns: + float: Total cost in USD for the request. + """ costs = self.COSTS[provider] input_cost = (input_tokens / 1_000_000) * costs["input"] output_cost = (output_tokens / 1_000_000) * costs["output"] return input_cost + output_cost def _update_stats(self, response: LLMResponse): - """Update usage statistics.""" - self.total_cost_usd += response.cost_usd - self.request_count += 1 + """ + Update aggregate and per-provider usage statistics using data from `response` in a thread-safe manner. + + Parameters: + response (LLMResponse): Response whose `cost_usd`, `tokens_used`, and `provider` are applied to the router's aggregates. + + Detailed behavior: + - Increments `total_cost_usd` by `response.cost_usd`. + - Increments `request_count` by 1. + - For the provider `response.provider`, increments `requests`, adds `tokens_used` to `tokens`, and adds `cost_usd` to `cost`. + """ + with self._stats_lock: + self.total_cost_usd += response.cost_usd + self.request_count += 1 - stats = self.provider_stats[response.provider] - stats["requests"] += 1 - stats["tokens"] += response.tokens_used - stats["cost"] += response.cost_usd + stats = self.provider_stats[response.provider] + stats["requests"] += 1 + stats["tokens"] += response.tokens_used + stats["cost"] += response.cost_usd def get_stats(self) -> dict[str, Any]: """ - Get usage statistics. - + Return a snapshot of accumulated usage statistics. + + Provides a thread-safe snapshot of total request and cost aggregates and per-provider metrics. The returned dictionary contains: + - total_requests: total number of completed requests. + - total_cost_usd: total cost across providers in USD, rounded to 4 decimal places. + - providers: mapping with per-provider entries: + - claude: { requests, tokens, cost_usd } + - kimi_k2: { requests, tokens, cost_usd } + Returns: - Dictionary with request counts, tokens, costs per provider + dict: A dictionary with keys `total_requests`, `total_cost_usd`, and `providers` where each provider entry contains `requests`, `tokens`, and `cost_usd`. """ - return { - "total_requests": self.request_count, - "total_cost_usd": round(self.total_cost_usd, 4), - "providers": { - "claude": { - "requests": self.provider_stats[LLMProvider.CLAUDE]["requests"], - "tokens": self.provider_stats[LLMProvider.CLAUDE]["tokens"], - "cost_usd": round(self.provider_stats[LLMProvider.CLAUDE]["cost"], 4), - }, - "kimi_k2": { - "requests": self.provider_stats[LLMProvider.KIMI_K2]["requests"], - "tokens": self.provider_stats[LLMProvider.KIMI_K2]["tokens"], - "cost_usd": round(self.provider_stats[LLMProvider.KIMI_K2]["cost"], 4), + with self._stats_lock: + return { + "total_requests": self.request_count, + "total_cost_usd": round(self.total_cost_usd, 4), + "providers": { + "claude": { + "requests": self.provider_stats[LLMProvider.CLAUDE]["requests"], + "tokens": self.provider_stats[LLMProvider.CLAUDE]["tokens"], + "cost_usd": round(self.provider_stats[LLMProvider.CLAUDE]["cost"], 4), + }, + "kimi_k2": { + "requests": self.provider_stats[LLMProvider.KIMI_K2]["requests"], + "tokens": self.provider_stats[LLMProvider.KIMI_K2]["tokens"], + "cost_usd": round(self.provider_stats[LLMProvider.KIMI_K2]["cost"], 4), }, }, } @@ -902,4 +935,4 @@ async def check_hardware_configs_parallel( # Show stats print("=== Usage Statistics ===") stats = router.get_stats() - print(json.dumps(stats, indent=2)) + print(json.dumps(stats, indent=2)) \ No newline at end of file diff --git a/cortex/notification_manager.py b/cortex/notification_manager.py index c8648488..54d38122 100644 --- a/cortex/notification_manager.py +++ b/cortex/notification_manager.py @@ -2,6 +2,7 @@ import json import shutil import subprocess +import threading from pathlib import Path from rich.console import Console @@ -22,6 +23,11 @@ class NotificationManager: def __init__(self): # Set up configuration directory in user home + """ + Initialize the NotificationManager. + + Creates the configuration directory (~/.cortex) if missing, sets paths for the history and config JSON files, establishes default DND configuration (dnd_start="22:00", dnd_end="08:00", enabled=True), loads persisted configuration and notification history, and initializes a thread lock for protecting history list and file I/O. + """ self.config_dir = Path.home() / ".cortex" self.config_dir.mkdir(exist_ok=True) @@ -33,9 +39,14 @@ def __init__(self): self._load_config() self.history = self._load_history() + self._history_lock = threading.Lock() # Protect history list and file I/O def _load_config(self): - """Loads configuration from JSON. Creates default if missing.""" + """ + Load configuration from the config JSON file and merge it into the in-memory configuration. + + If the config file exists, parse it as JSON and update self.config with the parsed values. If the file contains invalid JSON, leave the current configuration unchanged and print a warning to the console. If the config file does not exist, create it by writing the current in-memory configuration to disk via _save_config(). + """ if self.config_file.exists(): try: with open(self.config_file) as f: @@ -46,12 +57,25 @@ def _load_config(self): self._save_config() def _save_config(self): - """Saves current configuration to JSON.""" + """ + Write the in-memory configuration to the configured JSON file, overwriting its contents. + + The file is written with an indentation of 4 spaces to produce human-readable JSON. + """ with open(self.config_file, "w") as f: json.dump(self.config, f, indent=4) def _load_history(self) -> list[dict]: - """Loads notification history.""" + """ + Load the notification history from the configured history JSON file. + + If the history file exists and contains valid JSON, return the parsed list of history entry dicts. + If the file is missing or contains invalid JSON, return an empty list. + + Returns: + list[dict]: Parsed notification history entries, or an empty list if none are available. + """ + # Note: Called only during __init__, but protected for consistency if self.history_file.exists(): try: with open(self.history_file) as f: @@ -61,7 +85,12 @@ def _load_history(self) -> list[dict]: return [] def _save_history(self): - """Saves the last 100 notifications to history.""" + """ + Write the most recent 100 notification entries to the history JSON file. + + This method overwrites the history file with up to the last 100 entries from self.history, serialized as indented JSON. Caller must hold self._history_lock to ensure thread safety. + """ + # Caller must hold self._history_lock with open(self.history_file, "w") as f: json.dump(self.history[-100:], f, indent=4) @@ -93,9 +122,18 @@ def send( self, title: str, message: str, level: str = "normal", actions: list[str] | None = None ): """ - Sends a notification. - :param level: 'low', 'normal', 'critical'. Critical bypasses DND. - :param actions: List of button labels e.g. ["View Logs", "Retry"] + Send a desktop notification with optional action buttons, honoring Do Not Disturb (DND) rules. + + Parameters: + title (str): Notification title. + message (str): Notification body text. + level (str): Severity level; one of "low", "normal", or "critical". A "critical" notification bypasses DND. + actions (list[str] | None): Optional list of action button labels (e.g., ["View Logs", "Retry"]). When supported by the platform, these are delivered as notification actions/hints. + + Behavior: + - If DND is active and `level` is not "critical", the notification is suppressed. + - Attempts to send a native notification when available; otherwise logs a simulated notification to the console. + - Records every outcome to the notification history with a `status` of "suppressed", "sent", or "simulated". """ # 1. Check DND status if self.is_dnd_active() and level != "critical": @@ -136,7 +174,19 @@ def send( self._log_history(title, message, level, status="simulated", actions=actions) def _log_history(self, title, message, level, status, actions=None): - """Appends entry to history log.""" + """ + Append a notification event to the manager's history and persist it to disk in a thread-safe manner. + + Parameters: + title (str): Notification title. + message (str): Notification body text. + level (str): Notification severity (e.g., 'low', 'normal', 'critical'). + status (str): Outcome label for the entry (e.g., 'sent', 'suppressed', 'simulated'). + actions (list[str] | None): Optional list of action button labels; stored as an empty list if None. + + Notes: + This method acquires an internal lock to ensure atomic append and save of the history entry. The entry includes an ISO 8601 timestamp. + """ entry = { "timestamp": datetime.datetime.now().isoformat(), "title": title, @@ -145,11 +195,12 @@ def _log_history(self, title, message, level, status, actions=None): "status": status, "actions": actions if actions else [], } - self.history.append(entry) - self._save_history() + with self._history_lock: + self.history.append(entry) + self._save_history() if __name__ == "__main__": mgr = NotificationManager() # Test with actions to verify the new feature - mgr.send("Action Test", "Testing buttons support", actions=["View Logs", "Retry"]) + mgr.send("Action Test", "Testing buttons support", actions=["View Logs", "Retry"]) \ No newline at end of file diff --git a/cortex/progress_indicators.py b/cortex/progress_indicators.py index a16ba1d4..09ffa852 100644 --- a/cortex/progress_indicators.py +++ b/cortex/progress_indicators.py @@ -115,46 +115,89 @@ class FallbackProgress: """Simple fallback progress indicator without Rich.""" def __init__(self): + """ + Initialize the fallback progress indicator's internal state used for a text-based spinner. + + Sets up the spinner character sequence and index, the current message, a running flag, the background thread reference, and a Lock to protect shared state during concurrent updates. + """ self._current_message = "" self._spinner_chars = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" self._spinner_idx = 0 self._running = False self._thread = None + self._lock = threading.Lock() # Protect shared state def start(self, message: str): - """Start showing progress.""" - self._current_message = message - self._running = True - self._thread = threading.Thread(target=self._animate, daemon=True) - self._thread.start() + """ + Begin displaying a progress spinner with the given message. + + Starts a background thread that updates the progress display until stopped; calling this sets the current message and marks the spinner as running. + + Parameters: + message (str): Text to show alongside the spinner. + """ + with self._lock: + self._current_message = message + self._running = True + self._thread = threading.Thread(target=self._animate, daemon=True) + self._thread.start() def _animate(self): - """Animate the spinner.""" - while self._running: - char = self._spinner_chars[self._spinner_idx % len(self._spinner_chars)] - sys.stdout.write(f"\r{char} {self._current_message}") + """ + Continuously renders a spinner and the current message to stdout until the spinner is stopped. + + This method loops, updating the displayed spinner frame and message at regular intervals (approximately every 0.1s), and exits when the internal running flag is cleared. It performs output using stdout and flushes after each update. + """ + while True: + with self._lock: + if not self._running: + break + char = self._spinner_chars[self._spinner_idx % len(self._spinner_chars)] + message = self._current_message + self._spinner_idx += 1 + + sys.stdout.write(f"\r{char} {message}") sys.stdout.flush() - self._spinner_idx += 1 time.sleep(0.1) def update(self, message: str): """Update the progress message.""" - self._current_message = message + with self._lock: + self._current_message = message def stop(self, final_message: str = ""): - """Stop the progress indicator.""" - self._running = False - if self._thread: - self._thread.join(timeout=0.5) - sys.stdout.write(f"\r✓ {final_message or self._current_message}\n") + """ + Stop the progress indicator and print a completion line. + + Parameters: + final_message (str): Optional message to display after stopping; if empty, the last shown spinner message is used. + """ + with self._lock: + self._running = False + thread = self._thread + message = final_message or self._current_message + + if thread: + thread.join(timeout=0.5) + sys.stdout.write(f"\r✓ {message}\n") sys.stdout.flush() def fail(self, message: str = ""): - """Show failure.""" - self._running = False - if self._thread: - self._thread.join(timeout=0.5) - sys.stdout.write(f"\r✗ {message or self._current_message}\n") + """ + Stop the spinner and print a failure line for the current operation. + + If `message` is provided it will be used; otherwise the last displayed message is used. This also stops the background spinner thread. + Parameters: + message (str): Optional failure message to display; uses the current message when empty. + """ + with self._lock: + self._running = False + thread = self._thread + msg = message or self._current_message + + if thread: + thread.join(timeout=0.5) + sys.stdout.write(f"\r✗ {msg}\n") sys.stdout.flush() @@ -643,13 +686,21 @@ def finish(self): # Global instance for convenience _global_progress = None +_global_progress_lock = threading.Lock() def get_progress_indicator() -> ProgressIndicator: - """Get or create the global progress indicator.""" + """ + Get the module-level ProgressIndicator singleton, creating it on first access in a thread-safe manner. + + Returns: + ProgressIndicator: The global ProgressIndicator instance. + """ global _global_progress - if _global_progress is None: - _global_progress = ProgressIndicator() + if _global_progress is None: # Fast path + with _global_progress_lock: + if _global_progress is None: # Double-check + _global_progress = ProgressIndicator() return _global_progress @@ -722,4 +773,4 @@ def progress_bar(items: list[Any], description: str = "Processing"): tracker.finish() - print("\n✅ Demo complete!") + print("\n✅ Demo complete!") \ No newline at end of file diff --git a/cortex/semantic_cache.py b/cortex/semantic_cache.py index 67bef0dc..4f5df2d0 100644 --- a/cortex/semantic_cache.py +++ b/cortex/semantic_cache.py @@ -13,6 +13,8 @@ from datetime import datetime from pathlib import Path +from cortex.utils.db_pool import get_connection_pool, SQLiteConnectionPool + @dataclass(frozen=True) class CacheStats: @@ -52,12 +54,18 @@ def __init__( max_entries: int | None = None, similarity_threshold: float | None = None, ): - """Initialize semantic cache. - - Args: - db_path: Path to SQLite database file - max_entries: Maximum cache entries before LRU eviction (default: 500) - similarity_threshold: Cosine similarity threshold for matches (default: 0.86) + """ + Create a SemanticCache configured to persist LLM responses to a SQLite database. + + Ensures the database directory exists and initializes the SQLite connection pool and schema. + + Parameters: + db_path (str): Path to the SQLite database file. + max_entries (int | None): Maximum number of cache entries before LRU eviction. + If None, reads CORTEX_CACHE_MAX_ENTRIES from the environment or defaults to 500. + similarity_threshold (float | None): Minimum cosine similarity required to consider + a cached entry a semantic match. If None, reads CORTEX_CACHE_SIMILARITY_THRESHOLD + from the environment or defaults to 0.86. """ self.db_path = db_path self.max_entries = ( @@ -71,9 +79,15 @@ def __init__( else float(os.environ.get("CORTEX_CACHE_SIMILARITY_THRESHOLD", "0.86")) ) self._ensure_db_directory() + self._pool: SQLiteConnectionPool | None = None self._init_database() def _ensure_db_directory(self) -> None: + """ + Ensure the parent directory for the configured database path exists, and fall back to a user-local directory on permission errors. + + Attempts to create the parent directory for self.db_path (recursively). If directory creation raises PermissionError, creates ~/.cortex and updates self.db_path to use ~/ .cortex/cache.db. + """ db_dir = Path(self.db_path).parent try: db_dir.mkdir(parents=True, exist_ok=True) @@ -83,8 +97,15 @@ def _ensure_db_directory(self) -> None: self.db_path = str(user_dir / "cache.db") def _init_database(self) -> None: - conn = sqlite3.connect(self.db_path) - try: + # Initialize connection pool (thread-safe singleton) + """ + Initialize the persistent SQLite-backed cache schema and create a thread-safe connection pool. + + Creates or reuses a connection pool for the configured database path, ensures the cache schema exists (entries table with LRU index and unique constraint, and a single-row stats table), and initializes the stats row. + """ + self._pool = get_connection_pool(self.db_path, pool_size=5) + + with self._pool.get_connection() as conn: cur = conn.cursor() cur.execute( """ @@ -126,11 +147,15 @@ def _init_database(self) -> None: ) cur.execute("INSERT OR IGNORE INTO llm_cache_stats(id, hits, misses) VALUES (1, 0, 0)") conn.commit() - finally: - conn.close() @staticmethod def _utcnow_iso() -> str: + """ + Return the current UTC datetime in ISO 8601 format with seconds precision and a trailing "Z". + + Returns: + str: UTC datetime string formatted like "YYYY-MM-DDTHH:MM:SSZ". + """ return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" @staticmethod @@ -223,8 +248,7 @@ def get_commands( prompt_hash = self._hash_text(prompt) now = self._utcnow_iso() - conn = sqlite3.connect(self.db_path) - try: + with self._pool.get_connection() as conn: cur = conn.cursor() cur.execute( """ @@ -286,8 +310,6 @@ def get_commands( self._record_miss(conn) conn.commit() return None - finally: - conn.close() def put_commands( self, @@ -297,14 +319,18 @@ def put_commands( system_prompt: str, commands: list[str], ) -> None: - """Store commands in cache for future retrieval. - - Args: - prompt: User's natural language request - provider: LLM provider name - model: Model name - system_prompt: System prompt used for generation - commands: List of shell commands to cache + """ + Cache the list of commands generated for a specific prompt and system prompt. + + Parameters: + prompt (str): The user's natural language request. + provider (str): LLM provider identifier. + model (str): Model identifier. + system_prompt (str): System prompt used when generating the commands; its hash is used to scope the cache entry. + commands (list[str]): List of commands to store. + + Notes: + If an entry for (provider, model, system_prompt, prompt) already exists its `hit_count` is preserved; the entry's timestamps are set to the current time. After inserting the entry, the cache may evict old entries to respect the configured maximum size. """ system_hash = self._system_hash(system_prompt) prompt_hash = self._hash_text(prompt) @@ -312,8 +338,7 @@ def put_commands( vec = self._embed(prompt) embedding_blob = self._pack_embedding(vec) - conn = sqlite3.connect(self.db_path) - try: + with self._pool.get_connection() as conn: conn.execute( """ INSERT OR REPLACE INTO llm_cache_entries( @@ -342,10 +367,16 @@ def put_commands( ) self._evict_if_needed(conn) conn.commit() - finally: - conn.close() def _evict_if_needed(self, conn: sqlite3.Connection) -> None: + """ + Ensure the cache contains at most self.max_entries by removing the least-recently accessed rows. + + If the number of entries in llm_cache_entries exceeds self.max_entries, deletes the oldest rows ordered by last_accessed until the count equals self.max_entries. This operation modifies the provided SQLite connection's database. + + Parameters: + conn (sqlite3.Connection): An open SQLite connection used to execute the eviction statements. + """ cur = conn.cursor() cur.execute("SELECT COUNT(1) FROM llm_cache_entries") count = int(cur.fetchone()[0]) @@ -366,18 +397,16 @@ def _evict_if_needed(self, conn: sqlite3.Connection) -> None: ) def stats(self) -> CacheStats: - """Get current cache statistics. - + """ + Return current cache statistics. + Returns: - CacheStats object with hits, misses, and computed metrics + CacheStats: Hit and miss counts with derived metrics (total lookups and hit rate). """ - conn = sqlite3.connect(self.db_path) - try: + with self._pool.get_connection() as conn: cur = conn.cursor() cur.execute("SELECT hits, misses FROM llm_cache_stats WHERE id = 1") row = cur.fetchone() if row is None: return CacheStats(hits=0, misses=0) - return CacheStats(hits=int(row[0]), misses=int(row[1])) - finally: - conn.close() + return CacheStats(hits=int(row[0]), misses=int(row[1])) \ No newline at end of file diff --git a/cortex/stack_manager.py b/cortex/stack_manager.py index 952c83a0..fe183750 100644 --- a/cortex/stack_manager.py +++ b/cortex/stack_manager.py @@ -8,6 +8,7 @@ """ import json +import threading from pathlib import Path from typing import Any @@ -19,25 +20,53 @@ class StackManager: def __init__(self) -> None: # stacks.json is in the same directory as this file (cortex/) + """ + Initialize a StackManager by locating the stacks.json file and preparing the in-memory cache and its lock. + + Sets the path to the module-local stacks.json, initializes the cached stacks storage to None, and creates a threading lock to protect access to the cache. + """ self.stacks_file = Path(__file__).parent / "stacks.json" self._stacks = None + self._stacks_lock = threading.Lock() # Protect _stacks cache def load_stacks(self) -> dict[str, Any]: - """Load stacks from JSON file""" + """ + Load and cache stacks configuration from the module's stacks.json file in a thread-safe manner. + + Loads and parses the JSON file at self.stacks_file and caches the resulting dictionary on the instance. Subsequent calls return the cached value. The loading path is synchronized to be safe for concurrent callers. + + Returns: + dict[str, Any]: Parsed stacks configuration (typically contains a "stacks" key with the list of stacks). + + Raises: + FileNotFoundError: If the stacks file does not exist at self.stacks_file. + ValueError: If the stacks file contains invalid JSON. + """ + # Fast path: check without lock if self._stacks is not None: return self._stacks - try: - with open(self.stacks_file) as f: - self._stacks = json.load(f) - return self._stacks - except FileNotFoundError as e: - raise FileNotFoundError(f"Stacks config not found at {self.stacks_file}") from e - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in {self.stacks_file}") from e + # Slow path: acquire lock and recheck + with self._stacks_lock: + if self._stacks is not None: + return self._stacks + + try: + with open(self.stacks_file) as f: + self._stacks = json.load(f) + return self._stacks + except FileNotFoundError as e: + raise FileNotFoundError(f"Stacks config not found at {self.stacks_file}") from e + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in {self.stacks_file}") from e def list_stacks(self) -> list[dict[str, Any]]: - """Get all available stacks""" + """ + Return the list of available stack definitions. + + Returns: + list[dict[str, Any]]: A list of stack dictionaries from the loaded configuration; empty list if no stacks are defined. + """ stacks = self.load_stacks() return stacks.get("stacks", []) @@ -99,4 +128,4 @@ def describe_stack(self, stack_id: str) -> str: hardware = stack.get("hardware", "any") output += f"Hardware: {hardware}\n" - return output + return output \ No newline at end of file diff --git a/cortex/transaction_history.py b/cortex/transaction_history.py index 790ac6e2..c1331819 100644 --- a/cortex/transaction_history.py +++ b/cortex/transaction_history.py @@ -22,6 +22,9 @@ logger = logging.getLogger(__name__) +import threading # For thread-safe singleton pattern + + class TransactionType(Enum): """Types of package transactions.""" @@ -643,7 +646,15 @@ def undo( return result def undo_last(self, dry_run: bool = False) -> dict[str, Any]: - """Undo the most recent successful transaction.""" + """ + Trigger an undo operation for the most recent completed transaction. + + Parameters: + dry_run (bool): If True, do not execute rollback commands; return the actions that would be taken. + + Returns: + result (dict[str, Any]): Operation result. `success` is True for a successful (or simulated) undo; if False, an `error` key explains the failure. + """ recent = self.history.get_recent(limit=1, status_filter=TransactionStatus.COMPLETED) if not recent: @@ -652,24 +663,44 @@ def undo_last(self, dry_run: bool = False) -> dict[str, Any]: return self.undo(recent[0].id, dry_run=dry_run) -# CLI-friendly functions +# Global instances for easy access (thread-safe singletons) _history_instance = None +_history_lock = threading.Lock() _undo_manager_instance = None +_undo_manager_lock = threading.Lock() -def get_history() -> TransactionHistory: - """Get the global transaction history instance.""" +def get_history() -> "TransactionHistory": + """ + Return the module-wide TransactionHistory singleton, creating it if necessary in a thread-safe manner. + + Returns: + TransactionHistory: The global TransactionHistory singleton instance. + """ global _history_instance + # Fast path: avoid lock if already initialized if _history_instance is None: - _history_instance = TransactionHistory() + with _history_lock: + # Double-checked locking pattern + if _history_instance is None: + _history_instance = TransactionHistory() return _history_instance -def get_undo_manager() -> UndoManager: - """Get the global undo manager instance.""" +def get_undo_manager() -> "UndoManager": + """ + Retrieve the module-wide singleton UndoManager, creating it lazily in a thread-safe manner. + + Returns: + undo_manager (UndoManager): The shared UndoManager instance used for rollback operations. + """ global _undo_manager_instance + # Fast path: avoid lock if already initialized if _undo_manager_instance is None: - _undo_manager_instance = UndoManager(get_history()) + with _undo_manager_lock: + # Double-checked locking pattern + if _undo_manager_instance is None: + _undo_manager_instance = UndoManager(get_history()) return _undo_manager_instance @@ -738,4 +769,4 @@ def undo_last(dry_run: bool = False) -> dict[str, Any]: print(f" Total transactions: {stats['total_transactions']}") print(f" By type: {stats['by_type']}") - print("\n✅ Demo complete!") + print("\n✅ Demo complete!") \ No newline at end of file diff --git a/cortex/utils/db_pool.py b/cortex/utils/db_pool.py new file mode 100644 index 00000000..2b03a614 --- /dev/null +++ b/cortex/utils/db_pool.py @@ -0,0 +1,231 @@ +""" +Thread-safe SQLite connection pooling for Cortex Linux. + +Provides connection pooling to prevent database lock contention +and enable safe concurrent access in Python 3.14 free-threading mode. + +Author: Cortex Linux Team +License: Apache 2.0 +""" + +import queue +import sqlite3 +import threading +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator + + +class SQLiteConnectionPool: + """ + Thread-safe SQLite connection pool. + + SQLite has limited concurrency support: + - Multiple readers are OK with WAL mode + - Single writer at a time (database-level locking) + - SQLITE_BUSY errors occur under high write contention + + This pool manages connections and handles concurrent access gracefully. + + Usage: + pool = SQLiteConnectionPool("/path/to/db.sqlite", pool_size=5) + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT ...") + """ + + def __init__( + self, + db_path: str | Path, + pool_size: int = 5, + timeout: float = 5.0, + check_same_thread: bool = False, + ): + """ + Initialize the connection pool and pre-create the configured SQLite connections. + + Parameters: + db_path: File system path to the SQLite database. + pool_size: Maximum number of connections to maintain in the pool. + timeout: Default timeout (seconds) used when acquiring a connection. + check_same_thread: If False, allows connections to be used across threads; set True to enforce SQLite's same-thread restriction. + """ + self.db_path = str(db_path) + self.pool_size = pool_size + self.timeout = timeout + self.check_same_thread = check_same_thread + + # Connection pool (thread-safe queue) + self._pool: queue.Queue[sqlite3.Connection] = queue.Queue(maxsize=pool_size) + self._pool_lock = threading.Lock() + + # Initialize connections + for _ in range(pool_size): + conn = self._create_connection() + self._pool.put(conn) + + def _create_connection(self) -> sqlite3.Connection: + """ + Create a new SQLite connection configured for pooled concurrent access. + + The connection is tuned for concurrency and performance using these PRAGMA settings: + journal_mode=WAL, synchronous=NORMAL, cache_size=-64000 (64MB), temp_store=MEMORY, and foreign_keys=ON. + + Returns: + A configured sqlite3.Connection connected to the pool's database path. + """ + conn = sqlite3.connect( + self.db_path, + timeout=self.timeout, + check_same_thread=self.check_same_thread, + ) + + # Enable WAL mode for better concurrency + # WAL allows multiple readers + single writer simultaneously + conn.execute("PRAGMA journal_mode=WAL") + + # NORMAL synchronous mode (faster, still safe with WAL) + conn.execute("PRAGMA synchronous=NORMAL") + + # Larger cache for better performance + conn.execute("PRAGMA cache_size=-64000") # 64MB cache + + # Store temp tables in memory + conn.execute("PRAGMA temp_store=MEMORY") + + # Enable foreign keys (if needed) + conn.execute("PRAGMA foreign_keys=ON") + + return conn + + @contextmanager + def get_connection(self) -> Iterator[sqlite3.Connection]: + """ + Acquire a connection from the pool and return it to the pool when the context exits. + + Used as a context manager; yields a `sqlite3.Connection` that callers can use for database operations. The connection is returned to the pool after the context block completes, even if an exception is raised. + + Returns: + sqlite3.Connection: A connection from the pool. + + Raises: + TimeoutError: If a connection cannot be acquired within the pool's configured timeout. + """ + try: + conn = self._pool.get(timeout=self.timeout) + except queue.Empty: + raise TimeoutError( + f"Could not acquire database connection within {self.timeout}s. " + f"Pool size: {self.pool_size}. Consider increasing pool size or timeout." + ) + + try: + yield conn + finally: + # Always return connection to pool + try: + self._pool.put(conn, block=False) + except queue.Full: + # Should never happen, but log if it does + import logging + logging.error(f"Connection pool overflow for {self.db_path}") + + def close_all(self): + """ + Close all connections currently stored in the pool in a thread-safe manner. + + Returns: + closed_count (int): Number of connections that were closed. + """ + with self._pool_lock: + closed_count = 0 + while not self._pool.empty(): + try: + conn = self._pool.get_nowait() + conn.close() + closed_count += 1 + except queue.Empty: + break + return closed_count + + def __enter__(self): + """ + Enter the runtime context and provide the pool instance. + + Returns: + SQLiteConnectionPool: The same pool instance to be used as the context manager target. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Close all connections when exiting context. + + For pools managed as global singletons via get_connection_pool(), + avoid closing connections here to prevent affecting other users + of the same shared pool. + """ + # If this pool is a global singleton, do not close it on context exit. + # This ensures that using a globally shared pool in a `with` block + # does not disrupt other parts of the application. + if self not in _pools.values(): + self.close_all() + return False + + +# Global connection pools (one per database path) +# Thread-safe lazy initialization +_pools: dict[str, SQLiteConnectionPool] = {} +_pools_lock = threading.Lock() + + +def get_connection_pool( + db_path: str | Path, + pool_size: int = 5, + timeout: float = 5.0, +) -> SQLiteConnectionPool: + """ + Retrieve or create a shared SQLiteConnectionPool for the given database path. + + If a pool already exists for the path, that pool is returned; otherwise a new pool is created, registered, and returned. + + Parameters: + db_path (str | Path): Filesystem path to the SQLite database. + pool_size (int): Maximum number of connections the pool will hold. + timeout (float): Maximum seconds to wait when acquiring a connection from the pool. + + Returns: + SQLiteConnectionPool: The connection pool associated with the given database path. + """ + db_path = str(db_path) + + # Fast path: check without lock + if db_path in _pools: + return _pools[db_path] + + # Slow path: acquire lock and double-check + with _pools_lock: + if db_path not in _pools: + _pools[db_path] = SQLiteConnectionPool( + db_path, + pool_size=pool_size, + timeout=timeout, + ) + return _pools[db_path] + + +def close_all_pools(): + """ + Close and remove all global SQLiteConnectionPool instances. + + Closes every connection in the global pool registry and clears the registry. + + Returns: + int: Total number of connections closed. + """ + with _pools_lock: + total_closed = 0 + for pool in _pools.values(): + total_closed += pool.close_all() + _pools.clear() + return total_closed \ No newline at end of file diff --git a/tests/test_thread_safety.py b/tests/test_thread_safety.py new file mode 100644 index 00000000..c88fb2f4 --- /dev/null +++ b/tests/test_thread_safety.py @@ -0,0 +1,409 @@ +""" +Thread-safety tests for Python 3.14 free-threading compatibility. + +Run with: + python3.14 -m pytest tests/test_thread_safety.py -v # With GIL + PYTHON_GIL=0 python3.14t -m pytest tests/test_thread_safety.py -v # Without GIL + +Author: Cortex Linux Team +License: Apache 2.0 +""" + +import concurrent.futures +import os +import random +import tempfile +import time + +import pytest + + +def test_singleton_thread_safety_transaction_history(): + """Test that transaction history singleton is thread-safe.""" + from cortex.transaction_history import get_history + + results = [] + + def get_instance(): + """ + Obtain the transaction history singleton and record its identity. + + Appends the singleton object's id() to the enclosing `results` list for later uniqueness checks. + """ + history = get_history() + results.append(id(history)) + + # Hammer singleton initialization from 100 threads + with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor: + futures = [executor.submit(get_instance) for _ in range(1000)] + concurrent.futures.wait(futures) + + # All threads should get the SAME instance + unique_instances = len(set(results)) + assert unique_instances == 1, f"Multiple singleton instances created! Found {unique_instances} different instances" + + +def test_singleton_thread_safety_hardware_detection(): + """Test that hardware detector singleton is thread-safe.""" + from cortex.hardware_detection import get_detector + + results = [] + + def get_instance(): + """ + Record the identity of the hardware detector by appending its object id to the surrounding `results` list. + + This helper obtains the global detector and appends `id(detector)` to the `results` list captured from the enclosing scope. + """ + detector = get_detector() + results.append(id(detector)) + + # 50 threads trying to get detector simultaneously + with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: + futures = [executor.submit(get_instance) for _ in range(500)] + concurrent.futures.wait(futures) + + # All threads should get the SAME instance + unique_instances = len(set(results)) + assert unique_instances == 1, f"Multiple detector instances created! Found {unique_instances} different instances" + + +def test_singleton_thread_safety_degradation_manager(): + """Test that degradation manager singleton is thread-safe.""" + from cortex.graceful_degradation import get_degradation_manager + + results = [] + + def get_instance(): + """ + Record the identity of the global degradation manager instance. + + Appends the object's id returned by get_degradation_manager() to the surrounding `results` list so concurrent callers can verify singleton identity. + """ + manager = get_degradation_manager() + results.append(id(manager)) + + # 50 threads trying to get manager simultaneously + with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: + futures = [executor.submit(get_instance) for _ in range(500)] + concurrent.futures.wait(futures) + + # All threads should get the SAME instance + unique_instances = len(set(results)) + assert unique_instances == 1, f"Multiple manager instances created! Found {unique_instances} different instances" + + +def test_connection_pool_concurrent_reads(): + """ + Verify that the SQLite connection pool returns consistent results under concurrent read load. + + Sets up a temporary SQLite database with 100 rows, creates a connection pool (pool_size=5), and launches 20 threads that each perform 50 SELECT COUNT(*) reads. Asserts every read returns 100 and ensures the pool and temporary database file are cleaned up. + """ + from cortex.utils.db_pool import get_connection_pool + + # Create temporary database + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + + try: + # Initialize database with test data + pool = get_connection_pool(db_path, pool_size=5) + with pool.get_connection() as conn: + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)") + for i in range(100): + conn.execute("INSERT INTO test (value) VALUES (?)", (f"value_{i}",)) + conn.commit() + + # Test concurrent reads + def read_data(thread_id: int): + """ + Read the row count from the `test` table 50 times using a connection from the pool. + + Parameters: + thread_id (int): Caller thread identifier used to correlate results with the caller. + + Returns: + list[int]: A list of 50 integers, each the result of `SELECT COUNT(*) FROM test` observed on each read. + """ + results = [] + for _ in range(50): + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM test") + count = cursor.fetchone()[0] + results.append(count) + return results + + # 20 threads reading simultaneously + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(read_data, i) for i in range(20)] + all_results = [f.result() for f in futures] + + # All reads should return 100 + for results in all_results: + assert all(count == 100 for count in results), "Inconsistent read results" + + finally: + # Cleanup + pool.close_all() + os.unlink(db_path) + + +def test_connection_pool_concurrent_writes(): + """Test SQLite connection pool under concurrent write load.""" + from cortex.utils.db_pool import get_connection_pool + + # Create temporary database + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + + try: + # Initialize database + pool = get_connection_pool(db_path, pool_size=5) + with pool.get_connection() as conn: + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, thread_id INTEGER, value TEXT)") + conn.commit() + + errors = [] + + def write_data(thread_id: int): + """ + Insert 20 rows into the test table using the shared connection pool, recording any errors. + + Parameters: + thread_id (int): Identifier used to set the `thread_id` column and to distinguish inserted values. + + Details: + For i from 0 to 19, obtains a connection from the surrounding `pool`, inserts a row with `thread_id` + and `value` = "thread_{thread_id}_value_{i}", and commits each insert. On exception, appends a tuple + (thread_id, error_message) to the surrounding `errors` list. + """ + try: + for i in range(20): + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO test (thread_id, value) VALUES (?, ?)", + (thread_id, f"thread_{thread_id}_value_{i}") + ) + conn.commit() + except Exception as e: + errors.append((thread_id, str(e))) + + # 10 threads writing simultaneously + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(write_data, i) for i in range(10)] + concurrent.futures.wait(futures) + + # Should handle concurrency gracefully (no crashes) + if errors: + pytest.fail(f"Concurrent write errors: {errors}") + + # Verify all writes succeeded + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM test") + count = cursor.fetchone()[0] + assert count == 200, f"Expected 200 rows, got {count}" + + finally: + # Cleanup + pool.close_all() + os.unlink(db_path) + + +def test_hardware_detection_parallel(): + """Test hardware detection from multiple threads.""" + from cortex.hardware_detection import get_detector + + results = [] + errors = [] + + def detect_hardware(): + """ + Detect hardware using the shared detector and record the CPU core count or any error. + + Appends the detected CPU core count (uses 1 if detection reports 0) to the outer-scope list `results`. If an exception occurs, appends the exception message string to the outer-scope list `errors`. + """ + try: + detector = get_detector() + info = detector.detect() + # Store CPU core count as a simple check + # Use multiprocessing.cpu_count() as fallback if cores is 0 + cores = info.cpu.cores if info.cpu.cores > 0 else 1 + results.append(cores) + except Exception as e: + errors.append(str(e)) + + # 10 threads detecting hardware simultaneously + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(detect_hardware) for _ in range(10)] + concurrent.futures.wait(futures) + + # Check for errors + assert len(errors) == 0, f"Hardware detection errors: {errors}" + + # Should have results from all threads + assert len(results) == 10, f"Expected 10 results, got {len(results)}" + + # All results should be identical (same hardware) + unique_results = len(set(results)) + assert unique_results == 1, f"Inconsistent hardware detection! Got {unique_results} different results: {set(results)}" + + +def test_connection_pool_timeout(): + """ + Verify that the SQLite connection pool raises a TimeoutError when all connections are in use. + + Creates a temporary database and a pool configured with pool_size=2 and timeout=0.5, holds both connections via the pool's context-manager API, and asserts that attempting to acquire a third connection raises a TimeoutError with the message "Could not acquire database connection". Cleans up held connections, closes the pool, and removes the temporary database file. + """ + from cortex.utils.db_pool import get_connection_pool + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + + pool = None + conn1_cm = conn2_cm = None + try: + # Create small pool + pool = get_connection_pool(db_path, pool_size=2, timeout=0.5) + + # Hold all connections via the public context manager API + conn1_cm = pool.get_connection() + conn1 = conn1_cm.__enter__() + conn2_cm = pool.get_connection() + conn2 = conn2_cm.__enter__() + + # Try to get third connection (should timeout) + with pytest.raises(TimeoutError, match="Could not acquire database connection"): + with pool.get_connection() as conn: + pass + + finally: + # Release held connections if they were acquired + if conn2_cm is not None: + conn2_cm.__exit__(None, None, None) + if conn1_cm is not None: + conn1_cm.__exit__(None, None, None) + if pool is not None: + pool.close_all() + os.unlink(db_path) + + +def test_connection_pool_context_manager(): + """Test that connection pool works as context manager.""" + from cortex.utils.db_pool import SQLiteConnectionPool + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + + try: + # Use pool as context manager + with SQLiteConnectionPool(db_path, pool_size=3) as pool: + with pool.get_connection() as conn: + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") + conn.commit() + + # Pool should still work + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM test") + cursor.fetchall() + + # After exiting context, connections should be closed + # (pool._pool should be empty or inaccessible) + + finally: + os.unlink(db_path) + + +@pytest.mark.slow +def test_stress_concurrent_operations(): + """Stress test with many threads performing mixed read/write operations.""" + from cortex.utils.db_pool import get_connection_pool + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + + try: + pool = get_connection_pool(db_path, pool_size=5) + + # Initialize + with pool.get_connection() as conn: + conn.execute("CREATE TABLE stress (id INTEGER PRIMARY KEY AUTOINCREMENT, data TEXT, timestamp REAL)") + conn.commit() + + errors = [] + + def mixed_operations(thread_id: int): + """ + Perform 50 database operations mixing reads and writes against the shared `pool`. + + Each iteration performs approximately 70% SELECT COUNT(*) reads and 30% INSERT writes that record the calling thread's id and a timestamp. On error, appends a tuple (thread_id, error_message) to the shared `errors` list. + + Parameters: + thread_id (int): Identifier used in inserted row data and included in any error records. + """ + try: + for i in range(50): + if random.random() < 0.7: # 70% reads + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM stress") + cursor.fetchone() + else: # 30% writes + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO stress (data, timestamp) VALUES (?, ?)", + (f"thread_{thread_id}", time.time()) + ) + conn.commit() + except Exception as e: + errors.append((thread_id, str(e))) + + # 20 threads doing mixed operations + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(mixed_operations, i) for i in range(20)] + concurrent.futures.wait(futures) + + if errors: + pytest.fail(f"Stress test errors: {errors[:5]}") # Show first 5 + + # Verify database integrity + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM stress") + count = cursor.fetchone()[0] + # Should have some writes (not exact count due to randomness) + assert count > 0, "No writes occurred" + + finally: + pool.close_all() + os.unlink(db_path) + + +if __name__ == "__main__": + # Quick standalone test + print("Running quick thread-safety tests...") + print("\n1. Testing transaction history singleton...") + test_singleton_thread_safety_transaction_history() + print("✅ PASSED") + + print("\n2. Testing hardware detection singleton...") + test_singleton_thread_safety_hardware_detection() + print("✅ PASSED") + + print("\n3. Testing degradation manager singleton...") + test_singleton_thread_safety_degradation_manager() + print("✅ PASSED") + + print("\n4. Testing connection pool concurrent reads...") + test_connection_pool_concurrent_reads() + print("✅ PASSED") + + print("\n5. Testing connection pool concurrent writes...") + test_connection_pool_concurrent_writes() + print("✅ PASSED") + + print("\n✅ All quick tests passed! Run with pytest for full suite.") \ No newline at end of file