Skip to content

solstice

Solstice, a library for creating and scaling experiments in JAX.


Whole API¤

Abstract

This is all of Solstice. Everything is accessible through the solstice.* namespace.

solstice.__all__ = ('Experiment', 'Metrics', 'ClassificationMetrics', 'Callback', 'CheckpointingCallback', 'EarlyStoppingCallback', 'LoggingCallback', 'ProfilingCallback', 'train', 'test', 'replace', 'EarlyStoppingException') module-attribute ¤

¤


Experiments¤

The Experiment is at the heart of Solstice. The API is similar to the pl.LightningModule loved by PyTorch-Lightning users, but we do less 'magic' to keep it as transparent as possible. If in doubt, just read the source code - it's really short!

Experiment ¤

Bases: eqx.Module, ABC

Base class for Solstice experiments.

An Experiment holds all stateful models, optimizers, etc... for a run and implements this interface. To make your own experiments, subclass this class and implement the logic for initialisation, training, and evaluating.

Tip

This is a subclass of equinox.Module, so you are free to use pure JAX transformations such as jax.jit and jax.pmap, as long as you remember to filter out static PyTree fields (e.g. with eqx.filter_jit).

Example

Pseudocode for typical Experiment usage:

exp = MyExperiment(...)  # initialise experiment state

for step in range(num_steps):
    exp, outs = exp.train_step(batch)
    #do anything with the outputs here

# exp is just a pytree, so we can save and restore checkpoints like so...
equinox.tree_serialise_leaves("checkpoint_0.eqx", exp)

This class just specifies a recommended interface for experiment code. Experiments implementing this interface will automatically work with the Solstice training loops. You can always create or override methods as you wish and no methods are special-cased. For example it is common to define a __call__ method to perform inference on a batch of data.

Source code in solstice/experiment.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class Experiment(eqx.Module, ABC):
    """Base class for Solstice experiments.

    An Experiment holds all stateful models, optimizers, etc... for a run and
    implements this interface. To make your own experiments, subclass this class and
    implement the logic for initialisation, training, and evaluating.

    !!! tip
        This is a subclass of `equinox.Module`, so you are free to use pure JAX
        transformations such as `jax.jit` and `jax.pmap`, as long as you remember to
        filter out static PyTree fields (e.g. with `eqx.filter_jit`).

    !!! example
        Pseudocode for typical `Experiment` usage:
        ```python

        exp = MyExperiment(...)  # initialise experiment state

        for step in range(num_steps):
            exp, outs = exp.train_step(batch)
            #do anything with the outputs here

        # exp is just a pytree, so we can save and restore checkpoints like so...
        equinox.tree_serialise_leaves("checkpoint_0.eqx", exp)


        ```

    This class just specifies a recommended interface for experiment code. Experiments
    implementing this interface will automatically work with the Solstice training
    loops. You can always create or override methods as you wish and no methods are
    special-cased. For example it is common to define a `__call__` method to perform
    inference on a batch of data.
    """

    @abstractmethod
    def __init__(self, *args, **kwargs) -> None:
        """Initialise the experiment.
        !!! example
            Pseudocode implementation for initialising an MNIST classifier with flax
            and optax:
            ```python
            class MNISTExperiment(Experiment):
                params: Any
                opt_state: Any
                opt_apply: Callable
                model_apply: Callable
                num_classes: int

                def __init__(self, rng: int, model: flax.nn.Module,
                    optimizer = optax.GradientTransformation
                ) -> None:
                    key = jax.random.PRNGKey(rng)
                    dummy_batch = jnp.zeros((32, 784))
                    self.params = model.init(key, dummy_batch)
                    self.model_apply = model.apply
                    self.opt = optax.adam(learning_rate=1e-3)
                    self.opt_state = optimizer.init(self.params)
                    self.num_classes = 10
            ```
        """
        raise NotImplementedError()

    @abstractmethod
    def train_step(self, batch: Any) -> Tuple[Experiment, Any]:
        """A training step takes a batch of data and returns the updated experiment and
        any auxiliary outputs (usually a `solstice.Metrics` object).

        !!! tip
            You will typically want to use `jax.jit`, `jax.pmap`, `eqx.filter_jit`, or
            `eqx.filter_pmap` on this method. See the
            [solstice primer](https://charl-ai.github.io/Solstice/primer/)
            for more info on filtered transformations. You can also read the tutorial on
            different [parallelism strategies](https://charl-ai.github.io/Solstice/parallelism_strategies/).

        !!! example
            Pseudocode implementation of a training step:
            ```python
            class MNISTExperiment(Experiment):
                @eqx.filter_jit(kwargs=dict(batch=True))
                def train_step(self, batch: Tuple[np.ndarray, ...]
                ) -> Tuple[Experiment, solstice.Metrics]:

                imgs, labels = batch

                def loss_fn(params, x, y):
                    ... # compute loss
                    return loss, logits

                (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(
                    self.params, imgs, labels
                )

                new_params, new_opt_state = ... # calculate grads and update params
                preds = jnp.argmax(logits, axis=-1)
                metrics = MyMetrics(preds, labels, loss)

                return (
                    solstice.replace(self, params=new_params, opt_state=new_opt_state),
                    metrics,
                )
            ```

        !!! tip
            You can use the `solstice.replace` function as a way of returning an
            experiment instance with modified state.

        Args:
            batch (Any): Batch of data. Usually, this will be either a tuple of
                (input, target) arrays or a dictionary mapping keys to arrays.

        Returns:
            Tuple[Experiment, Any]: A new instance of the Experiment with the updated
                state and any auxiliary outputs, such as metrics.
        """
        raise NotImplementedError()

    @abstractmethod
    def eval_step(self, batch: Any) -> Tuple[Experiment, Any]:
        """An evaluation step (e.g. for validation or testing) takes a batch of data and
        returns the updated experiment and any auxiliary outputs. Usually, this will be
        a `solstice.Metrics` object. Like `train_step()`, you should probably JIT this
        method.

        !!! tip
            In most evaluation cases, the experiment returned will be unchanged,
            the main reason why you would want to modify it is to advance PRNG state.

        !!! example
            Pseudocode implementation of an evaluation step:
            ```python
            class MNISTExperiment(Experiment):
                @eqx.filter_jit(kwargs=dict(batch=True))
                def eval_step(self, batch: Tuple[np.ndarray, ...]
                ) -> Tuple[Experiment, Any]:
                imgs, labels = batch

                logits = ... # apply the model e.g. self.apply_fn(imgs)
                loss = ... # compute loss
                preds = jnp.argmax(logits, axis=-1)
                metrics = MyMetrics(preds, labels, loss)
                return self, metrics
            ```

        Args:
            batch (Any): Batch of data. Usually, this will be either a tuple of
                (input, target) arrays or a dictionary mapping keys to arrays.

        Returns:
            Tuple[Experiment, Any]: A new instance of the Experiment with the updated
                state and any auxiliary outputs, such as metrics.

        """
        raise NotImplementedError()

__init__(*args, **kwargs) -> None abstractmethod ¤

Initialise the experiment.

Example

Pseudocode implementation for initialising an MNIST classifier with flax and optax:

class MNISTExperiment(Experiment):
    params: Any
    opt_state: Any
    opt_apply: Callable
    model_apply: Callable
    num_classes: int

    def __init__(self, rng: int, model: flax.nn.Module,
        optimizer = optax.GradientTransformation
    ) -> None:
        key = jax.random.PRNGKey(rng)
        dummy_batch = jnp.zeros((32, 784))
        self.params = model.init(key, dummy_batch)
        self.model_apply = model.apply
        self.opt = optax.adam(learning_rate=1e-3)
        self.opt_state = optimizer.init(self.params)
        self.num_classes = 10

