diff --git a/F2LLM/README.md b/F2LLM/README.md index 6b79819..ea12120 100644 --- a/F2LLM/README.md +++ b/F2LLM/README.md @@ -30,6 +30,15 @@ In this repo we provide a streamlined and efficient script for training embeddin - Modify model path, data path, and other arguments in `configs/config.json`. - Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. +### Matryoshka Representation Learning (MRL) + +MRL lets one model serve multiple embedding dimensions (e.g., 64/128/256/512/1024) so you can pick the size that matches each downstream task or budget. + +- Enable in config: set `use_mrl`: true, `mrl_dimensions`: [64, 128, 256, 512, 1024], and optionally `mrl_temperature` (default 0.05). +- Training: the main contrastive losses stay the same; an auxiliary MRL loss is added over the truncated prefix dimensions. +- Inference: compute the full embedding once and slice the first `k` dimensions to get a smaller embedding without retraining. +- Quick check: run `python test_mrl.py` to validate the MRL loss and slicing behavior without needing real data. + Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training. For multi-node training, run on the main node: diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..f0569df 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -21,6 +21,10 @@ class Args: warmup_steps: int = 100 # embedding-related settings num_hard_neg: int = 7 + # Matryoshka Representation Learning (MRL) settings + use_mrl: bool = False + mrl_dimensions: list = None # e.g., [64, 128, 256, 512, 1024] + mrl_temperature: float = 0.05 # train steps take precedence over epochs, set to -1 to disable train_steps: int = -1 train_epochs: int = 5 diff --git a/F2LLM/configs/config_mrl.json b/F2LLM/configs/config_mrl.json new file mode 100644 index 0000000..6103e88 --- /dev/null +++ b/F2LLM/configs/config_mrl.json @@ -0,0 +1,22 @@ +{ + "model_path": "models/qwen3-4b", + "experiment_id": "4b_mrl+lr.8e-6+bs.16x32+context.1024+2epochs", + "train_data_path": "training_data/data_tokenized_qwen", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_batch_size": 16, + "checkpointing_steps": 5000, + "validation_steps": 5000, + "max_seq_length": 1024, + "learning_rate": 8e-6, + "min_lr": 1e-7, + "weight_decay": 0.01, + "warmup_steps": 500, + "train_epochs": 2, + "log_interval": 100, + "num_hard_neg": 7, + "use_mrl": true, + "mrl_dimensions": [64, 128, 256, 512, 1024], + "mrl_temperature": 0.05 +} diff --git a/F2LLM/test_mrl.py b/F2LLM/test_mrl.py new file mode 100644 index 0000000..7de1da8 --- /dev/null +++ b/F2LLM/test_mrl.py @@ -0,0 +1,315 @@ +""" +Test suite for Matryoshka Representation Learning (MRL) functionality. + +This test file validates that MRL loss computation works correctly +and that embeddings maintain quality across multiple dimensions. +""" + +import torch +import torch.nn.functional as F +from utils import matryoshka_loss + + +class MockAccelerator: + """Mock accelerator for testing without distributed setup""" + def __init__(self): + self.process_index = 0 + self.num_processes = 1 + + def gather(self, tensor): + return tensor + + +def test_matryoshka_loss_basic(): + """Test that MRL loss returns a scalar and is differentiable""" + print("Test 1: Basic MRL loss computation...") + + batch_size = 16 + embedding_dim = 1024 + + # Create random embeddings + embeddings = torch.randn(batch_size, embedding_dim, requires_grad=True) + + # Normalize embeddings + embeddings_norm = F.normalize(embeddings, p=2, dim=-1) + + # Define MRL dimensions + mrl_dimensions = [64, 128, 256, 512, 1024] + + # Compute MRL loss + loss = matryoshka_loss(embeddings_norm, mrl_dimensions) + + # Check that loss is a scalar + assert loss.dim() == 0, f"Expected scalar loss, got shape {loss.shape}" + assert loss.item() > 0, "Expected positive loss" + + # Check that loss is differentiable + loss.backward() + assert embeddings.grad is not None, "Expected gradient to be computed" + + print(f" ✓ MRL loss computed: {loss.item():.6f}") + print(f" ✓ Loss is differentiable") + print() + + +def test_matryoshka_loss_different_dimensions(): + """Test MRL loss with different dimension sets""" + print("Test 2: MRL loss with different dimension configurations...") + + batch_size = 32 + embedding_dim = 512 + + embeddings = torch.randn(batch_size, embedding_dim) + embeddings_norm = F.normalize(embeddings, p=2, dim=-1) + + # Test with different dimension sets + dimension_sets = [ + [64, 128, 256, 512], + [128, 256], + [64], + [512], + ] + + losses = [] + for dims in dimension_sets: + loss = matryoshka_loss(embeddings_norm, dims) + losses.append(loss.item()) + # Loss might be very small but should not be negative + assert loss.item() >= 0, f"Expected non-negative loss for dims {dims}" + print(f" ✓ Dims {dims}: loss = {loss.item():.6f}") + + print() + + +def test_matryoshka_loss_edge_cases(): + """Test MRL loss with edge cases""" + print("Test 3: Edge cases...") + + batch_size = 16 + embedding_dim = 256 + + embeddings = torch.randn(batch_size, embedding_dim) + embeddings_norm = F.normalize(embeddings, p=2, dim=-1) + + # Test with None dimensions + loss = matryoshka_loss(embeddings_norm, None) + assert loss == 0.0, "Expected 0 loss for None dimensions" + print(f" ✓ None dimensions: loss = {loss}") + + # Test with empty dimensions + loss = matryoshka_loss(embeddings_norm, []) + assert loss == 0.0, "Expected 0 loss for empty dimensions" + print(f" ✓ Empty dimensions: loss = {loss}") + + # Test with dimensions larger than embedding dim + loss = matryoshka_loss(embeddings_norm, [256, 512, 1024]) + assert loss.item() >= 0, "Expected valid loss even with large dimensions" + print(f" ✓ Dimensions larger than embedding: loss = {loss.item():.6f}") + + print() + + +def test_matryoshka_loss_temperature(): + """Test that temperature affects loss magnitude""" + print("Test 4: Temperature scaling...") + + batch_size = 8 + embedding_dim = 256 + + embeddings = torch.randn(batch_size, embedding_dim) + embeddings_norm = F.normalize(embeddings, p=2, dim=-1) + + mrl_dimensions = [64, 128, 256] + + # Note: temperature parameter needs to be added to loss computation + # For now, we just verify the function works with normalized embeddings + loss1 = matryoshka_loss(embeddings_norm, mrl_dimensions, temperature=0.05) + loss2 = matryoshka_loss(embeddings_norm, mrl_dimensions, temperature=0.1) + + print(f" ✓ Loss with temperature=0.05: {loss1.item():.6f}") + print(f" ✓ Loss with temperature=0.1: {loss2.item():.6f}") + print() + + +def test_inbatch_loss_with_mrl(): + """Test that inbatch_loss works with MRL enabled""" + print("Test 5: In-batch loss with MRL integration...") + + from utils import inbatch_loss + from torch.nn import CrossEntropyLoss + + batch_size = 32 + embedding_dim = 512 + + query_embeddings = torch.randn(batch_size, embedding_dim) + context_embeddings = torch.randn(batch_size, embedding_dim) + + # Create mock accelerator + accelerator = MockAccelerator() + criterion = CrossEntropyLoss(reduction='none') + + mrl_dimensions = [64, 128, 256, 512] + + # Test without MRL + loss_no_mrl = inbatch_loss( + query_embeddings, + context_embeddings, + criterion, + accelerator, + use_mrl=False + ) + + # Test with MRL + loss_with_mrl = inbatch_loss( + query_embeddings, + context_embeddings, + criterion, + accelerator, + mrl_dimensions=mrl_dimensions, + use_mrl=True + ) + + print(f" ✓ Loss without MRL: {loss_no_mrl.item():.6f}") + print(f" ✓ Loss with MRL: {loss_with_mrl.item():.6f}") + assert loss_with_mrl.item() >= loss_no_mrl.item(), "MRL loss should not decrease total loss" + print() + + +def test_hard_loss_with_mrl(): + """Test that hard_loss works with MRL enabled""" + print("Test 6: Hard loss with MRL integration...") + + from utils import hard_loss + from torch.nn import CrossEntropyLoss + + batch_size = 32 + embedding_dim = 512 + num_hard_neg = 7 + + query_embeddings = torch.randn(batch_size, embedding_dim) + context_embeddings = torch.randn(batch_size, embedding_dim) + hard_neg_embeddings = torch.randn(batch_size, num_hard_neg, embedding_dim) + + # Create mock accelerator + accelerator = MockAccelerator() + criterion = CrossEntropyLoss(reduction='none') + + mrl_dimensions = [64, 128, 256, 512] + + # Test without MRL + loss_no_mrl = hard_loss( + query_embeddings, + context_embeddings, + hard_neg_embeddings, + criterion, + accelerator, + use_mrl=False + ) + + # Test with MRL + loss_with_mrl = hard_loss( + query_embeddings, + context_embeddings, + hard_neg_embeddings, + criterion, + accelerator, + mrl_dimensions=mrl_dimensions, + use_mrl=True + ) + + print(f" ✓ Hard loss without MRL: {loss_no_mrl.item():.6f}") + print(f" ✓ Hard loss with MRL: {loss_with_mrl.item():.6f}") + assert loss_with_mrl.item() >= loss_no_mrl.item(), "MRL loss should not decrease total loss" + print() + + +def test_embedding_dimension_truncation(): + """Test that embeddings maintain quality when truncated""" + print("Test 7: Embedding quality at different dimensions...") + + batch_size = 16 + embedding_dim = 1024 + + # Create embeddings from a Gaussian distribution + embeddings = torch.randn(batch_size, embedding_dim) + embeddings_norm = F.normalize(embeddings, p=2, dim=-1) + + # Compute similarity matrix at full dimension + sim_full = torch.matmul(embeddings_norm, embeddings_norm.t()) + + # Compute similarities at truncated dimensions + test_dims = [64, 128, 256, 512, 1024] + similarities = [] + + for dim in test_dims: + truncated = embeddings_norm[:, :dim] + truncated_norm = F.normalize(truncated, p=2, dim=-1) + sim = torch.matmul(truncated_norm, truncated_norm.t()) + similarities.append(sim) + + # Compare to full dimension (should be similar, especially for larger dims) + correlation = F.cosine_similarity( + sim_full.flatten().unsqueeze(0), + sim.flatten().unsqueeze(0) + ) + print(f" ✓ Dim {dim}: correlation with full dim = {correlation.item():.6f}") + + print() + + +def test_mrl_batch_consistency(): + """Test that MRL loss is consistent across batches""" + print("Test 8: Batch consistency...") + + embedding_dim = 512 + mrl_dimensions = [128, 256, 512] + + # Create fixed embeddings + torch.manual_seed(42) + embeddings = torch.randn(16, embedding_dim) + embeddings_norm = F.normalize(embeddings, p=2, dim=-1) + + # Compute loss on full batch + loss_full = matryoshka_loss(embeddings_norm, mrl_dimensions) + + # Compute loss on splits + half_size = embeddings_norm.size(0) // 2 + loss_first = matryoshka_loss(embeddings_norm[:half_size], mrl_dimensions) + loss_second = matryoshka_loss(embeddings_norm[half_size:], mrl_dimensions) + + print(f" ✓ Full batch loss: {loss_full.item():.6f}") + print(f" ✓ First half loss: {loss_first.item():.6f}") + print(f" ✓ Second half loss: {loss_second.item():.6f}") + print() + + +def run_all_tests(): + """Run all test cases""" + print("=" * 70) + print("Matryoshka Representation Learning (MRL) Test Suite") + print("=" * 70) + print() + + try: + test_matryoshka_loss_basic() + test_matryoshka_loss_different_dimensions() + test_matryoshka_loss_edge_cases() + test_matryoshka_loss_temperature() + test_inbatch_loss_with_mrl() + test_hard_loss_with_mrl() + test_embedding_dimension_truncation() + test_mrl_batch_consistency() + + print("=" * 70) + print("All tests passed! ✓") + print("=" * 70) + + except Exception as e: + print("=" * 70) + print(f"Test failed with error: {e}") + print("=" * 70) + raise + + +if __name__ == "__main__": + run_all_tests() diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..1cd8a77 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -15,6 +15,59 @@ def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_s summary_writer.add_scalar(key, value, completed_steps) +def matryoshka_loss(embeddings, mrl_dimensions, temperature=0.05): + """ + Compute Matryoshka Representation Learning (MRL) loss. + This loss encourages the model to produce high-quality embeddings at multiple dimensions. + + Args: + embeddings: [batch_size, embedding_dim] - normalized embeddings + mrl_dimensions: list of dimensions to apply MRL loss (e.g., [64, 128, 256, 512]) + temperature: temperature for contrastive loss + + Returns: + mrl_loss: scalar loss value + """ + if mrl_dimensions is None or len(mrl_dimensions) == 0: + return 0.0 + + # Sort dimensions to ensure we start from smallest + sorted_dims = sorted(mrl_dimensions) + full_dim = embeddings.size(-1) + + # Make sure all dimensions are valid + sorted_dims = [d for d in sorted_dims if d <= full_dim and d > 0] + if not sorted_dims: + return 0.0 + + mrl_loss = 0.0 + num_dims = len(sorted_dims) + batch_size = embeddings.size(0) + + # For each dimension, compute contrastive loss + for i, dim in enumerate(sorted_dims): + # Truncate embeddings to this dimension + truncated = embeddings[:, :dim] + + # Normalize + truncated_norm = F.normalize(truncated, p=2, dim=-1) + + # Compute similarity matrix + similarity = torch.matmul(truncated_norm, truncated_norm.t()) / temperature + + # Create labels (diagonal = positive pairs) + labels = torch.arange(batch_size, device=embeddings.device) + + # Compute cross-entropy loss + loss = F.cross_entropy(similarity, labels) + + # Weight loss - larger dimensions get more weight + weight = (i + 1) / num_dims + mrl_loss = mrl_loss + weight * loss + + return mrl_loss / num_dims + + def save_checkpoint(args, accelerator, model, output_dir, lr_scheduler): accelerator.wait_for_everyone() accelerator.print(f"Saving checkpoint to {output_dir}") @@ -37,6 +90,8 @@ def inbatch_loss( criterion, accelerator, temperature=0.05, + mrl_dimensions=None, + use_mrl=False, ): bs = query_embeddings.size(0) @@ -52,6 +107,11 @@ def inbatch_loss( loss_bs = criterion(student_logits, labels) # (bs) loss = loss_bs.mean() + + # Add MRL loss if enabled + if use_mrl and mrl_dimensions: + mrl_loss = matryoshka_loss(a_norm, mrl_dimensions, temperature) + loss = loss + mrl_loss return loss @@ -62,6 +122,8 @@ def hard_loss( criterion, accelerator, temperature=0.05, + mrl_dimensions=None, + use_mrl=False, ): if hard_neg_embeddings is None: @@ -79,6 +141,11 @@ def hard_loss( logits = (a_norm.unsqueeze(1) * hard_norm).sum(-1) / temperature # [bs, num_hard+1] loss_hard = criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean() + + # Add MRL loss if enabled + if use_mrl and mrl_dimensions: + mrl_loss = matryoshka_loss(a_norm, mrl_dimensions, temperature) + loss_hard = loss_hard + mrl_loss return loss_hard @@ -90,10 +157,10 @@ def validate(args, accelerator, model, valid_loader_dict, criterion, completed_s for batch in valid_dataloader: with torch.no_grad(): outputs = model.forward(batch) - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) + loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator, mrl_dimensions=args.mrl_dimensions, use_mrl=args.use_mrl) loss_hard_ls.append(accelerator.gather(loss_hard).float()) if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) + loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator, mrl_dimensions=args.mrl_dimensions, use_mrl=args.use_mrl) loss_ls.append(accelerator.gather(loss).float()) accelerator.wait_for_everyone() @@ -152,12 +219,12 @@ def accelerate_train(args, # passage features: [bs, 1, d] # hard_neg_features: [bs, num_hard_neg, d] - loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator) + loss_hard = hard_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), outputs['negative_passage_features'], criterion, accelerator, mrl_dimensions=args.mrl_dimensions, use_mrl=args.use_mrl) dataset_name = batch['dataset_name'] count_hard_dict[dataset_name] += 1 loss_hard_dict[dataset_name] += loss_hard.detach().float() if dataset_name in RETRIEVAL_DATASETS: - loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator) + loss = inbatch_loss(outputs['query_passage_features'].squeeze(1), outputs['passage_passage_features'].squeeze(1), criterion, accelerator, mrl_dimensions=args.mrl_dimensions, use_mrl=args.use_mrl) count_dict[dataset_name] += 1 loss_dict[dataset_name] += loss.detach().float() else: