PD03 commited on
Commit
af504e3
·
verified ·
1 Parent(s): c829fa9

Create agent_tools/ml_tools.py

Browse files
Files changed (1) hide show
  1. agent_tools/ml_tools.py +204 -0
agent_tools/ml_tools.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ML Tools optimized for Hugging Face Spaces
3
+ """
4
+
5
+ from smolagents import tool
6
+ import joblib
7
+ import pandas as pd
8
+ import numpy as np
9
+ import json
10
+ from pathlib import Path
11
+ from datetime import datetime
12
+ import duckdb
13
+ import streamlit as st
14
+
15
+ # Global model cache for HF Spaces
16
+ _model_cache = {}
17
+
18
+ def load_model_with_cache(model_name: str = 'churn_model_v1'):
19
+ """Load model with HF Spaces caching"""
20
+ if model_name not in _model_cache:
21
+ model_path = Path(f'models/{model_name}.pkl')
22
+ if model_path.exists():
23
+ _model_cache[model_name] = joblib.load(model_path)
24
+ else:
25
+ return None
26
+ return _model_cache[model_name]
27
+
28
+ @tool
29
+ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float = 0.6) -> str:
30
+ """
31
+ HF Spaces optimized churn prediction with performance constraints.
32
+
33
+ Args:
34
+ customer_ids: Comma-separated customer IDs (optional)
35
+ risk_threshold: Risk threshold for alerts (default 0.6)
36
+
37
+ Returns:
38
+ JSON with churn predictions optimized for HF Spaces
39
+ """
40
+ try:
41
+ # Load model
42
+ model_data = load_model_with_cache()
43
+ if model_data is None:
44
+ return json.dumps({"error": "Model not found. Please wait for training to complete."})
45
+
46
+ model = model_data['model']
47
+ label_encoders = model_data['label_encoders']
48
+ feature_columns = model_data['feature_columns']
49
+
50
+ # Load data with limits for HF Spaces performance
51
+ conn = duckdb.connect(':memory:')
52
+
53
+ conn.execute("""
54
+ CREATE TABLE customers AS
55
+ SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
56
+ LIMIT 2000
57
+ """) # Limit for performance
58
+
59
+ conn.execute("""
60
+ CREATE TABLE sales_docs AS
61
+ SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
62
+ LIMIT 5000
63
+ """) # Limit for performance
64
+
65
+ # Filter customers if specified
66
+ if customer_ids:
67
+ customer_list = [f"'{cid.strip()}'" for cid in customer_ids.split(',')]
68
+ where_clause = f"WHERE c.Customer IN ({','.join(customer_list)})"
69
+ else:
70
+ where_clause = "LIMIT 500" # Further limit for demo
71
+
72
+ # Get customer data
73
+ customer_data = conn.execute(f"""
74
+ SELECT
75
+ c.Customer,
76
+ c.CustomerName,
77
+ c.Country,
78
+ c.CustomerGroup,
79
+ COUNT(s.SalesDocument) as total_orders,
80
+ MAX(s.CreationDate) as last_order_date,
81
+ MIN(s.CreationDate) as first_order_date
82
+ FROM customers c
83
+ LEFT JOIN sales_docs s ON c.Customer = s.SoldToParty
84
+ {where_clause if not customer_ids else ""}
85
+ GROUP BY c.Customer, c.CustomerName, c.Country, c.CustomerGroup
86
+ {where_clause if customer_ids else ""}
87
+ """).df()
88
+
89
+ if len(customer_data) == 0:
90
+ return json.dumps({"error": "No customers found"})
91
+
92
+ # Feature engineering (same as training)
93
+ reference_date = pd.to_datetime('2024-12-31')
94
+ customer_data['last_order_date'] = pd.to_datetime(customer_data['last_order_date'])
95
+ customer_data['first_order_date'] = pd.to_datetime(customer_data['first_order_date'])
96
+
97
+ # RFM features
98
+ customer_data['Recency'] = (reference_date - customer_data['last_order_date']).dt.days
99
+ customer_data['Recency'] = customer_data['Recency'].fillna(365)
100
+ customer_data['Frequency'] = customer_data['total_orders'].fillna(0)
101
+
102
+ np.random.seed(42)
103
+ customer_data['Monetary'] = customer_data['Frequency'] * np.random.exponential(500, len(customer_data))
104
+
105
+ customer_data['Tenure'] = (reference_date - customer_data['first_order_date']).dt.days
106
+ customer_data['Tenure'] = customer_data['Tenure'].fillna(0)
107
+ customer_data['OrderVelocity'] = customer_data['Frequency'] / (customer_data['Tenure'] / 30 + 1)
108
+
109
+ # Encode categoricals
110
+ for col in ['Country', 'CustomerGroup']:
111
+ if col in label_encoders:
112
+ try:
113
+ customer_data[f'{col}_encoded'] = label_encoders[col].transform(
114
+ customer_data[col].fillna('Unknown')
115
+ )
116
+ except:
117
+ # Handle unseen categories
118
+ customer_data[f'{col}_encoded'] = 0
119
+
120
+ # Make predictions
121
+ try:
122
+ X = customer_data[feature_columns].fillna(0)
123
+ predictions = model.predict(X)
124
+ probabilities = model.predict_proba(X)[:, 1]
125
+
126
+ # Results
127
+ results = customer_data.copy()
128
+ results['churn_probability'] = probabilities
129
+ results['risk_level'] = results['churn_probability'].apply(
130
+ lambda x: 'CRITICAL' if x > 0.8 else 'HIGH' if x > 0.6 else 'MEDIUM' if x > 0.4 else 'LOW'
131
+ )
132
+
133
+ # High risk customers
134
+ high_risk = results[results['churn_probability'] >= risk_threshold].sort_values(
135
+ 'churn_probability', ascending=False
136
+ ).head(20) # Limit results for HF Spaces
137
+
138
+ # Generate recommendations
139
+ recommendations = []
140
+ for _, customer in high_risk.iterrows():
141
+ recommendations.append({
142
+ "customer_id": customer['Customer'],
143
+ "customer_name": customer['CustomerName'],
144
+ "churn_probability": round(float(customer['churn_probability']), 3),
145
+ "risk_level": customer['risk_level'],
146
+ "recommended_action": "Immediate contact" if customer['churn_probability'] > 0.8 else "Schedule follow-up",
147
+ "days_since_order": int(customer['Recency']) if not pd.isna(customer['Recency']) else 0
148
+ })
149
+
150
+ return json.dumps({
151
+ "analysis_date": datetime.now().isoformat(),
152
+ "customers_analyzed": len(results),
153
+ "high_risk_count": len(high_risk),
154
+ "churn_rate_predicted": round(len(high_risk) / len(results) * 100, 2) if len(results) > 0 else 0,
155
+ "urgent_actions": recommendations,
156
+ "model_performance": f"Accuracy: {model_data.get('accuracy', 'N/A')}",
157
+ "hf_spaces_note": "Results limited for demo performance"
158
+ })
159
+
160
+ except Exception as e:
161
+ return json.dumps({"error": f"Prediction failed: {str(e)}"})
162
+
163
+ except Exception as e:
164
+ return json.dumps({
165
+ "error": f"Churn analysis failed: {str(e)}",
166
+ "suggestion": "Please ensure model is trained"
167
+ })
168
+
169
+ @tool
170
+ def get_model_status() -> str:
171
+ """
172
+ Get ML model status for HF Spaces.
173
+
174
+ Returns:
175
+ JSON with model information and health
176
+ """
177
+ try:
178
+ metadata_path = Path('models/model_metadata.json')
179
+ model_path = Path('models/churn_model_v1.pkl')
180
+
181
+ if metadata_path.exists() and model_path.exists():
182
+ with open(metadata_path, 'r') as f:
183
+ metadata = json.load(f)
184
+
185
+ return json.dumps({
186
+ "model_status": "Ready",
187
+ "model_info": metadata,
188
+ "files_present": {
189
+ "model_file": model_path.exists(),
190
+ "metadata_file": metadata_path.exists()
191
+ },
192
+ "recommendation": "Model is ready for predictions"
193
+ })
194
+ else:
195
+ return json.dumps({
196
+ "model_status": "Not Found",
197
+ "message": "Model will be trained automatically on first use",
198
+ "training_time": "Approximately 1-2 minutes"
199
+ })
200
+
201
+ except Exception as e:
202
+ return json.dumps({
203
+ "error": f"Status check failed: {str(e)}"
204
+ })