metadata
license: apache-2.0
language:
- en
tags:
- medical
- biomedical
- drug-safety
- adverse-drug-reactions
- pharmacovigilance
- relation-extraction
- dual-encoder
- clinical-nlp
- biolinkbert
- entity-markers
- hard-negative-mining
- focal-loss
datasets:
- ade-benchmark-corpus/ade_corpus_v2
metrics:
- f1
- roc_auc
pipeline_tag: text-classification
model-index:
- name: CRAG-dual-encoder-ade
results:
- task:
type: text-classification
name: Drug-ADR Relation Extraction
dataset:
name: ADE Corpus V2
type: ade-benchmark-corpus/ade_corpus_v2
config: Ade_corpus_v2_drug_ade_relation
metrics:
- type: f1
value: 0.975
name: F1 Score
- type: roc_auc
value: 0.991
name: ROC-AUC
CRAG-dual-encoder-ade
CRAG: Causal Reasoning for Adversomics Graphs
This is the enhanced ADE-trained model in the CRAG dual-encoder family. It incorporates multiple architectural and training improvements over the base model, achieving 97.5% F1 and 99.1% AUC on drug-ADR relation extraction.
Model Description
CRAG-dual-encoder-ade builds upon the base architecture with several key improvements:
- BioLinkBERT backbone (pre-trained with link prediction for better relation understanding)
- Entity markers ([DRUG]...[/DRUG], [ADR]...[/ADR]) for explicit entity boundary signaling
- Hard negative mining (semantically similar but unrelated pairs)
- Focal loss for handling class imbalance
- Attention pooling instead of [CLS] token
- Layer-wise learning rate decay for stable fine-tuning
Architecture
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β CRAG Dual-Encoder ADE β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β "[DRUG] aspirin [/DRUG] "[ADR] bleeding [/ADR] β
β caused bleeding..." from aspirin..." β
β β β β
β βΌ βΌ β
β βββββββββββββββ βββββββββββββββ β
β β BioLinkBERT β β BioLinkBERT β (separate) β
β β Drug β β ADR β β
β β Encoder β β Encoder β β
β ββββββββ¬βββββββ ββββββββ¬βββββββ β
β β β β
β βΌ βΌ β
β βββββββββββββββ βββββββββββββββ β
β β Attention β β Attention β β
β β Pooling β β Pooling β β
β β (4 heads) β β (4 heads) β β
β ββββββββ¬βββββββ ββββββββ¬βββββββ β
β β β β
β βΌ βΌ β
β βββββββββββββββ βββββββββββββββ β
β β Projection β β Projection β β
β β 768β256 β β 768β256 β β
β β +LayerNorm β β +LayerNorm β β
β β +GELU+Drop β β +GELU+Drop β β
β ββββββββ¬βββββββ ββββββββ¬βββββββ β
β β β β
β βββββββββββββ¬βββββββββββββββββ β
β β β
β βΌ β
β ββββββββββββββββ β
β β Bilinear β β
β β Fusion β β
β β (256Γ256) β β
β ββββββββ¬ββββββββ β
β β β
β βΌ β
β ββββββββββββββββ β
β β Classifier β β
β β 512β256β128β1β β
β β +LayerNorm β β
β ββββββββ¬ββββββββ β
β β β
β βΌ β
β P(causal) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Key Improvements Over Base Model
| Feature | Base | ADE (this model) |
|---|---|---|
| Base Encoder | PubMedBERT | BioLinkBERT |
| Pooling | [CLS] token | Multi-head Attention |
| Entity Marking | None | [DRUG]/[ADR] tokens |
| Negative Sampling | Random | 50% Hard negatives |
| Loss Function | BCE | Focal Loss (Ξ³=2.0) |
| LR Schedule | Linear warmup | Cosine + layer-wise decay |
| Gradient Accumulation | None | 4 steps (effective batch=64) |
Model Specifications
- Base Model:
michiyasunaga/BioLinkBERT-base - Hidden Dimension: 768
- Fusion Dimension: 256
- Attention Heads (pooling): 4
- Total Parameters: ~238M
- Special Tokens:
[DRUG],[/DRUG],[ADR],[/ADR]
Training Procedure
Phase 1: Contrastive Pre-training (5 epochs)
CONFIG = {
"temperature": 0.05, # Sharper similarity distribution
"hard_negative_ratio": 0.5, # 50% hard negatives
"batch_size": 16,
"gradient_accumulation_steps": 4, # Effective batch = 64
}
Hard Negative Mining Strategy:
- Same drug, different ADR (tests ADR discrimination)
- Same ADR, different drug (tests drug discrimination)
- Semantically similar but unrelated pairs
Phase 2: Classification Fine-tuning (8 epochs)
CONFIG = {
"learning_rate": 2e-5,
"warmup_ratio": 0.1,
"layerwise_lr_decay": 0.9, # Lower layers get 0.9Γ LR
"focal_gamma": 2.0, # Focus on hard examples
"focal_alpha": 0.75, # Positive class weight
"weight_decay": 0.01,
"max_grad_norm": 1.0,
}
Focal Loss:
Where Ξ³=2.0 down-weights easy examples, focusing learning on hard cases.
Training Data
- Dataset: ADE Corpus V2
- Configuration:
Ade_corpus_v2_drug_ade_relation - Training Examples: 13,642 (balanced positive/negative with hard mining)
- Validation Examples: 2,047
- Entity Marker Format:
"[DRUG] aspirin [/DRUG] caused [ADR] bleeding [/ADR]"
Performance
Metrics
| Metric | Value |
|---|---|
| F1 Score | 97.5% |
| ROC-AUC | 99.1% |
| Optimal Threshold | 0.55 |
Training Curves
| Phase | Epochs | Final Loss | Final Metric |
|---|---|---|---|
| Contrastive | 5 | 0.021 | - |
| Classification | 8 | 0.008 | F1: 97.5% |
Comparison with CRAG Family
| Model | F1 | AUC | Improvement |
|---|---|---|---|
| CRAG-dual-encoder-base | 88.3% | - | Baseline |
| CRAG-dual-encoder-ade | 97.5% | 99.1% | +9.2% F1 |
| CRAG-dual-encoder-mimicause | 98.9% | 99.8% | +10.6% F1 |
Usage
Loading the Model
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
# Define the model architecture
class AttentionPooling(nn.Module):
def __init__(self, hidden_dim, num_heads=4):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))
def forward(self, hidden_states, attention_mask):
batch_size = hidden_states.size(0)
query = self.query.expand(batch_size, -1, -1)
key_padding_mask = ~attention_mask.bool()
pooled, _ = self.attention(query, hidden_states, hidden_states,
key_padding_mask=key_padding_mask)
return pooled.squeeze(1)
class DualEncoderADE(nn.Module):
def __init__(self, model_name="michiyasunaga/BioLinkBERT-base"):
super().__init__()
self.drug_encoder = AutoModel.from_pretrained(model_name)
self.adr_encoder = AutoModel.from_pretrained(model_name)
self.drug_pooler = AttentionPooling(768)
self.adr_pooler = AttentionPooling(768)
# ... (see full architecture in training script)
# Load tokenizer with special tokens
tokenizer = AutoTokenizer.from_pretrained("chrisvoncsefalvay/CRAG-dual-encoder-ade")
# Load model weights
model = DualEncoderADE()
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
Inference Example
def score_drug_adr_pair(model, tokenizer, drug_text, adr_text, drug_entity, adr_entity):
"""Score a drug-ADR pair for causal relationship."""
# Add entity markers
drug_context = drug_text.replace(
drug_entity,
f"[DRUG] {drug_entity} [/DRUG]"
)
adr_context = adr_text.replace(
adr_entity,
f"[ADR] {adr_entity} [/ADR]"
)
# Tokenize
drug_inputs = tokenizer(
drug_context,
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
adr_inputs = tokenizer(
adr_context,
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
# Get prediction
with torch.no_grad():
drug_repr = model.encode_drug(**drug_inputs)
adr_repr = model.encode_adr(**adr_inputs)
logit = model.classify(drug_repr, adr_repr)
prob = torch.sigmoid(logit).item()
return prob
# Example usage
prob = score_drug_adr_pair(
model, tokenizer,
drug_text="Patient was started on metformin 500mg twice daily.",
adr_text="She developed lactic acidosis requiring ICU admission.",
drug_entity="metformin",
adr_entity="lactic acidosis"
)
print(f"Causal probability: {prob:.3f}")
# Output: Causal probability: 0.923
Intended Uses
Primary Use Cases
- Pharmacovigilance Systems: Automated ADR detection in medical literature
- Drug Safety Databases: Populating causal knowledge graphs
- Clinical Trial Analysis: Mining safety signals from trial reports
- Regulatory Submission Review: Screening documents for ADR mentions
- Post-Market Surveillance: Monitoring real-world drug safety
Best Practices
- Use entity markers
[DRUG]/[ADR]for optimal performance - Apply threshold of 0.55 for balanced precision/recall
- Validate high-confidence predictions with domain experts
- Consider ensemble with CRAG-dual-encoder-mimicause for critical applications
Limitations
- English Only: Trained on English biomedical text
- Explicit Mentions Required: Both drug and ADR must appear in the text
- Binary Classification: Does not distinguish causation types (e.g., dose-dependent)
- Training Data Bias: Reflects the drug/ADR distribution in ADE Corpus V2
- Context Window: Maximum 128 tokens per input
Ethical Considerations
- Not for Direct Clinical Use: Predictions require expert validation
- Bias in Coverage: Common drugs/ADRs better represented than rare ones
- Automation Risks: Over-reliance may miss nuanced relationships
- Transparency: Model confidence should be communicated to end users
Technical Specifications
| Specification | Value |
|---|---|
| Framework | PyTorch |
| Base Model | BioLinkBERT-base |
| Model Size | ~955 MB |
| Vocabulary Size | 30,522 + 4 special |
| Max Sequence Length | 128 tokens |
| Inference Speed | ~100 pairs/sec (GPU) |
Citation
@misc{crag-dual-encoder-ade-2024,
title={CRAG: Causal Reasoning for Adversomics Graphs - Enhanced Dual-Encoder with Hard Negative Mining},
author={von Csefalvay, Chris},
year={2024},
publisher={Hugging Face},
url={https://huggingface.co/chrisvoncsefalvay/CRAG-dual-encoder-ade}
}
References
- Gurulingappa, H., et al. (2012). Development of a benchmark corpus to support the automatic extraction of drug-related adverse effects from medical case reports. Journal of Biomedical Informatics.
- Yasunaga, M., et al. (2022). LinkBERT: Pretraining Language Models with Document Links. ACL.
- Lin, T.Y., et al. (2017). Focal Loss for Dense Object Detection. ICCV.
Model Card Authors
Chris von Csefalvay (@chrisvoncsefalvay)
Model Card Contact
For questions or issues, please open a discussion on this model's repository or contact chris@chrisvoncsefalvay.com.