Skip to content

Utilities

logger module-attribute

logger = getLogger(__name__)

BatchInfo

BatchInfo(batch_info_prefix)
Source code in tokensmith/utils.py
def __init__(self, batch_info_prefix: str):
    self.doc_idx = np.load(f"{batch_info_prefix}_doc_idx.npy", allow_pickle=True, mmap_mode="r")
    self.sample_idx = np.load(f"{batch_info_prefix}_sample_idx.npy", allow_pickle=True, mmap_mode="r")
    self.shuffle_idx = np.load(f"{batch_info_prefix}_shuffle_idx.npy", allow_pickle=True, mmap_mode="r")

doc_idx instance-attribute

doc_idx = load(f'{batch_info_prefix}_doc_idx.npy', allow_pickle=True, mmap_mode='r')

sample_idx instance-attribute

sample_idx = load(f'{batch_info_prefix}_sample_idx.npy', allow_pickle=True, mmap_mode='r')

shuffle_idx instance-attribute

shuffle_idx = load(f'{batch_info_prefix}_shuffle_idx.npy', allow_pickle=True, mmap_mode='r')

get_example_details_by_id

get_example_details_by_id(example_loc)
Source code in tokensmith/utils.py
def get_example_details_by_id(self, example_loc: int) -> dict:
    pt_shuffle_idx = self.shuffle_idx[example_loc]
    doc_index_f = self.sample_idx[pt_shuffle_idx][0]
    doc_index_l = self.sample_idx[pt_shuffle_idx + 1][0]
    offset_f = self.sample_idx[pt_shuffle_idx][1]
    offset_l = self.sample_idx[pt_shuffle_idx + 1][1]
    return {
        "doc_index_f": doc_index_f,
        "doc_index_l": doc_index_l,
        "offset_f": offset_f,
        "offset_l": offset_l
    }

get_doc_index_in_corpus

get_doc_index_in_corpus(doc_index)
Source code in tokensmith/utils.py
def get_doc_index_in_corpus(self, doc_index: int) -> int:
    return self.doc_idx[doc_index]

WriteableMMapIndexedDataset

WriteableMMapIndexedDataset(dataset_prefix, batch_info_save_prefix, train_seq_len, train_iters, train_batch_size, seed, splits_string, packing_impl, allow_chopped, add_extra_token_to_seq)
Source code in tokensmith/utils.py
def __init__(self, 
             dataset_prefix: str, 
             batch_info_save_prefix: str, 
             train_seq_len: int, 
             train_iters: int,
             train_batch_size: int,
             seed: int,
             splits_string: str,
             packing_impl: str,
             allow_chopped: bool,
             add_extra_token_to_seq: int):
    logger.debug(f"Initializing WriteableMMapIndexedDataset with pointer: {dataset_prefix}.bin and index: {dataset_prefix}.idx")

    self.corpus_pointer = open(f"{dataset_prefix}.bin", 'r+b')
    self.corpus_index = MMapIndexedDataset.Index(f"{dataset_prefix}.idx")
    self.corpus_dtype = self.corpus_index.dtype

    self.num_samples = train_batch_size * train_iters
    self.num_documents = len(self.corpus_index.sizes)

    batch_info_save_path = f"{batch_info_save_prefix}_train_indexmap_{train_iters*train_batch_size}ns_{train_seq_len}sl_{seed}s_{packing_impl}pi"
    if allow_chopped:
        batch_info_save_path += "_ac"
    logger.debug(f"Loading doc/sample/shuffle indexes with prefix: {batch_info_save_path}*")
    if not os.path.exists(f"{batch_info_save_path}_doc_idx.npy"):
        self.simulate_training_run(batch_info_prefix=batch_info_save_prefix,
                                   train_seq_len=train_seq_len,
                                   train_iters=train_iters,
                                   train_batch_size=train_batch_size,
                                   seed=seed,
                                   splits_string=splits_string,
                                   packing_impl=packing_impl,
                                   allow_chopped=allow_chopped
                                   )
    self.batch_info = BatchInfo(batch_info_save_path)
    self.train_seq_len = train_seq_len
    self.add_extra_token_to_seq = add_extra_token_to_seq  # Default to 1 to account adding EOS token

