Neural Network Building Blocks¶
DN3 provides a variety of ready-made networks and building blocks (any PyTorch modules would suffice) to be trained.
Models¶
Classes
|
A generic Classifer container. |
|
This is a base model used by the provided models in the library that is meant to make those included in this library as powerful and multi-purpose as is reasonable. |
|
This is the DN3 re-implementation of Lawhern et. |
|
This is the DN3 re-implementation of Lawhern et. |
|
In effect, simply an implementation of linear kernel (multi)logistic regression |
|
|
|
The Thinker Invariant Densenet from Kostas & Rudzicz 2020, https://doi.org/10.1088/1741-2552/abb7a7 |
-
class
dn3.trainable.models.
Classifier
(targets, samples, channels, return_features=True)¶ A generic Classifer container. This container breaks operations up into feature extraction and feature classification to enable convenience in transfer learning and more.
Methods
forward
(*x)Defines the computation performed at every call.
freeze_features
([unfreeze, freeze_classifier])In many cases, the features learned by a model in one domain can be applied to another case.
from_dataset
(dataset, **modelargs)Create a classifier from a dataset.
This allows for a distinction between the classification layer(s) and the rest of the network.
-
forward
(*x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
freeze_features
(unfreeze=False, freeze_classifier=False)¶ In many cases, the features learned by a model in one domain can be applied to another case.
This method freezes (or un-freezes) all but the classifier layer. So that any further training does not (or does if unfreeze=True) affect these weights.
- Parameters
unfreeze (bool) – To unfreeze weights after a previous call to this.
freeze_classifier (bool) – Commonly, the classifier layer will not be frozen (default). Setting this to True will freeze this layer too.
-
classmethod
from_dataset
(dataset: dn3.data.dataset.DN3ataset, **modelargs)¶ Create a classifier from a dataset.
- Parameters
dataset –
modelargs (dict) – Options to construct the dataset, if dataset does not have listed targets, targets must be specified in the keyword arguments or will fall back to 2.
- Returns
model – A new Classifier ready to classifiy data from dataset
- Return type
-
make_new_classification_layer
()¶ This allows for a distinction between the classification layer(s) and the rest of the network. Using a basic formulation of a network being composed of two parts feature_extractor & classifier.
This method is for implementing the classification side, so that methods like
freeze_features()
works as intended.Anything besides a layer that just flattens anything incoming to a vector and Linearly weights this to the target should override this method, and there should be a variable called self.classifier
-
-
class
dn3.trainable.models.
DN3BaseModel
(samples, channels, return_features=True)¶ This is a base model used by the provided models in the library that is meant to make those included in this library as powerful and multi-purpose as is reasonable.
It is not strictly necessary to have new modules inherit from this, any nn.Module should suffice, but it provides some integrated conveniences…
The premise of this model is that deep learning models can be understood as learned pipelines. These
DN3BaseModel
objects, are re-interpreted as a two-stage pipeline, the two stages being feature extraction and classification.Methods
clone
()This provides a standard way to copy models, weights and all.
forward
(x)Defines the computation performed at every call.
-
clone
()¶ This provides a standard way to copy models, weights and all.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.models.
EEGNet
(targets, samples, channels, do=0.25, pooling=8, F1=8, D=2, t_len=65, F2=16, return_features=False)¶ This is the DN3 re-implementation of Lawhern et. al.’s EEGNet from: https://iopscience.iop.org/article/10.1088/1741-2552/aace8c
Notes
The implementation below is in no way officially sanctioned by the original authors, and in fact is missing the constraints the original authors have on the convolution kernels, and may or may not be missing more…
That being said, in our own personal experience, this implementation has fared no worse when compared to implementations that include this constraint (albeit, those were also not written by the original authors).
-
class
dn3.trainable.models.
EEGNetStrided
(targets, samples, channels, do=0.25, pooling=8, F1=8, D=2, t_len=65, F2=16, return_features=False, stride_width=2)¶ This is the DN3 re-implementation of Lawhern et. al.’s EEGNet from: https://iopscience.iop.org/article/10.1088/1741-2552/aace8c
Notes
The implementation below is in no way officially sanctioned by the original authors, and in fact is missing the constraints the original authors have on the convolution kernels, and may or may not be missing more…
That being said, in our own personal experience, this implementation has fared no worse when compared to implementations that include this constraint (albeit, those were also not written by the original authors).
-
class
dn3.trainable.models.
LogRegNetwork
(targets, samples, channels, return_features=True)¶ In effect, simply an implementation of linear kernel (multi)logistic regression
-
class
dn3.trainable.models.
StrideClassifier
(targets, samples, channels, stride_width=2, return_features=False)¶ Methods
This allows for a distinction between the classification layer(s) and the rest of the network.
-
make_new_classification_layer
()¶ This allows for a distinction between the classification layer(s) and the rest of the network. Using a basic formulation of a network being composed of two parts feature_extractor & classifier.
This method is for implementing the classification side, so that methods like
freeze_features()
works as intended.Anything besides a layer that just flattens anything incoming to a vector and Linearly weights this to the target should override this method, and there should be a variable called self.classifier
-
-
class
dn3.trainable.models.
TIDNet
(targets, samples, channels, s_growth=24, t_filters=32, do=0.4, pooling=20, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, temp_layers=2, spat_layers=2, temp_span=0.05, bottleneck=3, summary=-1, return_features=False)¶ The Thinker Invariant Densenet from Kostas & Rudzicz 2020, https://doi.org/10.1088/1741-2552/abb7a7
This alone is not strictly “thinker invariant”, but on average outperforms shallower models at inter-subject prediction capability.
Layers¶
Classes
|
|
|
Implements complete convolution block with order: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-
class
dn3.trainable.layers.
Concatenate
(axis=- 1)¶ Methods
forward
(*x)Defines the computation performed at every call.
-
forward
(*x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
ConvBlock2D
(in_filters, out_filters, kernel, stride=(1, 1), padding=0, dilation=1, groups=1, do_rate=0.5, batch_norm=True, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, residual=False)¶ - Implements complete convolution block with order:
Convolution
dropout (spatial)
activation
batch-norm
(optional) residual reconnection
Methods
forward
(input, **kwargs)Defines the computation performed at every call.
-
forward
(input, **kwargs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
dn3.trainable.layers.
DenseFilter
(in_features, growth_rate, filter_len=5, do=0.5, bottleneck=2, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, dim=-2)¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
DenseSpatialFilter
(channels, growth, depth, in_ch=1, bottleneck=4, dropout_rate=0.0, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, collapse=True)¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
Expand
(axis=- 1)¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
Flatten
¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
IndexSelect
(indices)¶ Methods
forward
(*x)Defines the computation performed at every call.
-
forward
(*x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
Permute
(axes)¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
SpatialFilter
(channels, filters, depth, in_ch=1, dropout_rate=0.0, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, batch_norm=True, residual=False)¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
Squeeze
(axis=- 1)¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
dn3.trainable.layers.
TemporalFilter
(channels, filters, depth, temp_len, dropout=0.0, activation=<class 'torch.nn.modules.activation.LeakyReLU'>, residual='netwise')¶ Methods
forward
(x)Defines the computation performed at every call.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-