Cluster by Embedding¶
ESM2 embedding-based K-means clustering and PCA visualization of a peptide library.
⚠️ Preview Feature — The
biolmai.pipelinemodule used in this guide is currently in preview and not yet publicly released. Access is available to early users on request. Contact us to get access.
What you'll learn:
- Generating ESM2-8m embeddings via the pipeline SDK
- K-means clustering with silhouette score selection
- PCA visualization of sequence space
- Diversity sampling across clusters
Requirements:
pip install biolmai[pipeline] matplotlib scikit-learn
export BIOLMAI_TOKEN=your-token-here
Setup¶
import os
from biolmai.pipeline import (
DataPipeline, DuckDBDataStore,
ThresholdFilter, RankingFilter,
ValidAminoAcidFilter, EmbeddingSpec,
DiversitySamplingFilter,
)
TOKEN = os.environ.get("BIOLMAI_TOKEN", "")
if not TOKEN:
raise EnvironmentError(
"Set BIOLMAI_TOKEN before running.\n"
"Get one at https://biolm.ai/ui/accounts/user-api-tokens/"
)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeansPeptide library¶
50 antimicrobial peptides from three families: magainins, defensins, and designed sequences.
MY_50_PEPTIDES = [
"GIGKFLHSAKKFGKAFVGEIMNS", "GIGKFLHSAGKFGKAFVGEIMKS", "GLFDIIKKIAESF",
"GLFDIVKKVVGALGSL", "FLPLILRKIVTAL", "GIKKFLGSIWKFIKAFVKEIMN",
"GLFDIIKKIAESFLPKV", "GIGKFLHSAKKFGKAFV", "GIGKFLHSAK", "GIGKFLHSAKKFGK",
"LLGDFFRKSKEKIGKEFKRIVQRIKDFLRNLVPRTES", "RLFDKIRQVIRKF",
"KWKLFKKIPKFLHLAKKF", "KWKLFKKIPKFLHLAK", "FKRIVQRIKDFL",
"RWKIFKKIEKVGRNVRDGIIKAGPAVAVVGQATQIAK",
"KWKLFKKIEKVGQNIRDGIIKAGPAVAVVGQATQIAK",
"GIGAVLKVLTTGLPALISWIKRKRQQ", "VDKGSYLPRPTPPRPIYNRN", "GKPRPYSPRPTSHPRPIRV",
"KLAKLAKKLAKLAK", "LKLLKKLLKLLKKL", "RRWWRRWWRR", "KWKWKWKWKW",
"RRLCRIVVIRVCR", "RRWQWR", "RWRWRW", "KFLKKAKKFGK",
"KWKLFKKI", "RLFDKIRQ",
"ACYCRIPACIAGERRYGTCIYQGRLWAFCC", "GCSKNKGCSVDKECAAFGSR",
"DCSCFGGKGETACNKCKTPEGKPCTEGKPCK", "GFCKLCCNPACGPNYKGICIDM",
"AFGKQLAKLAKSKLAKLAKLAK", "KLALKLALKALKAALKLA",
"WKWLKKWLKKLK", "RWKKWWRRKK", "KFKKLFKKLKK",
"FLGALFKALSKLL", "ALLKFLLKFLLK", "GLLDFLK",
"KKLFKKILKYL", "FFKDEL", "ILKKWPWWPWRR",
"GLWSKIKEVGKEAAKAAAKAAGKAALNAVTGGGKPG",
"KLKLLLLLKLK", "WKLFKKIKVWK", "KWKWFKKIKVWK", "KLFKKIKVWKK",
]
print(f"{len(MY_50_PEPTIDES)} peptides")Embed with ESM2-8m¶
ds = DuckDBDataStore("cluster_demo.duckdb")
cp = DataPipeline(sequences=MY_50_PEPTIDES, datastore=ds, verbose=True)
cp.add_filter(ValidAminoAcidFilter(), stage_name="validate")
cp.add_prediction(
"esm2-8m", action="encode",
embedding_extractor=EmbeddingSpec(key="embeddings", layer=6),
stage_name="embed", depends_on=["validate"],
)
cp.add_clustering(
method="kmeans", n_clusters=3,
similarity_metric="embedding", embedding_model="esm2-8m",
stage_name="cluster_k3", depends_on=["embed"],
)
cp.run()Choose k with silhouette scores¶
Because embeddings are cached, re-clustering at different k values costs nothing.
sequence_ids = [r[0] for r in ds.conn.execute("SELECT sequence_id FROM sequences").fetchall()]
emb_map = ds.get_embeddings_bulk(sequence_ids, model_name="esm2-8m")
mat = np.stack([emb_map[sid] for sid in sequence_ids])
scores = {}
for k in range(2, 9):
labels = KMeans(n_clusters=k, random_state=42, n_init=10).fit_predict(mat)
scores[k] = silhouette_score(mat, labels)
print(f"k={k}: silhouette={scores[k]:.3f}")
best_k = max(scores, key=scores.get)
print(f"\nBest k: {best_k}")PCA visualization¶
cluster_labels = [
int(r[0]) for r in ds.conn.execute(
"SELECT CAST(value AS INTEGER) FROM predictions WHERE prediction_type='cluster_k3' ORDER BY sequence_id"
).fetchall()
]
pca = PCA(n_components=2)
coords = pca.fit_transform(mat)
fig, ax = plt.subplots(figsize=(8, 6))
scatter = ax.scatter(coords[:, 0], coords[:, 1], c=cluster_labels,
cmap="Set1", edgecolors="black", linewidth=0.5, alpha=0.8, s=60)
plt.colorbar(scatter, ax=ax, label="Cluster")
ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)")
ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)")
ax.set_title("ESM2-8m embedding space — K-means clusters (k=3)")
plt.tight_layout()
plt.show()Diversity sampling¶
Select 20 maximally diverse representatives spanning all clusters.
diverse_ds = DuckDBDataStore("diverse_demo.duckdb")
dp = DataPipeline(sequences=MY_50_PEPTIDES, datastore=diverse_ds, verbose=True)
dp.add_filter(ValidAminoAcidFilter(), stage_name="validate")
dp.add_prediction(
"esm2-8m", action="encode",
embedding_extractor=EmbeddingSpec(key="embeddings", layer=6),
stage_name="embed", depends_on=["validate"],
)
dp.add_clustering(
method="kmeans", n_clusters=3,
similarity_metric="embedding", embedding_model="esm2-8m",
stage_name="cluster", depends_on=["embed"],
)
dp.add_filter(DiversitySamplingFilter(n_samples=20, method="random"), stage_name="diverse_20")
dp.run()
dp.summary()Cleanup¶
ds.close(); diverse_ds.close()
import os
for f in ["cluster_demo.duckdb", "diverse_demo.duckdb"]:
if os.path.exists(f): os.remove(f)