corpus_pointer instance-attribute

corpus_pointer = open(f'{dataset_prefix}.bin', 'r+b')

corpus_index instance-attribute

corpus_index = Index(f'{dataset_prefix}.idx')

corpus_dtype instance-attribute

corpus_dtype = dtype

num_samples instance-attribute

num_samples = train_batch_size * train_iters

num_documents instance-attribute

num_documents = len(sizes)

batch_info instance-attribute

batch_info = BatchInfo(batch_info_save_path)

train_seq_len instance-attribute

train_seq_len = train_seq_len

add_extra_token_to_seq instance-attribute

add_extra_token_to_seq = add_extra_token_to_seq

close

close()

Closes the corpus pointer to release the file handle.

Source code in tokensmith/utils.py
def close(self):
    """
    Closes the corpus pointer to release the file handle.
    """
    logger.debug("Closing corpus pointer.")
    self.corpus_pointer.close()

simulate_training_run

simulate_training_run(batch_info_prefix, train_seq_len, train_iters, train_batch_size, seed, splits_string, packing_impl, allow_chopped)

Simulates a training run by creating doc_idx, sample_idx, and shuffle_idx files for the training sets. This is a placeholder method. It is better to use files generated by the training run.

Source code in tokensmith/utils.py
def simulate_training_run(self, 
                          batch_info_prefix: str,
                            train_seq_len: int,
                          train_iters: int,
                          train_batch_size: int,
                          seed: int,
                          splits_string: str,
                          packing_impl: str,
                          allow_chopped: bool):
    """
    Simulates a training run by creating doc_idx, sample_idx, and shuffle_idx files for the training sets.
    This is a placeholder method. It is better to use files generated by the training run.
    """
    logger.warning("Simulating training run. This is method generates the shuffling and batching order for the training data.")

    total_num_of_documents = self.corpus_index.sizes.shape[0]
    splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
    documents_ = np.arange(start=splits[0], stop=splits[1], step=1, dtype=np.int32)

    num_samples = train_iters * train_batch_size
    data_prefix = batch_info_prefix

    logger.warning(f"> Data prefix: {data_prefix}, num_samples: {num_samples}, seq_length: {train_seq_len}, seed: {seed}, packing_impl: {packing_impl}")
    build_index_mappings(
        'train',
        data_prefix,
        documents_,
        self.corpus_index.sizes,
        None,
        num_samples,
        None,
        train_seq_len,
        seed,
        packing_impl,
        allow_chopped=allow_chopped,
    )

get_corpus_document_by_id

get_corpus_document_by_id(doc_index)

Reads a document from the MMapIndexedDataset by its index.

Parameters:

Name Type Description Default

doc_index

int

The index of the document to read.

required

Returns:

Type Description
ndarray

np.ndarray: A numpy array containing the data read from the document.

Source code in tokensmith/utils.py
def get_corpus_document_by_id(self, doc_index: int) -> np.ndarray:
    """
    Reads a document from the MMapIndexedDataset by its index.

    Args:
        doc_index (int): The index of the document to read.

    Returns:
        np.ndarray: A numpy array containing the data read from the document.
    """
    pt_byte_offset, size = self.corpus_index[doc_index]
    self.corpus_pointer.seek(pt_byte_offset)
    return np.frombuffer(self.corpus_pointer.read(size * np.dtype(self.corpus_dtype).itemsize),
                         dtype=np.dtype(self.corpus_dtype))

get_train_document_by_id

get_train_document_by_id(doc_index)

Reads a document from the MMapIndexedDataset by its index.

Parameters:

Name Type Description Default

doc_index

int

The index of the document to read.

required

Returns:

Type Description
ndarray

np.ndarray: A numpy array containing the data read from the document.

