Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -1,22 +1,23 @@
|
|
| 1 |
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import pytorch_lightning as pl
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
import logging
|
| 9 |
from tokenizers import Tokenizer
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
-
import gc
|
| 12 |
-
import math
|
| 13 |
|
| 14 |
# --- Configuration ---
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
|
| 18 |
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
|
| 19 |
-
CONFIG_FILENAME = "config.json"
|
| 20 |
# --- End Configuration ---
|
| 21 |
|
| 22 |
# --- Logging ---
|
|
@@ -24,43 +25,41 @@ logging.basicConfig(
|
|
| 24 |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 25 |
)
|
| 26 |
|
| 27 |
-
# --- Load Helper Code (Only Model Definition Needed) ---
|
| 28 |
try:
|
| 29 |
-
# We
|
|
|
|
| 30 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
| 31 |
-
|
| 32 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
#
|
| 36 |
|
| 37 |
except ImportError as e:
|
| 38 |
logging.error(
|
| 39 |
-
f"Failed to import helper code from enhanced_trainer.py: {e}.
|
|
|
|
| 40 |
)
|
| 41 |
-
|
|
|
|
| 42 |
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
|
| 43 |
)
|
| 44 |
-
exit()
|
| 45 |
except Exception as e:
|
| 46 |
logging.error(
|
| 47 |
f"An unexpected error occurred during helper code import: {e}", exc_info=True
|
| 48 |
)
|
| 49 |
-
gr.Error(
|
| 50 |
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
|
| 51 |
)
|
| 52 |
-
exit()
|
| 53 |
|
| 54 |
# --- Global Variables (Load Model Once) ---
|
| 55 |
-
model: pl.LightningModule | None = None
|
| 56 |
smiles_tokenizer: Tokenizer | None = None
|
| 57 |
iupac_tokenizer: Tokenizer | None = None
|
| 58 |
device: torch.device | None = None
|
| 59 |
config: dict | None = None
|
| 60 |
|
| 61 |
-
# --- Beam Search Decoding Logic (
|
| 62 |
-
|
| 63 |
-
|
| 64 |
def beam_search_decode(
|
| 65 |
model: pl.LightningModule,
|
| 66 |
src: torch.Tensor,
|
|
@@ -68,36 +67,34 @@ def beam_search_decode(
|
|
| 68 |
max_len: int,
|
| 69 |
sos_idx: int,
|
| 70 |
eos_idx: int,
|
| 71 |
-
pad_idx: int,
|
| 72 |
device: torch.device,
|
| 73 |
beam_width: int = 5,
|
| 74 |
-
n_best: int = 5,
|
| 75 |
-
length_penalty: float = 0.6,
|
| 76 |
) -> list[torch.Tensor]:
|
| 77 |
"""
|
| 78 |
Performs beam search decoding using the LightningModule's model.
|
| 79 |
-
(
|
| 80 |
"""
|
| 81 |
-
# Ensure model is in
|
| 82 |
-
model.
|
| 83 |
-
|
| 84 |
-
n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences
|
| 85 |
|
| 86 |
try:
|
| 87 |
with torch.no_grad():
|
| 88 |
# --- Encode Source ---
|
| 89 |
memory = transformer_model.encode(
|
| 90 |
src, src_padding_mask
|
| 91 |
-
)
|
| 92 |
memory = memory.to(device)
|
| 93 |
-
|
| 94 |
-
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
|
| 95 |
|
| 96 |
# --- Initialize Beams ---
|
| 97 |
initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
| 98 |
sos_idx
|
| 99 |
-
)
|
| 100 |
-
initial_beam_score = torch.zeros(1, dtype=torch.float, device=device)
|
| 101 |
active_beams = [(initial_beam_seq, initial_beam_score)]
|
| 102 |
finished_beams = []
|
| 103 |
|
|
@@ -108,99 +105,125 @@ def beam_search_decode(
|
|
| 108 |
|
| 109 |
potential_next_beams = []
|
| 110 |
for current_seq, current_score in active_beams:
|
|
|
|
| 111 |
if current_seq[0, -1].item() == eos_idx:
|
|
|
|
| 112 |
finished_beams.append((current_seq, current_score))
|
| 113 |
continue
|
| 114 |
|
| 115 |
-
|
|
|
|
| 116 |
tgt_seq_len = tgt_input.shape[1]
|
| 117 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
| 118 |
device
|
| 119 |
-
)
|
|
|
|
| 120 |
tgt_padding_mask = torch.zeros(
|
| 121 |
tgt_input.shape, dtype=torch.bool, device=device
|
| 122 |
-
)
|
| 123 |
|
|
|
|
| 124 |
decoder_output = transformer_model.decode(
|
| 125 |
tgt=tgt_input,
|
| 126 |
memory=memory,
|
| 127 |
tgt_mask=tgt_mask,
|
| 128 |
tgt_padding_mask=tgt_padding_mask,
|
| 129 |
memory_key_padding_mask=memory_key_padding_mask,
|
| 130 |
-
)
|
| 131 |
|
|
|
|
| 132 |
next_token_logits = transformer_model.generator(
|
| 133 |
-
decoder_output[:, -1, :]
|
| 134 |
-
)
|
|
|
|
|
|
|
| 135 |
log_probs = F.log_softmax(
|
| 136 |
next_token_logits, dim=-1
|
| 137 |
-
)
|
|
|
|
| 138 |
|
|
|
|
| 139 |
topk_log_probs, topk_indices = torch.topk(
|
| 140 |
-
|
| 141 |
-
)
|
| 142 |
|
|
|
|
| 143 |
for i in range(beam_width):
|
| 144 |
next_token_id = topk_indices[0, i].item()
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
) # Keep as tensor [1]
|
| 148 |
next_token_tensor = torch.tensor(
|
| 149 |
[[next_token_id]], dtype=torch.long, device=device
|
| 150 |
-
)
|
| 151 |
new_seq = torch.cat(
|
| 152 |
[current_seq, next_token_tensor], dim=1
|
| 153 |
-
)
|
| 154 |
potential_next_beams.append((new_seq, next_score))
|
| 155 |
|
|
|
|
|
|
|
| 156 |
potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
|
| 157 |
|
|
|
|
| 158 |
active_beams = []
|
| 159 |
-
|
| 160 |
for seq, score in potential_next_beams:
|
|
|
|
|
|
|
|
|
|
| 161 |
is_finished = seq[0, -1].item() == eos_idx
|
| 162 |
if is_finished:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
finished_beams.extend(active_beams)
|
| 171 |
|
| 172 |
# Apply length penalty and sort
|
| 173 |
-
|
| 174 |
-
def get_score(beam_tuple):
|
| 175 |
seq, score = beam_tuple
|
| 176 |
seq_len = seq.shape[1]
|
| 177 |
-
|
|
|
|
| 178 |
return score.item()
|
| 179 |
else:
|
| 180 |
-
#
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
finished_beams.sort(key=
|
| 184 |
|
|
|
|
| 185 |
top_sequences = [
|
| 186 |
-
seq[:, 1:] for seq, score in finished_beams[:n_best]
|
| 187 |
-
]
|
| 188 |
return top_sequences
|
| 189 |
|
| 190 |
except RuntimeError as e:
|
| 191 |
-
logging.error(f"Runtime error during beam search decode: {e}")
|
| 192 |
-
if "CUDA out of memory" in str(e):
|
| 193 |
gc.collect()
|
| 194 |
torch.cuda.empty_cache()
|
| 195 |
-
return []
|
| 196 |
except Exception as e:
|
| 197 |
logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
|
| 198 |
return []
|
| 199 |
|
| 200 |
-
|
| 201 |
-
# --- Translation Function (Moved from test_ckpt.py) ---
|
| 202 |
-
|
| 203 |
-
|
| 204 |
def translate(
|
| 205 |
model: pl.LightningModule,
|
| 206 |
src_sentence: str,
|
|
@@ -217,38 +240,43 @@ def translate(
|
|
| 217 |
) -> list[str]:
|
| 218 |
"""
|
| 219 |
Translates a single SMILES string using beam search.
|
| 220 |
-
(
|
| 221 |
"""
|
| 222 |
-
model.eval()
|
| 223 |
translations = []
|
|
|
|
| 224 |
|
| 225 |
# --- Tokenize Source ---
|
| 226 |
try:
|
|
|
|
|
|
|
| 227 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
| 228 |
if not src_encoded or not src_encoded.ids:
|
| 229 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
| 230 |
return ["[Encoding Error]"] * n_best
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
return ["[Encoding Error - Empty Src]"] * n_best
|
| 235 |
except Exception as e:
|
| 236 |
-
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
|
| 237 |
return ["[Encoding Error]"] * n_best
|
| 238 |
|
| 239 |
# --- Prepare Input Tensor and Mask ---
|
| 240 |
src = (
|
| 241 |
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
|
| 242 |
-
)
|
| 243 |
-
|
|
|
|
| 244 |
|
| 245 |
# --- Perform Beam Search Decoding ---
|
| 246 |
-
# Calls the beam_search_decode function defined above in this file
|
|
|
|
|
|
|
| 247 |
tgt_tokens_list = beam_search_decode(
|
| 248 |
model=model,
|
| 249 |
src=src,
|
| 250 |
src_padding_mask=src_padding_mask,
|
| 251 |
-
max_len=
|
| 252 |
sos_idx=sos_idx,
|
| 253 |
eos_idx=eos_idx,
|
| 254 |
pad_idx=pad_idx,
|
|
@@ -256,25 +284,28 @@ def translate(
|
|
| 256 |
beam_width=beam_width,
|
| 257 |
n_best=n_best,
|
| 258 |
length_penalty=length_penalty,
|
| 259 |
-
)
|
| 260 |
|
| 261 |
# --- Decode Generated Tokens ---
|
| 262 |
if not tgt_tokens_list:
|
| 263 |
logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
|
|
|
|
| 264 |
return ["[Decoding Error - Empty Output]"] * n_best
|
| 265 |
|
| 266 |
-
for tgt_tokens_tensor in tgt_tokens_list:
|
| 267 |
-
if tgt_tokens_tensor.numel() > 0:
|
| 268 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
| 269 |
try:
|
|
|
|
| 270 |
translation = iupac_tokenizer.decode(
|
| 271 |
tgt_tokens, skip_special_tokens=True
|
| 272 |
)
|
| 273 |
translations.append(translation)
|
| 274 |
except Exception as e:
|
| 275 |
-
logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
|
| 276 |
translations.append("[Decoding Error]")
|
| 277 |
else:
|
|
|
|
| 278 |
translations.append("[Decoding Error - Empty Tensor]")
|
| 279 |
|
| 280 |
# Pad with error messages if fewer than n_best results were generated
|
|
@@ -283,39 +314,53 @@ def translate(
|
|
| 283 |
|
| 284 |
return translations
|
| 285 |
|
| 286 |
-
|
| 287 |
-
# --- Model/Tokenizer Loading Function (Unchanged) ---
|
| 288 |
def load_model_and_tokenizers():
|
| 289 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
| 290 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
| 291 |
-
if model is not None:
|
| 292 |
logging.info("Model and tokenizers already loaded.")
|
| 293 |
return
|
| 294 |
|
| 295 |
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
|
| 296 |
try:
|
| 297 |
-
device
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
# Download files from HF Hub
|
| 301 |
logging.info("Downloading files from Hugging Face Hub...")
|
| 302 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
checkpoint_path = hf_hub_download(
|
| 304 |
-
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME
|
| 305 |
)
|
| 306 |
smiles_tokenizer_path = hf_hub_download(
|
| 307 |
-
repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME
|
| 308 |
)
|
| 309 |
iupac_tokenizer_path = hf_hub_download(
|
| 310 |
-
repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME
|
| 311 |
)
|
| 312 |
config_path = hf_hub_download(
|
| 313 |
-
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME
|
| 314 |
)
|
| 315 |
logging.info("Files downloaded successfully.")
|
| 316 |
except Exception as e:
|
| 317 |
logging.error(
|
| 318 |
-
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
|
| 319 |
exc_info=True,
|
| 320 |
)
|
| 321 |
raise gr.Error(
|
|
@@ -329,41 +374,81 @@ def load_model_and_tokenizers():
|
|
| 329 |
config = json.load(f)
|
| 330 |
logging.info("Configuration loaded.")
|
| 331 |
# --- Validate essential config keys ---
|
|
|
|
|
|
|
|
|
|
| 332 |
required_keys = [
|
| 333 |
-
|
| 334 |
-
"
|
|
|
|
|
|
|
| 335 |
"emb_size",
|
| 336 |
"nhead",
|
| 337 |
"ffn_hid_dim",
|
| 338 |
"num_encoder_layers",
|
| 339 |
"num_decoder_layers",
|
| 340 |
"dropout",
|
| 341 |
-
"max_len",
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
"pad_token_id",
|
|
|
|
|
|
|
| 345 |
]
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
if missing_keys:
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
)
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
except FileNotFoundError:
|
| 353 |
-
logging.error(
|
| 354 |
-
|
| 355 |
-
)
|
| 356 |
-
raise gr.Error(
|
| 357 |
-
f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
|
| 358 |
-
)
|
| 359 |
except json.JSONDecodeError as e:
|
| 360 |
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
|
| 361 |
-
raise gr.Error(
|
| 362 |
-
|
| 363 |
-
)
|
| 364 |
-
except ValueError as e:
|
| 365 |
logging.error(f"Config validation error: {e}")
|
| 366 |
raise gr.Error(f"Config Error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
# Load tokenizers
|
| 369 |
logging.info("Loading tokenizers...")
|
|
@@ -371,25 +456,33 @@ def load_model_and_tokenizers():
|
|
| 371 |
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
|
| 372 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
| 373 |
logging.info("Tokenizers loaded.")
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
except Exception as e:
|
| 394 |
logging.error(
|
| 395 |
f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
|
|
@@ -402,115 +495,139 @@ def load_model_and_tokenizers():
|
|
| 402 |
# Load model
|
| 403 |
logging.info("Loading model from checkpoint...")
|
| 404 |
try:
|
|
|
|
|
|
|
| 405 |
model = SmilesIupacLitModule.load_from_checkpoint(
|
| 406 |
checkpoint_path,
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
device
|
|
|
|
|
|
|
| 413 |
)
|
|
|
|
|
|
|
| 414 |
model.to(device)
|
| 415 |
model.eval()
|
| 416 |
-
model.freeze()
|
| 417 |
logging.info(
|
| 418 |
-
"Model loaded successfully, set to eval mode, frozen, and moved to device."
|
| 419 |
)
|
| 420 |
|
| 421 |
except FileNotFoundError:
|
| 422 |
-
logging.error(
|
| 423 |
-
|
| 424 |
-
)
|
| 425 |
-
raise gr.Error(
|
| 426 |
-
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
|
| 427 |
-
)
|
| 428 |
except Exception as e:
|
| 429 |
logging.error(
|
| 430 |
-
f"Error loading model from checkpoint {checkpoint_path}: {e}",
|
| 431 |
-
exc_info=True,
|
| 432 |
)
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
gc.collect()
|
| 435 |
-
if device == torch.
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
|
| 439 |
-
)
|
| 440 |
|
| 441 |
-
except gr.Error:
|
| 442 |
raise
|
| 443 |
-
except Exception as e:
|
| 444 |
-
logging.error(
|
| 445 |
-
|
| 446 |
-
)
|
| 447 |
-
raise gr.Error(
|
| 448 |
-
f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
|
| 449 |
-
)
|
| 450 |
|
| 451 |
|
| 452 |
-
# --- Inference Function for Gradio
|
| 453 |
-
def predict_iupac(smiles_string,
|
| 454 |
"""
|
| 455 |
Performs SMILES to IUPAC translation using the loaded model and beam search.
|
|
|
|
| 456 |
"""
|
| 457 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
| 458 |
|
| 459 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
| 460 |
-
error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
except:
|
| 465 |
-
n_best_int = 1
|
| 466 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
| 467 |
|
| 468 |
if not smiles_string or not smiles_string.strip():
|
| 469 |
error_msg = "Error: Please enter a valid SMILES string."
|
| 470 |
-
try:
|
| 471 |
-
|
| 472 |
-
except:
|
| 473 |
-
n_best_int = 1
|
| 474 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
| 475 |
|
| 476 |
smiles_input = smiles_string.strip()
|
|
|
|
|
|
|
| 477 |
try:
|
| 478 |
-
beam_width = int(
|
| 479 |
-
n_best = int(
|
| 480 |
-
length_penalty = float(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
except ValueError as e:
|
| 482 |
-
error_msg = f"Error: Invalid input parameter
|
| 483 |
-
|
|
|
|
|
|
|
| 484 |
|
| 485 |
logging.info(
|
| 486 |
-
f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})"
|
| 487 |
)
|
| 488 |
|
| 489 |
try:
|
| 490 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
predicted_names = translate(
|
| 492 |
model=model,
|
| 493 |
src_sentence=smiles_input,
|
| 494 |
smiles_tokenizer=smiles_tokenizer,
|
| 495 |
iupac_tokenizer=iupac_tokenizer,
|
| 496 |
device=device,
|
| 497 |
-
max_len=
|
| 498 |
-
sos_idx=
|
| 499 |
-
eos_idx=
|
| 500 |
-
pad_idx=
|
| 501 |
beam_width=beam_width,
|
| 502 |
n_best=n_best,
|
| 503 |
length_penalty=length_penalty,
|
| 504 |
)
|
| 505 |
logging.info(f"Predictions returned: {predicted_names}")
|
| 506 |
|
|
|
|
| 507 |
if not predicted_names:
|
| 508 |
-
output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
|
| 509 |
else:
|
| 510 |
-
|
|
|
|
|
|
|
|
|
|
| 511 |
output_text += "\n".join(
|
| 512 |
-
[f"{i + 1}. {name}" for i, name in enumerate(
|
| 513 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
return output_text
|
| 516 |
|
|
@@ -519,9 +636,9 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
|
|
| 519 |
error_msg = f"Runtime Error during translation: {e}"
|
| 520 |
if "memory" in str(e).lower():
|
| 521 |
gc.collect()
|
| 522 |
-
if device == torch.
|
| 523 |
-
|
| 524 |
-
|
| 525 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
|
| 526 |
|
| 527 |
except Exception as e:
|
|
@@ -530,85 +647,94 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
|
|
| 530 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
|
| 531 |
|
| 532 |
|
| 533 |
-
# --- Load Model on App Start
|
|
|
|
|
|
|
| 534 |
try:
|
| 535 |
load_model_and_tokenizers()
|
| 536 |
-
except gr.Error:
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
| 538 |
except Exception as e:
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
|
| 543 |
|
| 544 |
|
| 545 |
-
# --- Create Gradio Interface
|
| 546 |
title = "SMILES to IUPAC Name Translator"
|
| 547 |
description = f"""
|
| 548 |
-
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model
|
| 549 |
-
|
| 550 |
-
|
| 551 |
"""
|
| 552 |
|
|
|
|
| 553 |
examples = [
|
| 554 |
-
["CCO", 5, 3, 0.6],
|
| 555 |
-
["C1=CC=CC=C1", 5, 3, 0.6],
|
| 556 |
-
["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6],
|
| 557 |
-
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6],
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
1,
|
| 562 |
-
0.6,
|
| 563 |
-
], # Complex example
|
| 564 |
-
["INVALID_SMILES", 5, 1, 0.6],
|
| 565 |
]
|
| 566 |
|
|
|
|
| 567 |
smiles_input = gr.Textbox(
|
| 568 |
label="SMILES String",
|
| 569 |
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
|
| 570 |
lines=1,
|
| 571 |
)
|
|
|
|
| 572 |
beam_width_input = gr.Slider(
|
| 573 |
-
minimum=1,
|
| 574 |
-
|
| 575 |
-
value=5,
|
| 576 |
-
step=1,
|
| 577 |
-
label="Beam Width (k)",
|
| 578 |
-
info="Number of sequences to keep at each decoding step (higher = more exploration, slower).",
|
| 579 |
)
|
| 580 |
n_best_input = gr.Slider(
|
| 581 |
-
minimum=1,
|
| 582 |
-
|
| 583 |
-
value=3,
|
| 584 |
-
step=1,
|
| 585 |
-
label="Number of Results (n_best)",
|
| 586 |
-
info="How many top-scoring sequences to return (must be <= Beam Width).",
|
| 587 |
)
|
| 588 |
length_penalty_input = gr.Slider(
|
| 589 |
-
minimum=0.0,
|
| 590 |
-
|
| 591 |
-
value=0.6,
|
| 592 |
-
step=0.1,
|
| 593 |
-
label="Length Penalty (alpha)",
|
| 594 |
-
info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty.",
|
| 595 |
)
|
| 596 |
output_text = gr.Textbox(
|
| 597 |
label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
|
| 598 |
)
|
| 599 |
|
|
|
|
| 600 |
iface = gr.Interface(
|
| 601 |
-
fn=predict_iupac,
|
| 602 |
-
inputs=[
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
title=title,
|
| 605 |
description=description,
|
| 606 |
-
examples=examples,
|
| 607 |
-
allow_flagging="never",
|
| 608 |
-
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
|
| 609 |
-
article="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
)
|
| 611 |
|
| 612 |
-
# --- Launch the App
|
| 613 |
if __name__ == "__main__":
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
+
import torch.nn.functional as F # Needed for beam search log_softmax
|
| 5 |
+
import pytorch_lightning as pl # Needed for LightningModule and loading
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
import logging
|
| 9 |
from tokenizers import Tokenizer
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
+
import gc # For garbage collection on potential OOM
|
| 12 |
+
import math # Potentially needed by imported classes
|
| 13 |
|
| 14 |
# --- Configuration ---
|
| 15 |
+
# Ensure these match the files uploaded to your Hugging Face Hub repository
|
| 16 |
+
MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
|
| 17 |
+
CHECKPOINT_FILENAME = "last.ckpt" # Or "best_model.ckpt" or whatever you uploaded
|
| 18 |
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
|
| 19 |
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
|
| 20 |
+
CONFIG_FILENAME = "config.json" # Assumes you saved hparams to config.json during/after training
|
| 21 |
# --- End Configuration ---
|
| 22 |
|
| 23 |
# --- Logging ---
|
|
|
|
| 25 |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 26 |
)
|
| 27 |
|
| 28 |
+
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
|
| 29 |
try:
|
| 30 |
+
# We need the LightningModule definition and the mask function
|
| 31 |
+
# Ensure enhanced_trainer.py is present in the root of your HF Repo
|
| 32 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
|
|
|
| 33 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
| 34 |
|
| 35 |
+
# REMOVED: Redundant import from test_ckpt as functions are defined below
|
| 36 |
+
# from test_ckpt import beam_search_decode, translate
|
| 37 |
|
| 38 |
except ImportError as e:
|
| 39 |
logging.error(
|
| 40 |
+
f"Failed to import helper code from enhanced_trainer.py: {e}. "
|
| 41 |
+
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
|
| 42 |
)
|
| 43 |
+
# Raise error visible in Gradio UI and logs
|
| 44 |
+
raise gr.Error(
|
| 45 |
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
|
| 46 |
)
|
|
|
|
| 47 |
except Exception as e:
|
| 48 |
logging.error(
|
| 49 |
f"An unexpected error occurred during helper code import: {e}", exc_info=True
|
| 50 |
)
|
| 51 |
+
raise gr.Error(
|
| 52 |
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
|
| 53 |
)
|
|
|
|
| 54 |
|
| 55 |
# --- Global Variables (Load Model Once) ---
|
| 56 |
+
model: pl.LightningModule | None = None
|
| 57 |
smiles_tokenizer: Tokenizer | None = None
|
| 58 |
iupac_tokenizer: Tokenizer | None = None
|
| 59 |
device: torch.device | None = None
|
| 60 |
config: dict | None = None
|
| 61 |
|
| 62 |
+
# --- Beam Search Decoding Logic (Locally defined) ---
|
|
|
|
|
|
|
| 63 |
def beam_search_decode(
|
| 64 |
model: pl.LightningModule,
|
| 65 |
src: torch.Tensor,
|
|
|
|
| 67 |
max_len: int,
|
| 68 |
sos_idx: int,
|
| 69 |
eos_idx: int,
|
| 70 |
+
pad_idx: int,
|
| 71 |
device: torch.device,
|
| 72 |
beam_width: int = 5,
|
| 73 |
+
n_best: int = 5,
|
| 74 |
+
length_penalty: float = 0.6,
|
| 75 |
) -> list[torch.Tensor]:
|
| 76 |
"""
|
| 77 |
Performs beam search decoding using the LightningModule's model.
|
| 78 |
+
(Ensures this code is self-contained within app.py or correctly imported)
|
| 79 |
"""
|
| 80 |
+
model.eval() # Ensure model is in evaluation mode
|
| 81 |
+
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
| 82 |
+
n_best = min(n_best, beam_width)
|
|
|
|
| 83 |
|
| 84 |
try:
|
| 85 |
with torch.no_grad():
|
| 86 |
# --- Encode Source ---
|
| 87 |
memory = transformer_model.encode(
|
| 88 |
src, src_padding_mask
|
| 89 |
+
) # [1, src_len, emb_size]
|
| 90 |
memory = memory.to(device)
|
| 91 |
+
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
|
|
|
|
| 92 |
|
| 93 |
# --- Initialize Beams ---
|
| 94 |
initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
| 95 |
sos_idx
|
| 96 |
+
) # [1, 1]
|
| 97 |
+
initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
|
| 98 |
active_beams = [(initial_beam_seq, initial_beam_score)]
|
| 99 |
finished_beams = []
|
| 100 |
|
|
|
|
| 105 |
|
| 106 |
potential_next_beams = []
|
| 107 |
for current_seq, current_score in active_beams:
|
| 108 |
+
# Check if the beam already ended
|
| 109 |
if current_seq[0, -1].item() == eos_idx:
|
| 110 |
+
# If already finished, add directly to finished beams and skip expansion
|
| 111 |
finished_beams.append((current_seq, current_score))
|
| 112 |
continue
|
| 113 |
|
| 114 |
+
# Prepare inputs for the decoder
|
| 115 |
+
tgt_input = current_seq # [1, current_len]
|
| 116 |
tgt_seq_len = tgt_input.shape[1]
|
| 117 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
| 118 |
device
|
| 119 |
+
) # [curr_len, curr_len]
|
| 120 |
+
# No padding in target during generation yet
|
| 121 |
tgt_padding_mask = torch.zeros(
|
| 122 |
tgt_input.shape, dtype=torch.bool, device=device
|
| 123 |
+
) # [1, curr_len]
|
| 124 |
|
| 125 |
+
# Decode one step
|
| 126 |
decoder_output = transformer_model.decode(
|
| 127 |
tgt=tgt_input,
|
| 128 |
memory=memory,
|
| 129 |
tgt_mask=tgt_mask,
|
| 130 |
tgt_padding_mask=tgt_padding_mask,
|
| 131 |
memory_key_padding_mask=memory_key_padding_mask,
|
| 132 |
+
) # [1, curr_len, emb_size]
|
| 133 |
|
| 134 |
+
# Get logits for the *next* token prediction
|
| 135 |
next_token_logits = transformer_model.generator(
|
| 136 |
+
decoder_output[:, -1, :] # Use output corresponding to the last input token
|
| 137 |
+
) # [1, tgt_vocab_size]
|
| 138 |
+
|
| 139 |
+
# Calculate log probabilities and add current beam score
|
| 140 |
log_probs = F.log_softmax(
|
| 141 |
next_token_logits, dim=-1
|
| 142 |
+
) # [1, tgt_vocab_size]
|
| 143 |
+
combined_scores = log_probs + current_score # Add score of the current path
|
| 144 |
|
| 145 |
+
# Find top k candidates for the *next* step
|
| 146 |
topk_log_probs, topk_indices = torch.topk(
|
| 147 |
+
combined_scores, beam_width, dim=-1
|
| 148 |
+
) # [1, beam_width], [1, beam_width]
|
| 149 |
|
| 150 |
+
# Expand potential beams
|
| 151 |
for i in range(beam_width):
|
| 152 |
next_token_id = topk_indices[0, i].item()
|
| 153 |
+
# Score is the cumulative log probability of the new sequence
|
| 154 |
+
next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1]
|
|
|
|
| 155 |
next_token_tensor = torch.tensor(
|
| 156 |
[[next_token_id]], dtype=torch.long, device=device
|
| 157 |
+
) # [1, 1]
|
| 158 |
new_seq = torch.cat(
|
| 159 |
[current_seq, next_token_tensor], dim=1
|
| 160 |
+
) # [1, current_len + 1]
|
| 161 |
potential_next_beams.append((new_seq, next_score))
|
| 162 |
|
| 163 |
+
# --- Prune Beams ---
|
| 164 |
+
# Sort all potential next beams by score
|
| 165 |
potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
|
| 166 |
|
| 167 |
+
# Select the top `beam_width` beams for the next iteration
|
| 168 |
active_beams = []
|
| 169 |
+
temp_finished_beams = [] # Collect beams finished in *this* step
|
| 170 |
for seq, score in potential_next_beams:
|
| 171 |
+
if len(active_beams) >= beam_width and len(temp_finished_beams) >= beam_width:
|
| 172 |
+
break # Optimization: Stop if we have enough active and finished candidates
|
| 173 |
+
|
| 174 |
is_finished = seq[0, -1].item() == eos_idx
|
| 175 |
if is_finished:
|
| 176 |
+
# Add to temporary finished list for this step
|
| 177 |
+
if len(temp_finished_beams) < beam_width:
|
| 178 |
+
temp_finished_beams.append((seq, score))
|
| 179 |
+
elif len(active_beams) < beam_width:
|
| 180 |
+
# Add to active beams for next step
|
| 181 |
+
active_beams.append((seq, score))
|
| 182 |
+
|
| 183 |
+
# Add the newly finished beams to the main finished list
|
| 184 |
+
finished_beams.extend(temp_finished_beams)
|
| 185 |
+
# Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
|
| 186 |
+
finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
|
| 187 |
+
finished_beams = finished_beams[:beam_width * 2] # Keep a reasonable number
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# --- Final Selection ---
|
| 191 |
+
# Add any remaining active beams (which didn't finish) to the finished list
|
| 192 |
finished_beams.extend(active_beams)
|
| 193 |
|
| 194 |
# Apply length penalty and sort
|
| 195 |
+
def get_score_with_penalty(beam_tuple):
|
|
|
|
| 196 |
seq, score = beam_tuple
|
| 197 |
seq_len = seq.shape[1]
|
| 198 |
+
# Avoid division by zero or negative exponent issues
|
| 199 |
+
if length_penalty <= 0.0 or seq_len <= 1:
|
| 200 |
return score.item()
|
| 201 |
else:
|
| 202 |
+
# Length penalty calculation
|
| 203 |
+
penalty = ((5.0 + float(seq_len)) / 6.0) ** length_penalty # Common formula
|
| 204 |
+
return score.item() / penalty
|
| 205 |
+
# Alternative simpler penalty:
|
| 206 |
+
# return score.item() / (float(seq_len) ** length_penalty)
|
| 207 |
|
| 208 |
+
finished_beams.sort(key=get_score_with_penalty, reverse=True) # Higher score is better
|
| 209 |
|
| 210 |
+
# Return the top n_best sequences (excluding the initial SOS token)
|
| 211 |
top_sequences = [
|
| 212 |
+
seq[:, 1:] for seq, score in finished_beams[:n_best] if seq.shape[1] > 1 # Ensure seq not just SOS
|
| 213 |
+
] # seq shape [1, len] -> [1, len-1]
|
| 214 |
return top_sequences
|
| 215 |
|
| 216 |
except RuntimeError as e:
|
| 217 |
+
logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
|
| 218 |
+
if "CUDA out of memory" in str(e) and device.type == 'cuda':
|
| 219 |
gc.collect()
|
| 220 |
torch.cuda.empty_cache()
|
| 221 |
+
return [] # Return empty list on error
|
| 222 |
except Exception as e:
|
| 223 |
logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
|
| 224 |
return []
|
| 225 |
|
| 226 |
+
# --- Translation Function (Locally defined) ---
|
|
|
|
|
|
|
|
|
|
| 227 |
def translate(
|
| 228 |
model: pl.LightningModule,
|
| 229 |
src_sentence: str,
|
|
|
|
| 240 |
) -> list[str]:
|
| 241 |
"""
|
| 242 |
Translates a single SMILES string using beam search.
|
| 243 |
+
(Ensures this code is self-contained within app.py or correctly imported)
|
| 244 |
"""
|
| 245 |
+
model.eval() # Ensure model is in eval mode
|
| 246 |
translations = []
|
| 247 |
+
n_best = min(n_best, beam_width) # Can't return more than beam width
|
| 248 |
|
| 249 |
# --- Tokenize Source ---
|
| 250 |
try:
|
| 251 |
+
# Ensure tokenizer has truncation/padding configured if needed, or handle manually
|
| 252 |
+
smiles_tokenizer.enable_truncation(max_length=max_len)
|
| 253 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
| 254 |
if not src_encoded or not src_encoded.ids:
|
| 255 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
| 256 |
return ["[Encoding Error]"] * n_best
|
| 257 |
+
# Use the truncated IDs directly
|
| 258 |
+
src_ids = src_encoded.ids
|
| 259 |
+
# Note: max_len here applies to source *tokenizer*, generation length is separate
|
|
|
|
| 260 |
except Exception as e:
|
| 261 |
+
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
|
| 262 |
return ["[Encoding Error]"] * n_best
|
| 263 |
|
| 264 |
# --- Prepare Input Tensor and Mask ---
|
| 265 |
src = (
|
| 266 |
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
|
| 267 |
+
) # [1, src_len]
|
| 268 |
+
# Create padding mask (True where it's a pad token, should be all False here)
|
| 269 |
+
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
|
| 270 |
|
| 271 |
# --- Perform Beam Search Decoding ---
|
| 272 |
+
# Calls the beam_search_decode function defined *above in this file*
|
| 273 |
+
# Note: max_len for generation should come from config if it dictates output length
|
| 274 |
+
generation_max_len = config.get("max_len", 256) # Use config max_len for output limit
|
| 275 |
tgt_tokens_list = beam_search_decode(
|
| 276 |
model=model,
|
| 277 |
src=src,
|
| 278 |
src_padding_mask=src_padding_mask,
|
| 279 |
+
max_len=generation_max_len, # Use generation limit
|
| 280 |
sos_idx=sos_idx,
|
| 281 |
eos_idx=eos_idx,
|
| 282 |
pad_idx=pad_idx,
|
|
|
|
| 284 |
beam_width=beam_width,
|
| 285 |
n_best=n_best,
|
| 286 |
length_penalty=length_penalty,
|
| 287 |
+
) # Returns list of tensors
|
| 288 |
|
| 289 |
# --- Decode Generated Tokens ---
|
| 290 |
if not tgt_tokens_list:
|
| 291 |
logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
|
| 292 |
+
# Provide n_best error messages
|
| 293 |
return ["[Decoding Error - Empty Output]"] * n_best
|
| 294 |
|
| 295 |
+
for i, tgt_tokens_tensor in enumerate(tgt_tokens_list):
|
| 296 |
+
if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0:
|
| 297 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
| 298 |
try:
|
| 299 |
+
# Decode using the target tokenizer, skipping special tokens
|
| 300 |
translation = iupac_tokenizer.decode(
|
| 301 |
tgt_tokens, skip_special_tokens=True
|
| 302 |
)
|
| 303 |
translations.append(translation)
|
| 304 |
except Exception as e:
|
| 305 |
+
logging.error(f"Error decoding target tokens {tgt_tokens} for beam {i}: {e}", exc_info=True)
|
| 306 |
translations.append("[Decoding Error]")
|
| 307 |
else:
|
| 308 |
+
logging.warning(f"Beam {i} result was empty or None for SMILES: {src_sentence}")
|
| 309 |
translations.append("[Decoding Error - Empty Tensor]")
|
| 310 |
|
| 311 |
# Pad with error messages if fewer than n_best results were generated
|
|
|
|
| 314 |
|
| 315 |
return translations
|
| 316 |
|
| 317 |
+
# --- Model/Tokenizer Loading Function ---
|
|
|
|
| 318 |
def load_model_and_tokenizers():
|
| 319 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
| 320 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
| 321 |
+
if model is not None: # Already loaded
|
| 322 |
logging.info("Model and tokenizers already loaded.")
|
| 323 |
return
|
| 324 |
|
| 325 |
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
|
| 326 |
try:
|
| 327 |
+
# Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
|
| 328 |
+
# For simplicity and broader compatibility on free tier Spaces, CPU is safer.
|
| 329 |
+
if torch.cuda.is_available():
|
| 330 |
+
logging.warning("CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended.")
|
| 331 |
+
device = torch.device("cpu")
|
| 332 |
+
# Uncomment below and comment above line to try using GPU if available
|
| 333 |
+
# device = torch.device("cuda")
|
| 334 |
+
# logging.info("CUDA available, using GPU.")
|
| 335 |
+
else:
|
| 336 |
+
device = torch.device("cpu")
|
| 337 |
+
logging.info("CUDA not available, using CPU.")
|
| 338 |
+
|
| 339 |
|
| 340 |
# Download files from HF Hub
|
| 341 |
logging.info("Downloading files from Hugging Face Hub...")
|
| 342 |
try:
|
| 343 |
+
# Use cache directory for Spaces persistence if possible
|
| 344 |
+
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache") # Gradio sets cache dir
|
| 345 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 346 |
+
logging.info(f"Using cache directory: {cache_dir}")
|
| 347 |
+
|
| 348 |
checkpoint_path = hf_hub_download(
|
| 349 |
+
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
|
| 350 |
)
|
| 351 |
smiles_tokenizer_path = hf_hub_download(
|
| 352 |
+
repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir
|
| 353 |
)
|
| 354 |
iupac_tokenizer_path = hf_hub_download(
|
| 355 |
+
repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir
|
| 356 |
)
|
| 357 |
config_path = hf_hub_download(
|
| 358 |
+
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
|
| 359 |
)
|
| 360 |
logging.info("Files downloaded successfully.")
|
| 361 |
except Exception as e:
|
| 362 |
logging.error(
|
| 363 |
+
f"Failed to download files from {MODEL_REPO_ID}. Check filenames ({CHECKPOINT_FILENAME}, {SMILES_TOKENIZER_FILENAME}, etc.) and repo status. Error: {e}",
|
| 364 |
exc_info=True,
|
| 365 |
)
|
| 366 |
raise gr.Error(
|
|
|
|
| 374 |
config = json.load(f)
|
| 375 |
logging.info("Configuration loaded.")
|
| 376 |
# --- Validate essential config keys ---
|
| 377 |
+
# Use hparams logged during training if available, map them carefully
|
| 378 |
+
# These keys are based on SmilesIupacLitModule and Seq2SeqTransformer init args
|
| 379 |
+
# Mappings might be needed if keys in config.json differ from these exact names
|
| 380 |
required_keys = [
|
| 381 |
+
# Need vocab sizes used during *training* for loading
|
| 382 |
+
"actual_src_vocab_size", # Assuming this was saved in hparams
|
| 383 |
+
"actual_tgt_vocab_size", # Assuming this was saved in hparams
|
| 384 |
+
# Model architecture params
|
| 385 |
"emb_size",
|
| 386 |
"nhead",
|
| 387 |
"ffn_hid_dim",
|
| 388 |
"num_encoder_layers",
|
| 389 |
"num_decoder_layers",
|
| 390 |
"dropout",
|
| 391 |
+
"max_len", # Needed for generation limit and tokenizer setting
|
| 392 |
+
# Special token IDs needed for generation
|
| 393 |
+
# Assuming standard names, adjust if your config uses different keys
|
| 394 |
+
"pad_token_id", # Often 0
|
| 395 |
+
"bos_token_id", # Often 1 (used as SOS)
|
| 396 |
+
"eos_token_id", # Often 2
|
| 397 |
]
|
| 398 |
+
# Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
|
| 399 |
+
config_key_mapping = {
|
| 400 |
+
"actual_src_vocab_size": config.get("actual_src_vocab_size", config.get("src_vocab_size")),
|
| 401 |
+
"actual_tgt_vocab_size": config.get("actual_tgt_vocab_size", config.get("tgt_vocab_size")),
|
| 402 |
+
"emb_size": config.get("emb_size"),
|
| 403 |
+
"nhead": config.get("nhead"),
|
| 404 |
+
"ffn_hid_dim": config.get("ffn_hid_dim"),
|
| 405 |
+
"num_encoder_layers": config.get("num_encoder_layers"),
|
| 406 |
+
"num_decoder_layers": config.get("num_decoder_layers"),
|
| 407 |
+
"dropout": config.get("dropout"),
|
| 408 |
+
"max_len": config.get("max_len"),
|
| 409 |
+
"pad_token_id": config.get("pad_token_id", PAD_IDX), # Use default if missing? Risky.
|
| 410 |
+
"bos_token_id": config.get("bos_token_id", SOS_IDX), # Use default if missing? Risky.
|
| 411 |
+
"eos_token_id": config.get("eos_token_id", EOS_IDX), # Use default if missing? Risky.
|
| 412 |
+
}
|
| 413 |
+
# Update config with potentially remapped values
|
| 414 |
+
config.update(config_key_mapping)
|
| 415 |
+
|
| 416 |
+
missing_keys = [key for key in required_keys if config.get(key) is None]
|
| 417 |
if missing_keys:
|
| 418 |
+
# Try to load defaults for token IDs if absolutely necessary, but warn heavily
|
| 419 |
+
defaults_used = []
|
| 420 |
+
if "pad_token_id" in missing_keys and 'PAD_IDX' in globals(): config["pad_token_id"] = PAD_IDX; defaults_used.append("pad_token_id")
|
| 421 |
+
if "bos_token_id" in missing_keys and 'SOS_IDX' in globals(): config["bos_token_id"] = SOS_IDX; defaults_used.append("bos_token_id")
|
| 422 |
+
if "eos_token_id" in missing_keys and 'EOS_IDX' in globals(): config["eos_token_id"] = EOS_IDX; defaults_used.append("eos_token_id")
|
| 423 |
+
|
| 424 |
+
# Re-check missing keys after attempting defaults
|
| 425 |
+
missing_keys = [key for key in required_keys if config.get(key) is None]
|
| 426 |
+
if missing_keys:
|
| 427 |
+
raise ValueError(
|
| 428 |
+
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
|
| 429 |
+
f"Ensure these were saved in the hyperparameters during training."
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
logging.warning(f"Config file was missing keys, used defaults for: {defaults_used}. This might be incorrect!")
|
| 433 |
+
|
| 434 |
+
# Log the final config values being used
|
| 435 |
+
logging.info(f"Using config values: src_vocab={config['actual_src_vocab_size']}, tgt_vocab={config['actual_tgt_vocab_size']}, "
|
| 436 |
+
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
|
| 437 |
+
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}")
|
| 438 |
+
|
| 439 |
except FileNotFoundError:
|
| 440 |
+
logging.error(f"Config file not found locally after download attempt: {config_path}")
|
| 441 |
+
raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
except json.JSONDecodeError as e:
|
| 443 |
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
|
| 444 |
+
raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}")
|
| 445 |
+
except ValueError as e: # Catch our custom validation error
|
|
|
|
|
|
|
| 446 |
logging.error(f"Config validation error: {e}")
|
| 447 |
raise gr.Error(f"Config Error: {e}")
|
| 448 |
+
except Exception as e: # Catch other potential errors during config processing
|
| 449 |
+
logging.error(f"Unexpected error loading or validating config: {e}", exc_info=True)
|
| 450 |
+
raise gr.Error(f"Config Error: Unexpected error processing config. Check logs. Error: {e}")
|
| 451 |
+
|
| 452 |
|
| 453 |
# Load tokenizers
|
| 454 |
logging.info("Loading tokenizers...")
|
|
|
|
| 456 |
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
|
| 457 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
| 458 |
logging.info("Tokenizers loaded.")
|
| 459 |
+
|
| 460 |
+
# --- Validate Tokenizer Special Tokens Against Config ---
|
| 461 |
+
pad_token = "<pad>"
|
| 462 |
+
sos_token = "<sos>"
|
| 463 |
+
eos_token = "<eos>"
|
| 464 |
+
unk_token = "<unk>"
|
| 465 |
+
|
| 466 |
+
issues = []
|
| 467 |
+
if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
| 468 |
+
issues.append(f"SMILES PAD ID mismatch (tokenizer={smiles_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})")
|
| 469 |
+
if smiles_tokenizer.token_to_id(unk_token) is None:
|
| 470 |
+
issues.append("SMILES UNK token not found")
|
| 471 |
+
|
| 472 |
+
if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
| 473 |
+
issues.append(f"IUPAC PAD ID mismatch (tokenizer={iupac_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})")
|
| 474 |
+
if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
|
| 475 |
+
issues.append(f"IUPAC SOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(sos_token)}, config={config['bos_token_id']})")
|
| 476 |
+
if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
|
| 477 |
+
issues.append(f"IUPAC EOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(eos_token)}, config={config['eos_token_id']})")
|
| 478 |
+
if iupac_tokenizer.token_to_id(unk_token) is None:
|
| 479 |
+
issues.append("IUPAC UNK token not found")
|
| 480 |
+
|
| 481 |
+
if issues:
|
| 482 |
+
logging.warning("Tokenizer validation issues detected: " + "; ".join(issues))
|
| 483 |
+
# Decide if this is fatal or just a warning
|
| 484 |
+
# raise gr.Error("Tokenizer Error: Special token IDs mismatch config. Check tokenizers and config.json.") # Make it fatal if IDs must match
|
| 485 |
+
|
| 486 |
except Exception as e:
|
| 487 |
logging.error(
|
| 488 |
f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
|
|
|
|
| 495 |
# Load model
|
| 496 |
logging.info("Loading model from checkpoint...")
|
| 497 |
try:
|
| 498 |
+
# Load the LightningModule state dict
|
| 499 |
+
# Use the actual vocab sizes and hparams from the loaded config
|
| 500 |
model = SmilesIupacLitModule.load_from_checkpoint(
|
| 501 |
checkpoint_path,
|
| 502 |
+
# Pass necessary __init__ args that might not be in saved hparams automatically
|
| 503 |
+
# Ensure these keys exist in your loaded 'config' dict after validation/mapping
|
| 504 |
+
src_vocab_size=config["actual_src_vocab_size"],
|
| 505 |
+
tgt_vocab_size=config["actual_tgt_vocab_size"],
|
| 506 |
+
hparams_dict=config, # Pass the loaded config as hparams
|
| 507 |
+
map_location=device, # Map model to the chosen device (CPU or CUDA)
|
| 508 |
+
strict=False, # Be less strict about matching keys, useful for PTL versions or minor changes
|
| 509 |
+
# REMOVED invalid argument: device="cpu",
|
| 510 |
)
|
| 511 |
+
|
| 512 |
+
# Ensure model is on the correct device, in eval mode, and frozen
|
| 513 |
model.to(device)
|
| 514 |
model.eval()
|
| 515 |
+
model.freeze() # Disables gradient calculations
|
| 516 |
logging.info(
|
| 517 |
+
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
|
| 518 |
)
|
| 519 |
|
| 520 |
except FileNotFoundError:
|
| 521 |
+
logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}")
|
| 522 |
+
raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
except Exception as e:
|
| 524 |
logging.error(
|
| 525 |
+
f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True
|
|
|
|
| 526 |
)
|
| 527 |
+
# Check for common errors
|
| 528 |
+
if "size mismatch" in str(e):
|
| 529 |
+
error_detail = (f"Potential size mismatch. Check if vocab sizes in config.json ({config.get('actual_src_vocab_size')}, "
|
| 530 |
+
f"{config.get('actual_tgt_vocab_size')}) match the loaded checkpoint's embedding layers.")
|
| 531 |
+
logging.error(error_detail)
|
| 532 |
+
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
|
| 533 |
+
elif "memory" in str(e).lower():
|
| 534 |
+
logging.warning("Potential Out-of-Memory error during model loading.")
|
| 535 |
gc.collect()
|
| 536 |
+
if device.type == 'cuda': torch.cuda.empty_cache()
|
| 537 |
+
raise gr.Error(f"Model Error: Out of memory loading model. Check Space resources. Error: {e}")
|
| 538 |
+
else:
|
| 539 |
+
raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}")
|
|
|
|
| 540 |
|
| 541 |
+
except gr.Error: # Re-raise Gradio errors to be displayed
|
| 542 |
raise
|
| 543 |
+
except Exception as e: # Catch any other unexpected errors
|
| 544 |
+
logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True)
|
| 545 |
+
raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
|
| 548 |
+
# --- Inference Function for Gradio ---
|
| 549 |
+
def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str):
|
| 550 |
"""
|
| 551 |
Performs SMILES to IUPAC translation using the loaded model and beam search.
|
| 552 |
+
Takes string inputs from Gradio sliders/inputs and converts them.
|
| 553 |
"""
|
| 554 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
| 555 |
|
| 556 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
| 557 |
+
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
|
| 558 |
+
logging.error(error_msg)
|
| 559 |
+
# Try to determine n_best for error output formatting
|
| 560 |
+
try: n_best_int = int(n_best_str)
|
| 561 |
+
except: n_best_int = 1
|
|
|
|
| 562 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
| 563 |
|
| 564 |
if not smiles_string or not smiles_string.strip():
|
| 565 |
error_msg = "Error: Please enter a valid SMILES string."
|
| 566 |
+
try: n_best_int = int(n_best_str)
|
| 567 |
+
except: n_best_int = 1
|
|
|
|
|
|
|
| 568 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
| 569 |
|
| 570 |
smiles_input = smiles_string.strip()
|
| 571 |
+
|
| 572 |
+
# --- Safely parse numerical inputs ---
|
| 573 |
try:
|
| 574 |
+
beam_width = int(beam_width_str)
|
| 575 |
+
n_best = int(n_best_str)
|
| 576 |
+
length_penalty = float(length_penalty_str)
|
| 577 |
+
if beam_width < 1 or n_best < 1 or n_best > beam_width:
|
| 578 |
+
raise ValueError("Beam width and n_best must be >= 1, and n_best <= beam width.")
|
| 579 |
+
if length_penalty < 0:
|
| 580 |
+
logging.warning(f"Length penalty {length_penalty} is negative, using 0.0 instead.")
|
| 581 |
+
length_penalty = 0.0
|
| 582 |
except ValueError as e:
|
| 583 |
+
error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
|
| 584 |
+
logging.error(error_msg)
|
| 585 |
+
# Cannot determine n_best if its input was invalid, default to 1 error line
|
| 586 |
+
return f"1. {error_msg}"
|
| 587 |
|
| 588 |
logging.info(
|
| 589 |
+
f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty:.2f})"
|
| 590 |
)
|
| 591 |
|
| 592 |
try:
|
| 593 |
+
# --- Call the core translation logic ---
|
| 594 |
+
# Retrieve necessary IDs from the loaded config
|
| 595 |
+
sos_idx = config['bos_token_id']
|
| 596 |
+
eos_idx = config['eos_token_id']
|
| 597 |
+
pad_idx = config['pad_token_id']
|
| 598 |
+
gen_max_len = config['max_len'] # Max length for generation
|
| 599 |
+
|
| 600 |
predicted_names = translate(
|
| 601 |
model=model,
|
| 602 |
src_sentence=smiles_input,
|
| 603 |
smiles_tokenizer=smiles_tokenizer,
|
| 604 |
iupac_tokenizer=iupac_tokenizer,
|
| 605 |
device=device,
|
| 606 |
+
max_len=gen_max_len, # Pass generation length limit
|
| 607 |
+
sos_idx=sos_idx,
|
| 608 |
+
eos_idx=eos_idx,
|
| 609 |
+
pad_idx=pad_idx,
|
| 610 |
beam_width=beam_width,
|
| 611 |
n_best=n_best,
|
| 612 |
length_penalty=length_penalty,
|
| 613 |
)
|
| 614 |
logging.info(f"Predictions returned: {predicted_names}")
|
| 615 |
|
| 616 |
+
# --- Format Output ---
|
| 617 |
if not predicted_names:
|
| 618 |
+
output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated (beam search might have failed)."
|
| 619 |
else:
|
| 620 |
+
# Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
|
| 621 |
+
display_names = predicted_names[:n_best]
|
| 622 |
+
output_text = (f"Input SMILES: {smiles_input}\n\n"
|
| 623 |
+
f"Top {len(display_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n")
|
| 624 |
output_text += "\n".join(
|
| 625 |
+
[f"{i + 1}. {name}" for i, name in enumerate(display_names)]
|
| 626 |
)
|
| 627 |
+
# Add a note if fewer results than requested were generated
|
| 628 |
+
if len(display_names) < n_best:
|
| 629 |
+
output_text += f"\n\nNote: Only {len(display_names)} result(s) generated successfully."
|
| 630 |
+
|
| 631 |
|
| 632 |
return output_text
|
| 633 |
|
|
|
|
| 636 |
error_msg = f"Runtime Error during translation: {e}"
|
| 637 |
if "memory" in str(e).lower():
|
| 638 |
gc.collect()
|
| 639 |
+
if device.type == 'cuda': torch.cuda.empty_cache()
|
| 640 |
+
error_msg += " (Potential OOM - try reducing beam width or input length)"
|
| 641 |
+
# Return n_best error messages
|
| 642 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
|
| 643 |
|
| 644 |
except Exception as e:
|
|
|
|
| 647 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
|
| 648 |
|
| 649 |
|
| 650 |
+
# --- Load Model on App Start ---
|
| 651 |
+
# Wrap in try/except to prevent app from crashing completely if loading fails
|
| 652 |
+
# The error should be caught and displayed by Gradio via gr.Error raised in the function.
|
| 653 |
try:
|
| 654 |
load_model_and_tokenizers()
|
| 655 |
+
except gr.Error as ge:
|
| 656 |
+
logging.error(f"Gradio Initialization Error: {ge}")
|
| 657 |
+
# Gradio handles displaying gr.Error, but we log it too.
|
| 658 |
+
# We might want to display a placeholder UI or message if loading fails critically.
|
| 659 |
+
pass # Allow Gradio to potentially start with an error message
|
| 660 |
except Exception as e:
|
| 661 |
+
# Catch any non-Gradio errors during the initial load sequence
|
| 662 |
+
logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True)
|
| 663 |
+
# Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
|
| 664 |
+
# raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
|
| 665 |
|
| 666 |
|
| 667 |
+
# --- Create Gradio Interface ---
|
| 668 |
title = "SMILES to IUPAC Name Translator"
|
| 669 |
description = f"""
|
| 670 |
+
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
|
| 671 |
+
Translation uses beam search decoding. Adjust parameters below.
|
| 672 |
+
**Note:** Model loaded on **{str(device).upper()}**. Performance may vary. Check `config.json` in the repo for model details.
|
| 673 |
"""
|
| 674 |
|
| 675 |
+
# Define examples using the input types expected by the interface
|
| 676 |
examples = [
|
| 677 |
+
["CCO", 5, 3, 0.6], # Ethanol
|
| 678 |
+
["C1=CC=CC=C1", 5, 3, 0.6], # Benzene
|
| 679 |
+
["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
|
| 680 |
+
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
|
| 681 |
+
# Very complex example - might take time or fail on CPU/low memory
|
| 682 |
+
# ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
|
| 683 |
+
["INVALID_SMILES", 3, 1, 0.6], # Example of invalid input
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
]
|
| 685 |
|
| 686 |
+
# Ensure input components match the `predict_iupac` function signature order and types
|
| 687 |
smiles_input = gr.Textbox(
|
| 688 |
label="SMILES String",
|
| 689 |
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
|
| 690 |
lines=1,
|
| 691 |
)
|
| 692 |
+
# Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
|
| 693 |
beam_width_input = gr.Slider(
|
| 694 |
+
minimum=1, maximum=10, value=5, step=1, label="Beam Width (k)",
|
| 695 |
+
info="Number of sequences kept at each step (higher = more exploration, slower). Affects memory usage."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
)
|
| 697 |
n_best_input = gr.Slider(
|
| 698 |
+
minimum=1, maximum=10, value=3, step=1, label="Number of Results (n_best)",
|
| 699 |
+
info="How many top sequences to return (must be <= Beam Width)."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
)
|
| 701 |
length_penalty_input = gr.Slider(
|
| 702 |
+
minimum=0.0, maximum=2.0, value=0.6, step=0.1, label="Length Penalty (alpha)",
|
| 703 |
+
info="Controls preference for sequence length. >1 favors longer, <1 favors shorter, 0 no penalty."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
)
|
| 705 |
output_text = gr.Textbox(
|
| 706 |
label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
|
| 707 |
)
|
| 708 |
|
| 709 |
+
# Create the interface instance
|
| 710 |
iface = gr.Interface(
|
| 711 |
+
fn=predict_iupac, # The function to call
|
| 712 |
+
inputs=[ # List of input components
|
| 713 |
+
smiles_input,
|
| 714 |
+
beam_width_input,
|
| 715 |
+
n_best_input,
|
| 716 |
+
length_penalty_input
|
| 717 |
+
],
|
| 718 |
+
outputs=output_text, # Output component
|
| 719 |
title=title,
|
| 720 |
description=description,
|
| 721 |
+
examples=examples, # Examples to populate the interface
|
| 722 |
+
allow_flagging="never", # Disable flagging
|
| 723 |
+
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), # Optional theme
|
| 724 |
+
article="""
|
| 725 |
+
**Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
|
| 726 |
+
Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
|
| 727 |
+
Beam search parameters (width, penalty) significantly impact results and performance.
|
| 728 |
+
""",
|
| 729 |
+
# Optional: Add live=True for real-time updates as sliders change (can be slow/resource intensive)
|
| 730 |
+
# live=False,
|
| 731 |
)
|
| 732 |
|
| 733 |
+
# --- Launch the App ---
|
| 734 |
if __name__ == "__main__":
|
| 735 |
+
# Launch the Gradio interface
|
| 736 |
+
# share=True generates a public link (useful for testing outside HF Spaces, but temporary)
|
| 737 |
+
# Set share=False or remove for deployment on Spaces.
|
| 738 |
+
# Use server_name="0.0.0.0" to make it accessible on the network if running locally
|
| 739 |
+
# Use auth=("username", "password") for basic authentication
|
| 740 |
+
iface.launch() # share=True is deprecated, use launch()
|