From Flax to Solstice¤
This notebook starts with an MNIST classification project and demonstrates how to incrementally buy in to Solstice in 3 steps:
- Organise training code with
solstice.Experiment
- Implement
solstice.Metrics
for tracking metrics - Use the premade
solstice.train()
loop withsolstice.Callback
s
Housekeeping: colab imports¤
# if solstice isn't avaialble, we're in Colab, so import extra packages
# else we assume you've set up the devcontainer so install no extras
try:
import solstice
except ImportError:
...
%pip install solstice-jax
%pip install flax
%pip install optax
MNIST in pure Flax¤
First, set up the dataset:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import tensorflow as tf
import tensorflow_datasets as tfds
# stop tensorflow grabbing GPU memory
tf.config.experimental.set_visible_devices([], "GPU")
train_ds = tfds.load(name="mnist", split="train", as_supervised=True, data_dir="/tmp/data")
assert isinstance(train_ds, tf.data.Dataset)
preprocess_mnist = lambda x, y: (
tf.reshape(tf.cast(x, tf.float32) / 255, (784,)),
tf.cast(y, tf.float32),
)
train_ds = train_ds.map(preprocess_mnist).batch(32).prefetch(1)
Now, create the Flax model:
from typing import Sequence, Any
import flax.linen as nn
import jax.numpy as jnp
class MLP(nn.Module):
features: Sequence[int]
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x):
for i, feat in enumerate(self.features):
x = nn.Dense(feat, dtype=self.dtype)(x)
if i != len(self.features) - 1:
x = nn.relu(x)
return x
Now, define a TrainState
object and training step (notice how this is already quite similar to solstice.Experiment
):
from typing import Callable, Tuple
import jax
import optax
import dataclasses
from flax import struct
@struct.dataclass
class TrainState:
params: optax.Params
opt_state: optax.OptState
tx: optax.GradientTransformation = struct.field(pytree_node=False)
apply_fn: Callable = struct.field(pytree_node=False)
@classmethod
def create(cls, rng: int, learning_rate: float):
key = jax.random.PRNGKey(rng)
model = MLP(features=[300, 300, 10])
params = model.init(key, jnp.ones([1, 784]))['params']
tx = optax.sgd(learning_rate)
opt_state = tx.init(params)
return cls(params, opt_state, tx, model.apply)
@jax.jit
def train_step(
state: TrainState, batch: Tuple[jnp.ndarray, jnp.ndarray]
) -> Tuple[TrainState, Any]:
imgs, labels = batch
def loss_fn(params):
logits = state.apply_fn({'params': params}, imgs)
loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, 10)))
return loss, logits
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
preds = jnp.argmax(logits, axis=-1)
accuracy = jnp.mean(preds == labels)
metrics = {'accuracy': accuracy, 'loss': loss}
return dataclasses.replace(state, params=new_params, opt_state=new_opt_state), metrics
Finally, make a training loop and train the model:
from tqdm import tqdm
def flax_train(state: TrainState, train_ds: tf.data.Dataset, num_epochs: int):
metrics = []
for epoch in range(num_epochs):
for batch in tqdm(train_ds.as_numpy_iterator(), total=len(train_ds)):
state, batch_metrics = train_step(state, batch)
metrics.append(batch_metrics)
metrics = jax.tree_util.tree_map(lambda *ms: jnp.mean(jnp.array(ms)), *metrics)
print(f"Epoch {epoch}, {metrics}")
metrics = []
return state
state = TrainState.create(rng=0, learning_rate=0.1)
trained_state = flax_train(state, train_ds, num_epochs=3)
Introducing solstice.Experiment
¤
Here, we introduce the solstice.Experiment
, a better way to organise your deep learning code. When converting from Flax to Solstice, notice that a couple of things happened:
- We replaced
TrainState
withsolstice.Experiment
, using__init__
instead of.create
- We encapsulated the
train_step()
function into atrain_step()
method. - All mentions of
state
became mentions ofself
. - You can (optionally) use filtered transformations instead of specifying fields as static up-front (see the Solstice Primer for more info).
Notice that self
is just a PyTree, and the train_step
method is still a pure function. Like TrainState
, Experiment
s are immutable, so all updates are performed out-of-place by returning a new Experiment
from the step.
import equinox as eqx
import solstice
class MNISTClassifier(solstice.Experiment):
params: optax.Params
opt_state: optax.OptState
tx: optax.GradientTransformation = eqx.static_field()
apply_fn: Callable = eqx.static_field()
def __init__(self, rng: int, learning_rate: float):
key = jax.random.PRNGKey(rng)
model = MLP(features=[300, 300, 10])
self.params = model.init(key, jnp.ones([1, 784]))['params']
self.tx = optax.sgd(learning_rate)
self.opt_state = self.tx.init(self.params)
self.apply_fn = model.apply
@jax.jit
def train_step(self, batch: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple["MNISTClassifier", Any]:
imgs, labels = batch
def loss_fn(params):
logits = self.apply_fn({'params': params}, imgs)
loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, 10)))
return loss, logits
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(self.params)
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
preds = jnp.argmax(logits, axis=-1)
accuracy = jnp.mean(preds == labels)
metrics = {'accuracy': accuracy, 'loss': loss}
return solstice.replace(self, params=new_params, opt_state=new_opt_state), metrics
def eval_step(self, batch):
raise NotImplementedError("not bothering with eval in this example")
def solstice_train(exp: solstice.Experiment, train_ds: tf.data.Dataset, num_epochs: int):
metrics = []
for epoch in range(num_epochs):
for batch in tqdm(train_ds.as_numpy_iterator(), total=len(train_ds)):
exp, batch_metrics = exp.train_step(batch)
metrics.append(batch_metrics)
metrics = jax.tree_util.tree_map(lambda *ms: jnp.mean(jnp.array(ms)), *metrics)
print(f"Epoch {epoch}, {metrics}")
metrics = []
return exp
exp = MNISTClassifier(rng=0, learning_rate=0.1)
trained_exp = solstice_train(exp, train_ds, num_epochs=3)
Notice that none of the logic has changed (in fact all the computations and results are identical), just the organisation. Even without the rest of Solstice, this has a few advantages over the pure Flax code:
- Better ergonomics due to creating experiments with
__init__
instead of custom classmethods. - Explicitly keeping related training code together in one place.
- The flax code had implicit coupling between the
train_step()
andTrainState
, it is now encapsulated into one class to make the dependency explicit. - It is now easier to define different
Experiment
classes for different experiments and sweep across them with with your favourite tools (such as hydra or wandb).
Introducing solstice.Metrics
¤
Did you notice the subtle gotcha in the metrics calculation above? The dataset size needed to be perfectly divisible by the batch size, otherwise the last batch would have had a different size so averaging the loss and accuracy over all batches would have been wrong. Accumulating and calculating metrics gets even harder when you are using metrics that are not 'averageable' such as precision. We provide solstice.Metrics
, an API for keeping track of metrics scalably and without these headaches.
A solstice.Metrics
object knows how to do three things:
- Calculate intermediate results from model outputs with __init__
.
- Accumulate results with other solstice.Metrics
objects with merge()
.
- Calculate final metrics with compute()
.
Below, we integrate this into our current MNIST experiment, notice that the results are still the same, but the code is cleaner and more extensible:
from typing import Mapping
class MyMetrics(solstice.Metrics):
"""Custom Metrics class for calculating accuracy and average loss. Included for
didactic purposes, in practice `solstice.ClassificationMetrics` is better."""
average_loss: float
count: int # number of samples seen
num_correct: int
def __init__(self, preds: jnp.ndarray, targets: jnp.ndarray, loss: float) -> None:
self.average_loss = loss
self.count = preds.shape[0] # assumes batch is first dim
self.num_correct = jnp.sum(preds == targets)
def merge(self, other: "MyMetrics") -> "MyMetrics":
# can simply sum num_correct and count
new_num_correct = self.num_correct + other.num_correct
new_count = self.count + other.count
# average loss is weighted by count from each object
new_loss = (
self.average_loss * self.count + other.average_loss * other.count
) / (self.count + other.count)
return solstice.replace(
self, num_correct=new_num_correct, count=new_count, average_loss=new_loss
)
def compute(self) -> Mapping[str, float]:
return {
"accuracy": self.num_correct / self.count,
"average_loss": self.average_loss,
}
class MNISTClassifierWithMetrics(solstice.Experiment):
params: optax.Params
opt_state: optax.OptState
tx: optax.GradientTransformation = eqx.static_field()
apply_fn: Callable = eqx.static_field()
def __init__(self, rng: int, learning_rate: float):
key = jax.random.PRNGKey(rng)
model = MLP(features=[300, 300, 10])
self.params = model.init(key, jnp.ones([1, 784]))['params']
self.tx = optax.sgd(learning_rate)
self.opt_state = self.tx.init(self.params)
self.apply_fn = model.apply
@jax.jit
def train_step(self, batch: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple["MNISTClassifierWithMetrics", solstice.Metrics]:
imgs, labels = batch
def loss_fn(params):
logits = self.apply_fn({'params': params}, imgs)
loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, 10)))
return loss, logits
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(self.params)
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
preds = jnp.argmax(logits, axis=-1)
metrics = MyMetrics(preds, labels, loss)
return solstice.replace(self, params=new_params, opt_state=new_opt_state), metrics
def eval_step(self, batch):
raise NotImplementedError("not bothering with eval in this example")
def solstice_train_with_metrics(exp: solstice.Experiment, train_ds: tf.data.Dataset, num_epochs: int):
metrics = None
for epoch in range(num_epochs):
for batch in tqdm(train_ds.as_numpy_iterator(), total=len(train_ds)):
exp, batch_metrics = exp.train_step(batch)
metrics = batch_metrics if metrics is None else batch_metrics.merge(metrics)
assert metrics is not None
metrics = metrics.compute()
print(f"Epoch {epoch}, {metrics}")
metrics = None
return exp
exp = MNISTClassifierWithMetrics(rng=0, learning_rate=0.1)
trained_exp = solstice_train_with_metrics(exp, train_ds, num_epochs=3)
Solstice also provides some pre-made metrics classes, such as solstice.ClassificationMetrics
for common use cases.
Introducing solstice.train()
and solstice.Callback
s¤
Often, training loops are boilerplate code. In general, they tend to have two parts: the loops that advance the training state, and the bits that make side effects such as logging and checkpointing work. Solstice comes with solstice.train()
, a standard training loop which integrates with a flexible callback system for injecting side effects.
Below, we use the built-in solstice.LoggingCallback
with solstice.train()
to cut down on boilerplate code.
import logging
logging.getLogger("solstice").setLevel(logging.INFO)
# by default, `solstice.LoggingCallback` logs to the built-in Python logging system
# with name 'solstice' and level INFO. You can also use this callback with TensorBoard etc...
logging_callback = solstice.LoggingCallback()
exp = MNISTClassifierWithMetrics(rng=0, learning_rate=0.1)
trained_exp = solstice.train(exp, train_ds=train_ds, num_epochs=3, callbacks=[logging_callback])
Notice that the results are still identical to the ones from the initial Flax code. All Solstice does is provide user-facing utilities for creating and scaling deep learning experiments in JAX. We encourage people to create their own Callback
s to do more interesting things.