Alignment post-training for LLMs: lessons learned making it work

I recently completed a research project where the objective was to get a Language Model (LM) to perform entity and relationship (ER) extraction over noisy, short-form web documents, at scale.  Since the task involved processing a large volume of documents on an hourly basis, throughput speed was key - both in terms of the time-to-first-token (TTFT) and the tokens-per-second (TPS).

The LM was required to read the documents and extract a list of key entities together with succinct descriptions. Then it had to infer both the explicit and implied relationships between these entities and summarise them. Furthermore, the client had some very specific requirements about the specific kinds of entities and relationships they were interested in, and those which they were not.

This use case - fast inference on a specific task - is a great example of a situation where customising a small, efficient LM made more sense than using a frontier service. By hosting a small LM, we could engineer the set-up to achieve blazingly fast inference speeds at low cost. Furthermore, by applying post-training techniques, we could get task-specific performance of the smaller LM to a level equivalent with leading frontier models.

In this blog, I’ll explain how we did it and what we learned along the way.

Firstly, what this post is not. There’s a host of material out there covering the theory behind post-training techniques or providing example code. I don’t cover that here. Instead, expect a description of the process we followed, a summary of the learnings and (what I hope is a helpful) bundle of practical tips for anyone else doing the same.

Start with prompt engineering

A rule-of-thumb for language models is that pre-training adds knowledge, fine-tuning controls behaviour and style, and alignment post-training matches user preferences. That said, it’s rarely clear from the outset whether the behaviours you want in your LM can be achieved via inference-time techniques or supervised fine-tuning (SFT), or whether they’ll require an alignment step. As always, work your way up the ladder of complexity, moving to the next stage only when you have confidence that your current technique isn’t going to achieve the results you require.

Therefore, we started with prompt engineering.

We began by taking the time to write - and to successively improve - an excellent prompt template. To evaluate the prompts, we needed some examples of “excellent” outputs - enough examples to show some diversity and to enable us to detect differences in prompt efficacy, but no larger, since we had to write them all by hand.

We crafted 105 ER extraction exemplars. We started with outputs from a frontier model (using a basic prompt) and edited them until they were perfect examples of what the client was looking for. We then crafted, tested and iterated a base prompt (still using the frontier model, trying to optimise it’s behaviour). At this stage, our evaluation metric was a simple F1 score, measuring the similarity of the entities and relationships extracted (requiring an exact string match) by the model with those in the exemplars.

We could see two failure modes in the results:

  • Sometimes the extracted reports didn’t capture the “essence” of the concepts - i.e. salient facts were missing.

  • Sometimes, entities or relationships were missed, or “irrelevant” (from the client’s perspective) items were extracted.

With the prompt refined, we used DsPy to help us automatically select the 5 best exemplars to be added to the prompt template. (Recall that inference speed was key. Since a long prompt slows things down, when we came to training the local LM, we pared the prompt right back and relied on post-training to steer the behaviour. For now though, richer is better.)

Fork-in-the-road question: if we now switched the frontier model for the local model, would the quality metrics drop? Yes, they did. Quite a lot, in fact. We played around with the prompt template a bit more but couldn’t resolve the problems. This was the signal that we needed to move onto some fine-tuning.

Constructing the post-training datasets

At this point, we could get our frontier model to generate extraction reports that were pretty close to what we wanted. Now to get the smaller LM to do the same. We start by selecting 1,000 examples from our documents and having the frontier model generate the ER reports.

One thing to remember with fine-tuning is that the behaviours you are teaching are not necessarily preserved when the input data is “out of distribution” - i.e. when documents presented at inference time look very different to those presented at training time. To counter this, when building datasets for SFT or preference learning, we sampled our examples for diversity. There are many ways to do this, here is what we did:

  1. We embedded 100,000 of our documents. (The choice of embedding model is not super-important, just pick something that ranks highly on the Clustering task of the MTEB leaderboard.)

  2. We built a nearest-neighbour index using Spotify’s Annoy library (it’s fast and the underlying algorithm doesn’t require all-to-all comparisons).

  3. We constructed a graph representation of the dataset with documents as nodes and nearest neighbour similarities as edges. We tuned the similarity cut-off to a point where the graph no longer had a giant component but also had as few isolated nodes as possible.

  4. We iterated through the connected components of the graph and picked a single document from each component.

The result was a dataset consisting of highly diverse examples.

N.B. what a sampling process like this won’t solve is the additional complications introduced by non-stationary datasets (i.e. where the topics, the length or structure change over time). Unfortunately, there’s no real way to future-proof a fine-tuned model against this, other than to periodically re-run the post-training processes when model drift is detected.

