AdrianM0 commited on
Commit
028e0b0
·
verified ·
1 Parent(s): 59543a5

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +370 -244
app.py CHANGED
@@ -1,22 +1,23 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
- import torch.nn.functional as F # <--- Added import
5
- import pytorch_lightning as pl # <--- Added import (needed for type hints, model access)
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
- import gc # For garbage collection on potential OOM
12
- import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer)
13
 
14
  # --- Configuration ---
15
- MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator"
16
- CHECKPOINT_FILENAME = "last.ckpt"
 
17
  SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
18
  IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
19
- CONFIG_FILENAME = "config.json"
20
  # --- End Configuration ---
21
 
22
  # --- Logging ---
@@ -24,43 +25,41 @@ logging.basicConfig(
24
  level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
25
  )
26
 
27
- # --- Load Helper Code (Only Model Definition Needed) ---
28
  try:
29
- # We only need the LightningModule definition and the mask function now
 
30
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
31
-
32
  logging.info("Successfully imported from enhanced_trainer.py.")
33
 
34
- # We will define beam_search_decode and translate locally in this file
35
- # REMOVED: from test_ckpt import beam_search_decode, translate
36
 
37
  except ImportError as e:
38
  logging.error(
39
- f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
 
40
  )
41
- gr.Error(
 
42
  f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
43
  )
44
- exit()
45
  except Exception as e:
46
  logging.error(
47
  f"An unexpected error occurred during helper code import: {e}", exc_info=True
48
  )
49
- gr.Error(
50
  f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
51
  )
52
- exit()
53
 
54
  # --- Global Variables (Load Model Once) ---
55
- model: pl.LightningModule | None = None # Added type hint
56
  smiles_tokenizer: Tokenizer | None = None
57
  iupac_tokenizer: Tokenizer | None = None
58
  device: torch.device | None = None
59
  config: dict | None = None
60
 
61
- # --- Beam Search Decoding Logic (Moved from test_ckpt.py) ---
62
-
63
-
64
  def beam_search_decode(
65
  model: pl.LightningModule,
66
  src: torch.Tensor,
@@ -68,36 +67,34 @@ def beam_search_decode(
68
  max_len: int,
69
  sos_idx: int,
70
  eos_idx: int,
71
- pad_idx: int, # Needed for padding mask check if src has padding
72
  device: torch.device,
73
  beam_width: int = 5,
74
- n_best: int = 5, # Number of top sequences to return
75
- length_penalty: float = 0.6, # Alpha for length normalization (0=no penalty, 1=full penalty)
76
  ) -> list[torch.Tensor]:
77
  """
78
  Performs beam search decoding using the LightningModule's model.
79
- (Code copied and pasted from test_ckpt.py)
80
  """
81
- # Ensure model is in eval mode (redundant if called after model.eval(), but safe)
82
- model.eval()
83
- transformer_model = model.model # Access the underlying Seq2SeqTransformer
84
- n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences
85
 
86
  try:
87
  with torch.no_grad():
88
  # --- Encode Source ---
89
  memory = transformer_model.encode(
90
  src, src_padding_mask
91
- ) # [1, src_len, emb_size]
92
  memory = memory.to(device)
93
- # Ensure memory_key_padding_mask is also on the correct device for decode
94
- memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
95
 
96
  # --- Initialize Beams ---
97
  initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
98
  sos_idx
99
- ) # [1, 1]
100
- initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
101
  active_beams = [(initial_beam_seq, initial_beam_score)]
102
  finished_beams = []
103
 
@@ -108,99 +105,125 @@ def beam_search_decode(
108
 
109
  potential_next_beams = []
110
  for current_seq, current_score in active_beams:
 
111
  if current_seq[0, -1].item() == eos_idx:
 
112
  finished_beams.append((current_seq, current_score))
113
  continue
114
 
115
- tgt_input = current_seq # [1, current_len]
 
116
  tgt_seq_len = tgt_input.shape[1]
117
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
118
  device
119
- ) # [curr_len, curr_len]
 
120
  tgt_padding_mask = torch.zeros(
121
  tgt_input.shape, dtype=torch.bool, device=device
122
- ) # [1, curr_len]
123
 
 
124
  decoder_output = transformer_model.decode(
125
  tgt=tgt_input,
126
  memory=memory,
127
  tgt_mask=tgt_mask,
128
  tgt_padding_mask=tgt_padding_mask,
129
  memory_key_padding_mask=memory_key_padding_mask,
130
- ) # [1, curr_len, emb_size]
131
 
 
132
  next_token_logits = transformer_model.generator(
133
- decoder_output[:, -1, :]
134
- ) # [1, tgt_vocab_size]
 
 
135
  log_probs = F.log_softmax(
136
  next_token_logits, dim=-1
137
- ) # [1, tgt_vocab_size]
 
138
 
 
139
  topk_log_probs, topk_indices = torch.topk(
140
- log_probs + current_score, beam_width, dim=-1
141
- )
142
 
 
143
  for i in range(beam_width):
144
  next_token_id = topk_indices[0, i].item()
145
- next_score = topk_log_probs[0, i].reshape(
146
- 1
147
- ) # Keep as tensor [1]
148
  next_token_tensor = torch.tensor(
149
  [[next_token_id]], dtype=torch.long, device=device
150
- ) # [1, 1]
151
  new_seq = torch.cat(
152
  [current_seq, next_token_tensor], dim=1
153
- ) # [1, current_len + 1]
154
  potential_next_beams.append((new_seq, next_score))
155
 
 
 
156
  potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
157
 
 
158
  active_beams = []
159
- added_count = 0
160
  for seq, score in potential_next_beams:
 
 
 
161
  is_finished = seq[0, -1].item() == eos_idx
162
  if is_finished:
163
- finished_beams.append((seq, score))
164
- elif added_count < beam_width:
165
- active_beams.append((seq, score))
166
- added_count += 1
167
- elif added_count >= beam_width:
168
- break
169
-
 
 
 
 
 
 
 
 
 
170
  finished_beams.extend(active_beams)
171
 
172
  # Apply length penalty and sort
173
- # Handle potential division by zero if sequence length is 1 (or 0?)
174
- def get_score(beam_tuple):
175
  seq, score = beam_tuple
176
  seq_len = seq.shape[1]
