Update src/streamlit_app.py
Browse files- src/streamlit_app.py +82 -44
src/streamlit_app.py
CHANGED
|
@@ -172,20 +172,19 @@ def create_bar_chart(df, view_type):
|
|
| 172 |
color_continuous_scale=px.colors.diverging.Fall,
|
| 173 |
orientation="v",
|
| 174 |
)
|
| 175 |
-
|
| 176 |
fig.update_layout(
|
| 177 |
xaxis_title_text="Model",
|
| 178 |
yaxis_title_text="Score (Lower is better)",
|
| 179 |
title_text="",
|
| 180 |
font=dict(size=15),
|
| 181 |
xaxis_tickangle=-45,
|
| 182 |
-
bargap=0.2,
|
| 183 |
-
height=600,
|
| 184 |
showlegend=False,
|
| 185 |
margin=dict(
|
| 186 |
l=80, # Left
|
| 187 |
r=0, # Right
|
| 188 |
-
b=80,
|
| 189 |
t=70, # Top
|
| 190 |
pad=0 # Padding
|
| 191 |
),
|
|
@@ -199,48 +198,87 @@ def create_bar_chart(df, view_type):
|
|
| 199 |
texttemplate="%{y:.2f}",
|
| 200 |
textposition="outside"
|
| 201 |
)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
# TODO
|
| 205 |
-
# fig = go.Figure(data=[
|
| 206 |
-
# go.Bar(
|
| 207 |
-
# x=df['Model'],
|
| 208 |
-
# y=df['Total Score'],
|
| 209 |
-
# orientation='v',
|
| 210 |
-
# marker_color=px.colors.sequential.Blues,
|
| 211 |
-
# text=df['Total Score'].round(1),
|
| 212 |
-
# textposition='outside',
|
| 213 |
-
# )
|
| 214 |
-
# ])
|
| 215 |
-
# fig.update_layout(
|
| 216 |
-
# title="Model Performance - Total Score",
|
| 217 |
-
# xaxis_title="Model",
|
| 218 |
-
# yaxis_title="Score",
|
| 219 |
-
# yaxis_range=[0, 100],
|
| 220 |
-
# height=500,
|
| 221 |
-
# )
|
| 222 |
|
| 223 |
elif view_type == "Per Embodiment":
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
#
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
#
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
#
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
#
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
#
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
else: # Per Category
|
| 246 |
# category_cols = [col for col in df.columns if col.startswith('Category-')]
|
|
|
|
| 172 |
color_continuous_scale=px.colors.diverging.Fall,
|
| 173 |
orientation="v",
|
| 174 |
)
|
|
|
|
| 175 |
fig.update_layout(
|
| 176 |
xaxis_title_text="Model",
|
| 177 |
yaxis_title_text="Score (Lower is better)",
|
| 178 |
title_text="",
|
| 179 |
font=dict(size=15),
|
| 180 |
xaxis_tickangle=-45,
|
| 181 |
+
bargap=0.2,
|
| 182 |
+
height=600,
|
| 183 |
showlegend=False,
|
| 184 |
margin=dict(
|
| 185 |
l=80, # Left
|
| 186 |
r=0, # Right
|
| 187 |
+
b=80, # Bottom
|
| 188 |
t=70, # Top
|
| 189 |
pad=0 # Padding
|
| 190 |
),
|
|
|
|
| 198 |
texttemplate="%{y:.2f}",
|
| 199 |
textposition="outside"
|
| 200 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
elif view_type == "Per Embodiment":
|
| 203 |
+
|
| 204 |
+
# Format df
|
| 205 |
+
df_fig = df.copy()
|
| 206 |
+
df_fig = df_fig[df_fig["score"] != np.inf]
|
| 207 |
+
|
| 208 |
+
# Calculate the model order
|
| 209 |
+
df_model_order = df_fig.groupby("model")[["score"]].mean().reset_index()
|
| 210 |
+
model_order = df_model_order.sort_values(by="score", ascending=True)["models"].tolist()
|
| 211 |
+
|
| 212 |
+
# Calculate mean score per model and embodiment
|
| 213 |
+
df_fig = df_fig.groupby(["model", "embodiment"])[["score"]].mean().reset_index()
|
| 214 |
+
|
| 215 |
+
# Sort the results from best to worst
|
| 216 |
+
df_fig = df_fig.sort_values(by="score", ascending=True)
|
| 217 |
+
|
| 218 |
+
# Convert the "model" column to a categorical type with the sorted order
|
| 219 |
+
df_fig["model"] = pd.Categorical(df_fig["model"], categories=model_order, ordered=True)
|
| 220 |
+
|
| 221 |
+
# Sort the DataFrame based on the new categorical order
|
| 222 |
+
df_fig = df_fig.sort_values(by=["model", "score"], ascending=[True, True])
|
| 223 |
+
|
| 224 |
+
# Create the Plotly figure
|
| 225 |
+
fig = px.bar(
|
| 226 |
+
df_fig,
|
| 227 |
+
x="model",
|
| 228 |
+
y="score",
|
| 229 |
+
color="embodiment",
|
| 230 |
+
color_continuous_scale=px.colors.qualitative.Plotly,
|
| 231 |
+
orientation="v",
|
| 232 |
+
)
|
| 233 |
+
max_score = df_fig["score"].max()
|
| 234 |
+
fig.update_layout(
|
| 235 |
+
xaxis_title_text="Model",
|
| 236 |
+
yaxis_title_text="Score (Lower is better)",
|
| 237 |
+
title_text="",
|
| 238 |
+
font=dict(size=15),
|
| 239 |
+
xaxis_tickangle=-45,
|
| 240 |
+
bargap=0.1,
|
| 241 |
+
barmode="group",
|
| 242 |
+
height=600,
|
| 243 |
+
margin=dict(
|
| 244 |
+
l=80, # Left
|
| 245 |
+
r=0, # Right
|
| 246 |
+
b=80, # Bottom
|
| 247 |
+
t=70, # Top
|
| 248 |
+
pad=0 # Padding
|
| 249 |
+
),
|
| 250 |
+
showlegend=True,
|
| 251 |
+
legend=dict(
|
| 252 |
+
x=0,
|
| 253 |
+
y=1.3,
|
| 254 |
+
xanchor="left",
|
| 255 |
+
yanchor="top",
|
| 256 |
+
bgcolor="rgba(255,255,255,0)",
|
| 257 |
+
bordercolor="rgba(0,0,0,0)",
|
| 258 |
+
borderwidth=1,
|
| 259 |
+
itemclick="toggle",
|
| 260 |
+
itemdoubleclick="toggleothers",
|
| 261 |
+
title=dict(
|
| 262 |
+
text=f"<b>Embodiments</b>",
|
| 263 |
+
font=dict(size=12, color="black"),
|
| 264 |
+
side="top center",
|
| 265 |
+
),
|
| 266 |
+
font=dict(size=10, color='black')
|
| 267 |
+
),
|
| 268 |
+
uniformtext_minsize=10,
|
| 269 |
+
uniformtext_mode="show",
|
| 270 |
+
yaxis_range=[0, max_score * 1.25]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Remove the color legend from the chart.
|
| 274 |
+
fig.update_coloraxes(showscale=False)
|
| 275 |
+
|
| 276 |
+
# Add annotations to show the exact score on each bar.
|
| 277 |
+
fig.update_traces(
|
| 278 |
+
texttemplate="%{y:.2f}",
|
| 279 |
+
textposition="outside",
|
| 280 |
+
textangle=-90,
|
| 281 |
+
)
|
| 282 |
|
| 283 |
else: # Per Category
|
| 284 |
# category_cols = [col for col in df.columns if col.startswith('Category-')]
|