Source code in solstice/experiment.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@abstractmethod
def __init__(self, *args, **kwargs) -> None:
    """Initialise the experiment.
    !!! example
        Pseudocode implementation for initialising an MNIST classifier with flax
        and optax:
        ```python
        class MNISTExperiment(Experiment):
            params: Any
            opt_state: Any
            opt_apply: Callable
            model_apply: Callable
            num_classes: int

            def __init__(self, rng: int, model: flax.nn.Module,
                optimizer = optax.GradientTransformation
            ) -> None:
                key = jax.random.PRNGKey(rng)
                dummy_batch = jnp.zeros((32, 784))
                self.params = model.init(key, dummy_batch)
                self.model_apply = model.apply
                self.opt = optax.adam(learning_rate=1e-3)
                self.opt_state = optimizer.init(self.params)
                self.num_classes = 10
        ```
    """
    raise NotImplementedError()

train_step(batch: Any) -> Tuple[Experiment, Any] abstractmethod ¤

A training step takes a batch of data and returns the updated experiment and any auxiliary outputs (usually a solstice.Metrics object).

Tip

You will typically want to use jax.jit, jax.pmap, eqx.filter_jit, or eqx.filter_pmap on this method. See the solstice primer for more info on filtered transformations. You can also read the tutorial on different parallelism strategies.

Example

Pseudocode implementation of a training step:

class MNISTExperiment(Experiment):
    @eqx.filter_jit(kwargs=dict(batch=True))
    def train_step(self, batch: Tuple[np.ndarray, ...]
    ) -> Tuple[Experiment, solstice.Metrics]:

    imgs, labels = batch

    def loss_fn(params, x, y):
        ... # compute loss
        return loss, logits

    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(
        self.params, imgs, labels
    )

    new_params, new_opt_state = ... # calculate grads and update params
    preds = jnp.argmax(logits, axis=-1)
    metrics = MyMetrics(preds, labels, loss)

    return (
        solstice.replace(self, params=new_params, opt_state=new_opt_state),
        metrics,
    )

Tip

You can use the solstice.replace function as a way of returning an experiment instance with modified state.

Parameters:

  • batch (Any) –

    Batch of data. Usually, this will be either a tuple of (input, target) arrays or a dictionary mapping keys to arrays.

Returns:

  • Tuple[Experiment, Any]

    Tuple[Experiment, Any]: A new instance of the Experiment with the updated state and any auxiliary outputs, such as metrics.

Source code in solstice/experiment.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@abstractmethod
def train_step(self, batch: Any) -> Tuple[Experiment, Any]:
    """A training step takes a batch of data and returns the updated experiment and
    any auxiliary outputs (usually a `solstice.Metrics` object).

    !!! tip
        You will typically want to use `jax.jit`, `jax.pmap`, `eqx.filter_jit`, or
        `eqx.filter_pmap` on this method. See the
        [solstice primer](https://charl-ai.github.io/Solstice/primer/)
        for more info on filtered transformations. You can also read the tutorial on
        different [parallelism strategies](https://charl-ai.github.io/Solstice/parallelism_strategies/).

    !!! example
        Pseudocode implementation of a training step:
        ```python
        class MNISTExperiment(Experiment):
            @eqx.filter_jit(kwargs=dict(batch=True))
            def train_step(self, batch: Tuple[np.ndarray, ...]
            ) -> Tuple[Experiment, solstice.Metrics]:

            imgs, labels = batch

            def loss_fn(params, x, y):
                ... # compute loss
                return loss, logits

            (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(
                self.params, imgs, labels
            )

            new_params, new_opt_state = ... # calculate grads and update params
            preds = jnp.argmax(logits, axis=-1)
            metrics = MyMetrics(preds, labels, loss)

            return (
                solstice.replace(self, params=new_params, opt_state=new_opt_state),
                metrics,
            )
        ```

    !!! tip
        You can use the `solstice.replace` function as a way of returning an
        experiment instance with modified state.

    Args:
        batch (Any): Batch of data. Usually, this will be either a tuple of
            (input, target) arrays or a dictionary mapping keys to arrays.

    Returns:
        Tuple[Experiment, Any]: A new instance of the Experiment with the updated
            state and any auxiliary outputs, such as metrics.
    """
    raise NotImplementedError()

eval_step(batch: Any) -> Tuple[Experiment, Any] abstractmethod ¤

An evaluation step (e.g. for validation or testing) takes a batch of data and returns the updated experiment and any auxiliary outputs. Usually, this will be a solstice.Metrics object. Like train_step(), you should probably JIT this method.

Tip

In most evaluation cases, the experiment returned will be unchanged, the main reason why you would want to modify it is to advance PRNG state.

Example

Pseudocode implementation of an evaluation step:

class MNISTExperiment(Experiment):
    @eqx.filter_jit(kwargs=dict(batch=True))
    def eval_step(self, batch: Tuple[np.ndarray, ...]
    ) -> Tuple[Experiment, Any]:
    imgs, labels = batch

    logits = ... # apply the model e.g. self.apply_fn(imgs)
    loss = ... # compute loss
    preds = jnp.argmax(logits, axis=-1)
    metrics = MyMetrics(preds, labels, loss)
    return self, metrics

Parameters:

  • batch (Any) –

    Batch of data. Usually, this will be either a tuple of (input, target) arrays or a dictionary mapping keys to arrays.

Returns:

  • Tuple[Experiment, Any]

    Tuple[Experiment, Any]: A new instance of the Experiment with the updated state and any auxiliary outputs, such as metrics.

Source code in solstice/experiment.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@abstractmethod
def eval_step(self, batch: Any) -> Tuple[Experiment, Any]:
    """An evaluation step (e.g. for validation or testing) takes a batch of data and
    returns the updated experiment and any auxiliary outputs. Usually, this will be
    a `solstice.Metrics` object. Like `train_step()`, you should probably JIT this
    method.

    !!! tip
        In most evaluation cases, the experiment returned will be unchanged,
        the main reason why you would want to modify it is to advance PRNG state.

    !!! example
        Pseudocode implementation of an evaluation step:
        ```python
        class MNISTExperiment(Experiment):
            @eqx.filter_jit(kwargs=dict(batch=True))
            def eval_step(self, batch: Tuple[np.ndarray, ...]
            ) -> Tuple[Experiment, Any]:
            imgs, labels = batch

            logits = ... # apply the model e.g. self.apply_fn(imgs)
            loss = ... # compute loss
            preds = jnp.argmax(logits, axis=-1)
            metrics = MyMetrics(preds, labels, loss)
            return self, metrics
        ```

    Args:
        batch (Any): Batch of data. Usually, this will be either a tuple of
            (input, target) arrays or a dictionary mapping keys to arrays.

    Returns:
        Tuple[Experiment, Any]: A new instance of the Experiment with the updated
            state and any auxiliary outputs, such as metrics.

    """
    raise NotImplementedError()

Metrics¤

Our Metrics API is similar to the one in CLU, although more sexy because we use equinox :) We favour defining one single object for handling all metrics for an experiment instead of composing multiple objects into a collection. This is more efficient because often we can calculate a battery of metrics from the same intermediate results. It is also simpler and easier to reason about.

Metrics ¤

Bases: eqx.Module, ABC

