Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| FEATURE_NAMES = [ | |
| "temperature_f", | |
| "humidity_percent", | |
| "wind_mph", | |
| "hour_of_day", | |
| "is_weekend", | |
| ] | |
| CACHE_ROOT = Path(__file__).with_name(".cache") | |
| CACHE_ROOT.mkdir(exist_ok=True) | |
| matplotlib_cache = CACHE_ROOT / "matplotlib" | |
| matplotlib_cache.mkdir(parents=True, exist_ok=True) | |
| os.environ.setdefault("MPLCONFIGDIR", str(matplotlib_cache)) | |
| os.environ.setdefault("XDG_CACHE_HOME", str(CACHE_ROOT)) | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from sklearn.metrics import accuracy_score, classification_report | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.tree import DecisionTreeClassifier, plot_tree | |
| CUSTOM_CSS = """ | |
| #predict-button { | |
| background: linear-gradient(135deg, #ce1126, #8b0000); | |
| color: white; | |
| font-weight: 700; | |
| font-size: 1.05rem; | |
| border: none; | |
| box-shadow: 0 6px 18px rgba(139, 0, 0, 0.35); | |
| } | |
| #predict-button:hover { | |
| transform: translateY(-1px); | |
| box-shadow: 0 10px 24px rgba(139, 0, 0, 0.45); | |
| } | |
| .prediction-card { | |
| border: 3px solid #ce1126; | |
| border-radius: 18px; | |
| padding: 1.5rem; | |
| background: #fff4f4; | |
| text-align: center; | |
| box-shadow: 0 12px 30px rgba(206, 17, 38, 0.15); | |
| } | |
| .prediction-card .location { | |
| font-size: 2rem; | |
| font-weight: 800; | |
| color: #7b0011; | |
| letter-spacing: 0.5px; | |
| } | |
| .prediction-card .location span { | |
| text-transform: uppercase; | |
| } | |
| .prediction-card .confidence { | |
| margin-top: 0.75rem; | |
| font-size: 1.05rem; | |
| color: #333; | |
| } | |
| .prediction-card .secondary { | |
| margin-top: 0.25rem; | |
| font-size: 0.95rem; | |
| color: #555; | |
| } | |
| .path-list { | |
| list-style: none; | |
| padding: 0; | |
| margin: 0; | |
| display: flex; | |
| flex-direction: column; | |
| gap: 0.75rem; | |
| } | |
| .path-list li { | |
| background: #f8f9fa; | |
| border-left: 5px solid #ce1126; | |
| padding: 0.75rem 1rem; | |
| border-radius: 10px; | |
| box-shadow: inset 0 0 0 1px rgba(0, 0, 0, 0.05); | |
| } | |
| .path-list li .headline { | |
| font-weight: 700; | |
| color: #7b0011; | |
| margin-bottom: 0.2rem; | |
| } | |
| .path-list li .meta { | |
| color: #333; | |
| font-size: 0.95rem; | |
| } | |
| .path-list li.leaf { | |
| border-left-color: #1b5e20; | |
| background: #e8f5e9; | |
| } | |
| .path-list li.leaf .headline { | |
| color: #1b5e20; | |
| } | |
| """ | |
| def build_dataset(n_samples: int = 200, seed: int = 42) -> pd.DataFrame: | |
| rng = np.random.default_rng(seed) | |
| data = pd.DataFrame( | |
| { | |
| "temperature_f": rng.integers(60, 115, n_samples), | |
| "humidity_percent": rng.integers(10, 40, n_samples), | |
| "wind_mph": rng.integers(0, 25, n_samples), | |
| "hour_of_day": rng.integers(8, 22, n_samples), | |
| "is_weekend": rng.integers(0, 2, n_samples), | |
| } | |
| ) | |
| labels: list[str] = [] | |
| for idx in range(n_samples): | |
| temp = data.at[idx, "temperature_f"] | |
| wind = data.at[idx, "wind_mph"] | |
| hour = data.at[idx, "hour_of_day"] | |
| if temp < 85 and wind < 15 and 8 <= hour <= 18: | |
| labels.append("Outdoors") | |
| elif temp > 105: | |
| labels.append("Library") | |
| elif wind > 20: | |
| labels.append("Library") | |
| elif hour > 19: | |
| labels.append("Library") | |
| else: | |
| labels.append(rng.choice(["Library", "Outdoors"], p=[0.6, 0.4])) | |
| data["study_location"] = labels | |
| return data | |
| def train_model(data: pd.DataFrame) -> Tuple[DecisionTreeClassifier, Dict[str, float]]: | |
| X = data[FEATURE_NAMES] | |
| y = data["study_location"] | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42 | |
| ) | |
| clf = DecisionTreeClassifier( | |
| max_depth=3, | |
| min_samples_split=10, | |
| random_state=42, | |
| ) | |
| clf.fit(X_train, y_train) | |
| y_pred_train = clf.predict(X_train) | |
| y_pred_test = clf.predict(X_test) | |
| metrics = { | |
| "train_accuracy": accuracy_score(y_train, y_pred_train), | |
| "test_accuracy": accuracy_score(y_test, y_pred_test), | |
| } | |
| return clf, metrics | |
| def describe_path( | |
| model: DecisionTreeClassifier, sample: np.ndarray | |
| ) -> Tuple[List[Dict[str, object]], List[int]]: | |
| tree = model.tree_ | |
| node_indicator = model.decision_path(sample.reshape(1, -1)) | |
| leaf_id = model.apply(sample.reshape(1, -1))[0] | |
| start, end = node_indicator.indptr[:2] | |
| node_index = node_indicator.indices[start:end] | |
| steps: List[Dict[str, object]] = [] | |
| path_nodes = list(node_index) + [leaf_id] | |
| for node_position, node_id in enumerate(node_index, start=1): | |
| feature_index = int(tree.feature[node_id]) | |
| threshold = float(tree.threshold[node_id]) | |
| feature_name = FEATURE_NAMES[feature_index] | |
| feature_value = float(sample[feature_index]) | |
| go_left = feature_value <= threshold | |
| direction = "≤" if go_left else ">" | |
| next_node = int(tree.children_left[node_id] if go_left else tree.children_right[node_id]) | |
| steps.append( | |
| { | |
| "step": node_position, | |
| "node_id": int(node_id), | |
| "next_node": next_node, | |
| "feature": feature_name, | |
| "feature_label": feature_name.replace("_", " ").title(), | |
| "threshold": threshold, | |
| "value": feature_value, | |
| "direction": direction, | |
| "decision": "left" if go_left else "right", | |
| } | |
| ) | |
| leaf_samples = int(tree.n_node_samples[leaf_id]) | |
| confidences = model.predict_proba(sample.reshape(1, -1))[0] | |
| class_idx = int(confidences.argmax()) | |
| class_name = model.classes_[class_idx] | |
| steps.append( | |
| { | |
| "step": len(node_index) + 1, | |
| "node_id": int(leaf_id), | |
| "leaf": True, | |
| "prediction": class_name, | |
| "confidence": float(confidences[class_idx]), | |
| "samples": leaf_samples, | |
| } | |
| ) | |
| return steps, path_nodes | |
| def render_tree_image( | |
| model: DecisionTreeClassifier, | |
| highlighted_nodes: Optional[List[int]] = None, | |
| leaf_id: Optional[int] = None, | |
| ) -> Image.Image: | |
| highlighted = set(highlighted_nodes or []) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| texts = plot_tree( | |
| model, | |
| feature_names=[name.replace("_", " ").title() for name in FEATURE_NAMES], | |
| class_names=model.classes_, | |
| filled=False, | |
| rounded=True, | |
| node_ids=True, | |
| fontsize=9, | |
| ax=ax, | |
| ) | |
| fig.tight_layout() | |
| for text in texts: | |
| label = text.get_text() | |
| node_id = None | |
| for line in label.split("\n"): | |
| if line.startswith("node id ="): | |
| try: | |
| node_id = int(line.split("=")[1].strip()) | |
| except ValueError: | |
| node_id = None | |
| break | |
| if node_id is None: | |
| continue | |
| if node_id in highlighted: | |
| is_leaf = leaf_id is not None and node_id == leaf_id | |
| text.set_bbox( | |
| dict( | |
| boxstyle="round,pad=0.45", | |
| facecolor="#e8f5e9" if is_leaf else "#ffe5e9", | |
| edgecolor="#2e7d32" if is_leaf else "#ce1126", | |
| linewidth=2.5, | |
| ) | |
| ) | |
| text.set_color("#1b5e20" if is_leaf else "#7b0011") | |
| else: | |
| text.set_bbox( | |
| dict( | |
| boxstyle="round,pad=0.45", | |
| facecolor="#f5f7fa", | |
| edgecolor="#c3cfe2", | |
| linewidth=1.2, | |
| ) | |
| ) | |
| text.set_color("#2e2e2e") | |
| buffer = io.BytesIO() | |
| fig.savefig(buffer, format="png", dpi=200, bbox_inches="tight") | |
| plt.close(fig) | |
| buffer.seek(0) | |
| return Image.open(buffer) | |
| def load_html_snippet() -> str: | |
| html_path = Path(__file__).with_name("interactive_decision_tree.html") | |
| if html_path.exists(): | |
| return html_path.read_text(encoding="utf-8") | |
| return "" | |
| DATAFRAME = build_dataset() | |
| MODEL, METRICS = train_model(DATAFRAME) | |
| HTML_SNIPPET = load_html_snippet() | |
| CLASS_REPORT = classification_report( | |
| DATAFRAME["study_location"], | |
| MODEL.predict(DATAFRAME[FEATURE_NAMES]), | |
| target_names=MODEL.classes_, | |
| zero_division=0, | |
| ) | |
| DEFAULT_TREE_IMAGE = render_tree_image(MODEL) | |
| def predict_study_location( | |
| temperature: int, | |
| humidity: int, | |
| wind: int, | |
| hour: int, | |
| weekend: bool, | |
| ) -> Tuple[str, str, str, Image.Image]: | |
| sample = np.array( | |
| [ | |
| temperature, | |
| humidity, | |
| wind, | |
| hour, | |
| 1 if weekend else 0, | |
| ], | |
| dtype=float, | |
| ) | |
| probabilities = MODEL.predict_proba(sample.reshape(1, -1))[0] | |
| top_index = probabilities.argmax() | |
| label = MODEL.classes_[top_index] | |
| confidence = probabilities[top_index] | |
| secondary_index = ( | |
| probabilities.argsort()[::-1][1] | |
| if len(probabilities) > 1 | |
| else top_index | |
| ) | |
| secondary_confidence = probabilities[secondary_index] | |
| secondary_label = MODEL.classes_[secondary_index] | |
| step_details, path_nodes = describe_path(MODEL, sample) | |
| leaf_id = path_nodes[-1] if path_nodes else None | |
| path_html_items: list[str] = [] | |
| for detail in step_details: | |
| if detail.get("leaf"): | |
| path_html_items.append( | |
| f""" | |
| <li class="leaf"> | |
| <div class="headline">Leaf: predict {detail['prediction']}</div> | |
| <div class="meta">Confidence {detail['confidence']:.0%} · Support {detail['samples']} samples</div> | |
| </li> | |
| """.strip() | |
| ) | |
| else: | |
| threshold = detail["threshold"] | |
| value = detail["value"] | |
| direction = detail["direction"] | |
| feature_label = detail["feature_label"] | |
| decision = detail["decision"] | |
| next_node = detail["next_node"] | |
| path_html_items.append( | |
| f""" | |
| <li> | |
| <div class="headline"> | |
| Step {detail['step']}: {feature_label} {direction} {threshold:.1f} | |
| </div> | |
| <div class="meta"> | |
| Observed value {value:.1f} → take {decision} branch (node {next_node}) | |
| </div> | |
| </li> | |
| """.strip() | |
| ) | |
| path_html = "<ol class='path-list'>" + "\n".join(path_html_items) + "</ol>" | |
| prediction_html = f""" | |
| <div class="prediction-card"> | |
| <div class="location">Study at the <span>{label}</span></div> | |
| <div class="confidence">Confidence: {confidence:.1%}</div> | |
| <div class="secondary">Next best: {secondary_label} ({secondary_confidence:.1%})</div> | |
| </div> | |
| """.strip() | |
| confidence_text = ( | |
| f"Primary recommendation: **{label}** (`{confidence:.1%}`) · " | |
| f"Alternate: **{secondary_label}** (`{secondary_confidence:.1%}`)" | |
| ) | |
| highlighted_image = render_tree_image( | |
| MODEL, | |
| highlighted_nodes=path_nodes, | |
| leaf_id=leaf_id, | |
| ) | |
| return prediction_html, confidence_text, path_html, highlighted_image | |
| with gr.Blocks( | |
| title="UNLV Study Location Predictor", | |
| css=CUSTOM_CSS, | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # UNLV Study Location Predictor | |
| Adjust the sliders to mirror the current Las Vegas weather and the decision tree | |
| will suggest the best study location. All numbers come from a synthetic dataset | |
| tailored for classroom walkthroughs. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| temperature = gr.Slider( | |
| minimum=60, | |
| maximum=115, | |
| value=85, | |
| step=1, | |
| label="Temperature (°F)", | |
| info="Normal daytime range in Las Vegas", | |
| ) | |
| humidity = gr.Slider( | |
| minimum=10, | |
| maximum=40, | |
| value=20, | |
| step=1, | |
| label="Humidity (%)", | |
| ) | |
| wind = gr.Slider( | |
| minimum=0, | |
| maximum=25, | |
| value=10, | |
| step=1, | |
| label="Wind Speed (mph)", | |
| ) | |
| hour = gr.Slider( | |
| minimum=8, | |
| maximum=22, | |
| value=14, | |
| step=1, | |
| label="Hour of Day (24h)", | |
| ) | |
| weekend = gr.Checkbox( | |
| label="Is it the weekend?", | |
| value=False, | |
| ) | |
| run_button = gr.Button( | |
| "Predict Study Location", | |
| elem_id="predict-button", | |
| ) | |
| with gr.Column(): | |
| prediction_box = gr.HTML() | |
| confidence_box = gr.Markdown() | |
| path_box = gr.HTML( | |
| "<p style='color:#666;font-style:italic;'>Run a prediction to walk through each decision rule.</p>" | |
| ) | |
| with gr.Accordion("Explore the Decision Tree", open=False): | |
| tree_image = gr.Image( | |
| value=DEFAULT_TREE_IMAGE, | |
| image_mode="RGB", | |
| show_label=False, | |
| ) | |
| gr.Markdown( | |
| f"Train accuracy: `{METRICS['train_accuracy']:.1%}` | " | |
| f"Test accuracy: `{METRICS['test_accuracy']:.1%}`" | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Legend:** | |
| • *gini* – how mixed the classes are (0 = pure, higher = more mixed) | |
| • *samples* – number of training rows that reached the node | |
| • *class* – majority label assigned when the tree predicts at that node | |
| """, | |
| ) | |
| with gr.Accordion("Inspect the Synthetic Dataset", open=False): | |
| gr.DataFrame( | |
| value=DATAFRAME.head(20), | |
| wrap=True, | |
| label="Sample of the training data (20 rows)", | |
| interactive=False, | |
| ) | |
| gr.Markdown( | |
| "Class balance and precision/recall on the full dataset:" | |
| f"\n```\n{CLASS_REPORT}\n```" | |
| ) | |
| if HTML_SNIPPET: | |
| with gr.Accordion("HTML Prototype (original demo)", open=False): | |
| gr.HTML(HTML_SNIPPET) | |
| run_button.click( | |
| predict_study_location, | |
| inputs=[temperature, humidity, wind, hour, weekend], | |
| outputs=[prediction_box, confidence_box, path_box, tree_image], | |
| ) | |
| if __name__ == "__main__": | |
| queued_app = demo.queue() | |
| is_hf_space = bool(os.environ.get("SPACE_ID")) | |
| default_port = int( | |
| os.environ.get( | |
| "PORT", | |
| os.environ.get("GRADIO_SERVER_PORT", "7860"), | |
| ) | |
| ) | |
| launch_kwargs = { | |
| "server_name": os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), | |
| "server_port": default_port, | |
| "show_error": True, | |
| "share": os.environ.get("GRADIO_SHARE", "").lower() == "true", | |
| } | |
| if not is_hf_space: | |
| launch_kwargs["prevent_thread_lock"] = True | |
| try: | |
| queued_app.launch(**launch_kwargs) | |
| except OSError: | |
| fallback_port = int( | |
| os.environ.get( | |
| "GRADIO_FALLBACK_PORT", | |
| str(default_port + 1111), | |
| ) | |
| ) | |
| if fallback_port == launch_kwargs["server_port"]: | |
| raise | |
| launch_kwargs["server_port"] = fallback_port | |
| queued_app.launch(**launch_kwargs) | |