X
Innovation

Meta's GenAI moves from simple predictions to a chess game of consequences

In place of ChatGPT's conventional approach, Meta scientists argue for optimized 'multi-token prediction' - one that penalizes wrong answers.
Written by Tiernan Ray, Senior Contributing Writer
meta-2024-multi-token-prediction.png

A schematic of Meta's approach to what's called multi-token prediction. During training of the AI model, the inputs are fed in as usual, but instead of the AI model being trained to produce a single token as a response - the next most likely word, say - the model is trained to simultaneously generate four or more likely tokens.

Meta

Generative AI models such as GPT-4 have astounded us all with the ability to produce textual output that resembles thought, such as answers to multiple-choice questions. Reaching the "right" thought, however, such as answering the question, remains a deeper problem, as evidenced by the phenomenon of "hallucinations," where AI models will assert -- with apparent confidence -- false statements.

In a new work, scientists at Meta have tweaked large language models (LLMs) to produce output that could be more correct in a given situation, by introducing the notion of penalties for wrong answers. 

Also: Meta's 'pruning' of Llama 2 model shows path to slimmer AI

The approach, known as "multi-token prediction," seeks to instill in an AI model a cost for less desirable answers. In that sense, it is analogous to popular approaches for establishing guardrails in AI such as "reinforcement learning from human feedback," or RLHF, a method OpenAI popularized to curb ChatGPT's  most outrageous outputs. 

(An "AI model" is part of an AI program containing numerous neural net parameters and activation functions that are the key elements for an AI program's functions.)

"Gains are especially pronounced on generative benchmarks like coding, where our models consistently outperform strong baselines by several percentage points," write the authors of "Better & Faster Large Language Models via Multi-token Prediction." Lead author Fabian Gloeckle, joined by colleagues at Facebook AI Research and collaborating institutions CERMICS École des Ponts ParisTech and LISN Université Paris-Saclay, posted the paper last month on the arXiv pre-print server.

The authors' principal concern is that LLMs -- despite their impressive accomplishments -- don't achieve things such as reasoning or planning. The conventional approach of ChatGPT and the rest, called "next-token prediction," they write, "remains an inefficient way of acquiring language, world knowledge, and reasoning capabilities."

Instead of simple next-token prediction, where the AI model is trained to predict a single "token," such as a word or character in a string of tokens -- say, the next word in a sentence -- the Meta team's multi-token version is trained to predict multiple tokens of text simultaneously, each of which could be the correct completion of the sequence. 

Technically, Gloeckle's team alter the basic structure of the LLM, known as a Transformer, so that it has four output "heads" that each produce a word or character or other symbol, rather than the standard single head.

The approach's immediate benefit is that it can be more memory-efficient when the AI model is live, making predictions for users, known as the inference stage of AI. Because multiple output heads can be working behind the scenes to try possibilities, a high degree of parallelism can happen. This form of "speculative decoding" means the multi-token approach "can speed up inference by a factor of 3×" versus predicting one thing at a time.

Also: Meta unveils second-gen AI training and inference chip

There's also a more profound insight. Normal AI models picking one token at a time are -- in a sense -- flat: They don't view any single prediction as more important than the last, as long as the current prediction is a good one. 

In fact, the team notes there is a big difference between certain tokens in a phrase. In the oft-cited punctuation meme -- "stop clubbing, baby seals" -- the presence or absence of a comma in the middle phrase is the difference between an urgent plea for animal rights and an amusing image. The humor in the utterance plays in the mind because the comma alters the semantics of the phrase.

The point, as others have observed, is that "not all token decisions are equally important for generating useful texts from language models," Gloeckle's team wrote. "While some tokens allow stylistic variations that do not constrain the remainder of the text, others represent choice points that are linked with higher-level semantic properties of the text and may decide whether an answer is perceived as useful or derailing."

Also: Rote automation is so last year: AI pushes more intelligence into software development

The multi-head, multi-token approach, the team wrote, assigns fitness to each prediction based on the other simultaneous predictions. "Generally, we believe that the quality of text generations depends on picking the right decisions at choice points, and that n-token prediction losses promote those," the team wrote.

The "choice point" involves those moments where one prediction entails others down the road that can make or break the total phrase. "Multi-token prediction implicitly assigns weights to training tokens depending on how closely they are correlated with their successors," the team wrote.

By analogy, Gloeckle's team liken choosing the next word to moving through a maze: Each choice can be a route to the reward, or a route to some terrible fate.

They use the image of a maze to illustrate the "sequential prediction task" (as they refer to predicting the next word). The next right step could be a pivotal one that sends the AI model on the right path or the wrong path -- a "consequential choice," as they term it.

meta-2024-choice-points-in-language-models.png

Choosing the next right token is like walking through a maze, write the authors: at certain moments, the choice is a "consequential" one that will send the program to success (the trophy) or defeat (skull and crossbones.) 

Meta

In a striking fusion of technologies, the authors link the multi-token approach to the RLHF approach, trying to predict a reward far down the line: "Assume that the language model is deployed in a reinforcement learning setting like in reinforcement learning from human feedback … [where] actions are single tokens […] to generate." 

Linking text prediction to reward functions in that way brings into play all the areas where reward functions have made great strides in gaming. Reward functions are used in all sorts of AI problems referred to as reinforcement learning, not just RLHF. 

For example, Google's DeepMind unit used reinforcement learning to develop AlphaZero, the program that can beat humans at chess and Go. It was also used in the program AlphaStar to compete in video game skill competitions against humans in the real-time strategy game StarCraft II.

Also: Snowflake says its new LLM outperforms Meta's Llama 3 on half the training

This gamification has the immediate result of producing a more "optimal" answer from the multi-token approach. The authors provide a variety of benchmark results. One, for example, compares how an AI model with 7 billion neural parameters, or weights, improves performance as it moves from single to multi-token prediction. 

On a test called "Mostly Basic Programming Problems," or MBPP, developed at Google in 2021, an AI model has to produce code such as lines of Python for a given function. On that benchmark, the program always achieves greater accuracy with multi-token prediction.

There's also a sweet spot. The AI model seems to perform best at four simultaneous tokens, while predicting more than that -- six or eight -- leads to results that are not as good.

meta-2024-benchmark-results-for-multi-token-prediction.png

On standardized tests such as "Mostly Basic Programming Problems," where an LLM has to generate programming code, the same-sized AI model, one with 7 billion neural parameters, or weights, achieves greater accuracy when more tokens are produced, as indicated by "n," the number of tokens simultaneously generated.

Meta

As with many things in neural networks, it's not immediately certain why multi-token prediction should be better than single-token prediction. The hunch the authors offer is that by training a model for multi-token prediction, the resulting model avoids a disconnect that happens when the AI model makes live predictions with real prompting from users. That's what's called a "distribution mismatch between teacher-forced training and autoregressive generation."

Also: You can make big money from AI - but only if people trust your data

There are still many things to figure out, Gloeckle and his colleagues wrote. One goal is to develop a method of automating the sweet spot, the optimal number of simultaneous tokens that leads to the greatest accuracy. Another is how to automatically determine the right amount of data needed to train the AI model, given that "optimal vocabulary sizes for multi-token prediction are likely different from those for next-token prediction, and tuning them could lead to better results."

A larger takeaway is that traditional reinforcement learning may have much more to offer generative AI than many have suspected to date, suggesting there will be more fusion of the two methodologies down the road.

Editorial standards