SemHash provides a simple interface for representative sampling, which is useful for selecting a subset of data that best represents the entire dataset.
This can be particularly useful in scenarios where you want to reduce the size of your dataset while retaining its diversity.
This works by first selecting samples that have the highest average similarity to other samples in the dataset (the most “central” samples),
and then iteratively selecting samples that are most dissimilar to the already selected ones with that set using maximum marginal relevance (MMR).
To perform representative sampling from a single dataset, you can use the self_find_representative method. This method will select a subset of samples that best represent the entire dataset based on their semantic similarity.
Copy
from datasets import load_datasetfrom semhash import SemHash# Load a dataset to filtertexts = load_dataset("ag_news", split="train")["text"]# Initialize a SemHash instancesemhash = SemHash.from_records(records=texts)# Find representative samples from the textsrepresentative_texts = semhash.self_find_representative().selected
To perform representative sampling across multiple datasets, you can use the find_representative method. This method allows you to select a subset of samples from one dataset that best represents another dataset.
Copy
from datasets import load_datasetfrom semhash import SemHash# Load two datasets to filtertrain_texts = load_dataset("ag_news", split="train")["text"]test_texts = load_dataset("ag_news", split="test")["text"]# Initialize a SemHash instance with the training datasemhash = SemHash.from_records(records=train_texts)# Find representative samples from the test data against the training datarepresentative_test_texts = semhash.find_representative(records=test_texts).selected
Representative Sampling from a Multi-Column Dataset
If you have a multi-column dataset, you can sample representatives from it by specifying the columns to use for representative sampling.
Copy
from datasets import load_datasetfrom semhash import SemHash# Load the datasetdataset = load_dataset("squad_v2", split="train")# Convert the dataset to a list of dictionariesrecords = [dict(row) for row in dataset]# Initialize SemHash with the columns to deduplicatesemhash = SemHash.from_records(records=records, columns=["question", "context"])# Find representative samples from the recordsrepresentative_records = semhash.self_find_representative().selected