from datasets import load_dataset
import timm
import torch
from semhash import SemHash
# Create a custom image encoder
class VisionEncoder:
"""Custom encoder using timm models. Implements the Encoder protocol."""
def __init__(self, model_name: str = "mobilenetv3_small_100.lamb_in1k"):
self.model = timm.create_model(model_name, pretrained=True, num_classes=0).eval()
data_config = timm.data.resolve_model_data_config(self.model)
self.transform = timm.data.create_transform(**data_config, is_training=False)
def encode(self, inputs, batch_size: int = 128):
"""Encode a batch of PIL images into embeddings."""
import numpy as np
# Convert grayscale to RGB if needed
rgb_inputs = [img.convert("RGB") if img.mode != "RGB" else img for img in inputs]
# Process in batches to avoid memory issues
all_embeddings = []
with torch.no_grad():
for i in range(0, len(rgb_inputs), batch_size):
batch_inputs = rgb_inputs[i : i + batch_size]
batch = torch.stack([self.transform(img) for img in batch_inputs])
embeddings = self.model(batch).numpy()
all_embeddings.append(embeddings)
return np.vstack(all_embeddings)
# Load image dataset
dataset = load_dataset("uoft-cs/cifar10", split="test")
train_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][:100])]
test_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][100:150])]
# Initialize SemHash with the custom vision encoder
semhash = SemHash.from_records(train_data, columns=["img"], model=VisionEncoder())
# Single-dataset operations
deduplicated = semhash.self_deduplicate().selected
outliers = semhash.self_filter_outliers().selected
representatives = semhash.self_find_representative().selected
# Cross-dataset operations
test_deduplicated = semhash.deduplicate(test_data).selected
test_outliers = semhash.filter_outliers(test_data).selected
test_representatives = semhash.find_representative(test_data, selection_size=10).selected