AIsumit123 commited on
Commit
edae885
·
verified ·
1 Parent(s): 3d0fceb

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +168 -0
  2. lstm_stock_model.pth +3 -0
  3. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import yfinance as yf
6
+ import pandas as pd
7
+ import numpy as np
8
+ from sklearn.preprocessing import MinMaxScaler
9
+ import matplotlib.pyplot as plt
10
+ import pickle
11
+ from datetime import datetime, timedelta
12
+
13
+ # Define the LSTM model architecture
14
+ class LSTMModel(nn.Module):
15
+ def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1, dropout=0.2):
16
+ super(LSTMModel, self).__init__()
17
+ self.hidden_size = hidden_size
18
+ self.num_layers = num_layers
19
+ self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
20
+ num_layers=num_layers, batch_first=True,
21
+ dropout=dropout if num_layers > 1 else 0)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.linear = nn.Linear(hidden_size, output_size)
24
+
25
+ def forward(self, x):
26
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
27
+ c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
28
+ out, _ = self.lstm(x, (h0, c0))
29
+ out = out[:, -1, :]
30
+ out = self.dropout(out)
31
+ out = self.linear(out)
32
+ return out
33
+
34
+ def safe_load_model():
35
+ """Safely load model and scaler"""
36
+ try:
37
+ # Try weights_only=True first (secure)
38
+ checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=True)
39
+ scaler = None
40
+ except:
41
+ try:
42
+ # Fallback to weights_only=False
43
+ checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=False)
44
+ scaler = checkpoint.get('scaler', None)
45
+ except Exception as e:
46
+ raise Exception(f"Failed to load model: {e}")
47
+
48
+ # Load model architecture
49
+ model = LSTMModel()
50
+ model.load_state_dict(checkpoint['model_state_dict'])
51
+ model.eval()
52
+
53
+ # Load scaler from separate file if not in checkpoint
54
+ if scaler is None:
55
+ try:
56
+ with open('scaler.pkl', 'rb') as f:
57
+ scaler = pickle.load(f)
58
+ except:
59
+ raise Exception("Scaler not found. Please ensure scaler.pkl exists.")
60
+
61
+ sequence_length = checkpoint.get('sequence_length', 60)
62
+
63
+ return model, scaler, sequence_length
64
+
65
+ def predict_stock(ticker="AAPL", days=30):
66
+ try:
67
+ # Load model safely
68
+ model, scaler, sequence_length = safe_load_model()
69
+
70
+ # Fetch recent stock data
71
+ print(f"Fetching data for {ticker}...")
72
+ stock_data = yf.download(ticker, period="2y", interval="1d")
73
+ if stock_data.empty:
74
+ return create_error_plot(f"No data found for {ticker}")
75
+
76
+ # Use closing prices
77
+ closing_prices = stock_data['Close'].values.reshape(-1, 1)
78
+
79
+ # Scale the data
80
+ scaled_data = scaler.transform(closing_prices)
81
+
82
+ # Create sequence for prediction
83
+ if len(scaled_data) >= sequence_length:
84
+ last_sequence = scaled_data[-sequence_length:].reshape(1, sequence_length, 1)
85
+ else:
86
+ padding = np.full((sequence_length - len(scaled_data), 1), scaled_data[0, 0])
87
+ last_sequence = np.vstack([padding, scaled_data]).reshape(1, sequence_length, 1)
88
+
89
+ # Generate predictions
90
+ predictions = []
91
+ current_sequence = torch.FloatTensor(last_sequence)
92
+
93
+ with torch.no_grad():
94
+ for _ in range(days):
95
+ next_pred = model(current_sequence)
96
+ predictions.append(next_pred.item())
97
+
98
+ # Update sequence
99
+ new_sequence = torch.cat([
100
+ current_sequence[:, 1:, :],
101
+ next_pred.reshape(1, 1, 1)
102
+ ], dim=1)
103
+ current_sequence = new_sequence
104
+
105
+ # Convert back to original scale
106
+ predictions_array = np.array(predictions).reshape(-1, 1)
107
+ predictions_original = scaler.inverse_transform(predictions_array).flatten()
108
+
109
+ # Create plot
110
+ return create_forecast_plot(stock_data, predictions_original, ticker, days)
111
+
112
+ except Exception as e:
113
+ print(f"Error in prediction: {e}")
114
+ return create_error_plot(str(e))
115
+
116
+ def create_forecast_plot(stock_data, predictions, ticker, days):
117
+ """Create forecast plot"""
118
+ from datetime import timedelta
119
+
120
+ last_date = stock_data.index[-1]
121
+ forecast_dates = [last_date + timedelta(days=i+1) for i in range(days)]
122
+
123
+ plt.figure(figsize=(12, 6))
124
+
125
+ # Historical data
126
+ historical_days = min(100, len(stock_data))
127
+ plt.plot(stock_data.index[-historical_days:],
128
+ stock_data['Close'][-historical_days:],
129
+ label='Historical Prices', color='blue', linewidth=2)
130
+
131
+ # Forecast
132
+ plt.plot(forecast_dates, predictions,
133
+ label='Forecast', color='red', linewidth=2, linestyle='--', marker='o')
134
+
135
+ plt.title(f'{ticker} Stock Price Forecast - Next {days} Days')
136
+ plt.xlabel('Date')
137
+ plt.ylabel('Price (USD)')
138
+ plt.legend()
139
+ plt.grid(True, alpha=0.3)
140
+ plt.xticks(rotation=45)
141
+ plt.tight_layout()
142
+
143
+ return plt
144
+
145
+ def create_error_plot(error_message):
146
+ """Create error plot"""
147
+ plt.figure(figsize=(10, 6))
148
+ plt.text(0.5, 0.5, f'Error: {error_message}',
149
+ ha='center', va='center', transform=plt.gca().transAxes,
150
+ fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
151
+ plt.title('Prediction Error')
152
+ plt.axis('off')
153
+ return plt
154
+
155
+ # Create Gradio interface
156
+ iface = gr.Interface(
157
+ fn=predict_stock,
158
+ inputs=[
159
+ gr.Textbox(value="AAPL", label="Stock Ticker"),
160
+ gr.Number(value=30, label="Days to Forecast", minimum=1, maximum=365)
161
+ ],
162
+ outputs=gr.Plot(label="Forecast Results"),
163
+ title="Stock Price Forecaster (PyTorch LSTM)",
164
+ description="Predict future stock prices using LSTM neural networks.",
165
+ )
166
+
167
+ if __name__ == "__main__":
168
+ iface.launch()
lstm_stock_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7b2155b56e4dc18b3ee846a8193c2c0dc2da3e73333945ee1a1debec651e52
3
+ size 45938
requirements.txt ADDED
Binary file (4.7 kB). View file