177
- if length_penalty == 0.0 or seq_len <= 1:
 
178
  return score.item()
179
  else:
180
- # Ensure seq_len is float for pow
181
- return score.item() / (float(seq_len) ** length_penalty)
 
 
 
182
 
183
- finished_beams.sort(key=get_score, reverse=True) # Higher score is better
184
 
 
185
  top_sequences = [
186
- seq[:, 1:] for seq, score in finished_beams[:n_best]
187
- ] # seq shape [1, len] -> [1, len-1]
188
  return top_sequences
189
 
190
  except RuntimeError as e:
191
- logging.error(f"Runtime error during beam search decode: {e}")
192
- if "CUDA out of memory" in str(e):
193
  gc.collect()
194
  torch.cuda.empty_cache()
195
- return [] # Return empty list on error
196
  except Exception as e:
197
  logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
198
  return []
199
 
200
-
201
- # --- Translation Function (Moved from test_ckpt.py) ---
202
-
203
-
204
  def translate(
205
  model: pl.LightningModule,
206
  src_sentence: str,
@@ -217,38 +240,43 @@ def translate(
217
  ) -> list[str]:
218
  """
219
  Translates a single SMILES string using beam search.
220
- (Code copied and pasted from test_ckpt.py)
221
  """
222
- model.eval() # Ensure model is in eval mode
223
  translations = []
 
224
 
225
  # --- Tokenize Source ---
226
  try:
 
 
227
  src_encoded = smiles_tokenizer.encode(src_sentence)
228
  if not src_encoded or not src_encoded.ids:
229
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
230
  return ["[Encoding Error]"] * n_best
231
- src_ids = src_encoded.ids[:max_len] # Truncate source
232
- if not src_ids:
233
- logging.warning(f"Source empty after truncation: {src_sentence}")
234
- return ["[Encoding Error - Empty Src]"] * n_best
235
  except Exception as e:
236
- logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
237
  return ["[Encoding Error]"] * n_best
238
 
239
  # --- Prepare Input Tensor and Mask ---
240
  src = (
241
  torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
242
- ) # [1, src_len]
243
- src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
 
244
 
245
  # --- Perform Beam Search Decoding ---
246
- # Calls the beam_search_decode function defined above in this file
 
 
247
  tgt_tokens_list = beam_search_decode(
248
  model=model,
249
  src=src,
250
  src_padding_mask=src_padding_mask,
251
- max_len=max_len,
252
  sos_idx=sos_idx,
253
  eos_idx=eos_idx,
254
  pad_idx=pad_idx,
@@ -256,25 +284,28 @@ def translate(
256
  beam_width=beam_width,
257
  n_best=n_best,
258
  length_penalty=length_penalty,
259
- ) # Returns list of tensors
260
 
261
  # --- Decode Generated Tokens ---
262
  if not tgt_tokens_list:
263
  logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
 
264
  return ["[Decoding Error - Empty Output]"] * n_best
265
 
266
- for tgt_tokens_tensor in tgt_tokens_list:
267
- if tgt_tokens_tensor.numel() > 0:
268
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
269
  try:
 
270
  translation = iupac_tokenizer.decode(
271
  tgt_tokens, skip_special_tokens=True
272
  )
273
  translations.append(translation)
274
  except Exception as e:
275
- logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
276
  translations.append("[Decoding Error]")
277
  else:
 
278
  translations.append("[Decoding Error - Empty Tensor]")
279
 
280
  # Pad with error messages if fewer than n_best results were generated
@@ -283,39 +314,53 @@ def translate(
283
 
284
  return translations
285
 
286
-
287
- # --- Model/Tokenizer Loading Function (Unchanged) ---
288
  def load_model_and_tokenizers():
289
  """Loads tokenizers, config, and model from Hugging Face Hub."""
290
  global model, smiles_tokenizer, iupac_tokenizer, device, config
291
- if model is not None: # Already loaded
292
  logging.info("Model and tokenizers already loaded.")
293
  return
294
 
295
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
296
  try:
297
- device = torch.device("cpu")
298
- logging.info(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
 
299
 
300
  # Download files from HF Hub
301
  logging.info("Downloading files from Hugging Face Hub...")
302
  try:
 
 
 
 
 
303
  checkpoint_path = hf_hub_download(
304
- repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME
305
  )
306
  smiles_tokenizer_path = hf_hub_download(
307
- repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME
308
  )
309
  iupac_tokenizer_path = hf_hub_download(
310
- repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME
311
  )
312
  config_path = hf_hub_download(
313
- repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME
314
  )
315
  logging.info("Files downloaded successfully.")
316
  except Exception as e:
317
  logging.error(
318
- f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
319
  exc_info=True,
320
  )
321
  raise gr.Error(
@@ -329,41 +374,81 @@ def load_model_and_tokenizers():
329
  config = json.load(f)
330
  logging.info("Configuration loaded.")
331
  # --- Validate essential config keys ---
 
 
 
332
  required_keys = [
333
- "src_vocab_size",
334
- "tgt_vocab_size",
 
 
335
  "emb_size",
336
  "nhead",
337
  "ffn_hid_dim",
338
  "num_encoder_layers",
339
  "num_decoder_layers",
340
  "dropout",
341
- "max_len",
342
- "bos_token_id",
343
- "eos_token_id",
344
- "pad_token_id",
 
 
345
  ]
346
- missing_keys = [key for key in required_keys if key not in config]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  if missing_keys:
348
- raise ValueError(
349
- f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}"
350
- )
351
- # --- End Validation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  except FileNotFoundError:
353
- logging.error(
354
- f"Config file not found locally after download attempt: {config_path}"
355
- )
356
- raise gr.Error(
357
- f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
358
- )
359
  except json.JSONDecodeError as e:
360
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
361
- raise gr.Error(
362
- f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}"
363
- )
364
- except ValueError as e:
365
  logging.error(f"Config validation error: {e}")
366
  raise gr.Error(f"Config Error: {e}")
 
 
 
 
367
 
368
  # Load tokenizers
369
  logging.info("Loading tokenizers...")
@@ -371,25 +456,33 @@ def load_model_and_tokenizers():
371
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
372
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
373
  logging.info("Tokenizers loaded.")
374
- # --- Validate Tokenizer Special Tokens ---
375
- # Add more robust checks if necessary
376
- if (
377
- smiles_tokenizer.token_to_id("<pad>") != config["pad_token_id"]
378
- or smiles_tokenizer.token_to_id("<unk>") is None
379
- ):
380
- logging.warning(
381
- "SMILES tokenizer special tokens might not match config or are missing."
382
- )
383
- if (
384
- iupac_tokenizer.token_to_id("<pad>") != config["pad_token_id"]
385
- or iupac_tokenizer.token_to_id("<sos>") != config["bos_token_id"]
386
- or iupac_tokenizer.token_to_id("<eos>") != config["eos_token_id"]
387
- or iupac_tokenizer.token_to_id("<unk>") is None
388
- ):
389
- logging.warning(
390
- "IUPAC tokenizer special tokens might not match config or are missing."
391
- )
392
- # --- End Validation ---
 
 
 
 
 
 
 
 
393
  except Exception as e:
394
  logging.error(
395
  f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
@@ -402,115 +495,139 @@ def load_model_and_tokenizers():
402
  # Load model
403
  logging.info("Loading model from checkpoint...")
404
  try:
 
 
405
  model = SmilesIupacLitModule.load_from_checkpoint(
406
  checkpoint_path,
407
- src_vocab_size=config["src_vocab_size"],
408
- tgt_vocab_size=config["tgt_vocab_size"],
409
- map_location=device,
410
- hparams_dict=config,
411
- strict=False,
412
- device="cpu",
 
 
413
  )
 
 
414
  model.to(device)
415
  model.eval()
416
- model.freeze()
417
  logging.info(
418
- "Model loaded successfully, set to eval mode, frozen, and moved to device."
419
  )
420
 
421
  except FileNotFoundError:
422
- logging.error(
423
- f"Checkpoint file not found locally after download attempt: {checkpoint_path}"
424
- )
425
- raise gr.Error(
426
- f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
427
- )
428
  except Exception as e:
429
  logging.error(
430
- f"Error loading model from checkpoint {checkpoint_path}: {e}",
431
- exc_info=True,
432
  )
433
- if "memory" in str(e).lower():
 
 
 
 
 
 
 
434
  gc.collect()
435
- if device == torch.device("cuda"):
436
- torch.cuda.empty_cache()
437
- raise gr.Error(
438
- f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
439
- )
440
 
441
- except gr.Error:
442
  raise
443
- except Exception as e:
444
- logging.error(
445
- f"Unexpected error during model/tokenizer loading: {e}", exc_info=True
446
- )
447
- raise gr.Error(
448
- f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
449
- )
450
 
451
 
452
- # --- Inference Function for Gradio (Unchanged, calls local translate) ---
453
- def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
454
  """
455
  Performs SMILES to IUPAC translation using the loaded model and beam search.
 
456
  """
457
  global model, smiles_tokenizer, iupac_tokenizer, device, config
458
 
459
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
460
- error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
461
- # Ensure n_best is int for range, default to 1 if conversion fails early
462
- try:
463
- n_best_int = int(n_best)
464
- except:
465
- n_best_int = 1
466
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
467
 
468
  if not smiles_string or not smiles_string.strip():
469
  error_msg = "Error: Please enter a valid SMILES string."
470
- try:
471
- n_best_int = int(n_best)
472
- except:
473
- n_best_int = 1
474
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
475
 
476
  smiles_input = smiles_string.strip()
 
 
477
  try:
478
- beam_width = int(beam_width)
479
- n_best = int(n_best)
480
- length_penalty = float(length_penalty)
 
 
 
 
 
481
  except ValueError as e:
482
- error_msg = f"Error: Invalid input parameter type ({e})."
483
- return f"1. {error_msg}" # Cannot determine n_best here
 
 
484
 
485
  logging.info(
486
- f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})"
487
  )
488
 
489
  try:
490
- # Calls the translate function defined *above in this file*
 
 
 
 
 
 
491
  predicted_names = translate(
492
  model=model,
493
  src_sentence=smiles_input,
494
  smiles_tokenizer=smiles_tokenizer,
495
  iupac_tokenizer=iupac_tokenizer,
496
  device=device,
497
- max_len=config["max_len"],
498
- sos_idx=config["bos_token_id"],
499
- eos_idx=config["eos_token_id"],
500
- pad_idx=config["pad_token_id"],
501
  beam_width=beam_width,
502
  n_best=n_best,
503
  length_penalty=length_penalty,
504
  )
505
  logging.info(f"Predictions returned: {predicted_names}")
506
 
 
507
  if not predicted_names:
508
- output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
509
  else:
510
- output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
 
 
 
511
  output_text += "\n".join(
512
- [f"{i + 1}. {name}" for i, name in enumerate(predicted_names)]
513
  )
 
 
 
 
514
 
515
  return output_text
516
 
@@ -519,9 +636,9 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
519
  error_msg = f"Runtime Error during translation: {e}"
520
  if "memory" in str(e).lower():
521
  gc.collect()
522
- if device == torch.device("cuda"):
523
- torch.cuda.empty_cache()
524
- error_msg += " (Potential OOM)"
525
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
526
 
527
  except Exception as e:
@@ -530,85 +647,94 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
530
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
531
 
532
 
533
- # --- Load Model on App Start (Unchanged) ---
 
 
534
  try:
535
  load_model_and_tokenizers()
536
- except gr.Error:
537
- pass # Error already raised for Gradio UI
 
 
 
538
  except Exception as e:
539
- logging.error(
540
- f"Critical error during initial model loading sequence: {e}", exc_info=True
541
- )
542
- gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
543
 
544
 
545
- # --- Create Gradio Interface (Unchanged) ---
546
  title = "SMILES to IUPAC Name Translator"
547
  description = f"""
548
- Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model and beam search decoding.
549
- Model repository: <a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>{MODEL_REPO_ID}</a>.
550
- Adjust beam search parameters below. Higher beam width explores more possibilities but is slower. Length penalty influences the preference for shorter/longer names.
551
  """
552
 
 
553
  examples = [
554
- ["CCO", 5, 3, 0.6],
555
- ["C1=CC=CC=C1", 5, 3, 0.6],
556
- ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
557
- ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
558
- [
559
- "CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C",
560
- 5,
561
- 1,
562
- 0.6,
563
- ], # Complex example
564
- ["INVALID_SMILES", 5, 1, 0.6],
565
  ]
566
 
 
567
  smiles_input = gr.Textbox(
568
  label="SMILES String",
569
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
570
  lines=1,
571
  )
 
572
  beam_width_input = gr.Slider(
573
- minimum=1,
574
- maximum=10,
575
- value=5,
576
- step=1,
577
- label="Beam Width (k)",
578
- info="Number of sequences to keep at each decoding step (higher = more exploration, slower).",
579
  )
580
  n_best_input = gr.Slider(
581
- minimum=1,
582
- maximum=10,
583
- value=3,
584
- step=1,
585
- label="Number of Results (n_best)",
586
- info="How many top-scoring sequences to return (must be <= Beam Width).",
587
  )
588
  length_penalty_input = gr.Slider(
589
- minimum=0.0,
590
- maximum=2.0,
591
- value=0.6,
592
- step=0.1,
593
- label="Length Penalty (alpha)",
594
- info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty.",
595
  )
596
  output_text = gr.Textbox(
597
  label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
598
  )
599
 
 
600
  iface = gr.Interface(
601
- fn=predict_iupac,
602
- inputs=[smiles_input, beam_width_input, n_best_input, length_penalty_input],
603
- outputs=output_text,
 
 
 
 
 
604
  title=title,
605
  description=description,
606
- examples=examples,
607
- allow_flagging="never",
608
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
609
- article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results.",
 
 
 
 
 
 
610
  )
611
 
612
- # --- Launch the App (Unchanged) ---
613
  if __name__ == "__main__":
614
- iface.launch(share=True)
 
 
 
 
 
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ import torch.nn.functional as F # Needed for beam search log_softmax
5
+ import pytorch_lightning as pl # Needed for LightningModule and loading
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
+ import gc # For garbage collection on potential OOM
12
+ import math # Potentially needed by imported classes
13
 
14
  # --- Configuration ---
15
+ # Ensure these match the files uploaded to your Hugging Face Hub repository
16
+ MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
17
+ CHECKPOINT_FILENAME = "last.ckpt" # Or "best_model.ckpt" or whatever you uploaded
18
  SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
19
  IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
20
+ CONFIG_FILENAME = "config.json" # Assumes you saved hparams to config.json during/after training
21
  # --- End Configuration ---
22
 
23
  # --- Logging ---
 
25
  level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
26
  )
27
 
28
+ # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
29
  try:
30
+ # We need the LightningModule definition and the mask function
31
+ # Ensure enhanced_trainer.py is present in the root of your HF Repo
32
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
 
33
  logging.info("Successfully imported from enhanced_trainer.py.")
34
 
35
+ # REMOVED: Redundant import from test_ckpt as functions are defined below
36
+ # from test_ckpt import beam_search_decode, translate
37
 
38
  except ImportError as e:
39
  logging.error(
40
+ f"Failed to import helper code from enhanced_trainer.py: {e}. "
41
+ f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
42
  )
43
+ # Raise error visible in Gradio UI and logs
44
+ raise gr.Error(
45
  f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
46
  )
 
47
  except Exception as e:
48
  logging.error(
49
  f"An unexpected error occurred during helper code import: {e}", exc_info=True
50
  )
51
+ raise gr.Error(
52
  f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
53
  )
 
54
 
55
  # --- Global Variables (Load Model Once) ---
56
+ model: pl.LightningModule | None = None
57
  smiles_tokenizer: Tokenizer | None = None
58
  iupac_tokenizer: Tokenizer | None = None
59
  device: torch.device | None = None
60
  config: dict | None = None
61
 
62
+ # --- Beam Search Decoding Logic (Locally defined) ---
 
 
63
  def beam_search_decode(
64
  model: pl.LightningModule,
65
  src: torch.Tensor,
 
67
  max_len: int,
68
  sos_idx: int,
69
  eos_idx: int,
70
+ pad_idx: int,
71
  device: torch.device,
72
  beam_width: int = 5,
73
+ n_best: int = 5,
74
+ length_penalty: float = 0.6,
75
  ) -> list[torch.Tensor]:
76
  """
77
  Performs beam search decoding using the LightningModule's model.
78
+ (Ensures this code is self-contained within app.py or correctly imported)
79
  """
80
+ model.eval() # Ensure model is in evaluation mode
81
+ transformer_model = model.model # Access the underlying Seq2SeqTransformer
82
+ n_best = min(n_best, beam_width)
 
83
 
84
  try:
85
  with torch.no_grad():
86
  # --- Encode Source ---
87
  memory = transformer_model.encode(
88
  src, src_padding_mask
89
+ ) # [1, src_len, emb_size]
90
  memory = memory.to(device)
91
+ memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
 
92
 
93
  # --- Initialize Beams ---
94
  initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
95
  sos_idx
96
+ ) # [1, 1]
97
+ initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
98
  active_beams = [(initial_beam_seq, initial_beam_score)]
