Spaces:
Sleeping
Sleeping
Update extensions/visualization.py
Browse files- extensions/visualization.py +90 -23
extensions/visualization.py
CHANGED
|
@@ -1,20 +1,21 @@
|
|
| 1 |
"""
|
| 2 |
-
Data Visualization Extension
|
| 3 |
Create charts, graphs, and visualizations from data
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
from base_extension import BaseExtension
|
| 8 |
from google.genai import types
|
| 9 |
-
from typing import Dict, Any, List
|
| 10 |
import json
|
| 11 |
import matplotlib
|
| 12 |
-
matplotlib.use('Agg')
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
import io
|
| 15 |
import base64
|
| 16 |
from pathlib import Path
|
| 17 |
import numpy as np
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class VisualizationExtension(BaseExtension):
|
|
@@ -35,6 +36,10 @@ class VisualizationExtension(BaseExtension):
|
|
| 35 |
def icon(self) -> str:
|
| 36 |
return "π"
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def get_system_context(self) -> str:
|
| 39 |
return """
|
| 40 |
You have access to a Data Visualization system for creating charts and graphs.
|
|
@@ -107,7 +112,33 @@ create_bar_chart(
|
|
| 107 |
def _get_default_state(self) -> Dict[str, Any]:
|
| 108 |
return {
|
| 109 |
"charts": [],
|
| 110 |
-
"output_dir": "visualizations"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
}
|
| 112 |
|
| 113 |
def get_tools(self) -> List[types.Tool]:
|
|
@@ -220,7 +251,6 @@ create_bar_chart(
|
|
| 220 |
output_dir.mkdir(exist_ok=True)
|
| 221 |
|
| 222 |
# Generate filename
|
| 223 |
-
import datetime
|
| 224 |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 225 |
safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
|
| 226 |
safe_title = safe_title.replace(' ', '_')[:50]
|
|
@@ -247,22 +277,17 @@ create_bar_chart(
|
|
| 247 |
"timestamp": timestamp
|
| 248 |
}
|
| 249 |
state["charts"].append(chart_info)
|
|
|
|
| 250 |
self.update_state(user_id, state)
|
| 251 |
|
| 252 |
-
|
| 253 |
-
filepath_str = str(filepath)
|
| 254 |
-
return filepath_str, img_base64
|
| 255 |
|
| 256 |
-
def
|
| 257 |
-
|
| 258 |
-
if user_id not in self.state:
|
| 259 |
-
self.initialize_state(user_id)
|
| 260 |
-
|
| 261 |
try:
|
| 262 |
if tool_name == "create_line_chart":
|
| 263 |
fig, ax = plt.subplots(figsize=(12, 7))
|
| 264 |
|
| 265 |
-
# Plot each series
|
| 266 |
data = args["data"]
|
| 267 |
|
| 268 |
print(f"π Creating line chart with {len(data)} series")
|
|
@@ -281,15 +306,18 @@ create_bar_chart(
|
|
| 281 |
ax.legend(fontsize=10)
|
| 282 |
ax.grid(True, alpha=0.3, linestyle='--')
|
| 283 |
|
| 284 |
-
# Rotate x-axis labels for readability (especially for dates)
|
| 285 |
plt.xticks(rotation=45, ha='right')
|
| 286 |
-
|
| 287 |
-
# Tight layout to prevent label cutoff
|
| 288 |
plt.tight_layout()
|
| 289 |
|
| 290 |
-
# Save chart and get both filepath and base64
|
| 291 |
filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"])
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
print(f"β
Line chart saved: {filepath}")
|
| 294 |
|
| 295 |
return {
|
|
@@ -320,7 +348,6 @@ create_bar_chart(
|
|
| 320 |
ax.set_title(args["title"], fontsize=14, fontweight='bold')
|
| 321 |
ax.grid(True, alpha=0.3, axis='y')
|
| 322 |
|
| 323 |
-
# Rotate x-labels if needed
|
| 324 |
if len(categories) > 5:
|
| 325 |
plt.xticks(rotation=45, ha='right')
|
| 326 |
|
|
@@ -328,6 +355,12 @@ create_bar_chart(
|
|
| 328 |
|
| 329 |
filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"])
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
return {
|
| 332 |
"success": True,
|
| 333 |
"message": f"Bar chart created: {args['title']}",
|
|
@@ -360,6 +393,12 @@ create_bar_chart(
|
|
| 360 |
|
| 361 |
filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"])
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
return {
|
| 364 |
"success": True,
|
| 365 |
"message": f"Scatter plot created: {args['title']}",
|
|
@@ -384,7 +423,6 @@ create_bar_chart(
|
|
| 384 |
textprops={'fontsize': 11, 'fontweight': 'bold'}
|
| 385 |
)
|
| 386 |
|
| 387 |
-
# Enhance text
|
| 388 |
for text in texts:
|
| 389 |
text.set_fontsize(11)
|
| 390 |
for autotext in autotexts:
|
|
@@ -396,6 +434,12 @@ create_bar_chart(
|
|
| 396 |
|
| 397 |
filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"])
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
return {
|
| 400 |
"success": True,
|
| 401 |
"message": f"Pie chart created: {args['title']}",
|
|
@@ -410,7 +454,7 @@ create_bar_chart(
|
|
| 410 |
|
| 411 |
return {
|
| 412 |
"total_charts": len(charts),
|
| 413 |
-
"charts": charts
|
| 414 |
}
|
| 415 |
|
| 416 |
except Exception as e:
|
|
@@ -427,4 +471,27 @@ create_bar_chart(
|
|
| 427 |
|
| 428 |
def on_enable(self, user_id: str) -> str:
|
| 429 |
self.initialize_state(user_id)
|
| 430 |
-
return "π Data Visualization enabled! I can create charts from stock data and other numerical data. Just ask me to visualize something!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Enhanced Data Visualization Extension
|
| 3 |
Create charts, graphs, and visualizations from data
|
| 4 |
+
Now with better state management and orchestrator integration
|
| 5 |
"""
|
| 6 |
|
| 7 |
from base_extension import BaseExtension
|
| 8 |
from google.genai import types
|
| 9 |
+
from typing import Dict, Any, List, Optional
|
| 10 |
import json
|
| 11 |
import matplotlib
|
| 12 |
+
matplotlib.use('Agg')
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
import io
|
| 15 |
import base64
|
| 16 |
from pathlib import Path
|
| 17 |
import numpy as np
|
| 18 |
+
import datetime
|
| 19 |
|
| 20 |
|
| 21 |
class VisualizationExtension(BaseExtension):
|
|
|
|
| 36 |
def icon(self) -> str:
|
| 37 |
return "π"
|
| 38 |
|
| 39 |
+
@property
|
| 40 |
+
def version(self) -> str:
|
| 41 |
+
return "2.0.0"
|
| 42 |
+
|
| 43 |
def get_system_context(self) -> str:
|
| 44 |
return """
|
| 45 |
You have access to a Data Visualization system for creating charts and graphs.
|
|
|
|
| 112 |
def _get_default_state(self) -> Dict[str, Any]:
|
| 113 |
return {
|
| 114 |
"charts": [],
|
| 115 |
+
"output_dir": "visualizations",
|
| 116 |
+
"total_created": 0,
|
| 117 |
+
"created_at": datetime.datetime.now().isoformat(),
|
| 118 |
+
"last_updated": datetime.datetime.now().isoformat()
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def get_state_summary(self, user_id: str) -> Optional[str]:
|
| 122 |
+
"""Provide state summary for system prompt"""
|
| 123 |
+
state = self.get_state(user_id)
|
| 124 |
+
chart_count = len(state.get("charts", []))
|
| 125 |
+
if chart_count > 0:
|
| 126 |
+
return f"{chart_count} visualizations created this session"
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
def get_metrics(self, user_id: str) -> Dict[str, Any]:
|
| 130 |
+
"""Provide usage metrics"""
|
| 131 |
+
state = self.get_state(user_id)
|
| 132 |
+
charts = state.get("charts", [])
|
| 133 |
+
chart_types = {}
|
| 134 |
+
for chart in charts:
|
| 135 |
+
chart_type = chart.get("type", "unknown")
|
| 136 |
+
chart_types[chart_type] = chart_types.get(chart_type, 0) + 1
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
"total_created": state.get("total_created", 0),
|
| 140 |
+
"current_session": len(charts),
|
| 141 |
+
"by_type": chart_types
|
| 142 |
}
|
| 143 |
|
| 144 |
def get_tools(self) -> List[types.Tool]:
|
|
|
|
| 251 |
output_dir.mkdir(exist_ok=True)
|
| 252 |
|
| 253 |
# Generate filename
|
|
|
|
| 254 |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 255 |
safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
|
| 256 |
safe_title = safe_title.replace(' ', '_')[:50]
|
|
|
|
| 277 |
"timestamp": timestamp
|
| 278 |
}
|
| 279 |
state["charts"].append(chart_info)
|
| 280 |
+
state["total_created"] = state.get("total_created", 0) + 1
|
| 281 |
self.update_state(user_id, state)
|
| 282 |
|
| 283 |
+
return str(filepath), img_base64
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any:
|
| 286 |
+
"""Execute tool logic"""
|
|
|
|
|
|
|
|
|
|
| 287 |
try:
|
| 288 |
if tool_name == "create_line_chart":
|
| 289 |
fig, ax = plt.subplots(figsize=(12, 7))
|
| 290 |
|
|
|
|
| 291 |
data = args["data"]
|
| 292 |
|
| 293 |
print(f"π Creating line chart with {len(data)} series")
|
|
|
|
| 306 |
ax.legend(fontsize=10)
|
| 307 |
ax.grid(True, alpha=0.3, linestyle='--')
|
| 308 |
|
|
|
|
| 309 |
plt.xticks(rotation=45, ha='right')
|
|
|
|
|
|
|
| 310 |
plt.tight_layout()
|
| 311 |
|
|
|
|
| 312 |
filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"])
|
| 313 |
|
| 314 |
+
# Log activity
|
| 315 |
+
self.log_activity(user_id, "chart_created", {
|
| 316 |
+
"type": "line",
|
| 317 |
+
"title": args["title"],
|
| 318 |
+
"series_count": len(data)
|
| 319 |
+
})
|
| 320 |
+
|
| 321 |
print(f"β
Line chart saved: {filepath}")
|
| 322 |
|
| 323 |
return {
|
|
|
|
| 348 |
ax.set_title(args["title"], fontsize=14, fontweight='bold')
|
| 349 |
ax.grid(True, alpha=0.3, axis='y')
|
| 350 |
|
|
|
|
| 351 |
if len(categories) > 5:
|
| 352 |
plt.xticks(rotation=45, ha='right')
|
| 353 |
|
|
|
|
| 355 |
|
| 356 |
filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"])
|
| 357 |
|
| 358 |
+
self.log_activity(user_id, "chart_created", {
|
| 359 |
+
"type": "bar",
|
| 360 |
+
"title": args["title"],
|
| 361 |
+
"categories": len(categories)
|
| 362 |
+
})
|
| 363 |
+
|
| 364 |
return {
|
| 365 |
"success": True,
|
| 366 |
"message": f"Bar chart created: {args['title']}",
|
|
|
|
| 393 |
|
| 394 |
filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"])
|
| 395 |
|
| 396 |
+
self.log_activity(user_id, "chart_created", {
|
| 397 |
+
"type": "scatter",
|
| 398 |
+
"title": args["title"],
|
| 399 |
+
"points": len(x_vals)
|
| 400 |
+
})
|
| 401 |
+
|
| 402 |
return {
|
| 403 |
"success": True,
|
| 404 |
"message": f"Scatter plot created: {args['title']}",
|
|
|
|
| 423 |
textprops={'fontsize': 11, 'fontweight': 'bold'}
|
| 424 |
)
|
| 425 |
|
|
|
|
| 426 |
for text in texts:
|
| 427 |
text.set_fontsize(11)
|
| 428 |
for autotext in autotexts:
|
|
|
|
| 434 |
|
| 435 |
filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"])
|
| 436 |
|
| 437 |
+
self.log_activity(user_id, "chart_created", {
|
| 438 |
+
"type": "pie",
|
| 439 |
+
"title": args["title"],
|
| 440 |
+
"slices": len(labels)
|
| 441 |
+
})
|
| 442 |
+
|
| 443 |
return {
|
| 444 |
"success": True,
|
| 445 |
"message": f"Pie chart created: {args['title']}",
|
|
|
|
| 454 |
|
| 455 |
return {
|
| 456 |
"total_charts": len(charts),
|
| 457 |
+
"charts": charts[-10:] # Last 10 charts
|
| 458 |
}
|
| 459 |
|
| 460 |
except Exception as e:
|
|
|
|
| 471 |
|
| 472 |
def on_enable(self, user_id: str) -> str:
|
| 473 |
self.initialize_state(user_id)
|
| 474 |
+
return "π Data Visualization enabled! I can create charts from stock data and other numerical data. Just ask me to visualize something!"
|
| 475 |
+
|
| 476 |
+
def on_disable(self, user_id: str) -> str:
|
| 477 |
+
state = self.get_state(user_id)
|
| 478 |
+
total = state.get("total_created", 0)
|
| 479 |
+
return f"π Data Visualization disabled. You created {total} visualizations this session."
|
| 480 |
+
|
| 481 |
+
def health_check(self, user_id: str) -> Dict[str, Any]:
|
| 482 |
+
"""Check extension health"""
|
| 483 |
+
try:
|
| 484 |
+
import matplotlib
|
| 485 |
+
return {
|
| 486 |
+
"healthy": True,
|
| 487 |
+
"extension": self.name,
|
| 488 |
+
"version": self.version,
|
| 489 |
+
"matplotlib_available": True
|
| 490 |
+
}
|
| 491 |
+
except ImportError:
|
| 492 |
+
return {
|
| 493 |
+
"healthy": False,
|
| 494 |
+
"extension": self.name,
|
| 495 |
+
"version": self.version,
|
| 496 |
+
"issues": ["matplotlib library not installed"]
|
| 497 |
+
}
|