Base class for metrics. A Metrics object handles calculating intermediate metrics from model outputs, accumulating them over batches, then calculating final metrics from accumulated metrics. Subclass this class and implement the interface for initialisation, accumulation, and finalisation.

Tip

This class doesn't have to handle 'metrics' in the strictest sense. You could implement a Metrics class to collect output images for plotting for example.

Example

Pseudocode for typical Metrics usage:

metrics = None
for batch in dataset:
    batch_metrics = step(batch)  # step returns a Metrics object
    metrics = metrics.merge(batch_metrics) if metrics else batch_metrics

    if time_to_log:
        metrics_dict = metrics.compute()
        ... # log your metrics here
        metrics = None  # reset the object
Source code in solstice/metrics.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class Metrics(eqx.Module, ABC):
    """Base class for metrics. A Metrics object handles calculating intermediate
    metrics from model outputs, accumulating them over batches, then
    calculating final metrics from accumulated metrics. Subclass this class and
    implement the interface for initialisation, accumulation, and finalisation.

    !!! tip
        This class doesn't have to handle 'metrics' in the strictest sense. You could
        implement a `Metrics` class to collect output images for plotting for example.

    !!! example
        Pseudocode for typical `Metrics` usage:

        ```python
        metrics = None
        for batch in dataset:
            batch_metrics = step(batch)  # step returns a Metrics object
            metrics = metrics.merge(batch_metrics) if metrics else batch_metrics

            if time_to_log:
                metrics_dict = metrics.compute()
                ... # log your metrics here
                metrics = None  # reset the object
        ```
    """

    @abstractmethod
    def __init__(self, *args, **kwargs) -> None:
        """Initialise a metrics object, typically with predictions and targets.

        !!! example
            Pseudocode for typical `Metrics` initialisation, this example object will
            keep track of the number of correct predictions and the total number of
            predictions:
            ```python
            class MyMetrics(Metrics):
                count: int
                num_correct: int
                def __init__(self, preds: jnp.ndarray, targets: jnp.ndarray) -> None:
                    self.count = preds.shape[0]  # assumes batch is first dim
                    self.num_correct = jnp.sum(preds == targets)
            ```

        !!! tip
            In classification settings, the confusion matrix is a useful intermediate
            result to calculate during initialisation.
        """
        raise NotImplementedError

    @abstractmethod
    def merge(self, other: Metrics) -> Metrics:
        """Merge two metrics objects, returning a new metrics object.

        !!! example
            Pseudocode for typical `Metrics` merging, in the example code, we can simply
            sum the number of correct predictions and the total number of predictions:
            ```python
            class MyMetrics(Metrics):
                def merge(self, other: Metrics) -> Metrics:
                    new_num_correct = self.num_correct + other.num_correct
                    new_count = self.count + other.count
                    return solstice.replace(self,
                        num_correct=new_num_correct, count=new_count)
            ```
        """
        raise NotImplementedError

    @abstractmethod
    def compute(self) -> Any:
        """Compute final metrics from accumulated metrics.

        !!! example
            Pseudocode for typical `Metrics` finalisation, here we calculate accuracy
            from the number of correct predictions and the total number of predictions:
            ```python
            class MyMetrics(Metrics):
                def compute(self) -> Mapping[str, float]:
                    return {'accuracy': self.num_correct / self.count}
            ```

        !!! tip
            Typically, you will want to return a dictionary of metrics. Try to put any
            expensive computations here, not in `__init__`.
        """

        raise NotImplementedError

__init__(*args, **kwargs) -> None abstractmethod ¤

Initialise a metrics object, typically with predictions and targets.

Example

Pseudocode for typical Metrics initialisation, this example object will keep track of the number of correct predictions and the total number of predictions:

class MyMetrics(Metrics):
    count: int
    num_correct: int
    def __init__(self, preds: jnp.ndarray, targets: jnp.ndarray) -> None:
        self.count = preds.shape[0]  # assumes batch is first dim
        self.num_correct = jnp.sum(preds == targets)

Tip

In classification settings, the confusion matrix is a useful intermediate result to calculate during initialisation.

Source code in solstice/metrics.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@abstractmethod
def __init__(self, *args, **kwargs) -> None:
    """Initialise a metrics object, typically with predictions and targets.

    !!! example
        Pseudocode for typical `Metrics` initialisation, this example object will
        keep track of the number of correct predictions and the total number of
        predictions:
        ```python
        class MyMetrics(Metrics):
            count: int
            num_correct: int
            def __init__(self, preds: jnp.ndarray, targets: jnp.ndarray) -> None:
                self.count = preds.shape[0]  # assumes batch is first dim
                self.num_correct = jnp.sum(preds == targets)
        ```

    !!! tip
        In classification settings, the confusion matrix is a useful intermediate
        result to calculate during initialisation.
    """
    raise NotImplementedError

merge(other: Metrics) -> Metrics abstractmethod ¤

Merge two metrics objects, returning a new metrics object.

Example

Pseudocode for typical Metrics merging, in the example code, we can simply sum the number of correct predictions and the total number of predictions:

class MyMetrics(Metrics):
    def merge(self, other: Metrics) -> Metrics:
        new_num_correct = self.num_correct + other.num_correct
        new_count = self.count + other.count
        return solstice.replace(self,
            num_correct=new_num_correct, count=new_count)

Source code in solstice/metrics.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@abstractmethod
def merge(self, other: Metrics) -> Metrics:
    """Merge two metrics objects, returning a new metrics object.

    !!! example
        Pseudocode for typical `Metrics` merging, in the example code, we can simply
        sum the number of correct predictions and the total number of predictions:
        ```python
        class MyMetrics(Metrics):
            def merge(self, other: Metrics) -> Metrics:
                new_num_correct = self.num_correct + other.num_correct
                new_count = self.count + other.count
                return solstice.replace(self,
                    num_correct=new_num_correct, count=new_count)
        ```
    """
    raise NotImplementedError

compute() -> Any abstractmethod ¤

Compute final metrics from accumulated metrics.

Example

Pseudocode for typical Metrics finalisation, here we calculate accuracy from the number of correct predictions and the total number of predictions:

class MyMetrics(Metrics):
    def compute(self) -> Mapping[str, float]:
        return {'accuracy': self.num_correct / self.count}

Tip

Typically, you will want to return a dictionary of metrics. Try to put any expensive computations here, not in __init__.

Source code in solstice/metrics.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@abstractmethod
def compute(self) -> Any:
    """Compute final metrics from accumulated metrics.

    !!! example
        Pseudocode for typical `Metrics` finalisation, here we calculate accuracy
        from the number of correct predictions and the total number of predictions:
        ```python
        class MyMetrics(Metrics):
            def compute(self) -> Mapping[str, float]:
                return {'accuracy': self.num_correct / self.count}
        ```

    !!! tip
        Typically, you will want to return a dictionary of metrics. Try to put any
        expensive computations here, not in `__init__`.
    """

    raise NotImplementedError

ClassificationMetrics ¤

Bases: Metrics

Basic metrics for multiclass classification tasks.

Metrics included:

  • Average Loss

  • Accuracy

  • Prevalence

  • F1 score

  • Sensitivity (TPR, recall)

  • Positive predictive value (PPV, precision)

Accuracy is reported as Top-1 accuracy which is equal to the micro-average of precision/recall/f1. Prevalence is reported on a per-class basis. Precision, Recall and F1 are reported three times: per-class, macro-average, and weighted average (by prevalence).

Not for multi-label classification.

Info