Don’t skip Supervised fine-tuning stage

So, we had a well-diversified set of 1,000 examples of reasonable completions for the ER extraction task. Alignment post-training was going to take a lot of effort, so it made sense to see whether we could improve our local LM’s behaviours enough simply using a standard SFT set-up.

As I mentioned previously, there’s plenty of content out there on fine-tuning and I want to get onto the alignment post-training, so I’ll just leave a few notes on the process here.

  • During the fine-tuning, we’re minimising the cross-entropy loss between the log-probs (from the final layer of the model head) and the completions. For evaluations, we tracked the score on the evaluation harness (our “gold standard” completions). Although the process will be improving both the entity extraction and the summarisation capability, at this point it was easier for us to simply track the former.

  • In our use case, we weren’t really worried about degrading other capabilities of the LM. If we were, we’d need to mix in additional fine-tuning sources to ensure that we didn’t suffer a performance regressions on other tasks.

  • Even so, too much fine-tuning can really affect generalisation of performance to out-of-distribution data (which we knew we would have). Therefore, we stopped the process conservatively (after a single epoch, in fact.)

  • Fine-tuning did close the gap between the performance of our local LM and the frontier model. Recall that not even the frontier model was perfect - and we’re going to explore how alignment post-training can help - but it always makes sense to get as much of the improvement from a simpler set-up like SFT before moving onto to something more complicated.

Alignment post-training

This phase of training is referred to by a variety of names, including “reward learning”, “preference tuning” and “preference optimisation”. Simply put, it involves techniques that attempt to steer a LM towards producing generations that will be preferred by end users.

Two motivating factors which might prompt you to try an alignment exercise are:

  1. When optimising a small LM, the AI-generated SFT dataset can only ever make your local model as good as the frontier model; if even the frontier model has failure modes then there is a ceiling on your data quality - unless you create it all by hand…

  2. It’s much easier for human annotators to simply contrast and rate model outputs than it is for them to edit them by hand - especially when you’re dealing with a lot of examples.

The most well-known approach to alignment uses reinforcement learning, commonly with an algorithm called Proximal Policy Optimisation (PPO) (see **Training language models to follow instructions with human feedback** from the OpenAI team for a canonical overview**)**. The process is known as “reinforcement learning with human feedback” - or simply RLHF. RLHF was broadly responsible for transforming their GPT-3 model – an impressive but wayward curiosity which generated reams of made-up text – into ChatGPT – the essential companion for anyone who likes all questions to be answered with bullet points followed by a patronizing equivocation.

RLHF using PPO has a straightforward-enough recipe (note this is not the only way to do it, but it is how you’ll do it via the HuggingFace ecosystem):

  1. You start with a large number of well diversified inputs;

  2. You have the LM (the “policy model”) generate multiple outputs for each input;

  3. You get your annotation team to score the outputs using a preference-ranking system;

  4. You train a “reward model” to be capable of predicting which output the humans will prefer.

And now you do the following, in a loop, until the improvements level off or you run out of cloud credits:

  1. You select a set of example inputs;

  2. You generate two outputs per input;

  3. You use the reward model to predict which output will be preferred;

  4. You use PPO to adjust the policy model’s weights, steering it towards the outputs with higher reward.

PPO has some downsides, however. Principle among these is that it’s memory hungry. (Typically, your reward model is your full LM, you just swap the language head for a reward head). We were using low-rank adaptors for both the reward model and for the policy model, so this wasn’t so much a problem, but it is an overhead. (You’ll also need to maintain a copy of the original fine-tuned checkpoint, which will constitute a “reference model” for the PPO algorithm to prevent over-fitting to the reward model.)

All this means that even in a parameter-efficient training regime, you’ll still be hot-swapping adaptors during the RLHF process which adds to the memory and compute requirements of the whole exercise.

A recent(ish) alternative to RLHF with PPO is the “Direct Preference Optimisation” (DPO) algorithm, originally proposed by Google Deepmind. DPO is rather elegant re-parameterisation of the RL set-up, resulting in an optimisation function with an implicit rather than an explicit reward function.

With no need to train a reward model, nor to keep one in memory during training, DPO is undeniably much more efficient. But is it as good?

DPO vs RLHF: the showdown

