import io import os from pathlib import Path from typing import Dict, List, Optional, Tuple FEATURE_NAMES = [ "temperature_f", "humidity_percent", "wind_mph", "hour_of_day", "is_weekend", ] CACHE_ROOT = Path(__file__).with_name(".cache") CACHE_ROOT.mkdir(exist_ok=True) matplotlib_cache = CACHE_ROOT / "matplotlib" matplotlib_cache.mkdir(parents=True, exist_ok=True) os.environ.setdefault("MPLCONFIGDIR", str(matplotlib_cache)) os.environ.setdefault("XDG_CACHE_HOME", str(CACHE_ROOT)) import gradio as gr import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import pandas as pd from PIL import Image from sklearn.metrics import accuracy_score, classification_report from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, plot_tree CUSTOM_CSS = """ #predict-button { background: linear-gradient(135deg, #ce1126, #8b0000); color: white; font-weight: 700; font-size: 1.05rem; border: none; box-shadow: 0 6px 18px rgba(139, 0, 0, 0.35); } #predict-button:hover { transform: translateY(-1px); box-shadow: 0 10px 24px rgba(139, 0, 0, 0.45); } .prediction-card { border: 3px solid #ce1126; border-radius: 18px; padding: 1.5rem; background: #fff4f4; text-align: center; box-shadow: 0 12px 30px rgba(206, 17, 38, 0.15); } .prediction-card .location { font-size: 2rem; font-weight: 800; color: #7b0011; letter-spacing: 0.5px; } .prediction-card .location span { text-transform: uppercase; } .prediction-card .confidence { margin-top: 0.75rem; font-size: 1.05rem; color: #333; } .prediction-card .secondary { margin-top: 0.25rem; font-size: 0.95rem; color: #555; } .path-list { list-style: none; padding: 0; margin: 0; display: flex; flex-direction: column; gap: 0.75rem; } .path-list li { background: #f8f9fa; border-left: 5px solid #ce1126; padding: 0.75rem 1rem; border-radius: 10px; box-shadow: inset 0 0 0 1px rgba(0, 0, 0, 0.05); } .path-list li .headline { font-weight: 700; color: #7b0011; margin-bottom: 0.2rem; } .path-list li .meta { color: #333; font-size: 0.95rem; } .path-list li.leaf { border-left-color: #1b5e20; background: #e8f5e9; } .path-list li.leaf .headline { color: #1b5e20; } """ def build_dataset(n_samples: int = 200, seed: int = 42) -> pd.DataFrame: rng = np.random.default_rng(seed) data = pd.DataFrame( { "temperature_f": rng.integers(60, 115, n_samples), "humidity_percent": rng.integers(10, 40, n_samples), "wind_mph": rng.integers(0, 25, n_samples), "hour_of_day": rng.integers(8, 22, n_samples), "is_weekend": rng.integers(0, 2, n_samples), } ) labels: list[str] = [] for idx in range(n_samples): temp = data.at[idx, "temperature_f"] wind = data.at[idx, "wind_mph"] hour = data.at[idx, "hour_of_day"] if temp < 85 and wind < 15 and 8 <= hour <= 18: labels.append("Outdoors") elif temp > 105: labels.append("Library") elif wind > 20: labels.append("Library") elif hour > 19: labels.append("Library") else: labels.append(rng.choice(["Library", "Outdoors"], p=[0.6, 0.4])) data["study_location"] = labels return data def train_model(data: pd.DataFrame) -> Tuple[DecisionTreeClassifier, Dict[str, float]]: X = data[FEATURE_NAMES] y = data["study_location"] X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) clf = DecisionTreeClassifier( max_depth=3, min_samples_split=10, random_state=42, ) clf.fit(X_train, y_train) y_pred_train = clf.predict(X_train) y_pred_test = clf.predict(X_test) metrics = { "train_accuracy": accuracy_score(y_train, y_pred_train), "test_accuracy": accuracy_score(y_test, y_pred_test), } return clf, metrics def describe_path( model: DecisionTreeClassifier, sample: np.ndarray ) -> Tuple[List[Dict[str, object]], List[int]]: tree = model.tree_ node_indicator = model.decision_path(sample.reshape(1, -1)) leaf_id = model.apply(sample.reshape(1, -1))[0] start, end = node_indicator.indptr[:2] node_index = node_indicator.indices[start:end] steps: List[Dict[str, object]] = [] path_nodes = list(node_index) + [leaf_id] for node_position, node_id in enumerate(node_index, start=1): feature_index = int(tree.feature[node_id]) threshold = float(tree.threshold[node_id]) feature_name = FEATURE_NAMES[feature_index] feature_value = float(sample[feature_index]) go_left = feature_value <= threshold direction = "≤" if go_left else ">" next_node = int(tree.children_left[node_id] if go_left else tree.children_right[node_id]) steps.append( { "step": node_position, "node_id": int(node_id), "next_node": next_node, "feature": feature_name, "feature_label": feature_name.replace("_", " ").title(), "threshold": threshold, "value": feature_value, "direction": direction, "decision": "left" if go_left else "right", } ) leaf_samples = int(tree.n_node_samples[leaf_id]) confidences = model.predict_proba(sample.reshape(1, -1))[0] class_idx = int(confidences.argmax()) class_name = model.classes_[class_idx] steps.append( { "step": len(node_index) + 1, "node_id": int(leaf_id), "leaf": True, "prediction": class_name, "confidence": float(confidences[class_idx]), "samples": leaf_samples, } ) return steps, path_nodes def render_tree_image( model: DecisionTreeClassifier, highlighted_nodes: Optional[List[int]] = None, leaf_id: Optional[int] = None, ) -> Image.Image: highlighted = set(highlighted_nodes or []) fig, ax = plt.subplots(figsize=(10, 6)) texts = plot_tree( model, feature_names=[name.replace("_", " ").title() for name in FEATURE_NAMES], class_names=model.classes_, filled=False, rounded=True, node_ids=True, fontsize=9, ax=ax, ) fig.tight_layout() for text in texts: label = text.get_text() node_id = None for line in label.split("\n"): if line.startswith("node id ="): try: node_id = int(line.split("=")[1].strip()) except ValueError: node_id = None break if node_id is None: continue if node_id in highlighted: is_leaf = leaf_id is not None and node_id == leaf_id text.set_bbox( dict( boxstyle="round,pad=0.45", facecolor="#e8f5e9" if is_leaf else "#ffe5e9", edgecolor="#2e7d32" if is_leaf else "#ce1126", linewidth=2.5, ) ) text.set_color("#1b5e20" if is_leaf else "#7b0011") else: text.set_bbox( dict( boxstyle="round,pad=0.45", facecolor="#f5f7fa", edgecolor="#c3cfe2", linewidth=1.2, ) ) text.set_color("#2e2e2e") buffer = io.BytesIO() fig.savefig(buffer, format="png", dpi=200, bbox_inches="tight") plt.close(fig) buffer.seek(0) return Image.open(buffer) def load_html_snippet() -> str: html_path = Path(__file__).with_name("interactive_decision_tree.html") if html_path.exists(): return html_path.read_text(encoding="utf-8") return "" DATAFRAME = build_dataset() MODEL, METRICS = train_model(DATAFRAME) HTML_SNIPPET = load_html_snippet() CLASS_REPORT = classification_report( DATAFRAME["study_location"], MODEL.predict(DATAFRAME[FEATURE_NAMES]), target_names=MODEL.classes_, zero_division=0, ) DEFAULT_TREE_IMAGE = render_tree_image(MODEL) def predict_study_location( temperature: int, humidity: int, wind: int, hour: int, weekend: bool, ) -> Tuple[str, str, str, Image.Image]: sample = np.array( [ temperature, humidity, wind, hour, 1 if weekend else 0, ], dtype=float, ) probabilities = MODEL.predict_proba(sample.reshape(1, -1))[0] top_index = probabilities.argmax() label = MODEL.classes_[top_index] confidence = probabilities[top_index] secondary_index = ( probabilities.argsort()[::-1][1] if len(probabilities) > 1 else top_index ) secondary_confidence = probabilities[secondary_index] secondary_label = MODEL.classes_[secondary_index] step_details, path_nodes = describe_path(MODEL, sample) leaf_id = path_nodes[-1] if path_nodes else None path_html_items: list[str] = [] for detail in step_details: if detail.get("leaf"): path_html_items.append( f"""
Run a prediction to walk through each decision rule.
" ) with gr.Accordion("Explore the Decision Tree", open=False): tree_image = gr.Image( value=DEFAULT_TREE_IMAGE, image_mode="RGB", show_label=False, ) gr.Markdown( f"Train accuracy: `{METRICS['train_accuracy']:.1%}` | " f"Test accuracy: `{METRICS['test_accuracy']:.1%}`" ) gr.Markdown( """ **Legend:** • *gini* – how mixed the classes are (0 = pure, higher = more mixed) • *samples* – number of training rows that reached the node • *class* – majority label assigned when the tree predicts at that node """, ) with gr.Accordion("Inspect the Synthetic Dataset", open=False): gr.DataFrame( value=DATAFRAME.head(20), wrap=True, label="Sample of the training data (20 rows)", interactive=False, ) gr.Markdown( "Class balance and precision/recall on the full dataset:" f"\n```\n{CLASS_REPORT}\n```" ) if HTML_SNIPPET: with gr.Accordion("HTML Prototype (original demo)", open=False): gr.HTML(HTML_SNIPPET) run_button.click( predict_study_location, inputs=[temperature, humidity, wind, hour, weekend], outputs=[prediction_box, confidence_box, path_box, tree_image], ) if __name__ == "__main__": queued_app = demo.queue() is_hf_space = bool(os.environ.get("SPACE_ID")) default_port = int( os.environ.get( "PORT", os.environ.get("GRADIO_SERVER_PORT", "7860"), ) ) launch_kwargs = { "server_name": os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), "server_port": default_port, "show_error": True, "share": os.environ.get("GRADIO_SHARE", "").lower() == "true", } if not is_hf_space: launch_kwargs["prevent_thread_lock"] = True try: queued_app.launch(**launch_kwargs) except OSError: fallback_port = int( os.environ.get( "GRADIO_FALLBACK_PORT", str(default_port + 1111), ) ) if fallback_port == launch_kwargs["server_port"]: raise launch_kwargs["server_port"] = fallback_port queued_app.launch(**launch_kwargs)