See https://en.wikipedia.org/wiki/Confusion_matrix for more on confusion matrices and classification metrics. See https://scikit-learn.org/stable/modules/model_evaluation.html#from-binary-to-multiclass-and-multilabel for more on multiclass micro/macro/weighted averaging.

Source code in solstice/metrics.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
class ClassificationMetrics(Metrics):
    """Basic metrics for multiclass classification tasks.
    !!! summary "Metrics included:"
            - Average Loss

            - Accuracy

            - Prevalence

            - F1 score

            - Sensitivity (TPR, recall)

            - Positive predictive value (PPV, precision)

    Accuracy is reported as Top-1 accuracy which is equal to the micro-average of
    precision/recall/f1. Prevalence is reported on a per-class basis. Precision, Recall
    and F1 are reported three times: per-class, macro-average, and weighted average (by
    prevalence).

    *Not* for multi-label classification.

    !!! info
        See https://en.wikipedia.org/wiki/Confusion_matrix for more on confusion
        matrices and classification metrics.
        See https://scikit-learn.org/stable/modules/model_evaluation.html#from-binary-to-multiclass-and-multilabel
        for more on multiclass micro/macro/weighted averaging.

    """

    _confusion_matrix: jnp.ndarray
    _average_loss: float
    _count: int
    _num_classes: int

    def __init__(
        self, preds: jnp.ndarray, targets: jnp.ndarray, loss: float, num_classes: int
    ) -> None:
        """
        Create a ClassificationMetrics object from model predictions and targets.

        Args:
            preds (jnp.ndarray): Non OH encoded predictions, shape: (batch_size,).
            targets (jnp.ndarray): Non OH encoded targets, shape: (batch_size,).
            loss (float): Average loss over the batch (scalar).
            num_classes (int): Number of classes in classification problem.
        """
        self._confusion_matrix = _compute_confusion_matrix(preds, targets, num_classes)
        self._average_loss = loss
        self._count = preds.shape[0]
        self._num_classes = num_classes

    def merge(self, other: ClassificationMetrics) -> ClassificationMetrics:
        assert isinstance(other, ClassificationMetrics), (
            "Can only merge ClassificationMetrics object with another"
            f" ClassificationMetrics object, got {type(other)}"
        )
        assert self._num_classes == other._num_classes, (
            "Can only merge metrics with same num_classes, got"
            f" {self._num_classes} and {other._num_classes}"
        )
        # can simply sum confusion matrices and count
        new_cm = self._confusion_matrix + other._confusion_matrix
        new_count = self._count + other._count

        # average loss is weighted by count from each object
        new_loss = (
            self._average_loss * self._count + other._average_loss * other._count
        ) / (self._count + other._count)

        return replace(
            self, _confusion_matrix=new_cm, _average_loss=new_loss, _count=new_count
        )

    def compute(self) -> Mapping[str, float]:

        metrics = _compute_metrics_from_cm(self._confusion_matrix)
        metrics["average_loss"] = self._average_loss

        return metrics

__init__(preds: jnp.ndarray, targets: jnp.ndarray, loss: float, num_classes: int) -> None ¤

Create a ClassificationMetrics object from model predictions and targets.

Parameters:

  • preds (jnp.ndarray) –

    Non OH encoded predictions, shape: (batch_size,).

  • targets (jnp.ndarray) –

    Non OH encoded targets, shape: (batch_size,).

  • loss (float) –

    Average loss over the batch (scalar).

  • num_classes (int) –

    Number of classes in classification problem.

Source code in solstice/metrics.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def __init__(
    self, preds: jnp.ndarray, targets: jnp.ndarray, loss: float, num_classes: int
) -> None:
    """
    Create a ClassificationMetrics object from model predictions and targets.

    Args:
        preds (jnp.ndarray): Non OH encoded predictions, shape: (batch_size,).
        targets (jnp.ndarray): Non OH encoded targets, shape: (batch_size,).
        loss (float): Average loss over the batch (scalar).
        num_classes (int): Number of classes in classification problem.
    """
    self._confusion_matrix = _compute_confusion_matrix(preds, targets, num_classes)
    self._average_loss = loss
    self._count = preds.shape[0]
    self._num_classes = num_classes

Training¤

Training loops are usually boilerplate code that has little to do with your research. We provide training and testing loops which integrate with a simple and flexible callback system. Any solstice.Experiment can be passed to the loops, but you can always write your own if necessary. We provide a handful of pre-implemented callbacks, but if they do not suit your needs, you can use them as inspiration to write your own.

Callback ¤

Bases: ABC

