Protein Attention Maps with ESM-2

We can continue with the BioLM endpoint that provides contact maps, embeddings, and logits from ESM-2. This time, we can plot the attention map for a sequence, providing a bit of insight into how the model looks at each residue.

import time
from biolmai import BioLM
from matplotlib import pyplot as plt
SEQ = "KRVPRTGWVYRNVEKPESVSDHMYRMAVMAMVTRDDRLNKDRCIRLALVHDMAECIVGDIAPADNIPKEEKHRREEEAMKQITQLLPEDLRKELYELWEEYETQSSEEAKFVKQLDQCEMILQASEYEDLENKPGRLQDFYDSTAGKFSHPEIVQLVSELETERNASMATASAEPG"

print("Sequence length: {}".format(len(SEQ)))
Sequence length: 176

Define Endpoint Params

Let's make a secure REST API request to BioLM API to quickly make the prediction on GPU. First, we need to import the requests package.

sequence = "MSILVTRPSPAGEELVSRLRTLGQVAWHFPLIEFSPGQQLPQLADQLAALGESDLLFALSQHAVAFAQSQLHQQDRKWPRLPDYFAIGRTTALALHTVSGQKILYPQDREISEVLLQLPELQNIAGKRALILRGNGGRELIGDTLTARGAEVTFCECYQRCAIHYDGAEEAMRWQAREVTMVVVTSGEMLQQLWSLIPQWYREHWLLHCRLLVVSERLAKLARELGWQDIKVADNADNDALLRALQ"
params = {
            "include": [
                "mean",
                "contacts",
                "logits",
                "attentions"
            ]
    }

start = time.time()
result = BioLM(entity="esm2-650m", action="encode", type="sequence", items=[sequence], params=params)
end = time.time()
print(f"ESM2 attention map generation took {end - start:.4f} seconds.")
ESM2 attention map generation took 38.6824 seconds.

There are keys in the results containing:

  • our attentions, or attention map, which is len(seq) x n_layers
  • contacts, or contact map, which is a len(seq) x len(seq) matrix
  • the logits from the final hidden state, which is a vector of len(seq)
  • mean_representations, which are the protein embeddings which is a vector of 1280
  • lastly name, which is simply the index of the sequence in the order it was POSTed
attention_map = result['attentions']

# Straight from the model, this would be 223, 223 due to start/end tokens,
# but the endpoint cleans that up for us
nrow = len(attention_map)
ncol = len(attention_map[0])

print(f'({nrow}, {ncol})')
(33, 246)

We have 33 rows here since the model has that many layers. The sequence length was 176 aa.

plt.figure(figsize=(14, 9))
plt.xlabel('Residue')
plt.ylabel('Layer')
plt.title('Attention Map for ESM2 and Example Protein')
plt.imshow(attention_map, cmap='viridis', interpolation='nearest')
plt.show()
No description has been provided for this image

See more use-cases and APIs on your BioLM Console Catalog.


BioLM hosts deep learning models and runs inference at scale. You do the science.


Contact us to learn more.

<span></span>

Accelerate yourLead generation

BioLM offers tailored AI solutions to meet your experimental needs. We deliver top-tier results with our model-agnostic approach, powered by our highly scalable and real-time GPU-backed APIs and years of experience in biological data modeling, all at a competitive price.

CTA

We speak the language of bio-AI

© 2022 - 2025 BioLM. All Rights Reserved.