Parthiban97 commited on
Commit
58e5d1f
·
verified ·
1 Parent(s): 96090ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -246
app.py CHANGED
@@ -1,246 +1,256 @@
1
- import streamlit as st
2
- import os
3
- from langchain_core.messages import AIMessage, HumanMessage
4
- from langchain_core.prompts import ChatPromptTemplate
5
- from langchain_core.runnables import RunnablePassthrough
6
- from langchain_community.utilities import SQLDatabase
7
- from langchain_core.output_parsers import StrOutputParser
8
- from langchain_openai import ChatOpenAI
9
- from langchain_groq import ChatGroq
10
- import toml
11
-
12
- # Function to update secrets.toml file
13
- def update_secrets_file(data):
14
- secrets_file_path = ".streamlit/secrets.toml"
15
- if os.path.exists(secrets_file_path):
16
- with open(secrets_file_path, "r") as file:
17
- secrets_data = toml.load(file)
18
- else:
19
- secrets_data = {}
20
-
21
- secrets_data.update(data)
22
-
23
- with open(secrets_file_path, "w") as file:
24
- toml.dump(secrets_data, file)
25
-
26
- # Initialize database connections
27
- def init_databases():
28
- secrets_file_path = ".streamlit/secrets.toml"
29
- with open(secrets_file_path, "r") as file:
30
- secrets_data = toml.load(file)
31
-
32
- db_connections = {}
33
- for database in secrets_data["Databases"].split(','):
34
- database = database.strip()
35
- db_uri = f"mysql+mysqlconnector://{secrets_data['User']}:{secrets_data['Password']}@{secrets_data['Host']}:{secrets_data['Port']}/{database}"
36
- db_connections[database] = SQLDatabase.from_uri(db_uri)
37
- return db_connections
38
-
39
- # Function to get SQL chain
40
- def get_sql_chain(dbs, llm):
41
- template = """
42
- You are a Senior and vastly experienced Data analyst at a company with around 20 years of experience.
43
- You are interacting with a user who is asking you questions about the company's databases.
44
- Based on the table schemas below, write SQL queries that would answer the user's question. Take the conversation history into account.
45
-
46
- <SCHEMAS>{schemas}</SCHEMAS>
47
-
48
- Conversation History: {chat_history}
49
-
50
- Write the SQL queries for each relevant database, prefixed by the database name (e.g., DB1: SELECT * FROM ...; DB2: SELECT * FROM ...).
51
- Do not wrap the SQL queries in any other text, not even backticks.
52
-
53
- For example:
54
- Question: which 3 artists have the most tracks?
55
- SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
56
- Question: Name 10 artists
57
- SQL Query: SELECT Name FROM Artist LIMIT 10;
58
- Question: How much is the price of the inventory for all small size t-shirts?
59
- SQL Query: SELECT SUM(price * stock_quantity) FROM t_shirts WHERE size = 'S';
60
- Question: If we have to sell all the Levi's T-shirts today with discounts applied. How much revenue our store will generate (post discounts)?
61
- SQL Query: SELECT SUM(a.total_amount * ((100 - COALESCE(discounts.pct_discount, 0)) / 100)) AS total_revenue
62
- FROM (SELECT SUM(price * stock_quantity) AS total_amount, t_shirt_id
63
- FROM t_shirts
64
- WHERE brand = 'Levi' GROUP BY t_shirt_id) a
65
- LEFT JOIN discounts ON a.t_shirt_id = discounts.t_shirt_id;
66
- Question: For each brand, find the total revenue generated from t-shirts with a discount applied, grouped by the discount percentage.
67
- SQL Query: SELECT brand, COALESCE(discounts.pct_discount, 0) AS discount_pct, SUM(t.price * t.stock_quantity * (1 - COALESCE(discounts.pct_discount, 0) / 100)) AS total_revenue
68
- FROM t_shirts t
69
- LEFT JOIN discounts ON t.t_shirt_id = discounts.t_shirt_id
70
- GROUP BY brand, COALESCE(discounts.pct_discount, 0);
71
- Question: Find the top 3 most popular colors for each brand, based on the total stock quantity.
72
- SQL Query: SELECT brand, color, SUM(stock_quantity) AS total_stock
73
- FROM t_shirts
74
- GROUP BY brand, color
75
- ORDER BY brand, total_stock DESC;
76
-
77
- Question: Calculate the average price per size for each brand, excluding sizes with less than 10 t-shirts in stock.
78
- SQL Query: SELECT brand, size, AVG(price) AS avg_price
79
- FROM t_shirts
80
- WHERE stock_quantity >= 10
81
- GROUP BY brand, size
82
- HAVING COUNT(*) >= 10;
83
-
84
- Question: Find the brand and color combination with the highest total revenue, considering discounts.
85
- SQL Query: SELECT brand, color, SUM(t.price * t.stock_quantity * (1 - COALESCE(d.pct_discount, 0) / 100)) AS total_revenue
86
- FROM t_shirts t
87
- LEFT JOIN discounts d ON t.t_shirt_id = d.t_shirt_id
88
- GROUP BY brand, color
89
- ORDER BY total_revenue DESC
90
- LIMIT 1;
91
-
92
- Question: Create a view that shows the total stock quantity and revenue for each brand, size, and color combination.
93
- SQL Query: CREATE VIEW brand_size_color_stats AS
94
- SELECT brand, size, color, SUM(stock_quantity) AS total_stock, SUM(price * stock_quantity) AS total_revenue
95
- FROM t_shirts
96
- GROUP BY brand, size, color;
97
-
98
- Question: How much is the price of the inventory for all varients t-shirts and group them y brands?
99
- SQL Query: SELECT brand, SUM(price * stock_quantity) FROM t_shirts GROUP BY brand;
100
-
101
- Question: List the total revenue of t-shirts of L size for all brands
102
- SQL Query: SELECT brand, SUM(price * stock_quantity) AS total_revenue FROM t_shirts WHERE size = 'L' GROUP BY brand;
103
-
104
- Question: How many shirts are available in stock grouped by colours from each size and finally show me all brands?
105
- SQL Query: SELECT brand, color, size, SUM(stock_quantity) AS total_stock FROM t_shirts GROUP BY brand, color, size
106
-
107
- Your turn:
108
-
109
- Question: {question}
110
- SQL Queries:
111
- """
112
-
113
- prompt = ChatPromptTemplate.from_template(template)
114
- llm = llm
115
-
116
- def get_schema(_):
117
- schemas = {db_name: db.get_table_info() for db_name, db in dbs.items()}
118
- return schemas
119
-
120
- return (
121
- RunnablePassthrough.assign(schemas=get_schema)
122
- | prompt
123
- | llm
124
- | StrOutputParser()
125
- | (lambda result: {line.split(":")[0]: line.split(":")[1].strip() for line in result.strip().split("\n") if ":" in line and line.strip()})
126
- )
127
-
128
- # Function to get response
129
- def get_response(user_query, dbs, chat_history, llm):
130
- sql_chain = get_sql_chain(dbs, llm)
131
-
132
- template = """
133
- You are a Senior and vastly experienced Data analyst at a company with around 20 years of experience.
134
- You are interacting with a user who is asking you questions about the company's databases.
135
- Based on the table schemas below, question, sql queries, and sql responses, write an
136
- accurate natural language response so that the end user can understand things
137
- and make sure do not include words like "Based on the SQL queries I ran".
138
- Just provide only the answer with some text that the user expects.
139
- <SCHEMAS>{schemas}</SCHEMAS>
140
- Conversation History: {chat_history}
141
- SQL Queries: <SQL>{queries}</SQL>
142
- User question: {question}
143
- SQL Responses: {responses}"""
144
-
145
- prompt = ChatPromptTemplate.from_template(template)
146
- llm = llm
147
-
148
- def run_queries(var):
149
- responses = {}
150
- for db_name, query in var["queries"].items():
151
- responses[db_name] = dbs[db_name].run(query)
152
- return responses
153
-
154
- chain = (
155
- RunnablePassthrough.assign(queries=sql_chain).assign(
156
- schemas=lambda _: {db_name: db.get_table_info() for db_name, db in dbs.items()},
157
- responses=run_queries) # The comma at the end of the assign() method call is used to indicate that there may be more keyword arguments or method calls following it
158
- | prompt
159
- | llm
160
- | StrOutputParser()
161
- )
162
-
163
- return chain.invoke({
164
- "question": user_query,
165
- "chat_history": chat_history,
166
- })
167
-
168
- # Streamlit app configuration
169
- if "chat_history" not in st.session_state:
170
- st.session_state.chat_history = [
171
- AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
172
- ]
173
-
174
- st.set_page_config(page_title="Chat with MySQL", page_icon="🛢️")
175
- st.title("Chat with MySQL")
176
-
177
- with st.sidebar:
178
- st.subheader("Settings")
179
- st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
180
-
181
- if "db" not in st.session_state:
182
- st.session_state.Host = st.text_input("Host", value=st.secrets.get("Host", ""))
183
- st.session_state.Port = st.text_input("Port", value=st.secrets.get("Port", ""))
184
- st.session_state.User = st.text_input("User", value=st.secrets.get("User", ""))
185
- st.session_state.Password = st.text_input("Password", type="password", value=st.secrets.get("Password", ""))
186
- st.session_state.Databases = st.text_input("Databases", placeholder="Enter DB's separated by (,)", value=st.secrets.get("Databases", ""))
187
- st.session_state.openai_api_key = st.text_input("OpenAI API Key", type="password", help="Get your API key from [OpenAI Website](https://platform.openai.com/api-keys)", value=st.secrets.get("openai_api_key", ""))
188
- st.session_state.groq_api_key = st.text_input("Groq API Key", type="password", help="Get your API key from [GROQ Console](https://console.groq.com/keys)", value=st.secrets.get("groq_api_key", ""))
189
-
190
- st.info("Note: For interacting multiple databases, GPT-4 Model is recommended for accurate results else proceed with Groq Model")
191
-
192
- os.environ["OPENAI_API_KEY"] = str(st.session_state.openai_api_key)
193
-
194
- if st.button("Connect"):
195
- with st.spinner("Connecting to databases..."):
196
-
197
- # Update secrets.toml with connection details
198
- update_secrets_file({
199
- "Host": st.session_state.Host,
200
- "Port": st.session_state.Port,
201
- "User": st.session_state.User,
202
- "Password": st.session_state.Password,
203
- "Databases": st.session_state.Databases
204
- })
205
-
206
- dbs = init_databases()
207
- st.session_state.dbs = dbs
208
-
209
- if len(dbs) > 1:
210
- st.success(f"Connected to {len(dbs)} databases")
211
- else:
212
- st.success("Connected to database")
213
-
214
-
215
-
216
- if st.session_state.openai_api_key == "" and st.session_state.groq_api_key == "":
217
- st.error("Enter one API Key At least")
218
- elif st.session_state.openai_api_key:
219
- st.session_state.llm = ChatOpenAI(model="gpt-4-turbo", api_key=st.session_state.openai_api_key)
220
- elif st.session_state.groq_api_key:
221
- st.session_state.llm = ChatGroq(model="llama3-70b-8192", temperature=0.4, api_key=st.session_state.groq_api_key)
222
- else:
223
- pass
224
-
225
- # Display chat messages
226
- for message in st.session_state.chat_history:
227
- if isinstance(message, AIMessage):
228
- with st.chat_message("AI"):
229
- st.markdown(message.content)
230
- elif isinstance(message, HumanMessage):
231
- with st.chat_message("Human"):
232
- st.markdown(message.content)
233
-
234
- # Handle user input
235
- user_query = st.chat_input("Type a message...")
236
- if user_query is not None and user_query.strip() != "":
237
- st.session_state.chat_history.append(HumanMessage(content=user_query))
238
-
239
- with st.chat_message("Human"):
240
- st.markdown(user_query)
241
-
242
- with st.chat_message("AI"):
243
- response = get_response(user_query, st.session_state.dbs, st.session_state.chat_history, st.session_state.llm)
244
- st.markdown(response)
245
-
246
- st.session_state.chat_history.append(AIMessage(content=response))
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from langchain_core.messages import AIMessage, HumanMessage
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langchain_community.utilities import SQLDatabase
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_groq import ChatGroq
10
+ import toml
11
+
12
+ # Function to update config.toml file
13
+ def update_secrets_file(data):
14
+ secrets_file_path = ".streamlit/secrets.toml"
15
+ secrets_data = {}
16
+
17
+ # Load existing data from secrets.toml
18
+ if os.path.exists(secrets_file_path):
19
+ with open(secrets_file_path, "r") as file:
20
+ secrets_data = toml.load(file)
21
+
22
+ # Update secrets data with new data
23
+ secrets_data.update(data)
24
+
25
+ # Write updated data back to secrets.toml
26
+ with open(secrets_file_path, "w+") as file:
27
+ toml.dump(secrets_data, file)
28
+
29
+
30
+ # Initialize database connections
31
+ def init_databases():
32
+ secrets_file_path = ".streamlit/secrets.toml"
33
+ secrets_data = {}
34
+ if os.path.exists(secrets_file_path):
35
+ with open(secrets_file_path, "r") as file:
36
+ content = file.read().strip()
37
+ if content:
38
+ secrets_data = toml.loads(content)
39
+
40
+ db_connections = {}
41
+ for database in secrets_data.get("Databases", "").split(','):
42
+ database = database.strip()
43
+ if database:
44
+ db_uri = f"mysql+mysqlconnector://{secrets_data['User']}:{secrets_data['Password']}@{secrets_data['Host']}:{secrets_data['Port']}/{database}"
45
+ db_connections[database] = SQLDatabase.from_uri(db_uri)
46
+ return db_connections
47
+
48
+
49
+ # Function to get SQL chain
50
+ def get_sql_chain(dbs, llm):
51
+ template = """
52
+ You are a Senior and vastly experienced Data analyst at a company with around 20 years of experience.
53
+ You are interacting with a user who is asking you questions about the company's databases.
54
+ Based on the table schemas below, write SQL queries that would answer the user's question. Take the conversation history into account.
55
+
56
+ <SCHEMAS>{schemas}</SCHEMAS>
57
+
58
+ Conversation History: {chat_history}
59
+
60
+ Write the SQL queries for each relevant database, prefixed by the database name (e.g., DB1: SELECT * FROM ...; DB2: SELECT * FROM ...).
61
+ Do not wrap the SQL queries in any other text, not even backticks.
62
+
63
+ For example:
64
+ Question: which 3 artists have the most tracks?
65
+ SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
66
+ Question: Name 10 artists
67
+ SQL Query: SELECT Name FROM Artist LIMIT 10;
68
+ Question: How much is the price of the inventory for all small size t-shirts?
69
+ SQL Query: SELECT SUM(price * stock_quantity) FROM t_shirts WHERE size = 'S';
70
+ Question: If we have to sell all the Levi's T-shirts today with discounts applied. How much revenue our store will generate (post discounts)?
71
+ SQL Query: SELECT SUM(a.total_amount * ((100 - COALESCE(discounts.pct_discount, 0)) / 100)) AS total_revenue
72
+ FROM (SELECT SUM(price * stock_quantity) AS total_amount, t_shirt_id
73
+ FROM t_shirts
74
+ WHERE brand = 'Levi' GROUP BY t_shirt_id) a
75
+ LEFT JOIN discounts ON a.t_shirt_id = discounts.t_shirt_id;
76
+ Question: For each brand, find the total revenue generated from t-shirts with a discount applied, grouped by the discount percentage.
77
+ SQL Query: SELECT brand, COALESCE(discounts.pct_discount, 0) AS discount_pct, SUM(t.price * t.stock_quantity * (1 - COALESCE(discounts.pct_discount, 0) / 100)) AS total_revenue
78
+ FROM t_shirts t
79
+ LEFT JOIN discounts ON t.t_shirt_id = discounts.t_shirt_id
80
+ GROUP BY brand, COALESCE(discounts.pct_discount, 0);
81
+ Question: Find the top 3 most popular colors for each brand, based on the total stock quantity.
82
+ SQL Query: SELECT brand, color, SUM(stock_quantity) AS total_stock
83
+ FROM t_shirts
84
+ GROUP BY brand, color
85
+ ORDER BY brand, total_stock DESC;
86
+
87
+ Question: Calculate the average price per size for each brand, excluding sizes with less than 10 t-shirts in stock.
88
+ SQL Query: SELECT brand, size, AVG(price) AS avg_price
89
+ FROM t_shirts
90
+ WHERE stock_quantity >= 10
91
+ GROUP BY brand, size
92
+ HAVING COUNT(*) >= 10;
93
+
94
+ Question: Find the brand and color combination with the highest total revenue, considering discounts.
95
+ SQL Query: SELECT brand, color, SUM(t.price * t.stock_quantity * (1 - COALESCE(d.pct_discount, 0) / 100)) AS total_revenue
96
+ FROM t_shirts t
97
+ LEFT JOIN discounts d ON t.t_shirt_id = d.t_shirt_id
98
+ GROUP BY brand, color
99
+ ORDER BY total_revenue DESC
100
+ LIMIT 1;
101
+
102
+ Question: Create a view that shows the total stock quantity and revenue for each brand, size, and color combination.
103
+ SQL Query: CREATE VIEW brand_size_color_stats AS
104
+ SELECT brand, size, color, SUM(stock_quantity) AS total_stock, SUM(price * stock_quantity) AS total_revenue
105
+ FROM t_shirts
106
+ GROUP BY brand, size, color;
107
+
108
+ Question: How much is the price of the inventory for all varients t-shirts and group them y brands?
109
+ SQL Query: SELECT brand, SUM(price * stock_quantity) FROM t_shirts GROUP BY brand;
110
+
111
+ Question: List the total revenue of t-shirts of L size for all brands
112
+ SQL Query: SELECT brand, SUM(price * stock_quantity) AS total_revenue FROM t_shirts WHERE size = 'L' GROUP BY brand;
113
+
114
+ Question: How many shirts are available in stock grouped by colours from each size and finally show me all brands?
115
+ SQL Query: SELECT brand, color, size, SUM(stock_quantity) AS total_stock FROM t_shirts GROUP BY brand, color, size
116
+
117
+ Your turn:
118
+
119
+ Question: {question}
120
+ SQL Queries:
121
+ """
122
+
123
+ prompt = ChatPromptTemplate.from_template(template)
124
+ llm = llm
125
+
126
+ def get_schema(_):
127
+ schemas = {db_name: db.get_table_info() for db_name, db in dbs.items()}
128
+ return schemas
129
+
130
+ return (
131
+ RunnablePassthrough.assign(schemas=get_schema)
132
+ | prompt
133
+ | llm
134
+ | StrOutputParser()
135
+ | (lambda result: {line.split(":")[0]: line.split(":")[1].strip() for line in result.strip().split("\n") if ":" in line and line.strip()})
136
+ )
137
+
138
+ # Function to get response
139
+ def get_response(user_query, dbs, chat_history, llm):
140
+ sql_chain = get_sql_chain(dbs, llm)
141
+
142
+ template = """
143
+ You are a Senior and vastly experienced Data analyst at a company with around 20 years of experience.
144
+ You are interacting with a user who is asking you questions about the company's databases.
145
+ Based on the table schemas below, question, sql queries, and sql responses, write an
146
+ accurate natural language response so that the end user can understand things
147
+ and make sure do not include words like "Based on the SQL queries I ran".
148
+ Just provide only the answer with some text that the user expects.
149
+ <SCHEMAS>{schemas}</SCHEMAS>
150
+ Conversation History: {chat_history}
151
+ SQL Queries: <SQL>{queries}</SQL>
152
+ User question: {question}
153
+ SQL Responses: {responses}"""
154
+
155
+ prompt = ChatPromptTemplate.from_template(template)
156
+ llm = llm
157
+
158
+ def run_queries(var):
159
+ responses = {}
160
+ for db_name, query in var["queries"].items():
161
+ responses[db_name] = dbs[db_name].run(query)
162
+ return responses
163
+
164
+ chain = (
165
+ RunnablePassthrough.assign(queries=sql_chain).assign(
166
+ schemas=lambda _: {db_name: db.get_table_info() for db_name, db in dbs.items()},
167
+ responses=run_queries) # The comma at the end of the assign() method call is used to indicate that there may be more keyword arguments or method calls following it
168
+ | prompt
169
+ | llm
170
+ | StrOutputParser()
171
+ )
172
+
173
+ return chain.invoke({
174
+ "question": user_query,
175
+ "chat_history": chat_history,
176
+ })
177
+
178
+ # Streamlit app configuration
179
+ if "chat_history" not in st.session_state:
180
+ st.session_state.chat_history = [
181
+ AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
182
+ ]
183
+
184
+ st.set_page_config(page_title="Chat with MySQL", page_icon="🛢️")
185
+ st.title("Chat with MySQL")
186
+
187
+ with st.sidebar:
188
+ st.subheader("Settings")
189
+ st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
190
+
191
+ if "db" not in st.session_state:
192
+ st.session_state.Host = st.text_input("Host")
193
+ st.session_state.Port = st.text_input("Port")
194
+ st.session_state.User = st.text_input("User")
195
+ st.session_state.Password = st.text_input("Password", type="password")
196
+ st.session_state.Databases = st.text_input("Databases", placeholder="Enter DB's separated by (,)")
197
+ st.session_state.openai_api_key = st.text_input("OpenAI API Key", type="password", help="Get your API key from [OpenAI Website](https://platform.openai.com/api-keys)")
198
+ st.session_state.groq_api_key = st.text_input("Groq API Key", type="password", help="Get your API key from [GROQ Console](https://console.groq.com/keys)")
199
+
200
+ st.info("Note: For interacting multiple databases, GPT-4 Model is recommended for accurate results else proceed with Groq Model")
201
+
202
+ os.environ["OPENAI_API_KEY"] = str(st.session_state.openai_api_key)
203
+
204
+ if st.button("Connect"):
205
+ with st.spinner("Connecting to databases..."):
206
+
207
+ # Update secrets.toml with connection details
208
+ update_secrets_file({
209
+ "Host": st.session_state.Host,
210
+ "Port": st.session_state.Port,
211
+ "User": st.session_state.User,
212
+ "Password": st.session_state.Password,
213
+ "Databases": st.session_state.Databases
214
+ })
215
+
216
+ dbs = init_databases()
217
+ st.session_state.dbs = dbs
218
+
219
+ if len(dbs) > 1:
220
+ st.success(f"Connected to {len(dbs)} databases")
221
+ else:
222
+ st.success("Connected to database")
223
+
224
+
225
+
226
+ if st.session_state.openai_api_key == "" and st.session_state.groq_api_key == "":
227
+ st.error("Enter one API Key At least")
228
+ elif st.session_state.openai_api_key:
229
+ st.session_state.llm = ChatOpenAI(model="gpt-4-turbo", api_key=st.session_state.openai_api_key)
230
+ elif st.session_state.groq_api_key:
231
+ st.session_state.llm = ChatGroq(model="llama3-70b-8192", temperature=0.4, api_key=st.session_state.groq_api_key)
232
+ else:
233
+ pass
234
+
235
+ # Display chat messages
236
+ for message in st.session_state.chat_history:
237
+ if isinstance(message, AIMessage):
238
+ with st.chat_message("AI"):
239
+ st.markdown(message.content)
240
+ elif isinstance(message, HumanMessage):
241
+ with st.chat_message("Human"):
242
+ st.markdown(message.content)
243
+
244
+ # Handle user input
245
+ user_query = st.chat_input("Type a message...")
246
+ if user_query is not None and user_query.strip() != "":
247
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
248
+
249
+ with st.chat_message("Human"):
250
+ st.markdown(user_query)
251
+
252
+ with st.chat_message("AI"):
253
+ response = get_response(user_query, st.session_state.dbs, st.session_state.chat_history, st.session_state.llm)
254
+ st.markdown(response)
255
+
256
+ st.session_state.chat_history.append(AIMessage(content=response))