Base class for callbacks to solstice.train() and `solstice.test(). Subclass and implement this interface to inject arbitrary functionality into the training and testing loops.

Tip

All callback hooks return None, so they cannot affect the training itself. Use callbacks to execute side effects like logging, checkpointing or profiling.

Example

Pseudocode callback implementation for logging with solstice.Metrics:

class MyLoggingCallback(Callback):
    def __init__(self, log_every_n_steps, ...):
        self.metrics = None
        self.log_every_n_steps = log_every_n_steps
        ... # set up logging, e.g. wandb.init(...)

    def on_step_end(self, exp, global_step, training, batch, outs):
        assert isinstance(outs, solstice.Metrics)
        self.metrics = outs.merge(self.metrics) if self.metrics else outs
        if (global_step + 1) % self.log_every_n_steps == 0:
            metrics_dict = self.metrics.compute()
            ... # do logging e.g. wandb.log(metrics_dict)
            self.metrics = None

Source code in solstice/trainer.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class Callback(ABC):
    """Base class for callbacks to `solstice.train()` and `solstice.test(). Subclass
    and implement this interface to inject arbitrary functionality into the training
    and testing loops.

    !!! tip
        All callback hooks return `None`, so they cannot affect the training itself.
        Use callbacks to execute side effects like logging, checkpointing or profiling.

    !!! example
        Pseudocode callback implementation for logging with `solstice.Metrics`:
        ```python

        class MyLoggingCallback(Callback):
            def __init__(self, log_every_n_steps, ...):
                self.metrics = None
                self.log_every_n_steps = log_every_n_steps
                ... # set up logging, e.g. wandb.init(...)

            def on_step_end(self, exp, global_step, training, batch, outs):
                assert isinstance(outs, solstice.Metrics)
                self.metrics = outs.merge(self.metrics) if self.metrics else outs
                if (global_step + 1) % self.log_every_n_steps == 0:
                    metrics_dict = self.metrics.compute()
                    ... # do logging e.g. wandb.log(metrics_dict)
                    self.metrics = None
        ```
    """

    @abstractmethod
    def __init__(self, *args, **kwargs) -> None:
        """Initialize the callback."""
        raise NotImplementedError

    def on_epoch_start(
        self, exp: Experiment, epoch: int, mode: Literal["train", "val", "test"]
    ) -> None:
        """Called at the start of each epoch, i.e. before the model has seen any data
        for that epoch.

        Args:
            exp (Experiment): Current Experiment state.
            epoch (int): Current epoch number.
            mode (Literal["train", "val", "test"]): String representing whether this is
                a training, validation or testing epoch.
        """
        pass

    def on_epoch_end(
        self, exp: Experiment, epoch: int, mode: Literal["train", "val", "test"]
    ) -> None:
        """Called at the end of each epoch, i.e. after the model has seen the full
        dataset for that epoch.

        Args:
            exp (Experiment): Current Experiment state.
            epoch (int): Current epoch number.
            mode (Literal["train", "val", "test"]): String representing whether this is
                a training, validation or testing step.
        """
        pass

    def on_step_start(
        self,
        exp: Experiment,
        global_step: int,
        mode: Literal["train", "val", "test"],
        batch: Any,
    ) -> None:
        """Called at the start of each training and validation step, i.e. before the
        batch has been seen.

        Args:
            exp (Experiment): Current Experiment state.
            global_step (int): Current step number. This is the global step, i.e. the
                total number of training or validation or testing steps seen so far.
                Note that we keep separate step counts for training and validation, so
                it might not be unique.
            mode (Literal["train", "val", "test"]): String representing whether this is
                a training, validation or testing step.
            batch (Any): Current batch of data for this step.
        """
        pass

    def on_step_end(
        self,
        outs: Any,
        exp: Experiment,
        global_step: int,
        mode: Literal["train", "val", "test"],
        batch: Any,
    ) -> None:
        """Called at the end of each training and validation step, i.e. after the batch
        has been seen.

        Args:
            exp (Experiment): Current Experiment state.
            global_step (int): Current step number. This is the global step, i.e. the
                total number of training or validation or testing steps seen so far.
                Note that we keep separate step counts for training and validation, so
                it might not be unique.
            mode (Literal["train", "val", "test"]): String representing whether this is
                a training, validation or testing step.
            batch (Any): Current batch of data for this step.
            outs (Any): Auxiliary outputs from the experiment train/eval step. Usually,
                this should be a `solstice.Metrics` object.
        """
        pass

__init__(*args, **kwargs) -> None abstractmethod ¤

Initialize the callback.

Source code in solstice/trainer.py
59
60
61
62
@abstractmethod
def __init__(self, *args, **kwargs) -> None:
    """Initialize the callback."""
    raise NotImplementedError

on_epoch_start(exp: Experiment, epoch: int, mode: Literal[train, val, test]) -> None ¤

Called at the start of each epoch, i.e. before the model has seen any data for that epoch.

Parameters:

  • exp (Experiment) –

    Current Experiment state.

  • epoch (int) –

    Current epoch number.

  • mode (Literal[train, val, test]) –

    String representing whether this is a training, validation or testing epoch.

Source code in solstice/trainer.py
64
65
66
67
68
69
70
71
72
73
74
75
76
def on_epoch_start(
    self, exp: Experiment, epoch: int, mode: Literal["train", "val", "test"]
) -> None:
    """Called at the start of each epoch, i.e. before the model has seen any data
    for that epoch.

    Args:
        exp (Experiment): Current Experiment state.
        epoch (int): Current epoch number.
        mode (Literal["train", "val", "test"]): String representing whether this is
            a training, validation or testing epoch.
    """
    pass

on_epoch_end(exp: Experiment, epoch: int, mode: Literal[train, val, test]) -> None ¤

Called at the end of each epoch, i.e. after the model has seen the full dataset for that epoch.

Parameters:

  • exp (Experiment) –

    Current Experiment state.

  • epoch (int) –

    Current epoch number.

  • mode (Literal[train, val, test]) –

    String representing whether this is a training, validation or testing step.

Source code in solstice/trainer.py
78
79
80
81
82
83
84
85
86
87
88
89
90
def on_epoch_end(
    self, exp: Experiment, epoch: int, mode: Literal["train", "val", "test"]
) -> None:
    """Called at the end of each epoch, i.e. after the model has seen the full
    dataset for that epoch.

    Args:
        exp (Experiment): Current Experiment state.
        epoch (int): Current epoch number.
        mode (Literal["train", "val", "test"]): String representing whether this is
            a training, validation or testing step.
    """
    pass

on_step_start(exp: Experiment, global_step: int, mode: Literal[train, val, test], batch: Any) -> None ¤

Called at the start of each training and validation step, i.e. before the batch has been seen.

Parameters:

  • exp (Experiment) –

    Current Experiment state.

  • global_step (int) –

    Current step number. This is the global step, i.e. the total number of training or validation or testing steps seen so far. Note that we keep separate step counts for training and validation, so it might not be unique.

  • mode (Literal[train, val, test]) –

    String representing whether this is a training, validation or testing step.

  • batch (Any) –

    Current batch of data for this step.

Source code in solstice/trainer.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def on_step_start(
    self,
    exp: Experiment,
    global_step: int,
    mode: Literal["train", "val", "test"],
    batch: Any,
) -> None:
    """Called at the start of each training and validation step, i.e. before the
    batch has been seen.

    Args:
        exp (Experiment): Current Experiment state.
        global_step (int): Current step number. This is the global step, i.e. the
            total number of training or validation or testing steps seen so far.
            Note that we keep separate step counts for training and validation, so
            it might not be unique.
        mode (Literal["train", "val", "test"]): String representing whether this is
            a training, validation or testing step.
        batch (Any): Current batch of data for this step.
    """
    pass

on_step_end(outs: Any, exp: Experiment, global_step: int, mode: Literal[train, val, test], batch: Any) -> None ¤

Called at the end of each training and validation step, i.e. after the batch has been seen.

Parameters:

  • exp (Experiment) –

    Current Experiment state.

  • global_step (int) –

    Current step number. This is the global step, i.e. the total number of training or validation or testing steps seen so far. Note that we keep separate step counts for training and validation, so it might not be unique.

  • mode (Literal[train, val, test]) –

    String representing whether this is a training, validation or testing step.

  • batch (Any) –

    Current batch of data for this step.

  • outs (Any) –

    Auxiliary outputs from the experiment train/eval step. Usually, this should be a solstice.Metrics object.

Source code in solstice/trainer.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def on_step_end(
    self,
    outs: Any,
    exp: Experiment,
    global_step: int,
    mode: Literal["train", "val", "test"],
    batch: Any,
) -> None:
    """Called at the end of each training and validation step, i.e. after the batch
    has been seen.

    Args:
        exp (Experiment): Current Experiment state.
        global_step (int): Current step number. This is the global step, i.e. the
            total number of training or validation or testing steps seen so far.
            Note that we keep separate step counts for training and validation, so
            it might not be unique.
        mode (Literal["train", "val", "test"]): String representing whether this is
            a training, validation or testing step.
        batch (Any): Current batch of data for this step.
        outs (Any): Auxiliary outputs from the experiment train/eval step. Usually,
            this should be a `solstice.Metrics` object.
    """
    pass

LoggingCallback ¤

Bases: Callback

Logs auxiliary outputs from training or evaulation steps (either periodically every n steps, or at the end of the epoch). Internally, this accumulates metrics with metrics.merge(), computes them with metrics.compute(), and then passes the final results to the given logging function.

Warning

Auxiliary outputs from the train and eval steps must be a solstice.Metrics instance for this callback to work properly. We raise an AssertionError if this is not the case.

Note

There are many different libraries you can use for writing logs (e.g. wandb, TensorBoard(X), ...). We offer no opinion on which one you should use. Pass in a logging function to use any arbitrary logger.

Source code in solstice/trainer.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class LoggingCallback(Callback):
    """Logs auxiliary outputs from training or evaulation steps (either periodically
    every n steps, or at the end of the epoch). Internally, this accumulates metrics
    with `metrics.merge()`, computes them with `metrics.compute()`, and then passes
    the final results to the given logging function.

    !!! warning
        Auxiliary outputs from the train and eval steps must be a `solstice.Metrics`
        instance for this callback to work properly. We raise an AssertionError if this
        is not the case.

    !!! note
        There are many different libraries you can use for writing logs (e.g. wandb,
        TensorBoard(X), ...). We offer no opinion on which one you should use. Pass in
        a logging function to use any arbitrary logger.
    """

    def __init__(
        self,
        log_every_n_steps: int | None = None,
        logging_fn: Callable[[Any, int, Literal["train", "val", "test"]], None]
        | None = None,
    ) -> None:
        """Initialize the logging callback.

        Args:
            log_every_n_steps (int | None, optional): If given, accumulate metrics over
                n steps before logging. If None, log at end of epoch. Defaults to None.
            logging_fn (Callable[[Any, int, Literal['train', 'val', 'test']], None] | None, optional):
                Logging function. Takes the outputs of `metrics.compute()`, the current
                step or epoch number, and a string representing whether training,
                validating, or testing. The function should return nothing. If no
                logging_fn is given, the default behaviour is to log with the built in
                Python logger (INFO level). Defaults to None.

        !!! example
            The default logging function (used if None is given) logs using the built
            in Python logger, with name "solstice" and INFO level
            (notice that the output of `metrics.compute()` must be printable):
            ```python
            logger = logging.getLogger("solstice")

            default_logger = lambda metrics, step, mode: logging.info(
                f"{mode} step {step}: {metrics}"
            )
            ```

            If the logs aren't showing, you might need to put this line at the top of
            your script:
            ```python
            import logging
            logging.getLogger("solstice").setLevel(logging.INFO)
            ```
        """
        default_logger = lambda metrics, step, mode: logger.info(
            f"{mode} step {step}: {metrics}"
        )
        self.logging_fn = logging_fn if logging_fn else default_logger
        self.log_every_n_steps = log_every_n_steps
        self.metrics = None

    def on_step_end(
        self,
        outs: Any,
        exp: Experiment,
        global_step: int,
        mode: Literal["train", "val", "test"],
        batch: Any,
    ) -> None:
        del exp, batch
        assert isinstance(outs, Metrics)
        self.metrics = outs.merge(self.metrics) if self.metrics else outs

        if self.log_every_n_steps and (global_step + 1) % self.log_every_n_steps == 0:
            final_metrics = self.metrics.compute()
            self.logging_fn(final_metrics, global_step, mode)
            self.metrics = None

    def on_epoch_end(
        self, exp: Experiment, epoch: int, mode: Literal["train", "val", "test"]
    ) -> None:
        del exp
        # if not logging every n steps, we just log at the end of the epoch
        if not self.log_every_n_steps:
            assert self.metrics is not None
            final_metrics = self.metrics.compute()
            self.logging_fn(final_metrics, epoch, mode)
        # reset the metrics object to prevent train/val metrics from being mixed
        self.metrics = None

__init__(log_every_n_steps: int | None = None, logging_fn: Callable[[Any, int, Literal[train, val, test]], None] | None = None) -> None ¤

Initialize the logging callback.

Parameters:

  • log_every_n_steps (int | None) –

    If given, accumulate metrics over n steps before logging. If None, log at end of epoch. Defaults to None.

  • logging_fn (Callable[[Any, int, Literal[train, val, test]], None] | None) –

    Logging function. Takes the outputs of metrics.compute(), the current step or epoch number, and a string representing whether training, validating, or testing. The function should return nothing. If no logging_fn is given, the default behaviour is to log with the built in Python logger (INFO level). Defaults to None.

Example

The default logging function (used if None is given) logs using the built in Python logger, with name "solstice" and INFO level (notice that the output of metrics.compute() must be printable):

logger = logging.getLogger("solstice")

default_logger = lambda metrics, step, mode: logging.info(
    f"{mode} step {step}: {metrics}"
)

If the logs aren't showing, you might need to put this line at the top of your script:

import logging
logging.getLogger("solstice").setLevel(logging.INFO)

Source code in solstice/trainer.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def __init__(
    self,
    log_every_n_steps: int | None = None,
    logging_fn: Callable[[Any, int, Literal["train", "val", "test"]], None]
    | None = None,
) -> None:
    """Initialize the logging callback.

    Args:
        log_every_n_steps (int | None, optional): If given, accumulate metrics over
            n steps before logging. If None, log at end of epoch. Defaults to None.
        logging_fn (Callable[[Any, int, Literal['train', 'val', 'test']], None] | None, optional):
            Logging function. Takes the outputs of `metrics.compute()`, the current
            step or epoch number, and a string representing whether training,
            validating, or testing. The function should return nothing. If no
            logging_fn is given, the default behaviour is to log with the built in
            Python logger (INFO level). Defaults to None.

    !!! example
        The default logging function (used if None is given) logs using the built
        in Python logger, with name "solstice" and INFO level
        (notice that the output of `metrics.compute()` must be printable):
        ```python
        logger = logging.getLogger("solstice")

        default_logger = lambda metrics, step, mode: logging.info(
            f"{mode} step {step}: {metrics}"
        )
        ```

        If the logs aren't showing, you might need to put this line at the top of
        your script:
        ```python
        import logging
        logging.getLogger("solstice").setLevel(logging.INFO)
        ```
    """
    default_logger = lambda metrics, step, mode: logger.info(
        f"{mode} step {step}: {metrics}"
    )
    self.logging_fn = logging_fn if logging_fn else default_logger
    self.log_every_n_steps = log_every_n_steps
    self.metrics = None

CheckpointingCallback ¤

Bases: Callback

Checkpoint the experiment state at the end of each epoch.

Todo

Implement this. Consider adding asynchronous checkpointing.

Source code in solstice/trainer.py
231
232
233
234
235
236
237
class CheckpointingCallback(Callback):
    """Checkpoint the experiment state at the end of each epoch.

    !!! todo
        Implement this. Consider adding asynchronous checkpointing."""

    pass

ProfilingCallback ¤

Bases: Callback

Uses the built-in JAX (TensorBoard) profiler to profile training and evaluation steps.

Note

To view the traces, ensure TensorBoard is installed. Then run tensorboard --logdir=<log_dir>. See https://jax.readthedocs.io/en/latest/profiling.html for more information.

Source code in solstice/trainer.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class ProfilingCallback(Callback):
    """Uses the built-in JAX (TensorBoard) profiler to profile training and evaluation
    steps.

    !!! note
        To view the traces, ensure TensorBoard is installed. Then run
        `tensorboard --logdir=<log_dir>`. See
        https://jax.readthedocs.io/en/latest/profiling.html for more information."""

    def __init__(self, log_dir: str, steps_to_profile: list[int] | None = None) -> None:
        """Initialize the Profiler callback.

        !!! tip
            You can use the `steps_to_profile` argument to profile only a subset of the
            steps. Usually, step 0 will be slowest due to JIT compilation, so you might
            want to profile steps 0 and 1.

        Args:
            log_dir (str): Directory to write the profiler trace files to.
            steps_to_profile (list[int] | None, optional): If given, only profile these
                steps, else profile every step. Defaults to None.
        """
        self.log_dir = log_dir
        self.steps_to_profile = steps_to_profile

    def on_step_start(
        self,
        exp: Experiment,
        global_step: int,
        mode: Literal["train", "val", "test"],
        batch,
    ) -> None:
        del exp, mode, batch
        if self.steps_to_profile is None or global_step in self.steps_to_profile:
            jax.profiler.start_trace(self.log_dir)

    def on_step_end(
        self,
        outs: Any,
        exp: Experiment,
        global_step: int,
        mode: Literal["train", "val", "test"],
        batch,
    ) -> None:
        del exp, mode, batch, outs
        if self.steps_to_profile is None or global_step in self.steps_to_profile:
            jax.profiler.stop_trace()

