Using TPU VM instance w/ pre-alpha timm bits setup as per: https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits#readme
python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet --config hparams.yaml
Note the config yaml files do have args that are not used or active based on other overriding code or the state of the current training code. The bits
code is under heavy development so these configs will likely need specific revision (currently https://github.com/rwightman/pytorch-image-models/commit/5e95ced5a7763541f7219f35fd155e3fbfe66e8b)
The gMlp hparams are the last (latest) in the series and likely will produce better results than the earlier gmixer / resmlp variants...
Note, for adapting the LR to differenrt batch size. AdamW is being used here and I use a sqrt scaling for the learning rate wrt to (global) batch size. I typicall use linear LR scaling w/ SGD or RMSProp for most from-scratch training.