Spaces:
Running
Running
| import os | |
| import random | |
| import numpy as np | |
| import warnings | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from torch.utils.data import Dataset, DataLoader | |
| import gc | |
| import streamlit as st | |
| warnings.filterwarnings("ignore") | |
| st.title('ReactionT5 task forward') | |
| st.markdown(''' | |
| ##### At this space, you can predict the products of reactions from their inputs. | |
| ##### The code expects input_data as a string or CSV file that contains an "input" column. | |
| ##### The format of the string or contents of the column should be "REACTANT:{reactants}REAGENT:{reagents}". | |
| ##### If there is no reagent, fill the blank with a space. For multiple compounds, concatenate them with ".". | |
| ##### The output contains SMILES of predicted products and the sum of log-likelihood for each prediction, ordered by their log-likelihood (0th is the most probable product). | |
| ''') | |
| display_text = 'input the reaction smiles (e.g. REACTANT:COC(=O)C1=CCCN(C)C1.O.[Al+3].[H-].[Li+].[Na+].[OH-]REAGENT:C1CCOC1)' | |
| st.download_button( | |
| label="Download demo_input.csv", | |
| data=pd.read_csv('demo_input.csv').to_csv(index=False), | |
| file_name='demo_input.csv', | |
| mime='text/csv', | |
| ) | |
| class CFG(): | |
| num_beams = st.number_input(label='num beams', min_value=1, max_value=10, value=5, step=1) | |
| num_return_sequences = num_beams | |
| uploaded_file = st.file_uploader("Choose a CSV file") | |
| input_data = st.text_area(display_text) | |
| model_name_or_path = 'sagawa/ReactionT5v2-forward' | |
| input_column = 'input' | |
| input_max_length = 400 | |
| model = 't5' | |
| seed = 42 | |
| batch_size=1 | |
| def seed_everything(seed=42): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| def prepare_input(cfg, text): | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| max_length=cfg.input_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| dic = {"input_ids": [], "attention_mask": []} | |
| for k, v in inputs.items(): | |
| dic[k].append(torch.tensor(v[0], dtype=torch.long)) | |
| return dic | |
| class ProductDataset(Dataset): | |
| def __init__(self, cfg, df): | |
| self.cfg = cfg | |
| self.inputs = df[cfg.input_column].values | |
| def __len__(self): | |
| return len(self.inputs) | |
| def __getitem__(self, idx): | |
| return prepare_input(self.cfg, self.inputs[idx]) | |
| def predict_single_input(input_compound): | |
| inp = tokenizer(input_compound, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inp, | |
| num_beams=CFG.num_beams, | |
| num_return_sequences=CFG.num_return_sequences, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| return output | |
| def decode_output(output): | |
| sequences = [ | |
| tokenizer.decode(seq, skip_special_tokens=True).replace(" ", "").rstrip(".") | |
| for seq in output["sequences"] | |
| ] | |
| if CFG.num_beams > 1: | |
| scores = output["sequences_scores"].tolist() | |
| return sequences, scores | |
| return sequences, None | |
| def save_single_prediction(input_compound, output, scores): | |
| output_data = [input_compound] + output + (scores if scores else []) | |
| columns = ( | |
| ["input"] | |
| + [f"{i}th" for i in range(CFG.num_beams)] | |
| + ([f"{i}th score" for i in range(CFG.num_beams)] if scores else []) | |
| ) | |
| output_df = pd.DataFrame([output_data], columns=columns) | |
| return output_df | |
| def save_multiple_predictions(input_data, sequences, scores): | |
| output_list = [ | |
| [input_data.loc[i // CFG.num_return_sequences, CFG.input_column]] | |
| + sequences[i : i + CFG.num_return_sequences] | |
| + scores[i : i + CFG.num_return_sequences] | |
| for i in range(0, len(sequences), CFG.num_return_sequences) | |
| ] | |
| columns = ( | |
| ["input"] | |
| + [f"{i}th" for i in range(CFG.num_return_sequences)] | |
| + ([f"{i}th score" for i in range(CFG.num_return_sequences)] if scores else []) | |
| ) | |
| output_df = pd.DataFrame(output_list, columns=columns) | |
| return output_df | |
| if st.button('predict'): | |
| with st.spinner('Now processing. If num beams=5, this process takes about 15 seconds per reaction.'): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| seed_everything(seed=CFG.seed) | |
| tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors="pt") | |
| model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device) | |
| model.eval() | |
| if CFG.uploaded_file is None: | |
| input_compound = CFG.input_data | |
| output = predict_single_input(input_compound) | |
| sequences, scores = decode_output(output) | |
| output_df = save_single_prediction(input_compound, sequences, scores) | |
| else: | |
| input_data = pd.read_csv(CFG.uploaded_file) | |
| dataset = ProductDataset(CFG, input_data) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=CFG.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True, | |
| drop_last=False, | |
| ) | |
| all_sequences, all_scores = [], [] | |
| for inputs in dataloader: | |
| inputs = {k: v[0].to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| num_beams=CFG.num_beams, | |
| num_return_sequences=CFG.num_return_sequences, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| sequences, scores = decode_output(output) | |
| all_sequences.extend(sequences) | |
| if scores: | |
| all_scores.extend(scores) | |
| del output | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| output_df = save_multiple_predictions(input_data, all_sequences, all_scores) | |
| def convert_df(df): | |
| return df.to_csv(index=False) | |
| csv = convert_df(output_df) | |
| st.download_button( | |
| label="Download data as CSV", | |
| data=csv, | |
| file_name='output.csv', | |
| mime='text/csv', | |
| ) | |