Contains methods for transferring.
import os

We'll create collections of Embedding layers, which will be used to test our transfer methods.

emb_szs1 = ((3, 10), (2, 8))
emb_szs2 = ((2, 10), (2, 8))
embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1])
embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])
embed1
ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(2, 8)
)

Now, we'll create collections containing required metadata.

newcatcols = ("new_cat1", "new_cat2")
oldcatcols = ("old_cat2", "old_cat3")

newcatdict = {"new_cat1" : ["new_class1", "new_class2", "new_class3"], "new_cat2" : ["new_class1", "new_class2"]}
oldcatdict = {"old_cat2" : ["a", "b"], "old_cat3" : ["A", "B"]}
json_file_path = "../data/jsons/metadict.json"

with open(json_file_path, 'r') as j:
     metadict = json.loads(j.read())

metadict is a Dict with the keys as the classes in dest. model's data, and value is another Dict where mapped_cat corresponds to the class in src model's data, along with information about how the classes map from dest. data to src data.

metadict
{'new_cat1': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['a', 'b'],
   'new_class2': ['b'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat3',
  'classes_info': {'new_class1': ['A'], 'new_class2': []}}}
df = pd.DataFrame({"old_cat1": [1, 2, 3, 4, 5], "old_cat2": ['a', 'b', 'b', 'b', 'a'], "old_cat3": ['A', 'B', 'B', 'B', 'A']})
cats = ("old_cat2", "old_cat3")
embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats, path="tempwtbson")

get_metadict_skeleton[source]

get_metadict_skeleton(df:DataFrame, catcols=None, path=None)

get_metadict_skeleton(df)
{'old_cat2': {'mapped_cat': '', 'classes_info': {'a': [], 'b': []}},
 'old_cat3': {'mapped_cat': '', 'classes_info': {'A': [], 'B': []}}}
transferembeds_
(Module,Module) -> transferembeds_
(Module,dict) -> transferembeds_
(Module,PosixPath) -> transferembeds_
(Module,str) -> transferembeds_

Embeddings before transfer:

embed1.state_dict()
OrderedDict([('0.weight',
              tensor([[ 0.1291,  0.9989, -0.1258,  0.4697, -1.8180, -0.4062,  0.7807, -2.4058,
                        2.2032,  1.8388],
                      [ 0.1456,  0.2293,  0.2135,  0.4504, -1.4981, -0.2788,  0.9045,  0.1295,
                       -0.9927, -0.0125],
                      [-1.6132,  0.8939, -0.2192, -0.7470, -0.5318, -2.4357, -0.0404, -0.8680,
                        0.2412,  1.8898]])),
             ('1.weight',
              tensor([[-0.0732, -1.5366,  0.6748,  1.9617, -0.7229,  1.9168, -0.2036, -0.5741],
                      [ 0.6987, -1.2535, -0.2394, -0.3216,  0.9821,  1.1238,  2.2877, -0.7127]]))])
embed2.state_dict()
OrderedDict([('0.weight',
              tensor([[-1.7019, -0.6068,  1.3590,  3.2759,  1.5049,  1.1870,  0.5087,  0.6172,
                        0.0863, -0.5930],
                      [-1.8109, -0.6033,  1.1796,  0.6103,  0.6482,  1.4825, -1.3552,  1.0069,
                        0.1493, -1.3304]])),
             ('1.weight',
              tensor([[ 2.3631,  0.1950, -1.3559,  0.0663,  0.1289,  0.5940, -0.0549,  0.2415],
                      [ 0.6715, -1.1929, -0.2372, -1.3345, -1.2651, -0.0468, -0.0934, -2.2118]]))])
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embdict, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embed2, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, oldcatdict=oldcatdict, newcatdict=newcatdict)
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, pathlib.Path("tempwtbson"), metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)

Embeddings after transfer:

embed1.state_dict()
OrderedDict([('0.weight',
              tensor([[-1.7564, -0.6051,  1.2693,  1.9431,  1.0765,  1.3347, -0.4232,  0.8120,
                        0.1178, -0.9617],
                      [-1.8109, -0.6033,  1.1796,  0.6103,  0.6482,  1.4825, -1.3552,  1.0069,
                        0.1493, -1.3304],
                      [-1.7564, -0.6051,  1.2693,  1.9431,  1.0765,  1.3347, -0.4232,  0.8120,
                        0.1178, -0.9617]])),
             ('1.weight',
              tensor([[ 2.3631,  0.1950, -1.3559,  0.0663,  0.1289,  0.5940, -0.0549,  0.2415],
                      [ 1.5173, -0.4990, -0.7966, -0.6341, -0.5681,  0.2736, -0.0741, -0.9852]]))])
embed2.state_dict()
OrderedDict([('0.weight',
              tensor([[-1.7019, -0.6068,  1.3590,  3.2759,  1.5049,  1.1870,  0.5087,  0.6172,
                        0.0863, -0.5930],
                      [-1.8109, -0.6033,  1.1796,  0.6103,  0.6482,  1.4825, -1.3552,  1.0069,
                        0.1493, -1.3304]])),
             ('1.weight',
              tensor([[ 2.3631,  0.1950, -1.3559,  0.0663,  0.1289,  0.5940, -0.0549,  0.2415],
                      [ 0.6715, -1.1929, -0.2372, -1.3345, -1.2651, -0.0468, -0.0934, -2.2118]]))])
os.remove("tempwtbson")

Export