Helper functions#
- cebra.integrations.sklearn.helpers.align_embeddings(embeddings, labels, normalize=True, n_bins=100)#
Align the embeddings in the
embeddings
list to thelabels
.Each embedding has an associated set of labels. During alignment, the labels are digitalized so that all sets of digitalized labels contain the same set of values. Then the embeddings are quantized based on the new digitalized labels.
- Parameters:
embeddings (
List
[Union
[ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]],Tensor
]]) – List of embeddings to align on the labels.labels (
List
[Union
[ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]],Tensor
]]) – List of labels corresponding to each embedding and to use for alignment between them.normalize (
bool
) – If True, samples of the embeddings are normalized across dimensions.n_bins (
int
) – Number of values for the digitalized common labels.
- Return type:
List
[Union
[ndarray
[tuple
[int
,...
],dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]],Tensor
]]- Returns:
The embeddings aligned on the labels.