Guide to Perceiver: A Scalable Transformer-based Model
Real-World Data comes in several modalities such as audio, text, video and images. Understanding different kinds of data and extracting patterns in it requires algorithms and models specific to the modality of the data. For example, CNNs are preferred to use on image data whereas attention-based models preferred for text data. But biological systems do not use disparate models to process data of diverse modalities. They process data from different modalities simultaneously. Inspired by this biological perception approach Perceiver, a transformer-based model was introduced by Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals and Joao Carreira on 4th March 2021. In this post let’s explore the Perceiver Model.
Architecture
Perceiver is a transformer-based model that uses both cross attention and self-attention layers to generate representations of multimodal data. A latent array is used to extract information from the input byte array using top-down or feedback processing i.e. attention to a byte array is controlled by the latent array which already has information about the byte array from the previous layer.
Complexity
The attention layers in the perceiver model are the most memory and time-intensive blocks. Vanilla Self Attention has quadratic complexity i.e if we have m inputs in the byte array it would take O(M2) memory to get attention values. The perceiver model solves this by using a low dimensional latent array for calculating attentions. The complexity of Self Attentions in the latent transformer blocks will reduce to O(N2) where N is the size of the latent array. Cross Attention layer complexity reduces to O(M X N). Here N is much smaller compared to M. This is especially helpful for data modalities with high bandwidth.
Iterative Attention
We can see that the byte array is used multiple times in the architecture. Each time an attention layer attends to the byte array with queries from the latent array. The output will have drawn some information from the byte array and this output is again used(after passing through a Latent Transformer block) to query the byte array again. This process can be repeated a number of times by making the network very deep. This results in the model selecting the right information from the byte array. This process can be optimized by sharing the parameters across transformer blocks. This makes it very similar to an RNN model.
Positional Embeddings
Transformer Models are permutation invariant i.e they ignore the sequence information in the data. While this may be good for generalization across modalities, this is bad when it comes to data where sequence information is crucial ex: Images, Text etc.
Just like in many other transformer architectures this problem is solved by injected positional embeddings in the inputs. Fourier features are used as positional encodings in this model. This allows flexibility in the number of dimensions as well as the length of the dimensions of data. Frequencies are log uniformly sampled from n bands of frequencies. n is predetermined by us. This allows us greater control over the encodings while retaining the flexibility of the embeddings. The values of Fourier transform at these frequencies gives us the embeddings which are then concatenated with the inputs.
Perceiver model performed on par with models with assumptions about the structure of the data. It got SOTA results using Image data, raw audio, video, audio + video and point clouds in 3D space. Let’s see how we can train a Perceiver model and run inference on it..
Code
Installation
A PyTorch implementation of the model was made available by authors. We can directly install this implementation from pip using the following command.
pip install perceiver-pytorch
Data Loading
Let’s use the CIFAR10 image dataset for training the model. Following is the boilerplate code for loading the data. The code mentioned below, is referenced to official pytorch tutorial.
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Model
Model architecture can be loaded using.
from perceiver_pytorch import Perceiver model = Perceiver( input_channels = 3, # number of channels for each token of the input input_axis = 2, # number of axis for input data (2 for images, 3 for video) num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1) max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is depth = 6, # depth of net num_latents = 32, # number of latents, or induced set points, or centroids. different papers giving it different names cross_dim = 128, # cross attention dimension latent_dim = 128, # latent dimension cross_heads = 1, # number of heads for cross attention. paper said 1 latent_heads = 2, # number of heads for latent self attention, 8 cross_dim_head = 8, latent_dim_head = 8, num_classes = 10, # output number of classes attn_dropout = 0., ff_dropout = 0., weight_tie_layers = False # whether to weight tie layers (optional, as indicated in the diagram) )
Training Loop
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
dev=torch.device("cuda:0")
model.to(dev)
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(dev), labels.to(dev)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs.permute(0,2,3,1))
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 1000 == 0: # print every 1000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 1000))
running_loss = 0.0
print('Finished Training')
Inference
def predict(inputs): preds=model(inputs) preds=preds.detach().cpu().numpy() labels=[classes[np.argmax(i)] for i in preds] return labels predict(inputs.permute(0,2,3,1).to(dev))
References
The post Guide to Perceiver: A Scalable Transformer-based Model appeared first on Analytics India Magazine.




