PKM achieved faster training compared to Top-K Transcoder with the same number of latents due to reduced encoder parameters, with a lower expansion ratio.
The input is projected into two small sub-encoders, each selecting their Top-K activations. From the combinations (products) of Top-K from both sub-encoders, only the top K are selected again, and these indices are mapped to the decoder to perform reconstruction in sparse (latent) space. Due to the squared ratio between latents, even with more encoders, the combination effect significantly reduces encoder parameters leading to faster training, and natural grouping between latents slightly improves interpretability.
However, there's no particular reason why it has to be specifically two sub-encoders.