pip install transfertab
TransferTab enables effective transfer learning from models trained on tabular data.
To make use of transfertab
, you'll need
* A pytorch model which contains some embeddings in a layer group.
* Another model to transfer these embeddings to, along with the metadata about the dataset on which this model will be trained.
The whole process takes place in two main steps-
1. Extraction
2. Transfer
Extraction
This involves storing the embeddings present in the model to a JSON
structure. This JSON
would contain the embeddings related to the categorical variables, and can be later transfered to another model which can also benefit from these categories. It will also be possible to have multiple JSON
files constructed from various models with different categorical variables and then use them together.
Here we'll quickly construct a ModuleList
with a bunch of Embedding
layers, and see how to transfer it's embeddings.
emb_szs1 = ((3, 10), (2, 8))
emb_szs2 = ((2, 10), (2, 8))
embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1])
embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])
embed1
We can call the extractembeds
function to extract the embeddings. Take a look at the documentation to see other dispatch methods, and details on the parameters.
df = pd.DataFrame({"old_cat1": [1, 2, 3, 4, 5], "old_cat2": ['a', 'b', 'b', 'b', 'a'], "old_cat3": ['A', 'B', 'B', 'B', 'A']})
cats = ("old_cat2", "old_cat3")
embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats)
embdict
Transfer
The transfer process involves using the extracted weights, or a model directly and reusing trained paramters. We can define how this process will take place using the metadict
which is a mapping of all the categories (in the current dataset), and contains information about the category it is mapped to (from the previous dataset which was used to train the old model), and how the new classes map to the old classes. We can even choose to map multiple classes to a single one, and in this case the aggfn
parameter is used to aggregate the embedding vectors.
json_file_path = "../data/jsons/metadict.json"
with open(json_file_path, 'r') as j:
metadict = json.loads(j.read())
metadict
We take a look at the layer parameters before and after transferring to see if it worked as expected.
embed1.state_dict()
embed2.state_dict()
transfer_cats = ("new_cat1", "new_cat2")
newcatcols = ("new_cat1", "new_cat2")
oldcatcols = ("old_cat2", "old_cat3")
newcatdict = {"new_cat1" : ["new_class1", "new_class2", "new_class3"], "new_cat2" : ["new_class1", "new_class2"]}
oldcatdict = {"old_cat2" : ["a", "b"], "old_cat3" : ["A", "B"]}
transferembeds_(embed1, embdict, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)
embed1.state_dict()
As we can see, the embeddings have been transferred over.