| from typing import Dict, Iterable, List, Optional, Tuple, Union |
|
|
| import collections |
| import glob |
| import json |
| import hashlib |
| import itertools |
| import logging |
| import multiprocessing |
| import os |
| import pickle |
| import random |
| import requests |
| import sys |
| import zipfile |
|
|
| import datasets |
| import numpy as np |
| import torch |
| import tqdm |
| import transformers |
|
|
| from cde.lib.dist import get_num_proc, get_rank |
|
|
|
|
| def get_cde_cache_dir() -> str: |
| script_directory = os.path.normpath( |
| os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), |
| os.pardir, os.pardir, |
| ) |
| ) |
| return os.path.join(script_directory, "data") |
|
|
|
|
| def get_cache_location_from_kwargs(**kwargs): |
| cache_location = os.path.join( |
| get_cde_cache_dir(), "cluster" |
| ) |
| os.makedirs(cache_location, exist_ok=True) |
| return os.path.join(cache_location, md5_hash_kwargs(**kwargs)) |
|
|
|
|
| def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]: |
| qrels_idxs = collections.defaultdict(list) |
| qrels_scores = collections.defaultdict(list) |
| corpus_ids = np.array(corpus['_id']) |
| skipped_qrels = 0 |
|
|
| for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False): |
| |
| |
| |
| |
| |
| |
| |
| |
| q_id = str(ex['query-id']) |
| c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0] |
| |
| assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)" |
| |
| if len(c_idxs): |
| qrels_idxs[q_id].append(c_idxs[0]) |
| qrels_scores[q_id].append(ex['score']) |
| else: |
| skipped_qrels += 1 |
| |
| |
| if skipped_qrels > 0: |
| logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.') |
| |
| return qrels_idxs, qrels_scores |
|
|
|
|
| def process_qrels( |
| corpus: datasets.Dataset, qrels: datasets.Dataset, |
| use_cache: bool = True |
| ) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]: |
| dataset_cache_file = '_'.join( |
| (corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename']) |
| ) |
| cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p' |
| os.makedirs(os.path.dirname(cache_file), exist_ok=True) |
|
|
| if not (use_cache and os.path.exists(cache_file)): |
| qrels_idxs, qrels_scores = process_qrels_uncached( |
| corpus=corpus, qrels=qrels |
| ) |
| if use_cache: |
| pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb')) |
| else: |
| qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb')) |
| |
| return qrels_idxs, qrels_scores |
|
|
|
|
| def strip_extension(filename: str) -> str: |
| """Strips file extension. |
| |
| Ex: |
| >> strip_extension('/root/dir/sub/file.ext') |
| '/root/dir/sub/file' |
| """ |
| return os.path.splitext(filename)[0] |
|
|
|
|
| def md5_hash(t: Tuple[str]) -> str: |
| return hashlib.md5('__'.join(t).encode()).hexdigest() |
|
|
|
|
| def md5_hash_kwargs(**kwargs) -> str: |
| |
| safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')} |
| s = json.dumps(safe_kwargs, sort_keys=True) |
| return hashlib.md5(s.encode()).hexdigest() |
|
|
| def download_url(url: str, save_path: str, chunk_size: int = 1024): |
| """Download url with progress bar using tqdm |
| https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads |
| Args: |
| url (str): downloadable url |
| save_path (str): local path to save the downloaded file |
| chunk_size (int, optional): chunking of files. Defaults to 1024. |
| """ |
| r = requests.get(url, stream=True) |
| total = int(r.headers.get('Content-Length', 0)) |
| with open(save_path, 'wb') as fd, tqdm.tqdm( |
| desc=save_path, |
| total=total, |
| unit='iB', |
| unit_scale=True, |
| unit_divisor=chunk_size, |
| ) as bar: |
| for data in r.iter_content(chunk_size=chunk_size): |
| size = fd.write(data) |
| bar.update(size) |
|
|
|
|
| def unzip(zip_file: str, out_dir: str): |
| print("unzipping =>", zip_file) |
| zip_ = zipfile.ZipFile(zip_file, "r") |
| zip_.extractall(path=out_dir) |
| zip_.close() |
|
|
|
|
| def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str: |
| os.makedirs(out_dir, exist_ok=True) |
| dataset = url.split("/")[-1] |
| zip_file = os.path.join(out_dir, dataset) |
| |
| if not os.path.isfile(zip_file): |
| logging.info("Downloading {} ...".format(dataset)) |
| download_url(url, zip_file, chunk_size) |
| |
| if not os.path.isdir(zip_file.replace(".zip", "")): |
| logging.info("Unzipping {} ...".format(dataset)) |
| unzip(zip_file, out_dir) |
| |
| return os.path.join(out_dir, dataset.replace(".zip", "")) |
|
|
|
|
| def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable: |
| if get_rank() == 0: |
| return tqdm.tqdm(iterable, **kwargs) |
| else: |
| return iterable |
|
|
|
|
| class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig): |
| """We create a dummy configuration class that will just set properties |
| based on whatever kwargs we pass in. |
| |
| When this class is initialized (see experiments.py) we pass in the |
| union of all data, model, and training args, all of which should |
| get saved to the config json. |
| """ |
|
|
| def __init__(self, **kwargs): |
| for key, value in kwargs.items(): |
| try: |
| json.dumps(value) |
| setattr(self, key, value) |
| except TypeError: |
| |
| continue |
| super().__init__() |
|
|
|
|
| def independent_crop( |
| input_ids: torch.Tensor, pad_token_id: int, |
| l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Returns two independent crops from input_ids. |
| |
| Assumes input_ids has a beginning and end token, like |
| [101, ..., 102, 0, 0, 0]. |
| |
| Args: |
| input_ids: tensor of IDs |
| pad_token_id: ID of pad tokens in input_ids |
| l1: length of span 1, cropped |
| l2: length of span 2, cropped |
| Returns: |
| span1: first crop (of length l1) |
| span2: second crop (of length l2) |
| """ |
| |
| if (input_ids == pad_token_id).sum() == 0: |
| N = len(input_ids) |
| else: |
| N = (input_ids == pad_token_id).int().argmax().item() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| nl1 = min(N//2, l1) |
| nl2 = min(N//2, l2) |
|
|
| s1_start = random.randint(1, N-nl1) |
| s2_start = random.randint(1, N-nl2) |
|
|
| s1_idxs = itertools.chain( |
| [0], range(s1_start, s1_start+nl1), [N-1] |
| ) |
| s1 = input_ids[torch.tensor(list(s1_idxs))] |
| s2_idxs = itertools.chain( |
| [0], range(s2_start, s2_start+nl2), [N-1] |
| ) |
| s2 = input_ids[torch.tensor(list(s2_idxs))] |
| return (s1, s2) |
|
|
|
|
| def load_dataset_tables( |
| files: Iterable[str], num_workers: int = 16 |
| ) -> Iterable[datasets.table.MemoryMappedTable]: |
| import concurrent |
| from multiprocessing import Pool |
|
|
| |
| num_workers = min(32, len(files)) |
|
|
| use_threads = True |
| if use_threads: |
| pool_cls = concurrent.futures.ThreadPoolExecutor |
| pool_kwargs = {"max_workers": num_workers} |
| else: |
| pool_cls = Pool |
| pool_kwargs = {"processes": num_workers} |
| |
| with pool_cls(**pool_kwargs) as pool: |
| if len(files) > 10: |
| files = tqdm_if_main_worker( |
| files, |
| desc=f"Loading {len(files)} files with {num_workers} workers", |
| total=len(files), |
| colour="#ffbd88" |
| ) |
| |
| result = list( |
| pool.map(datasets.table.MemoryMappedTable.from_file, files) |
| ) |
| return result |
|
|
|
|
| def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset: |
| logging.info(f"fast_load_from_disk called with path:", cache_path) |
| dataset_info_path = os.path.join(cache_path, "dataset_info.json") |
| with open(dataset_info_path, encoding="utf-8") as dataset_info_file: |
| dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file)) |
|
|
| dataset_state_path = os.path.join(cache_path, "state.json") |
| with open(dataset_state_path, encoding="utf-8") as state_file: |
| state = json.load(state_file) |
|
|
| files = glob.glob(os.path.join(cache_path, "data-*.arrow")) |
| files = sorted(files) |
| num_workers = get_num_proc() |
| ds_tables = load_dataset_tables( |
| files=files, |
| num_workers=num_workers |
| ) |
| arrow_table = datasets.table.concat_tables(ds_tables) |
|
|
| split = state["_split"] |
| split = datasets.splits.Split(split) if split is not None else split |
|
|
| |
| return datasets.Dataset( |
| arrow_table=arrow_table, |
| info=dataset_info, |
| split=split, |
| fingerprint=state["_fingerprint"], |
| ) |
|
|
|
|
| def tokenize_dataset( |
| dataset: datasets.Dataset, |
| tokenizer: transformers.PreTrainedTokenizer, |
| max_length: int, |
| text_key: str, |
| padding_strategy: str |
| ) -> datasets.Dataset: |
| def tokenize_text(ex: Dict) -> Dict: |
| tt = tokenizer( |
| ex[text_key], |
| max_length=max_length, |
| truncation=True, |
| padding=padding_strategy, |
| ) |
| for k,v in tt.items(): |
| ex[f"{text_key}_{k}"] = v |
| ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]] |
| return ex |
|
|
| |
| vocab = tokenizer.vocab |
| vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word])) |
| vocab_hash = md5_hash(vocab_words) |
|
|
| data_fingerprint = '__'.join(( |
| dataset._fingerprint, str(vocab_hash), str(max_length), |
| text_key, padding_strategy |
| )) |
| data_fingerprint = md5_hash(data_fingerprint) |
| dataset = dataset.map( |
| tokenize_text, |
| new_fingerprint=data_fingerprint, |
| batched=True, |
| load_from_cache_file=True, |
| ) |
| return dataset |
|
|
|
|
| class TensorRunningAverages: |
| _store_sum: Dict[str, torch.Tensor] |
| _store_total: Dict[str, torch.Tensor] |
|
|
| def __init__(self): |
| self._store_sum = {} |
| self._store_total = {} |
| |
| def __iter__(self) -> Iterable[str]: |
| return iter(self._store_sum.keys()) |
|
|
| def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None: |
| if key not in self._store_sum: |
| self.clear(key) |
| if isinstance(val, torch.Tensor): |
| val = val.item() |
| self._store_sum[key] += val |
| self._store_total[key] += 1 |
|
|
| def get(self, key: str) -> float: |
| total = max(self._store_total.get(key).item(), 1.0) |
| return (self._store_sum[key] / float(total)).item() or 0.0 |
| |
| def clear(self, key: str) -> None: |
| self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32) |
| self._store_total[key] = torch.tensor(0, dtype=torch.int32) |
| |
| def clear_all(self) -> None: |
| for key in self._store_sum: |
| self.clear(key) |
|
|
| def get_and_clear_all(self) -> Dict[str, float]: |
| metrics = {} |
| for key in self: |
| metrics[key] = self.get(key) |
| self.clear(key) |
| return metrics |
|
|
| def load_embedder_and_tokenizer(name: str) -> Tuple[ |
| transformers.PreTrainedModel, |
| transformers.PreTrainedTokenizer |
| ]: |
| if name.startswith("nomic") or (name == "bert-base-uncased"): |
| from cde.lib.nomic_bert import NomicBertModel |
| if name.endswith("--from-scratch"): |
| name = name.replace("--from-scratch", "") |
| config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True) |
| model = NomicBertModel._from_config(config) |
| else: |
| model = NomicBertModel.from_pretrained( |
| name, add_pooling_layer=False |
| ) |
| tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| elif name in ["gtr-base", "gtr_base"]: |
| model = transformers.AutoModel.from_pretrained( |
| "sentence-transformers/gtr-t5-base" |
| ).encoder |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| "sentence-transformers/gtr-t5-base" |
| ) |
| elif name == "pile-t5-base-encoder": |
| model = transformers.AutoModel.from_pretrained( |
| "EleutherAI/pile-t5-base" |
| ).encoder |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| "EleutherAI/pile-t5-base" |
| ) |
| tokenizer.pad_token = tokenizer.eos_token |
| elif name == "pile-t5-base-decoder": |
| model = transformers.AutoModel.from_pretrained( |
| "EleutherAI/pile-t5-base" |
| ).decoder |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| "EleutherAI/pile-t5-base" |
| ) |
| tokenizer.pad_token = tokenizer.eos_token |
| elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): |
| model = transformers.AutoModelForCausalLM.from_pretrained( |
| name, |
| |
| |
| attn_implementation="sdpa", |
| low_cpu_mem_usage=True, |
| |
| ) |
| model.padding_side = "right" |
| tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.add_eos_token = True |
| tokenizer.padding_side = "right" |
| elif "Modern" in name: |
| print("special loading for ModernBERT!") |
| |
| model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=True) |
| |
| |
| tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| else: |
| model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) |
| tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| return model, tokenizer |
|
|
|
|
| def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str): |
| key += "_" |
| return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)} |
|
|
|
|
| def count_cpus() -> int: |
| try: |
| return len(os.sched_getaffinity(0)) |
| except AttributeError: |
| return multiprocessing.cpu_count() |
|
|
|
|
| def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]: |
| all_indices = [] |
| for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"): |
| rand_perm = torch.randperm(len(batch_tensor), generator=g) |
| batch_list = batch_tensor[rand_perm].tolist() |
| all_indices.extend(batch_list) |
| return all_indices |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def exit_if_running_or_finished_wandb( |
| project_name: str, |
| exp_group: str, exp_name: str |
| ) -> None: |
| print("Checking if experiment is already running...") |
| import wandb |
|
|
| api = wandb.Api() |
| running_runs = api.runs( |
| path="cde-1", |
| filters={ |
| "display_name": exp_name, |
| "state": {"$regex": "Running|Finished"}, |
| "config.exp_group": exp_group, |
| } |
| ) |
| print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.") |
|
|
| if len(running_runs) > 0: |
| print("Exiting because experiment is already running or completed.") |
| sys.exit(0) |
| |
|
|
| HN_FILTER_TOKENIZER_MAP = { |
| "nomic": "nomic-ai/nomic-embed-text-v1", |
| "stella": "dunzhang/stella_en_400M_v5", |
| "sbert": "sentence-transformers/all-MiniLM-L6-v2", |
| "sentence_t5": "sentence-transformers/sentence-t5-base", |
| "gte": "Alibaba-NLP/gte-large-en-v1.5", |
| } |
| def load_hn_filter_tokenizer(tokenizer_name: str) -> Optional[transformers.PreTrainedTokenizer]: |
| if tokenizer_name in HN_FILTER_TOKENIZER_MAP: |
| return transformers.AutoTokenizer.from_pretrained(HN_FILTER_TOKENIZER_MAP[tokenizer_name]) |
| else: |
| return None |
|
|