99
  finished_beams = []
100
 
 
105
 
106
  potential_next_beams = []
107
  for current_seq, current_score in active_beams:
108
+ # Check if the beam already ended
109
  if current_seq[0, -1].item() == eos_idx:
110
+ # If already finished, add directly to finished beams and skip expansion
111
  finished_beams.append((current_seq, current_score))
112
  continue
113
 
114
+ # Prepare inputs for the decoder
115
+ tgt_input = current_seq # [1, current_len]
116
  tgt_seq_len = tgt_input.shape[1]
117
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
118
  device
119
+ ) # [curr_len, curr_len]
120
+ # No padding in target during generation yet
121
  tgt_padding_mask = torch.zeros(
122
  tgt_input.shape, dtype=torch.bool, device=device
123
+ ) # [1, curr_len]
124
 
125
+ # Decode one step
126
  decoder_output = transformer_model.decode(
127
  tgt=tgt_input,
128
  memory=memory,
129
  tgt_mask=tgt_mask,
130
  tgt_padding_mask=tgt_padding_mask,
131
  memory_key_padding_mask=memory_key_padding_mask,
132
+ ) # [1, curr_len, emb_size]
133
 
134
+ # Get logits for the *next* token prediction
135
  next_token_logits = transformer_model.generator(
136
+ decoder_output[:, -1, :] # Use output corresponding to the last input token
137
+ ) # [1, tgt_vocab_size]
138
+
139
+ # Calculate log probabilities and add current beam score
140
  log_probs = F.log_softmax(
141
  next_token_logits, dim=-1
142
+ ) # [1, tgt_vocab_size]
143
+ combined_scores = log_probs + current_score # Add score of the current path
144
 
