ReactionT5 / app.py
sagawa's picture
Update app.py
60bbe72 verified
raw
history blame
15.1 kB
import gc
import os
import warnings
from types import SimpleNamespace
import pandas as pd
import numpy as np
import streamlit as st
import torch
# Local imports
from generation_utils import (
ReactionT5Dataset,
decode_output,
save_multiple_predictions,
)
from models import ReactionT5Yield2
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import seed_everything
warnings.filterwarnings("ignore")
# ------------------------------
# Page setup
# ------------------------------
st.set_page_config(
page_title="ReactionT5",
page_icon=None,
layout="wide",
)
st.title("ReactionT5")
st.caption(
"Predict reaction products, reactants, or yields from your inputs using a pretrained ReactionT5 model."
)
# ------------------------------
# Sidebar: configuration
# ------------------------------
with st.sidebar:
st.header("Configuration")
task = st.selectbox(
"Task",
options=["product prediction", "retrosynthesis prediction", "yield prediction"],
index=0,
help="Choose the task to run.",
)
with st.expander("How to format your CSV", expanded=False):
if task == "product prediction":
st.markdown(
"""
- `REACTANT` column is required.
- Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`.
- If a field lists multiple compounds, separate them with a dot (`.`).
- For details, download **demo_reaction_data.csv** and check its contents.
"""
)
elif task == "retrosynthesis prediction":
st.markdown(
"""
- `PRODUCT` column is required.
- No optional columns are used.
- If a field lists multiple compounds, separate them with a dot (`.`).
- For details, download **demo_retro_data.csv** and check its contents.
"""
)
else: # yield prediction
st.markdown(
"""
- `REACTANT` and `PRODUCT` columns are required.
- Optional columns: `REAGENT`, `SOLVENT`, `CATALYST`.
- If a field lists multiple compounds, separate them with a dot (`.`).
- For details, download **demo_yield_data.csv** and check its contents.
- Output contains predicted **reaction yield** on a **0–100% scale**.
"""
)
# ------------------------------
# Demo data download
# ------------------------------
import io
@st.cache_data(show_spinner=False)
def parse_csv_from_bytes(file_bytes: bytes) -> pd.DataFrame:
# If your files are always UTF-8, this is fine:
return pd.read_csv(io.BytesIO(file_bytes))
# If you prefer explicit text decoding:
# return pd.read_csv(io.StringIO(file_bytes.decode("utf-8")))
@st.cache_data(show_spinner=False)
def load_demo_csv_as_bytes() -> bytes:
demo_df = pd.read_csv("data/demo_reaction_data.csv")
return demo_df.to_csv(index=False).encode("utf-8")
st.download_button(
label="Download demo_reaction_data.csv",
data=load_demo_csv_as_bytes(),
file_name="demo_reaction_data.csv",
mime="text/csv",
use_container_width=True,
)
st.divider()
# ------------------------------
# Sidebar: configuration
# ------------------------------
with st.sidebar:
st.header("Configuration")
# Model options tied to task
if task == "product prediction":
model_options = [
"sagawa/ReactionT5v2-forward",
"sagawa/ReactionT5v2-forward-USPTO_MIT",
]
model_help = "Recommended models for product prediction."
input_max_length_default = 400
output_max_length_default = 300
from task_forward.train import preprocess_df
elif task == "retrosynthesis prediction":
model_options = [
"sagawa/ReactionT5v2-retrosynthesis",
"sagawa/ReactionT5v2-retrosynthesis-USPTO_50k",
]
model_help = "Recommended models for retrosynthesis prediction."
input_max_length_default = 100
output_max_length_default = 400
from task_retrosynthesis.train import preprocess_df
else: # yield prediction
model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
model_help = "Default model for yield prediction."
input_max_length_default = 400
from task_yield.train import preprocess_df
model_name_or_path = st.selectbox(
"Model",
options=model_options,
index=0,
help=model_help,
)
if task != "yield prediction":
num_beams = st.slider(
"Beam size",
min_value=1,
max_value=10,
value=5,
step=1,
help="Number of beams for beam search.",
)
seed = st.number_input(
"Random seed",
min_value=0,
max_value=2**32 - 1,
value=42,
step=1,
help="Seed for reproducibility.",
)
with st.expander("Advanced generation", expanded=False):
input_max_length = st.number_input(
"Input max length",
min_value=8,
max_value=1024,
value=input_max_length_default,
step=8,
)
if task != "yield prediction":
output_max_length = st.number_input(
"Output max length",
min_value=8,
max_value=1024,
value=output_max_length_default,
step=8,
)
output_min_length = st.number_input(
"Output min length",
min_value=-1,
max_value=1024,
value=-1,
step=1,
help="Use -1 to let the model decide.",
)
batch_size = st.number_input(
"Batch size", min_value=1, max_value=16, value=1, step=1
)
num_workers = st.number_input(
"DataLoader workers",
min_value=0,
max_value=8,
value=4,
step=1,
help="Set to 0 if multiprocessing is restricted in your environment.",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.caption(f"Detected device: **{device.type.upper()}**")
# ------------------------------
# Cached loaders
# ------------------------------
@st.cache_resource(show_spinner=False)
def load_tokenizer(model_ref: str):
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
return AutoTokenizer.from_pretrained(resolved, return_tensors="pt")
@st.cache_resource(show_spinner=True)
def load_model(model_ref: str, device_str: str, task: str):
resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
if task != "yield prediction":
model = AutoModelForSeq2SeqLM.from_pretrained(resolved)
else:
model = ReactionT5Yield2.from_pretrained(resolved)
model.to(torch.device(device_str))
model.eval()
return model
@st.cache_data(show_spinner=False)
def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
return df.to_csv(index=False).encode("utf-8")
# ------------------------------
# Main interaction
# ------------------------------
left, right = st.columns([1.4, 1.0], vertical_alignment="top")
with left:
with st.form("predict_form", clear_on_submit=False):
uploaded = st.file_uploader(
"Upload a CSV file with reactions",
type=["csv"],
accept_multiple_files=False,
help="Must contain a REACTANT column. Optional: REAGENT, SOLVENT, CATALYST.",
)
run = st.form_submit_button("Predict", use_container_width=True)
if uploaded is not None:
try:
file_bytes = uploaded.getvalue()
raw_df = parse_csv_from_bytes(file_bytes)
# raw_df = pd.read_csv(uploaded)
st.subheader("Input preview")
st.dataframe(raw_df.head(20), use_container_width=True)
except Exception as e:
st.error(f"Failed to read CSV: {e}")
with right:
st.subheader("Notes")
if task == "product prediction":
st.markdown(
f"""
- Approximate time: about **3 seconds per reaction** when `beam size = 5` (varies by hardware).
- Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable).
"""
)
elif task == "retrosynthesis prediction":
st.markdown(
f"""
- Approximate time: about **5 seconds per reaction** when `beam size = 5` (varies by hardware).
- Output contains predicted **sets of reactant SMILES** and their log-likelihoods, sorted by log-likelihood (index 0 is most probable).
"""
)
else: # yield prediction
st.markdown(
f"""
- Approximate time: about **0.25 seconds per reaction** when `batch size = 1` (varies by hardware).
- Output contains predicted **reaction yield** on a **0–100% scale**.
"""
)
st.info(
"In this space, CPU is used for inference. So the speed is slower than using a GPU."
)
# ------------------------------
# Inference
# ------------------------------
if "results_df" not in st.session_state:
st.session_state["results_df"] = None
if "last_error" not in st.session_state:
st.session_state["last_error"] = None
if run:
if uploaded is None:
st.warning("Please upload a CSV file before running prediction.")
else:
# Build config object expected by your dataset/utils
CFG = SimpleNamespace(
task=task,
num_beams=int(num_beams) if task != "yield prediction" else None,
num_return_sequences=int(num_beams)
if task != "yield prediction"
else None, # tie to beams by default
model_name_or_path=model_name_or_path,
input_column="input",
input_max_length=int(input_max_length)
if task != "yield prediction"
else None,
output_max_length=int(output_max_length)
if task != "yield prediction"
else None,
output_min_length=int(output_min_length)
if task != "yield prediction"
else None,
seed=int(seed),
batch_size=int(batch_size),
debug=False
)
seed_everything(seed=CFG.seed)
# Load model & tokenizer
with st.status("Loading model and tokenizer...", expanded=False) as status:
try:
tokenizer = load_tokenizer(CFG.model_name_or_path)
CFG.tokenizer = tokenizer
model = load_model(CFG.model_name_or_path, device.type, task)
status.update(label="Model ready.", state="complete")
except Exception as e:
st.session_state["last_error"] = f"Failed to load model: {e}"
status.update(label="Model load failed.", state="error")
st.stop()
# Prepare data
file_bytes = uploaded.getvalue()
input_df = parse_csv_from_bytes(file_bytes)
if task != "yield prediction":
input_df = preprocess_df(input_df, drop_duplicates=False)
else:
input_df = preprocess_df(input_df, cfg=CFG,drop_duplicates=False)
# Dataset & loader
dataset = ReactionT5Dataset(CFG, input_df)
dataloader = DataLoader(
dataset,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=int(num_workers),
pin_memory=(device.type == "cuda"),
drop_last=False,
)
if task == "yield prediction":
# Use custom inference function for yield prediction
prediction = []
total = len(dataloader)
progress = st.progress(0, text="Predicting yields...")
info_placeholder = st.empty()
for i, inputs in enumerate(dataloader, start=1):
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
y_preds = model(inputs)
prediction.extend(y_preds.to("cpu").numpy())
del y_preds
progress.progress(i / total, text=f"Predicting yields... {i}/{total}")
info_placeholder.caption(f"Processed batch {i} of {total}")
prediction = np.concatenate(prediction)
output_df = input_df.copy()
output_df["prediction"] = prediction
output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0)
st.session_state["results_df"] = output_df
st.success("Prediction complete.")
else:
# Generation loop with progress
all_sequences, all_scores = [], []
total = len(dataloader)
progress = st.progress(0, text="Generating predictions...")
info_placeholder = st.empty()
for i, inputs in enumerate(dataloader, start=1):
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
min_length=CFG.output_min_length,
max_length=CFG.output_max_length,
num_beams=CFG.num_beams,
num_return_sequences=CFG.num_return_sequences,
return_dict_in_generate=True,
output_scores=True,
)
sequences, scores = decode_output(output, CFG)
all_sequences.extend(sequences)
if scores:
all_scores.extend(scores)
del output
if device.type == "cuda":
torch.cuda.empty_cache()
gc.collect()
progress.progress(i / total, text=f"Generating predictions... {i}/{total}")
info_placeholder.caption(f"Processed batch {i} of {total}")
progress.empty()
info_placeholder.empty()
# Save predictions
try:
output_df = save_multiple_predictions(
input_df, all_sequences, all_scores, CFG
)
st.session_state["results_df"] = output_df
st.success("Prediction complete.")
except Exception as e:
st.session_state["last_error"] = f"Failed to assemble output: {e}"
st.error(st.session_state["last_error"])
st.stop()
# ------------------------------
# Results
# ------------------------------
if st.session_state.get("results_df") is not None:
st.subheader("Results preview")
st.dataframe(st.session_state["results_df"].head(50), use_container_width=True)
st.download_button(
label="Download predictions as CSV",
data=df_to_csv_bytes(st.session_state["results_df"]),
file_name="output.csv",
mime="text/csv",
use_container_width=True,
)
if st.session_state.get("last_error"):
st.error(st.session_state["last_error"])