import os import time import threading from typing import Optional, Dict, Any import gradio as gr from huggingface_hub import HfApi, create_repo, hf_hub_url DEFAULT_BASE_MODEL = "dmis-lab/biobert-base-cased-v1.1" DEFAULT_DATASET = "conll2003" # stronger baseline default TARGET_REPO = os.getenv("MEDVLLM_TARGET_REPO", "Junaidi-AI/med-vllm") def _train_ner_lora( base_model: str, dataset_name: str, output_dir: str, num_train_epochs: int = 1, per_device_train_batch_size: int = 8, learning_rate: float = 2e-5, lora_r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.1, trust_dataset_scripts: bool = True, log_cb=None, ) -> Dict[str, Any]: """ Minimal LoRA token-classification trainer. Uses conll2003 by default to be robust in Spaces. Extend to medical datasets later. """ # Avoid importing any local dataset scripts even if present in working dir os.environ.setdefault("HF_DATASETS_DISABLE_LOCAL_IMPORTS", "1") from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer, ) from transformers.trainer_utils import set_seed from seqeval.metrics import f1_score, accuracy_score, precision_score, recall_score from peft import LoraConfig, get_peft_model, TaskType def log(msg: str): if log_cb: log_cb(msg) else: print(msg) set_seed(42) ds_spec = (dataset_name or "").strip() log(f"Loading dataset: {ds_spec}") # Support optional config via 'name:config' (e.g., 'wikiann:en') try: # Medical aliases -> BigBio NER configs alias_map = { # BigBio script-based configs (preferred with datasets<3.0) "bc5cdr": ("bigbio/bc5cdr", "bigbio_ner"), "ncbi_disease": ("bigbio/ncbi_disease", "bigbio_ner"), } lower_spec = ds_spec.lower() if lower_spec in alias_map: repo_id, subset = alias_map[lower_spec] # 1) Try script loader first (requires datasets<3.0) try: log(f"Trying BigBio script loader: load_dataset('{repo_id}', '{subset}')") ds = load_dataset(repo_id, subset, trust_remote_code=trust_dataset_scripts) except Exception as e_script: log(f"Script loader failed: {e_script}") # 2) Fallback to Parquet discovery via HTTPS log("Falling back to Parquet discovery via refs/convert/parquet") api = HfApi() files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", revision="refs/convert/parquet") def split_files(split: str): shard_prefix = f"{subset}/{split}-" dir_prefix = f"{subset}/{split}/" out = [] for path in files: if not path.endswith(".parquet"): continue if path.startswith(shard_prefix) or path.startswith(dir_prefix): out.append( hf_hub_url(repo_id=repo_id, filename=path, repo_type="dataset", revision="refs/convert/parquet") ) return sorted(out) train_files = split_files("train") val_files = split_files("validation") or split_files("valid") or split_files("dev") test_files = split_files("test") if not train_files: raise RuntimeError("No train parquet files found for BigBio subset; merge PR to pin datasets<3.0 or choose another dataset") data_files = {"train": train_files} if val_files: data_files["validation"] = val_files if test_files: data_files["test"] = test_files ds = load_dataset("parquet", data_files=data_files) elif ":" in ds_spec: ds_name, ds_config = [s.strip() for s in ds_spec.split(":", 1)] # Respect UI toggle for trusting dataset scripts trust = trust_dataset_scripts or ("/" in ds_name) ds = load_dataset(ds_name, ds_config, trust_remote_code=trust) else: trust = trust_dataset_scripts or ("/" in ds_spec) ds = load_dataset(ds_spec, trust_remote_code=trust) except Exception as e: # Fallback: if it looks like 'name:config' but was treated as a local path, try explicit two-arg call err_msg = str(e) log(f"Dataset load failed: {err_msg}") if ":" in ds_spec: try: ds_name, ds_config = [s.strip() for s in ds_spec.split(":", 1)] log(f"Retrying with split name/config: {ds_name}, {ds_config}") trust = trust_dataset_scripts or ("/" in ds_name) ds = load_dataset(ds_name, ds_config, trust_remote_code=trust) except Exception as e2: log(f"Retry failed: {e2}") raise else: raise if "train" not in ds: raise RuntimeError("Dataset must have a train split") # Detect token and label columns across common schemas features = ds["train"].features token_candidates = ["tokens", "words"] tag_candidates = ["ner_tags", "tags", "labels", "ner_tags_general"] token_col = next((c for c in token_candidates if c in features), None) tag_col = next((c for c in tag_candidates if c in features), None) if not token_col or not tag_col: raise RuntimeError( "Dataset must provide token and tag columns. Looked for tokens/words and ner_tags/tags/labels." ) label_list = ds["train"].features[tag_col].feature.names id2label = {i: l for i, l in enumerate(label_list)} label2id = {l: i for i, l in enumerate(label_list)} log(f"Loading tokenizer/model: {base_model}") tokenizer = AutoTokenizer.from_pretrained(base_model) base = AutoModelForTokenClassification.from_pretrained( base_model, num_labels=len(label_list), id2label=id2label, label2id=label2id ) peft_config = LoraConfig( task_type=TaskType.TOKEN_CLS, inference_mode=False, r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, ) model = get_peft_model(base, peft_config) # Tokenize with alignment def tokenize_align(batch): tokenized = tokenizer( batch[token_col], is_split_into_words=True, truncation=True, padding=False ) # Build aligned labels per example new_input_ids = [] new_labels = [] for tokens, tags in zip(batch[token_col], batch[tag_col]): enc = tokenizer(tokens, is_split_into_words=True, truncation=True, padding=False) word_ids = enc.word_ids() lab = [] prev_wid = None for wid in word_ids: if wid is None: lab.append(-100) else: tag_id = tags[wid] # Only label first subword if wid != prev_wid: lab.append(tag_id) prev_wid = wid else: lab.append(-100) new_input_ids.append(enc["input_ids"]) # unused but keeps shape; collator will pad new_labels.append(lab) enc = tokenizer( batch[token_col], is_split_into_words=True, truncation=True, padding=True ) enc["labels"] = new_labels return enc log("Tokenizing dataset...") tokenized = ds.map(tokenize_align, batched=True) data_collator = DataCollatorForTokenClassification(tokenizer) metrics_holder: Dict[str, float] = {} def compute_metrics(p): preds, labels = p preds = preds.argmax(-1) true_predictions = [] true_labels = [] for pred, lab in zip(preds, labels): curr_pred = [] curr_lab = [] for p_i, l_i in zip(pred, lab): if l_i != -100: curr_pred.append(id2label[int(p_i)]) curr_lab.append(id2label[int(l_i)]) true_predictions.append(curr_pred) true_labels.append(curr_lab) out = { "f1": f1_score(true_labels, true_predictions), "precision": precision_score(true_labels, true_predictions), "recall": recall_score(true_labels, true_predictions), "accuracy": accuracy_score(true_labels, true_predictions), } metrics_holder.update(out) return out training_args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=per_device_train_batch_size, per_device_eval_batch_size=per_device_train_batch_size, learning_rate=learning_rate, num_train_epochs=num_train_epochs, eval_strategy="epoch", save_strategy="epoch", logging_steps=10, report_to=[], fp16=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized["train"], eval_dataset=tokenized.get("validation") or tokenized.get("dev") or tokenized["test"], tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, ) log("Starting training...") trainer.train() log("Saving adapter...") model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # Compose commit description with metrics desc_lines = [ f"base_model: {base_model}", f"dataset: {dataset_name}", f"epochs: {num_train_epochs}", f"batch_size: {per_device_train_batch_size}", f"learning_rate: {learning_rate}", f"lora_r: {lora_r}", f"lora_alpha: {lora_alpha}", f"lora_dropout: {lora_dropout}", "", "metrics:", *(f"- {k}: {v:.4f}" for k, v in metrics_holder.items()), ] commit_description = "\n".join(desc_lines) # Push to the umbrella repo under checkpoints/ api = HfApi() run_name = os.path.basename(output_dir.rstrip("/")) path_in_repo = f"checkpoints/ner-{run_name}" log(f"Pushing to {TARGET_REPO}:{path_in_repo}") commit = api.upload_folder( repo_id=TARGET_REPO, repo_type="model", folder_path=output_dir, path_in_repo=path_in_repo, commit_message=f"Add NER LoRA checkpoint ({run_name})", commit_description=commit_description, create_pr=True, ) log(f"Pushed: {commit}") # Also publish to a dedicated med-vllm-* variant repo try: base_short = base_model.split("/")[-1].replace(" ", "-").lower() ds_short = dataset_name.split("/")[-1].replace(" ", "-").lower() variant_name = f"Junaidi-AI/med-vllm-ner-{ds_short}-{base_short}-lora-v1" log(f"Ensuring repo exists: {variant_name}") try: create_repo(repo_id=variant_name, repo_type="model", exist_ok=True, private=False) except Exception: pass commit2 = api.upload_folder( repo_id=variant_name, repo_type="model", folder_path=output_dir, path_in_repo=".", commit_message=f"Initial LoRA checkpoint from {base_model} on {dataset_name}", commit_description=commit_description, create_pr=False, ) log(f"Variant published: {commit2}") except Exception as e: log(f"Warning: failed to publish variant repo: {e}") return {"commit": str(commit), "path_in_repo": path_in_repo, "metrics": metrics_holder} class TrainerThread: def __init__(self): self.thread: Optional[threading.Thread] = None self.logs = "" self.result: Optional[Dict[str, Any]] = None self.error: Optional[str] = None def _log(self, msg: str): self.logs += msg + "\n" def start(self, **kwargs): if self.thread and self.thread.is_alive(): raise gr.Error("Training is already running") def target(): try: self._log("Initializing training...") res = _train_ner_lora(log_cb=self._log, **kwargs) self.result = res self._log("Training complete") except Exception as e: self.error = str(e) self._log(f"ERROR: {e}") self.logs = "" self.result = None self.error = None self.thread = threading.Thread(target=target, daemon=True) self.thread.start() def status(self): running = self.thread.is_alive() if self.thread else False return running, self.logs, self.result, self.error TRAINER = TrainerThread() def build_ui(): with gr.Blocks(title="Med vLLM Train (LoRA NER)") as demo: gr.Markdown( f""" # Med vLLM Train (LoRA NER) This Space fine-tunes a token-classification model with LoRA. - Base model default: `{DEFAULT_BASE_MODEL}` - Dataset default: `{DEFAULT_DATASET}` (robust demo). Medical sets like `bc5cdr`/`ncbi_disease` may require custom preprocessing. - Checkpoints will be pushed to `{TARGET_REPO}` under `checkpoints/` as a PR. """ ) with gr.Row(): base_model = gr.Textbox(value=DEFAULT_BASE_MODEL, label="Base model") dataset_name = gr.Dropdown( choices=[ "bc5cdr", "ncbi_disease", "wikiann:en", "conll2003", ], value=DEFAULT_DATASET, allow_custom_value=True, label="Dataset (token classification)", ) trust_scripts = gr.Checkbox(value=True, label="Trust dataset script (required for many HF datasets incl. conll2003)") with gr.Row(): epochs = gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Epochs") batch = gr.Slider(minimum=4, maximum=16, step=2, value=8, label="Batch size") lr = gr.Textbox(value="2e-5", label="Learning rate") with gr.Row(): lora_r = gr.Slider(minimum=4, maximum=32, step=2, value=8, label="LoRA r") lora_alpha = gr.Slider(minimum=8, maximum=64, step=8, value=16, label="LoRA alpha") lora_dropout = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, value=0.1, label="LoRA dropout") with gr.Row(): run_name = gr.Textbox(value=f"run-{int(time.time())}", label="Run name (folder)") with gr.Row(): start_btn = gr.Button("Start Training") status_btn = gr.Button("Refresh Status") logs = gr.Textbox(label="Logs", lines=18) result = gr.Textbox(label="Result / Commit info") def on_start(bm, ds, ep, bs, lr_s, r, alpha, drop, rn, trust): try: out_dir = os.path.join("outputs", rn) os.makedirs(out_dir, exist_ok=True) TRAINER.start( base_model=bm, dataset_name=ds, output_dir=out_dir, num_train_epochs=int(ep), per_device_train_batch_size=int(bs), learning_rate=float(lr_s), lora_r=int(r), lora_alpha=int(alpha), lora_dropout=float(drop), trust_dataset_scripts=bool(trust), ) return "Started" except Exception as e: return f"ERROR starting: {e}" def on_status(): running, l, res, err = TRAINER.status() info = "Running" if running else ("Error" if err else "Idle/Done") res_s = str(res) if res else "" return f"[{info}]\n" + l, res_s start_btn.click( on_start, inputs=[base_model, dataset_name, epochs, batch, lr, lora_r, lora_alpha, lora_dropout, run_name, trust_scripts], outputs=[logs], ) status_btn.click(on_status, outputs=[logs, result]) return demo if __name__ == "__main__": ui = build_ui() ui.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))