145
+ # Find top k candidates for the *next* step
146
  topk_log_probs, topk_indices = torch.topk(
147
+ combined_scores, beam_width, dim=-1
148
+ ) # [1, beam_width], [1, beam_width]
149
 
150
+ # Expand potential beams
151
  for i in range(beam_width):
152
  next_token_id = topk_indices[0, i].item()
153
+ # Score is the cumulative log probability of the new sequence
154
+ next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1]
 
155
  next_token_tensor = torch.tensor(
156
  [[next_token_id]], dtype=torch.long, device=device
157
+ ) # [1, 1]
158
  new_seq = torch.cat(
159
  [current_seq, next_token_tensor], dim=1
160
+ ) # [1, current_len + 1]
161
  potential_next_beams.append((new_seq, next_score))
162
 
163
+ # --- Prune Beams ---
164
+ # Sort all potential next beams by score
165
  potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
166
 
167
+ # Select the top `beam_width` beams for the next iteration
168
  active_beams = []
169
+ temp_finished_beams = [] # Collect beams finished in *this* step
170
  for seq, score in potential_next_beams:
171
+ if len(active_beams) >= beam_width and len(temp_finished_beams) >= beam_width:
172
+ break # Optimization: Stop if we have enough active and finished candidates
173
+
174
  is_finished = seq[0, -1].item() == eos_idx
