""" Enhanced Base Extension System for Wuhp Agents Compatible with the modular app.py - combines state management with capability-based discovery. """ from abc import ABC, abstractmethod from typing import Dict, Any, List, Optional, Callable from google.genai import types import json import datetime from pathlib import Path class BaseExtension(ABC): """Enhanced base class for all agent extensions with capability-based discovery""" def __init__(self): self.enabled = False self.state: Dict[str, Any] = {} self._state_validators: Dict[str, Callable] = {} self._state_hooks: Dict[str, List[Callable]] = { 'before_update': [], 'after_update': [], 'before_tool_call': [], 'after_tool_call': [] } # ========================================== # REQUIRED PROPERTIES (Use @property decorator) # ========================================== @property @abstractmethod def name(self) -> str: """Unique identifier for the extension (lowercase, no spaces)""" pass @property @abstractmethod def display_name(self) -> str: """Human-readable name shown in UI""" pass @property @abstractmethod def description(self) -> str: """Brief description of what the extension does""" pass @property def icon(self) -> str: """Emoji icon for the extension""" return "🔧" @property def version(self) -> str: """Extension version for compatibility checking""" return "1.0.0" # ========================================== # REQUIRED ABSTRACT METHODS # ========================================== @abstractmethod def get_system_context(self) -> str: """Returns context to inject into system prompt when enabled""" pass @abstractmethod def get_tools(self) -> List[types.Tool]: """Returns Gemini function calling tools for this extension""" pass # ========================================== # CAPABILITY-BASED DISCOVERY SYSTEM (NEW!) # ========================================== def get_capabilities(self) -> Dict[str, Any]: """ Declare what this extension can do. This enables automatic integration without hardcoding. Returns dict with capabilities like: { 'provides_data': ['stock_history', 'financial_data'], 'consumes_data': ['numerical_series', 'comparison_data'], 'creates_output': ['visualization', 'report'], 'keywords': ['stock', 'price', 'chart', 'graph'], 'data_outputs': { 'stock_history': { 'format': 'time_series', 'fields': ['dates', 'close_prices', 'ticker'] } } } Override this in your extension to declare capabilities! """ return { 'provides_data': [], 'consumes_data': [], 'creates_output': [], 'keywords': [], 'data_outputs': {} } def can_consume(self, data_type: str) -> bool: """Check if this extension can consume a specific data type""" return data_type in self.get_capabilities().get('consumes_data', []) def can_provide(self, data_type: str) -> bool: """Check if this extension can provide a specific data type""" return data_type in self.get_capabilities().get('provides_data', []) def can_create(self, output_type: str) -> bool: """Check if this extension can create a specific output""" return output_type in self.get_capabilities().get('creates_output', []) def get_suggested_next_action(self, tool_result: Dict[str, Any], available_extensions: List['BaseExtension']) -> Optional[Dict[str, Any]]: """ Given a tool result, suggest what should happen next. Returns None or dict with: {'extension': ext_name, 'tool': tool_name, 'reason': str, 'data': dict} This makes extensions self-aware of their integration opportunities! Example implementation: ```python # Find visualization extension by capability (NO HARDCODING!) viz_ext = next((ext for ext in available_extensions if ext.can_create('visualization')), None) if viz_ext and 'dates' in tool_result: return { 'extension': viz_ext.name, 'tool': 'create_line_chart', 'reason': 'Data is ready for time-series visualization', 'data': tool_result } ``` Override this in your extension to suggest next actions! """ return None # ========================================== # STATE MANAGEMENT # ========================================== def get_state_summary(self, user_id: str) -> Optional[str]: """ Override to provide a human-readable summary of current state. This will be included in the system prompt for context awareness. Example: "You have 2 active timers and 5 pending tasks" """ return None def initialize_state(self, user_id: str) -> None: """Initialize empty state for a new user""" if user_id not in self.state: self.state[user_id] = self._get_default_state() self._run_hooks('after_update', user_id, self.state[user_id]) def _get_default_state(self) -> Dict[str, Any]: """Override to provide default state structure""" return { 'created_at': datetime.datetime.now().isoformat(), 'last_updated': datetime.datetime.now().isoformat() } def get_state(self, user_id: str) -> Dict[str, Any]: """Get state for a specific user""" if user_id not in self.state: self.initialize_state(user_id) return self.state.get(user_id, {}) def update_state(self, user_id: str, updates: Dict[str, Any]) -> None: """Update state for a specific user with validation and hooks""" if user_id not in self.state: self.initialize_state(user_id) # Run before_update hooks self._run_hooks('before_update', user_id, updates) # Validate updates self._validate_state_updates(updates) # Update timestamp updates['last_updated'] = datetime.datetime.now().isoformat() # Apply updates self.state[user_id].update(updates) # Run after_update hooks self._run_hooks('after_update', user_id, self.state[user_id]) def _validate_state_updates(self, updates: Dict[str, Any]) -> None: """Validate state updates using registered validators""" for key, value in updates.items(): if key in self._state_validators: validator = self._state_validators[key] if not validator(value): raise ValueError(f"Invalid value for state key '{key}': {value}") def register_state_validator(self, key: str, validator: Callable[[Any], bool]) -> None: """Register a validation function for a state key""" self._state_validators[key] = validator def add_hook(self, hook_type: str, func: Callable) -> None: """Add a hook function to be called at specific points""" if hook_type in self._state_hooks: self._state_hooks[hook_type].append(func) def _run_hooks(self, hook_type: str, *args, **kwargs) -> None: """Run all registered hooks of a specific type""" for hook_func in self._state_hooks.get(hook_type, []): try: hook_func(*args, **kwargs) except Exception as e: print(f"Hook error in {self.name}.{hook_type}: {e}") # ========================================== # TOOL EXECUTION # ========================================== def handle_tool_call(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any: """ Handle a tool call from Gemini with hooks. Override this to implement tool logic, OR override _execute_tool instead. """ # Run before_tool_call hooks self._run_hooks('before_tool_call', user_id, tool_name, args) try: # Call the actual tool implementation result = self._execute_tool(user_id, tool_name, args) # Run after_tool_call hooks self._run_hooks('after_tool_call', user_id, tool_name, args, result) return result except Exception as e: error_result = { "success": False, "error": str(e), "tool": tool_name } self._run_hooks('after_tool_call', user_id, tool_name, args, error_result) return error_result def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any: """ Override this method to implement actual tool logic. This is called by handle_tool_call after running before hooks. """ return {"error": f"Tool {tool_name} not implemented"} def get_tool_by_name(self, tool_name: str) -> Optional[types.FunctionDeclaration]: """Helper to find a specific tool declaration by name""" for tool in self.get_tools(): if hasattr(tool, 'function_declarations'): for func_decl in tool.function_declarations: if func_decl.name == tool_name: return func_decl return None # ========================================== # LIFECYCLE HOOKS # ========================================== def on_enable(self, user_id: str) -> Optional[str]: """ Called when extension is enabled for a user. Return a message to show to the user, or None. """ self.initialize_state(user_id) self.enabled = True return None def on_disable(self, user_id: str) -> Optional[str]: """ Called when extension is disabled for a user. Return a message to show to the user, or None. """ self.enabled = False return None def get_proactive_message(self, user_id: str) -> Optional[str]: """ Called periodically to check if extension wants to proactively message user. Return message string or None. Override this to implement proactive notifications (timers, reminders, etc.) """ return None # ========================================== # PERSISTENCE # ========================================== def serialize_state(self, user_id: str) -> str: """Serialize state to JSON for persistence""" return json.dumps(self.get_state(user_id), indent=2, default=str) def deserialize_state(self, user_id: str, state_json: str) -> None: """Load state from JSON""" try: loaded = json.loads(state_json) self.state[user_id] = loaded except Exception as e: print(f"Error loading state for {self.name}: {e}") def save_state_to_file(self, user_id: str, filepath: Optional[str] = None) -> str: """Save state to a file""" if filepath is None: filepath = f"state_{self.name}_{user_id}.json" state_path = Path(filepath) state_path.parent.mkdir(parents=True, exist_ok=True) with open(state_path, 'w') as f: f.write(self.serialize_state(user_id)) return str(state_path) def load_state_from_file(self, user_id: str, filepath: str) -> bool: """Load state from a file""" try: with open(filepath, 'r') as f: state_json = f.read() self.deserialize_state(user_id, state_json) return True except Exception as e: print(f"Error loading state from {filepath}: {e}") return False def clear_state(self, user_id: str) -> None: """Clear all state for a user (useful for testing/reset)""" self.state[user_id] = self._get_default_state() def export_state(self, user_id: str) -> Dict[str, Any]: """Export state in a format suitable for external use""" return { 'extension': self.name, 'version': self.version, 'exported_at': datetime.datetime.now().isoformat(), 'state': self.get_state(user_id) } def import_state(self, user_id: str, exported_data: Dict[str, Any]) -> bool: """Import state from exported data with version checking""" try: if exported_data.get('extension') != self.name: print(f"Extension name mismatch: {exported_data.get('extension')} != {self.name}") return False # Could add version compatibility checks here self.state[user_id] = exported_data['state'] return True except Exception as e: print(f"Error importing state: {e}") return False # ========================================== # DEPENDENCIES & VALIDATION # ========================================== def get_dependencies(self) -> List[str]: """ Return list of other extension names this extension depends on. The orchestrator can use this to ensure dependencies are loaded. """ return [] def validate_dependencies(self, available_extensions: List[str]) -> bool: """Check if all required dependencies are available""" deps = self.get_dependencies() return all(dep in available_extensions for dep in deps) # ========================================== # LOGGING & METRICS # ========================================== def log_activity(self, user_id: str, activity: str, details: Dict[str, Any] = None) -> None: """ Log extension activity for debugging/auditing. Override to implement custom logging. """ timestamp = datetime.datetime.now().isoformat() log_entry = { 'timestamp': timestamp, 'extension': self.name, 'activity': activity, 'details': details or {} } # Store in state state = self.get_state(user_id) if 'activity_log' not in state: state['activity_log'] = [] state['activity_log'].append(log_entry) # Keep only last 100 entries if len(state['activity_log']) > 100: state['activity_log'] = state['activity_log'][-100:] def get_recent_activity(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]: """Get recent activity log entries""" state = self.get_state(user_id) activity_log = state.get('activity_log', []) return activity_log[-limit:] def get_metrics(self, user_id: str) -> Dict[str, Any]: """ Override to provide usage metrics/statistics. Example: {"total_timers_created": 15, "active_timers": 2} """ return {} def health_check(self, user_id: str) -> Dict[str, Any]: """ Perform a health check on the extension state. Returns dict with 'healthy': bool and optional 'issues': list """ return { 'healthy': True, 'extension': self.name, 'version': self.version } # ========================================== # HELPER FUNCTIONS FOR CAPABILITY MATCHING # ========================================== def find_extensions_with_capability(extensions: List[BaseExtension], capability_type: str, capability_value: str) -> List[BaseExtension]: """ Helper function to find extensions with a specific capability. Args: extensions: List of extension instances to search capability_type: 'provides_data', 'consumes_data', or 'creates_output' capability_value: The specific capability to search for Returns: List of extensions that have the specified capability Example: viz_extensions = find_extensions_with_capability( all_extensions, 'creates_output', 'visualization' ) """ matching = [] for ext in extensions: caps = ext.get_capabilities().get(capability_type, []) if capability_value in caps: matching.append(ext) return matching def get_data_flow_possibilities(extensions: List[BaseExtension]) -> List[Dict[str, Any]]: """ Analyze a list of extensions and return all possible data flow chains. Returns list of dicts with: { 'provider': extension_name, 'data_type': data_type, 'consumers': [list of extension names that can consume this data] } Example usage: flows = get_data_flow_possibilities(enabled_extensions) for flow in flows: print(f"{flow['provider']} produces {flow['data_type']}") print(f" → can be consumed by: {', '.join(flow['consumers'])}") """ flows = [] for provider in extensions: provided_data = provider.get_capabilities().get('provides_data', []) for data_type in provided_data: consumers = [ ext.name for ext in extensions if ext.can_consume(data_type) and ext.name != provider.name ] if consumers: flows.append({ 'provider': provider.name, 'data_type': data_type, 'consumers': consumers }) return flows def detect_relevant_extensions(query: str, extensions: List[BaseExtension]) -> List[str]: """ Detect which extensions are relevant to a query based on keywords. Args: query: The user's query string extensions: List of available extensions Returns: List of extension names that are relevant to the query Example: query = "Show me a chart of AAPL stock prices" relevant = detect_relevant_extensions(query, all_extensions) # Returns: ['yfinance', 'visualization'] """ query_lower = query.lower() relevant = [] for ext in extensions: keywords = ext.get_capabilities().get('keywords', []) if any(keyword.lower() in query_lower for keyword in keywords): relevant.append(ext.name) return relevant