IS-335-Demo / app.py
Ric
Document tree legend
e06319a
import io
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
FEATURE_NAMES = [
"temperature_f",
"humidity_percent",
"wind_mph",
"hour_of_day",
"is_weekend",
]
CACHE_ROOT = Path(__file__).with_name(".cache")
CACHE_ROOT.mkdir(exist_ok=True)
matplotlib_cache = CACHE_ROOT / "matplotlib"
matplotlib_cache.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("MPLCONFIGDIR", str(matplotlib_cache))
os.environ.setdefault("XDG_CACHE_HOME", str(CACHE_ROOT))
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
CUSTOM_CSS = """
#predict-button {
background: linear-gradient(135deg, #ce1126, #8b0000);
color: white;
font-weight: 700;
font-size: 1.05rem;
border: none;
box-shadow: 0 6px 18px rgba(139, 0, 0, 0.35);
}
#predict-button:hover {
transform: translateY(-1px);
box-shadow: 0 10px 24px rgba(139, 0, 0, 0.45);
}
.prediction-card {
border: 3px solid #ce1126;
border-radius: 18px;
padding: 1.5rem;
background: #fff4f4;
text-align: center;
box-shadow: 0 12px 30px rgba(206, 17, 38, 0.15);
}
.prediction-card .location {
font-size: 2rem;
font-weight: 800;
color: #7b0011;
letter-spacing: 0.5px;
}
.prediction-card .location span {
text-transform: uppercase;
}
.prediction-card .confidence {
margin-top: 0.75rem;
font-size: 1.05rem;
color: #333;
}
.prediction-card .secondary {
margin-top: 0.25rem;
font-size: 0.95rem;
color: #555;
}
.path-list {
list-style: none;
padding: 0;
margin: 0;
display: flex;
flex-direction: column;
gap: 0.75rem;
}
.path-list li {
background: #f8f9fa;
border-left: 5px solid #ce1126;
padding: 0.75rem 1rem;
border-radius: 10px;
box-shadow: inset 0 0 0 1px rgba(0, 0, 0, 0.05);
}
.path-list li .headline {
font-weight: 700;
color: #7b0011;
margin-bottom: 0.2rem;
}
.path-list li .meta {
color: #333;
font-size: 0.95rem;
}
.path-list li.leaf {
border-left-color: #1b5e20;
background: #e8f5e9;
}
.path-list li.leaf .headline {
color: #1b5e20;
}
"""
def build_dataset(n_samples: int = 200, seed: int = 42) -> pd.DataFrame:
rng = np.random.default_rng(seed)
data = pd.DataFrame(
{
"temperature_f": rng.integers(60, 115, n_samples),
"humidity_percent": rng.integers(10, 40, n_samples),
"wind_mph": rng.integers(0, 25, n_samples),
"hour_of_day": rng.integers(8, 22, n_samples),
"is_weekend": rng.integers(0, 2, n_samples),
}
)
labels: list[str] = []
for idx in range(n_samples):
temp = data.at[idx, "temperature_f"]
wind = data.at[idx, "wind_mph"]
hour = data.at[idx, "hour_of_day"]
if temp < 85 and wind < 15 and 8 <= hour <= 18:
labels.append("Outdoors")
elif temp > 105:
labels.append("Library")
elif wind > 20:
labels.append("Library")
elif hour > 19:
labels.append("Library")
else:
labels.append(rng.choice(["Library", "Outdoors"], p=[0.6, 0.4]))
data["study_location"] = labels
return data
def train_model(data: pd.DataFrame) -> Tuple[DecisionTreeClassifier, Dict[str, float]]:
X = data[FEATURE_NAMES]
y = data["study_location"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
clf = DecisionTreeClassifier(
max_depth=3,
min_samples_split=10,
random_state=42,
)
clf.fit(X_train, y_train)
y_pred_train = clf.predict(X_train)
y_pred_test = clf.predict(X_test)
metrics = {
"train_accuracy": accuracy_score(y_train, y_pred_train),
"test_accuracy": accuracy_score(y_test, y_pred_test),
}
return clf, metrics
def describe_path(
model: DecisionTreeClassifier, sample: np.ndarray
) -> Tuple[List[Dict[str, object]], List[int]]:
tree = model.tree_
node_indicator = model.decision_path(sample.reshape(1, -1))
leaf_id = model.apply(sample.reshape(1, -1))[0]
start, end = node_indicator.indptr[:2]
node_index = node_indicator.indices[start:end]
steps: List[Dict[str, object]] = []
path_nodes = list(node_index) + [leaf_id]
for node_position, node_id in enumerate(node_index, start=1):
feature_index = int(tree.feature[node_id])
threshold = float(tree.threshold[node_id])
feature_name = FEATURE_NAMES[feature_index]
feature_value = float(sample[feature_index])
go_left = feature_value <= threshold
direction = "≤" if go_left else ">"
next_node = int(tree.children_left[node_id] if go_left else tree.children_right[node_id])
steps.append(
{
"step": node_position,
"node_id": int(node_id),
"next_node": next_node,
"feature": feature_name,
"feature_label": feature_name.replace("_", " ").title(),
"threshold": threshold,
"value": feature_value,
"direction": direction,
"decision": "left" if go_left else "right",
}
)
leaf_samples = int(tree.n_node_samples[leaf_id])
confidences = model.predict_proba(sample.reshape(1, -1))[0]
class_idx = int(confidences.argmax())
class_name = model.classes_[class_idx]
steps.append(
{
"step": len(node_index) + 1,
"node_id": int(leaf_id),
"leaf": True,
"prediction": class_name,
"confidence": float(confidences[class_idx]),
"samples": leaf_samples,
}
)
return steps, path_nodes
def render_tree_image(
model: DecisionTreeClassifier,
highlighted_nodes: Optional[List[int]] = None,
leaf_id: Optional[int] = None,
) -> Image.Image:
highlighted = set(highlighted_nodes or [])
fig, ax = plt.subplots(figsize=(10, 6))
texts = plot_tree(
model,
feature_names=[name.replace("_", " ").title() for name in FEATURE_NAMES],
class_names=model.classes_,
filled=False,
rounded=True,
node_ids=True,
fontsize=9,
ax=ax,
)
fig.tight_layout()
for text in texts:
label = text.get_text()
node_id = None
for line in label.split("\n"):
if line.startswith("node id ="):
try:
node_id = int(line.split("=")[1].strip())
except ValueError:
node_id = None
break
if node_id is None:
continue
if node_id in highlighted:
is_leaf = leaf_id is not None and node_id == leaf_id
text.set_bbox(
dict(
boxstyle="round,pad=0.45",
facecolor="#e8f5e9" if is_leaf else "#ffe5e9",
edgecolor="#2e7d32" if is_leaf else "#ce1126",
linewidth=2.5,
)
)
text.set_color("#1b5e20" if is_leaf else "#7b0011")
else:
text.set_bbox(
dict(
boxstyle="round,pad=0.45",
facecolor="#f5f7fa",
edgecolor="#c3cfe2",
linewidth=1.2,
)
)
text.set_color("#2e2e2e")
buffer = io.BytesIO()
fig.savefig(buffer, format="png", dpi=200, bbox_inches="tight")
plt.close(fig)
buffer.seek(0)
return Image.open(buffer)
def load_html_snippet() -> str:
html_path = Path(__file__).with_name("interactive_decision_tree.html")
if html_path.exists():
return html_path.read_text(encoding="utf-8")
return ""
DATAFRAME = build_dataset()
MODEL, METRICS = train_model(DATAFRAME)
HTML_SNIPPET = load_html_snippet()
CLASS_REPORT = classification_report(
DATAFRAME["study_location"],
MODEL.predict(DATAFRAME[FEATURE_NAMES]),
target_names=MODEL.classes_,
zero_division=0,
)
DEFAULT_TREE_IMAGE = render_tree_image(MODEL)
def predict_study_location(
temperature: int,
humidity: int,
wind: int,
hour: int,
weekend: bool,
) -> Tuple[str, str, str, Image.Image]:
sample = np.array(
[
temperature,
humidity,
wind,
hour,
1 if weekend else 0,
],
dtype=float,
)
probabilities = MODEL.predict_proba(sample.reshape(1, -1))[0]
top_index = probabilities.argmax()
label = MODEL.classes_[top_index]
confidence = probabilities[top_index]
secondary_index = (
probabilities.argsort()[::-1][1]
if len(probabilities) > 1
else top_index
)
secondary_confidence = probabilities[secondary_index]
secondary_label = MODEL.classes_[secondary_index]
step_details, path_nodes = describe_path(MODEL, sample)
leaf_id = path_nodes[-1] if path_nodes else None
path_html_items: list[str] = []
for detail in step_details:
if detail.get("leaf"):
path_html_items.append(
f"""
<li class="leaf">
<div class="headline">Leaf: predict {detail['prediction']}</div>
<div class="meta">Confidence {detail['confidence']:.0%} · Support {detail['samples']} samples</div>
</li>
""".strip()
)
else:
threshold = detail["threshold"]
value = detail["value"]
direction = detail["direction"]
feature_label = detail["feature_label"]
decision = detail["decision"]
next_node = detail["next_node"]
path_html_items.append(
f"""
<li>
<div class="headline">
Step {detail['step']}: {feature_label} {direction} {threshold:.1f}
</div>
<div class="meta">
Observed value {value:.1f} → take {decision} branch (node {next_node})
</div>
</li>
""".strip()
)
path_html = "<ol class='path-list'>" + "\n".join(path_html_items) + "</ol>"
prediction_html = f"""
<div class="prediction-card">
<div class="location">Study at the <span>{label}</span></div>
<div class="confidence">Confidence: {confidence:.1%}</div>
<div class="secondary">Next best: {secondary_label} ({secondary_confidence:.1%})</div>
</div>
""".strip()
confidence_text = (
f"Primary recommendation: **{label}** (`{confidence:.1%}`) · "
f"Alternate: **{secondary_label}** (`{secondary_confidence:.1%}`)"
)
highlighted_image = render_tree_image(
MODEL,
highlighted_nodes=path_nodes,
leaf_id=leaf_id,
)
return prediction_html, confidence_text, path_html, highlighted_image
with gr.Blocks(
title="UNLV Study Location Predictor",
css=CUSTOM_CSS,
) as demo:
gr.Markdown(
"""
# UNLV Study Location Predictor
Adjust the sliders to mirror the current Las Vegas weather and the decision tree
will suggest the best study location. All numbers come from a synthetic dataset
tailored for classroom walkthroughs.
"""
)
with gr.Row():
with gr.Column():
temperature = gr.Slider(
minimum=60,
maximum=115,
value=85,
step=1,
label="Temperature (°F)",
info="Normal daytime range in Las Vegas",
)
humidity = gr.Slider(
minimum=10,
maximum=40,
value=20,
step=1,
label="Humidity (%)",
)
wind = gr.Slider(
minimum=0,
maximum=25,
value=10,
step=1,
label="Wind Speed (mph)",
)
hour = gr.Slider(
minimum=8,
maximum=22,
value=14,
step=1,
label="Hour of Day (24h)",
)
weekend = gr.Checkbox(
label="Is it the weekend?",
value=False,
)
run_button = gr.Button(
"Predict Study Location",
elem_id="predict-button",
)
with gr.Column():
prediction_box = gr.HTML()
confidence_box = gr.Markdown()
path_box = gr.HTML(
"<p style='color:#666;font-style:italic;'>Run a prediction to walk through each decision rule.</p>"
)
with gr.Accordion("Explore the Decision Tree", open=False):
tree_image = gr.Image(
value=DEFAULT_TREE_IMAGE,
image_mode="RGB",
show_label=False,
)
gr.Markdown(
f"Train accuracy: `{METRICS['train_accuracy']:.1%}` | "
f"Test accuracy: `{METRICS['test_accuracy']:.1%}`"
)
gr.Markdown(
"""
**Legend:**
• *gini* – how mixed the classes are (0 = pure, higher = more mixed)
• *samples* – number of training rows that reached the node
• *class* – majority label assigned when the tree predicts at that node
""",
)
with gr.Accordion("Inspect the Synthetic Dataset", open=False):
gr.DataFrame(
value=DATAFRAME.head(20),
wrap=True,
label="Sample of the training data (20 rows)",
interactive=False,
)
gr.Markdown(
"Class balance and precision/recall on the full dataset:"
f"\n```\n{CLASS_REPORT}\n```"
)
if HTML_SNIPPET:
with gr.Accordion("HTML Prototype (original demo)", open=False):
gr.HTML(HTML_SNIPPET)
run_button.click(
predict_study_location,
inputs=[temperature, humidity, wind, hour, weekend],
outputs=[prediction_box, confidence_box, path_box, tree_image],
)
if __name__ == "__main__":
queued_app = demo.queue()
is_hf_space = bool(os.environ.get("SPACE_ID"))
default_port = int(
os.environ.get(
"PORT",
os.environ.get("GRADIO_SERVER_PORT", "7860"),
)
)
launch_kwargs = {
"server_name": os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
"server_port": default_port,
"show_error": True,
"share": os.environ.get("GRADIO_SHARE", "").lower() == "true",
}
if not is_hf_space:
launch_kwargs["prevent_thread_lock"] = True
try:
queued_app.launch(**launch_kwargs)
except OSError:
fallback_port = int(
os.environ.get(
"GRADIO_FALLBACK_PORT",
str(default_port + 1111),
)
)
if fallback_port == launch_kwargs["server_port"]:
raise
launch_kwargs["server_port"] = fallback_port
queued_app.launch(**launch_kwargs)