Skip to content

Commit 70361e4

Browse files
First Commit
0 parents  commit 70361e4

File tree

8 files changed

+253
-0
lines changed

8 files changed

+253
-0
lines changed

Diff for: .gitignore

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Python
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# Virtual Environment
7+
venv/
8+
env/
9+
10+
# Streamlit
11+
.streamlit/
12+
13+
# Model files
14+
models/*.pkl
15+
16+
# OS files
17+
.DS_Store
18+
Thumbs.db
19+
20+
# IDE files
21+
.vscode/
22+
.idea/

Diff for: README.md

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Stock Market Prediction App
2+
3+
This Streamlit app uses a Random Forest model to predict stock prices based on historical data.
4+
5+
## Features
6+
7+
- Fetch and display historical stock data
8+
- Train a Random Forest model for price prediction
9+
- Visualize stock price trends and model performance
10+
- Make predictions based on user input
11+
12+
## Installation
13+
14+
1. Clone this repository:
15+
```
16+
git clone https://github.com/agsurajuthaliyan/stock-market-prediction.git
17+
cd stock-market-prediction
18+
```
19+
20+
2. Install the required packages:
21+
```
22+
pip install -r requirements.txt
23+
```
24+
25+
## Usage
26+
27+
Run the Streamlit app:
28+
```
29+
streamlit run app/main.py
30+
```
31+
32+
Navigate to the provided local URL in your web browser to use the app.

Diff for: app/__innit__.py

Whitespace-only changes.

Diff for: app/data_loader.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import yfinance as yf
2+
import pandas as pd
3+
import streamlit as st
4+
5+
@st.cache_data(ttl=86400) # Cache data for 24 hours
6+
def load_data(ticker, start_date, end_date):
7+
try:
8+
stock_data = yf.download(ticker, start=start_date, end=end_date)
9+
stock_data.reset_index(inplace=True)
10+
return stock_data
11+
except Exception as e:
12+
st.error(f"Error fetching data: {str(e)}")
13+
return pd.DataFrame()
14+
15+
def validate_date(date_str):
16+
try:
17+
return pd.to_datetime(date_str)
18+
except ValueError:
19+
st.error("Invalid date format. Please use YYYY-MM-DD.")
20+
return None

Diff for: app/main.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import streamlit as st
2+
import pandas as pd
3+
from data_loader import load_data, validate_date
4+
from model import load_model, predict_price, train_and_save_model, evaluate_model, train_test_split
5+
from utils import plot_stock_price, plot_model_performance, validate_ticker, validate_numeric_input
6+
7+
# Set page config
8+
st.set_page_config(
9+
page_title="Stock Market Prediction",
10+
page_icon=":chart_with_upwards_trend:",
11+
layout="wide"
12+
)
13+
14+
# Title of the app
15+
st.title("Stock Market Prediction App📊")
16+
st.subheader("Using Random Forest🌳")
17+
18+
# Sidebar: Stock selection and date range
19+
with st.sidebar:
20+
st.header("Stock Selection")
21+
stock_ticker = validate_ticker(st.text_input("Enter Stock Ticker Symbol", value='NVDA'))
22+
start_date = st.date_input("Start Date", value=pd.to_datetime("2022-01-01"))
23+
end_date = st.date_input("End Date", value=pd.to_datetime("2024-09-01"))
24+
25+
# Load and display data
26+
with st.spinner("Fetching stock data..."):
27+
hist = load_data(stock_ticker, start_date, end_date)
28+
29+
if not hist.empty:
30+
st.success("Data successfully loaded!")
31+
st.write(f"Displaying data for: **{stock_ticker}**")
32+
33+
# Display stock price chart
34+
fig = plot_stock_price(hist, stock_ticker)
35+
st.plotly_chart(fig)
36+
37+
# Display historical data
38+
st.write("**Filtered Historical Data** (sorted by Date)")
39+
st.dataframe(hist.sort_values(by='Date'))
40+
41+
# Model training and evaluation
42+
X = hist.drop(columns=['Date', 'Close', 'Adj Close'])
43+
y = hist['Close']
44+
45+
regressor = load_model(stock_ticker)
46+
if regressor is None:
47+
regressor, X_test, y_test = train_and_save_model(X, y, stock_ticker)
48+
else:
49+
# If the model is loaded, we need to create X_test and y_test for evaluation
50+
_, X_test, _, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
51+
52+
# Evaluate the model
53+
mse, rmse, mae, y_pred = evaluate_model(regressor, X_test, y_test)
54+
55+
# Display model performance metrics
56+
st.subheader("Model Performance Metrics")
57+
col1, col2, col3 = st.columns(3)
58+
with col1:
59+
st.metric("Mean Squared Error (MSE)", f"{mse:.4f}")
60+
with col2:
61+
st.metric("Root Mean Squared Error (RMSE)", f"{rmse:.4f}")
62+
with col3:
63+
st.metric("Mean Absolute Error (MAE)", f"{mae:.4f}")
64+
65+
# Plot model performance
66+
plot_model_performance(y_test, y_pred)
67+
68+
# Prediction inputs
69+
with st.sidebar:
70+
st.header("Prediction Inputs")
71+
open_price = validate_numeric_input(st.number_input("Open Price", min_value=0.0, step=0.1), "Open Price")
72+
high_price = validate_numeric_input(st.number_input("High Price", min_value=0.0, step=0.1), "High Price")
73+
low_price = validate_numeric_input(st.number_input("Low Price", min_value=0.0, step=0.1), "Low Price")
74+
volume = validate_numeric_input(st.number_input("Volume", min_value=0, step=1), "Volume")
75+
76+
# Predict button
77+
if st.sidebar.button("Predict Closing Price"):
78+
if all([open_price, high_price, low_price, volume]):
79+
prediction = predict_price(regressor, open_price, high_price, low_price, volume)
80+
st.subheader(f"Predicted Closing Price for {stock_ticker}: {prediction:.2f}")
81+
82+
# Model performance
83+
y_pred = regressor.predict(X)
84+
plot_model_performance(y, y_pred)
85+
else:
86+
st.error("Please enter valid values for all inputs.")
87+
else:
88+
st.error(f"Unable to load data for {stock_ticker}. Please check the ticker symbol.")

Diff for: app/model.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import joblib
3+
import streamlit as st
4+
from sklearn.ensemble import RandomForestRegressor
5+
from sklearn.model_selection import train_test_split
6+
from sklearn.metrics import mean_squared_error, mean_absolute_error
7+
import pandas as pd
8+
import numpy as np
9+
10+
def save_model(stock_ticker, model):
11+
if not os.path.exists('models'):
12+
os.makedirs('models')
13+
filename = f"models/{stock_ticker}_model.pkl"
14+
joblib.dump(model, filename)
15+
16+
def load_model(stock_ticker):
17+
filename = f"models/{stock_ticker}_model.pkl"
18+
if os.path.exists(filename):
19+
return joblib.load(filename)
20+
else:
21+
return None
22+
23+
def train_and_save_model(X, y, stock_ticker):
24+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
25+
26+
with st.spinner("Training new model..."):
27+
regressor = RandomForestRegressor(n_estimators=100, random_state=42)
28+
regressor.fit(X_train, y_train)
29+
30+
save_model(stock_ticker, regressor)
31+
st.success(f"Model training completed and saved for {stock_ticker}!")
32+
33+
return regressor, X_test, y_test
34+
35+
def predict_price(model, open_price, high_price, low_price, volume):
36+
new_data = pd.DataFrame({
37+
'Open': [open_price],
38+
'High': [high_price],
39+
'Low': [low_price],
40+
'Volume': [volume]
41+
})
42+
prediction = model.predict(new_data)
43+
return prediction[0]
44+
45+
def evaluate_model(model, X_test, y_test):
46+
y_pred = model.predict(X_test)
47+
mse = mean_squared_error(y_test, y_pred)
48+
rmse = np.sqrt(mse)
49+
mae = mean_absolute_error(y_test, y_pred)
50+
return mse, rmse, mae, y_pred

Diff for: app/utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import streamlit as st
2+
import plotly.express as px
3+
import plotly.graph_objects as go
4+
5+
def plot_stock_price(data, stock_ticker):
6+
fig = px.line(
7+
data,
8+
x='Date',
9+
y='Close',
10+
title=f"{stock_ticker} Closing Price Over Time",
11+
range_x=[data['Date'].min(), data['Date'].max()],
12+
)
13+
return fig
14+
15+
def plot_model_performance(y_true, y_pred):
16+
fig = go.Figure()
17+
fig.add_trace(go.Scatter(x=y_true, y=y_pred, mode='markers', name='Predictions'))
18+
fig.add_trace(go.Scatter(x=[y_true.min(), y_true.max()], y=[y_true.min(), y_true.max()],
19+
mode='lines', name='Ideal Prediction', line=dict(color='red', dash='dash')))
20+
fig.update_layout(title='Actual vs Predicted Closing Prices',
21+
xaxis_title='Actual Price',
22+
yaxis_title='Predicted Price')
23+
st.plotly_chart(fig)
24+
25+
def validate_ticker(ticker):
26+
if not ticker or not ticker.isalpha():
27+
st.error("Invalid ticker. Please enter a valid stock symbol.")
28+
return ""
29+
return ticker.upper()
30+
31+
def validate_numeric_input(value, name):
32+
if value <= 0:
33+
st.error(f"Invalid {name}. Please enter a positive number.")
34+
return None
35+
return value

Diff for: requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
streamlit==1.24.0
2+
pandas==1.5.3
3+
yfinance==0.2.18
4+
scikit-learn==1.2.2
5+
plotly==5.14.1
6+
joblib==1.2.0

0 commit comments

Comments
 (0)