renhehuang/traditional-chinese-classification
Viewer • Updated • 206k • 46 • 1
How to use renhehuang/bert-traditional-chinese-classifier with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("text-classification", model="renhehuang/bert-traditional-chinese-classifier") # Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("renhehuang/bert-traditional-chinese-classifier")
model = AutoModelForMaskedLM.from_pretrained("renhehuang/bert-traditional-chinese-classifier")A BERT-based classifier to distinguish Mainland Traditional vs. Taiwan Traditional Chinese usage.
import torch, torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# --- Basic config ---
REPO_ID = "renhehuang/bert-traditional-chinese-classifier"
LABELS = {0: "Mainland Traditional", 1: "Taiwan Traditional"}
MAX_LEN, STRIDE = 384, 128
# --- Device ---
device = (
"mps" if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
# --- Load model & tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(REPO_ID, cache_dir=".cache")
model = AutoModelForSequenceClassification.from_pretrained(REPO_ID, cache_dir=".cache")
model.to(device).eval()
# --- Long-text chunking ---
def chunk_encode(text, max_len=MAX_LEN, stride=STRIDE):
ids = tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"]
if len(ids) <= max_len - 2:
enc = tokenizer(text, truncation=True, max_length=max_len,
return_attention_mask=True, return_tensors="pt")
return [enc]
enc = tokenizer(text, truncation=True, max_length=max_len, stride=stride,
return_overflowing_tokens=True, return_attention_mask=True,
return_tensors="pt")
return [{"input_ids": enc["input_ids"][i:i+1],
"attention_mask": enc["attention_mask"][i:i+1]}
for i in range(len(enc["input_ids"]))]
# --- Single-text inference ---
@torch.inference_mode()
def predict(text: str):
chunks = chunk_encode(text)
probs_all = []
for ch in chunks:
logits = model(
input_ids=ch["input_ids"].to(device),
attention_mask=ch["attention_mask"].to(device)
).logits
probs_all.append(F.softmax(logits, dim=-1).cpu())
avg = torch.cat(probs_all, 0).mean(0)
label_id = int(avg.argmax())
return {
"text_preview": (text[:100] + "...") if len(text) > 100 else text,
"predicted_id": label_id,
"predicted_name": LABELS[label_id],
"confidence": float(avg[label_id]),
"probabilities": {LABELS[0]: float(avg[0]), LABELS[1]: float(avg[1])},
"num_chunks": len(chunks),
"device": device,
}
# --- Quick test ---
if __name__ == "__main__":
tests = [
"這個軟件的界面設計得很好。",
"這個軟體的介面設計得很好。",
"我需要下載這個程序到計算機上。",
"我需要下載這個程式到電腦上。",
]
for t in tests:
r = predict(t)
print(f"{r['predicted_name']} | conf={r['confidence']:.2%} | {r['text_preview']}")
from collections import Counter
@torch.inference_mode()
def predict_runs(text: str, n_runs: int = 3, enable_dropout: bool = True):
# Pre-chunk
chunks = chunk_encode(text)
prev_training = model.training
run_prob_list = []
try:
model.train() if enable_dropout else model.eval() # enable MC Dropout
for _ in range(n_runs):
probs_all = []
for ch in chunks:
logits = model(
input_ids=ch["input_ids"].to(device),
attention_mask=ch["attention_mask"].to(device)
).logits
probs_all.append(F.softmax(logits, dim=-1).cpu())
run_prob_list.append(torch.cat(probs_all, 0).mean(0))
finally:
model.train() if prev_training else model.eval()
probs_stack = torch.stack(run_prob_list, 0)
per_run_ids = probs_stack.argmax(-1).tolist()
vote_counts = Counter(per_run_ids)
mean_probs = probs_stack.mean(0)
# Majority vote + mean probability as a tie-breaker
voted_id = max(vote_counts.items(), key=lambda kv: (kv[1], mean_probs[kv[0]].item()))[0]
return LABELS[voted_id], float(mean_probs[voted_id]), dict(vote_counts)
For full learning curves and diagnostic plots, see repository outputs.
Intended: origin-style identification, data cleaning, annotation assistance, pre-normalization, and hybrid use with rules/other models.
Limitations:
If you use this model, please cite:
@misc{bert-traditional-chinese-classifier,
author = {renhehuang},
title = {BERT Traditional Chinese Classifier},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/renhehuang/bert-traditional-chinese-classifier}}
}
Apache-2.0
Please open an issue on the Hugging Face model page or GitHub repository.
Base model
ckiplab/bert-base-chinese