API: embeddings

The framework’s compute_graph_embeddings dispatches to one of these classes. You can also use them directly.

GraphEmbeddings

NEExT.embeddings.GraphEmbeddings — distribution-based embeddings (vectorizers).

GraphEmbeddings
GraphEmbeddings(
  graph_collection,
  features,
  embedding_algorithm,          # "approx_wasserstein" | "wasserstein" | "sinkhornvectorizer"
  embedding_dimension,
  feature_columns=None,
  random_state=42,
  memory_size="4G",
  suffix="",
)

compute() -> Embeddings

GNNEmbeddings

NEExT.embeddings.GNNEmbeddings — pure-PyTorch GNN (requires the gnn extra). Note it exposes a few extra knobs beyond the framework wrapper (train_ratio, val_ratio, device, verbose).

GNNEmbeddings
GNNEmbeddings(
  graph_collection,
  features,
  architecture="GCN",           # "GCN" | "GraphSAGE" | "GIN"
  embedding_dimension=16,
  random_state=42,
  hidden_dims=None,             # default [64, 32]
  epochs=100,
  learning_rate=0.01,
  weight_decay=5e-4,
  dropout=0.0,
  early_stopping_patience=10,
  train_ratio=0.8,
  val_ratio=0.1,
  pooling="mean",               # "mean" | "sum" | "max"
  device="cpu",                 # "cpu" | "cuda"
  verbose=False,
)

compute() -> Embeddings

Embeddings

NEExT.embeddings.Embeddings — a container around the embeddings DataFrame.

Embeddings
Embeddings(embeddings_df, embedding_name, embedding_columns)

# attributes
embeddings.embeddings_df       # DataFrame: graph_id, emb_0 ... emb_{D-1}
embeddings.embedding_name      # algorithm used
embeddings.embedding_columns   # List[str]

# methods
emb_a + emb_b                  # merge on graph_id, prefixing columns by algorithm -> Embeddings

See Embeddings.