Skip to content

From Flax to Solstice¤

Open In Colab

This notebook starts with an MNIST classification project and demonstrates how to incrementally buy in to Solstice in 3 steps:

  1. Organise training code with solstice.Experiment
  2. Implement solstice.Metrics for tracking metrics
  3. Use the premade solstice.train() loop with solstice.Callbacks

Housekeeping: colab imports¤

# if solstice isn't avaialble, we're in Colab, so import extra packages
# else we assume you've set up the devcontainer so install no extras
try: 
  import solstice
except ImportError:
    ...
    %pip install solstice-jax
    %pip install flax
    %pip install optax

MNIST in pure Flax¤

First, set up the dataset:

%env XLA_PYTHON_CLIENT_PREALLOCATE=false

import tensorflow as tf
import tensorflow_datasets as tfds

# stop tensorflow grabbing GPU memory
tf.config.experimental.set_visible_devices([], "GPU")


train_ds = tfds.load(name="mnist", split="train", as_supervised=True, data_dir="/tmp/data")
assert isinstance(train_ds, tf.data.Dataset)
preprocess_mnist = lambda x, y: (
    tf.reshape(tf.cast(x, tf.float32) / 255, (784,)),
    tf.cast(y, tf.float32),
)
train_ds = train_ds.map(preprocess_mnist).batch(32).prefetch(1)
env: XLA_PYTHON_CLIENT_PREALLOCATE=false

/opt/venv/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2022-07-09 12:07:03.558482: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata".

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/data/mnist/3.0.1...

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.57 file/s]
Dataset mnist downloaded and prepared to /tmp/data/mnist/3.0.1. Subsequent calls will reuse this data.



Now, create the Flax model:

from typing import Sequence, Any
import flax.linen as nn
import jax.numpy as jnp

class MLP(nn.Module):
    features: Sequence[int]
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x):
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, dtype=self.dtype)(x)
            if i != len(self.features) - 1:
                x = nn.relu(x)
        return x

Now, define a TrainState object and training step (notice how this is already quite similar to solstice.Experiment):

from typing import Callable, Tuple
import jax
import optax
import dataclasses
from flax import struct

@struct.dataclass
class TrainState:
    params: optax.Params
    opt_state: optax.OptState
    tx: optax.GradientTransformation = struct.field(pytree_node=False)
    apply_fn: Callable = struct.field(pytree_node=False)

    @classmethod
    def create(cls, rng: int, learning_rate: float):
        key = jax.random.PRNGKey(rng)
        model = MLP(features=[300, 300, 10])
        params = model.init(key, jnp.ones([1, 784]))['params']
        tx = optax.sgd(learning_rate)
        opt_state = tx.init(params)
        return cls(params, opt_state, tx, model.apply)

@jax.jit
def train_step(
    state: TrainState, batch: Tuple[jnp.ndarray, jnp.ndarray]
    ) -> Tuple[TrainState, Any]:
    imgs, labels = batch

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, imgs)
        loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, 10)))
        return loss, logits

    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    preds = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(preds == labels)
    metrics = {'accuracy': accuracy, 'loss': loss}
    return dataclasses.replace(state, params=new_params, opt_state=new_opt_state), metrics

Finally, make a training loop and train the model:

from tqdm import tqdm

def flax_train(state: TrainState, train_ds: tf.data.Dataset, num_epochs: int):
    metrics = []
    for epoch in range(num_epochs):
        for batch in tqdm(train_ds.as_numpy_iterator(), total=len(train_ds)):
            state, batch_metrics = train_step(state, batch)
            metrics.append(batch_metrics)
        metrics = jax.tree_util.tree_map(lambda *ms: jnp.mean(jnp.array(ms)), *metrics)
        print(f"Epoch {epoch}, {metrics}")
        metrics = []
    return state

state = TrainState.create(rng=0, learning_rate=0.1)
trained_state = flax_train(state, train_ds, num_epochs=3)
100%|██████████| 1875/1875 [00:03<00:00, 478.66it/s]

Epoch 0, {'accuracy': DeviceArray(0.9185667, dtype=float32), 'loss': DeviceArray(0.27219555, dtype=float32)}

100%|██████████| 1875/1875 [00:01<00:00, 1332.18it/s]

Epoch 1, {'accuracy': DeviceArray(0.96816665, dtype=float32), 'loss': DeviceArray(0.10662512, dtype=float32)}

100%|██████████| 1875/1875 [00:01<00:00, 1247.04it/s]

Epoch 2, {'accuracy': DeviceArray(0.97900003, dtype=float32), 'loss': DeviceArray(0.07093416, dtype=float32)}

Introducing solstice.Experiment¤

Here, we introduce the solstice.Experiment, a better way to organise your deep learning code. When converting from Flax to Solstice, notice that a couple of things happened:

  • We replaced TrainState with solstice.Experiment, using __init__ instead of .create
  • We encapsulated the train_step() function into a train_step() method.
  • All mentions of state became mentions of self.
  • You can (optionally) use filtered transformations instead of specifying fields as static up-front (see the Solstice Primer for more info).

