Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| import os | |
| import yaml | |
| import pandas as pd | |
| import numpy as np | |
| from huggingface_hub import HfApi | |
| from datetime import datetime, timezone | |
| import re | |
| # Root directory of the project | |
| ROOT = Path(__file__).resolve().parent.parent | |
| # Detect Streamlit runtime | |
| try: | |
| import streamlit as st | |
| has_streamlit = True | |
| except ImportError: | |
| has_streamlit = False | |
| # Load environment variables when running locally | |
| if os.getenv("ENV") == "local" or not has_streamlit: | |
| from dotenv import load_dotenv | |
| load_dotenv(ROOT / ".env") | |
| # Read Hugging Face dataset repo ID from config | |
| with open(ROOT / "config.yaml") as f: | |
| cfg = yaml.safe_load(f) | |
| REPO_ID: str = cfg["repo_id"] | |
| # Initialize Hugging Face API client | |
| api = HfApi() | |
| # URL for the summary CSV in the dataset | |
| CSV_URL = ( | |
| f"https://huggingface.co/datasets/{REPO_ID}/resolve/main/subreddit_daily_summary.csv" | |
| ) | |
| def get_secret(key: str, default=None) -> str | None: | |
| """Fetch a secret from environment variables or Streamlit secrets.""" | |
| val = os.getenv(key) | |
| if val is None and has_streamlit: | |
| val = st.secrets.get(key, default) | |
| return val | |
| import streamlit as st | |
| def load_summary() -> pd.DataFrame: | |
| """Download and return the subreddit daily summary as a DataFrame using HF Hub API. Cached for 10 minutes.""" | |
| # Use HF Hub API to download the file instead of direct URL | |
| local_file = api.hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="subreddit_daily_summary.csv", | |
| repo_type="dataset" | |
| ) | |
| df = pd.read_csv(local_file, parse_dates=["date"]) | |
| needed = {"date", "subreddit", "mean_sentiment", "community_weighted_sentiment", "count"} | |
| if not needed.issubset(df.columns): | |
| missing = needed - set(df.columns) | |
| raise ValueError(f"Missing columns in summary CSV: {missing}") | |
| return df | |
| def _sanitize(name: str) -> str: | |
| """ | |
| Make subreddit safe for filenames (removes slashes, spaces, etc.). | |
| """ | |
| name = name.strip().lower() | |
| name = re.sub(r"[^\w\-\.]", "_", name) | |
| return name | |
| def load_day(date: str, subreddit: str) -> pd.DataFrame: | |
| """Lazy-download the parquet shard for one YYYY-MM-DD and return df slice. | |
| Args: | |
| date: Date string in YYYY-MM-DD format | |
| subreddit: Subreddit name to filter by | |
| Returns: | |
| DataFrame containing posts from the specified subreddit on the given day | |
| """ | |
| # Download the subreddit-specific file using sanitized subreddit | |
| safe_sub = _sanitize(subreddit) | |
| fname = f"data_scored_subreddit/{date}__{safe_sub}.parquet" | |
| local = api.hf_hub_download(REPO_ID, fname, repo_type="dataset") | |
| df_day = pd.read_parquet(local) | |
| # File contains only the selected subreddit; reset index | |
| return df_day.reset_index(drop=True) | |
| def get_last_updated_hf(repo_id: str) -> datetime: | |
| """ | |
| Retrieve the dataset repo's last modified datetime via HF Hub API. | |
| Returns a timezone-aware datetime in UTC. | |
| """ | |
| info = api.repo_info(repo_id=repo_id, repo_type="dataset") | |
| dt: datetime = info.lastModified # already a datetime object | |
| if dt.tzinfo is not None: | |
| dt = dt.astimezone(timezone.utc) | |
| return dt | |
| def get_last_updated_hf_caption() -> str: | |
| """ | |
| Build a markdown-formatted caption string showing the dataset source and last update. | |
| Uses REPO_ID and the HF Hub API to fetch the timestamp. | |
| """ | |
| # Generate dataset link and timestamp | |
| dataset_url = f"https://huggingface.co/datasets/{REPO_ID}" | |
| last_update_dt = get_last_updated_hf(REPO_ID) | |
| last_update = last_update_dt.strftime("%Y-%m-%d %H:%M:%S UTC") | |
| # Return the small-caption HTML/markdown string | |
| return ( | |
| f"<small>" | |
| f"Data source: <a href='{dataset_url}' target='_blank'>{REPO_ID}</a> • " | |
| f"Last updated: {last_update}" | |
| f"</small>" | |
| ) | |
| def add_rolling(df: pd.DataFrame, window: int = 7) -> pd.DataFrame: | |
| """Add a rolling mean for community_weighted_sentiment over the specified window.""" | |
| out = df.copy() | |
| for sub, grp in out.groupby("subreddit"): | |
| grp_sorted = grp.sort_values("date") | |
| roll = grp_sorted["community_weighted_sentiment"].rolling(window).mean() | |
| out.loc[grp_sorted.index, f"roll_{window}"] = roll | |
| return out | |
| def get_subreddit_colors(subreddits: list[str]) -> dict[str, str]: | |
| """Provide a consistent color map for each subreddit.""" | |
| palette = [ | |
| "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", | |
| "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", | |
| ] | |
| return {sub: palette[i % len(palette)] for i, sub in enumerate(sorted(subreddits))} | |