Source code in tokensmith/utils.py
def get_train_document_by_id(self, doc_index: int) -> np.ndarray:
    """
    Reads a document from the MMapIndexedDataset by its index.

    Args:
        doc_index (int): The index of the document to read.

    Returns:
        np.ndarray: A numpy array containing the data read from the document.
    """
    corpus_doc_index = self.batch_info.get_doc_index_in_corpus(doc_index)
    pt_byte_offset, size = self.corpus_index[corpus_doc_index]
    self.corpus_pointer.seek(pt_byte_offset)
    return np.frombuffer(self.corpus_pointer.read(size * np.dtype(self.corpus_dtype).itemsize),
                         dtype=np.dtype(self.corpus_dtype))

get_example_by_id

get_example_by_id(example_loc, return_doc_details=False)

Reads an example from the MMapIndexedDataset by its location in a training run.

Parameters:

Name Type Description Default

example_loc

int

The index of the example to read.

required

return_doc_details

bool

If True, returns the document details along with the data.

False

Returns:

Name Type Description

list of np.ndarray: A list of numpy arrays, each containing the data read from the corresponding document segment.

doc_details (dict, optional)

If return_doc_details is True, returns a dictionary with document details.

Notes: - If the sequence is contained within a single document, only one array is returned in the list. - If the sequence spans multiple documents, the list contains one array per document segment. - The dtype used for reading is inferred from corpus_index_.dtype.

Source code in tokensmith/utils.py
def get_example_by_id(self, example_loc: int, return_doc_details: bool = False):
    """
    Reads an example from the MMapIndexedDataset by its location in a training run.

    Args:
        example_loc (int): The index of the example to read.
        return_doc_details (bool): If True, returns the document details along with the data.

    Returns:
        list of np.ndarray: A list of numpy arrays, each containing the data read from the corresponding document segment.
        doc_details (dict, optional): If `return_doc_details` is True, returns a dictionary with document details.
    Notes:
        - If the sequence is contained within a single document, only one array is returned in the list.
        - If the sequence spans multiple documents, the list contains one array per document segment.
        - The dtype used for reading is inferred from `corpus_index_.dtype`.
    """
    output_seq = []

    doc_details = self.batch_info.get_example_details_by_id(example_loc)
    doc_index_f_, doc_index_l_ = doc_details["doc_index_f"], doc_details["doc_index_l"]
    offset_f_, offset_l_ = doc_details["offset_f"], doc_details["offset_l"]
    if doc_index_f_ == doc_index_l_:
        pt_byte_offset, _ = self.corpus_index[self.batch_info.get_doc_index_in_corpus(doc_index_f_)]
        pt_byte_offset += offset_f_ * np.dtype(self.corpus_dtype).itemsize
        item_length = (offset_l_ - offset_f_ + self.add_extra_token_to_seq) * np.dtype(self.corpus_dtype).itemsize
        self.corpus_pointer.seek(pt_byte_offset)
        output_seq.append(np.frombuffer(self.corpus_pointer.read(item_length),
                                        dtype=np.dtype(self.corpus_dtype)))
    else:
        pt_byte_offset, size = self.corpus_index[self.batch_info.get_doc_index_in_corpus(doc_index_f_)]
        pt_byte_offset += offset_f_ * np.dtype(self.corpus_dtype).itemsize
        item_length = (size - offset_f_) * np.dtype(self.corpus_dtype).itemsize
        self.corpus_pointer.seek(pt_byte_offset)
        output_seq.append(np.frombuffer(self.corpus_pointer.read(item_length),
                                        dtype=np.dtype(self.corpus_dtype)))

        for i in range(doc_index_f_ + 1, doc_index_l_):
            pt_byte_offset, size = self.corpus_index[self.batch_info.get_doc_index_in_corpus(i)]
            item_length = size * np.dtype(self.corpus_dtype).itemsize
            self.corpus_pointer.seek(pt_byte_offset)
            output_seq.append(np.frombuffer(self.corpus_pointer.read(item_length),
                                            dtype=np.dtype(self.corpus_dtype)))

        pt_byte_offset, size = self.corpus_index[self.batch_info.get_doc_index_in_corpus(doc_index_l_)]
        item_length = (offset_l_ + self.add_extra_token_to_seq) * np.dtype(self.corpus_dtype).itemsize
        self.corpus_pointer.seek(pt_byte_offset)
        output_seq.append(np.frombuffer(self.corpus_pointer.read(item_length),
                                        dtype=np.dtype(self.corpus_dtype)))
    if return_doc_details:
        return output_seq, doc_details
    return output_seq

