Reshaping Bonsai

Pruning LLMs for Mathematical Reasoning. Can we prune LLMs while maintaining their mathematical reasoning abilities? How does a novel comprehensive metric affect pruning?

You can find more about the project in the paper here Click here to view the full PDF.

Introduction

This project, completed for the Advanced Natural Language Processing course, focused on improving the Bonsai pruning method for Large Language Models (LLMs) with a specific emphasis on mathematical reasoning capabilities.

Key Concepts

  1. Bonsai Pruning: A forward-only, regression-based neural network pruning method that decides which modules to prune based on estimates of module importance.

  2. Comprehensive Metric: A novel metric combining lexicographical similarity, semantic similarity, and accuracy to evaluate model-generated outputs against ground truth during pruning.

Technical Background

Bonsai Pruning Method

The Bonsai pruning method aims to solve the following optimization problem:

\[m^* = \arg\max_{\bar{m} \in F_p} U(M|_{\bar{m}})\]

where \(F_p = \{\bar{m} \subseteq m \mid \sum_{[j:m_j \in \bar{m}]} s_j \leq (1-p)D\}\)

Here, \(m^*\) represents the optimal sub-model, \(p\) is the target sparsity, \(U\) is the utility function measuring model performance, and \(D\) is the total number of parameters.

Bonsai estimates module importance using a regression-based approach:

\[\hat{\beta} = \arg\min_{\beta \in \mathbb{R}^N} \left\{\frac{1}{n}\sum_{(\bar{m}_k, U_k) \in \mathcal{D}} (U_k - \beta^T \alpha_{\bar{m}_k})^2 + \gamma\|\beta\|^2\right\}\]

where \(\mathcal{D}\) is the dataset of sampled sub-modules and their performances, and \(\alpha_{\bar{m}_k}\) is a binary mask.

Limitation: Bonsai shows great promise based on its performance on 4/6 tasks on Huggingface Open LLM Leaderboard in its parameter category. However, one notable exception to this generally good performance is its performance on the GSM-8K dataset, which is a mathematical reasoning dataset (achieving ~6% accuracy in its best hyperparameter setting). In this work, we wanted to see if we can improve its performance on mathematical reasoning tasks.

Our Novel Comprehensive Metric

Key Insight: Our key insight was to notice that while the usual gradient pruning requires the metric \(U\) to be differentiable, this regression-based approach allows us to use any well-defined metric, so long as we can obtain a good estimation of the module’s importance. For instance, while accuracy is not differentiable, it can still be used here.

We asked whether a metric that rewards better reasoning during pruning help with the downstream performance. How can we come up with this metric?

Building on this insight, we experimented with combining accuracy (to capture the quality of the final output), lexicographical similarity (to ensure intermediate numbers are correct), and semantic similarities (to capture the similarity in meaning) between the true and generated tokens.

We introduce a new metric \(U^\dagger\) that combines lexicographical similarity, semantic similarity, and accuracy: \(U^\dagger = \sum_{i=1}^n a_i M_i \quad \text{where} \quad \sum_{i=1}^n a_i = 100\)

Here, \(M_i\) represents individual metrics (e.g., lexicographical similarity, semantic similarity, accuracy), and \(a_i\) are their respective weights.

Research Focus

The project explored various aspects of LLM pruning through several experiments and ablation studies:

Key Findings

Challenges and Limitations

  1. Computational Constraints: We had to reduce the number of generated tokens from 100 to 20 and increase pruning step size from 5% to 20% per iteration.

  2. Accuracy Measurement: Defining accuracy based on the presence of the ground truth answer string in the output may have led to false positives.

  3. Embedding Model Limitations: The sentence embedding model used for semantic similarity (all-MiniLM-L6-v2) may not have captured the nuances of mathematical reasoning effectively.

Future Directions

While our results were mixed, this project demonstrates the potential of using task-specific datasets and comprehensive metrics for pruning language models while maintaining their reasoning capabilities. As we continue to refine our approach, we hope to contribute to the development of more efficient and capable language models for specific reasoning tasks.