Chris9293's picture
Update app.py
4b98612 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re
# Load model + tokenizer
MODEL_NAME = "yiyanghkust/finbert-tone"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
id2label = model.config.id2label # e.g. {0: 'negative', 1: 'neutral', 2: 'positive'}
num_labels = len(id2label)
def split_into_sentences(text: str):
"""Basic sentence splitter: split on ., !, ? and keep non-empty pieces."""
text = text.replace("\n", " ")
sentences = re.split(r'(?<=[\.\?\!])\s+', text)
return [s.strip() for s in sentences if len(s.strip()) > 0]
def analyze_sentiment(text: str) -> str:
text = text.strip()
if not text:
return "Please enter some text."
segments = split_into_sentences(text)
if not segments:
segments = [text]
all_probs = []
with torch.no_grad():
for seg in segments:
inputs = tokenizer(seg, return_tensors="pt", truncation=True, max_length=512)
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)[0] # shape [num_labels]
all_probs.append(probs)
# Average probabilities across segments
avg_probs = torch.stack(all_probs, dim=0).mean(dim=0) # [num_labels]
best_idx = int(torch.argmax(avg_probs))
overall_label = id2label[best_idx]
# Build nice markdown output
lines = []
lines.append("### Sentiment Analysis Result")
lines.append("")
lines.append("**Full Input Text:** ")
lines.append(text)
lines.append("")
lines.append(f"**Segments Analyzed ({len(segments)}):**")
lines.append(str(segments))
lines.append("")
lines.append(f"**Overall Prediction (averaged):** {overall_label.capitalize()}")
lines.append("")
lines.append("**Probabilities:**")
# Print probabilities in a stable, readable order
for i in range(num_labels):
label_name = id2label[i].capitalize()
lines.append(f"- {label_name}: {float(avg_probs[i]):.4f}")
return "\n".join(lines)
with gr.Blocks() as demo:
gr.Markdown("# πŸ“˜ FinBERT Financial Sentiment Analyzer")
gr.Markdown(
"Paste financial news, earnings reports, or company statements below. "
"The tool splits the text into sentences, runs FinBERT on each, and averages the sentiment."
)
input_box = gr.Textbox(lines=4, label="Input Text")
output_box = gr.Markdown(label="Analysis Result")
analyze_btn = gr.Button("Analyze")
analyze_btn.click(fn=analyze_sentiment, inputs=input_box, outputs=output_box)
if __name__ == "__main__":
demo.launch()