write_example_into_corpus

write_example_into_corpus(injection_loc, injection_data, dry_run=False)

Writes an example into the corpus at the specified location (sample number in a training run).

Parameters:

Name Type Description Default

injection_loc

int

The index of the example to write.

required

injection_data

ndarray

The data to write into the corpus.

required

dry_run

bool

If True, only simulates the write operation without actually modifying the corpus.

False

Returns:

Name Type Description
doc_details dict

A dictionary containing details about the corpus documents where the data was injected.

Source code in tokensmith/utils.py
def write_example_into_corpus(self, injection_loc: int, injection_data: np.ndarray, dry_run: bool = False):
    """
    Writes an example into the corpus at the specified location (sample number in a training run).

    Args:
        injection_loc (int): The index of the example to write.
        injection_data (np.ndarray): The data to write into the corpus.
        dry_run (bool): If True, only simulates the write operation without actually modifying the corpus.

    Returns:
        doc_details (dict): A dictionary containing details about the corpus documents where the data was injected.
    """

    if not dry_run:
        warn_once(logger, "This warning will not be shown again. The process will wait for 10 seconds before starting the edit process.")

    doc_details = self.batch_info.get_example_details_by_id(injection_loc)
    doc_index_f_, doc_index_l_ = doc_details["doc_index_f"], doc_details["doc_index_l"]
    offset_f_, offset_l_ = doc_details["offset_f"], doc_details["offset_l"]

    if doc_index_f_ == doc_index_l_:
        pt_byte_offset, _ = self.corpus_index[self.batch_info.get_doc_index_in_corpus(doc_index_f_)]
        pt_byte_offset += offset_f_ * np.dtype(self.corpus_dtype).itemsize
        assert (offset_l_ - offset_f_ + 1) >= len(injection_data)
        logger.debug(f'>>> Inserting {len(injection_data)} tokens starting at position {pt_byte_offset}')
        if not dry_run:
            self.corpus_pointer.seek(pt_byte_offset)
            self.corpus_pointer.write(injection_data.tobytes(order="C"))
    else:
        pt_byte_offset, size = self.corpus_index[self.batch_info.get_doc_index_in_corpus(doc_index_f_)]
        pt_byte_offset += offset_f_ * np.dtype(self.corpus_dtype).itemsize
        ex_space = size - offset_f_


        logger.debug(f'>>> Inserting {min(ex_space, len(injection_data))} of {len(injection_data)} tokens starting at position {pt_byte_offset}')

        ex_left = len(injection_data) - ex_space
        if not dry_run:
            logger.debug(f'>>> Injection Data Shape {injection_data[:ex_space].shape} | Injection Data Dtype {injection_data[:ex_space].dtype}')
            self.corpus_pointer.seek(pt_byte_offset)
            self.corpus_pointer.write(injection_data[:ex_space].tobytes(order="C"))
        ex_done = ex_space

        if ex_left > 0:
            for i in range(doc_index_f_ + 1, doc_index_l_):
                pt_byte_offset, size = self.corpus_index[self.batch_info.get_doc_index_in_corpus(i)]
                ex_space = size

                logger.debug(f'>>> Inserting {min(ex_space, ex_left)} of {len(injection_data)} tokens starting at position {pt_byte_offset}')
                ex_left -= ex_space
                if not dry_run:
                    self.corpus_pointer.seek(pt_byte_offset)
                    self.corpus_pointer.write(injection_data[ex_done:ex_done+ex_space].tobytes(order="C"))
                ex_done += ex_space
                if ex_left <= 0:
                    break

        if ex_left > 0:
            pt_byte_offset, size = self.corpus_index[self.batch_info.get_doc_index_in_corpus(doc_index_l_)]
            ex_space = offset_l_ + 1

            logger.debug(f'>>> Inserting {min(ex_left, ex_space)} of {len(injection_data)} tokens starting at position {pt_byte_offset}')
            ex_left -= ex_space
            if not dry_run:
                self.corpus_pointer.seek(pt_byte_offset)
                self.corpus_pointer.write(injection_data[ex_done:ex_done+ex_space].tobytes(order="C"))
            ex_done += ex_space

        assert ex_left <= 0
        assert ex_done >= len(injection_data)
    return doc_details

