Contains helpful functions for extracting embeddings and preparing data for it.
import os
emb_szs = ((3, 10), (4, 8))
embed = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
embed
df = pd.DataFrame({"cat1": [1, 2, 3, 4, 5], "cat2": ['a', 'b', 'c', 'b', 'a'], "cat3": ['A', 'B', 'C', 'D', 'A']})
df
catdict = getcatdict(df, ("cat2", "cat3"))
cats = ("cat2", "cat3")
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtbson", kind="bson")
embdict
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtjson", kind="json")
embdict
load_bson("tempwtbson") == embdict
os.remove("tempwtbson")
os.remove("tempwtjson")