wuhp commited on
Commit
dceea05
Β·
verified Β·
1 Parent(s): 81d389f

Update extensions/visualization.py

Browse files
Files changed (1) hide show
  1. extensions/visualization.py +90 -23
extensions/visualization.py CHANGED
@@ -1,20 +1,21 @@
1
  """
2
- Data Visualization Extension
3
  Create charts, graphs, and visualizations from data
4
- Enhanced for stock market data integration
5
  """
6
 
7
  from base_extension import BaseExtension
8
  from google.genai import types
9
- from typing import Dict, Any, List
10
  import json
11
  import matplotlib
12
- matplotlib.use('Agg') # Use non-GUI backend
13
  import matplotlib.pyplot as plt
14
  import io
15
  import base64
16
  from pathlib import Path
17
  import numpy as np
 
18
 
19
 
20
  class VisualizationExtension(BaseExtension):
@@ -35,6 +36,10 @@ class VisualizationExtension(BaseExtension):
35
  def icon(self) -> str:
36
  return "πŸ“Š"
37
 
 
 
 
 
38
  def get_system_context(self) -> str:
39
  return """
40
  You have access to a Data Visualization system for creating charts and graphs.
@@ -107,7 +112,33 @@ create_bar_chart(
107
  def _get_default_state(self) -> Dict[str, Any]:
108
  return {
109
  "charts": [],
110
- "output_dir": "visualizations"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
112
 
113
  def get_tools(self) -> List[types.Tool]:
@@ -220,7 +251,6 @@ create_bar_chart(
220
  output_dir.mkdir(exist_ok=True)
221
 
222
  # Generate filename
223
- import datetime
224
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
225
  safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
226
  safe_title = safe_title.replace(' ', '_')[:50]
@@ -247,22 +277,17 @@ create_bar_chart(
247
  "timestamp": timestamp
248
  }
249
  state["charts"].append(chart_info)
 
250
  self.update_state(user_id, state)
251
 
252
- # Return as tuple: (filepath_string, base64_string)
253
- filepath_str = str(filepath)
254
- return filepath_str, img_base64
255
 
256
- def handle_tool_call(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any:
257
- # Ensure state is initialized
258
- if user_id not in self.state:
259
- self.initialize_state(user_id)
260
-
261
  try:
262
  if tool_name == "create_line_chart":
263
  fig, ax = plt.subplots(figsize=(12, 7))
264
 
265
- # Plot each series
266
  data = args["data"]
267
 
268
  print(f"πŸ“Š Creating line chart with {len(data)} series")
@@ -281,15 +306,18 @@ create_bar_chart(
281
  ax.legend(fontsize=10)
282
  ax.grid(True, alpha=0.3, linestyle='--')
283
 
284
- # Rotate x-axis labels for readability (especially for dates)
285
  plt.xticks(rotation=45, ha='right')
286
-
287
- # Tight layout to prevent label cutoff
288
  plt.tight_layout()
289
 
290
- # Save chart and get both filepath and base64
291
  filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"])
292
 
 
 
 
 
 
 
 
293
  print(f"βœ… Line chart saved: {filepath}")
294
 
295
  return {
@@ -320,7 +348,6 @@ create_bar_chart(
320
  ax.set_title(args["title"], fontsize=14, fontweight='bold')
321
  ax.grid(True, alpha=0.3, axis='y')
322
 
323
- # Rotate x-labels if needed
324
  if len(categories) > 5:
325
  plt.xticks(rotation=45, ha='right')
326
 
@@ -328,6 +355,12 @@ create_bar_chart(
328
 
329
  filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"])
330
 
 
 
 
 
 
 
331
  return {
332
  "success": True,
333
  "message": f"Bar chart created: {args['title']}",
@@ -360,6 +393,12 @@ create_bar_chart(
360
 
361
  filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"])
362
 
 
 
 
 
 
 
363
  return {
364
  "success": True,
365
  "message": f"Scatter plot created: {args['title']}",
@@ -384,7 +423,6 @@ create_bar_chart(
384
  textprops={'fontsize': 11, 'fontweight': 'bold'}
385
  )
386
 
387
- # Enhance text
388
  for text in texts:
389
  text.set_fontsize(11)
390
  for autotext in autotexts:
@@ -396,6 +434,12 @@ create_bar_chart(
396
 
397
  filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"])
398
 
 
 
 
 
 
 
399
  return {
400
  "success": True,
401
  "message": f"Pie chart created: {args['title']}",
@@ -410,7 +454,7 @@ create_bar_chart(
410
 
411
  return {
412
  "total_charts": len(charts),
413
- "charts": charts
414
  }
415
 
416
  except Exception as e:
@@ -427,4 +471,27 @@ create_bar_chart(
427
 
428
  def on_enable(self, user_id: str) -> str:
429
  self.initialize_state(user_id)
430
- return "πŸ“Š Data Visualization enabled! I can create charts from stock data and other numerical data. Just ask me to visualize something!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Enhanced Data Visualization Extension
3
  Create charts, graphs, and visualizations from data
4
+ Now with better state management and orchestrator integration
5
  """