inject_example_into_corpus

inject_example_into_corpus(injection_loc, injection_data, injection_type, rng, dry_run=False)

Injects an example into the corpus at the specified location (sample number in a training run).

Parameters:

Name Type Description Default

injection_loc

int

The index of the example to inject.

required

injection_data

ndarray

The data to inject into the corpus.

required

injection_type

str

The type of injection, e.g., "seq_start" or "seq_shuffle".

required

rng

Generator

Random number generator for sampling positions.

required

dry_run

bool

If True, only simulates the injection operation without actually modifying the corpus.

False
Source code in tokensmith/utils.py
def inject_example_into_corpus(self, injection_loc: int, injection_data: np.ndarray,
                               injection_type: str, rng: np.random.Generator, 
                               dry_run: bool = False):
    """
    Injects an example into the corpus at the specified location (sample number in a training run).

    Args:
        injection_loc (int): The index of the example to inject.
        injection_data (np.ndarray): The data to inject into the corpus.
        injection_type (str): The type of injection, e.g., "seq_start" or "seq_shuffle".
        rng (np.random.Generator): Random number generator for sampling positions.
        dry_run (bool): If True, only simulates the injection operation without actually modifying the corpus.
    """
    injection_details = {}

    # Cast injection_data to the corpus dtype
    if injection_data.dtype != self.corpus_dtype:
        logger.warning(f">> Casting injection data from {injection_data.dtype} to {self.corpus_dtype}")
        injection_data = injection_data.astype(self.corpus_dtype)

    if injection_type == "seq_shuffle":
        """
        Step 1: Read in the training sequence
        Step 2: Sample an injection position and offset for the new training sequence window
        Step 3: Create the  full sequence
        Step 4: Shorten the sequence to the train_seq_len
        Step 5: Inject it into the tokenized corpus
        """
        # Step 1: Read in the orig training sequence
        pt_train_seqs = self.get_example_by_id(injection_loc)
        pt_train_szs = [len(one_seq) for one_seq in pt_train_seqs]
        assert len(np.concatenate(pt_train_seqs)) == self.train_seq_len + self.add_extra_token_to_seq
        # Step 2: Sample an injection position and offset for the new training sequence window
        injection_pos = rng.choice(len(pt_train_seqs), 1)[0]
        if injection_pos == 0:
            window_offset = 0
        else:
            window_offset = rng.choice(len(injection_data), 1)[0]
        # Step 3: Create the full sequence
        pt_train_seqs.insert(injection_pos, injection_data)
        concat_pt_seq = np.concatenate(pt_train_seqs)
        # Step 4: Shorten the sequence to the train_seq_len
        if injection_pos > 0:
            injection_start = sum(pt_train_szs[:injection_pos])
            injection_end = injection_start + len(injection_data)
            used_check = False
            if injection_start < window_offset:
                used_check = True
                logger.debug(f"> Window check: Restricting window_offset ({window_offset}) <= injection_start ({injection_start}).")
                window_offset = min(window_offset, injection_start)
            if (window_offset + self.train_seq_len + self.add_extra_token_to_seq) < injection_end:
                used_check = True
                logger.debug(f"> Window check: Restricting window_offset ({window_offset}) + train_seq_len ({self.train_seq_len + self.add_extra_token_to_seq}) >= injection_end ({injection_end}).")
                window_offset = max(window_offset, injection_end - self.train_seq_len - 1)
            if used_check:
                injection_details['check_used'] = 1
        else:
            injection_start = 0

        injection_details["pt_injection_pos"] = int(injection_pos)
        injection_details["pt_window_offset"] = int(window_offset)
        injection_details["pt_injection_len"] = len(injection_data)
        injection_details["orig_doc_seq_sizes"] = pt_train_szs


        concat_pt_seq = concat_pt_seq[window_offset:window_offset + self.train_seq_len + self.add_extra_token_to_seq]
        assert (concat_pt_seq[injection_start - window_offset: injection_start - window_offset + len(injection_data)] == injection_data).all()

        # Step 5: Replace perturbation object with the full sequence and 
        # inject it into the tokenized corpus
        injection_data = concat_pt_seq

        if dry_run:
            injection_details['injection_data'] = injection_data

        assert len(injection_data) == self.train_seq_len + self.add_extra_token_to_seq
    else:
        # The injected data will overwrite the existing training sequence
        # Overwriting will happen at the start of the sequence
        if dry_run:
            # Generate the training sequence for visualization purposes only
            # Step 1: Read in the orig training sequence
            pt_train_seqs = self.get_example_by_id(injection_loc)
            assert len(np.concatenate(pt_train_seqs)) == self.train_seq_len + self.add_extra_token_to_seq
            # Step 2: Create the full sequence
            concat_pt_seq = np.concatenate(pt_train_seqs)
            # Step 3: Over-write the start of the sequence with the injected data
            concat_pt_seq[:len(injection_data)] = injection_data

            # Step 5: Replace perturbation object with the full sequence and 
            # inject it into the tokenized corpus
            injection_details['injection_data'] = concat_pt_seq

    injection_doc_details = self.write_example_into_corpus(
        injection_loc=injection_loc,
        injection_data=injection_data,
        dry_run=dry_run
    )
    injection_details.update(injection_doc_details)
    return injection_details

