chrisvoncsefalvay's picture
Update metrics with proper validation scores
019a0b8 verified
metadata
license: apache-2.0
language:
  - en
tags:
  - medical
  - biomedical
  - drug-safety
  - adverse-drug-reactions
  - pharmacovigilance
  - relation-extraction
  - dual-encoder
  - clinical-nlp
  - biolinkbert
  - entity-markers
  - hard-negative-mining
  - focal-loss
datasets:
  - ade-benchmark-corpus/ade_corpus_v2
metrics:
  - f1
  - roc_auc
pipeline_tag: text-classification
model-index:
  - name: CRAG-dual-encoder-ade
    results:
      - task:
          type: text-classification
          name: Drug-ADR Relation Extraction
        dataset:
          name: ADE Corpus V2
          type: ade-benchmark-corpus/ade_corpus_v2
          config: Ade_corpus_v2_drug_ade_relation
        metrics:
          - type: f1
            value: 0.975
            name: F1 Score
          - type: roc_auc
            value: 0.991
            name: ROC-AUC

CRAG-dual-encoder-ade

CRAG: Causal Reasoning for Adversomics Graphs

This is the enhanced ADE-trained model in the CRAG dual-encoder family. It incorporates multiple architectural and training improvements over the base model, achieving 97.5% F1 and 99.1% AUC on drug-ADR relation extraction.

Model Description

CRAG-dual-encoder-ade builds upon the base architecture with several key improvements:

  1. BioLinkBERT backbone (pre-trained with link prediction for better relation understanding)
  2. Entity markers ([DRUG]...[/DRUG], [ADR]...[/ADR]) for explicit entity boundary signaling
  3. Hard negative mining (semantically similar but unrelated pairs)
  4. Focal loss for handling class imbalance
  5. Attention pooling instead of [CLS] token
  6. Layer-wise learning rate decay for stable fine-tuning

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    CRAG Dual-Encoder ADE                        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚   "[DRUG] aspirin [/DRUG]       "[ADR] bleeding [/ADR]          β”‚
β”‚    caused bleeding..."           from aspirin..."               β”‚
β”‚           β”‚                            β”‚                        β”‚
β”‚           β–Ό                            β–Ό                        β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”‚
β”‚   β”‚ BioLinkBERT β”‚              β”‚ BioLinkBERT β”‚  (separate)      β”‚
β”‚   β”‚   Drug      β”‚              β”‚    ADR      β”‚                  β”‚
β”‚   β”‚  Encoder    β”‚              β”‚   Encoder   β”‚                  β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜              β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜                  β”‚
β”‚          β”‚                            β”‚                         β”‚
β”‚          β–Ό                            β–Ό                         β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”‚
β”‚   β”‚  Attention  β”‚              β”‚  Attention  β”‚                  β”‚
β”‚   β”‚   Pooling   β”‚              β”‚   Pooling   β”‚                  β”‚
β”‚   β”‚  (4 heads)  β”‚              β”‚  (4 heads)  β”‚                  β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜              β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜                  β”‚
β”‚          β”‚                            β”‚                         β”‚
β”‚          β–Ό                            β–Ό                         β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”‚
β”‚   β”‚ Projection  β”‚              β”‚ Projection  β”‚                  β”‚
β”‚   β”‚  768β†’256    β”‚              β”‚  768β†’256    β”‚                  β”‚
β”‚   β”‚ +LayerNorm  β”‚              β”‚ +LayerNorm  β”‚                  β”‚
β”‚   β”‚ +GELU+Drop  β”‚              β”‚ +GELU+Drop  β”‚                  β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜              β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜                  β”‚
β”‚          β”‚                            β”‚                         β”‚
β”‚          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                         β”‚
β”‚                      β”‚                                          β”‚
β”‚                      β–Ό                                          β”‚
β”‚              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                   β”‚
β”‚              β”‚   Bilinear   β”‚                                   β”‚
β”‚              β”‚   Fusion     β”‚                                   β”‚
β”‚              β”‚  (256Γ—256)   β”‚                                   β”‚
β”‚              β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜                                   β”‚
β”‚                     β”‚                                           β”‚
β”‚                     β–Ό                                           β”‚
β”‚              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                   β”‚
β”‚              β”‚  Classifier  β”‚                                   β”‚
β”‚              β”‚ 512β†’256β†’128β†’1β”‚                                   β”‚
β”‚              β”‚ +LayerNorm   β”‚                                   β”‚
β”‚              β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜                                   β”‚
β”‚                     β”‚                                           β”‚
β”‚                     β–Ό                                           β”‚
β”‚                 P(causal)                                       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Key Improvements Over Base Model