__init__(log_dir: str, steps_to_profile: list[int] | None = None) -> None ¤

Initialize the Profiler callback.

Tip

You can use the steps_to_profile argument to profile only a subset of the steps. Usually, step 0 will be slowest due to JIT compilation, so you might want to profile steps 0 and 1.

Parameters:

  • log_dir (str) –

    Directory to write the profiler trace files to.

  • steps_to_profile (list[int] | None) –

    If given, only profile these steps, else profile every step. Defaults to None.

Source code in solstice/trainer.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def __init__(self, log_dir: str, steps_to_profile: list[int] | None = None) -> None:
    """Initialize the Profiler callback.

    !!! tip
        You can use the `steps_to_profile` argument to profile only a subset of the
        steps. Usually, step 0 will be slowest due to JIT compilation, so you might
        want to profile steps 0 and 1.

    Args:
        log_dir (str): Directory to write the profiler trace files to.
        steps_to_profile (list[int] | None, optional): If given, only profile these
            steps, else profile every step. Defaults to None.
    """
    self.log_dir = log_dir
    self.steps_to_profile = steps_to_profile

EarlyStoppingCallback ¤

Bases: Callback

Stops training early if a criterion is met. Checks once per validation epoch (at the end). This callback accumulates auxiliary outputs from each validation step into a list and passes them to the criterion function which determines whether to stop training.