warn_once cached

warn_once(logger, msg)
Source code in tokensmith/utils.py
@lru_cache(1)
def warn_once(logger: logging.Logger, msg: str):
    logger.warning(msg)
    time.sleep(10)

generate_training_sample

generate_training_sample(tokenized_segments, tokenizer)
Source code in tokensmith/utils.py
def generate_training_sample(tokenized_segments: List[List[int]], tokenizer: AutoTokenizer) -> str:
    concat_training_sample = np.concatenate(tokenized_segments)
    return tokenizer.decode(
        concat_training_sample,
    )

perturb_dataset

perturb_dataset(raw_dataset, batch_info, perturbation_dir, max_train_samples, max_train_batches, train_seq_len, add_extra_token_to_seq, injection_type, loc_sampler, seed, dry_run=False, perturbation_include_filters=None)
Source code in tokensmith/utils.py
def perturb_dataset(raw_dataset: str,
                    batch_info: str,
                    perturbation_dir: str,
                    max_train_samples: int,
                    max_train_batches: int,
                    train_seq_len: int,
                    add_extra_token_to_seq: int,
                    injection_type: str,
                    loc_sampler: str,
                    seed: int,
                    dry_run: bool = False,
                    perturbation_include_filters: Optional[List[str]] = None
                    ) -> None:
    logger.warning(">>>>>>>>>>>>>>>>>>>>>>>>>")
    logger.warning("WARNING: Index sizes will be inconsistent with the actual document boundaries.")
    logger.warning("<<<<<<<<<<<<<<<<<<<<<<<<<")

    perturbation_files = sorted([pfnm[:-len('.bin')] for pfnm in os.listdir(perturbation_dir) if pfnm.endswith('.bin')])
    if perturbation_include_filters is not None:
        logger.info("> Filtering perturbation files")
        perturbation_files = [pfnm for pfnm in perturbation_files if any(incl_filter in pfnm for incl_filter in perturbation_include_filters)]
        logger.info(f">> Including perturbations with: {perturbation_include_filters}")
        logger.info(f">> Perturbations to be used: {perturbation_files}")

    samples_to_insert = 0
    for perturbation_nm in perturbation_files:
        logger.info(f"> Loading: {os.path.split(perturbation_nm)}")
        mmap_perturb_set = MMapIndexedDataset(os.path.join(perturbation_dir, perturbation_nm))
        samples_to_insert += len(mmap_perturb_set)
    assert samples_to_insert <= max_train_samples
    logger.info(f"> Inserting {samples_to_insert} samples.")
    logger.info(f"> Perturbing {samples_to_insert / max_train_samples * 100}% of the training samples.")
    logger.info(f"> Perturbing {samples_to_insert / max_train_batches * 100}% of the training batches (expectation).")

    rng = np.random.default_rng(seed)
    if loc_sampler == "seq":
        perturbation_locs = rng.choice(max_train_samples, samples_to_insert, replace=False)
    elif loc_sampler == "batch":
        assert samples_to_insert < max_train_batches, f"More perturbations ({samples_to_insert}) than batches ({max_train_batches})"
        assert max_train_samples % max_train_batches == 0, "Inferred batch size is not an integer"
        batch_sz = max_train_samples // max_train_batches
        logger.info(f"> Inferred batch size = {batch_sz}")
        perturbation_batches = rng.choice(max_train_batches, samples_to_insert, replace=False)
        perturbation_offsets = rng.choice(batch_sz, samples_to_insert, replace=True)
        perturbation_locs = perturbation_batches * batch_sz + perturbation_offsets
        assert len(set(perturbation_locs.flatten())) == samples_to_insert
        assert perturbation_locs.max() < max_train_samples
    else:
        raise NotImplementedError(f"Unknown loc_sampler: {loc_sampler}")

    writeable_dataset = WriteableMMapIndexedDataset(raw_dataset, batch_info, train_seq_len=train_seq_len, add_extra_token_to_seq=add_extra_token_to_seq)

    p_ctr = 0
    prior_perturbation_locs = set()
    perturbation_info = []
    check_used = 0
    printed_ctr = 0
    for perturbation_nm in perturbation_files:
        logger.info(f'> Begin adding file {perturbation_nm}')
        mmap_perturb_set = MMapIndexedDataset(os.path.join(perturbation_dir, perturbation_nm))
        assert max(mmap_perturb_set._index.sizes) <= train_seq_len

        for idx_ in trange(len(mmap_perturb_set)):
            one_pt_ex = mmap_perturb_set.get(idx_)

            pt_loc = perturbation_locs[p_ctr]
            assert pt_loc not in prior_perturbation_locs
            prior_perturbation_locs.add(pt_loc)

            injection_details = {
                "perturbation_file": perturbation_nm,
                "perturbation_idx": idx_,
                "pt_loc": int(pt_loc),
            }

            corpus_injection_details = writeable_dataset.inject_example_into_corpus(
                injection_loc=pt_loc,
                injection_data=one_pt_ex,
                injection_type=injection_type,
                rng=rng,
                dry_run=dry_run
            )

            injection_details.update(corpus_injection_details)
            perturbation_info.append(injection_details)

            check_used += injection_details.get('check_used', 0)
            p_ctr += 1
            if dry_run and printed_ctr > 10:
                break

    rand_suffix = uuid.uuid4().hex[:8]
    out_filename = f"{raw_dataset}_perturbation_info_{rand_suffix}.json"
    logger.info(f"Writing perturbation info to {out_filename}")
    with open(out_filename, 'w') as fout:
        out_str = '\n'.join([json.dumps(one_info) for one_info in perturbation_info]) + '\n'
        fout.write(out_str)

    logger.debug(f"> Window check used {check_used} times.")
    writeable_dataset.close()