Feature Base ADE (this model)
Base Encoder PubMedBERT BioLinkBERT
Pooling [CLS] token Multi-head Attention
Entity Marking None [DRUG]/[ADR] tokens
Negative Sampling Random 50% Hard negatives
Loss Function BCE Focal Loss (Ξ³=2.0)
LR Schedule Linear warmup Cosine + layer-wise decay
Gradient Accumulation None 4 steps (effective batch=64)

Model Specifications

  • Base Model: michiyasunaga/BioLinkBERT-base
  • Hidden Dimension: 768
  • Fusion Dimension: 256
  • Attention Heads (pooling): 4
  • Total Parameters: ~238M
  • Special Tokens: [DRUG], [/DRUG], [ADR], [/ADR]

Training Procedure

Phase 1: Contrastive Pre-training (5 epochs)

CONFIG = {
    "temperature": 0.05,           # Sharper similarity distribution
    "hard_negative_ratio": 0.5,    # 50% hard negatives
    "batch_size": 16,
    "gradient_accumulation_steps": 4,  # Effective batch = 64
}

Hard Negative Mining Strategy:

  • Same drug, different ADR (tests ADR discrimination)
  • Same ADR, different drug (tests drug discrimination)
  • Semantically similar but unrelated pairs

Phase 2: Classification Fine-tuning (8 epochs)

CONFIG = {
    "learning_rate": 2e-5,
    "warmup_ratio": 0.1,
    "layerwise_lr_decay": 0.9,     # Lower layers get 0.9Γ— LR
    "focal_gamma": 2.0,            # Focus on hard examples
    "focal_alpha": 0.75,           # Positive class weight
    "weight_decay": 0.01,
    "max_grad_norm": 1.0,
}

Focal Loss: FL(pt)=βˆ’Ξ±t(1βˆ’pt)Ξ³log⁑(pt)FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)

Where Ξ³=2.0 down-weights easy examples, focusing learning on hard cases.

Training Data

  • Dataset: ADE Corpus V2
  • Configuration: Ade_corpus_v2_drug_ade_relation
  • Training Examples: 13,642 (balanced positive/negative with hard mining)
  • Validation Examples: 2,047
  • Entity Marker Format: "[DRUG] aspirin [/DRUG] caused [ADR] bleeding [/ADR]"

Performance

Metrics

Metric Value
F1 Score 97.5%
ROC-AUC 99.1%
Optimal Threshold 0.55

Training Curves

Phase Epochs Final Loss Final Metric
Contrastive 5 0.021 -
Classification 8 0.008 F1: 97.5%

Comparison with CRAG Family

Model F1 AUC Improvement
CRAG-dual-encoder-base 88.3% - Baseline
CRAG-dual-encoder-ade 97.5% 99.1% +9.2% F1
CRAG-dual-encoder-mimicause 98.9% 99.8% +10.6% F1

Usage

Loading the Model

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

# Define the model architecture
class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim, num_heads=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))

    def forward(self, hidden_states, attention_mask):
        batch_size = hidden_states.size(0)
        query = self.query.expand(batch_size, -1, -1)
        key_padding_mask = ~attention_mask.bool()
        pooled, _ = self.attention(query, hidden_states, hidden_states,
                                   key_padding_mask=key_padding_mask)
        return pooled.squeeze(1)

class DualEncoderADE(nn.Module):
    def __init__(self, model_name="michiyasunaga/BioLinkBERT-base"):
        super().__init__()
        self.drug_encoder = AutoModel.from_pretrained(model_name)
        self.adr_encoder = AutoModel.from_pretrained(model_name)
        self.drug_pooler = AttentionPooling(768)
        self.adr_pooler = AttentionPooling(768)
        # ... (see full architecture in training script)

