Project code is currently in Colab as I want to practice implementing things in Jax.

  • batches are just for efficiency reasons: if we only train on one sample, we can imagine that the data is not going to take up a lot of spaces on the GPU! by adding more samples to process “in parallel”, we are effectively filling the available GPU spaces.