6
 
7
  from base_extension import BaseExtension
8
  from google.genai import types
9
+ from typing import Dict, Any, List, Optional
10
  import json
11
  import matplotlib
12
+ matplotlib.use('Agg')
13
  import matplotlib.pyplot as plt
14
  import io
15
  import base64
16
  from pathlib import Path
17
  import numpy as np
18
+ import datetime
19
 
20
 
21
  class VisualizationExtension(BaseExtension):
 
36
  def icon(self) -> str:
37
  return "πŸ“Š"
38
 
39
+ @property
40
+ def version(self) -> str:
41
+ return "2.0.0"
42
+
43
  def get_system_context(self) -> str:
44
  return """
45
  You have access to a Data Visualization system for creating charts and graphs.
 
112
  def _get_default_state(self) -> Dict[str, Any]:
113
  return {
114
  "charts": [],
115
+ "output_dir": "visualizations",
116
+ "total_created": 0,
117
+ "created_at": datetime.datetime.now().isoformat(),
118
+ "last_updated": datetime.datetime.now().isoformat()
119
+ }
120
+
121
+ def get_state_summary(self, user_id: str) -> Optional[str]:
122
+ """Provide state summary for system prompt"""
123
+ state = self.get_state(user_id)
124
+ chart_count = len(state.get("charts", []))
125
+ if chart_count > 0:
126
+ return f"{chart_count} visualizations created this session"
127
+ return None
128
+
129
+ def get_metrics(self, user_id: str) -> Dict[str, Any]:
130
+ """Provide usage metrics"""
131
+ state = self.get_state(user_id)
132
+ charts = state.get("charts", [])
133
+ chart_types = {}
134
+ for chart in charts:
135
+ chart_type = chart.get("type", "unknown")
136
+ chart_types[chart_type] = chart_types.get(chart_type, 0) + 1
137
+
138
+ return {
139
+ "total_created": state.get("total_created", 0),
140
+ "current_session": len(charts),
141
+ "by_type": chart_types
142
  }
143
 
144
  def get_tools(self) -> List[types.Tool]:
 
251
  output_dir.mkdir(exist_ok=True)
252
 
253
  # Generate filename
 
254
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
255
  safe_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
256
  safe_title = safe_title.replace(' ', '_')[:50]
 
277
  "timestamp": timestamp
278
  }
279
  state["charts"].append(chart_info)
280
+ state["total_created"] = state.get("total_created", 0) + 1
281
  self.update_state(user_id, state)
282
 
283
+ return str(filepath), img_base64
 
 
284
 
285
+ def _execute_tool(self, user_id: str, tool_name: str, args: Dict[str, Any]) -> Any:
286
+ """Execute tool logic"""
 
 
 
287
  try:
288
  if tool_name == "create_line_chart":
289
  fig, ax = plt.subplots(figsize=(12, 7))
290
 
 
291
  data = args["data"]
292
 
293
  print(f"πŸ“Š Creating line chart with {len(data)} series")
 
306
  ax.legend(fontsize=10)
307
  ax.grid(True, alpha=0.3, linestyle='--')
308
 
 
309
  plt.xticks(rotation=45, ha='right')
 
 
310
  plt.tight_layout()
311
 
 
312
  filepath, img_base64 = self._save_chart(fig, user_id, "line_chart", args["title"])
313
 
314
+ # Log activity
315
+ self.log_activity(user_id, "chart_created", {
316
+ "type": "line",
317
+ "title": args["title"],
318
+ "series_count": len(data)
319
+ })
320
+
321
  print(f"βœ… Line chart saved: {filepath}")
322
 
323
  return {
 
348
  ax.set_title(args["title"], fontsize=14, fontweight='bold')
349
  ax.grid(True, alpha=0.3, axis='y')
350
 
 
351
  if len(categories) > 5:
352
  plt.xticks(rotation=45, ha='right')
353
 
 
355
 
356
  filepath, img_base64 = self._save_chart(fig, user_id, "bar_chart", args["title"])
357
 
358
+ self.log_activity(user_id, "chart_created", {
359
+ "type": "bar",
360
+ "title": args["title"],
361
+ "categories": len(categories)
362
+ })
363
+
364
  return {
365
  "success": True,
366
  "message": f"Bar chart created: {args['title']}",
 
393
 
394
  filepath, img_base64 = self._save_chart(fig, user_id, "scatter_plot", args["title"])
395
 
396
+ self.log_activity(user_id, "chart_created", {
397
+ "type": "scatter",
398
+ "title": args["title"],
399
+ "points": len(x_vals)
400
+ })
401
+
402
  return {
403
  "success": True,
404
  "message": f"Scatter plot created: {args['title']}",
 
423
  textprops={'fontsize': 11, 'fontweight': 'bold'}
424
  )
425
 
 
426
  for text in texts:
427
  text.set_fontsize(11)
428
  for autotext in autotexts:
 
434
 
435
  filepath, img_base64 = self._save_chart(fig, user_id, "pie_chart", args["title"])
436
 
437
+ self.log_activity(user_id, "chart_created", {
438
+ "type": "pie",
439
+ "title": args["title"],
440
+ "slices": len(labels)
441
+ })
442
+
443
  return {
444
  "success": True,
445
  "message": f"Pie chart created: {args['title']}",
 
454
 
455
  return {
456
  "total_charts": len(charts),
457
+ "charts": charts[-10:] # Last 10 charts
458
  }
459
 
460
  except Exception as e:
 
471
 
472
  def on_enable(self, user_id: str) -> str:
473
  self.initialize_state(user_id)
474
+ return "πŸ“Š Data Visualization enabled! I can create charts from stock data and other numerical data. Just ask me to visualize something!"
475
+
476
+ def on_disable(self, user_id: str) -> str:
477
+ state = self.get_state(user_id)
478
+ total = state.get("total_created", 0)
479
+ return f"πŸ“Š Data Visualization disabled. You created {total} visualizations this session."
480
+
481
+ def health_check(self, user_id: str) -> Dict[str, Any]:
482
+ """Check extension health"""
483
+ try:
484
+ import matplotlib
485
+ return {
486
+ "healthy": True,
487
+ "extension": self.name,
488
+ "version": self.version,
489
+ "matplotlib_available": True
490
+ }
491
+ except ImportError:
492
+ return {
493
+ "healthy": False,
494
+ "extension": self.name,
495
+ "version": self.version,
496
+ "issues": ["matplotlib library not installed"]
497
+ }