Learned Token Pruning for Transformers (2024)

\useunder

Sehoon Kimsehoonkim@berkeley.eduUniversity of California, BerkeleyBerkeleyCAUSA,Sheng Shensheng.s@berkeley.eduUniversity of California, BerkeleyBerkeleyCAUSA,David Thorsleyd.thorsley@samsung.comSamsung Semiconductor, Inc.San JoseCAUSA,Amir Gholamiamirgh@berkeley.eduUniversity of California, BerkeleyBerkeleyCAUSA,Woosuk Kwonwoosuk.kwon@berkeley.eduUniversity of California, BerkeleyBerkeleyCAUSA,Joseph Hassounj.hassoun@samsung.comSamsung Semiconductor, Inc.San JoseCAUSAandKurt Keutzerkeutzer@berkeley.eduUniversity of California, BerkeleyBerkeleyCAUSA

(2022)

Abstract.

Efficient deployment of transformer models in practice is challenging due to their inference cost including memory footprint, latency, and power consumption, which scales quadratically with input sequence length.To address this, we present a novel token reduction method dubbed Learned Token Pruning (LTP)which adaptively removes unimportant tokens as an input sequence passes through transformer layers.In particular, LTP prunes tokens with an attention score below a threshold,whose value is learned for each layer during training.Our threshold-based method allows the length of the pruned sequence to vary adaptively based on the input sequence,and avoids algorithmically expensive operations such as top-k๐‘˜k token selection.We extensively test the performance of LTP on GLUE and SQuAD tasks and show that our method outperforms the prior state-of-the-art token pruning methods by up to โˆผsimilar-to\sim2.5% higher accuracy with the same amount of FLOPs.In particular, LTP achieves up to 2.1ร—\times FLOPs reduction with less than 1% accuracy drop,which results in up to 1.9ร—\times and 2.0ร—\times throughput improvement on Intel Haswell CPUs and NVIDIA V100 GPUs.Furthermore, we demonstrate that LTP is more robust than prior methods to variations in input sequence lengths.Our code has been developed in PyTorch and open-sourced111https://github.com/kssteven418/LTP.

Deep Learning, Network Pruning, Natural Language Processing

โ€ โ€ journalyear: 2022โ€ โ€ copyright: acmcopyrightโ€ โ€ conference: Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining; August 14โ€“18, 2022; Washington, DC, USA.โ€ โ€ booktitle: Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD โ€™22), August 14โ€“18, 2022, Washington, DC, USAโ€ โ€ price: 15.00โ€ โ€ isbn: 978-1-4503-9385-0/22/08โ€ โ€ doi: 10.1145/3534678.3539260โ€ โ€ ccs: Computer systems organizationNeural networksโ€ โ€ ccs: Computer systems organizationNatural language processing

1. Introduction

Transformer-based deep neural network architectures (Vaswani etal., 2017), such as BERT(Devlin etal., 2018) and RoBERTa(Liu etal., 2019),achieve state-of-the-art results in Natural Language Processing (NLP) tasks such as sentence classification and question answering.However, efficiently deploying these models is increasingly challenging due to their large size, the need for real-time inference, and the limited energy, compute, and memory resources available.The heart of a transformer layer is the multi-head self-attention mechanism, where each token in the input sequence attends to every other token to compute a new representation of the sequence.Because all tokens attend to each others, the computation complexity is quadratic with respect to the input sequence length; thus the ability to apply transformer models to long input sequences becomes limited.

Learned Token Pruning for Transformers (1)
Learned Token Pruning for Transformers (2)
Learned Token Pruning for Transformers (3)

Pruning is a popular technique to reduce the size of neural networks and the amount of computation required.Unstructured pruning allows arbitrary patterns of sparsification for parameters and feature maps and can, in theory, produce significant computational savings while preserving accuracy.However, commodity DNN accelerators cannot efficiently exploit unstructured sparsity patterns.Thus, structured pruning methods are typically preferred in practice due to their relative ease of deployment to hardware.

Multi-head self-attention provides several possibilities for structured pruning; for example,head pruning (Michel etal., 2019; Voita etal., 2019) decreases the size of the model by removing unneeded heads in each transformer layer.Another orthogonal approach that we consider in this paperis token pruning, which reduces computation by progressivelyremoving unimportant tokens in the sequence during inference.For NLP tasks such as sentence classification,token pruning is an attractive approach to consider as it exploits the intuitive observationthat not all tokens (i.e., words) in an input sentence are necessarily required to make a successful inference.

There are two main classes of token pruning methods.In the first class, methods like PoWER-BERT (Goyal etal., 2020) and Length-Adaptive Transformer (LAT) (Kim and Cho, 2020) search for a single token pruning configuration (i.e., sequence length for each layer) for an entire dataset.In other words, they prune all input sequences to the same length.However, input sequence lengths can vary greatly within tasks and between training and validation sets as in Figure1, and thus applying a single pruning configuration to all input sequences can potentially under-prune shorter sequences or over-prune longer sequences.

In the other class, the token pruning method adjusts the configuration based on the input sequence. SpAtten(Wang etal., 2020b) uses a pruning configuration proportional to input sentence length; however, it does not adjust the proportion of pruned tokens based on the content of the input sequence. The recently published TR-BERT (Ye etal., 2021) uses reinforcement learning (RL) to find a policy network that dynamically reduces the number of tokens based on the length and content of the input sequence;however, it requires additional costly training for convergence of the RL-based method.Additionally, all of these prior methods rely in part on selecting the k๐‘˜k most significant tokens during inference or training.This selection can be computationally expensive without the development of specialized hardware, such as the top-k๐‘˜k engine introduced in SpAtten(Wang etal., 2020b).

To this end, we propose a learned threshold-based token pruning method which adapts to the length and content of individual examples and avoids the use of top-k๐‘˜k operations. In particular, our contributions are as follows:

  • โ€ข

    We propose Learned Token Pruning (LTP), a threshold-based token pruning method, which only needs a simple threshold operation to detect unimportant tokens.In addition, LTP fully automates the search for optimal pruning configurationsby introducing a differentiable soft binarized mask that allows training the correct thresholds for different layers and tasks.(Section3.3)

  • โ€ข

    We apply LTP to RoBERTa and evaluate its performance on GLUE and SQuAD tasks.We show LTP achieves up to 2.10ร—\times FLOPs reductionwith less than 1% accuracy degradation,which results in up to 1.93ร—\times and 1.97ร—\times throughput improvement on NVIDIA V100 GPU and Intel Haswell CPU, respectively, as compared to the unpruned FP16 baseline.We also show that LTP outperforms SpAtten and LAT in most cases, achieving additional FLOPs reduction for the same drop in accuracy.(Section4.2 and 4.5)

  • โ€ข

    We show that LTP is highly robust against sentence length variations.LTP exhibits consistently better accuracy over different sentence length distributions, achieving up to 16.4% accuracy gap from LAT.(Section4.3)

Learned Token Pruning for Transformers (4)

2. Related Work

2.1. Efficient Transformers

Multiple different approaches have been proposed to improve speed and diminish memory footprint of transformers.These can be broadly categorized as follows:(i) efficient architecture design(Lan etal., 2019; Child etal., 2019; Kitaev etal., 2020; Wang etal., 2020a; Iandola etal., 2020; Vyas etal., 2020; Tay etal., 2020; Katharopoulos etal., 2020; Zaheer etal., 2020; Roy etal., 2021);(ii) knowledge distillation(Sun etal., 2020; Jiao etal., 2019; Tang etal., 2019; Sanh etal., 2019; Sun etal., 2019);(iii) quantization(Bhandare etal., 2019; Zafrir etal., 2019; Shen etal., 2020; Fan etal., 2020; Zadeh etal., 2020; Zhang etal., 2020; Bai etal., 2020; Kim etal., 2021);and (iv) pruning.Here, we focus only on pruning and briefly discuss therelated work.

2.2. Transformer Pruning