Notice that self is just a PyTree, and the train_step method is still a pure function. Like TrainState, Experiments are immutable, so all updates are performed out-of-place by returning a new Experiment from the step.

import equinox as eqx
import solstice

class MNISTClassifier(solstice.Experiment):
    params: optax.Params
    opt_state: optax.OptState
    tx: optax.GradientTransformation = eqx.static_field()
    apply_fn: Callable = eqx.static_field()

    def __init__(self, rng: int, learning_rate: float):
        key = jax.random.PRNGKey(rng)
        model = MLP(features=[300, 300, 10])
        self.params = model.init(key, jnp.ones([1, 784]))['params']
        self.tx = optax.sgd(learning_rate)
        self.opt_state = self.tx.init(self.params)
        self.apply_fn = model.apply

    @jax.jit
    def train_step(self, batch: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple["MNISTClassifier", Any]:
        imgs, labels = batch

        def loss_fn(params):
            logits = self.apply_fn({'params': params}, imgs)
            loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, 10)))
            return loss, logits

        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(self.params)
        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, updates)

        preds = jnp.argmax(logits, axis=-1)
        accuracy = jnp.mean(preds == labels)
        metrics = {'accuracy': accuracy, 'loss': loss}
        return solstice.replace(self, params=new_params, opt_state=new_opt_state), metrics

    def eval_step(self, batch):
        raise NotImplementedError("not bothering with eval in this example")


def solstice_train(exp: solstice.Experiment, train_ds: tf.data.Dataset, num_epochs: int):
    metrics = []
    for epoch in range(num_epochs):
        for batch in tqdm(train_ds.as_numpy_iterator(), total=len(train_ds)):
            exp, batch_metrics = exp.train_step(batch)
            metrics.append(batch_metrics)
        metrics = jax.tree_util.tree_map(lambda *ms: jnp.mean(jnp.array(ms)), *metrics)
        print(f"Epoch {epoch}, {metrics}")
        metrics = []
    return exp

exp = MNISTClassifier(rng=0, learning_rate=0.1)
trained_exp = solstice_train(exp, train_ds, num_epochs=3)
100%|██████████| 1875/1875 [00:01<00:00, 1108.96it/s]

Epoch 0, {'accuracy': DeviceArray(0.9185667, dtype=float32), 'loss': DeviceArray(0.27219555, dtype=float32)}

100%|██████████| 1875/1875 [00:01<00:00, 1098.72it/s]

Epoch 1, {'accuracy': DeviceArray(0.96816665, dtype=float32), 'loss': DeviceArray(0.10662512, dtype=float32)}

100%|██████████| 1875/1875 [00:01<00:00, 1216.60it/s]

Epoch 2, {'accuracy': DeviceArray(0.97900003, dtype=float32), 'loss': DeviceArray(0.07093416, dtype=float32)}

Notice that none of the logic has changed (in fact all the computations and results are identical), just the organisation. Even without the rest of Solstice, this has a few advantages over the pure Flax code:

  • Better ergonomics due to creating experiments with __init__ instead of custom classmethods.
  • Explicitly keeping related training code together in one place.
  • The flax code had implicit coupling between the train_step() and TrainState, it is now encapsulated into one class to make the dependency explicit.
  • It is now easier to define different Experiment classes for different experiments and sweep across them with with your favourite tools (such as hydra or wandb).

Introducing solstice.Metrics¤

Did you notice the subtle gotcha in the metrics calculation above? The dataset size needed to be perfectly divisible by the batch size, otherwise the last batch would have had a different size so averaging the loss and accuracy over all batches would have been wrong. Accumulating and calculating metrics gets even harder when you are using metrics that are not 'averageable' such as precision. We provide solstice.Metrics, an API for keeping track of metrics scalably and without these headaches.

A solstice.Metrics object knows how to do three things: - Calculate intermediate results from model outputs with __init__. - Accumulate results with other solstice.Metrics objects with merge(). - Calculate final metrics with compute().

Below, we integrate this into our current MNIST experiment, notice that the results are still the same, but the code is cleaner and more extensible:

from typing import Mapping

class MyMetrics(solstice.Metrics):
    """Custom Metrics class for calculating accuracy and average loss. Included for
    didactic purposes, in practice `solstice.ClassificationMetrics` is better."""

    average_loss: float
    count: int  # number of samples seen
    num_correct: int

    def __init__(self, preds: jnp.ndarray, targets: jnp.ndarray, loss: float) -> None:
        self.average_loss = loss
        self.count = preds.shape[0]  # assumes batch is first dim
        self.num_correct = jnp.sum(preds == targets)

    def merge(self, other: "MyMetrics") -> "MyMetrics":
        # can simply sum num_correct and count
        new_num_correct = self.num_correct + other.num_correct
        new_count = self.count + other.count

        # average loss is weighted by count from each object
        new_loss = (
            self.average_loss * self.count + other.average_loss * other.count
        ) / (self.count + other.count)

        return solstice.replace(
            self, num_correct=new_num_correct, count=new_count, average_loss=new_loss
        )

    def compute(self) -> Mapping[str, float]:
        return {
            "accuracy": self.num_correct / self.count,
            "average_loss": self.average_loss,
        }