# Load tokenizer with special tokens
tokenizer = AutoTokenizer.from_pretrained("chrisvoncsefalvay/CRAG-dual-encoder-ade")

# Load model weights
model = DualEncoderADE()
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

Inference Example

def score_drug_adr_pair(model, tokenizer, drug_text, adr_text, drug_entity, adr_entity):
    """Score a drug-ADR pair for causal relationship."""

    # Add entity markers
    drug_context = drug_text.replace(
        drug_entity,
        f"[DRUG] {drug_entity} [/DRUG]"
    )
    adr_context = adr_text.replace(
        adr_entity,
        f"[ADR] {adr_entity} [/ADR]"
    )

    # Tokenize
    drug_inputs = tokenizer(
        drug_context,
        return_tensors="pt",
        max_length=128,
        truncation=True,
        padding="max_length"
    )
    adr_inputs = tokenizer(
        adr_context,
        return_tensors="pt",
        max_length=128,
        truncation=True,
        padding="max_length"
    )

    # Get prediction
    with torch.no_grad():
        drug_repr = model.encode_drug(**drug_inputs)
        adr_repr = model.encode_adr(**adr_inputs)
        logit = model.classify(drug_repr, adr_repr)
        prob = torch.sigmoid(logit).item()

    return prob

# Example usage
prob = score_drug_adr_pair(
    model, tokenizer,
    drug_text="Patient was started on metformin 500mg twice daily.",
    adr_text="She developed lactic acidosis requiring ICU admission.",
    drug_entity="metformin",
    adr_entity="lactic acidosis"
)
print(f"Causal probability: {prob:.3f}")
# Output: Causal probability: 0.923

Intended Uses

Primary Use Cases

  • Pharmacovigilance Systems: Automated ADR detection in medical literature
  • Drug Safety Databases: Populating causal knowledge graphs
  • Clinical Trial Analysis: Mining safety signals from trial reports
  • Regulatory Submission Review: Screening documents for ADR mentions
  • Post-Market Surveillance: Monitoring real-world drug safety

Best Practices

  1. Use entity markers [DRUG]/[ADR] for optimal performance
  2. Apply threshold of 0.55 for balanced precision/recall
  3. Validate high-confidence predictions with domain experts
  4. Consider ensemble with CRAG-dual-encoder-mimicause for critical applications

Limitations

  1. English Only: Trained on English biomedical text
  2. Explicit Mentions Required: Both drug and ADR must appear in the text
  3. Binary Classification: Does not distinguish causation types (e.g., dose-dependent)
  4. Training Data Bias: Reflects the drug/ADR distribution in ADE Corpus V2
  5. Context Window: Maximum 128 tokens per input

Ethical Considerations

  • Not for Direct Clinical Use: Predictions require expert validation
  • Bias in Coverage: Common drugs/ADRs better represented than rare ones
  • Automation Risks: Over-reliance may miss nuanced relationships
  • Transparency: Model confidence should be communicated to end users

Technical Specifications

Specification Value
Framework PyTorch
Base Model BioLinkBERT-base
Model Size ~955 MB
Vocabulary Size 30,522 + 4 special
Max Sequence Length 128 tokens
Inference Speed ~100 pairs/sec (GPU)

Citation

@misc{crag-dual-encoder-ade-2024,
  title={CRAG: Causal Reasoning for Adversomics Graphs - Enhanced Dual-Encoder with Hard Negative Mining},
  author={von Csefalvay, Chris},
  year={2024},
  publisher={Hugging Face},
  url={https://huggingface.co/chrisvoncsefalvay/CRAG-dual-encoder-ade}
}

References

  • Gurulingappa, H., et al. (2012). Development of a benchmark corpus to support the automatic extraction of drug-related adverse effects from medical case reports. Journal of Biomedical Informatics.
  • Yasunaga, M., et al. (2022). LinkBERT: Pretraining Language Models with Document Links. ACL.
  • Lin, T.Y., et al. (2017). Focal Loss for Dense Object Detection. ICCV.

Model Card Authors

Chris von Csefalvay (@chrisvoncsefalvay)

Model Card Contact

For questions or issues, please open a discussion on this model's repository or contact chris@chrisvoncsefalvay.com.