import gradio as gr import torch import torch.nn as nn import yfinance as yf import pandas as pd import numpy as np from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt import pickle from datetime import datetime, timedelta # Define the LSTM model architecture class LSTMModel(nn.Module): def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1, dropout=0.2): super(LSTMModel, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(hidden_size, output_size) def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) out, _ = self.lstm(x, (h0, c0)) out = out[:, -1, :] out = self.dropout(out) out = self.linear(out) return out def safe_load_model(): """Safely load model and scaler""" try: # Try weights_only=True first (secure) checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=True) scaler = None except: try: # Fallback to weights_only=False checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=False) scaler = checkpoint.get('scaler', None) except Exception as e: raise Exception(f"Failed to load model: {e}") # Load model architecture model = LSTMModel() model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Load scaler from separate file if not in checkpoint if scaler is None: try: with open('scaler.pkl', 'rb') as f: scaler = pickle.load(f) except: raise Exception("Scaler not found. Please ensure scaler.pkl exists.") sequence_length = checkpoint.get('sequence_length', 60) return model, scaler, sequence_length def predict_stock(ticker="AAPL", days=30): try: # Load model safely model, scaler, sequence_length = safe_load_model() # Fetch recent stock data print(f"Fetching data for {ticker}...") stock_data = yf.download(ticker, period="2y", interval="1d") if stock_data.empty: return create_error_plot(f"No data found for {ticker}") # Use closing prices closing_prices = stock_data['Close'].values.reshape(-1, 1) # Scale the data scaled_data = scaler.transform(closing_prices) # Create sequence for prediction if len(scaled_data) >= sequence_length: last_sequence = scaled_data[-sequence_length:].reshape(1, sequence_length, 1) else: padding = np.full((sequence_length - len(scaled_data), 1), scaled_data[0, 0]) last_sequence = np.vstack([padding, scaled_data]).reshape(1, sequence_length, 1) # Generate predictions predictions = [] current_sequence = torch.FloatTensor(last_sequence) with torch.no_grad(): for _ in range(days): next_pred = model(current_sequence) predictions.append(next_pred.item()) # Update sequence new_sequence = torch.cat([ current_sequence[:, 1:, :], next_pred.reshape(1, 1, 1) ], dim=1) current_sequence = new_sequence # Convert back to original scale predictions_array = np.array(predictions).reshape(-1, 1) predictions_original = scaler.inverse_transform(predictions_array).flatten() # Create plot return create_forecast_plot(stock_data, predictions_original, ticker, days) except Exception as e: print(f"Error in prediction: {e}") return create_error_plot(str(e)) def create_forecast_plot(stock_data, predictions, ticker, days): """Create forecast plot""" from datetime import timedelta last_date = stock_data.index[-1] forecast_dates = [last_date + timedelta(days=i+1) for i in range(days)] plt.figure(figsize=(12, 6)) # Historical data historical_days = min(100, len(stock_data)) plt.plot(stock_data.index[-historical_days:], stock_data['Close'][-historical_days:], label='Historical Prices', color='blue', linewidth=2) # Forecast plt.plot(forecast_dates, predictions, label='Forecast', color='red', linewidth=2, linestyle='--', marker='o') plt.title(f'{ticker} Stock Price Forecast - Next {days} Days') plt.xlabel('Date') plt.ylabel('Price (USD)') plt.legend() plt.grid(True, alpha=0.3) plt.xticks(rotation=45) plt.tight_layout() return plt def create_error_plot(error_message): """Create error plot""" plt.figure(figsize=(10, 6)) plt.text(0.5, 0.5, f'Error: {error_message}', ha='center', va='center', transform=plt.gca().transAxes, fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral")) plt.title('Prediction Error') plt.axis('off') return plt # Create Gradio interface iface = gr.Interface( fn=predict_stock, inputs=[ gr.Textbox(value="AAPL", label="Stock Ticker"), gr.Number(value=30, label="Days to Forecast", minimum=1, maximum=365) ], outputs=gr.Plot(label="Forecast Results"), title="Stock Price Forecaster (PyTorch LSTM)", description="Predict future stock prices using LSTM neural networks.", ) if __name__ == "__main__": iface.launch()