Dataset Sampling Tutorial¶
This tutorial demonstrates how to use TokenSmith's flexible sampling functionality to retrieve data samples and batches using various strategies. We'll explore different sampling methods, custom policies, and advanced techniques for data selection.
Prerequisites:
- Complete tutorials 1 and 2 (basic setup and inspection)
- Have a tokenized dataset ready with batch info generated
What you'll learn:
- How to sample data by specific indices
- How to sample batches by IDs
- Creating and using custom sampling policies
- Policy-based sampling for individual samples and batches
- Advanced sampling strategies for research and analysis
- Performance considerations for different sampling methods
Setup¶
Let's start by setting up our environment and dataset manager, building on the previous tutorials.
# Fix paths for imports
import sys
sys.path.insert(0, "/NS/llm-pretraining/work/afkhan/tokensmith")
sys.path.insert(0, "/NS/llm-pretraining/work/afkhan/USC_Colab/gpt-neox")
# Import required libraries
import numpy as np
import random
from transformers import AutoTokenizer
from tokensmith.manager import DatasetManager
# Load tokenizer
TOKENIZER_NAME_OR_PATH = "EleutherAI/gpt-neox-20b"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME_OR_PATH, add_eos_token=True)
print(f"Loaded tokenizer: {TOKENIZER_NAME_OR_PATH}")
[2025-07-01 12:32:22,370] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/NS/venvs/work/afkhan/neox_updated_env/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:47: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @autocast_custom_fwd /NS/venvs/work/afkhan/neox_updated_env/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:66: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead. @autocast_custom_bwd /NS/venvs/work/afkhan/neox_updated_env/lib/python3.11/site-packages/huggingface_hub/file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn( Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loaded tokenizer: EleutherAI/gpt-neox-20b
# Initialize DatasetManager
dataset_manager = DatasetManager()
# Setup the dataset for sampling
dataset_manager.setup_edit_inspect_sample_export(
    dataset_prefix='../../artifacts/data_tokenized_text_document',
    batch_info_save_prefix='../../artifacts/batch_info',
    train_iters=100,
    train_batch_size=16,
    train_seq_len=2048,
    seed=42,
    splits_string='990,5,5',
    packing_impl='packed',
    allow_chopped=True,
)
print("Dataset manager setup complete!")
    warming up index mmap file...
    reading sizes...
    reading pointers...
    reading document index...
Dataset manager setup complete!
Basic Sampling by Indices¶
The most straightforward way to sample data is by specifying exact sample indices. This is useful when you know exactly which samples you want to examine.
# Sample specific indices
sample_indices = [0, 5, 10, 25, 50]
# Get samples as tokenized arrays
tokenized_samples = dataset_manager.sample.get_samples_by_indices(
    indices=sample_indices
)
print("Tokenized samples:")
for i, (idx, sample) in enumerate(zip(sample_indices, tokenized_samples)):
    total_tokens = sum(len(segment) for segment in sample)
    print(f"Sample {idx}: {len(sample)} segments, {total_tokens} total tokens")
    print(f"  First segment shape: {sample[0].shape}")
    print(f"  First 10 tokens: {sample[0][:10]}")
    print()
Tokenized samples: Sample 0: 12 segments, 2049 total tokens First segment shape: (70,) First 10 tokens: [ 2181 4592 15 32817 434 1652 4929 2210 3515 285] Sample 5: 8 segments, 2049 total tokens First segment shape: (495,) First 10 tokens: [1466 434 1007 387 253 7968 15 1583 1537 2028] Sample 10: 12 segments, 2049 total tokens First segment shape: (291,) First 10 tokens: [ 1388 13 597 1089 247 1943 5453 327 253 10583] Sample 25: 8 segments, 2049 total tokens First segment shape: (55,) First 10 tokens: [ 326 253 5101 1160 30244 24493 745 273 616 22513] Sample 50: 9 segments, 2049 total tokens First segment shape: (261,) First 10 tokens: [6029 434 1113 6497 327 253 3216 285 9377 15]
# Same samples but detokenized (human-readable text)
text_samples = dataset_manager.sample.get_samples_by_indices(
    indices=sample_indices,
    return_detokenized=True,
    tokenizer=tokenizer
)
print("Detokenized samples:")
for idx, text in zip(sample_indices, text_samples):
    print(f"Sample {idx} (length: {len(text)} chars):")
    print(f"  Preview: {text[:150]}...")
    print(f"  Ending: ...{text[-50:]}")
    print()
Detokenized samples: Sample 0 (length: 8990 chars): Preview: thing happened. Lily's little brother came running and accidentally stepped on the diamond. Oh no! The diamond was destroyed! Lily was very sad, but ... Ending: ...hey explored the forest, gathered flowers and made Sample 5 (length: 8404 chars): Preview: Let's look at the pictures. They might tell us something." Lila and Ben look at the pictures on the map. They see a sun, a cloud, a star, a fish, and ... Ending: ...ets on!" They both put on their bracelets and rode Sample 10 (length: 8530 chars): Preview: day, they find a big club on the grass. It is brown and heavy. "Look, a club!" Lily says. "Let's play with it!" "OK!" Ben says. "What can we do with ... Ending: ...t was scary but fun.<|endoftext|>Once upon a time, Sample 25 (length: 8254 chars): Preview: that the sun made droplets scatter off of their backs! They felt so refreshed in the warm, wet water. After a while, the frogs decided it was time to... Ending: ... to Ben. "I'm sorry, Ben," she said. "I was wrong. Sample 50 (length: 8495 chars): Preview: Ben's car fell on the ground and broke. The wheel came off and the paint scratched. "Uh oh!" Lily said, looking at Ben's car. "I'm sorry, Ben. I did ... Ending: ...r and closed her eyes. She tried to think of happy
# Get samples with document details
samples_with_details = dataset_manager.sample.get_samples_by_indices(
    indices=sample_indices[:3],  # Just first 3 for brevity
    return_detokenized=True,
    return_doc_details=True,
    tokenizer=tokenizer
)
print("Samples with document details:")
for idx, (text, doc_details) in zip(sample_indices[:3], samples_with_details):
    print(f"Sample {idx}:")
    print(f"  Text length: {len(text)} characters")
    print(f"  Document range: docs {doc_details['doc_index_f']} to {doc_details['doc_index_l']}")
    print(f"  Offset range: {doc_details['offset_f']} to {doc_details['offset_l']}")
    print(f"  Spans multiple docs: {doc_details['doc_index_f'] != doc_details['doc_index_l']}")
    print(f"  Preview: {text[:100]}...")
    print()
Samples with document details: Sample 0: Text length: 8990 characters Document range: docs 11212 to 11223 Offset range: 67 to 154 Spans multiple docs: True Preview: thing happened. Lily's little brother came running and accidentally stepped on the diamond. Oh no! ... Sample 5: Text length: 8404 characters Document range: docs 5991 to 5998 Offset range: 226 to 134 Spans multiple docs: True Preview: Let's look at the pictures. They might tell us something." Lila and Ben look at the pictures on the ... Sample 10: Text length: 8530 characters Document range: docs 7983 to 7994 Offset range: 16 to 4 Spans multiple docs: True Preview: day, they find a big club on the grass. It is brown and heavy. "Look, a club!" Lily says. "Let's pl...
Batch Sampling by IDs¶
When working with training data, you often want to examine entire batches as they would appear during training. TokenSmith allows you to sample specific batches by their IDs.
# Sample specific batches
batch_ids = [0, 2, 5]
batch_size = 4  # Small batch size for easier examination
batches = dataset_manager.sample.get_batches_by_ids(
    batch_ids=batch_ids,
    batch_size=batch_size,
    return_detokenized=True,
    tokenizer=tokenizer
)
print(f"Sampled {len(batches)} batches:")
for batch_idx, batch in enumerate(batches):
    batch_id = batch_ids[batch_idx]
    print(f"\nBatch {batch_id} (size: {len(batch)}):")
    for sample_idx, sample in enumerate(batch):
        global_sample_id = batch_id * batch_size + sample_idx
        print(f"  Sample {sample_idx} (global ID {global_sample_id}): {len(sample)} chars")
        print(f"    Preview: {sample[:80]}...")
Sampled 3 batches:
Batch 0 (size: 4):
  Sample 0 (global ID 0): 8990 chars
    Preview:  thing happened. Lily's little brother came running and accidentally stepped on ...
  Sample 1 (global ID 1): 8388 chars
    Preview: . She had gone to the office for a minute. Lily had an idea. "Let's steal some c...
  Sample 2 (global ID 2): 8789 chars
    Preview:  agreed to marry him. They had a wonderful wedding and were very happy together....
  Sample 3 (global ID 3): 8700 chars
    Preview:  sleep, Maggie's mommy saw something very rare and wet. It was raining outside a...
Batch 2 (size: 4):
  Sample 0 (global ID 8): 8547 chars
    Preview:  park. They see the slide and the swing. They wish they had been nice. They wish...
  Sample 1 (global ID 9): 8731 chars
    Preview:  run around the faucet, letting the water spray all over him. His mom, seeing th...
  Sample 2 (global ID 10): 8530 chars
    Preview:  day, they find a big club on the grass. It is brown and heavy. "Look, a club!" ...
  Sample 3 (global ID 11): 8813 chars
    Preview:  and dad said it seemed too dangerous, so they said no. The boy didn't listen an...
Batch 5 (size: 4):
  Sample 0 (global ID 20): 8650 chars
    Preview:  saying thank you and sorry. They did not know that they had missed a chance to ...
  Sample 1 (global ID 21): 8398 chars
    Preview:  it and take care of it. It will be your new friend." Tim and Lily hug their par...
  Sample 2 (global ID 22): 8857 chars
    Preview:  Suddenly, a big, brown pony appeared in front of her! Lily was so happy and hug...
  Sample 3 (global ID 23): 8998 chars
    Preview:  also about being kind and working hard. From that day on, Timmy worked hard and...
# Get batches with document details
batches_with_details = dataset_manager.sample.get_batches_by_ids(
    batch_ids=[1, 3],  # Just 2 batches for detailed examination
    batch_size=3,
    return_detokenized=True,
    return_doc_details=True,
    tokenizer=tokenizer
)
print("Batches with document details:")
for batch_idx, batch in enumerate(batches_with_details):
    batch_id = [1, 3][batch_idx]
    print(f"\nBatch {batch_id}:")
    
    for sample_idx, (text, doc_details) in enumerate(batch):
        global_sample_id = batch_id * 3 + sample_idx
        print(f"  Sample {sample_idx} (global ID {global_sample_id}):")
        print(f"    Length: {len(text)} chars")
        print(f"    Doc range: {doc_details['doc_index_f']}-{doc_details['doc_index_l']}")
        print(f"    Multi-doc: {doc_details['doc_index_f'] != doc_details['doc_index_l']}")
        print(f"    Preview: {text[:60]}...")
Batches with document details:
Batch 1:
  Sample 0 (global ID 3):
    Length: 8700 chars
    Doc range: 5168-5178
    Multi-doc: True
    Preview:  sleep, Maggie's mommy saw something very rare and wet. It w...
  Sample 1 (global ID 4):
    Length: 8976 chars
    Doc range: 1942-1954
    Multi-doc: True
    Preview:  Benny learned that not all big, wild animals are scary. Som...
  Sample 2 (global ID 5):
    Length: 8404 chars
    Doc range: 5991-5998
    Multi-doc: True
    Preview: Let's look at the pictures. They might tell us something." L...
Batch 3:
  Sample 0 (global ID 9):
    Length: 8731 chars
    Doc range: 6888-6899
    Multi-doc: True
    Preview:  run around the faucet, letting the water spray all over him...
  Sample 1 (global ID 10):
    Length: 8530 chars
    Doc range: 7983-7994
    Multi-doc: True
    Preview:  day, they find a big club on the grass. It is brown and hea...
  Sample 2 (global ID 11):
    Length: 8813 chars
    Doc range: 13066-13077
    Multi-doc: True
    Preview:  and dad said it seemed too dangerous, so they said no. The ...
Custom Sampling Policies¶
The real power of TokenSmith's sampling comes from policy-based selection. You can define custom functions that determine which samples or batches to retrieve, enabling sophisticated sampling strategies.
Policy Functions for Individual Samples¶
Let's start by creating various policy functions for individual sample selection.
# Define policy functions for sample selection
def random_sample_policy(num_samples, max_index, rng_seed=42):
    """
    Policy function that returns random sample indices.
    
    Args:
        num_samples: Number of samples to return
        max_index: Maximum index value (exclusive)
        rng_seed: Random seed for reproducibility
    
    Returns:
        List of random sample indices
    """
    rng = np.random.default_rng(rng_seed)
    return rng.integers(0, max_index, size=num_samples).tolist()
def sequential_sample_policy(start_index, num_samples):
    """
    Policy function that returns sequential sample indices.
    
    Args:
        start_index: Starting index
        num_samples: Number of consecutive samples to return
    
    Returns:
        List of sequential sample indices
    """
    return list(range(start_index, start_index + num_samples))
def sparse_sample_policy(start_index, num_samples, step_size=10):
    """
    Policy function that returns sparsely distributed sample indices.
    
    Args:
        start_index: Starting index
        num_samples: Number of samples to return
        step_size: Step size between samples
    
    Returns:
        List of sparse sample indices
    """
    return [start_index + i * step_size for i in range(num_samples)]
def prime_sample_policy(max_index):
    """
    Policy function that returns sample indices at prime numbers.
    
    Args:
        max_index: Maximum index to consider
    
    Returns:
        List of prime-numbered sample indices
    """
    def is_prime(n):
        if n < 2:
            return False
        for i in range(2, int(n**0.5) + 1):
            if n % i == 0:
                return False
        return True
    
    return [i for i in range(2, max_index) if is_prime(i)][:10]  # Limit to first 10 primes
print("Defined custom sampling policy functions!")
Defined custom sampling policy functions!
# Example 1: Random sampling policy
print("=== Random Sample Policy Example ===")
random_samples = dataset_manager.sample.get_samples_by_policy(
    policy_fn=random_sample_policy,
    num_samples=5,
    max_index=200,  # Sample from first 200 indices
    rng_seed=42,
    return_detokenized=True,
    tokenizer=tokenizer
)
for i, sample in enumerate(random_samples):
    print(f"Random Sample {i+1}: {len(sample)} chars")
    print(f"  Text: {sample[:100]}...")
    print()
=== Random Sample Policy Example === Random Sample 1: 8936 chars Text: and Dad's hands. They were alone in the crowd. "Mom! Dad! Where are you?" Lily shouted. "Help! Help... Random Sample 2: 8663 chars Text: He was so excited to tell her he was being patient! Mom was very proud of him and said he was growi... Random Sample 3: 8682 chars Text: 's parents took him to the park and he saw the bird again. The old man was there as well. This time,... Random Sample 4: 8544 chars Text: let go of the car. "I am sorry, Ben," Anna says. "I am sorry, Anna," Ben says. They hug each other.... Random Sample 5: 8654 chars Text: make a picture for his mom, who was at work. He took a big paper and some colors and started to dra...
# Example 2: Sequential sampling policy
print("=== Sequential Sample Policy Example ===")
sequential_samples = dataset_manager.sample.get_samples_by_policy(
    policy_fn=sequential_sample_policy,
    start_index=100,
    num_samples=4,
    return_detokenized=True,
    tokenizer=tokenizer
)
for i, sample in enumerate(sequential_samples):
    sample_id = 100 + i
    print(f"Sequential Sample {sample_id}: {len(sample)} chars")
    print(f"  Text: {sample[:100]}...")
    print()
=== Sequential Sample Policy Example === Sequential Sample 100: 8699 chars Text: and peeked over the grass to get a better look. Sally asked the sheep, “What’s your name?” The youn... Sequential Sample 101: 8723 chars Text: it." Lily nodded and put the pebble in her pocket. She played for a while and then went back inside... Sequential Sample 102: 8948 chars Text: better than the brick! Mandy was very pleased with her new snack. She ate the carrot and thanked th... Sequential Sample 103: 8639 chars Text: ran to the house. She was cold and sad. She wished she had listened to Tom. She reached the house a...
# Example 3: Sparse sampling policy with document details
print("=== Sparse Sample Policy with Document Details ===")
sparse_samples = dataset_manager.sample.get_samples_by_policy(
    policy_fn=sparse_sample_policy,
    start_index=50,
    num_samples=4,
    step_size=25,
    return_doc_details=True,
    return_detokenized=True,
    tokenizer=tokenizer
)
for i, (sample, doc_details) in enumerate(sparse_samples):
    sample_id = 50 + i * 25
    print(f"Sparse Sample {sample_id}:")
    print(f"  Length: {len(sample)} chars")
    print(f"  Document range: {doc_details['doc_index_f']} to {doc_details['doc_index_l']}")
    print(f"  Text: {sample[:80]}...")
    print()
=== Sparse Sample Policy with Document Details === Sparse Sample 50: Length: 8495 chars Document range: 14417 to 14425 Text: Ben's car fell on the ground and broke. The wheel came off and the paint scratc... Sparse Sample 75: Length: 8828 chars Document range: 16055 to 16064 Text: the museum. Sam was very happy that she was able to help BRO when he was broken... Sparse Sample 100: Length: 8699 chars Document range: 21659 to 21670 Text: and peeked over the grass to get a better look. Sally asked the sheep, “What’s ... Sparse Sample 125: Length: 8419 chars Document range: 1534 to 1544 Text: a time there was a little ice cream cone. It was filled with white, creamy ice ...
# Example 4: Prime number sampling policy
print("=== Prime Sample Policy Example ===")
prime_samples = dataset_manager.sample.get_samples_by_policy(
    policy_fn=prime_sample_policy,
    max_index=100,
    return_detokenized=True,
    tokenizer=tokenizer
)
print(f"Found {len(prime_samples)} samples at prime indices:")
prime_indices = prime_sample_policy(100)
for i, (sample, prime_idx) in enumerate(zip(prime_samples, prime_indices)):
    print(f"Prime Index {prime_idx}: {len(sample)} chars")
    print(f"  Text: {sample[:70]}...")
    if i >= 4:  # Show only first 5
        print(f"  ... and {len(prime_samples) - 5} more")
        break
=== Prime Sample Policy Example === Found 10 samples at prime indices: Prime Index 2: 8789 chars Text: agreed to marry him. They had a wonderful wedding and were very happy... Prime Index 3: 8700 chars Text: sleep, Maggie's mommy saw something very rare and wet. It was raining... Prime Index 5: 8404 chars Text: Let's look at the pictures. They might tell us something." Lila and Be... Prime Index 7: 8719 chars Text: said. Bella was shocked and excited. She hugged her parents and thank... Prime Index 11: 8813 chars Text: and dad said it seemed too dangerous, so they said no. The boy didn't... ... and 5 more
Policy Functions for Batch Sampling¶
Now let's create policy functions for batch selection, which can be useful for examining training dynamics or data distribution across batches.
# Define policy functions for batch selection
def random_batch_policy(num_batches, max_batch_id, rng_seed=42):
    """
    Policy function that returns random batch IDs.
    
    Args:
        num_batches: Number of batches to return
        max_batch_id: Maximum batch ID value (exclusive)
        rng_seed: Random seed for reproducibility
    
    Returns:
        List of random batch IDs
    """
    rng = np.random.default_rng(rng_seed)
    return rng.integers(0, max_batch_id, size=num_batches).tolist()
def sequential_batch_policy(start_batch_id, num_batches):
    """
    Policy function that returns sequential batch IDs.
    
    Args:
        start_batch_id: Starting batch ID
        num_batches: Number of consecutive batches to return
    
    Returns:
        List of sequential batch IDs
    """
    return list(range(start_batch_id, start_batch_id + num_batches))
def stride_batch_policy(start_batch_id, num_batches, stride=5):
    """
    Policy function that returns batch IDs with a specific stride.
    
    Args:
        start_batch_id: Starting batch ID
        num_batches: Number of batches to return
        stride: Stride between batch IDs
    
    Returns:
        List of strided batch IDs
    """
    return [start_batch_id + i * stride for i in range(num_batches)]
def fibonacci_batch_policy(max_batch_id):
    """
    Policy function that returns batch IDs at Fibonacci numbers.
    
    Args:
        max_batch_id: Maximum batch ID to consider
    
    Returns:
        List of Fibonacci-numbered batch IDs
    """
    fib = [1, 1]
    while fib[-1] < max_batch_id:
        fib.append(fib[-1] + fib[-2])
    
    return [f for f in fib if f < max_batch_id][:8]  # Limit to first 8
print("Defined custom batch sampling policy functions!")
Defined custom batch sampling policy functions!
# Example 1: Random batch policy
print("=== Random Batch Policy Example ===")
random_batches = dataset_manager.sample.get_batches_by_policy(
    policy_fn=random_batch_policy,
    batch_size=3,
    num_batches=2,
    max_batch_id=20,
    rng_seed=123,
    return_detokenized=True,
    tokenizer=tokenizer
)
for batch_idx, batch in enumerate(random_batches):
    print(f"Random Batch {batch_idx + 1}:")
    for sample_idx, sample in enumerate(batch):
        print(f"  Sample {sample_idx + 1}: {sample[:60]}...")
    print()
=== Random Batch Policy Example === Random Batch 1: Sample 1: thing happened. Lily's little brother came running and acci... Sample 2: . She had gone to the office for a minute. Lily had an idea.... Sample 3: agreed to marry him. They had a wonderful wedding and were ... Random Batch 2: Sample 1: play is a good thing to do and it can make people happy.<|e... Sample 2: . The ball is empty and flat. "Oh no, we broke the ball!" Li... Sample 3: dad said they could go and play on the sand and look for sh...
# Example 2: Sequential batch policy
print("=== Sequential Batch Policy Example ===")
sequential_batches = dataset_manager.sample.get_batches_by_policy(
    policy_fn=sequential_batch_policy,
    batch_size=2,
    start_batch_id=5,
    num_batches=3,
    return_detokenized=True,
    tokenizer=tokenizer
)
for batch_idx, batch in enumerate(sequential_batches):
    batch_id = 5 + batch_idx
    print(f"Sequential Batch {batch_id}:")
    for sample_idx, sample in enumerate(batch):
        print(f"  Sample {sample_idx + 1}: {sample[:60]}...")
    print()
=== Sequential Batch Policy Example === Sequential Batch 5: Sample 1: day, they find a big club on the grass. It is brown and hea... Sample 2: and dad said it seemed too dangerous, so they said no. The ... Sequential Batch 6: Sample 1: He thought it was a good idea to play under it. So Bob walk... Sample 2: dog named Spot. Spot liked nothing better than to play in t... Sequential Batch 7: Sample 1: . She says: "That is a bad man! He is not your friend! He wa... Sample 2: ," she said with a big, happy grin. Dad smiled and said, "Ye...
# Example 3: Stride batch policy with document details
print("=== Stride Batch Policy with Document Details ===")
stride_batches = dataset_manager.sample.get_batches_by_policy(
    policy_fn=stride_batch_policy,
    batch_size=2,
    start_batch_id=3,
    num_batches=3,
    stride=4,  # Skip 3 batches between selections
    return_doc_details=True,
    return_detokenized=True,
    tokenizer=tokenizer
)
for batch_idx, batch in enumerate(stride_batches):
    batch_id = 3 + batch_idx * 4
    print(f"Stride Batch {batch_id}:")
    for sample_idx, (sample, doc_details) in enumerate(batch):
        print(f"  Sample {sample_idx + 1}:")
        print(f"    Doc range: {doc_details['doc_index_f']}-{doc_details['doc_index_l']}")
        print(f"    Text: {sample[:50]}...")
    print()
=== Stride Batch Policy with Document Details ===
Stride Batch 3:
  Sample 1:
    Doc range: 18928-18938
    Text:  baby, saving them from certain injury. The creati...
  Sample 2:
    Doc range: 7508-7515
    Text:  said. Bella was shocked and excited. She hugged h...
Stride Batch 7:
  Sample 1:
    Doc range: 16979-16987
    Text: . She says: "That is a bad man! He is not your fri...
  Sample 2:
    Doc range: 4665-4676
    Text: ," she said with a big, happy grin. Dad smiled and...
Stride Batch 11:
  Sample 1:
    Doc range: 15364-15375
    Text:  Suddenly, a big, brown pony appeared in front of ...
  Sample 2:
    Doc range: 14762-14772
    Text:  also about being kind and working hard. From that...
# Example 4: Fibonacci batch policy
print("=== Fibonacci Batch Policy Example ===")
fib_batches = dataset_manager.sample.get_batches_by_policy(
    policy_fn=fibonacci_batch_policy,
    batch_size=2,
    max_batch_id=25,
    return_detokenized=True,
    tokenizer=tokenizer
)
fib_batch_ids = fibonacci_batch_policy(25)
print(f"Fibonacci batch IDs: {fib_batch_ids}")
for batch_idx, batch in enumerate(fib_batches[:3]):  # Show first 3
    batch_id = fib_batch_ids[batch_idx]
    print(f"Fibonacci Batch {batch_id}:")
    for sample_idx, sample in enumerate(batch):
        print(f"  Sample {sample_idx + 1}: {sample[:50]}...")
    print()
if len(fib_batches) > 3:
    print(f"... and {len(fib_batches) - 3} more batches")
=== Fibonacci Batch Policy Example === Fibonacci batch IDs: [1, 1, 2, 3, 5, 8, 13, 21] Fibonacci Batch 1: Sample 1: agreed to marry him. They had a wonderful wedding... Sample 2: sleep, Maggie's mommy saw something very rare and... Fibonacci Batch 1: Sample 1: agreed to marry him. They had a wonderful wedding... Sample 2: sleep, Maggie's mommy saw something very rare and... Fibonacci Batch 2: Sample 1: Benny learned that not all big, wild animals are ... Sample 2: Let's look at the pictures. They might tell us som... ... and 5 more batches
Advanced Policy Examples¶
Let's explore some more sophisticated sampling policies that might be useful for research and analysis.
# Advanced policy: Sample based on text length distribution
def length_based_sample_policy(dataset_manager, tokenizer_for_policy, num_samples=5, target_lengths=None):
    """
    Policy that samples based on desired text lengths.
    
    This is a more complex policy that examines actual samples to find ones
    matching certain criteria.
    """
    if target_lengths is None:
        target_lengths = [7000, 8000, 8500, 9000, 9500]  # Different length targets
    
    selected_indices = []
    search_range = min(500, num_samples * 50)  # Reasonable search space
    
    for target_length in target_lengths[:num_samples]:
        best_idx = None
        best_diff = float('inf')
        
        # Search through a range of samples
        for idx in range(search_range):
            try:
                # Get sample text to check length
                text = dataset_manager.sample.get_samples_by_indices(
                    indices=[idx],
                    return_detokenized=True,
                    tokenizer=tokenizer_for_policy
                )[0]
                
                diff = abs(len(text) - target_length)
                if diff < best_diff:
                    best_diff = diff
                    best_idx = idx
                    
            except Exception:
                continue
        
        if best_idx is not None:
            selected_indices.append(best_idx)
    
    return selected_indices
# Use the advanced policy
print("=== Length-Based Sample Policy Example ===")
length_samples = dataset_manager.sample.get_samples_by_policy(
    policy_fn=length_based_sample_policy,
    dataset_manager=dataset_manager,
    tokenizer_for_policy=tokenizer,
    num_samples=3,
    target_lengths=[7500, 8500, 9000],
    return_detokenized=True,
    tokenizer=tokenizer
)
target_lengths = [7500, 8500, 9000]
for i, sample in enumerate(length_samples):
    target = target_lengths[i]
    actual = len(sample)
    print(f"Target length {target}, actual length {actual} (diff: {abs(actual-target)}):")
    print(f"  Text: {sample[:80]}...")
    print()
=== Length-Based Sample Policy Example === Target length 7500, actual length 8195 (diff: 695): Text: Lily liked to gather bugs in her jar. She would look for them in the grass, unde... Target length 8500, actual length 8500 (diff: 0): Text: and dirt, and soon the necklace was clean and light. Jim was so happy, he hugge... Target length 9000, actual length 8998 (diff: 2): Text: also about being kind and working hard. From that day on, Timmy worked hard and...
# Advanced policy: Lambda functions for complex logic
print("=== Lambda Policy Functions Example ===")
# Example 1: Sample indices that are perfect squares
square_samples = dataset_manager.sample.get_samples_by_policy(
    policy_fn=lambda max_idx: [i*i for i in range(1, int(max_idx**0.5)+1) if i*i < max_idx][:5],
    max_idx=100,
    return_detokenized=True,
    tokenizer=tokenizer
)
square_indices = [i*i for i in range(1, int(100**0.5)+1) if i*i < 100][:5]
print("Samples at perfect square indices:")
for i, (sample, sq_idx) in enumerate(zip(square_samples, square_indices)):
    print(f"Square Index {sq_idx}: {sample[:60]}...")
print("\n" + "="*50)
# Example 2: Batch IDs that are powers of 2
power_batches = dataset_manager.sample.get_batches_by_policy(
    policy_fn=lambda max_power: [2**i for i in range(1, max_power+1)],
    batch_size=2,
    max_power=4,  # 2^1, 2^2, 2^3, 2^4 = batches 2, 4, 8, 16
    return_detokenized=True,
    tokenizer=tokenizer
)
power_batch_ids = [2**i for i in range(1, 5)]
print("Batches at power-of-2 IDs:")
for batch_idx, batch in enumerate(power_batches):
    batch_id = power_batch_ids[batch_idx]
    print(f"Batch ID {batch_id}:")
    for sample_idx, sample in enumerate(batch):
        print(f"  Sample {sample_idx + 1}: {sample[:50]}...")
    print()
=== Lambda Policy Functions Example === Samples at perfect square indices: Square Index 1: . She had gone to the office for a minute. Lily had an idea.... Square Index 4: Benny learned that not all big, wild animals are scary. Som... Square Index 9: run around the faucet, letting the water spray all over him... Square Index 16: you too," she says. "But you have to promise me that you wo... Square Index 25: that the sun made droplets scatter off of their backs! They... ================================================== Batches at power-of-2 IDs: Batch ID 2: Sample 1: Benny learned that not all big, wild animals are ... Sample 2: Let's look at the pictures. They might tell us som... Batch ID 4: Sample 1: park. They see the slide and the swing. They wish... Sample 2: run around the faucet, letting the water spray al... Batch ID 8: Sample 1: you too," she says. "But you have to promise me t... Sample 2: and Dad's hands. They were alone in the crowd. "M... Batch ID 16: Sample 1: my took him to the airport to see a famous plane. ... Sample 2: off. The man finally smiled back at the little gi...
Practical Sampling Utilities¶
Let's create some utility functions that demonstrate practical applications of sampling for dataset analysis and debugging.
def sample_quality_check(dataset_manager, tokenizer, num_samples=10, seed=42):
    """
    Utility function to perform quality checks on random samples.
    
    Returns statistics about sample lengths, document boundaries, etc.
    """
    # Get random samples with document details
    samples = dataset_manager.sample.get_samples_by_policy(
        policy_fn=random_sample_policy,
        num_samples=num_samples,
        max_index=500,
        rng_seed=seed,
        return_doc_details=True,
        return_detokenized=True,
        tokenizer=tokenizer
    )
    
    stats = {
        'num_samples': len(samples),
        'lengths': [],
        'multi_doc_count': 0,
        'doc_spans': [],
        'avg_length': 0,
        'length_std': 0
    }
    
    for text, doc_details in samples:
        length = len(text)
        stats['lengths'].append(length)
        
        doc_span = doc_details['doc_index_l'] - doc_details['doc_index_f'] + 1
        stats['doc_spans'].append(doc_span)
        
        if doc_details['doc_index_f'] != doc_details['doc_index_l']:
            stats['multi_doc_count'] += 1
    
    stats['avg_length'] = np.mean(stats['lengths'])
    stats['length_std'] = np.std(stats['lengths'])
    stats['avg_doc_span'] = np.mean(stats['doc_spans'])
    stats['multi_doc_percentage'] = (stats['multi_doc_count'] / num_samples) * 100
    
    return stats
# Run quality check
print("=== Dataset Quality Check Example ===")
quality_stats = sample_quality_check(dataset_manager, tokenizer, num_samples=15)
print(f"Quality check results ({quality_stats['num_samples']} samples):")
print(f"  Average length: {quality_stats['avg_length']:.1f} ± {quality_stats['length_std']:.1f} chars")
print(f"  Length range: {min(quality_stats['lengths'])} - {max(quality_stats['lengths'])} chars")
print(f"  Multi-document samples: {quality_stats['multi_doc_count']} ({quality_stats['multi_doc_percentage']:.1f}%)")
print(f"  Average document span: {quality_stats['avg_doc_span']:.1f} documents")
print(f"  Document span range: {min(quality_stats['doc_spans'])} - {max(quality_stats['doc_spans'])} documents")
=== Dataset Quality Check Example === Quality check results (15 samples): Average length: 8713.5 ± 175.3 chars Length range: 8173 - 8938 chars Multi-document samples: 15 (100.0%) Average document span: 11.1 documents Document span range: 9 - 13 documents
def compare_sampling_strategies(dataset_manager, tokenizer):
    """
    Compare different sampling strategies to understand their characteristics.
    """
    strategies = {
        'Random': lambda: random_sample_policy(5, 200, 42),
        'Sequential': lambda: sequential_sample_policy(100, 5),
        'Sparse': lambda: sparse_sample_policy(50, 5, 20),
        'Prime': lambda: prime_sample_policy(50)[:5]
    }
    
    results = {}
    
    for strategy_name, policy_fn in strategies.items():
        indices = policy_fn()
        samples = dataset_manager.sample.get_samples_by_indices(
            indices=indices,
            return_detokenized=True,
            return_doc_details=True,
            tokenizer=tokenizer
        )
        
        lengths = [len(text) for text, _ in samples]
        multi_doc_count = sum(1 for _, doc_details in samples 
                             if doc_details['doc_index_f'] != doc_details['doc_index_l'])
        
        results[strategy_name] = {
            'indices': indices,
            'avg_length': np.mean(lengths),
            'length_std': np.std(lengths),
            'multi_doc_percentage': (multi_doc_count / len(samples)) * 100
        }
    
    return results
# Compare strategies
print("=== Sampling Strategy Comparison ===")
comparison = compare_sampling_strategies(dataset_manager, tokenizer)
for strategy, stats in comparison.items():
    print(f"{strategy} sampling:")
    print(f"  Indices: {stats['indices']}")
    print(f"  Avg length: {stats['avg_length']:.1f} ± {stats['length_std']:.1f}")
    print(f"  Multi-doc %: {stats['multi_doc_percentage']:.1f}%")
    print()
=== Sampling Strategy Comparison === Random sampling: Indices: [17, 154, 130, 87, 86] Avg length: 8695.8 ± 129.4 Multi-doc %: 100.0% Sequential sampling: Indices: [100, 101, 102, 103, 104] Avg length: 8739.8 ± 107.6 Multi-doc %: 100.0% Sparse sampling: Indices: [50, 70, 90, 110, 130] Avg length: 8708.8 ± 154.4 Multi-doc %: 100.0% Prime sampling: Indices: [2, 3, 5, 7, 11] Avg length: 8685.0 ± 146.7 Multi-doc %: 100.0%
def batch_diversity_analysis(dataset_manager, tokenizer, batch_ids, batch_size):
    """
    Analyze diversity within and across batches.
    """
    batches = dataset_manager.sample.get_batches_by_ids(
        batch_ids=batch_ids,
        batch_size=batch_size,
        return_detokenized=True,
        return_doc_details=True,
        tokenizer=tokenizer
    )
    
    analysis = {
        'batch_stats': [],
        'overall_length_variance': 0,
        'cross_batch_variance': 0
    }
    
    all_lengths = []
    batch_avg_lengths = []
    
    for batch_idx, batch in enumerate(batches):
        batch_id = batch_ids[batch_idx]
        lengths = [len(text) for text, _ in batch]
        multi_docs = [1 for _, doc_details in batch 
                     if doc_details['doc_index_f'] != doc_details['doc_index_l']]
        
        batch_stat = {
            'batch_id': batch_id,
            'avg_length': np.mean(lengths),
            'length_variance': np.var(lengths),
            'multi_doc_count': sum(multi_docs)
        }
        
        analysis['batch_stats'].append(batch_stat)
        all_lengths.extend(lengths)
        batch_avg_lengths.append(batch_stat['avg_length'])
    
    analysis['overall_length_variance'] = np.var(all_lengths)
    analysis['cross_batch_variance'] = np.var(batch_avg_lengths)
    
    return analysis
# Analyze batch diversity
print("=== Batch Diversity Analysis ===")
diversity_analysis = batch_diversity_analysis(
    dataset_manager, tokenizer, 
    batch_ids=[2, 4, 6, 8], 
    batch_size=3
)
print("Batch-by-batch analysis:")
for stats in diversity_analysis['batch_stats']:
    print(f"  Batch {stats['batch_id']}:")
    print(f"    Avg length: {stats['avg_length']:.1f}")
    print(f"    Length variance: {stats['length_variance']:.1f}")
    print(f"    Multi-doc samples: {stats['multi_doc_count']}")
print(f"\nOverall statistics:")
print(f"  Overall length variance: {diversity_analysis['overall_length_variance']:.1f}")
print(f"  Cross-batch variance: {diversity_analysis['cross_batch_variance']:.1f}")
print(f"  Diversity ratio: {diversity_analysis['cross_batch_variance'] / diversity_analysis['overall_length_variance']:.3f}")
=== Batch Diversity Analysis ===
Batch-by-batch analysis:
  Batch 2:
    Avg length: 8605.3
    Length variance: 6461.6
    Multi-doc samples: 3
  Batch 4:
    Avg length: 8614.0
    Length variance: 66940.7
    Multi-doc samples: 3
  Batch 6:
    Avg length: 8777.3
    Length variance: 20437.6
    Multi-doc samples: 3
  Batch 8:
    Avg length: 8395.0
    Length variance: 20778.0
    Multi-doc samples: 3
Overall statistics:
  Overall length variance: 47074.2
  Cross-batch variance: 18419.8
  Diversity ratio: 0.391
Performance Considerations¶
When working with large datasets, it's important to understand the performance characteristics of different sampling methods.
import time
def benchmark_sampling_methods(dataset_manager, tokenizer):
    """
    Benchmark different sampling methods to understand their performance.
    """
    benchmarks = {}
    
    # Test 1: Index-based sampling
    start_time = time.time()
    indices = list(range(0, 50, 5))  # Every 5th sample from 0 to 50
    samples = dataset_manager.sample.get_samples_by_indices(
        indices=indices,
        return_detokenized=True,
        tokenizer=tokenizer
    )
    benchmarks['Index-based'] = time.time() - start_time
    
    # Test 2: Policy-based sampling (simple)
    start_time = time.time()
    samples = dataset_manager.sample.get_samples_by_policy(
        policy_fn=lambda: list(range(0, 50, 5)),
        return_detokenized=True,
        tokenizer=tokenizer
    )
    benchmarks['Policy-based (simple)'] = time.time() - start_time
    
    # Test 3: Batch sampling
    start_time = time.time()
    batches = dataset_manager.sample.get_batches_by_ids(
        batch_ids=[0, 1, 2, 3, 4],
        batch_size=2,
        return_detokenized=True,
        tokenizer=tokenizer
    )
    benchmarks['Batch sampling'] = time.time() - start_time
    
    # Test 4: With document details
    start_time = time.time()
    samples = dataset_manager.sample.get_samples_by_indices(
        indices=indices,
        return_detokenized=True,
        return_doc_details=True,
        tokenizer=tokenizer
    )
    benchmarks['With doc details'] = time.time() - start_time
    
    return benchmarks
# Run benchmarks
print("=== Performance Benchmarks ===")
benchmarks = benchmark_sampling_methods(dataset_manager, tokenizer)
print("Sampling method performance (10 samples each):")
for method, duration in benchmarks.items():
    print(f"  {method}: {duration:.4f} seconds")
# Calculate relative performance
fastest = min(benchmarks.values())
print(f"\nRelative to fastest method:")
for method, duration in benchmarks.items():
    relative = duration / fastest
    print(f"  {method}: {relative:.2f}x")
=== Performance Benchmarks === Sampling method performance (10 samples each): Index-based: 0.0183 seconds Policy-based (simple): 0.0174 seconds Batch sampling: 0.0174 seconds With doc details: 0.0191 seconds Relative to fastest method: Index-based: 1.05x Policy-based (simple): 1.00x Batch sampling: 1.00x With doc details: 1.10x
Best Practices and Tips¶
Here are some practical tips for effective sampling with TokenSmith:
# Tip 1: Efficient batch processing for large-scale analysis
def efficient_large_scale_sampling(dataset_manager, tokenizer, total_samples=100):
    """
    Demonstrate efficient sampling for large-scale analysis.
    """
    print("=== Efficient Large-Scale Sampling ===")
    
    # Instead of sampling one by one, use batch-based approaches
    batch_size = 10
    num_batches = total_samples // batch_size
    
    # Use policy to get batch IDs efficiently
    batch_ids = list(range(0, num_batches))
    
    start_time = time.time()
    all_batches = dataset_manager.sample.get_batches_by_ids(
        batch_ids=batch_ids,
        batch_size=batch_size,
        return_detokenized=True,
        tokenizer=tokenizer
    )
    
    # Flatten for analysis
    all_samples = [sample for batch in all_batches for sample in batch]
    duration = time.time() - start_time
    
    print(f"Processed {len(all_samples)} samples in {duration:.4f} seconds")
    print(f"Rate: {len(all_samples)/duration:.1f} samples/second")
    
    # Quick statistics
    lengths = [len(sample) for sample in all_samples]
    print(f"Length statistics: {np.mean(lengths):.1f} ± {np.std(lengths):.1f}")
    
    return all_samples
# Demonstrate efficient sampling
efficient_samples = efficient_large_scale_sampling(dataset_manager, tokenizer, 50)
=== Efficient Large-Scale Sampling === Processed 50 samples in 0.0913 seconds Rate: 547.8 samples/second Length statistics: 8668.9 ± 211.0
# Tip 2: Reproducible sampling with seeds
def reproducible_sampling_demo(dataset_manager, tokenizer):
    """
    Demonstrate reproducible sampling using seeds.
    """
    print("=== Reproducible Sampling Demo ===")
    
    # Same seed should give same results
    seed = 12345
    
    # First run
    samples1 = dataset_manager.sample.get_samples_by_policy(
        policy_fn=random_sample_policy,
        num_samples=5,
        max_index=100,
        rng_seed=seed,
        return_detokenized=True,
        tokenizer=tokenizer
    )
    
    # Second run with same seed
    samples2 = dataset_manager.sample.get_samples_by_policy(
        policy_fn=random_sample_policy,
        num_samples=5,
        max_index=100,
        rng_seed=seed,
        return_detokenized=True,
        tokenizer=tokenizer
    )
    
    # Check if they're identical
    identical = all(s1 == s2 for s1, s2 in zip(samples1, samples2))
    print(f"Samples identical across runs: {identical}")
    
    # Different seed should give different results
    samples3 = dataset_manager.sample.get_samples_by_policy(
        policy_fn=random_sample_policy,
        num_samples=5,
        max_index=100,
        rng_seed=seed + 1,
        return_detokenized=True,
        tokenizer=tokenizer
    )
    
    different = any(s1 != s3 for s1, s3 in zip(samples1, samples3))
    print(f"Different seed gives different samples: {different}")
reproducible_sampling_demo(dataset_manager, tokenizer)
=== Reproducible Sampling Demo === Samples identical across runs: True Different seed gives different samples: True
# Tip 3: Memory-efficient sampling for very large datasets
def memory_efficient_analysis(dataset_manager, tokenizer, sample_indices):
    """
    Demonstrate memory-efficient analysis techniques.
    """
    print("=== Memory-Efficient Analysis ===")
    
    # Process samples in chunks to avoid loading everything at once
    chunk_size = 10
    total_length = 0
    multi_doc_count = 0
    processed_count = 0
    
    for i in range(0, len(sample_indices), chunk_size):
        chunk_indices = sample_indices[i:i + chunk_size]
        
        # Process chunk
        chunk_samples = dataset_manager.sample.get_samples_by_indices(
            indices=chunk_indices,
            return_detokenized=True,
            return_doc_details=True,
            tokenizer=tokenizer
        )
        
        # Accumulate statistics without storing all data
        for text, doc_details in chunk_samples:
            total_length += len(text)
            if doc_details['doc_index_f'] != doc_details['doc_index_l']:
                multi_doc_count += 1
            processed_count += 1
        
        # Clear chunk from memory (in real scenarios, this happens automatically)
        del chunk_samples
    
    # Final statistics
    avg_length = total_length / processed_count
    multi_doc_percentage = (multi_doc_count / processed_count) * 100
    
    print(f"Processed {processed_count} samples in chunks of {chunk_size}")
    print(f"Average length: {avg_length:.1f} characters")
    print(f"Multi-document samples: {multi_doc_percentage:.1f}%")
# Demonstrate memory-efficient processing
large_sample_indices = list(range(0, 80, 2))  # Every other sample from 0 to 80
memory_efficient_analysis(dataset_manager, tokenizer, large_sample_indices)
=== Memory-Efficient Analysis === Processed 40 samples in chunks of 10 Average length: 8687.5 characters Multi-document samples: 100.0%
Summary¶
Congratulations! You've successfully learned how to use TokenSmith's flexible sampling capabilities. Here's what we covered:
Key Concepts Learned:¶
- Index-based Sampling: Direct sampling by specifying exact sample indices
- Batch Sampling: Retrieving complete training batches by their IDs
- Policy-based Sampling: Using custom functions to define sampling strategies
- Advanced Policies: Creating sophisticated sampling logic for research needs
- Performance Optimization: Understanding trade-offs and efficiency considerations
- Best Practices: Reproducible, memory-efficient, and scalable sampling techniques
Key Methods Used:¶
- dataset_manager.sample.get_samples_by_indices()- Sample by specific indices
- dataset_manager.sample.get_batches_by_ids()- Sample complete batches
- dataset_manager.sample.get_samples_by_policy()- Policy-based sample selection
- dataset_manager.sample.get_batches_by_policy()- Policy-based batch selection
Sampling Strategies Explored:¶
- Random sampling: For unbiased dataset exploration
- Sequential sampling: For examining consecutive samples
- Sparse sampling: For distributed dataset coverage
- Mathematical patterns: Prime numbers, Fibonacci, powers of 2
- Length-based sampling: Targeting specific text characteristics
- Custom lambda policies: Flexible, inline sampling logic
Performance Insights:¶
- Index-based sampling is fastest for known indices
- Policy-based sampling adds minimal overhead for simple policies
- Batch sampling is efficient for processing multiple samples
- Document details add some processing cost but provide valuable metadata
- Memory-efficient chunking enables large-scale analysis
Pro Tips:¶
- Use seeds for reproducible sampling in research
- Prefer batch-based operations for large-scale analysis
- Create reusable policy functions for common sampling patterns
- Consider memory usage when processing large numbers of samples
- Combine sampling with inspection for comprehensive dataset understanding