175
  if is_finished:
176
+ # Add to temporary finished list for this step
177
+ if len(temp_finished_beams) < beam_width:
178
+ temp_finished_beams.append((seq, score))
179
+ elif len(active_beams) < beam_width:
180
+ # Add to active beams for next step
181
+ active_beams.append((seq, score))
182
+
183
+ # Add the newly finished beams to the main finished list
184
+ finished_beams.extend(temp_finished_beams)
185
+ # Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
186
+ finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
187
+ finished_beams = finished_beams[:beam_width * 2] # Keep a reasonable number
188
+
189
+
190
+ # --- Final Selection ---
191
+ # Add any remaining active beams (which didn't finish) to the finished list
192
  finished_beams.extend(active_beams)
193
 
194
  # Apply length penalty and sort
195
+ def get_score_with_penalty(beam_tuple):
 
196
  seq, score = beam_tuple
197
  seq_len = seq.shape[1]
198
+ # Avoid division by zero or negative exponent issues
199
+ if length_penalty <= 0.0 or seq_len <= 1:
200
  return score.item()
201
  else:
202
+ # Length penalty calculation
203
+ penalty = ((5.0 + float(seq_len)) / 6.0) ** length_penalty # Common formula
204
+ return score.item() / penalty
205
+ # Alternative simpler penalty:
206
+ # return score.item() / (float(seq_len) ** length_penalty)
207
 
208
+ finished_beams.sort(key=get_score_with_penalty, reverse=True) # Higher score is better
209
 
210
+ # Return the top n_best sequences (excluding the initial SOS token)
211
  top_sequences = [
212
+ seq[:, 1:] for seq, score in finished_beams[:n_best] if seq.shape[1] > 1 # Ensure seq not just SOS
213
+ ] # seq shape [1, len] -> [1, len-1]
214
  return top_sequences
215
 
216
  except RuntimeError as e:
217
+ logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
218
+ if "CUDA out of memory" in str(e) and device.type == 'cuda':
219
  gc.collect()
220
  torch.cuda.empty_cache()
221
+ return [] # Return empty list on error
222
  except Exception as e:
223
  logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
224
  return []
225
 
226
+ # --- Translation Function (Locally defined) ---
 
 
 
227
  def translate(
228
  model: pl.LightningModule,
229
  src_sentence: str,
 
240
  ) -> list[str]:
241
  """
242
  Translates a single SMILES string using beam search.
243
+ (Ensures this code is self-contained within app.py or correctly imported)
244
  """
245
+ model.eval() # Ensure model is in eval mode
246
  translations = []
247
+ n_best = min(n_best, beam_width) # Can't return more than beam width
248
 
249
  # --- Tokenize Source ---
250
  try:
251
+ # Ensure tokenizer has truncation/padding configured if needed, or handle manually
252
+ smiles_tokenizer.enable_truncation(max_length=max_len)
253
  src_encoded = smiles_tokenizer.encode(src_sentence)
254
  if not src_encoded or not src_encoded.ids:
255
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
256
  return ["[Encoding Error]"] * n_best
257
+ # Use the truncated IDs directly
258
+ src_ids = src_encoded.ids
259
+ # Note: max_len here applies to source *tokenizer*, generation length is separate
 
260
  except Exception as e:
261
+ logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
262
  return ["[Encoding Error]"] * n_best
263
 
264
  # --- Prepare Input Tensor and Mask ---
265
  src = (
266
  torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
267
+ ) # [1, src_len]
268
+ # Create padding mask (True where it's a pad token, should be all False here)
269
+ src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
270
 
271
  # --- Perform Beam Search Decoding ---
272
+ # Calls the beam_search_decode function defined *above in this file*
273
+ # Note: max_len for generation should come from config if it dictates output length
274
+ generation_max_len = config.get("max_len", 256) # Use config max_len for output limit
275
  tgt_tokens_list = beam_search_decode(
276
  model=model,
277
  src=src,
278
  src_padding_mask=src_padding_mask,
279
+ max_len=generation_max_len, # Use generation limit
280
  sos_idx=sos_idx,
281
  eos_idx=eos_idx,
282
  pad_idx=pad_idx,
 
284
  beam_width=beam_width,
285
  n_best=n_best,
286
  length_penalty=length_penalty,
287
+ ) # Returns list of tensors
288
 
289
  # --- Decode Generated Tokens ---
290
  if not tgt_tokens_list:
291
  logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
292
+ # Provide n_best error messages
293
  return ["[Decoding Error - Empty Output]"] * n_best
294
 
