Jascha Sohl-Dickstein is a senior staff research scientist in Google Brain, and leads a research team with interests spanning machine learning, physics, and neuroscience. Recent projects have focused on theory of overparameterized neural networks, meta-training of learned optimizers, and understanding the capabilities of large language models. Jascha was previously a visiting scholar in Surya Ganguli’s lab at Stanford, and an academic resident at Khan Academy. He earned his PhD in 2012 in Bruno Olshausen’s lab in the Redwood Center for Theoretical Neuroscience at UC Berkeley. Prior to his PhD, he spent several years working for NASA on the Mars Exploration Rover mission.
The success of deep learning has hinged on learned functions dramatically outperforming hand-designed functions for many tasks. However, we still train models using hand designed optimizers acting on hand designed loss functions. Jascha will argue that these hand designed components are typically mismatched to the desired behavior, and that we can expect meta-learned optimizers to perform much better. He will discuss the challenges and pathologies that make meta-training learned optimizers difficult. These include: chaotic and high variance meta-loss landscapes; extreme computational costs for meta-training; lack of comprehensive meta-training datasets; challenges designing learned optimizers with the right inductive biases; challenges interpreting the method of action of learned optimizers. I will share solutions to some of these challenges. He will show experimental results where learned optimizers outperform hand-designed optimizers in several contexts. Jascha will discuss novel capabilities that can be achieved by meta-training learned optimizers to target downstream performance rather than training loss. He will end with a demo of an open source JAX library for training, testing, and applying learned optimizers.