Protein LM Embeddings and Logits as Features¶
One of the best ways you can leverage a large language model is for feature generation. The internal, numeric representations the neural-net uses to make predictions can be output and used for downstream machine learning tasks. The numeric vectors from NLP models often encode additional, powerful information beyond simple one-hot encodings. Usually feature engineering for biology is heavily task-specific, but in this case the embeddings can be used for a variety of classification, regression, and other tasks.
On the backend, the process involves passing input sequences into the pre-trained model for tokenization and representation of the protein through its neural-net layers. Multiple representations of a protein - numeric vectors and/or matrices, such as attention maps - are created. Here we will quickly demo using ESM2 via GPU-backed REST API in order to quickly transform a sequence into embeddings without installing packages, setting up a GPU, and downloading the model.
from IPython.display import JSON  # Helpful UI for JSON display
import time
from biolmai import BioLM
import os
import pandas as pd
from glob import glob
import json
import asyncio
import random
import seaborn as sns
import xgboost as xgb
from sklearn import model_selection
from matplotlib import pyplot as plt
from IPython.display import JSON# Load sequences and fluorescence data
df = pd.concat([
    pd.read_csv(f) for f in
    glob(os.path.join('./', 'data', 'protein', 'data', 'fluor*.train.csv'))
])
df.drop_duplicates('seq', inplace=True)df.shape# SAMPLE some of the data
df_orig = df  # Retain original DF as `df_orig`
df = df_orig.copy().sample(50, random_state=42)
df.head(6)Here we have a DataFrame containing sequences and their measured fluorescence. We can use the ESM2 embeddings as features to perform a quick regression to predict fluorescence values; but first we need to actually generate the embeddings.
Let's write a function that takes a sequence and requests its embeddings via REST API. The ESM-2 encode endpoint documentation provides examples of the structure of a response.
Single Sequence Embeddings¶
test_protein = df.sample(1).seq.iloc[0]
print("Sequence length: {}\n{}".format(len(test_protein), test_protein))We can POST that sequence:
params = {
            "include": [
                "mean",
                "contacts",
                "logits",
                "attentions"
            ]
    }
response =  BioLM(entity="esm2-650m", action="encode", type="sequence", items=[test_protein], params=params)JSON(response)ach dictionary item embeddings contains the mean representations of a layer(s) from ESM2. In this case, we return the embeddings from the the final hidden layer, 33.
Let's load this representation and look at its shape.
embed_single = pd.DataFrame(response['embeddings'][0]['embedding']).T
embed_singleWe can see that while the original sequence is 237 residues, the LLM uses a vector of 1280 to represent sequences. So, anytime we request an embedding for a sequence, we'll get back a representation that is the same size as another sequence. This makes downstream ML, especially with other NNs, nice and easy since we don't have to worry about padding.
Let's get the embeddings for all sequences in the DF, asynchronously, via REST API.
Embeddings for All Seqs in DF¶
rets =  BioLM(entity="esm2-650m", action="encode", type="sequence", items=df['seq'].to_list(), params=params)len(rets)Now, instead of getting the embeddings, we could have written a function to retrieve the logits. In fact, they were returned by the same API endpoint - we simply need to use a different key. We can look at the single-sequence API request we made earlier to find them:
logits_single = pd.DataFrame(response['logits']).T
logits_singleOne could also use the sum or mean as a representation of sequence:
logits_single.iloc[0].sum()logits_single.iloc[0].mean()Modeling with XGBoost¶
In order to not require any transformations of the data, let's use a tree-based method to quickly create a regression model using these ESM2 embeddings.
We'll see how well we can model fluorescence with about 1,000 samples. With so few, we should make sure to uniformly sample our labels as much as possible.
This will take several minutes, depending on your download speed and connection bandwidth. We've added a slight jitter and limit of 5 concurrent connections above, to prevent thundering herds and overwhelming your connection. You can open your browser's Developer Tools (aka Inspector) and watch the traffic under the Network tab.
df_orig['bins'] = pd.cut(df_orig.label, bins=100)
df_orig['bins'] = df_orig['bins'].astype('str')
df_orig.bins.value_counts()sampled = []
df_orig = df_orig.sample(df_orig.shape[0], random_state=42)
for n, grp in df_orig.groupby('bins'):
    # At most 15 rows from each bin
    samp = grp.head(10).reset_index(drop=True)
    sampled.append(samp)
    
sampled_rows = pd.concat(sampled, axis=0).reset_index(drop=True)
df_xgboost = sampled_rows
df_xgboost.shapeops = [f(seq) for seq in df_xgboost['seq'].to_list()]
xgboost_embeddings = await asyncio.gather(*ops)embeddings = pd.DataFrame(xgboost_embeddings)
embeddings.shape# Create 80:20 train:test split
train_x, test_x, train_y, test_y = model_selection.train_test_split(
    embeddings,
    df_xgboost.label,
    test_size=0.2,
    random_state=54
)
print("X Train size: {}\nX Test size: {}".format(train_x.shape, test_x.shape))
print("Y Train size: {}\nY Test size: {}".format(train_y.shape, test_y.shape))embeddings.head(3)#Set up cross-validation modeling objective
data_dmatrix = xgb.DMatrix(data=train_x, label=train_y)
params = {
    'booster': 'gbtree',
    "objective": "reg:squarederror",
    'colsample_bytree': 0.40,
    'learning_rate': 0.2,
    'max_depth': 40,
    'eval_metric': 'rmse',
    'alpha': 0.8,
}
# Run CV
cv_results = xgb.cv(
    dtrain=data_dmatrix,
    params=params,
    nfold=5,
    num_boost_round=100,
    early_stopping_rounds=6,
    metrics="rmse",
    as_pandas=True,
    seed=42
)We can see how the performance of the model training started...
cv_results.head()cv_results.tail(10)# Final RMSE on test set
print((cv_results["test-rmse-mean"]).tail(1))# Final SD on test set
print((cv_results["test-rmse-std"]).tail(1))We can get context for these values by looking at the Y values that were used in this cross-validation experiment.
plt.figure(figsize=(6, 5))
train_y.hist()Now we can train a model with the cross-validated parameters, using our full training dataset.
xg_reg = xgb.train(params=params, dtrain=data_dmatrix, num_boost_round=80)Measuring the predicted values against our test set, let's see how well the model did using just the sequence embeddings features.
y_pred = xg_reg.predict(xgb.DMatrix(data=test_x, label=test_y))sns.regplot(x=test_y, y=y_pred)Lastly, we can attempt to inspect the model and learn a bit more about it, its fit, and our data.
xgb.plot_importance(xg_reg, max_num_features=20, grid=False)
plt.rcParams['figure.figsize'] = [10, 8]
plt.show()