295
+ for i, tgt_tokens_tensor in enumerate(tgt_tokens_list):
296
+ if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0:
297
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
298
  try:
299
+ # Decode using the target tokenizer, skipping special tokens
300
  translation = iupac_tokenizer.decode(
301
  tgt_tokens, skip_special_tokens=True
302
  )
303
  translations.append(translation)
304
  except Exception as e:
305
+ logging.error(f"Error decoding target tokens {tgt_tokens} for beam {i}: {e}", exc_info=True)
306
  translations.append("[Decoding Error]")
307
  else:
308
+ logging.warning(f"Beam {i} result was empty or None for SMILES: {src_sentence}")
309
  translations.append("[Decoding Error - Empty Tensor]")
310
 
311
  # Pad with error messages if fewer than n_best results were generated
 
314
 
315
  return translations
316
 
317
+ # --- Model/Tokenizer Loading Function ---
 
318
  def load_model_and_tokenizers():
319
  """Loads tokenizers, config, and model from Hugging Face Hub."""
320
  global model, smiles_tokenizer, iupac_tokenizer, device, config
321
+ if model is not None: # Already loaded
322
  logging.info("Model and tokenizers already loaded.")
323
  return
324
 
325
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
326
  try:
327
+ # Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
328
+ # For simplicity and broader compatibility on free tier Spaces, CPU is safer.
329
+ if torch.cuda.is_available():
330
+ logging.warning("CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended.")
331
+ device = torch.device("cpu")
332
+ # Uncomment below and comment above line to try using GPU if available
333
+ # device = torch.device("cuda")
334
+ # logging.info("CUDA available, using GPU.")
335
+ else:
336
+ device = torch.device("cpu")
337
+ logging.info("CUDA not available, using CPU.")
338
+
339
 
340
  # Download files from HF Hub
341
  logging.info("Downloading files from Hugging Face Hub...")
342
  try:
343
+ # Use cache directory for Spaces persistence if possible
344
+ cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache") # Gradio sets cache dir
345
+ os.makedirs(cache_dir, exist_ok=True)
346
+ logging.info(f"Using cache directory: {cache_dir}")
347
+
348
  checkpoint_path = hf_hub_download(
349
+ repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
350
  )
351
  smiles_tokenizer_path = hf_hub_download(
352
+ repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir
353
  )
354
  iupac_tokenizer_path = hf_hub_download(
355
+ repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir
356
  )
357
  config_path = hf_hub_download(
358
+ repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
359
  )
360
  logging.info("Files downloaded successfully.")
361
  except Exception as e:
362
  logging.error(
363
+ f"Failed to download files from {MODEL_REPO_ID}. Check filenames ({CHECKPOINT_FILENAME}, {SMILES_TOKENIZER_FILENAME}, etc.) and repo status. Error: {e}",
364
  exc_info=True,
365
  )
366
  raise gr.Error(
 
374
  config = json.load(f)
375
  logging.info("Configuration loaded.")
376
  # --- Validate essential config keys ---
377
+ # Use hparams logged during training if available, map them carefully
378
+ # These keys are based on SmilesIupacLitModule and Seq2SeqTransformer init args
379
+ # Mappings might be needed if keys in config.json differ from these exact names
380
  required_keys = [
381
+ # Need vocab sizes used during *training* for loading
382
+ "actual_src_vocab_size", # Assuming this was saved in hparams
383
+ "actual_tgt_vocab_size", # Assuming this was saved in hparams
384
+ # Model architecture params
385
  "emb_size",
386
  "nhead",
387
  "ffn_hid_dim",
388
  "num_encoder_layers",
389
  "num_decoder_layers",
390
  "dropout",
391
+ "max_len", # Needed for generation limit and tokenizer setting
392
+ # Special token IDs needed for generation
393
+ # Assuming standard names, adjust if your config uses different keys
394
+ "pad_token_id", # Often 0
395
+ "bos_token_id", # Often 1 (used as SOS)
396
+ "eos_token_id", # Often 2
397
  ]
398
+ # Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
399
+ config_key_mapping = {
400
+ "actual_src_vocab_size": config.get("actual_src_vocab_size", config.get("src_vocab_size")),
401
+ "actual_tgt_vocab_size": config.get("actual_tgt_vocab_size", config.get("tgt_vocab_size")),
402
+ "emb_size": config.get("emb_size"),
403
+ "nhead": config.get("nhead"),
404
+ "ffn_hid_dim": config.get("ffn_hid_dim"),
405
+ "num_encoder_layers": config.get("num_encoder_layers"),
406
+ "num_decoder_layers": config.get("num_decoder_layers"),
407
+ "dropout": config.get("dropout"),
408
+ "max_len": config.get("max_len"),
409
+ "pad_token_id": config.get("pad_token_id", PAD_IDX), # Use default if missing? Risky.
410
+ "bos_token_id": config.get("bos_token_id", SOS_IDX), # Use default if missing? Risky.
411
+ "eos_token_id": config.get("eos_token_id", EOS_IDX), # Use default if missing? Risky.
412
+ }
413
+ # Update config with potentially remapped values
414
+ config.update(config_key_mapping)
415
+
416
+ missing_keys = [key for key in required_keys if config.get(key) is None]
417
  if missing_keys:
418
+ # Try to load defaults for token IDs if absolutely necessary, but warn heavily
419
+ defaults_used = []
420
+ if "pad_token_id" in missing_keys and 'PAD_IDX' in globals(): config["pad_token_id"] = PAD_IDX; defaults_used.append("pad_token_id")
421
+ if "bos_token_id" in missing_keys and 'SOS_IDX' in globals(): config["bos_token_id"] = SOS_IDX; defaults_used.append("bos_token_id")
422
+ if "eos_token_id" in missing_keys and 'EOS_IDX' in globals(): config["eos_token_id"] = EOS_IDX; defaults_used.append("eos_token_id")
423
+
424
+ # Re-check missing keys after attempting defaults
425
+ missing_keys = [key for key in required_keys if config.get(key) is None]
426
+ if missing_keys:
427
+ raise ValueError(
428
+ f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
429
+ f"Ensure these were saved in the hyperparameters during training."
430
+ )
431
+ else:
432
+ logging.warning(f"Config file was missing keys, used defaults for: {defaults_used}. This might be incorrect!")
433
+
434
+ # Log the final config values being used
435
+ logging.info(f"Using config values: src_vocab={config['actual_src_vocab_size']}, tgt_vocab={config['actual_tgt_vocab_size']}, "
436
+ f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
437
+ f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}")
438
+
439
  except FileNotFoundError:
440
+ logging.error(f"Config file not found locally after download attempt: {config_path}")
441
+ raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.")
 
 
 
 
442
  except json.JSONDecodeError as e:
443
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
444
+ raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}")
445
+ except ValueError as e: # Catch our custom validation error
 
 
446
  logging.error(f"Config validation error: {e}")
