Neural Network Building Blocks

DN3 provides a variety of ready-made networks and building blocks (any PyTorch modules would suffice) to be trained.

Models

Classes

Classifier(targets, samples, channels[, …])

A generic Classifer container.

DN3BaseModel(samples, channels[, …])

This is a base model used by the provided models in the library that is meant to make those included in this library as powerful and multi-purpose as is reasonable.

EEGNet(targets, samples, channels[, do, …])

This is the DN3 re-implementation of Lawhern et.

EEGNetStrided(targets, samples, channels[, …])

This is the DN3 re-implementation of Lawhern et.

LogRegNetwork(targets, samples, channels[, …])

In effect, simply an implementation of linear kernel (multi)logistic regression

StrideClassifier(targets, samples, channels)

TIDNet(targets, samples, channels[, …])

The Thinker Invariant Densenet from Kostas & Rudzicz 2020, https://doi.org/10.1088/1741-2552/abb7a7

class dn3.trainable.models.Classifier(targets, samples, channels, return_features=True)

A generic Classifer container. This container breaks operations up into feature extraction and feature classification to enable convenience in transfer learning and more.

Methods

forward(*x)

Defines the computation performed at every call.

freeze_features([unfreeze, freeze_classifier])

In many cases, the features learned by a model in one domain can be applied to another case.

from_dataset(dataset, **modelargs)

Create a classifier from a dataset.

make_new_classification_layer()

This allows for a distinction between the classification layer(s) and the rest of the network.

forward(*x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

freeze_features(unfreeze=False, freeze_classifier=False)

In many cases, the features learned by a model in one domain can be applied to another case.

This method freezes (or un-freezes) all but the classifier layer. So that any further training does not (or does if unfreeze=True) affect these weights.

Parameters
  • unfreeze (bool) – To unfreeze weights after a previous call to this.

  • freeze_classifier (bool) – Commonly, the classifier layer will not be frozen (default). Setting this to True will freeze this layer too.

classmethod from_dataset(dataset: dn3.data.dataset.DN3ataset, **modelargs)

Create a classifier from a dataset.

Parameters
  • dataset

  • modelargs (dict) – Options to construct the dataset, if dataset does not have listed targets, targets must be specified in the keyword arguments or will fall back to 2.

Returns

model – A new Classifier ready to classifiy data from dataset

Return type

Classifier

make_new_classification_layer()

This allows for a distinction between the classification layer(s) and the rest of the network. Using a basic formulation of a network being composed of two parts feature_extractor & classifier.

This method is for implementing the classification side, so that methods like freeze_features() works as intended.

Anything besides a layer that just flattens anything incoming to a vector and Linearly weights this to the target should override this method, and there should be a variable called self.classifier

class dn3.trainable.models.DN3BaseModel(samples, channels, return_features=True)

This is a base model used by the provided models in the library that is meant to make those included in this library as powerful and multi-purpose as is reasonable.

It is not strictly necessary to have new modules inherit from this, any nn.Module should suffice, but it provides some integrated conveniences…

The premise of this model is that deep learning models can be understood as learned pipelines. These DN3BaseModel objects, are re-interpreted as a two-stage pipeline, the two stages being feature extraction and classification.

Methods

clone()

This provides a standard way to copy models, weights and all.

forward(x)

Defines the computation performed at every call.

clone()

This provides a standard way to copy models, weights and all.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.models.EEGNet(targets, samples, channels, do=0.25, pooling=8, F1=8, D=2, t_len=65, F2=16, return_features=False)

This is the DN3 re-implementation of Lawhern et. al.’s EEGNet from: https://iopscience.iop.org/article/10.1088/1741-2552/aace8c

Notes

The implementation below is in no way officially sanctioned by the original authors, and in fact is missing the constraints the original authors have on the convolution kernels, and may or may not be missing more…

That being said, in our own personal experience, this implementation has fared no worse when compared to implementations that include this constraint (albeit, those were also not written by the original authors).

class dn3.trainable.models.EEGNetStrided(targets, samples, channels, do=0.25, pooling=8, F1=8, D=2, t_len=65, F2=16, return_features=False, stride_width=2)

This is the DN3 re-implementation of Lawhern et. al.’s EEGNet from: https://iopscience.iop.org/article/10.1088/1741-2552/aace8c

Notes

The implementation below is in no way officially sanctioned by the original authors, and in fact is missing the constraints the original authors have on the convolution kernels, and may or may not be missing more…

That being said, in our own personal experience, this implementation has fared no worse when compared to implementations that include this constraint (albeit, those were also not written by the original authors).

class dn3.trainable.models.LogRegNetwork(targets, samples, channels, return_features=True)

In effect, simply an implementation of linear kernel (multi)logistic regression

class dn3.trainable.models.StrideClassifier(targets, samples, channels, stride_width=2, return_features=False)

Methods

make_new_classification_layer()

This allows for a distinction between the classification layer(s) and the rest of the network.

make_new_classification_layer()

This allows for a distinction between the classification layer(s) and the rest of the network. Using a basic formulation of a network being composed of two parts feature_extractor & classifier.

This method is for implementing the classification side, so that methods like freeze_features() works as intended.

Anything besides a layer that just flattens anything incoming to a vector and Linearly weights this to the target should override this method, and there should be a variable called self.classifier

class dn3.trainable.models.TIDNet(targets, samples, channels, s_growth=24, t_filters=32, do=0.4, pooling=20, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, temp_layers=2, spat_layers=2, temp_span=0.05, bottleneck=3, summary=-1, return_features=False)

The Thinker Invariant Densenet from Kostas & Rudzicz 2020, https://doi.org/10.1088/1741-2552/abb7a7

This alone is not strictly “thinker invariant”, but on average outperforms shallower models at inter-subject prediction capability.

Layers

Classes

Concatenate([axis])

ConvBlock2D(in_filters, out_filters, kernel)

Implements complete convolution block with order:

DenseFilter(in_features, growth_rate[, …])

DenseSpatialFilter(channels, growth, depth)

Expand([axis])

Flatten()

IndexSelect(indices)

Permute(axes)

SpatialFilter(channels, filters, depth[, …])

Squeeze([axis])

TemporalFilter(channels, filters, depth, …)

class dn3.trainable.layers.Concatenate(axis=- 1)

Methods

forward(*x)

Defines the computation performed at every call.

forward(*x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.ConvBlock2D(in_filters, out_filters, kernel, stride=(1, 1), padding=0, dilation=1, groups=1, do_rate=0.5, batch_norm=True, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, residual=False)
Implements complete convolution block with order:
  • Convolution

  • dropout (spatial)

  • activation

  • batch-norm

  • (optional) residual reconnection

Methods

forward(input, **kwargs)

Defines the computation performed at every call.

forward(input, **kwargs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.DenseFilter(in_features, growth_rate, filter_len=5, do=0.5, bottleneck=2, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, dim=-2)

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.DenseSpatialFilter(channels, growth, depth, in_ch=1, bottleneck=4, dropout_rate=0.0, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, collapse=True)

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.Expand(axis=- 1)

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.Flatten

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.IndexSelect(indices)

Methods

forward(*x)

Defines the computation performed at every call.

forward(*x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.Permute(axes)

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.SpatialFilter(channels, filters, depth, in_ch=1, dropout_rate=0.0, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, batch_norm=True, residual=False)

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.Squeeze(axis=- 1)

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dn3.trainable.layers.TemporalFilter(channels, filters, depth, temp_len, dropout=0.0, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, residual='netwise')

Methods

forward(x)

Defines the computation performed at every call.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.