Contains helpful functions for extracting embeddings and preparing data for it.
import os
emb_szs = ((3, 10), (4, 8))
embed = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
embed
ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(4, 8)
)

class JSONizerWithBool[source]

JSONizerWithBool(skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None) :: JSONEncoder

Extensible JSON <http://json.org> encoder for Python data structures.

Supports the following objects and types by default:

+-------------------+---------------+
| Python            | JSON          |
+===================+===============+
| dict              | object        |
+-------------------+---------------+
| list, tuple       | array         |
+-------------------+---------------+
| str               | string        |
+-------------------+---------------+
| int, float        | number        |
+-------------------+---------------+
| True              | true          |
+-------------------+---------------+
| False             | false         |
+-------------------+---------------+
| None              | null          |
+-------------------+---------------+

To extend this to recognize other objects, subclass and implement a
``.default()`` method with another method that returns a serializable
object for ``o`` if possible, otherwise it should call the superclass
implementation (to raise ``TypeError``).
df = pd.DataFrame({"cat1": [1, 2, 3, 4, 5], "cat2": ['a', 'b', 'c', 'b', 'a'], "cat3": ['A', 'B', 'C', 'D', 'A']})
df
cat1 cat2 cat3
0 1 a A
1 2 b B
2 3 c C
3 4 b D
4 5 a A
catdict = getcatdict(df, ("cat2", "cat3"))
cats = ("cat2", "cat3")
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtbson", kind="bson")
embdict
{'cat2': {'classes': ['a', 'b', 'c'],
  'embeddings': [[0.31570491194725037,
    -0.07632226496934891,
    1.5683248043060303,
    -0.417350172996521,
    -0.10798821598291397,
    1.4268646240234375,
    -0.22982962429523468,
    -0.16915012896060944,
    0.002859442261978984,
    -0.4939035475254059],
   [0.6530274748802185,
    -0.5577511191368103,
    -0.9275949001312256,
    -0.06805138289928436,
    -2.2739336490631104,
    0.1566399186849594,
    -0.0531904362142086,
    -0.43463948369026184,
    -0.0794961154460907,
    0.4645240008831024],
   [1.0870261192321777,
    -0.22893156111240387,
    -0.253396600484848,
    -0.3393022119998932,
    -2.0341274738311768,
    -0.31127995252609253,
    0.3499477803707123,
    -1.9891204833984375,
    0.674164891242981,
    -1.3391718864440918]]},
 'cat3': {'classes': ['A', 'B', 'C', 'D'],
  'embeddings': [[1.3585036993026733,
    0.024397719651460648,
    0.4804745614528656,
    1.1160022020339966,
    0.8734705448150635,
    0.784949004650116,
    -0.5678505897521973,
    0.33350786566734314],
   [0.3660596013069153,
    -0.9798707962036133,
    -0.2037343829870224,
    -0.22464202344417572,
    -1.0697559118270874,
    -0.6113787889480591,
    -0.9179865717887878,
    -0.8937533497810364],
   [1.278954029083252,
    0.18886688351631165,
    1.2901731729507446,
    0.5247588157653809,
    1.2530524730682373,
    -0.898102343082428,
    1.1512700319290161,
    1.5226550102233887],
   [-0.6129401922225952,
    -0.38670244812965393,
    0.7002972364425659,
    -1.2176426649093628,
    0.5013972520828247,
    -1.5657707452774048,
    -2.1267337799072266,
    0.5773623585700989]]}}
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtjson", kind="json")
embdict
{'cat2': {'classes': ['a', 'b', 'c'],
  'embeddings': [[0.31570491194725037,
    -0.07632226496934891,
    1.5683248043060303,
    -0.417350172996521,
    -0.10798821598291397,
    1.4268646240234375,
    -0.22982962429523468,
    -0.16915012896060944,
    0.002859442261978984,
    -0.4939035475254059],
   [0.6530274748802185,
    -0.5577511191368103,
    -0.9275949001312256,
    -0.06805138289928436,
    -2.2739336490631104,
    0.1566399186849594,
    -0.0531904362142086,
    -0.43463948369026184,
    -0.0794961154460907,
    0.4645240008831024],
   [1.0870261192321777,
    -0.22893156111240387,
    -0.253396600484848,
    -0.3393022119998932,
    -2.0341274738311768,
    -0.31127995252609253,
    0.3499477803707123,
    -1.9891204833984375,
    0.674164891242981,
    -1.3391718864440918]]},
 'cat3': {'classes': ['A', 'B', 'C', 'D'],
  'embeddings': [[1.3585036993026733,
    0.024397719651460648,
    0.4804745614528656,
    1.1160022020339966,
    0.8734705448150635,
    0.784949004650116,
    -0.5678505897521973,
    0.33350786566734314],
   [0.3660596013069153,
    -0.9798707962036133,
    -0.2037343829870224,
    -0.22464202344417572,
    -1.0697559118270874,
    -0.6113787889480591,
    -0.9179865717887878,
    -0.8937533497810364],
   [1.278954029083252,
    0.18886688351631165,
    1.2901731729507446,
    0.5247588157653809,
    1.2530524730682373,
    -0.898102343082428,
    1.1512700319290161,
    1.5226550102233887],
   [-0.6129401922225952,
    -0.38670244812965393,
    0.7002972364425659,
    -1.2176426649093628,
    0.5013972520828247,
    -1.5657707452774048,
    -2.1267337799072266,
    0.5773623585700989]]}}
load_bson("tempwtbson") == embdict
True
os.remove("tempwtbson")
os.remove("tempwtjson")

Export