Entity Embedding in fastai 2
This post aims to introduce how to use fastai v2 to implement entity embedding for categorical variables in tabular data. Entity embedding is a powerful technique that can sometimes boost the performance of various machine learning methods and reveal the intrinsic properties of categorical variables. See this paper and this post for more details.
import seaborn as sns
from fastai.tabular.all import *
We will use the California housing dataset for this post. We want to develop a model to predict the median_house_value based on other variables in the dataset. The following code shows the first 5 rows of the dataset.
df = pd.read_csv('housing.csv')
df.head()
We rely on various functionalities provided by fastai to preprocess the data. We will not discuss them in detail here. For more information, please check this tutorial. One important note is that when dealing with categorical variables, instead of using one-hot encoding, we map each category into a distinct integer.
cont, cat = cont_cat_split(df, dep_var = 'median_house_value')
splits = RandomSplitter(valid_pct=0.2)(range_of(df))
to = TabularPandas(df, procs=[Categorify, FillMissing,Normalize],
cat_names = cat,
cont_names = cont,
y_names='median_house_value',
splits=splits)
dls = to.dataloaders(bs=64)
We will train a deep learning model to predict the median housing value and thus get the trained embedding for categorical variables. Again, see this tutorial for details about the meaning of the codes here.
learn = tabular_learner(dls, metrics=rmse)
early_stop_cb = EarlyStoppingCallback(patience=2)
learn.fit_one_cycle(10, cbs=early_stop_cb)
We will retrieve the trained embedding matrix from the model.
embs = [param for param in learn.model.embeds.parameters()]
The list has two elements and each element represents an embedding matrix for a categorical variable.
len(embs)
To check what each element corresponds to, we can use the cat_names
attributes from TabularPandas
. The list indicates that the first element in embs
is the embedding matrix for the variable ocean_proximity.
to.cat_names
Let's see this matrix. Note that we convert it from a tensor array to a numpy array to make the operation later easier.
ocean_emb = embs[0].detach().numpy()
ocean_emb
Each row in the matrix above corresponds to one category in the categorical variable. The categories for ocean_proximity are the following.
cat = to.procs.categorify
ocean_cat = cat['ocean_proximity']
ocean_cat
We can use the o2i
attribute to see how each category is mapped into an integer. The integer corresponds to the row in the embedding matrix.
ocean_cat.o2i
For example, the embedding for 'NEAR BAY' is:
ocean_emb[4]
Let's create a dictionary to map each category to its corresponding embedding.
ocean_emb_map = {ocean_cat[i]:ocean_emb[i] for i in range(len(ocean_cat))}
This section shows how we can apply the embedding to the original dataset; that is, change the category variable to numeric vectors.
emb_dim = ocean_emb.shape[1]
col_name = [f'ocean_emb_{i}' for i in range(1,emb_dim+1)]
df_emb = pd.DataFrame(df['ocean_proximity'].map(ocean_emb_map).to_list(), columns=col_name)
df_emb.head()
df_new = pd.concat([df, df_emb],axis=1)
Another way to explore the embedding matrix is to visualize it and see what it learns. We will use principle component analysis(PCA) to visualize the embedding.
from sklearn import decomposition
pca = decomposition.PCA(n_components=2)
pca_result = pca.fit_transform(ocean_emb)
df_visualize = pd.DataFrame({'name':ocean_cat, 'dim1':pca_result[:,0], 'dim2':
pca_result[:,1]})
sns.scatterplot(data=df_visualize, x='dim1', y='dim2', hue='name');
If we compare this visualization to the map of these houses above, we will find that the embedding matrix does provide some useful insights. In particular, we see that the relative location of the category INLAND does correspond to the inland area in the map. The code to make the following map comes from this notebook.
This is the end of this post. You can see the full code by clicking to the View on Github or Open in Colab tab at the top this post.