Pruning methods can be categorized into unstructured pruning and structured pruning.For unstructured pruning, the lottery-ticket hypothesis(Frankle and Carbin, 2018) has been explored for transformers in (Prasanna etal., 2020; Chen etal., 2020).Recently, (Zhao etal., 2020) leverages pruning as an effective way to fine-tune transformers on downstream tasks.(Sanh etal., 2020) proposes movement pruning, which achieves significant performance improvements versus prior magnitude-based methods by considering the weights modification during fine-tuning.However, it is often quite difficult to efficiently deploy unstructured sparsity on commodity neural accelerators for meaningful speedup.

For this reason, a number of structured pruning methods have been introduced to remove structured sets of parameters.(Michel etal., 2019; Voita etal., 2019) drop attention heads in multi-head attention layers,and (Sajjad etal., 2020; Fan etal., 2019) prunes entire transformer layers.(Wang etal., 2019) structurally prunes weight matrices via low-rank factorization, and (Khetan and Karnin, 2020; Lin etal., 2020) attempt to jointly prune attention heads and filters of weight matrices.(Liu etal., 2021; Hou etal., 2020) dynamically determines structured pruning ratios during inference.Recent block pruning schemes chunk weight matrices into multiple blocks and prune them basedon group Lasso optimization(Li etal., 2020), adaptive regularization(Yao etal., 2021), and movement pruning(Lagunas etal., 2021).All of these methods correspond to weight pruning, where model parameters (i.e., weights) are pruned.

Recently, there has been work on pruning input sentences to transformers, rather than model parameters.This is referred to as token pruning, where less important tokens are progressively removed during inference.PoWER-BERT(Goyal etal., 2020), one of the earliest works, proposes to directly learn token pruning configurations.LAT(Kim and Cho, 2020) extends this idea by introducing LengthDrop, a procedure in which a model is trained with different token pruning configurations, followed by an evolutionary search.This method has an advantage that the former training procedure need not be repeated fordifferent pruning ratios of the same model.While these methods have shown a large computation reduction on various NLP downstream tasks,they fix a single token pruning configuration for the entire dataset.That is, they prune all input sequences to the same length.However, as shown inFigure1, input sequence lengths vary greatly within a task.As a consequence, fixing a single pruning configuration can under-prune shorter sequences so as to retain sufficient tokens for processing longer sequences or, conversely, over-prune longer sequences to remove sufficient tokens to efficiently process shorter sequences.More importantly, a single pruning configuration lacks robustness against input sequence length variations, where input sentences at inference time are longer than those in the training dataset(Press etal., 2021).

In contrast, SpAtten(Wang etal., 2020b) circumvents this issue by assigning a pruning configuration proportional to the input sequence length.While this allows pruning more tokens from longer sequences and fewer tokens from shorter ones,it is not adaptive to individual input sequences as it assigns the same configuration to all sequences with the same length regardless of their contents.In addition, the pruning configurations are determined heuristically and thus can result in a suboptimal solution.Recently, TR-BERT(Ye etal., 2021) proposes to learn a RL policy network to apply adaptive pruning configurations for each input sequence.However, as noted by the authors, the problem has a large search spaces which can be hard for RL to solve.This issue is mitigated by heuristics involving imitation learning and sampling of action sequences, which significantly increases the cost of training.Importantly, all of the aforementioned token pruning methods depend partially or entirely on top-k๐‘˜k operation for selecting the k๐‘˜k most important tokens during inference or training.This operation can be costly without specialized hardware support, as discussed in(Wang etal., 2020b).LTP, on the other hand, is based on a fully learnable threshold-based pruning strategy.Therefore, it is (i) adaptive to both input length and content,(ii) robust to sentence length variations, (iii) computationally efficient, and (iv) easy to deploy.

3. Methodology

Learned Token Pruning for Transformers (5)

3.1. Background

BERT(Devlin etal., 2018) consists of multiple transformer encoder layers(Vaswani etal., 2017) stacked up together.A basic transformer encoder layer consists of a multi-head attention (MHA) block followed by a point-wise feed-forward (FFN) block, with residual connections around each.Specifically, an MHA consists of Nhsubscript๐‘โ„ŽN_{h} independently parameterized heads.An attention head hโ„Žh in layer l๐‘™l is parameterized by ๐–k(h,l),๐–q(h,l),๐–v(h,l)โˆˆโ„dhร—dsuperscriptsubscript๐–๐‘˜โ„Ž๐‘™superscriptsubscript๐–๐‘žโ„Ž๐‘™superscriptsubscript๐–๐‘ฃโ„Ž๐‘™superscriptโ„subscript๐‘‘โ„Ž๐‘‘{\mathbf{W}}_{k}^{(h,l)},~{}{\mathbf{W}}_{q}^{(h,l)},~{}{\mathbf{W}}_{v}^{(h,l)}\in\mathbb{R}^{d_{h}\times d}, ๐–o(h,l)โˆˆโ„dร—dhsuperscriptsubscript๐–๐‘œโ„Ž๐‘™superscriptโ„๐‘‘subscript๐‘‘โ„Ž{\mathbf{W}}_{o}^{(h,l)}\in\mathbb{R}^{d\times d_{h}},where dhsubscript๐‘‘โ„Žd_{h} is typically set to d/Nh๐‘‘subscript๐‘โ„Žd/N_{h} and d๐‘‘d is the feature dimension.We drop the superscript l๐‘™l for simplicity in the following formula.The MHAmeasures the pairwise importance of each token on every other token in the input:

(1)MHAโ€‹(x)=โˆ‘h=1NhAtt๐–k,q,v,o(h)โ€‹(x),MHAxsuperscriptsubscriptโ„Ž1subscript๐‘โ„ŽsubscriptAttsuperscriptsubscript๐–๐‘˜๐‘ž๐‘ฃ๐‘œโ„Žx\text{MHA}(\text{x})=\sum_{h=1}^{N_{h}}\text{Att}_{{\mathbf{W}}_{k,q,v,o}^{(h)}}(\text{x}),

where xโˆˆโ„dร—nxsuperscriptโ„๐‘‘๐‘›\text{x}\in\mathbb{R}^{d\times n} is the input sequence with the sequence length n๐‘›n, and Att๐–k,q,v,osubscriptAttsubscript๐–๐‘˜๐‘ž๐‘ฃ๐‘œ\text{Att}_{{\mathbf{W}}_{k,q,v,o}} is:

(2)Att๐–k,q,v,oโ€‹(x)subscriptAttsubscript๐–๐‘˜๐‘ž๐‘ฃ๐‘œx\displaystyle\text{Att}_{{\mathbf{W}}_{k,q,v,o}}(\text{x})=๐–oโ€‹โˆ‘i=1n๐–vโ€‹xiโ€‹softmaxโ€‹(xTโ€‹๐–qTโ€‹๐–kโ€‹xid),absentsubscript๐–๐‘œsuperscriptsubscript๐‘–1๐‘›subscript๐–๐‘ฃsubscriptx๐‘–softmaxsuperscriptx๐‘‡superscriptsubscript๐–๐‘ž๐‘‡subscript๐–๐‘˜subscriptx๐‘–๐‘‘\displaystyle={\mathbf{W}}_{o}\sum_{i=1}^{n}{\mathbf{W}}_{v}\text{x}_{i}\text{softmax}(\frac{\text{x}^{T}{\mathbf{W}}_{q}^{T}{\mathbf{W}}_{k}\text{x}_{i}}{\sqrt{d}}),
(3)xMHAsubscriptxMHA\displaystyle\textbf{x}_{\text{MHA}}=LNโ€‹(Att๐–k,q,v,oโ€‹(x)+x),absentLNsubscriptAttsubscript๐–๐‘˜๐‘ž๐‘ฃ๐‘œxx\displaystyle=\text{LN}\big{(}\text{Att}_{{\mathbf{W}}_{k,q,v,o}}(\text{x})+\text{x}\big{)},

whereEq.3 is the residual connection and the follow up LayerNorm (LN).The output of the MHA is then fed into the FFN block which applies two feed-forward layers to this input:

(4)FFNโ€‹(xMHA)FFNsubscriptxMHA\displaystyle\text{FFN}(\text{x}_{\mathrm{{MHA}}})=ฯƒโ€‹(๐–2โ€‹(๐–1โ€‹xMHA+b1))+b2,absent๐œŽsubscript๐–2subscript๐–1subscriptxMHAsubscript๐‘1subscript๐‘2\displaystyle=\sigma\big{(}{\mathbf{W}}_{2}({\mathbf{W}}_{1}\text{x}_{\mathrm{{MHA}}}+b_{1})\big{)}+b_{2},
(5)xoutsubscriptxout\displaystyle\textbf{x}_{\mathrm{{out}}}=LNโ€‹(FFNโ€‹(xMHA)+xMHA),absentLNFFNsubscriptxMHAsubscriptxMHA\displaystyle=\text{LN}\big{(}\text{FFN}(\text{x}_{\mathrm{{MHA}}})+\text{x}_{\mathrm{{MHA}}}\big{)},

where ๐–1,๐–2,b1subscript๐–1subscript๐–2subscript๐‘1{\mathbf{W}}_{1},{\mathbf{W}}_{2},b_{1} and b2subscript๐‘2b_{2} are the FFN parameters, and ฯƒ๐œŽ\sigma is the activation function (typically GELU for BERT).

3.2. Threshold Token Pruning

Let us denote the attention probability of head hโ„Žh between token xi and xj as ๐€(h,l)superscript๐€โ„Ž๐‘™{\mathbf{A}}^{(h,l)}:

(6)๐€(h,l)โ€‹(xi,xj)=softmaxโ€‹(xTโ€‹๐–qTโ€‹๐–kโ€‹xd)(i,j)โˆˆโ„.superscript๐€โ„Ž๐‘™subscriptx๐‘–subscriptx๐‘—softmaxsubscriptsuperscriptx๐‘‡superscriptsubscript๐–๐‘ž๐‘‡subscript๐–๐‘˜x๐‘‘๐‘–๐‘—โ„{\mathbf{A}}^{(h,l)}(\text{x}_{i},\text{x}_{j})=\text{softmax}(\frac{\text{x}^{T}{\mathbf{W}}_{q}^{T}{\mathbf{W}}_{k}\text{x}}{\sqrt{d}})_{(i,j)}\in\mathbb{R}.

The cost of computational complexity for computing the attention matrix is๐’ชโ€‹(d2โ€‹n+n2โ€‹d)๐’ชsuperscript๐‘‘2๐‘›superscript๐‘›2๐‘‘\mathcal{O}(d^{2}n+n^{2}d), which quadratically scales with sequence length.As such, the attention operation becomes a bottleneck when applied tolong sequences. To address this, we apply token pruningwhich removes unimportant tokens as the input passes through the transformer layers to reduce the sequence length n๐‘›n for later blocks.This is schematically shown inFigure2 (Left).

For token pruning, we must define a metric to determine unimportant tokens. Following(Goyal etal., 2020; Wang etal., 2020b; Kim and Cho, 2020),we define the importance score of token xi in layer l๐‘™l as:

(7)s(l)โ€‹(xi)=1Nhโ€‹1nโ€‹โˆ‘h=1Nhโˆ‘j=1n๐€(h,l)โ€‹(xi,xj).superscript๐‘ ๐‘™subscriptx๐‘–1subscript๐‘โ„Ž1๐‘›superscriptsubscriptโ„Ž1subscript๐‘โ„Žsuperscriptsubscript๐‘—1๐‘›superscript๐€โ„Ž๐‘™subscriptx๐‘–subscriptx๐‘—s^{(l)}(\text{x}_{i})=\frac{1}{N_{h}}\frac{1}{n}\sum_{h=1}^{N_{h}}\sum_{j=1}^{n}{\mathbf{A}}^{(h,l)}(\text{x}_{i},\text{x}_{j}).

Intuitively, the attention probability ๐€(h,l)โ€‹(xi,xj)superscript๐€โ„Ž๐‘™subscriptx๐‘–subscriptx๐‘—{\mathbf{A}}^{(h,l)}(\text{x}_{i},\text{x}_{j}) is interpreted as the normalized amount that all the other tokens xj attend to token xi.Token xi is thus considered important if it receives more attention from all tokens across all heads, which directly leads us toequation7.The procedure for computing importance scores from attention probabilities is illustrated in Figure2 (Right).

In(Goyal etal., 2020; Wang etal., 2020b; Kim and Cho, 2020), tokens are ranked byimportance score and pruned using a top-k๐‘˜k selection strategy.Specially, token xi is pruned at layer l๐‘™l if its important score s(l)โ€‹(xi)superscript๐‘ ๐‘™subscriptx๐‘–s^{(l)}(\text{x}_{i})is smaller than the k๐‘˜k-largest values of the important score from all the tokens.However, finding the k๐‘˜k-largest values of the importance score is computationallyinefficient without specialized hardware(Wang etal., 2020b);we provide empirical results showing this inSectionA.2.Instead, we introduce a new threshold-based token pruning approach in whicha token is pruned only if its importance score is below a threshold denoted byฮธ(l)โˆˆโ„superscript๐œƒ๐‘™โ„\theta^{(l)}\in\mathbb{R}. Specifically, we define a pruning strategy by imposing a binary mask M(l)โ€‹(โ‹…):{1,โ€ฆ,n}โ†’{0,1}:superscript๐‘€๐‘™โ‹…โ†’1โ€ฆ๐‘›01M^{(l)}(\cdot):\{1,\dots,n\}\to\{0,1\}which indicates whether a token should be kept or pruned:

