stanza-digphil / prepare-train-val-test.py
al1808th's picture
new repo structure
035e51c
#!/usr/bin/env python3
'''
Accepted language codes:
sv, diachron, bm, nn, dk, de_lit, is
Example arg:
python prepare-train-val-test.py sv diachron bm nn dk de_lit
'''
from collections import defaultdict
from pathlib import Path
import random
import sys
# ============================================================
# SETTINGS
# ============================================================
ud_treebank_groups_used = sys.argv[1:]
AT_RISK_DEPRELS = {
"acl",
"advcl",
"ccomp",
"appos",
"iobj",
"parataxis",
"nummod",
"flat:name",
}
# ============================================================
# BASE PATHS
# ============================================================
BASE = Path.cwd()
SVENSKA_PROJEKT = BASE / "ud-treebanks-sv"
NORSKA_PROJEKT = BASE / "ud-treebanks-bm"
NYNORSKA_PROJEKT = BASE / "ud-treebanks-nn"
DANSKA_PROJEKT = BASE / "ud-treebanks-dk"
TYSKA_PROJEKT = BASE / "ud-treebanks-de_lit"
ICELANDIC_PROJEKT = BASE / "ud-treebanks-is"
DIGPHIL_MACHINE = BASE / "alanev_raw_files/diachron"
DIGPHIL_GOLD = BASE / "alanev_raw_files/diachron-validated"
# ============================================================
# Redirect outputs to UD_Swedish-diachronic/
# ============================================================
OUTPUT_DIR = BASE / "ud" / "UD_Swedish-diachronic"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_TRAIN = OUTPUT_DIR / "sv_diachronic-ud-train.conllu"
OUTPUT_DEV = OUTPUT_DIR / "sv_diachronic-ud-dev.conllu"
OUTPUT_TEST = OUTPUT_DIR / "sv_diachronic-ud-test.conllu"
random.seed(1337)
# ============================================================
# BASIC HELPERS
# ============================================================
def read_conllu(path: Path):
text = path.read_text(encoding="utf-8").strip()
return [] if not text else text.split("\n\n")
def extract_sent_id(block: str) -> str | None:
for line in block.split("\n"):
if line.startswith("# sent_id"):
parts = line.split("=", 1)
if len(parts) == 2:
return parts[1].strip()
return line.split("# sent_id", 1)[1].strip()
return None
def write_conllu(path: Path, sentences):
with path.open("w", encoding="utf-8") as f:
for s in sentences:
f.write(s.strip() + "\n\n")
def load_from_treebank_dir(directory: Path):
collected = []
for path in directory.rglob("*.conllu"):
print(f"Reading: {path}")
collected.extend(read_conllu(path))
return collected
def sentence_deprels(block: str) -> set[str]:
rels = set()
for line in block.split("\n"):
if line.startswith("#"):
continue
fields = line.split("\t")
if len(fields) == 10:
deprel = fields[7]
if deprel and deprel != "_":
rels.add(deprel)
return rels
# ============================================================
# CoNLL-U VALIDATOR
# ============================================================
class CoNLLUValidator:
def __init__(self):
self.errors = []
def validate_sentence(self, sentence_lines, sent_id=None):
self.errors = []
if not sentence_lines:
self.errors.append("Empty sentence")
return False
tokens = []
roots = []
token_ids = set()
for line_num, line in enumerate(sentence_lines, 1):
try:
fields = line.split('\t')
if len(fields) != 10:
self.errors.append(f"Line {line_num}: Expected 10 fields, got {len(fields)}")
continue
token_id, form, lemma, upos, xpos, feats, head, deprel, deps, misc = fields
if '-' in token_id or '.' in token_id:
continue
try:
token_id_int = int(token_id)
head_int = int(head)
except ValueError:
self.errors.append(f"Line {line_num}: Invalid token ID or head")
continue
token_ids.add(token_id_int)
if head_int == 0:
roots.append(token_id_int)
tokens.append({
'id': token_id_int,
'form': form,
'lemma': lemma,
'upos': upos,
'head': head_int,
'deprel': deprel
})
except Exception as e:
self.errors.append(f"Line {line_num}: Error: {e}")
if len(roots) == 0:
self.errors.append("No root found")
elif len(roots) > 1:
self.errors.append(f"Multiple roots found: {roots}")
for token in tokens:
if token['head'] != 0 and token['head'] not in token_ids:
self.errors.append(f"Token {token['id']} has invalid head {token['head']}")
if not self._check_no_cycles(tokens):
self.errors.append("Dependency cycle detected")
for token in tokens:
if not token['form'] or token['form'] == '_':
self.errors.append(f"Token {token['id']}: Missing form")
if not token['upos'] or token['upos'] == '_':
self.errors.append(f"Token {token['id']}: Missing UPOS")
if not token['deprel'] or token['deprel'] == '_':
self.errors.append(f"Token {token['id']}: Missing deprel")
return len(self.errors) == 0
def _check_no_cycles(self, tokens):
heads = {t['id']: t['head'] for t in tokens}
for start in tokens:
visited = set()
current = start['id']
while current != 0 and current in heads:
if current in visited:
return False
visited.add(current)
current = heads[current]
return True
def get_errors(self):
return self.errors
# ============================================================
# CLEANING PIPELINE
# ============================================================
def clean_sentences(sentence_blocks):
validator = CoNLLUValidator()
cleaned = []
for block in sentence_blocks:
lines = [l for l in block.split("\n") if not l.startswith("#")]
comments = [l for l in block.split("\n") if l.startswith("#")]
sent_id = None
for c in comments:
if c.startswith("# sent_id"):
sent_id = c.split("=", 1)[1].strip() if "=" in c else None
if validator.validate_sentence(lines, sent_id):
cleaned.append(block)
else:
print(f"[REMOVED] sent_id={sent_id} ERRORS={validator.get_errors()}")
return cleaned
# ============================================================
# Load only requested treebanks
# ============================================================
train_sentences = []
if "sv" in ud_treebank_groups_used:
train_sentences.extend(load_from_treebank_dir(SVENSKA_PROJEKT))
if "bm" in ud_treebank_groups_used:
train_sentences.extend(load_from_treebank_dir(NORSKA_PROJEKT))
if "nn" in ud_treebank_groups_used:
train_sentences.extend(load_from_treebank_dir(NYNORSKA_PROJEKT))
if "dk" in ud_treebank_groups_used:
train_sentences.extend(load_from_treebank_dir(DANSKA_PROJEKT))
if "is" in ud_treebank_groups_used:
train_sentences.extend(load_from_treebank_dir(ICELANDIC_PROJEKT))
if "de_lit" in ud_treebank_groups_used:
train_sentences.extend(load_from_treebank_dir(TYSKA_PROJEKT))
# ============================================================
# DigPhil machine ONLY added if "diachron" requested
# ============================================================
def map_sent_ids_by_file(directory: Path):
mapping = {}
for path in directory.glob("*.conllu"):
blocks = read_conllu(path)
ids = {extract_sent_id(b) for b in blocks if extract_sent_id(b)}
mapping[path.name] = ids
return mapping
gold_ids = map_sent_ids_by_file(DIGPHIL_GOLD)
if "diachron" in ud_treebank_groups_used:
print("Including DigPhil MACHINE in TRAIN (minus gold)…")
for machine_file in DIGPHIL_MACHINE.glob("*.conllu"):
blocks = read_conllu(machine_file)
filename = machine_file.name
gold_for_this = gold_ids.get(filename, set())
for block in blocks:
sid = extract_sent_id(block)
if sid and sid in gold_for_this:
continue
train_sentences.append(block)
else:
print("Skipping DigPhil MACHINE (diachron not requested).")
# ============================================================
# FINAL MODEL ONLY: DigPhil gold used for dev/test
# ============================================================
gold_sentences = []
for gold_file in DIGPHIL_GOLD.glob("*.conllu"):
print(f"Reading GOLD: {gold_file}")
gold_sentences.extend(read_conllu(gold_file))
#random.shuffle(gold_sentences)
#
#n = len(gold_sentences)
#dev_size = max(1, int(n * 0.10))
#
#dev_sentences = gold_sentences[:dev_size]
#test_sentences = gold_sentences[dev_size:]
# STRATIFIED DEV SPLIT (UPDATE TO PREVENT ZEROING OF DEPRELS)
random.shuffle(gold_sentences)
TARGET_DEV_SIZE = max(1, int(len(gold_sentences) * 0.10))
MIN_PER_DEPREL = 5 # small but sufficient
dev_sentences = []
used = set()
# index sentences by deprel
by_deprel = defaultdict(list)
for i, sent in enumerate(gold_sentences):
for d in sentence_deprels(sent):
if d in AT_RISK_DEPRELS:
by_deprel[d].append(i)
# ensure coverage
for d in AT_RISK_DEPRELS:
candidates = by_deprel.get(d, [])
random.shuffle(candidates)
for idx in candidates[:MIN_PER_DEPREL]:
if idx not in used:
dev_sentences.append(gold_sentences[idx])
used.add(idx)
# fill remainder randomly
for i, sent in enumerate(gold_sentences):
if len(dev_sentences) >= TARGET_DEV_SIZE:
break
if i not in used:
dev_sentences.append(sent)
used.add(i)
test_sentences = [
sent for i, sent in enumerate(gold_sentences)
if i not in used
]
# ============================================================
# CLEAN ALL OUTPUTS
# ============================================================
print("Cleaning TRAIN...")
train_sentences = clean_sentences(train_sentences)
print("Cleaning DEV...")
dev_sentences = clean_sentences(dev_sentences)
print("Cleaning TEST...")
test_sentences = clean_sentences(test_sentences)
# ============================================================
# WRITE FINAL OUTPUTS (now inside UD_Swedish-diachronic/)
# ============================================================
print(f"Writing TRAIN → {OUTPUT_TRAIN} ({len(train_sentences)} valid sentences)")
write_conllu(OUTPUT_TRAIN, train_sentences)
print(f"Writing DEV → {OUTPUT_DEV} ({len(dev_sentences)} valid sentences)")
write_conllu(OUTPUT_DEV, dev_sentences)
print(f"Writing TEST → {OUTPUT_TEST} ({len(test_sentences)} valid sentences)")
write_conllu(OUTPUT_TEST, test_sentences)
print("Done.")