class MNISTClassifierWithMetrics(solstice.Experiment):
    params: optax.Params
    opt_state: optax.OptState
    tx: optax.GradientTransformation = eqx.static_field()
    apply_fn: Callable = eqx.static_field()

    def __init__(self, rng: int, learning_rate: float):
        key = jax.random.PRNGKey(rng)
        model = MLP(features=[300, 300, 10])
        self.params = model.init(key, jnp.ones([1, 784]))['params']
        self.tx = optax.sgd(learning_rate)
        self.opt_state = self.tx.init(self.params)
        self.apply_fn = model.apply

    @jax.jit
    def train_step(self, batch: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple["MNISTClassifierWithMetrics", solstice.Metrics]:
        imgs, labels = batch

        def loss_fn(params):
            logits = self.apply_fn({'params': params}, imgs)
            loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, 10)))
            return loss, logits

        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(self.params)
        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, updates)

        preds = jnp.argmax(logits, axis=-1)
        metrics = MyMetrics(preds, labels, loss)
        return solstice.replace(self, params=new_params, opt_state=new_opt_state), metrics

    def eval_step(self, batch):
        raise NotImplementedError("not bothering with eval in this example")


def solstice_train_with_metrics(exp: solstice.Experiment, train_ds: tf.data.Dataset, num_epochs: int):
    metrics = None
    for epoch in range(num_epochs):
        for batch in tqdm(train_ds.as_numpy_iterator(), total=len(train_ds)):
            exp, batch_metrics = exp.train_step(batch)
            metrics = batch_metrics if metrics is None else batch_metrics.merge(metrics)
        assert metrics is not None
        metrics = metrics.compute()
        print(f"Epoch {epoch}, {metrics}")
        metrics = None
    return exp


exp = MNISTClassifierWithMetrics(rng=0, learning_rate=0.1)
trained_exp = solstice_train_with_metrics(exp, train_ds, num_epochs=3)
100%|██████████| 1875/1875 [00:02<00:00, 629.13it/s]

Epoch 0, {'accuracy': DeviceArray(0.9185667, dtype=float32), 'average_loss': DeviceArray(0.27219722, dtype=float32)}

100%|██████████| 1875/1875 [00:02<00:00, 848.14it/s]

Epoch 1, {'accuracy': DeviceArray(0.96816665, dtype=float32), 'average_loss': DeviceArray(0.10662578, dtype=float32)}

100%|██████████| 1875/1875 [00:02<00:00, 770.72it/s]
Epoch 2, {'accuracy': DeviceArray(0.97900003, dtype=float32), 'average_loss': DeviceArray(0.07093462, dtype=float32)}



Solstice also provides some pre-made metrics classes, such as solstice.ClassificationMetrics for common use cases.

Introducing solstice.train() and solstice.Callbacks¤

Often, training loops are boilerplate code. In general, they tend to have two parts: the loops that advance the training state, and the bits that make side effects such as logging and checkpointing work. Solstice comes with solstice.train(), a standard training loop which integrates with a flexible callback system for injecting side effects.

Below, we use the built-in solstice.LoggingCallback with solstice.train() to cut down on boilerplate code.

import logging
logging.getLogger("solstice").setLevel(logging.INFO)

# by default, `solstice.LoggingCallback` logs to the built-in Python logging system
# with name 'solstice' and level INFO. You can also use this callback with TensorBoard etc...
logging_callback = solstice.LoggingCallback()

exp = MNISTClassifierWithMetrics(rng=0, learning_rate=0.1)
trained_exp = solstice.train(exp, train_ds=train_ds, num_epochs=3, callbacks=[logging_callback])
Training:   0%|          | 0/3 [00:00<?, ?epoch/s]INFO:solstice:train step 0: {'accuracy': DeviceArray(0.9185667, dtype=float32), 'average_loss': DeviceArray(0.27219722, dtype=float32)}
Training:  33%|███▎      | 1/3 [00:02<00:05,  2.61s/epoch]INFO:solstice:train step 1: {'accuracy': DeviceArray(0.96816665, dtype=float32), 'average_loss': DeviceArray(0.10662578, dtype=float32)}
Training:  67%|██████▋   | 2/3 [00:05<00:02,  2.60s/epoch]INFO:solstice:train step 2: {'accuracy': DeviceArray(0.97900003, dtype=float32), 'average_loss': DeviceArray(0.07093462, dtype=float32)}
Training: 100%|██████████| 3/3 [00:07<00:00,  2.51s/epoch]

Notice that the results are still identical to the ones from the initial Flax code. All Solstice does is provide user-facing utilities for creating and scaling deep learning experiments in JAX. We encourage people to create their own Callbacks to do more interesting things.