447
  raise gr.Error(f"Config Error: {e}")
448
+ except Exception as e: # Catch other potential errors during config processing
449
+ logging.error(f"Unexpected error loading or validating config: {e}", exc_info=True)
450
+ raise gr.Error(f"Config Error: Unexpected error processing config. Check logs. Error: {e}")
451
+
452
 
453
  # Load tokenizers
454
  logging.info("Loading tokenizers...")
 
456
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
457
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
458
  logging.info("Tokenizers loaded.")
459
+
460
+ # --- Validate Tokenizer Special Tokens Against Config ---
461
+ pad_token = "<pad>"
462
+ sos_token = "<sos>"
463
+ eos_token = "<eos>"
464
+ unk_token = "<unk>"
465
+
466
+ issues = []
467
+ if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
468
+ issues.append(f"SMILES PAD ID mismatch (tokenizer={smiles_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})")
469
+ if smiles_tokenizer.token_to_id(unk_token) is None:
470
+ issues.append("SMILES UNK token not found")
471
+
472
+ if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
473
+ issues.append(f"IUPAC PAD ID mismatch (tokenizer={iupac_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})")
474
+ if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
475
+ issues.append(f"IUPAC SOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(sos_token)}, config={config['bos_token_id']})")
476
+ if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
477
+ issues.append(f"IUPAC EOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(eos_token)}, config={config['eos_token_id']})")
478
+ if iupac_tokenizer.token_to_id(unk_token) is None:
479
+ issues.append("IUPAC UNK token not found")
480
+
481
+ if issues:
482
+ logging.warning("Tokenizer validation issues detected: " + "; ".join(issues))
483
+ # Decide if this is fatal or just a warning
484
+ # raise gr.Error("Tokenizer Error: Special token IDs mismatch config. Check tokenizers and config.json.") # Make it fatal if IDs must match
485
+
486
  except Exception as e:
487
  logging.error(
488
  f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
 
495
  # Load model
496
  logging.info("Loading model from checkpoint...")
497
  try:
498
+ # Load the LightningModule state dict
499
+ # Use the actual vocab sizes and hparams from the loaded config
500
  model = SmilesIupacLitModule.load_from_checkpoint(
501
  checkpoint_path,
502
+ # Pass necessary __init__ args that might not be in saved hparams automatically
503
+ # Ensure these keys exist in your loaded 'config' dict after validation/mapping
504
+ src_vocab_size=config["actual_src_vocab_size"],
505
+ tgt_vocab_size=config["actual_tgt_vocab_size"],
506
+ hparams_dict=config, # Pass the loaded config as hparams
507
+ map_location=device, # Map model to the chosen device (CPU or CUDA)
508
+ strict=False, # Be less strict about matching keys, useful for PTL versions or minor changes
509
+ # REMOVED invalid argument: device="cpu",
510
  )
511
+
512
+ # Ensure model is on the correct device, in eval mode, and frozen
513
  model.to(device)
514
  model.eval()
515
+ model.freeze() # Disables gradient calculations
516
  logging.info(
517
+ f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
518
  )
519
 
520
  except FileNotFoundError:
521
+ logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}")
522
+ raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
 
 
 
 
523
  except Exception as e:
524
  logging.error(
525
+ f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True
 
526
  )
527
+ # Check for common errors
528
+ if "size mismatch" in str(e):
529
+ error_detail = (f"Potential size mismatch. Check if vocab sizes in config.json ({config.get('actual_src_vocab_size')}, "
530
+ f"{config.get('actual_tgt_vocab_size')}) match the loaded checkpoint's embedding layers.")
531
+ logging.error(error_detail)
532
+ raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
533
+ elif "memory" in str(e).lower():
534
+ logging.warning("Potential Out-of-Memory error during model loading.")
535
  gc.collect()
536
+ if device.type == 'cuda': torch.cuda.empty_cache()
537
+ raise gr.Error(f"Model Error: Out of memory loading model. Check Space resources. Error: {e}")
538
+ else:
539
+ raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}")
 
540
 
541
+ except gr.Error: # Re-raise Gradio errors to be displayed
542
  raise
543
+ except Exception as e: # Catch any other unexpected errors
544
+ logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True)
545
+ raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}")
 
 
 
 
546
 
547
 
548
+ # --- Inference Function for Gradio ---
549
+ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str):
550
  """
551
  Performs SMILES to IUPAC translation using the loaded model and beam search.
552
+ Takes string inputs from Gradio sliders/inputs and converts them.
553
  """
554
  global model, smiles_tokenizer, iupac_tokenizer, device, config
555
 
556
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
557
+ error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
558
+ logging.error(error_msg)
559
+ # Try to determine n_best for error output formatting
560
+ try: n_best_int = int(n_best_str)
561
+ except: n_best_int = 1
 
562
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
563
 
564
  if not smiles_string or not smiles_string.strip():
565
  error_msg = "Error: Please enter a valid SMILES string."
566
+ try: n_best_int = int(n_best_str)
567
+ except: n_best_int = 1
 
 
568
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
569
 
570
  smiles_input = smiles_string.strip()
571
+
572
+ # --- Safely parse numerical inputs ---
573
  try:
574
+ beam_width = int(beam_width_str)
575
+ n_best = int(n_best_str)
576
+ length_penalty = float(length_penalty_str)
577
+ if beam_width < 1 or n_best < 1 or n_best > beam_width:
578
+ raise ValueError("Beam width and n_best must be >= 1, and n_best <= beam width.")
579
+ if length_penalty < 0:
580
+ logging.warning(f"Length penalty {length_penalty} is negative, using 0.0 instead.")
581
+ length_penalty = 0.0
582
  except ValueError as e:
