Dataset Sampling Tutorial¶
This tutorial demonstrates how to use TokenSmith's flexible sampling functionality to retrieve data samples and batches using various strategies. We'll explore different sampling methods, custom policies, and advanced techniques for data selection.
Prerequisites:
- Complete tutorials 1 and 2 (basic setup and inspection)
- Have a tokenized dataset ready with batch info generated
What you'll learn:
- How to sample data by specific indices
- How to sample batches by IDs
- Creating and using custom sampling policies
- Policy-based sampling for individual samples and batches
- Advanced sampling strategies for research and analysis
- Performance considerations for different sampling methods
Setup¶
Let's start by setting up our environment and dataset manager, building on the previous tutorials.
# Fix paths for imports
import sys
sys.path.insert(0, "/NS/llm-pretraining/work/afkhan/tokensmith")
sys.path.insert(0, "/NS/llm-pretraining/work/afkhan/USC_Colab/gpt-neox")
# Import required libraries
import numpy as np
import random
from transformers import AutoTokenizer
from tokensmith.manager import DatasetManager
# Load tokenizer
TOKENIZER_NAME_OR_PATH = "EleutherAI/gpt-neox-20b"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME_OR_PATH, add_eos_token=True)
print(f"Loaded tokenizer: {TOKENIZER_NAME_OR_PATH}")
[2025-07-01 12:32:22,370] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/NS/venvs/work/afkhan/neox_updated_env/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:47: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @autocast_custom_fwd /NS/venvs/work/afkhan/neox_updated_env/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:66: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead. @autocast_custom_bwd /NS/venvs/work/afkhan/neox_updated_env/lib/python3.11/site-packages/huggingface_hub/file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn( Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loaded tokenizer: EleutherAI/gpt-neox-20b
# Initialize DatasetManager
dataset_manager = DatasetManager()
# Setup the dataset for sampling
dataset_manager.setup_edit_inspect_sample_export(
dataset_prefix='../../artifacts/data_tokenized_text_document',
batch_info_save_prefix='../../artifacts/batch_info',
train_iters=100,
train_batch_size=16,
train_seq_len=2048,
seed=42,
splits_string='990,5,5',
packing_impl='packed',
allow_chopped=True,
)
print("Dataset manager setup complete!")
warming up index mmap file... reading sizes... reading pointers... reading document index... Dataset manager setup complete!
Basic Sampling by Indices¶
The most straightforward way to sample data is by specifying exact sample indices. This is useful when you know exactly which samples you want to examine.
# Sample specific indices
sample_indices = [0, 5, 10, 25, 50]
# Get samples as tokenized arrays
tokenized_samples = dataset_manager.sample.get_samples_by_indices(
indices=sample_indices
)
print("Tokenized samples:")
for i, (idx, sample) in enumerate(zip(sample_indices, tokenized_samples)):
total_tokens = sum(len(segment) for segment in sample)
print(f"Sample {idx}: {len(sample)} segments, {total_tokens} total tokens")
print(f" First segment shape: {sample[0].shape}")
print(f" First 10 tokens: {sample[0][:10]}")
print()
Tokenized samples: Sample 0: 12 segments, 2049 total tokens First segment shape: (70,) First 10 tokens: [ 2181 4592 15 32817 434 1652 4929 2210 3515 285] Sample 5: 8 segments, 2049 total tokens First segment shape: (495,) First 10 tokens: [1466 434 1007 387 253 7968 15 1583 1537 2028] Sample 10: 12 segments, 2049 total tokens First segment shape: (291,) First 10 tokens: [ 1388 13 597 1089 247 1943 5453 327 253 10583] Sample 25: 8 segments, 2049 total tokens First segment shape: (55,) First 10 tokens: [ 326 253 5101 1160 30244 24493 745 273 616 22513] Sample 50: 9 segments, 2049 total tokens First segment shape: (261,) First 10 tokens: [6029 434 1113 6497 327 253 3216 285 9377 15]
# Same samples but detokenized (human-readable text)
text_samples = dataset_manager.sample.get_samples_by_indices(
indices=sample_indices,
return_detokenized=True,
tokenizer=tokenizer
)
print("Detokenized samples:")
for idx, text in zip(sample_indices, text_samples):
print(f"Sample {idx} (length: {len(text)} chars):")
print(f" Preview: {text[:150]}...")
print(f" Ending: ...{text[-50:]}")
print()
Detokenized samples: Sample 0 (length: 8990 chars): Preview: thing happened. Lily's little brother came running and accidentally stepped on the diamond. Oh no! The diamond was destroyed! Lily was very sad, but ... Ending: ...hey explored the forest, gathered flowers and made Sample 5 (length: 8404 chars): Preview: Let's look at the pictures. They might tell us something." Lila and Ben look at the pictures on the map. They see a sun, a cloud, a star, a fish, and ... Ending: ...ets on!" They both put on their bracelets and rode Sample 10 (length: 8530 chars): Preview: day, they find a big club on the grass. It is brown and heavy. "Look, a club!" Lily says. "Let's play with it!" "OK!" Ben says. "What can we do with ... Ending: ...t was scary but fun.<|endoftext|>Once upon a time, Sample 25 (length: 8254 chars): Preview: that the sun made droplets scatter off of their backs! They felt so refreshed in the warm, wet water. After a while, the frogs decided it was time to... Ending: ... to Ben. "I'm sorry, Ben," she said. "I was wrong. Sample 50 (length: 8495 chars): Preview: Ben's car fell on the ground and broke. The wheel came off and the paint scratched. "Uh oh!" Lily said, looking at Ben's car. "I'm sorry, Ben. I did ... Ending: ...r and closed her eyes. She tried to think of happy
# Get samples with document details
samples_with_details = dataset_manager.sample.get_samples_by_indices(
indices=sample_indices[:3], # Just first 3 for brevity
return_detokenized=True,
return_doc_details=True,
tokenizer=tokenizer
)
print("Samples with document details:")
for idx, (text, doc_details) in zip(sample_indices[:3], samples_with_details):
print(f"Sample {idx}:")
print(f" Text length: {len(text)} characters")
print(f" Document range: docs {doc_details['doc_index_f']} to {doc_details['doc_index_l']}")
print(f" Offset range: {doc_details['offset_f']} to {doc_details['offset_l']}")
print(f" Spans multiple docs: {doc_details['doc_index_f'] != doc_details['doc_index_l']}")
print(f" Preview: {text[:100]}...")
print()
Samples with document details: Sample 0: Text length: 8990 characters Document range: docs 11212 to 11223 Offset range: 67 to 154 Spans multiple docs: True Preview: thing happened. Lily's little brother came running and accidentally stepped on the diamond. Oh no! ... Sample 5: Text length: 8404 characters Document range: docs 5991 to 5998 Offset range: 226 to 134 Spans multiple docs: True Preview: Let's look at the pictures. They might tell us something." Lila and Ben look at the pictures on the ... Sample 10: Text length: 8530 characters Document range: docs 7983 to 7994 Offset range: 16 to 4 Spans multiple docs: True Preview: day, they find a big club on the grass. It is brown and heavy. "Look, a club!" Lily says. "Let's pl...
Batch Sampling by IDs¶
When working with training data, you often want to examine entire batches as they would appear during training. TokenSmith allows you to sample specific batches by their IDs.
# Sample specific batches
batch_ids = [0, 2, 5]
batch_size = 4 # Small batch size for easier examination
batches = dataset_manager.sample.get_batches_by_ids(
batch_ids=batch_ids,
batch_size=batch_size,
return_detokenized=True,
tokenizer=tokenizer
)
print(f"Sampled {len(batches)} batches:")
for batch_idx, batch in enumerate(batches):
batch_id = batch_ids[batch_idx]
print(f"\nBatch {batch_id} (size: {len(batch)}):")
for sample_idx, sample in enumerate(batch):
global_sample_id = batch_id * batch_size + sample_idx
print(f" Sample {sample_idx} (global ID {global_sample_id}): {len(sample)} chars")
print(f" Preview: {sample[:80]}...")
Sampled 3 batches: Batch 0 (size: 4): Sample 0 (global ID 0): 8990 chars Preview: thing happened. Lily's little brother came running and accidentally stepped on ... Sample 1 (global ID 1): 8388 chars Preview: . She had gone to the office for a minute. Lily had an idea. "Let's steal some c... Sample 2 (global ID 2): 8789 chars Preview: agreed to marry him. They had a wonderful wedding and were very happy together.... Sample 3 (global ID 3): 8700 chars Preview: sleep, Maggie's mommy saw something very rare and wet. It was raining outside a... Batch 2 (size: 4): Sample 0 (global ID 8): 8547 chars Preview: park. They see the slide and the swing. They wish they had been nice. They wish... Sample 1 (global ID 9): 8731 chars Preview: run around the faucet, letting the water spray all over him. His mom, seeing th... Sample 2 (global ID 10): 8530 chars Preview: day, they find a big club on the grass. It is brown and heavy. "Look, a club!" ... Sample 3 (global ID 11): 8813 chars Preview: and dad said it seemed too dangerous, so they said no. The boy didn't listen an... Batch 5 (size: 4): Sample 0 (global ID 20): 8650 chars Preview: saying thank you and sorry. They did not know that they had missed a chance to ... Sample 1 (global ID 21): 8398 chars Preview: it and take care of it. It will be your new friend." Tim and Lily hug their par... Sample 2 (global ID 22): 8857 chars Preview: Suddenly, a big, brown pony appeared in front of her! Lily was so happy and hug... Sample 3 (global ID 23): 8998 chars Preview: also about being kind and working hard. From that day on, Timmy worked hard and...
# Get batches with document details
batches_with_details = dataset_manager.sample.get_batches_by_ids(
batch_ids=[1, 3], # Just 2 batches for detailed examination
batch_size=3,
return_detokenized=True,
return_doc_details=True,
tokenizer=tokenizer
)
print("Batches with document details:")
for batch_idx, batch in enumerate(batches_with_details):
batch_id = [1, 3][batch_idx]
print(f"\nBatch {batch_id}:")
for sample_idx, (text, doc_details) in enumerate(batch):
global_sample_id = batch_id * 3 + sample_idx
print(f" Sample {sample_idx} (global ID {global_sample_id}):")
print(f" Length: {len(text)} chars")
print(f" Doc range: {doc_details['doc_index_f']}-{doc_details['doc_index_l']}")
print(f" Multi-doc: {doc_details['doc_index_f'] != doc_details['doc_index_l']}")
print(f" Preview: {text[:60]}...")
Batches with document details: Batch 1: Sample 0 (global ID 3): Length: 8700 chars Doc range: 5168-5178 Multi-doc: True Preview: sleep, Maggie's mommy saw something very rare and wet. It w... Sample 1 (global ID 4): Length: 8976 chars Doc range: 1942-1954 Multi-doc: True Preview: Benny learned that not all big, wild animals are scary. Som... Sample 2 (global ID 5): Length: 8404 chars Doc range: 5991-5998 Multi-doc: True Preview: Let's look at the pictures. They might tell us something." L... Batch 3: Sample 0 (global ID 9): Length: 8731 chars Doc range: 6888-6899 Multi-doc: True Preview: run around the faucet, letting the water spray all over him... Sample 1 (global ID 10): Length: 8530 chars Doc range: 7983-7994 Multi-doc: True Preview: day, they find a big club on the grass. It is brown and hea... Sample 2 (global ID 11): Length: 8813 chars Doc range: 13066-13077 Multi-doc: True Preview: and dad said it seemed too dangerous, so they said no. The ...
Custom Sampling Policies¶
The real power of TokenSmith's sampling comes from policy-based selection. You can define custom functions that determine which samples or batches to retrieve, enabling sophisticated sampling strategies.
Policy Functions for Individual Samples¶
Let's start by creating various policy functions for individual sample selection.
# Define policy functions for sample selection
def random_sample_policy(num_samples, max_index, rng_seed=42):
"""
Policy function that returns random sample indices.
Args:
num_samples: Number of samples to return
max_index: Maximum index value (exclusive)
rng_seed: Random seed for reproducibility
Returns:
List of random sample indices
"""
rng = np.random.default_rng(rng_seed)
return rng.integers(0, max_index, size=num_samples).tolist()
def sequential_sample_policy(start_index, num_samples):
"""
Policy function that returns sequential sample indices.
Args:
start_index: Starting index
num_samples: Number of consecutive samples to return
Returns:
List of sequential sample indices
"""
return list(range(start_index, start_index + num_samples))
def sparse_sample_policy(start_index, num_samples, step_size=10):
"""
Policy function that returns sparsely distributed sample indices.
Args:
start_index: Starting index
num_samples: Number of samples to return
step_size: Step size between samples
Returns:
List of sparse sample indices
"""
return [start_index + i * step_size for i in range(num_samples)]
def prime_sample_policy(max_index):
"""
Policy function that returns sample indices at prime numbers.
Args:
max_index: Maximum index to consider
Returns:
List of prime-numbered sample indices
"""
def is_prime(n):
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
return [i for i in range(2, max_index) if is_prime(i)][:10] # Limit to first 10 primes
print("Defined custom sampling policy functions!")
Defined custom sampling policy functions!
# Example 1: Random sampling policy
print("=== Random Sample Policy Example ===")
random_samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=random_sample_policy,
num_samples=5,
max_index=200, # Sample from first 200 indices
rng_seed=42,
return_detokenized=True,
tokenizer=tokenizer
)
for i, sample in enumerate(random_samples):
print(f"Random Sample {i+1}: {len(sample)} chars")
print(f" Text: {sample[:100]}...")
print()
=== Random Sample Policy Example === Random Sample 1: 8936 chars Text: and Dad's hands. They were alone in the crowd. "Mom! Dad! Where are you?" Lily shouted. "Help! Help... Random Sample 2: 8663 chars Text: He was so excited to tell her he was being patient! Mom was very proud of him and said he was growi... Random Sample 3: 8682 chars Text: 's parents took him to the park and he saw the bird again. The old man was there as well. This time,... Random Sample 4: 8544 chars Text: let go of the car. "I am sorry, Ben," Anna says. "I am sorry, Anna," Ben says. They hug each other.... Random Sample 5: 8654 chars Text: make a picture for his mom, who was at work. He took a big paper and some colors and started to dra...
# Example 2: Sequential sampling policy
print("=== Sequential Sample Policy Example ===")
sequential_samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=sequential_sample_policy,
start_index=100,
num_samples=4,
return_detokenized=True,
tokenizer=tokenizer
)
for i, sample in enumerate(sequential_samples):
sample_id = 100 + i
print(f"Sequential Sample {sample_id}: {len(sample)} chars")
print(f" Text: {sample[:100]}...")
print()
=== Sequential Sample Policy Example === Sequential Sample 100: 8699 chars Text: and peeked over the grass to get a better look. Sally asked the sheep, “What’s your name?” The youn... Sequential Sample 101: 8723 chars Text: it." Lily nodded and put the pebble in her pocket. She played for a while and then went back inside... Sequential Sample 102: 8948 chars Text: better than the brick! Mandy was very pleased with her new snack. She ate the carrot and thanked th... Sequential Sample 103: 8639 chars Text: ran to the house. She was cold and sad. She wished she had listened to Tom. She reached the house a...
# Example 3: Sparse sampling policy with document details
print("=== Sparse Sample Policy with Document Details ===")
sparse_samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=sparse_sample_policy,
start_index=50,
num_samples=4,
step_size=25,
return_doc_details=True,
return_detokenized=True,
tokenizer=tokenizer
)
for i, (sample, doc_details) in enumerate(sparse_samples):
sample_id = 50 + i * 25
print(f"Sparse Sample {sample_id}:")
print(f" Length: {len(sample)} chars")
print(f" Document range: {doc_details['doc_index_f']} to {doc_details['doc_index_l']}")
print(f" Text: {sample[:80]}...")
print()
=== Sparse Sample Policy with Document Details === Sparse Sample 50: Length: 8495 chars Document range: 14417 to 14425 Text: Ben's car fell on the ground and broke. The wheel came off and the paint scratc... Sparse Sample 75: Length: 8828 chars Document range: 16055 to 16064 Text: the museum. Sam was very happy that she was able to help BRO when he was broken... Sparse Sample 100: Length: 8699 chars Document range: 21659 to 21670 Text: and peeked over the grass to get a better look. Sally asked the sheep, “What’s ... Sparse Sample 125: Length: 8419 chars Document range: 1534 to 1544 Text: a time there was a little ice cream cone. It was filled with white, creamy ice ...
# Example 4: Prime number sampling policy
print("=== Prime Sample Policy Example ===")
prime_samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=prime_sample_policy,
max_index=100,
return_detokenized=True,
tokenizer=tokenizer
)
print(f"Found {len(prime_samples)} samples at prime indices:")
prime_indices = prime_sample_policy(100)
for i, (sample, prime_idx) in enumerate(zip(prime_samples, prime_indices)):
print(f"Prime Index {prime_idx}: {len(sample)} chars")
print(f" Text: {sample[:70]}...")
if i >= 4: # Show only first 5
print(f" ... and {len(prime_samples) - 5} more")
break
=== Prime Sample Policy Example === Found 10 samples at prime indices: Prime Index 2: 8789 chars Text: agreed to marry him. They had a wonderful wedding and were very happy... Prime Index 3: 8700 chars Text: sleep, Maggie's mommy saw something very rare and wet. It was raining... Prime Index 5: 8404 chars Text: Let's look at the pictures. They might tell us something." Lila and Be... Prime Index 7: 8719 chars Text: said. Bella was shocked and excited. She hugged her parents and thank... Prime Index 11: 8813 chars Text: and dad said it seemed too dangerous, so they said no. The boy didn't... ... and 5 more
Policy Functions for Batch Sampling¶
Now let's create policy functions for batch selection, which can be useful for examining training dynamics or data distribution across batches.
# Define policy functions for batch selection
def random_batch_policy(num_batches, max_batch_id, rng_seed=42):
"""
Policy function that returns random batch IDs.
Args:
num_batches: Number of batches to return
max_batch_id: Maximum batch ID value (exclusive)
rng_seed: Random seed for reproducibility
Returns:
List of random batch IDs
"""
rng = np.random.default_rng(rng_seed)
return rng.integers(0, max_batch_id, size=num_batches).tolist()
def sequential_batch_policy(start_batch_id, num_batches):
"""
Policy function that returns sequential batch IDs.
Args:
start_batch_id: Starting batch ID
num_batches: Number of consecutive batches to return
Returns:
List of sequential batch IDs
"""
return list(range(start_batch_id, start_batch_id + num_batches))
def stride_batch_policy(start_batch_id, num_batches, stride=5):
"""
Policy function that returns batch IDs with a specific stride.
Args:
start_batch_id: Starting batch ID
num_batches: Number of batches to return
stride: Stride between batch IDs
Returns:
List of strided batch IDs
"""
return [start_batch_id + i * stride for i in range(num_batches)]
def fibonacci_batch_policy(max_batch_id):
"""
Policy function that returns batch IDs at Fibonacci numbers.
Args:
max_batch_id: Maximum batch ID to consider
Returns:
List of Fibonacci-numbered batch IDs
"""
fib = [1, 1]
while fib[-1] < max_batch_id:
fib.append(fib[-1] + fib[-2])
return [f for f in fib if f < max_batch_id][:8] # Limit to first 8
print("Defined custom batch sampling policy functions!")
Defined custom batch sampling policy functions!
# Example 1: Random batch policy
print("=== Random Batch Policy Example ===")
random_batches = dataset_manager.sample.get_batches_by_policy(
policy_fn=random_batch_policy,
batch_size=3,
num_batches=2,
max_batch_id=20,
rng_seed=123,
return_detokenized=True,
tokenizer=tokenizer
)
for batch_idx, batch in enumerate(random_batches):
print(f"Random Batch {batch_idx + 1}:")
for sample_idx, sample in enumerate(batch):
print(f" Sample {sample_idx + 1}: {sample[:60]}...")
print()
=== Random Batch Policy Example === Random Batch 1: Sample 1: thing happened. Lily's little brother came running and acci... Sample 2: . She had gone to the office for a minute. Lily had an idea.... Sample 3: agreed to marry him. They had a wonderful wedding and were ... Random Batch 2: Sample 1: play is a good thing to do and it can make people happy.<|e... Sample 2: . The ball is empty and flat. "Oh no, we broke the ball!" Li... Sample 3: dad said they could go and play on the sand and look for sh...
# Example 2: Sequential batch policy
print("=== Sequential Batch Policy Example ===")
sequential_batches = dataset_manager.sample.get_batches_by_policy(
policy_fn=sequential_batch_policy,
batch_size=2,
start_batch_id=5,
num_batches=3,
return_detokenized=True,
tokenizer=tokenizer
)
for batch_idx, batch in enumerate(sequential_batches):
batch_id = 5 + batch_idx
print(f"Sequential Batch {batch_id}:")
for sample_idx, sample in enumerate(batch):
print(f" Sample {sample_idx + 1}: {sample[:60]}...")
print()
=== Sequential Batch Policy Example === Sequential Batch 5: Sample 1: day, they find a big club on the grass. It is brown and hea... Sample 2: and dad said it seemed too dangerous, so they said no. The ... Sequential Batch 6: Sample 1: He thought it was a good idea to play under it. So Bob walk... Sample 2: dog named Spot. Spot liked nothing better than to play in t... Sequential Batch 7: Sample 1: . She says: "That is a bad man! He is not your friend! He wa... Sample 2: ," she said with a big, happy grin. Dad smiled and said, "Ye...
# Example 3: Stride batch policy with document details
print("=== Stride Batch Policy with Document Details ===")
stride_batches = dataset_manager.sample.get_batches_by_policy(
policy_fn=stride_batch_policy,
batch_size=2,
start_batch_id=3,
num_batches=3,
stride=4, # Skip 3 batches between selections
return_doc_details=True,
return_detokenized=True,
tokenizer=tokenizer
)
for batch_idx, batch in enumerate(stride_batches):
batch_id = 3 + batch_idx * 4
print(f"Stride Batch {batch_id}:")
for sample_idx, (sample, doc_details) in enumerate(batch):
print(f" Sample {sample_idx + 1}:")
print(f" Doc range: {doc_details['doc_index_f']}-{doc_details['doc_index_l']}")
print(f" Text: {sample[:50]}...")
print()
=== Stride Batch Policy with Document Details === Stride Batch 3: Sample 1: Doc range: 18928-18938 Text: baby, saving them from certain injury. The creati... Sample 2: Doc range: 7508-7515 Text: said. Bella was shocked and excited. She hugged h... Stride Batch 7: Sample 1: Doc range: 16979-16987 Text: . She says: "That is a bad man! He is not your fri... Sample 2: Doc range: 4665-4676 Text: ," she said with a big, happy grin. Dad smiled and... Stride Batch 11: Sample 1: Doc range: 15364-15375 Text: Suddenly, a big, brown pony appeared in front of ... Sample 2: Doc range: 14762-14772 Text: also about being kind and working hard. From that...
# Example 4: Fibonacci batch policy
print("=== Fibonacci Batch Policy Example ===")
fib_batches = dataset_manager.sample.get_batches_by_policy(
policy_fn=fibonacci_batch_policy,
batch_size=2,
max_batch_id=25,
return_detokenized=True,
tokenizer=tokenizer
)
fib_batch_ids = fibonacci_batch_policy(25)
print(f"Fibonacci batch IDs: {fib_batch_ids}")
for batch_idx, batch in enumerate(fib_batches[:3]): # Show first 3
batch_id = fib_batch_ids[batch_idx]
print(f"Fibonacci Batch {batch_id}:")
for sample_idx, sample in enumerate(batch):
print(f" Sample {sample_idx + 1}: {sample[:50]}...")
print()
if len(fib_batches) > 3:
print(f"... and {len(fib_batches) - 3} more batches")
=== Fibonacci Batch Policy Example === Fibonacci batch IDs: [1, 1, 2, 3, 5, 8, 13, 21] Fibonacci Batch 1: Sample 1: agreed to marry him. They had a wonderful wedding... Sample 2: sleep, Maggie's mommy saw something very rare and... Fibonacci Batch 1: Sample 1: agreed to marry him. They had a wonderful wedding... Sample 2: sleep, Maggie's mommy saw something very rare and... Fibonacci Batch 2: Sample 1: Benny learned that not all big, wild animals are ... Sample 2: Let's look at the pictures. They might tell us som... ... and 5 more batches
Advanced Policy Examples¶
Let's explore some more sophisticated sampling policies that might be useful for research and analysis.
# Advanced policy: Sample based on text length distribution
def length_based_sample_policy(dataset_manager, tokenizer_for_policy, num_samples=5, target_lengths=None):
"""
Policy that samples based on desired text lengths.
This is a more complex policy that examines actual samples to find ones
matching certain criteria.
"""
if target_lengths is None:
target_lengths = [7000, 8000, 8500, 9000, 9500] # Different length targets
selected_indices = []
search_range = min(500, num_samples * 50) # Reasonable search space
for target_length in target_lengths[:num_samples]:
best_idx = None
best_diff = float('inf')
# Search through a range of samples
for idx in range(search_range):
try:
# Get sample text to check length
text = dataset_manager.sample.get_samples_by_indices(
indices=[idx],
return_detokenized=True,
tokenizer=tokenizer_for_policy
)[0]
diff = abs(len(text) - target_length)
if diff < best_diff:
best_diff = diff
best_idx = idx
except Exception:
continue
if best_idx is not None:
selected_indices.append(best_idx)
return selected_indices
# Use the advanced policy
print("=== Length-Based Sample Policy Example ===")
length_samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=length_based_sample_policy,
dataset_manager=dataset_manager,
tokenizer_for_policy=tokenizer,
num_samples=3,
target_lengths=[7500, 8500, 9000],
return_detokenized=True,
tokenizer=tokenizer
)
target_lengths = [7500, 8500, 9000]
for i, sample in enumerate(length_samples):
target = target_lengths[i]
actual = len(sample)
print(f"Target length {target}, actual length {actual} (diff: {abs(actual-target)}):")
print(f" Text: {sample[:80]}...")
print()
=== Length-Based Sample Policy Example === Target length 7500, actual length 8195 (diff: 695): Text: Lily liked to gather bugs in her jar. She would look for them in the grass, unde... Target length 8500, actual length 8500 (diff: 0): Text: and dirt, and soon the necklace was clean and light. Jim was so happy, he hugge... Target length 9000, actual length 8998 (diff: 2): Text: also about being kind and working hard. From that day on, Timmy worked hard and...
# Advanced policy: Lambda functions for complex logic
print("=== Lambda Policy Functions Example ===")
# Example 1: Sample indices that are perfect squares
square_samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=lambda max_idx: [i*i for i in range(1, int(max_idx**0.5)+1) if i*i < max_idx][:5],
max_idx=100,
return_detokenized=True,
tokenizer=tokenizer
)
square_indices = [i*i for i in range(1, int(100**0.5)+1) if i*i < 100][:5]
print("Samples at perfect square indices:")
for i, (sample, sq_idx) in enumerate(zip(square_samples, square_indices)):
print(f"Square Index {sq_idx}: {sample[:60]}...")
print("\n" + "="*50)
# Example 2: Batch IDs that are powers of 2
power_batches = dataset_manager.sample.get_batches_by_policy(
policy_fn=lambda max_power: [2**i for i in range(1, max_power+1)],
batch_size=2,
max_power=4, # 2^1, 2^2, 2^3, 2^4 = batches 2, 4, 8, 16
return_detokenized=True,
tokenizer=tokenizer
)
power_batch_ids = [2**i for i in range(1, 5)]
print("Batches at power-of-2 IDs:")
for batch_idx, batch in enumerate(power_batches):
batch_id = power_batch_ids[batch_idx]
print(f"Batch ID {batch_id}:")
for sample_idx, sample in enumerate(batch):
print(f" Sample {sample_idx + 1}: {sample[:50]}...")
print()
=== Lambda Policy Functions Example === Samples at perfect square indices: Square Index 1: . She had gone to the office for a minute. Lily had an idea.... Square Index 4: Benny learned that not all big, wild animals are scary. Som... Square Index 9: run around the faucet, letting the water spray all over him... Square Index 16: you too," she says. "But you have to promise me that you wo... Square Index 25: that the sun made droplets scatter off of their backs! They... ================================================== Batches at power-of-2 IDs: Batch ID 2: Sample 1: Benny learned that not all big, wild animals are ... Sample 2: Let's look at the pictures. They might tell us som... Batch ID 4: Sample 1: park. They see the slide and the swing. They wish... Sample 2: run around the faucet, letting the water spray al... Batch ID 8: Sample 1: you too," she says. "But you have to promise me t... Sample 2: and Dad's hands. They were alone in the crowd. "M... Batch ID 16: Sample 1: my took him to the airport to see a famous plane. ... Sample 2: off. The man finally smiled back at the little gi...
Practical Sampling Utilities¶
Let's create some utility functions that demonstrate practical applications of sampling for dataset analysis and debugging.
def sample_quality_check(dataset_manager, tokenizer, num_samples=10, seed=42):
"""
Utility function to perform quality checks on random samples.
Returns statistics about sample lengths, document boundaries, etc.
"""
# Get random samples with document details
samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=random_sample_policy,
num_samples=num_samples,
max_index=500,
rng_seed=seed,
return_doc_details=True,
return_detokenized=True,
tokenizer=tokenizer
)
stats = {
'num_samples': len(samples),
'lengths': [],
'multi_doc_count': 0,
'doc_spans': [],
'avg_length': 0,
'length_std': 0
}
for text, doc_details in samples:
length = len(text)
stats['lengths'].append(length)
doc_span = doc_details['doc_index_l'] - doc_details['doc_index_f'] + 1
stats['doc_spans'].append(doc_span)
if doc_details['doc_index_f'] != doc_details['doc_index_l']:
stats['multi_doc_count'] += 1
stats['avg_length'] = np.mean(stats['lengths'])
stats['length_std'] = np.std(stats['lengths'])
stats['avg_doc_span'] = np.mean(stats['doc_spans'])
stats['multi_doc_percentage'] = (stats['multi_doc_count'] / num_samples) * 100
return stats
# Run quality check
print("=== Dataset Quality Check Example ===")
quality_stats = sample_quality_check(dataset_manager, tokenizer, num_samples=15)
print(f"Quality check results ({quality_stats['num_samples']} samples):")
print(f" Average length: {quality_stats['avg_length']:.1f} ± {quality_stats['length_std']:.1f} chars")
print(f" Length range: {min(quality_stats['lengths'])} - {max(quality_stats['lengths'])} chars")
print(f" Multi-document samples: {quality_stats['multi_doc_count']} ({quality_stats['multi_doc_percentage']:.1f}%)")
print(f" Average document span: {quality_stats['avg_doc_span']:.1f} documents")
print(f" Document span range: {min(quality_stats['doc_spans'])} - {max(quality_stats['doc_spans'])} documents")
=== Dataset Quality Check Example === Quality check results (15 samples): Average length: 8713.5 ± 175.3 chars Length range: 8173 - 8938 chars Multi-document samples: 15 (100.0%) Average document span: 11.1 documents Document span range: 9 - 13 documents
def compare_sampling_strategies(dataset_manager, tokenizer):
"""
Compare different sampling strategies to understand their characteristics.
"""
strategies = {
'Random': lambda: random_sample_policy(5, 200, 42),
'Sequential': lambda: sequential_sample_policy(100, 5),
'Sparse': lambda: sparse_sample_policy(50, 5, 20),
'Prime': lambda: prime_sample_policy(50)[:5]
}
results = {}
for strategy_name, policy_fn in strategies.items():
indices = policy_fn()
samples = dataset_manager.sample.get_samples_by_indices(
indices=indices,
return_detokenized=True,
return_doc_details=True,
tokenizer=tokenizer
)
lengths = [len(text) for text, _ in samples]
multi_doc_count = sum(1 for _, doc_details in samples
if doc_details['doc_index_f'] != doc_details['doc_index_l'])
results[strategy_name] = {
'indices': indices,
'avg_length': np.mean(lengths),
'length_std': np.std(lengths),
'multi_doc_percentage': (multi_doc_count / len(samples)) * 100
}
return results
# Compare strategies
print("=== Sampling Strategy Comparison ===")
comparison = compare_sampling_strategies(dataset_manager, tokenizer)
for strategy, stats in comparison.items():
print(f"{strategy} sampling:")
print(f" Indices: {stats['indices']}")
print(f" Avg length: {stats['avg_length']:.1f} ± {stats['length_std']:.1f}")
print(f" Multi-doc %: {stats['multi_doc_percentage']:.1f}%")
print()
=== Sampling Strategy Comparison === Random sampling: Indices: [17, 154, 130, 87, 86] Avg length: 8695.8 ± 129.4 Multi-doc %: 100.0% Sequential sampling: Indices: [100, 101, 102, 103, 104] Avg length: 8739.8 ± 107.6 Multi-doc %: 100.0% Sparse sampling: Indices: [50, 70, 90, 110, 130] Avg length: 8708.8 ± 154.4 Multi-doc %: 100.0% Prime sampling: Indices: [2, 3, 5, 7, 11] Avg length: 8685.0 ± 146.7 Multi-doc %: 100.0%
def batch_diversity_analysis(dataset_manager, tokenizer, batch_ids, batch_size):
"""
Analyze diversity within and across batches.
"""
batches = dataset_manager.sample.get_batches_by_ids(
batch_ids=batch_ids,
batch_size=batch_size,
return_detokenized=True,
return_doc_details=True,
tokenizer=tokenizer
)
analysis = {
'batch_stats': [],
'overall_length_variance': 0,
'cross_batch_variance': 0
}
all_lengths = []
batch_avg_lengths = []
for batch_idx, batch in enumerate(batches):
batch_id = batch_ids[batch_idx]
lengths = [len(text) for text, _ in batch]
multi_docs = [1 for _, doc_details in batch
if doc_details['doc_index_f'] != doc_details['doc_index_l']]
batch_stat = {
'batch_id': batch_id,
'avg_length': np.mean(lengths),
'length_variance': np.var(lengths),
'multi_doc_count': sum(multi_docs)
}
analysis['batch_stats'].append(batch_stat)
all_lengths.extend(lengths)
batch_avg_lengths.append(batch_stat['avg_length'])
analysis['overall_length_variance'] = np.var(all_lengths)
analysis['cross_batch_variance'] = np.var(batch_avg_lengths)
return analysis
# Analyze batch diversity
print("=== Batch Diversity Analysis ===")
diversity_analysis = batch_diversity_analysis(
dataset_manager, tokenizer,
batch_ids=[2, 4, 6, 8],
batch_size=3
)
print("Batch-by-batch analysis:")
for stats in diversity_analysis['batch_stats']:
print(f" Batch {stats['batch_id']}:")
print(f" Avg length: {stats['avg_length']:.1f}")
print(f" Length variance: {stats['length_variance']:.1f}")
print(f" Multi-doc samples: {stats['multi_doc_count']}")
print(f"\nOverall statistics:")
print(f" Overall length variance: {diversity_analysis['overall_length_variance']:.1f}")
print(f" Cross-batch variance: {diversity_analysis['cross_batch_variance']:.1f}")
print(f" Diversity ratio: {diversity_analysis['cross_batch_variance'] / diversity_analysis['overall_length_variance']:.3f}")
=== Batch Diversity Analysis === Batch-by-batch analysis: Batch 2: Avg length: 8605.3 Length variance: 6461.6 Multi-doc samples: 3 Batch 4: Avg length: 8614.0 Length variance: 66940.7 Multi-doc samples: 3 Batch 6: Avg length: 8777.3 Length variance: 20437.6 Multi-doc samples: 3 Batch 8: Avg length: 8395.0 Length variance: 20778.0 Multi-doc samples: 3 Overall statistics: Overall length variance: 47074.2 Cross-batch variance: 18419.8 Diversity ratio: 0.391
Performance Considerations¶
When working with large datasets, it's important to understand the performance characteristics of different sampling methods.
import time
def benchmark_sampling_methods(dataset_manager, tokenizer):
"""
Benchmark different sampling methods to understand their performance.
"""
benchmarks = {}
# Test 1: Index-based sampling
start_time = time.time()
indices = list(range(0, 50, 5)) # Every 5th sample from 0 to 50
samples = dataset_manager.sample.get_samples_by_indices(
indices=indices,
return_detokenized=True,
tokenizer=tokenizer
)
benchmarks['Index-based'] = time.time() - start_time
# Test 2: Policy-based sampling (simple)
start_time = time.time()
samples = dataset_manager.sample.get_samples_by_policy(
policy_fn=lambda: list(range(0, 50, 5)),
return_detokenized=True,
tokenizer=tokenizer
)
benchmarks['Policy-based (simple)'] = time.time() - start_time
# Test 3: Batch sampling
start_time = time.time()
batches = dataset_manager.sample.get_batches_by_ids(
batch_ids=[0, 1, 2, 3, 4],
batch_size=2,
return_detokenized=True,
tokenizer=tokenizer
)
benchmarks['Batch sampling'] = time.time() - start_time
# Test 4: With document details
start_time = time.time()
samples = dataset_manager.sample.get_samples_by_indices(
indices=indices,
return_detokenized=True,
return_doc_details=True,
tokenizer=tokenizer
)
benchmarks['With doc details'] = time.time() - start_time
return benchmarks
# Run benchmarks
print("=== Performance Benchmarks ===")
benchmarks = benchmark_sampling_methods(dataset_manager, tokenizer)
print("Sampling method performance (10 samples each):")
for method, duration in benchmarks.items():
print(f" {method}: {duration:.4f} seconds")
# Calculate relative performance
fastest = min(benchmarks.values())
print(f"\nRelative to fastest method:")
for method, duration in benchmarks.items():
relative = duration / fastest
print(f" {method}: {relative:.2f}x")
=== Performance Benchmarks === Sampling method performance (10 samples each): Index-based: 0.0183 seconds Policy-based (simple): 0.0174 seconds Batch sampling: 0.0174 seconds With doc details: 0.0191 seconds Relative to fastest method: Index-based: 1.05x Policy-based (simple): 1.00x Batch sampling: 1.00x With doc details: 1.10x
Best Practices and Tips¶
Here are some practical tips for effective sampling with TokenSmith:
# Tip 1: Efficient batch processing for large-scale analysis
def efficient_large_scale_sampling(dataset_manager, tokenizer, total_samples=100):
"""
Demonstrate efficient sampling for large-scale analysis.
"""
print("=== Efficient Large-Scale Sampling ===")
# Instead of sampling one by one, use batch-based approaches
batch_size = 10
num_batches = total_samples // batch_size
# Use policy to get batch IDs efficiently
batch_ids = list(range(0, num_batches))
start_time = time.time()
all_batches = dataset_manager.sample.get_batches_by_ids(
batch_ids=batch_ids,
batch_size=batch_size,
return_detokenized=True,
tokenizer=tokenizer
)
# Flatten for analysis
all_samples = [sample for batch in all_batches for sample in batch]
duration = time.time() - start_time
print(f"Processed {len(all_samples)} samples in {duration:.4f} seconds")
print(f"Rate: {len(all_samples)/duration:.1f} samples/second")
# Quick statistics
lengths = [len(sample) for sample in all_samples]
print(f"Length statistics: {np.mean(lengths):.1f} ± {np.std(lengths):.1f}")
return all_samples
# Demonstrate efficient sampling
efficient_samples = efficient_large_scale_sampling(dataset_manager, tokenizer, 50)
=== Efficient Large-Scale Sampling === Processed 50 samples in 0.0913 seconds Rate: 547.8 samples/second Length statistics: 8668.9 ± 211.0
# Tip 2: Reproducible sampling with seeds
def reproducible_sampling_demo(dataset_manager, tokenizer):
"""
Demonstrate reproducible sampling using seeds.
"""
print("=== Reproducible Sampling Demo ===")
# Same seed should give same results
seed = 12345
# First run
samples1 = dataset_manager.sample.get_samples_by_policy(
policy_fn=random_sample_policy,
num_samples=5,
max_index=100,
rng_seed=seed,
return_detokenized=True,
tokenizer=tokenizer
)
# Second run with same seed
samples2 = dataset_manager.sample.get_samples_by_policy(
policy_fn=random_sample_policy,
num_samples=5,
max_index=100,
rng_seed=seed,
return_detokenized=True,
tokenizer=tokenizer
)
# Check if they're identical
identical = all(s1 == s2 for s1, s2 in zip(samples1, samples2))
print(f"Samples identical across runs: {identical}")
# Different seed should give different results
samples3 = dataset_manager.sample.get_samples_by_policy(
policy_fn=random_sample_policy,
num_samples=5,
max_index=100,
rng_seed=seed + 1,
return_detokenized=True,
tokenizer=tokenizer
)
different = any(s1 != s3 for s1, s3 in zip(samples1, samples3))
print(f"Different seed gives different samples: {different}")
reproducible_sampling_demo(dataset_manager, tokenizer)
=== Reproducible Sampling Demo === Samples identical across runs: True Different seed gives different samples: True
# Tip 3: Memory-efficient sampling for very large datasets
def memory_efficient_analysis(dataset_manager, tokenizer, sample_indices):
"""
Demonstrate memory-efficient analysis techniques.
"""
print("=== Memory-Efficient Analysis ===")
# Process samples in chunks to avoid loading everything at once
chunk_size = 10
total_length = 0
multi_doc_count = 0
processed_count = 0
for i in range(0, len(sample_indices), chunk_size):
chunk_indices = sample_indices[i:i + chunk_size]
# Process chunk
chunk_samples = dataset_manager.sample.get_samples_by_indices(
indices=chunk_indices,
return_detokenized=True,
return_doc_details=True,
tokenizer=tokenizer
)
# Accumulate statistics without storing all data
for text, doc_details in chunk_samples:
total_length += len(text)
if doc_details['doc_index_f'] != doc_details['doc_index_l']:
multi_doc_count += 1
processed_count += 1
# Clear chunk from memory (in real scenarios, this happens automatically)
del chunk_samples
# Final statistics
avg_length = total_length / processed_count
multi_doc_percentage = (multi_doc_count / processed_count) * 100
print(f"Processed {processed_count} samples in chunks of {chunk_size}")
print(f"Average length: {avg_length:.1f} characters")
print(f"Multi-document samples: {multi_doc_percentage:.1f}%")
# Demonstrate memory-efficient processing
large_sample_indices = list(range(0, 80, 2)) # Every other sample from 0 to 80
memory_efficient_analysis(dataset_manager, tokenizer, large_sample_indices)
=== Memory-Efficient Analysis === Processed 40 samples in chunks of 10 Average length: 8687.5 characters Multi-document samples: 100.0%
Summary¶
Congratulations! You've successfully learned how to use TokenSmith's flexible sampling capabilities. Here's what we covered:
Key Concepts Learned:¶
- Index-based Sampling: Direct sampling by specifying exact sample indices
- Batch Sampling: Retrieving complete training batches by their IDs
- Policy-based Sampling: Using custom functions to define sampling strategies
- Advanced Policies: Creating sophisticated sampling logic for research needs
- Performance Optimization: Understanding trade-offs and efficiency considerations
- Best Practices: Reproducible, memory-efficient, and scalable sampling techniques
Key Methods Used:¶
dataset_manager.sample.get_samples_by_indices()
- Sample by specific indicesdataset_manager.sample.get_batches_by_ids()
- Sample complete batchesdataset_manager.sample.get_samples_by_policy()
- Policy-based sample selectiondataset_manager.sample.get_batches_by_policy()
- Policy-based batch selection
Sampling Strategies Explored:¶
- Random sampling: For unbiased dataset exploration
- Sequential sampling: For examining consecutive samples
- Sparse sampling: For distributed dataset coverage
- Mathematical patterns: Prime numbers, Fibonacci, powers of 2
- Length-based sampling: Targeting specific text characteristics
- Custom lambda policies: Flexible, inline sampling logic
Performance Insights:¶
- Index-based sampling is fastest for known indices
- Policy-based sampling adds minimal overhead for simple policies
- Batch sampling is efficient for processing multiple samples
- Document details add some processing cost but provide valuable metadata
- Memory-efficient chunking enables large-scale analysis
Pro Tips:¶
- Use seeds for reproducible sampling in research
- Prefer batch-based operations for large-scale analysis
- Create reusable policy functions for common sampling patterns
- Consider memory usage when processing large numbers of samples
- Combine sampling with inspection for comprehensive dataset understanding