Irritatingly, the situation is not very clear cut. Since the publication of Direct Preference Optimization: Your Language Model is Secretly a Reward Model in mid 2023**,** various teams of researchers have weighed into the DPO vs. PPO debate - often with conflicting and nuanced results. For example:

  • In Reward Model Learning vs. Direct Policy Optimization: A Comparative Analysis of Learning from Human Preferences, authors Nika et. al. perform a mathematical analysis of both RLHF and DPO. Among the results is an important one on non-realisable rewards - i.e. when the reward function cannot precisely model human preferences. This can (actually, let’s say will) happen when you ask humans to preference-score model outputs. Some level of inconsistency is almost guaranteed for all but the most trivial of comparisons. In these situations, RLHF solutions have an irreducible bias whereas DPO solutions can have an asymptotically reducing bias, achieved by simply increasing the size of the preference-learning dataset.

  • In Unpacking DPO and PPO: Disentangling Best Practices for Learning from Preference Feedback, authors Ivison et. al. conduct a set of controlled experiments on reward learning. Among the results are two which stand out:

    • Larger reward modelling datasets lead to improved downstream performance on the specific task being demonstrated.

    • PPO outperforms DPO across the range of reward-learning datasets and benchmarks being tested.

  • Recently (at the time of writing - and too late for our project) the Allen Institute for AI have published a detailed report into post-training recipes, in which they noted that PPO reached similar levels of performance to DPO but took 14 times as much compute.

So which did we choose? Our interpretation of the available evidence was that - although (theoretically speaking) DPO ought to perform best - in actual usage it’s often hard to realise these advantages and PPO can offer superior results.

We chose to use RLHF with PPO. To summarise, the key factors in this decision were:

  1. Project timescales did not allow for a full schedule of experimentation and hyperparameter tuning for both methods;

  2. The balance of both anecdotal and rigorous reports we read suggested that PPO - although sometimes unstable - often gave better results than DPO, despite the additional encumbrance of having to train a reward model;

  3. Our local model was a 7B parameter class GPT, small enough that the additional memory and compute overhead of using a reward model was not an issue.

So what did we learn as a result?

Lesson 1: put the effort into getting excellent preference annotations