Tip

If this callback doesn't suit your needs, you can implement your own early stopping callback by raising an EarlyStoppingException in the on_step_end hook.

Source code in solstice/trainer.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class EarlyStoppingCallback(Callback):
    """Stops training early if a criterion is met. Checks once per validation epoch
    (at the end). This callback accumulates auxiliary outputs from each validation step
    into a list and passes them to the criterion function which determines whether to
    stop training.

    !!! tip
        If this callback doesn't suit your needs, you can implement your own early
        stopping callback by raising an `EarlyStoppingException` in the `on_step_end`
        hook.
    """

    def __init__(
        self,
        criterion_fn: Callable[[list[Any]], bool],
        accumulate_every_n_steps: int = 1,
    ) -> None:
        """Initialize the EarlyStoppingCallback.

        Args:
            criterion_fn (Callable[[list[Any]], bool]): Function that takes a list of
                the accumulated auxiliary outputs from each step and returns a boolean
                indicating whether to stop training.
            accumulate_every_n_steps (int, optional): Accumulate auxiliary outputs every
                nth step. Set to 2 to only keep half, 3 for keeping 1/3, etc. This
                effectively downsamples the signal (so beware it is losing information).
                Defaults to 1.

        !!! example
            Example criterion function takes the final metrics object, calls .compute()
            on it to return a dictionary, and stops training if accuracy is > 0.9:
            TODO: update example when `solstice.reduce`  is implemented
            ```python
            criterion fn = lambda metrics: metrics.compute()["accuracy"] > 0.9
            ```
        """
        self.accumulated_outs = []
        self.criterion_fn = criterion_fn
        self.accumulate_every_n_steps = accumulate_every_n_steps

    def on_step_end(
        self,
        outs: Any,
        exp: Experiment,
        global_step: int,
        mode: Literal["train", "val", "test"],
        batch,
    ) -> None:
        del exp, batch

        if mode == "val" and global_step % self.accumulate_every_n_steps == 0:
            self.accumulated_outs.append(outs)

    def on_epoch_end(
        self, exp: Experiment, epoch: int, mode: Literal["train", "val", "test"]
    ) -> None:
        del exp, epoch
        if mode == "val":
            if self.criterion_fn(self.accumulated_outs):
                raise EarlyStoppingException()
        self.accumulated_outs = []  # reset for next epoch

__init__(criterion_fn: Callable[[list[Any]], bool], accumulate_every_n_steps: int = 1) -> None ¤

Initialize the EarlyStoppingCallback.

Parameters:

  • criterion_fn (Callable[[list[Any]], bool]) –

    Function that takes a list of the accumulated auxiliary outputs from each step and returns a boolean indicating whether to stop training.

  • accumulate_every_n_steps (int) –

    Accumulate auxiliary outputs every nth step. Set to 2 to only keep half, 3 for keeping 1/3, etc. This effectively downsamples the signal (so beware it is losing information). Defaults to 1.

Example

Example criterion function takes the final metrics object, calls .compute() on it to return a dictionary, and stops training if accuracy is > 0.9: TODO: update example when solstice.reduce is implemented

criterion fn = lambda metrics: metrics.compute()["accuracy"] > 0.9

Source code in solstice/trainer.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def __init__(
    self,
    criterion_fn: Callable[[list[Any]], bool],
    accumulate_every_n_steps: int = 1,
) -> None:
    """Initialize the EarlyStoppingCallback.

    Args:
        criterion_fn (Callable[[list[Any]], bool]): Function that takes a list of
            the accumulated auxiliary outputs from each step and returns a boolean
            indicating whether to stop training.
        accumulate_every_n_steps (int, optional): Accumulate auxiliary outputs every
            nth step. Set to 2 to only keep half, 3 for keeping 1/3, etc. This
            effectively downsamples the signal (so beware it is losing information).
            Defaults to 1.

    !!! example
        Example criterion function takes the final metrics object, calls .compute()
        on it to return a dictionary, and stops training if accuracy is > 0.9:
        TODO: update example when `solstice.reduce`  is implemented
        ```python
        criterion fn = lambda metrics: metrics.compute()["accuracy"] > 0.9
        ```
    """
    self.accumulated_outs = []
    self.criterion_fn = criterion_fn
    self.accumulate_every_n_steps = accumulate_every_n_steps

train(exp: ExperimentType, num_epochs: int, train_ds: tf.data.Dataset, val_ds: tf.data.Dataset | None = None, callbacks: list[Callback] | None = None) -> ExperimentType ¤

Train a solstice.Experiment, using tf.data.Dataset for data loading. Supply solstice.Callbacks to add any additional functionality.

Parameters:

  • exp (Experiment) –

    Solstice experiment to train.

  • num_epochs (int) –

    Number of epochs to train for.

  • train_ds (tf.data.Dataset) –

    TensorFlow dataset of training data.

  • val_ds (tf.data.Dataset | None) –

    TensorFlow dataset of validation data. If none is given, validation is skipped. Defaults to None.

  • callbacks (list[Callback] | None) –

    List of Solstice callbacks. These can execute arbitrary code on certain events, usually for side effects like logging and checkpointing. See solstice.Callback. Defaults to None.

Returns:

  • Experiment( ExperimentType ) –

    Trained experiment.