583
+ error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
584
+ logging.error(error_msg)
585
+ # Cannot determine n_best if its input was invalid, default to 1 error line
586
+ return f"1. {error_msg}"
587
 
588
  logging.info(
589
+ f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty:.2f})"
590
  )
591
 
592
  try:
593
+ # --- Call the core translation logic ---
594
+ # Retrieve necessary IDs from the loaded config
595
+ sos_idx = config['bos_token_id']
596
+ eos_idx = config['eos_token_id']
597
+ pad_idx = config['pad_token_id']
598
+ gen_max_len = config['max_len'] # Max length for generation
599
+
600
  predicted_names = translate(
601
  model=model,
602
  src_sentence=smiles_input,
603
  smiles_tokenizer=smiles_tokenizer,
604
  iupac_tokenizer=iupac_tokenizer,
605
  device=device,
606
+ max_len=gen_max_len, # Pass generation length limit
607
+ sos_idx=sos_idx,
608
+ eos_idx=eos_idx,
609
+ pad_idx=pad_idx,
610
  beam_width=beam_width,
611
  n_best=n_best,
612
  length_penalty=length_penalty,
613
  )
614
  logging.info(f"Predictions returned: {predicted_names}")
615
 
616
+ # --- Format Output ---
617
  if not predicted_names:
618
+ output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated (beam search might have failed)."
619
  else:
620
+ # Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
621
+ display_names = predicted_names[:n_best]
622
+ output_text = (f"Input SMILES: {smiles_input}\n\n"
623
+ f"Top {len(display_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n")
624
  output_text += "\n".join(
625
+ [f"{i + 1}. {name}" for i, name in enumerate(display_names)]
626
  )
627
+ # Add a note if fewer results than requested were generated
628
+ if len(display_names) < n_best:
629
+ output_text += f"\n\nNote: Only {len(display_names)} result(s) generated successfully."
630
+
631
 
632
  return output_text
633
 
 
636
  error_msg = f"Runtime Error during translation: {e}"
637
  if "memory" in str(e).lower():
638
  gc.collect()
639
+ if device.type == 'cuda': torch.cuda.empty_cache()
640
+ error_msg += " (Potential OOM - try reducing beam width or input length)"
641
+ # Return n_best error messages
642
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
643
 
644
  except Exception as e:
 
647
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
648
 
649
 
650
+ # --- Load Model on App Start ---
651
+ # Wrap in try/except to prevent app from crashing completely if loading fails
652
+ # The error should be caught and displayed by Gradio via gr.Error raised in the function.
653
  try:
654
  load_model_and_tokenizers()
655
+ except gr.Error as ge:
656
+ logging.error(f"Gradio Initialization Error: {ge}")
657
+ # Gradio handles displaying gr.Error, but we log it too.
658
+ # We might want to display a placeholder UI or message if loading fails critically.
659
+ pass # Allow Gradio to potentially start with an error message
660
  except Exception as e:
661
+ # Catch any non-Gradio errors during the initial load sequence
662
+ logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True)
663
+ # Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
664
+ # raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
665
 
666
 
667
+ # --- Create Gradio Interface ---
668
  title = "SMILES to IUPAC Name Translator"
669
  description = f"""
670
+ Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
671
+ Translation uses beam search decoding. Adjust parameters below.
672
+ **Note:** Model loaded on **{str(device).upper()}**. Performance may vary. Check `config.json` in the repo for model details.
673
  """
674
 
675
+ # Define examples using the input types expected by the interface
676
  examples = [
677
+ ["CCO", 5, 3, 0.6], # Ethanol
678
+ ["C1=CC=CC=C1", 5, 3, 0.6], # Benzene
679
+ ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
680
+ ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
681
+ # Very complex example - might take time or fail on CPU/low memory
682
+ # ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
683
+ ["INVALID_SMILES", 3, 1, 0.6], # Example of invalid input
 
 
 
 
684
  ]
685
 
686
+ # Ensure input components match the `predict_iupac` function signature order and types
687
  smiles_input = gr.Textbox(
688
  label="SMILES String",
689
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
690
  lines=1,
691
  )
692
+ # Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
693
  beam_width_input = gr.Slider(
694
+ minimum=1, maximum=10, value=5, step=1, label="Beam Width (k)",
695
+ info="Number of sequences kept at each step (higher = more exploration, slower). Affects memory usage."
 
 
 
 
696
  )
697
  n_best_input = gr.Slider(
698
+ minimum=1, maximum=10, value=3, step=1, label="Number of Results (n_best)",
699
+ info="How many top sequences to return (must be <= Beam Width)."
 
 
 
 
700
  )
701
  length_penalty_input = gr.Slider(
702
+ minimum=0.0, maximum=2.0, value=0.6, step=0.1, label="Length Penalty (alpha)",
703
+ info="Controls preference for sequence length. >1 favors longer, <1 favors shorter, 0 no penalty."
 
 
 
 
704
  )
705
  output_text = gr.Textbox(
706
  label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
707
  )
708
 
709
+ # Create the interface instance
710
  iface = gr.Interface(
711
+ fn=predict_iupac, # The function to call
712
+ inputs=[ # List of input components
713
+ smiles_input,
714
+ beam_width_input,
715
+ n_best_input,
716
+ length_penalty_input
717
+ ],
718
+ outputs=output_text, # Output component
719
  title=title,
720
  description=description,
721
+ examples=examples, # Examples to populate the interface
722
+ allow_flagging="never", # Disable flagging
723
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), # Optional theme
724
+ article="""
725
+ **Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
726
+ Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
727
+ Beam search parameters (width, penalty) significantly impact results and performance.
728
+ """,
729
+ # Optional: Add live=True for real-time updates as sliders change (can be slow/resource intensive)
730
+ # live=False,
731
  )
732
 
733
+ # --- Launch the App ---
734
  if __name__ == "__main__":
735
+ # Launch the Gradio interface
736
+ # share=True generates a public link (useful for testing outside HF Spaces, but temporary)
737
+ # Set share=False or remove for deployment on Spaces.
738
+ # Use server_name="0.0.0.0" to make it accessible on the network if running locally
739
+ # Use auth=("username", "password") for basic authentication
740
+ iface.launch() # share=True is deprecated, use launch()