File size: 6,138 Bytes
edae885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

import gradio as gr
import torch
import torch.nn as nn
import yfinance as yf
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import pickle
from datetime import datetime, timedelta

# Define the LSTM model architecture
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1, dropout=0.2):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, 
                           num_layers=num_layers, batch_first=True, 
                           dropout=dropout if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = self.dropout(out)
        out = self.linear(out)
        return out

def safe_load_model():
    """Safely load model and scaler"""
    try:
        # Try weights_only=True first (secure)
        checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=True)
        scaler = None
    except:
        try:
            # Fallback to weights_only=False
            checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=False)
            scaler = checkpoint.get('scaler', None)
        except Exception as e:
            raise Exception(f"Failed to load model: {e}")
    
    # Load model architecture
    model = LSTMModel()
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Load scaler from separate file if not in checkpoint
    if scaler is None:
        try:
            with open('scaler.pkl', 'rb') as f:
                scaler = pickle.load(f)
        except:
            raise Exception("Scaler not found. Please ensure scaler.pkl exists.")
    
    sequence_length = checkpoint.get('sequence_length', 60)
    
    return model, scaler, sequence_length

def predict_stock(ticker="AAPL", days=30):
    try:
        # Load model safely
        model, scaler, sequence_length = safe_load_model()
        
        # Fetch recent stock data
        print(f"Fetching data for {ticker}...")
        stock_data = yf.download(ticker, period="2y", interval="1d")
        if stock_data.empty:
            return create_error_plot(f"No data found for {ticker}")
        
        # Use closing prices
        closing_prices = stock_data['Close'].values.reshape(-1, 1)
        
        # Scale the data
        scaled_data = scaler.transform(closing_prices)
        
        # Create sequence for prediction
        if len(scaled_data) >= sequence_length:
            last_sequence = scaled_data[-sequence_length:].reshape(1, sequence_length, 1)
        else:
            padding = np.full((sequence_length - len(scaled_data), 1), scaled_data[0, 0])
            last_sequence = np.vstack([padding, scaled_data]).reshape(1, sequence_length, 1)
        
        # Generate predictions
        predictions = []
        current_sequence = torch.FloatTensor(last_sequence)
        
        with torch.no_grad():
            for _ in range(days):
                next_pred = model(current_sequence)
                predictions.append(next_pred.item())
                
                # Update sequence
                new_sequence = torch.cat([
                    current_sequence[:, 1:, :], 
                    next_pred.reshape(1, 1, 1)
                ], dim=1)
                current_sequence = new_sequence
        
        # Convert back to original scale
        predictions_array = np.array(predictions).reshape(-1, 1)
        predictions_original = scaler.inverse_transform(predictions_array).flatten()
        
        # Create plot
        return create_forecast_plot(stock_data, predictions_original, ticker, days)
        
    except Exception as e:
        print(f"Error in prediction: {e}")
        return create_error_plot(str(e))

def create_forecast_plot(stock_data, predictions, ticker, days):
    """Create forecast plot"""
    from datetime import timedelta
    
    last_date = stock_data.index[-1]
    forecast_dates = [last_date + timedelta(days=i+1) for i in range(days)]
    
    plt.figure(figsize=(12, 6))
    
    # Historical data
    historical_days = min(100, len(stock_data))
    plt.plot(stock_data.index[-historical_days:], 
             stock_data['Close'][-historical_days:], 
             label='Historical Prices', color='blue', linewidth=2)
    
    # Forecast
    plt.plot(forecast_dates, predictions, 
             label='Forecast', color='red', linewidth=2, linestyle='--', marker='o')
    
    plt.title(f'{ticker} Stock Price Forecast - Next {days} Days')
    plt.xlabel('Date')
    plt.ylabel('Price (USD)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xticks(rotation=45)
    plt.tight_layout()
    
    return plt

def create_error_plot(error_message):
    """Create error plot"""
    plt.figure(figsize=(10, 6))
    plt.text(0.5, 0.5, f'Error: {error_message}', 
            ha='center', va='center', transform=plt.gca().transAxes,
            fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
    plt.title('Prediction Error')
    plt.axis('off')
    return plt

# Create Gradio interface
iface = gr.Interface(
    fn=predict_stock,
    inputs=[
        gr.Textbox(value="AAPL", label="Stock Ticker"),
        gr.Number(value=30, label="Days to Forecast", minimum=1, maximum=365)
    ],
    outputs=gr.Plot(label="Forecast Results"),
    title="Stock Price Forecaster (PyTorch LSTM)",
    description="Predict future stock prices using LSTM neural networks.",
)

if __name__ == "__main__":
    iface.launch()