Source code in solstice/trainer.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def train(
    exp: ExperimentType,
    num_epochs: int,
    train_ds: tf.data.Dataset,
    val_ds: tf.data.Dataset | None = None,
    callbacks: list[Callback] | None = None,
) -> ExperimentType:
    """Train a `solstice.Experiment`, using `tf.data.Dataset` for data loading.
    Supply `solstice.Callback`s to add any additional functionality.

    Args:
        exp (Experiment): Solstice experiment to train.
        num_epochs (int): Number of epochs to train for.
        train_ds (tf.data.Dataset): TensorFlow dataset of training data.
        val_ds (tf.data.Dataset | None, optional): TensorFlow dataset of validation
            data. If none is given, validation is skipped. Defaults to None.
        callbacks (list[Callback] | None, optional): List of Solstice callbacks. These
            can execute arbitrary code on certain events, usually for side effects like
            logging and checkpointing. See `solstice.Callback`. Defaults to None.

    Returns:
        Experiment: Trained experiment.
    """

    for epoch in tqdm(range(num_epochs), desc="Training", unit="epoch"):
        assert isinstance(epoch, int)  # just for mypy for type narrowing

        for mode, ds in zip(["train", "val"], [train_ds, val_ds]):
            assert _is_valid_mode(mode)  # type narrowing
            if ds is None:
                continue

            [
                cb.on_epoch_start(exp, epoch, mode) for cb in callbacks
            ] if callbacks is not None else None

            global_step = epoch * len(ds)  # nb: separate step counts for train and val
            for batch in tqdm(
                ds.as_numpy_iterator(),
                total=len(ds),
                desc=f"{mode}",
                leave=False,
                unit="step",
            ):
                global_step += 1

                [
                    cb.on_step_start(exp, global_step, mode, batch) for cb in callbacks
                ] if callbacks is not None else None

                exp, outs = (
                    exp.train_step(batch) if mode == "train" else exp.eval_step(batch)
                )

                [
                    cb.on_step_end(outs, exp, global_step, mode, batch)
                    for cb in callbacks
                ] if callbacks is not None else None

            try:
                [
                    cb.on_epoch_end(exp, epoch, mode) for cb in callbacks
                ] if callbacks is not None else None
            except EarlyStoppingException:
                logging.info(f"Early stopping at epoch {epoch}")
                return exp
    return exp

test(exp: Experiment, test_ds: tf.data.Dataset, callbacks: list[Callback] | None = None, return_outs: bool = False) -> list[Any] | None ¤

Test a solstice.Experiment, using tf.data.Dataset for data loading. Supply solstice.Callbacks to add any additional functionality.

Parameters:

  • exp (Experiment) –

    Experiment to test.

  • test_ds (tf.data.Dataset) –

    TensorFlow dataset of test data.

  • callbacks (list[Callback] | None) –

    List of Solstice callbacks. These can execute arbitrary code on certain events, usually for side effects like logging. See solstice.Callback. Defaults to None.

  • return_outs (bool) –

    If True, the auxiliary outputs from exp.eval_step() are accumulated into a list and returned, else this function returns nothing. Defaults to False.

Tip

Testing simply involves running through the test_ds for a single epoch. Thus the on_epoch_start() and on_epoch_end() callback hooks are executed once each, before testing starts and after testing ends.

Returns:

  • list[Any] | None

    list[Any] | None: List of auxiliary outputs from exp.eval_step() if return_outs is True, else None.

Source code in solstice/trainer.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def test(
    exp: Experiment,
    test_ds: tf.data.Dataset,
    callbacks: list[Callback] | None = None,
    return_outs: bool = False,
) -> list[Any] | None:
    """Test a `solstice.Experiment`, using `tf.data.Dataset` for data loading. Supply
    `solstice.Callback`s to add any additional functionality.

    Args:
        exp (Experiment): Experiment to test.
        test_ds (tf.data.Dataset): TensorFlow dataset of test data.
        callbacks (list[Callback] | None, optional): List of Solstice callbacks. These
            can execute arbitrary code on certain events, usually for side effects like
            logging. See `solstice.Callback`. Defaults to None.
        return_outs (bool, optional): If True, the auxiliary outputs from
            `exp.eval_step()` are accumulated into a list and returned, else this
            function returns nothing. Defaults to False.

    !!! tip
        Testing simply involves running through the test_ds for a single epoch. Thus
        the `on_epoch_start()` and `on_epoch_end()` callback hooks are executed once
        each, before testing starts and after testing ends.

    Returns:
        list[Any] | None: List of auxiliary outputs from `exp.eval_step()` if
            return_outs is True, else None.
    """
    assert callbacks is not None or return_outs is True, (
        "No callbacks were provided and return_outs is False. This function thus has no"
        " return vaules or side effects. All it does is heat up the planet :("
    )

    mode: Literal["test"] = "test"
    [
        cb.on_epoch_start(exp, 0, mode) for cb in callbacks
    ] if callbacks is not None else None

    global_step = 0
    outputs_list = []

    for batch in tqdm(
        test_ds.as_numpy_iterator(), total=len(test_ds), desc="Testing", unit="step"
    ):
        global_step += 1

        [
            cb.on_step_start(exp, global_step, mode, batch) for cb in callbacks
        ] if callbacks is not None else None

        exp, outs = exp.eval_step(batch)
        outputs_list.append(outs) if return_outs else None

        [
            cb.on_step_end(outs, exp, global_step, mode, batch) for cb in callbacks
        ] if callbacks is not None else None

    [
        cb.on_epoch_end(exp, 0, mode) for cb in callbacks
    ] if callbacks is not None else None

    return outputs_list if return_outs else None

Utilities¤

Miscellaneous utilities for Solstice.

EarlyStoppingException ¤

Bases: Exception

A callback can raise this exception on_epoch_end to break the training loop early. Useful if you want to write a custom alternative to EarlyStoppingCallback.

Source code in solstice/utils.py
52
53
54
55
56
class EarlyStoppingException(Exception):
    """A callback can raise this exception `on_epoch_end` to break the training loop
    early. Useful if you want to write a custom alternative to `EarlyStoppingCallback`."""

    pass

replace(obj: Module, **changes: Any) -> Module ¤

Make out-of-place changes to a Module, returning a new module with changes applied. Just a wrapper around equinox.tree_at.

Example

You can use this in the same way as dataclasses.replace, but it only works with eqx.Modules. The advantage is that it can be used when custom __init__ constructors are defined. For more info, see https://github.com/patrick-kidger/equinox/issues/120.

import equinox as eqx
import solstice

class Counter(eqx.Module):
    x: int

    def __init__(self, z: int):
        # 'smart' constructor inits x by calculating from z
        self.x = 2 * z

    def increment(self):
        return solstice.replace(self, x=self.x+1)

C1 = Counter(z=0)
assert C1.x == 0
C2 = C1.increment()
assert C2.x == 1

Parameters:

  • obj (Module) –

    Module to make changes to (subclass of eqx.Module).

  • **changes (Any) –

    Keyword arguments to replace in the module.

Returns:

  • Module( Module ) –

    New instance of obj with the changes applied.

Source code in solstice/utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def replace(obj: Module, **changes: Any) -> Module:
    """Make out-of-place changes to a Module, returning a new module with changes
    applied. Just a wrapper around `equinox.tree_at`.

    !!! example
        You can use this in the same way as `dataclasses.replace`, but it only works
        with `eqx.Module`s. The advantage is that it can be used when custom `__init__`
        constructors are defined.
        For more info, see https://github.com/patrick-kidger/equinox/issues/120.

        ```python

        import equinox as eqx
        import solstice

        class Counter(eqx.Module):
            x: int

            def __init__(self, z: int):
                # 'smart' constructor inits x by calculating from z
                self.x = 2 * z

            def increment(self):
                return solstice.replace(self, x=self.x+1)

        C1 = Counter(z=0)
        assert C1.x == 0
        C2 = C1.increment()
        assert C2.x == 1
        ```

    Args:
        obj (Module): Module to make changes to (subclass of `eqx.Module`).
        **changes (Any): Keyword arguments to replace in the module.

    Returns:
        Module: New instance of `obj` with the changes applied.
    """

    keys, vals = zip(*changes.items())
    return eqx.tree_at(lambda c: [getattr(c, key) for key in keys], obj, vals)  # type: ignore