There is a world of difference between the results you’ll get from a well-run annotation project as compared to those from a poorly run one. There are plenty of experts and companies out there who can help you run a good data annotation project (ahem, consider contacting me if interested). Our experience of RLHF annotation suggests that the following are particularly important:

  1. Make sure you write a comprehensive brief for your annotators. Have a trial annotator read it, then check that their interpretation of the task exactly matches your intention by having them do some annotations and walk you through their thinking. Use this exercise to properly refine the brief - for example adding explicit instructions for handling edge cases.

  2. In general, it’s better to ask an annotator to score multiple aspects of a pair of completions, rather than to give a single preference score. It forces the annotators to explicitly and separately consider different aspects of the outputs. (For us, we asked questions along the line of “which result produced a more comprehensive and accurate list of the significant entities?”, “which result did a better job of inferring the key relationships?” and “which result produced more concise and accurate summaries?”. We then simply summed the aspect-based preferences to get an overall preference.

  3. Annotation is a cognitively demanding task and can quickly lead to exhaustion. Typically, annotators will just crack on, but the quality of their work will degrade as they tire. To help mitigate this, it’s well worth writing a simple annotation app tuned for your task (e..g using GradIO) to assist them. Simple tricks like parsing, highlighting and structuring texts and coordinating the layout will make it much easier on the eye and you’ll really see the difference in quality. As an added bonus, an app will make it easy to track the rate of annotations and release new batches. This is helpful for managing people and budgets, and for forecasting the number of samples you’ll generate.

  4. Don’t let the full team of annotators loose until they’ve been trained and have completed some trial annotations. Review the trial annotations as a group and discuss. You will discover problems, further edge cases and factors that you hadn’t considered. This pre-work will improve the quality of the results.

  5. Be sure to track the inter-annotator agreement (IAA) during the annotator training and during the project as a whole. With human preference ranking, you will never get 100% agreement, but it is important to know at the beginning (once you have completed the training) what the IAA is. If it drops during the full exercise, call a halt to proceedings and investigate. The other important fact about the IAA is that it places an upper bound on the performance of your reward model: remember that it cannot pick winning completions more successfully than the degree of agreement within your annotation team!

  6. Finally (and this is an obvious point but easy to forget), if you are generating three or more completions per example, make sure you split your examples into training, validation and test splits before you generate your pairs, not afterwards.  For example: if you are generating 4 responses to be preference-ordered, your resulting dataset will contain C(4, 2) = 6 preference pairs.  All six of those pairs need live within the same split to prevent leakage.

Lesson 2: finalise SFT before conducting preference tuning

First off, make sure that there is a hard end to your supervised fine-tuning process before you start the RLHF phase. The samples you draw for preference annotation should come from the checkpoint of the model that you will be training.  If you continue to finesse the LM after you’ve drawn and annotated the RLHF samples, then you’re effectively conducting off-policy reward training and PPO just won’t work as well.  We made this mistake early on and we had to re-annotate our first batch of samples which pleased no one…

Lesson 3: tune your sample generation policy

On to the sample generation itself.  PPO is an on-policy algorithm so you’ll be drawing fresh samples at each epoch of the training run.  One, seemingly underexplored question is this: what should the sampling policy be? With a little intuition and absolutely no theoretical justification, we settled on the following:

  1. For each example, generate a deterministic completion (i.e. with a sampling temperature of 0). Make this response one of your pairs.

  2. Now set the model’s temperature parameter to a reasonably high value and sample your other responses.

  3. Make sure you record the joint-log-probability of each response in the resulting dataset.

  4. Shuffle the order of the responses. (i.e. the deterministic response should not always be the first option in the preference dataset!)

To tune the temperature setting, we looked at the relationship between the log-probs of the completions and compared it to the preference scores of the trial annotator. We then adjusted the temperature to as high as it could go until the annotator started to prefer the deterministic result more than 50% of the time.

The intuition here is that you’ll get a better reward model by correctly balancing the diversity of the completions with their likely preference scores. Did this really enhance the end results? I’m not sure, but it’s something I hope to explore further in the future.

Lesson 4: the reward model doesn’t have to be perfect

We did our RLHF using the trl library. Our first attempts yielded a reward model that classified correctly a little over 60% of the time (vs. 50% for random guessing).  With some finessing of the data and the training procedure we were able to raise this to around 75%.  It turns out that this was more than sufficient for PPO to yield a dramatic improvement in model behaviour, despite our worries about accuracy. I’ve not come across a detailed analysis on reward model quality, but I thought it would be useful to note that we got good results despite the fact that the reward model was highly imperfect.

Lesson 5: track the distribution of reward values

PPO is fiddly and computationally intensive.  trl makes things easier with its multi-adaptor learning, enabling you to hot-swap between the reward adaptor and the RL adaptor.

Technical caveat: if you’re using the HuggingFace ecosystem and you trained an adaptor during the SFT stage, you’ll want to merge the adaptor with the base model (at full precision) before starting PPO. This will ensure that the SFT checkpoint becomes the reference model used by PPO. If you do not - e.g. you simply initialise the policy adaptor using the SFT adaptor - then the reference model will be the base model and this will lead to problems with the PPO updates.

One thing we discovered was the need to pay attention to the distribution of reward values generated during PPO.  We kept a running log of these values and kept comparing their distribution with that seen when we evaluated the reward model.  We noticed that these distributions would sometimes diverge, with the reward model starting to score more of the model completions with values that would have constituted outliers on its training set.  When this happened, PPO would swiftly go awry (either crashing due to numerical instabilities or yielding a model which produced nonsense).

Sometimes, restoring from the most recent checkpoint and restarting PPO with a more aggressive decrease in the learning rate schedule, helped. But eventually, we learned that a more reliable approach was to perform RLHF in iterations.

Lesson 6: stabilise the PPO updates by proceeding iteratively

Recall from above that the further your PPO-optimised model diverges from your SFT-optimised model, the less representative of model outputs your RLHF annotations will be and so the less stable your reward model will be.

We ended up running three iterations of RLHF. At each iteration we repeated the initial sampling, annotation, reward training and PPO steps.  Despite the overheads involved in repeating a process several times, this did end up helping to stabilise the training process.  (I’d be interested to hear whether anyone else does the same thing!)

The Payoff

So how did it all pan out?

At the end of the project we did a blind test.  We asked our annotators to choose between ER reports generated by the frontier model and ER reports generated by our algorithm.  The result: the frontier model’s responses were preferred 27% of the time; our algorithm 49% of the time and no strong preference was expressed 24% of the time - and all this using 4-bit quantization at inference time!

So that’s very nearly a 2:1 preference ratio. It sounds like an incredible claim given the small size of the LM we were optimising, but let’s just put it into context for a moment:

  1. We were hyper-focused on performance at one task and really didn’t care about regressions on other emergent LM capabilities. Our model extracts entities and relationships very well, but I wouldn’t use it for anything else!

  2. Our model is highly tuned to the client’s preferences: specifically the types of entities and relationships they were interested in. Our model will almost certainly be less useful than a frontier for a general ER use case, or across a different document set.

Those caveats aside, what we ended up with fit the bill exactly: something which did an excellent job at a well described ER task and which yielded blazingly fast high throughput on a local deployment.

Previous
Previous

Developing an AI Strategy: Analysis & Synthesis

Next
Next

Developing an AI Strategy: Effective Research