AIsumit123's picture
Upload 3 files
edae885 verified
import gradio as gr
import torch
import torch.nn as nn
import yfinance as yf
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import pickle
from datetime import datetime, timedelta
# Define the LSTM model architecture
class LSTMModel(nn.Module):
def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1, dropout=0.2):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True,
dropout=dropout if num_layers > 1 else 0)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, _ = self.lstm(x, (h0, c0))
out = out[:, -1, :]
out = self.dropout(out)
out = self.linear(out)
return out
def safe_load_model():
"""Safely load model and scaler"""
try:
# Try weights_only=True first (secure)
checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=True)
scaler = None
except:
try:
# Fallback to weights_only=False
checkpoint = torch.load('lstm_stock_model.pth', map_location='cpu', weights_only=False)
scaler = checkpoint.get('scaler', None)
except Exception as e:
raise Exception(f"Failed to load model: {e}")
# Load model architecture
model = LSTMModel()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Load scaler from separate file if not in checkpoint
if scaler is None:
try:
with open('scaler.pkl', 'rb') as f:
scaler = pickle.load(f)
except:
raise Exception("Scaler not found. Please ensure scaler.pkl exists.")
sequence_length = checkpoint.get('sequence_length', 60)
return model, scaler, sequence_length
def predict_stock(ticker="AAPL", days=30):
try:
# Load model safely
model, scaler, sequence_length = safe_load_model()
# Fetch recent stock data
print(f"Fetching data for {ticker}...")
stock_data = yf.download(ticker, period="2y", interval="1d")
if stock_data.empty:
return create_error_plot(f"No data found for {ticker}")
# Use closing prices
closing_prices = stock_data['Close'].values.reshape(-1, 1)
# Scale the data
scaled_data = scaler.transform(closing_prices)
# Create sequence for prediction
if len(scaled_data) >= sequence_length:
last_sequence = scaled_data[-sequence_length:].reshape(1, sequence_length, 1)
else:
padding = np.full((sequence_length - len(scaled_data), 1), scaled_data[0, 0])
last_sequence = np.vstack([padding, scaled_data]).reshape(1, sequence_length, 1)
# Generate predictions
predictions = []
current_sequence = torch.FloatTensor(last_sequence)
with torch.no_grad():
for _ in range(days):
next_pred = model(current_sequence)
predictions.append(next_pred.item())
# Update sequence
new_sequence = torch.cat([
current_sequence[:, 1:, :],
next_pred.reshape(1, 1, 1)
], dim=1)
current_sequence = new_sequence
# Convert back to original scale
predictions_array = np.array(predictions).reshape(-1, 1)
predictions_original = scaler.inverse_transform(predictions_array).flatten()
# Create plot
return create_forecast_plot(stock_data, predictions_original, ticker, days)
except Exception as e:
print(f"Error in prediction: {e}")
return create_error_plot(str(e))
def create_forecast_plot(stock_data, predictions, ticker, days):
"""Create forecast plot"""
from datetime import timedelta
last_date = stock_data.index[-1]
forecast_dates = [last_date + timedelta(days=i+1) for i in range(days)]
plt.figure(figsize=(12, 6))
# Historical data
historical_days = min(100, len(stock_data))
plt.plot(stock_data.index[-historical_days:],
stock_data['Close'][-historical_days:],
label='Historical Prices', color='blue', linewidth=2)
# Forecast
plt.plot(forecast_dates, predictions,
label='Forecast', color='red', linewidth=2, linestyle='--', marker='o')
plt.title(f'{ticker} Stock Price Forecast - Next {days} Days')
plt.xlabel('Date')
plt.ylabel('Price (USD)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.xticks(rotation=45)
plt.tight_layout()
return plt
def create_error_plot(error_message):
"""Create error plot"""
plt.figure(figsize=(10, 6))
plt.text(0.5, 0.5, f'Error: {error_message}',
ha='center', va='center', transform=plt.gca().transAxes,
fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
plt.title('Prediction Error')
plt.axis('off')
return plt
# Create Gradio interface
iface = gr.Interface(
fn=predict_stock,
inputs=[
gr.Textbox(value="AAPL", label="Stock Ticker"),
gr.Number(value=30, label="Days to Forecast", minimum=1, maximum=365)
],
outputs=gr.Plot(label="Forecast Results"),
title="Stock Price Forecaster (PyTorch LSTM)",
description="Predict future stock prices using LSTM neural networks.",
)
if __name__ == "__main__":
iface.launch()