Spaces:
Build error
Build error
File size: 6,138 Bytes
edae885 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
import torch
import torch.nn as nn
import yfinance as yf
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import pickle
from datetime import datetime, timedelta
# Define the LSTM model architecture
class LSTMModel(nn.Module):
def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1, dropout=0.2):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True,
dropout=dropout if num_layers > 1 else 0)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, _ = self.lstm(x, (h0, c0))
out = out[:, -1, :]
out = self.dropout(out)
out = self.linear(out)
return out
def safe_load_model():
"""Safely load model and scaler"""
try:
# Try weights_only=True first (secure)
checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=True)
scaler = None
except:
try:
# Fallback to weights_only=False
checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=False)
scaler = checkpoint.get('scaler', None)
except Exception as e:
raise Exception(f"Failed to load model: {e}")
# Load model architecture
model = LSTMModel()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Load scaler from separate file if not in checkpoint
if scaler is None:
try:
with open('scaler.pkl', 'rb') as f:
scaler = pickle.load(f)
except:
raise Exception("Scaler not found. Please ensure scaler.pkl exists.")
sequence_length = checkpoint.get('sequence_length', 60)
return model, scaler, sequence_length
def predict_stock(ticker="AAPL", days=30):
try:
# Load model safely
model, scaler, sequence_length = safe_load_model()
# Fetch recent stock data
print(f"Fetching data for {ticker}...")
stock_data = yf.download(ticker, period="2y", interval="1d")
if stock_data.empty:
return create_error_plot(f"No data found for {ticker}")
# Use closing prices
closing_prices = stock_data['Close'].values.reshape(-1, 1)
# Scale the data
scaled_data = scaler.transform(closing_prices)
# Create sequence for prediction
if len(scaled_data) >= sequence_length:
last_sequence = scaled_data[-sequence_length:].reshape(1, sequence_length, 1)
else:
padding = np.full((sequence_length - len(scaled_data), 1), scaled_data[0, 0])
last_sequence = np.vstack([padding, scaled_data]).reshape(1, sequence_length, 1)
# Generate predictions
predictions = []
current_sequence = torch.FloatTensor(last_sequence)
with torch.no_grad():
for _ in range(days):
next_pred = model(current_sequence)
predictions.append(next_pred.item())
# Update sequence
new_sequence = torch.cat([
current_sequence[:, 1:, :],
next_pred.reshape(1, 1, 1)
], dim=1)
current_sequence = new_sequence
# Convert back to original scale
predictions_array = np.array(predictions).reshape(-1, 1)
predictions_original = scaler.inverse_transform(predictions_array).flatten()
# Create plot
return create_forecast_plot(stock_data, predictions_original, ticker, days)
except Exception as e:
print(f"Error in prediction: {e}")
return create_error_plot(str(e))
def create_forecast_plot(stock_data, predictions, ticker, days):
"""Create forecast plot"""
from datetime import timedelta
last_date = stock_data.index[-1]
forecast_dates = [last_date + timedelta(days=i+1) for i in range(days)]
plt.figure(figsize=(12, 6))
# Historical data
historical_days = min(100, len(stock_data))
plt.plot(stock_data.index[-historical_days:],
stock_data['Close'][-historical_days:],
label='Historical Prices', color='blue', linewidth=2)
# Forecast
plt.plot(forecast_dates, predictions,
label='Forecast', color='red', linewidth=2, linestyle='--', marker='o')
plt.title(f'{ticker} Stock Price Forecast - Next {days} Days')
plt.xlabel('Date')
plt.ylabel('Price (USD)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.xticks(rotation=45)
plt.tight_layout()
return plt
def create_error_plot(error_message):
"""Create error plot"""
plt.figure(figsize=(10, 6))
plt.text(0.5, 0.5, f'Error: {error_message}',
ha='center', va='center', transform=plt.gca().transAxes,
fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
plt.title('Prediction Error')
plt.axis('off')
return plt
# Create Gradio interface
iface = gr.Interface(
fn=predict_stock,
inputs=[
gr.Textbox(value="AAPL", label="Stock Ticker"),
gr.Number(value=30, label="Days to Forecast", minimum=1, maximum=365)
],
outputs=gr.Plot(label="Forecast Results"),
title="Stock Price Forecaster (PyTorch LSTM)",
description="Predict future stock prices using LSTM neural networks.",
)
if __name__ == "__main__":
iface.launch() |