Last updated:
0 purchases
atai 0.0.6
atai
Atomic AI is a flexible, minimalist deep neural network training
framework based on Jeremy Howard’s
miniai from the
fast.ai 2022 course.
Install
pip install atai
How to use
from atai.core import *
The following example demonstrates how the Atomic AI training framework
can be used to train a custom model that predicts protein solubility.
Imports
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import init
from torch import optim
from torcheval.metrics import BinaryAccuracy, BinaryAUROC
from torcheval.metrics.functional import binary_auroc, binary_accuracy
from torchmetrics.classification import BinaryMatthewsCorrCoef
from torchmetrics.functional.classification import binary_matthews_corrcoef
import fastcore.all as fc
from functools import partial
Load Protein Solubility
This example uses the dataset from the
DeepSol paper by
Khurana et al. which was obtained at
https://zenodo.org/records/1162886. It consists of amino acid
sequences of peptides along with solubility labels that are 1 if the
peptide is soluble and 0 if the peptide is insoluble.
train_sqs = open('sol_data/train_src', 'r').read().splitlines()
train_tgs = list(map(int, open('sol_data/train_tgt', 'r').read().splitlines()))
valid_sqs = open('sol_data/val_src', 'r').read().splitlines()
valid_tgs = list(map(int, open('sol_data/val_tgt', 'r').read().splitlines()))
train_sqs[:2], train_tgs[:2]
(['GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK',
'MAHHHHHHMSFFRMKRRLNFVVKRGIEELWENSFLDNNVDMKKIEYSKTGDAWPCVLLRKKSFEDLHKLYYICLKEKNKLLGEQYFHLQNSTKMLQHGRLKKVKLTMKRILTVLSRRAIHDQCLRAKDMLKKQEEREFYEIQKFKLNEQLLCLKHKMNILKKYNSFSLEQISLTFSIKKIENKIQQIDIILNPLRKETMYLLIPHFKYQRKYSDLPGFISWKKQNIIALRNNMSKLHRLY'],
[1, 0])
len(train_sqs), len(train_tgs), len(valid_sqs), len(valid_tgs)
(62478, 62478, 6942, 6942)
Data Preparation
Create a sorted list of amino acid sequences aas including an empty
string for padding and determine the size of the vocabulary.
aas = sorted(list(set("".join(train_sqs))) + [""])
vocab_size = len(aas)
aas, vocab_size
(['',
'A',
'C',
'D',
'E',
'F',
'G',
'H',
'I',
'K',
'L',
'M',
'N',
'P',
'Q',
'R',
'S',
'T',
'V',
'W',
'Y'],
21)
Create dictionaries that translate between string and integer
representations of amino acids and define the corresponding encode and
decode functions.
str2int = {aa:i for i, aa in enumerate(aas)}
int2str = {i:aa for i, aa in enumerate(aas)}
encode = lambda s: [str2int[aa] for aa in s]
decode = lambda l: ''.join([int2str[i] for i in l])
print(encode("AYWCCCGGGHH"))
print(decode(encode("AYWCCCGGGHH")))
[1, 20, 19, 2, 2, 2, 6, 6, 6, 7, 7]
AYWCCCGGGHH
Figure out what the range of lengths of amino acid sequences in the
dataset is.
train_lens = list(map(len, train_sqs))
min(train_lens), max(train_lens)
(19, 1691)
Create a function that drops all sequences above a chosen threshold and
also returns a list of indices of the sequences that meet the threshold
that can be used to obtain the correct labels.
def drop_long_sqs(sqs, threshold=1200):
new_sqs = []
idx = []
for i, sq in enumerate(sqs):
if len(sq) <= threshold:
new_sqs.append(sq)
idx.append(i)
return new_sqs, idx
Drop all sequences above your chosen threshold.
trnsqs, trnidx = drop_long_sqs(train_sqs, threshold=200)
vldsqs, vldidx = drop_long_sqs(valid_sqs, threshold=200)
len(trnidx), len(vldidx)
(18066, 1971)
max(map(len, trnsqs))
200
Create a function for zero padding all sequences.
def zero_pad(sq, length=1200):
new_sq = sq.copy()
if len(new_sq) < length:
new_sq.extend([0] * (length-len(new_sq)))
return new_sq
Now encode and zero pad all sequences and make sure that it worked out
correctly.
trn = list(map(encode, trnsqs))
vld = list(map(encode, vldsqs))
print(f"Length of the first two sequences before zero padding: {len(trn[0])}, {len(trn[1])}")
trn = list(map(partial(zero_pad, length=200), trn))
vld = list(map(partial(zero_pad, length=200), vld))
print(f"Length of the first two sequences after zero padding: {len(trn[0])}, {len(trn[1])}");
Length of the first two sequences before zero padding: 116, 135
Length of the first two sequences after zero padding: 200, 200
Convert the data to torch.tensors unsing dtype=torch.int64 and check
for correctness.
trntns = torch.tensor(trn, dtype=torch.int64)
vldtns = torch.tensor(vld, dtype=torch.int64)
trntns.shape, trntns[0]
(torch.Size([18066, 200]),
tensor([11, 9, 1, 10, 2, 10, 10, 10, 10, 13, 18, 10, 6, 10, 10, 18, 16, 16, 9, 17, 10, 2, 16, 11, 4, 4, 1, 8, 12, 4, 15, 8, 14,
4, 18, 1, 6, 16, 10, 8, 5, 15, 1, 8, 16, 16, 8, 6, 10, 4, 2, 14, 16, 18, 17, 16, 15, 6, 3, 10, 1, 17, 2, 13, 15, 6,
5, 1, 18, 17, 6, 2, 17, 2, 6, 16, 1, 2, 6, 16, 19, 3, 18, 15, 1, 4, 17, 17, 2, 7, 2, 14, 2, 1, 6, 11, 3, 19, 17,
6, 1, 15, 2, 2, 15, 18, 14, 13, 10, 4, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]))
trntns.shape, vldtns.shape
(torch.Size([18066, 200]), torch.Size([1971, 200]))
Obtain the correct labels using the lists of indices obtained from the
drop_long_sqs function and convert the lists of labels to tensors in
torch.float32 format.
trnlbs = torch.tensor(train_tgs, dtype=torch.float32)[trnidx]
vldlbs = torch.tensor(valid_tgs, dtype=torch.float32)[vldidx]
trnlbs.shape, vldlbs.shape
(torch.Size([18066]), torch.Size([1971]))
Calculate the ratios of soluble peptides in the train and valid data.
trnlbs.sum().item()/trnlbs.shape[0], vldlbs.sum().item()/vldlbs.shape[0]
(0.4722129967895494, 0.4657534246575342)
These ratios tell us that there are slightly less than half soluble
proteins in the training an validation data, and slightly more than half
in the test set.
Dataset and DataLoaders
Turn train and valid data into datasets using the
Dataset class.
trnds = Dataset(trntns, trnlbs)
vldds = Dataset(vldtns, vldlbs)
trnds[0]
(tensor([11, 9, 1, 10, 2, 10, 10, 10, 10, 13, 18, 10, 6, 10, 10, 18, 16, 16, 9, 17, 10, 2, 16, 11, 4, 4, 1, 8, 12, 4, 15, 8, 14,
4, 18, 1, 6, 16, 10, 8, 5, 15, 1, 8, 16, 16, 8, 6, 10, 4, 2, 14, 16, 18, 17, 16, 15, 6, 3, 10, 1, 17, 2, 13, 15, 6,
5, 1, 18, 17, 6, 2, 17, 2, 6, 16, 1, 2, 6, 16, 19, 3, 18, 15, 1, 4, 17, 17, 2, 7, 2, 14, 2, 1, 6, 11, 3, 19, 17,
6, 1, 15, 2, 2, 15, 18, 14, 13, 10, 4, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]),
tensor(0.))
Use the get_dls
function to obtain the dataloaders from the train and valid datasets.
dls = get_dls(trnds, vldds, bs=32)
next(iter(dls.train))[0][:2], next(iter(dls.train))[1][:2]
(tensor([[11, 1, 7, 7, 7, 7, 7, 7, 11, 11, 13, 15, 16, 4, 12, 14, 9, 4, 4, 4, 4, 19, 4, 10, 7, 13, 10, 5, 10, 16, 9, 8, 13,
12, 9, 9, 3, 8, 3, 9, 12, 13, 1, 10, 16, 1, 10, 8, 17, 10, 8, 12, 4, 4, 4, 4, 9, 4, 18, 5, 16, 20, 4, 13, 15, 15,
9, 12, 10, 9, 9, 9, 8, 17, 4, 9, 6, 14, 8, 8, 20, 9, 9, 3, 15, 15, 12, 10, 20, 4, 13, 20, 14, 14, 12, 9, 12, 3, 9,
8, 12, 20, 5, 4, 9, 9, 12, 5, 6, 6, 12, 1, 3, 8, 16, 9, 4, 4, 4, 18, 10, 3, 18, 4, 11, 3, 4, 4, 6, 17, 17, 18,
17, 17, 1, 4, 14, 6, 6, 3, 7, 16, 16, 14, 12, 18, 16, 12, 12, 14, 4, 1, 17, 3, 14, 17, 16, 8, 6, 4, 18, 10, 18, 2, 10,
16, 11, 11, 12, 10, 9, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[11, 10, 1, 11, 18, 12, 9, 18, 10, 3, 19, 8, 15, 16, 10, 5, 19, 9, 4, 4, 11, 4, 10, 17, 10, 18, 6, 10, 14, 12, 16, 6, 9,
17, 17, 5, 18, 12, 18, 8, 1, 16, 6, 14, 5, 17, 4, 3, 11, 8, 13, 17, 18, 6, 5, 12, 11, 15, 9, 8, 17, 9, 6, 12, 18, 17,
8, 9, 10, 19, 3, 8, 6, 6, 14, 13, 15, 5, 15, 16, 11, 19, 4, 15, 20, 2, 15, 6, 18, 12, 1, 8, 18, 5, 11, 18, 3, 1, 1,
3, 4, 4, 9, 10, 4, 1, 16, 15, 12, 4, 10, 11, 14, 10, 10, 3, 9, 13, 14, 10, 3, 1, 8, 13, 18, 10, 18, 10, 6, 12, 9, 9,
3, 10, 13, 6, 1, 10, 3, 4, 15, 14, 10, 8, 4, 15, 11, 12, 10, 16, 16, 8, 14, 12, 15, 4, 8, 2, 2, 20, 16, 8, 16, 2, 9,
4, 9, 4, 12, 8, 3, 8, 17, 10, 14, 19, 10, 8, 3, 7, 16, 9, 1, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]]),
tensor([1., 0.]))
Design Your Model
Let’s create a tiny model (~10k parameters) that uses a sequence of
1-dimensional convolutional layers with skip connections, kaiming he
initialization, leaky relus, batchnorm, and dropout.
First, obtain a single batch from dls to help design the model.
idx = next(iter(dls.train))[0] ## a single batch
idx, idx.shape
(tensor([[11, 9, 17, ..., 0, 0, 0],
[16, 12, 1, ..., 0, 0, 0],
[16, 1, 14, ..., 0, 0, 0],
...,
[11, 1, 7, ..., 0, 0, 0],
[11, 3, 13, ..., 0, 0, 0],
[16, 17, 12, ..., 0, 0, 0]]),
torch.Size([32, 200]))
Custom Modules
def conv1d(ni, nf, ks=3, stride=2, act=nn.ReLU, norm=None, bias=None):
if bias is None: bias = not isinstance(norm, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d))
layers = [nn.Conv1d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias)]
if norm: layers.append(norm(nf))
if act: layers.append(act())
return nn.Sequential(*layers)
def _conv1d_block(ni, nf, stride, act=nn.ReLU, norm=None, ks=3):
return nn.Sequential(conv1d(ni, nf, stride=1, act=act, norm=norm, ks=ks),
conv1d(nf, nf, stride=stride, act=None, norm=norm, ks=ks))
class ResBlock1d(nn.Module):
def __init__(self, ni, nf, stride=1, ks=3, act=nn.ReLU, norm=None):
super().__init__()
self.convs = _conv1d_block(ni, nf, stride=stride, ks=ks, act=act, norm=norm)
self.idconv = fc.noop if ni==nf else conv1d(ni, nf, stride=1, ks=1, act=None)
self.pool = fc.noop if stride==1 else nn.AvgPool1d(stride, ceil_mode=True)
self.act = act()
def forward(self, x): return self.act(self.convs(x) + self.pool(self.idconv(x)))
The following module switches the rank order from BLC to BCL.
class Reshape(nn.Module):
def forward(self, x):
B, L, C = x.shape
return x.view(B, C, L)
Model Architecture
lr = 1e-2
epochs = 30
n_embd = 16
dls = get_dls(trnds, vldds, bs=32)
act_genrelu = partial(GeneralRelu, leak=0.1, sub=0.4)
model = nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),
ResBlock1d(n_embd, 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
nn.Flatten(1, -1),
nn.Linear(32, 1),
nn.Flatten(0, -1),
nn.Sigmoid())
model(idx).shape
torch.Size([32])
iw = partial(init_weights, leaky=0.1)
model = model.apply(iw)
metrics = MetricsCB(BinaryAccuracy(), BinaryMatthewsCorrCoef(), BinaryAUROC())
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), ProgressCB(plot=False), metrics, astats]
learn = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)
print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
learn.lr_find(start_lr=1e-4, gamma=1.05, av_over=3, max_mult=5)
Parameters total: 10175
<style>
/* Turns off some styling */
progress {
/* gets rid of default border in Firefox and Opera. */
border: none;
/* Needs to be in here for Safari polyfill so background images work as expected. */
background-size: auto;
}
progress:not([value]), progress:not([value])::-webkit-progress-bar {
background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
}
.progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
background: #F44336;
}
</style>
<div>
<progress value='0' class='' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>
0.00% [0/10 00:00<?]
</div>
88.85% [502/565 00:32<00:04 0.919]
This is a pretty noisy training set, so the learning rate finder does
not work very well. Yet it is possible to get a somewhat informative
result using the av_over keyword argument that tells
lr_find to average
over the specified number of batches for each learning rate tested. It
also helps to dial the gamma value down from its default value of
1.3.
Training
learn.fit(epochs)
<style>
/* Turns off some styling */
progress {
/* gets rid of default border in Firefox and Opera. */
border: none;
/* Needs to be in here for Safari polyfill so background images work as expected. */
background-size: auto;
}
progress:not([value]), progress:not([value])::-webkit-progress-bar {
background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
}
.progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
background: #F44336;
}
</style>
BinaryAccuracy
BinaryMatthewsCorrCoef
BinaryAUROC
loss
epoch
train
0.510
0.008
0.507
0.722
0
train
0.527
-0.033
0.494
0.691
0
eval
0.515
0.004
0.506
0.695
1
train
0.534
0.003
0.520
0.691
1
eval
0.537
0.054
0.526
0.692
2
train
0.530
0.058
0.550
0.693
2
eval
0.553
0.090
0.551
0.688
3
train
0.562
0.108
0.575
0.684
3
eval
0.589
0.170
0.607
0.671
4
train
0.619
0.234
0.663
0.647
4
eval
0.622
0.247
0.630
0.653
5
train
0.643
0.302
0.676
0.637
5
eval
0.627
0.260
0.642
0.650
6
train
0.649
0.313
0.675
0.628
6
eval
0.633
0.272
0.647
0.644
7
train
0.650
0.309
0.648
0.637
7
eval
0.634
0.276
0.652
0.640
8
train
0.650
0.314
0.683
0.620
8
eval
0.637
0.284
0.646
0.641
9
train
0.651
0.320
0.685
0.617
9
eval
0.639
0.287
0.661
0.636
10
train
0.645
0.308
0.671
0.647
10
eval
0.638
0.281
0.663
0.636
11
train
0.646
0.298
0.690
0.620
11
eval
0.641
0.286
0.670
0.632
12
train
0.648
0.296
0.685
0.619
12
eval
0.641
0.290
0.668
0.632
13
train
0.662
0.323
0.691
0.618
13
eval
0.643
0.290
0.675
0.630
14
train
0.641
0.286
0.677
0.627
14
eval
0.643
0.289
0.676
0.630
15
train
0.659
0.311
0.699
0.616
15
eval
0.644
0.291
0.677
0.629
16
train
0.652
0.314
0.690
0.613
16
eval
0.646
0.296
0.675
0.626
17
train
0.645
0.282
0.694
0.622
17
eval
0.642
0.288
0.678
0.626
18
train
0.648
0.332
0.677
0.639
18
eval
0.645
0.292
0.685
0.625
19
train
0.634
0.260
0.698
0.615
19
eval
0.649
0.302
0.689
0.621
20
train
0.651
0.344
0.710
0.617
20
eval
0.648
0.299
0.685
0.624
21
train
0.660
0.315
0.700
0.614
21
eval
0.648
0.297
0.691
0.620
22
train
0.563
0.168
0.679
0.672
22
eval
0.651
0.303
0.690
0.620
23
train
0.654
0.330
0.710
0.611
23
eval
0.650
0.304
0.691
0.620
24
train
0.668
0.344
0.711
0.599
24
eval
0.654
0.311
0.692
0.617
25
train
0.649
0.294
0.698
0.620
25
eval
0.650
0.301
0.690
0.617
26
train
0.642
0.320
0.697
0.611
26
eval
0.650
0.303
0.688
0.620
27
train
0.663
0.334
0.708
0.625
27
eval
0.652
0.308
0.694
0.616
28
train
0.672
0.356
0.711
0.597
28
eval
0.651
0.304
0.695
0.617
29
train
0.659
0.322
0.702
0.603
29
eval
Inspect Activations
dls = get_dls(trnds, vldds, bs=256)
model = nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),
ResBlock1d(n_embd, 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
nn.Flatten(1, -1),
nn.Linear(32, 1),
nn.Flatten(0, -1),
nn.Sigmoid())
model = model.apply(iw)
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), astats]
learn = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)
print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
learn.fit(1)
Parameters total: 10175
astats.color_dim()
astats.plot_stats()
astats.dead_chart()
For personal and professional use. You cannot resell or redistribute these repositories in their original state.
There are no reviews.