(8)M(l)โ€‹(xi)={1ifโ€‹s(l)โ€‹(xi)>ฮธ(l),0otherwise.superscript๐‘€๐‘™subscriptx๐‘–cases1ifsuperscript๐‘ ๐‘™subscriptx๐‘–superscript๐œƒ๐‘™0otherwiseM^{(l)}(\text{x}_{i})=\begin{cases}1~{}~{}&\text{if }s^{(l)}(\text{x}_{i})>\theta^{(l)},\\0~{}~{}&\text{otherwise}.\end{cases}

Note that this operation only requires a simplecomparison operator without any expensive top-k๐‘˜k calculation.Once a token is pruned, it is excluded from calculations in all succeeding layers, thereby gradually reducing computation complexity towards the output layers.

3.3. Learnable Threshold for Token Pruning

A key concern with the method above is how to determine the threshold values for each layer.Not only do threshold values change for different layers, they also vary between different tasks.We address this by making thethresholds (i.e., ฮธ๐œƒ\theta inEq.8) learnable.However, there are several challenges to consider.First, due to the binary nature of M๐‘€M there is no gradient flow for pruned tokens.Second, the M๐‘€M operator is non-differentiable which prevents gradient flow into the thresholds.To address these challenges, we use a soft pruning scheme that simulates the original hard pruningwhile still propagating gradients to the thresholds as shown inFigure3.

Soft Pruning Scheme. In the soft pruning scheme, we replace the non-differentiable mask M(l)superscript๐‘€๐‘™M^{(l)} with a differentiable soft mask using the sigmoid operation ฯƒ๐œŽ\sigma:

(9)M~(l)โ€‹(xi)superscript~๐‘€๐‘™subscriptx๐‘–\displaystyle\tilde{M}^{(l)}(\text{x}_{i})=ฯƒโ€‹(s(l)โ€‹(xi)โˆ’ฮธ(l)T),absent๐œŽsuperscript๐‘ ๐‘™subscriptx๐‘–superscript๐œƒ๐‘™๐‘‡\displaystyle=\sigma\left(\frac{s^{(l)}(\text{x}_{i})-\theta^{(l)}}{T}\right),

where T๐‘‡T is temperature, and ฮธ(l)superscript๐œƒ๐‘™\theta^{(l)} is the learnable threshold value for layer l๐‘™l.With sufficiently small temperature T๐‘‡T, M~(l)โ€‹(xi)superscript~๐‘€๐‘™subscriptx๐‘–\tilde{M}^{(l)}(\text{x}_{i}) will closely approximate the hard masking M(l)โ€‹(xi)superscript๐‘€๐‘™subscriptx๐‘–M^{(l)}(\text{x}_{i})inEq.8.In addition, instead of selecting tokens to be pruned or kept based onthe hard mask ofEq.8, we multiply the soft mask to the output activation of layer l๐‘™l. That is,

(10)x~out(l)superscriptsubscript~xout๐‘™\displaystyle\tilde{\text{x}}_{\mathrm{{out}}}^{(l)}=M~(l)โ€‹(x(l))โ‹…xout(l)absentโ‹…superscript~๐‘€๐‘™superscriptx๐‘™superscriptsubscriptxout๐‘™\displaystyle=\tilde{M}^{(l)}(\text{x}^{(l)})\cdot\text{x}_{\mathrm{{out}}}^{(l)}
(11)=M~(l)โ€‹(x(l))โ‹…LNโ€‹(FFNโ€‹(xMHA(l))+xMHA(l)),absentโ‹…superscript~๐‘€๐‘™superscriptx๐‘™LNFFNsuperscriptsubscriptxMHA๐‘™superscriptsubscriptxMHA๐‘™\displaystyle=\tilde{M}^{(l)}(\text{x}^{(l)})\cdot\mathrm{LN}(\text{FFN}(\text{x}_{\mathrm{{MHA}}}^{(l)})+\text{x}_{\mathrm{{MHA}}}^{(l)}),

where xMHA(l)superscriptsubscriptxMHA๐‘™\text{x}_{\mathrm{{MHA}}}^{(l)} is the output activation of MHA in layer l๐‘™l.If the importance score of token xi is below the threshold by a large margin, its layer output activation nears zero and thus it has little impact on the succeeding layer.Also, because the token gets a zero importance score in the succeeding layer, i.e., s(l+1)โ€‹(xi)=0superscript๐‘ ๐‘™1subscriptx๐‘–0s^{(l+1)}(\text{x}_{i})=0, it is likely to be pruned again.Therefore, the soft pruning scheme is nearly identical in behavior to hard pruning, yet its differentiable formallows for backpropagation and gradient-based optimizations to make ฮธ๐œƒ\theta learnable.After (i) jointly training model parameters and thresholds on downstream tasks with the soft pruning scheme,(ii) we fix the thresholds and binarize the soft mask, and(iii) perform a follow-up fine-tuning of the model parameters.The pseudo-code for this three-step algorithm is given inAlgorithm1.Intuitively, the magnitude of gradient dโ€‹M~(l)โ€‹(xi)/dโ€‹ฮธ(l)๐‘‘superscript~๐‘€๐‘™subscriptx๐‘–๐‘‘superscript๐œƒ๐‘™d\tilde{M}^{(l)}(\text{x}_{i})/d\theta^{(l)} is maximized when the importance score s(l)โ€‹(xi)superscript๐‘ ๐‘™subscriptx๐‘–s^{(l)}(\text{x}_{i}) is close enough to the threshold ฮธ(l)superscript๐œƒ๐‘™\theta^{(l)} and becomes near zero elsewhere.Therefore, the threshold can be trained only based on the tokens that are about to be pruned or retained.

Input: โ„ณโ„ณ\mathcal{M}: model finetuned on target downstream task

Step 1: Apply soft mask to โ„ณโ„ณ\mathcal{M} and train both the thresholds and model parametersโ–ทโ–ท\triangleright Soft Pruning

Step 2: Binarize the mask and fix the thresholds

Step 3: Finetune the model parametersโ–ทโ–ท\triangleright Hard Pruning

Learned Token Pruning for Transformers (6)

Regularization.It is not possible to learn ฮธ๐œƒ\theta without regularization, as the optimizergenerally gets a better loss value if all tokens are present. As such, we add a regularizationterm to penalize the network if tokens are left unpruned. This is achieved by imposingan L1 loss on the masking operator M~~๐‘€\tilde{M}:

(12)โ„’new=โ„’+ฮปโ€‹โ„’regโ€‹whereโ€‹โ„’reg=1Lโ€‹โˆ‘l=1Lโ€–M~(l)โ€‹(x)โ€–1.subscriptโ„’newโ„’๐œ†subscriptโ„’regwheresubscriptโ„’reg1๐ฟsuperscriptsubscript๐‘™1๐ฟsubscriptnormsuperscript~๐‘€๐‘™x1\mathcal{L}_{\mathrm{new}}=\mathcal{L}+\lambda\mathcal{L}_{\mathrm{reg}}\enspace\textrm{where}\enspace\mathcal{L}_{\mathrm{reg}}=\frac{1}{L}\sum_{l=1}^{L}||\tilde{M}^{(l)}(\text{x})||_{1}.

Here, โ„’โ„’\mathcal{L} is the original loss function (e.g., cross-entropy loss), and ฮป๐œ†\lambda is the regularization parameter.Larger values of ฮป๐œ†\lambda result in higher pruning ratios.This regularization operator induces an additional gradient to the threshold:

(13)dโ€‹โ„’regdโ€‹ฮธ(l)=1dโ€‹ฮธ(l)โ€‹โ€–M~(l)โ€‹(x)โ€–1=โˆ‘i=1ndโ€‹M~(l)โ€‹(xi)dโ€‹ฮธ(l)๐‘‘subscriptโ„’reg๐‘‘superscript๐œƒ๐‘™1๐‘‘superscript๐œƒ๐‘™subscriptnormsuperscript~๐‘€๐‘™x1superscriptsubscript๐‘–1๐‘›๐‘‘superscript~๐‘€๐‘™subscriptx๐‘–๐‘‘superscript๐œƒ๐‘™\frac{d\mathcal{L}_{\mathrm{reg}}}{d\theta^{(l)}}=\frac{1}{d\theta^{(l)}}||\tilde{M}^{(l)}(\text{x})||_{1}=\sum_{i=1}^{n}\frac{d\tilde{M}^{(l)}(\text{x}_{i})}{d\theta^{(l)}}

If there are more tokens near the threshold, then the gradient dโ€‹โ„’reg/dโ€‹ฮธ(l)๐‘‘subscriptโ„’reg๐‘‘superscript๐œƒ๐‘™d\mathcal{L}_{\mathrm{reg}}/d\theta^{(l)} is larger.As a result, the threshold is pushed to a larger value, which prunes more tokens near the threshold boundary.

4. Experiments

4.1. Experiment Setup

We implemented LTP on RoBERTabasebase{}_{\text{base}}(Liu etal., 2019) using HuggingFaceโ€™s repo222https://github.com/huggingface/transformers/and tested on (English) GLUE tasks (Wang etal., 2018) and SQuAD 2.0(Rajpurkar etal., 2018).For GLUE tasks (Wang etal., 2018), we use 6 tasks for evaluationincluding sentence similarity (QQP(Iyer etal., 2017), MRPC(Dolan and Brockett, 2005), STS-B(Cer etal., 2017)),sentiment classification (SST-2(Socher etal., 2013)),textual entailment (RTE(Dagan etal., 2005))and natural language inference (MNLI(Williams etal., 2017), QNLI(Rajpurkar etal., 2016)).For evaluating the results, we measure classification accuracy and F1 score for MRPC and QQP, Pearson Correlation and Spearman Correlation forSTS-B, and classification accuracy for the remaining tasks on validation sets.For the tasks with multiple metrics (i.e., MRPC, QQP, STS-B), we report their average.For SQuAD 2.0(Rajpurkar etal., 2018), which is a question and answering task, we measure F1 score for evaluating the results.

As mentioned inSection3.3, the training procedure of LTP consists of two stages:soft pruning that trains both the model parameters and thresholds on downstream tasks,followed by hard pruning that fine-tunes the model parameters with fixed thresholds.We also compare LTP with the current state-of-the-art token pruning methods of SpAtten(Wang etal., 2020b) and LAT(Kim and Cho, 2020) following the implementation details in their papers.SeeA.1 for the details of the training process.We use PyTorch 1.8 throughout all experiments.For CPU inference speed experiments, we use an Intel Haswell CPU with 3.75GB memory of Google Cloud Platform.For GPU inference speed experiments, we use an AWS p3.2xlarge instance that has a NVIDIA V100 GPU with CUDA 11.1.

An important issue in previous work(Goyal etal., 2020; Kim and Cho, 2020) is that all input sequences for a specific task are padded to thenearest power of 2 from the 99th percentile of the sequence lengths, and then the pruned performance is compared with the padded baseline.This results in exaggerated performance gain over the baseline.For instance, in(Goyal etal., 2020), inputs from the SST-2 dataset are padded to 64, while its average sentence length is 26 (cf.Figure1).With this approach, one can achieve roughly 2.5ร—2.5\times speedup by just removing padding.As such, we avoid any extra padding of input sequences and all speedups and throughputs we report are compared with the unpaddedbaselines.

TaskAccuracy MetricGFLOPsSpeedupRoBERTaLTPRoBERTaLTPLTPMNLI-m87.5386.536.833.641.88ร—\timesMNLI-mm87.3686.377.153.631.97ร—\timesQQP90.3989.695.312.532.10ร—\timesQNLI92.8691.988.944.771.87ร—\timesSST-294.2793.464.452.132.09ร—\timesSTS-B90.8990.035.532.841.95ร—\timesMRPC92.1491.599.334.442.10ร—\timesRTE77.9877.9811.386.301.81ร—\timesSQuAD 2.083.0482.2532.1216.991.89ร—\times


TaskQuantiles (train)Quantiles (eval)KL Div.MNLI-m27/38/5026/37/500.0055MNLI-mm27/38/5029/39/510.0042QQP23/28/3623/28/360.0006QNLI39/48/5939/49/610.0092SST-27/11/1918/25/331.2250STS-B20/24/3221/29/410.0925MRPC45/54/6345/54/640.0033RTE44/57/8642/54/780.0261

4.2. Performance Evaluation

Table1 lists the accuracy and GFLOPs for LTP.We select a model for each downstream task that achieves the smallest GFLOPs whileconstraining the accuracy degradation from the baseline (RoBERTabasebase{}_{\text{base}}) to be at most 1%.Using our method, sequence lengths in each layer can vary across different input sentences.Therefore, we report the averaged GFLOPs of processing all input sentences in the development set.As shown in the table, our method achieves speedup of 1.96ร—\times on average and up to 2.10ร—\times within 1% accuracy degradation.

Figure4 plots the accuracy of LTP (blue lines) as well as the prior pruning methods (red lines for SpAtten and orange lines for LAT) with different FLOPs on GLUE tasks.LTP consistently outperforms SpAtten for all tasks with up to ~2% higher accuracy under the same amount of FLOPs.Compared with LAT, LTP outperforms for all tasks except for QQP with up to ~2.5% higher accuracy for the same target FLOPs.For QQP alone, LTP achieves at most ~0.2% lower accuracy than LTP.

An important observation is that for SST-2 and STS-B where LTP (ours) outperforms LAT with large margins, the sequence length varies greatly from the training dataset to the evaluation datasetas can be seen from the large KL-divergence inTable2 and Figure1 (b, c).On the other hand, for QQP, the only dataset that LAT slightly outperforms LTP (ours),the sequence length distribution of the training dataset is almost identical to that of the evaluation dataset as can be seen from the small KL-divergence inTable2 and Figure2 (a).This observation supports our claim inSection1 and2:LTP is robust to sequence length variations as it does not fix the pruning configuration unlike other methods using a single pruning configuration regardless of the input sequence length.This is important in practice because the sequence lengths during inference do not always follow the sequence length distribution of the training dataset as in SST-2 and STS-B.We make a further discussion inSection4.3.

For SQuAD 2.0, we have similar results to GLUE.As can be seen in Table1 and Figure5 (Left),we obtain nearly identical F1 score to baseline at 0.58 relative FLOPs,and 1.89ร—\times speedup with less than 1% drop of F1 score.The SQuAD 2.0 dataset is divided into two subsets: the subset of examples where the answer to the question is included in the context text, and the subset that has no answer.In Figure5 (Right), we further plot the results on each subset of the dataset (black and red for the one with and without answers, respectively).We see that the F1 score decreases for the subset with answers and increases for the subset without answers as we decrease the relative FLOPs.This is to be expected as the question answering head will predict no answer if the start and end points of the answer within the context cannot be determined due to high token pruning ratios.Thus, a careful setting of ฮป๐œ†\lambda in Eq.12 is necessary to balance the accuracy between the two subsets.

At last, we also highlight thatLTP has an additional gain over the prior top-k๐‘˜k based approaches by avoiding computationally inefficient top-k๐‘˜k operations as further discussed inSectionA.2.

Learned Token Pruning for Transformers (7)
Learned Token Pruning for Transformers (8)

4.3. Robustness to Sequence Length Variation

InSection4.2, we claim that LTP is more robust against sequence length variations from training time to evaluation time.Here, we make a more systematic analysis on this.Ideally, performance should be independent of sequence length.To quantitatively test the robustness of pruning methods against sequence length variations, we train LTP and LAT on QNLI and QQP,but only using the training examples whose sequence lengths are below the median length of the evaluation dataset.We then evaluate the resulting models using the evaluation examples with sequence lengths (i) below the median (~Q2), (ii) between the median and the third quantile (Q2~Q3), and (iii) above the third quantile (Q3~) of the evaluation dataset.To make a fair comparison, we choose models from LTP and LAT that require similar FLOPs on ~Q2.

The results are listed inTable3. LTP consistently achieves better accuracy and FLOPs over different sequence lengths, even with those that are significantly longer than the training sequences.On the contrary, LAT shows significant accuracy degradation as longer sequences are over-pruned, which can be seen from the significant FLOPs reduction.In particular, LTP outperforms LAT by up to 16.44% and 9.20% on QNLI and QQP for the Q3~ evaluation dataset.

TaskQNLIQQP~Q2Q2~Q3Q3~~Q2Q2~Q3Q3~LTPAcc.91.2190.0291.8189.4289.5191.37(ours)FLOPs55.8955.6056.0255.1856.2958.01LATAcc.90.8786.1275.3789.2087.2782.17FLOPs56.2146.5535.8955.1746.6134.14Diff.Acc.+0.34+3.90+16.44+0.22+2.24+9.20

4.4. Ablation Studies

Instead of learning thresholds, we can set them manually.Because manually searching over the exponential search space is intractable, we add a constraint to the search space by assigning linearly rising threshold values for each layer, similar to how SpAtten(Wang etal., 2020b) assigns the token retain ratios:given the threshold of the final layer ฮธ(L)superscript๐œƒ๐ฟ\theta^{(L)}, the threshold for layer l๐‘™l is set as ฮธ(L)โ€‹l/Lsuperscript๐œƒ๐ฟ๐‘™๐ฟ\theta^{(L)}l/{L}.We plot the accuracy and FLOPs of the manual threshold approach inFigure4 as black lines.While this approach exhibits decent results on all downstream tasks, the learned thresholds consistently outperform the manual thresholds under the same FLOPs.This provides empirical evidence for the effectiveness of our threshold learning method.

4.5. Direct Throughput Measurement on Hardware

We directly measure throughputs on real hardware by deploying LTP on a NVIDIA V100 GPU and a Intel Haswell CPU.For inference, we completely remove the pruned tokens and rearrange the retained tokens into a blank-free sequence to have a latency gain.One consequence of adaptive pruning, however, is that each sequence will end up with a different pruning pattern and sequence length.As such, a naive hardware implementation of batched inference may require padding all the sequences in a batch to ensure that they all have the same length (i.e., the maximum sequence length in the batch),which results in a significant portion of computation being wasted to process padding tokens.To avoid this, we use NVIDIAโ€™s Faster Transformer333https://github.com/NVIDIA/FasterTransformerfor GPU implementation that requires large batch sizes.This framework dynamically removes and inserts padding tokens during inference so that most of the transformer operations effectively skip processing padding tokens.This enables fast inference even with irregular pruning lengths of individual sequences.For the CPU implementation, we find naive batching (i.e., padding sequences to the maximum sentence length) enough for good throughput.

The measured throughput results are shown inFigure6 for different batch sizes.For all experiments, relative throughput is evaluated 3 times on the randomly shuffled datasets.LTP achieves up to โˆผsimilar-to\sim1.9ร—\times and โˆผsimilar-to\sim2.0ร—\times thoughput improvement for QNLI and QQP on both CPU and GPU, as compared to the baseline.This is similar to the theoretical speedup inferred from the FLOPs reduction reported inTable1.Importantly, the speedup of LTP increases with larger batch sizes on both CPU and GPU, proving effectiveness of LTP on batched cases.

4.6. LTP with Quantization and Knowledge Distillation

Learned Token Pruning for Transformers (9)

Here, we show that our token-level pruning method is compatible with other compression methods.In particular, we perform compression experiments by combining LTP with quantization and knowledge distillation(Hinton etal., 2015) together.For quantization,we use the static uniform symmetric integer quantization method(Gholami etal., 2021), which is easy to deploy in commodity hardware with minimal run-time overhead.All the model parameters are quantized to 8-bit integers, except for those of the embedding layer whose bit-width does not affect the inference speed.Afterwards, we apply knowledge distillation that helps recover accuracy for high compression ratios.We set the baseline RoBERTabasebase{}_{\text{base}} model as the teacher and the quantized LTP model as the student.We then distill knowledge from the teacher model into the student model through a knowledge distillation lossthat matches the output logits of the classification layer and the output representations of the embedding layer in the teacher modelto the counterparts in the student model.The training objective is a convex combination of the original loss and the knowledge distillation loss.As shown inFigure7, we achieve up to 10ร—\times reduction in bit operations (BOPs) with less than 2%percent22\% accuracy degradation as compared to FP16 RoBERTabasebase{}_{\text{base}} by combining quantization and knowledge distillation.The results empirically show the effectiveness of LTP with other compression methods.

5. Conclusions

In this work, we present Learned Token Pruning (LTP), a fully automated token pruning framework for transformers.LTP only requires comparison of token importance scores with threshold values to determine unimportant tokens, thus adding minimal complexity over the original transformer inference.Importantly, the threshold values are learned for each layer during training through a differentiable soft binarized mask that enables backpropagation of gradients to the threshold values.Compared to the state-of-the-art token pruning methods, LTP outperforms by up to ~2.5% accuracy with the same amount of FLOPs.Extensive experiments on GLUE and SQuAD show the effectiveness of LTP, as it achieves up to 2.10ร—\times FLOPs reduction over the baseline model within only 1% of accuracy degradation.Our preliminary (and not highly optimized) implementation shows up to 1.9ร—\times and 2.0ร—\times throughput improvement on an Intel Haswell CPU and a NVIDIA V100 GPU.Furthermore, LTP exhibits significantly better robustness and consistency over different input sequence lengths.

References

  • (1)
  • Bai etal. (2020)Haoli Bai, Wei Zhang,Lu Hou, Lifeng Shang,Jing Jin, Xin Jiang, QunLiu, Michael Lyu, and Irwin King.2020.BinaryBERT: Pushing the Limit of BERTQuantization.arXiv preprint arXiv:2012.15701(2020).
  • Bhandare etal. (2019)Aishwarya Bhandare, VamsiSripathi, Deepthi Karkada, Vivek Menon,Sun Choi, Kushal Datta, andVikram Saletore. 2019.Efficient 8-bit quantization of transformer neuralmachine language translation model.arXiv preprint arXiv:1906.00532(2019).
  • Cer etal. (2017)Daniel Cer, Mona Diab,Eneko Agirre, Inigo Lopez-Gazpio, andLucia Specia. 2017.Semeval-2017 task 1: Semantic textualsimilarity-multilingual and cross-lingual focused evaluation.arXiv preprint arXiv:1708.00055(2017).
  • Chen etal. (2020)Tianlong Chen, JonathanFrankle, Shiyu Chang, Sijia Liu,Yang Zhang, Zhangyang Wang, andMichael Carbin. 2020.The lottery ticket hypothesis for pre-trainedBERT networks.arXiv preprint arXiv:2007.12223(2020).
  • Child etal. (2019)Rewon Child, Scott Gray,Alec Radford, and Ilya Sutskever.2019.Generating long sequences with sparsetransformers.arXiv preprint arXiv:1904.10509(2019).
  • Dagan etal. (2005)Ido Dagan, Oren Glickman,and Bernardo Magnini. 2005.The PASCAL recognising textual entailmentchallenge. In Machine Learning ChallengesWorkshop. Springer, 177โ€“190.
  • Devlin etal. (2018)Jacob Devlin, Ming-WeiChang, Kenton Lee, and KristinaToutanova. 2018.Bert: Pre-training of deep bidirectionaltransformers for language understanding.arXiv preprint arXiv:1810.04805(2018).
  • Dolan and Brockett (2005)WilliamB Dolan andChris Brockett. 2005.Automatically constructing a corpus of sententialparaphrases. In Proceedings of the ThirdInternational Workshop on Paraphrasing (IWP2005).
  • Fan etal. (2019)Angela Fan, EdouardGrave, and Armand Joulin.2019.Reducing transformer depth on demand withstructured dropout.arXiv preprint arXiv:1909.11556(2019).
  • Fan etal. (2020)Angela Fan, Pierre Stock,Benjamin Graham, Edouard Grave,Rรฉmi Gribonval, Hervรฉ Jรฉgou,and Armand Joulin. 2020.Training with quantization noise for extreme modelcompression.arXiv e-prints (2020),arXivโ€“2004.
  • Frankle and Carbin (2018)Jonathan Frankle andMichael Carbin. 2018.The lottery ticket hypothesis: Finding sparse,trainable neural networks.arXiv preprint arXiv:1803.03635(2018).
  • Gholami etal. (2021)Amir Gholami, Sehoon Kim,Zhen Dong, Zhewei Yao,MichaelW Mahoney, and Kurt Keutzer.2021.A survey of quantization methods for efficientneural network inference.arXiv preprint arXiv:2103.13630(2021).
  • Goyal etal. (2020)Saurabh Goyal,AnamitraRoy Choudhury, Saurabh Raje,Venkatesan Chakaravarthy, YogishSabharwal, and Ashish Verma.2020.Power-bert: Accelerating bert inference viaprogressive word-vector elimination. InInternational Conference on Machine Learning.PMLR, 3690โ€“3699.
  • Hinton etal. (2015)Geoffrey Hinton, OriolVinyals, and Jeff Dean.2015.Distilling the knowledge in a neural network.arXiv preprint arXiv:1503.02531(2015).
  • Hou etal. (2020)Lu Hou, Zhiqi Huang,Lifeng Shang, Xin Jiang,Xiao Chen, and Qun Liu.2020.Dynabert: Dynamic bert with adaptive width anddepth.arXiv preprint arXiv:2004.04037(2020).
  • Iandola etal. (2020)ForrestN Iandola,AlbertE Shaw, Ravi Krishna, andKurtW Keutzer. 2020.SqueezeBERT: What can computer vision teach NLPabout efficient neural networks?arXiv preprint arXiv:2006.11316(2020).
  • Iyer etal. (2017)Shankar Iyer, NikhilDandekar, and Kornl Csernai.2017.First Quora Dataset Release: QuestionPairs.(2017).URL https://data. quora.com/First-Quora-Dataset-Release-Question-Pairs (2017).
  • Jiao etal. (2019)Xiaoqi Jiao, Yichun Yin,Lifeng Shang, Xin Jiang,Xiao Chen, Linlin Li,Fang Wang, and Qun Liu.2019.Tinybert: Distilling bert for natural languageunderstanding.arXiv preprint arXiv:1909.10351(2019).
  • Katharopoulos etal. (2020)Angelos Katharopoulos,Apoorv Vyas, Nikolaos Pappas, andFranรงois Fleuret. 2020.Transformers are rnns: Fast autoregressivetransformers with linear attention. InInternational Conference on Machine Learning.PMLR, 5156โ€“5165.
  • Khetan and Karnin (2020)Ashish Khetan and ZoharKarnin. 2020.schubert: Optimizing elements of bert.arXiv preprint arXiv:2005.06628(2020).
  • Kim and Cho (2020)Gyuwan Kim and KyunghyunCho. 2020.Length-Adaptive Transformer: Train Once with LengthDrop, Use Anytime with Search.arXiv preprint arXiv:2010.07003(2020).
  • Kim etal. (2021)Sehoon Kim, Amir Gholami,Zhewei Yao, MichaelW Mahoney, andKurt Keutzer. 2021.I-BERT: Integer-only BERT Quantization.International conference on machinelearning (2021).
  • Kitaev etal. (2020)Nikita Kitaev, ลukaszKaiser, and Anselm Levskaya.2020.Reformer: The efficient transformer.arXiv preprint arXiv:2001.04451(2020).
  • Lagunas etal. (2021)Franรงois Lagunas,Ella Charlaix, Victor Sanh, andAlexanderM Rush. 2021.Block pruning for faster transformers.arXiv preprint arXiv:2109.04838(2021).
  • Lan etal. (2019)Zhenzhong Lan, MingdaChen, Sebastian Goodman, Kevin Gimpel,Piyush Sharma, and Radu Soricut.2019.Albert: A lite bert for self-supervised learning oflanguage representations.arXiv preprint arXiv:1909.11942(2019).
  • Li etal. (2020)Bingbing Li, ZhenglunKong, Tianyun Zhang, Ji Li,Zhengang Li, Hang Liu, andCaiwen Ding. 2020.Efficient transformer-based large scale languagerepresentations using hardware-friendly block structured pruning.arXiv preprint arXiv:2009.08065(2020).
  • Lin etal. (2020)Zi Lin, JeremiahZhe Liu,Zi Yang, Nan Hua, andDan Roth. 2020.Pruning Redundant Mappings in Transformer Modelsvia Spectral-Normalized Identity Prior.arXiv preprint arXiv:2010.01791(2020).
  • Liu etal. (2019)Yinhan Liu, Myle Ott,Naman Goyal, Jingfei Du,Mandar Joshi, Danqi Chen,Omer Levy, Mike Lewis,Luke Zettlemoyer, and VeselinStoyanov. 2019.RoBERTa: A robustly optimized bert pretrainingapproach.arXiv preprint arXiv:1907.11692(2019).
  • Liu etal. (2021)Zejian Liu, Fanrong Li,Gang Li, and Jian Cheng.2021.EBERT: Efficient BERT Inference with DynamicStructured Pruning. In Findings of the Associationfor Computational Linguistics: ACL-IJCNLP 2021.4814โ€“4823.
  • Michel etal. (2019)Paul Michel, Omer Levy,and Graham Neubig. 2019.Are sixteen heads really better than one?arXiv preprint arXiv:1905.10650(2019).
  • Prasanna etal. (2020)Sai Prasanna, AnnaRogers, and Anna Rumshisky.2020.When BERT plays the lottery, all tickets arewinning.arXiv preprint arXiv:2005.00561(2020).
  • Press etal. (2021)Ofir Press, NoahA Smith,and Mike Lewis. 2021.Train Short, Test Long: Attention with LinearBiases Enables Input Length Extrapolation.arXiv preprint arXiv:2108.12409(2021).
  • Rajpurkar etal. (2018)Pranav Rajpurkar, RobinJia, and Percy Liang. 2018.Know what you donโ€™t know: Unanswerable questionsfor SQuAD.arXiv preprint arXiv:1806.03822(2018).
  • Rajpurkar etal. (2016)Pranav Rajpurkar, JianZhang, Konstantin Lopyrev, and PercyLiang. 2016.SQuAD: 100,000+ questions for machinecomprehension of text.arXiv preprint arXiv:1606.05250(2016).
  • Roy etal. (2021)Aurko Roy, MohammadSaffar, Ashish Vaswani, and DavidGrangier. 2021.Efficient content-based sparse attention withrouting transformers.Transactions of the Association forComputational Linguistics 9 (2021),53โ€“68.
  • Sajjad etal. (2020)Hassan Sajjad, FahimDalvi, Nadir Durrani, and PreslavNakov. 2020.On the Effect of Dropping Layers of Pre-trainedTransformer Models.arXiv preprint arXiv:2004.03844(2020).
  • Sanh etal. (2019)Victor Sanh, LysandreDebut, Julien Chaumond, and ThomasWolf. 2019.DistilBERT, a distilled version of BERT: smaller,faster, cheaper and lighter.arXiv preprint arXiv:1910.01108(2019).
  • Sanh etal. (2020)Victor Sanh, Thomas Wolf,and AlexanderM Rush. 2020.Movement pruning: Adaptive sparsity byfine-tuning.arXiv preprint arXiv:2005.07683(2020).
  • Shen etal. (2020)Sheng Shen, Zhen Dong,Jiayu Ye, Linjian Ma,Zhewei Yao, Amir Gholami,MichaelW Mahoney, and Kurt Keutzer.2020.Q-BERT: Hessian Based Ultra Low PrecisionQuantization of BERT.. In AAAI.8815โ€“8821.
  • Socher etal. (2013)Richard Socher, AlexPerelygin, Jean Wu, Jason Chuang,ChristopherD Manning, AndrewY Ng, andChristopher Potts. 2013.Recursive deep models for semantic compositionalityover a sentiment treebank. In Proceedings of the2013 conference on empirical methods in natural language processing.1631โ€“1642.
  • Sun etal. (2019)Siqi Sun, Yu Cheng,Zhe Gan, and Jingjing Liu.2019.Patient knowledge distillation for bert modelcompression.arXiv preprint arXiv:1908.09355(2019).
  • Sun etal. (2020)Zhiqing Sun, Hongkun Yu,Xiaodan Song, Renjie Liu,Yiming Yang, and Denny Zhou.2020.Mobilebert: a compact task-agnostic bert forresource-limited devices.arXiv preprint arXiv:2004.02984(2020).
  • Tang etal. (2019)Raphael Tang, Yao Lu,Linqing Liu, Lili Mou,Olga Vechtomova, and Jimmy Lin.2019.Distilling task-specific knowledge from BERT intosimple neural networks.arXiv preprint arXiv:1903.12136(2019).
  • Tay etal. (2020)Yi Tay, Dara Bahri,Liu Yang, Donald Metzler, andDa-Cheng Juan. 2020.Sparse sinkhorn attention. InInternational Conference on Machine Learning.PMLR, 9438โ€“9447.
  • Vaswani etal. (2017)Ashish Vaswani, NoamShazeer, Niki Parmar, Jakob Uszkoreit,Llion Jones, AidanN Gomez,ลukasz Kaiser, and IlliaPolosukhin. 2017.Attention is all you need. InAdvances in neural information processingsystems. 5998โ€“6008.
  • Voita etal. (2019)Elena Voita, DavidTalbot, Fedor Moiseev, Rico Sennrich,and Ivan Titov. 2019.Analyzing multi-head self-attention: Specializedheads do the heavy lifting, the rest can be pruned.arXiv preprint arXiv:1905.09418(2019).
  • Vyas etal. (2020)Apoorv Vyas, AngelosKatharopoulos, and Franรงois Fleuret.2020.Fast transformers with clustered attention.Advances in Neural Information ProcessingSystems 33 (2020).
  • Wang etal. (2018)Alex Wang, AmanpreetSingh, Julian Michael, Felix Hill,Omer Levy, and SamuelR Bowman.2018.GLUE: A multi-task benchmark and analysisplatform for natural language understanding.arXiv preprint arXiv:1804.07461(2018).
  • Wang etal. (2020b)Hanrui Wang, ZhekaiZhang, and Song Han. 2020b.SpAtten: Efficient Sparse Attention Architecturewith Cascade Token and Head Pruning.arXiv preprint arXiv:2012.09852(2020).
  • Wang etal. (2020a)Sinong Wang, Belinda Li,Madian Khabsa, Han Fang, andHao Ma. 2020a.Linformer: Self-Attention with Linear Complexity.arXiv preprint arXiv:2006.04768(2020).
  • Wang etal. (2019)Ziheng Wang, JeremyWohlwend, and Tao Lei. 2019.Structured pruning of large language models.arXiv preprint arXiv:1910.04732(2019).
  • Williams etal. (2017)Adina Williams, NikitaNangia, and SamuelR Bowman.2017.A broad-coverage challenge corpus for sentenceunderstanding through inference.arXiv preprint arXiv:1704.05426(2017).
  • Yao etal. (2021)Zhewei Yao, Linjian Ma,Sheng Shen, Kurt Keutzer, andMichaelW Mahoney. 2021.MLPruning: A Multilevel Structured PruningFramework for Transformer-based Models.arXiv preprint arXiv:2105.14636(2021).
  • Ye etal. (2021)Deming Ye, Yankai Lin,Yufei Huang, and Maosong Sun.2021.TR-BERT: Dynamic Token Reduction for AcceleratingBERT Inference.arXiv preprint arXiv:2105.11618(2021).
  • Zadeh etal. (2020)AliHadi Zadeh, Isak Edo,OmarMohamed Awad, and AndreasMoshovos. 2020.Gobo: Quantizing attention-based nlp models for lowlatency and energy efficient inference. In 202053rd Annual IEEE/ACM International Symposium on Microarchitecture (MICRO).IEEE, 811โ€“824.
  • Zafrir etal. (2019)Ofir Zafrir, GuyBoudoukh, Peter Izsak, and MosheWasserblat. 2019.Q8BERT: Quantized 8bit bert.arXiv preprint arXiv:1910.06188(2019).
  • Zaheer etal. (2020)Manzil Zaheer, GuruGuruganesh, Avinava Dubey, JoshuaAinslie, Chris Alberti, SantiagoOntanon, Philip Pham, Anirudh Ravula,Qifan Wang, Li Yang, etal.2020.Big bird: Transformers for longer sequences.arXiv preprint arXiv:2007.14062(2020).
  • Zhang etal. (2020)Wei Zhang, Lu Hou,Yichun Yin, Lifeng Shang,Xiao Chen, Xin Jiang, andQun Liu. 2020.Ternarybert: Distillation-aware ultra-low bitbert.arXiv preprint arXiv:2009.12812(2020).
  • Zhao etal. (2020)Mengjie Zhao, Tao Lin,Fei Mi, Martin Jaggi, andHinrich Schรผtze. 2020.Masking as an efficient alternative to finetuningfor pretrained language models.arXiv preprint arXiv:2004.12406(2020).

Appendix A Appendix

A.1. Training Details

The training procedure of LTP consists of two separate stages: soft pruning followed by hard pruning.For soft pruning, we train both the model parameters and the thresholds on downstream tasks for 1 to 10 epochs, depending on the dataset size.We find it effective to initialize the thresholds with linearly rising values as described in4.4 with a fixed threshold of the final layer.We search the optimal temperature T๐‘‡T in a search space of {1, 2, 5, 10, 20}e-4and vary ฮป๐œ†\lambda from 0.001 to 0.4 to controlthe number of tokens to be pruned (and thus the FLOPs) for all experiments.We then fix the thresholds and perform an additional training with the hard pruning to fine-tune the model parameters only.More detailed hyperparameter settings are listed inTableA.1 for GLUE and SQuAD 2.0.

SpAtten is trained based on the implementation details in the paper:the first three layers retain all tokens and the remaining layers are assigned with linearly decaying token retain ratiountil it reaches the final token retain ratio at the last layer.We vary the final token retain ratio from 1.0 to -1.0 (prune all tokens for non-positive retain ratios) to control the FLOPs of SpAtten.For both LTP and SpAtten, we use learning rate of {0.5, 1, 2}e-5, except for the soft pruning stage of LTP where we use 2e-5.We follow the optimizer setting in RoBERTa(Liu etal., 2019) and use batch size of 64 for all experiments.

LAT is trained using the same hyperparameter and optimizer setting in the paper except for the length drop probabilities:for more extensive search on more aggressive pruning configurations, we used 0.25, 0.3, 0.35, and 0.4 for the length drop probability instead of 0.2 in the original setting.

Learned Token Pruning for Transformers (10)

A.2. Computation Efficiency Comparison

Here we compare the efficiency of top-k๐‘˜k versus threshold operation.To do this, we usea batch size of 32 and average the latency over 1000 independent runs.For each sequence length, we test over five different token retain ratios from 10% to 50%(e.g., 10% token retain ratio is the case where we select top-k๐‘˜k 10% of tokens from the input sequence).

With the above setting, we directly measure the latency of these twooperations on an Intel Haswell CPU, and report the results inFigureA.1.For top-k๐‘˜k operation, there is a noticeable increase in latency when token retain ratios and sequence lengths become largerwhereas this is not an issue for our threshold pruning method as it only requires a comparison operation.More importantly, top-k๐‘˜k operation incurs a huge latency overhead that is up to 7.4ร—\times and 33.4ร—\times slower than threshold operation for sequence length of 128 and 1024, respectively.444The inefficiency of top-k๐‘˜k is also further confirmed by(Wang etal., 2020b), where they report only 1.1ร—\times speedup for GPT-2without the top-k๐‘˜k hardware engine that they developed.

StageHyperparamGLUESQuAD 2.0Softpruningepochs1 - 101learning rate2e-52e-5T๐‘‡T{{\{1, 2, 5, 10, 20}}\}e-4{{\{1, 10}}\}e-4ฮป๐œ†\lambda0.001 - 0.20.001 - 0.4init. final thres.0.010.003Hardepochs105pruninglr{0.5, 1, 2}e-5{0.5, 1, 2}e-5

A.3. Discussion

Learned Token Pruning for Transformers (11)
Learned Token Pruning for Transformers (12)
Learned Token Pruning for Transformers (13)
Learned Token Pruning for Transformers (14)

A.3.1. Example Sequence Length Trajectories

FigureA.2 shows how the pruned sequence length decreases for input sequences of varying lengths.For LAT, the token pruning configuration is fixed for all sequences in the dataset. In LTP, token pruning can be more or less aggressive depending on the sequence content and the number of important tokens in the sequence.On average, LTP calculates 25.86% fewer tokens per layer than LAT for MNLI-m and 12.08% fewer tokens for SST-2.For both LTP and LAT, the model has been trained to produce a 1% drop in accuracy compared to baseline.

A.3.2. Unbiased Token Pruning for Various Sequence Length

Figure A.3 shows the distributions of initial sequence lengthsfor sequences that are correctly classified and for sequences that are not. We see thatfor multiple tasks, there is no significant correlation between the length ofthe sequence and the accuracy of the pruned models.Importantly, this suggests that our method is not biased towards being more accurate on longer or shorter sequences.

A.4. Comparison with TR-BERT on GLUE

Unlike LAT and SpAtten, TR-BERT (Ye etal., 2021) does not report results on the GLUE benchmark tasks described in the paper. We attempted to run TR-BERT on the GLUE tasks using the TR-BERT repo555https://github.com/thunlp/TR-BERT, but were unable to get the algorithm to converge to a high accuracy, despite varying the learning rate between 1e-6 and 1e-3 and the value of ฮฑ๐›ผ\alpha, the parameter that defines the length penalty, over the search space of {0.01,0.05,0.1,0.5,1,2,5}0.010.050.10.5125\{0.01,0.05,0.1,0.5,1,2,5\}. We also varied the number of training epochs based on the number of examples in each taskโ€™s training set. The authors of TR-BERT note the convergence difficulties of RL learning while describing the algorithm in their paper.

Learned Token Pruning for Transformers (2024)
Top Articles
Latest Posts
Article information

Author: Edwin Metz

Last Updated:

Views: 6312

Rating: 4.8 / 5 (78 voted)

Reviews: 85% of readers found this page helpful

Author information

Name: Edwin Metz

Birthday: 1997-04-16

Address: 51593 Leanne Light, Kuphalmouth, DE 50012-5183

Phone: +639107620957

Job: Corporate Banking Technician

Hobby: Reading, scrapbook, role-playing games, Fishing, Fishing, Scuba diving, Beekeeping

Introduction: My name is Edwin Metz, I am a fair, energetic, helpful, brave, outstanding, nice, helpful person who loves writing and wants to share my knowledge and understanding with you.