Examples

01_create_new_dataset.py

Creates a new dataset of MMS plasma distribution data using Olshevsky labels.

from spacephyml.datasets.creator import create_dataset

create_dataset('./mms_region.csv',
               trange=['2017-12-04/05:00:00','2017-12-04/15:00:00'],
               clean=False, label_source='Olshevsky',
               var_list=['mms1_dis_dist_fast'])

02_load_existing_dataset.py

Load an existing dataset and model, then classify the data using the model.

import torch
from torch.utils.data import DataLoader

from spacephyml.datasets.mms import MMS1IonDistLabeled
from spacephyml.models.mms import PCReduced

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

dataset = MMS1IonDistLabeled('SCNov2017')
model = PCReduced('s42').to(device)

labels = {'human': [], 'classifier': []}
correct = 0
with torch.no_grad():
    for x,l in DataLoader(dataset, batch_size=32):
        lc = model(x.to(device)).to('cpu')
        lc = torch.argmax(lc, axis = 1)
        correct += torch.sum(lc==l)
        labels['human'].extend(l)
        labels['classifier'].extend(lc)
    print(f'Accuracy: {correct/len(labels['human'])}')

03_classify_mms_region.py

Classifying MMS dayside space plasma regions

import torch
from torch.utils.data import DataLoader
from spacephyml.datasets.general.mms import ExternalMMSData
from spacephyml.models.mms import PCReduced
from spacephyml.datasets.creator import create_dataset
from spacephyml.transforms import MMS1IonDistLabeled_Transform

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

create_dataset('./mms_region.csv',
               trange=['2017-12-04/05:00:00','2017-12-04/15:00:00'],
               clean=False, label_source='Olshevsky',
               var_list=['mms1_dis_dist_fast'])

dataset = ExternalMMSData('./mms_region.csv',
                          transform = MMS1IonDistLabeled_Transform())

model = PCReduced('s42').to(device)

labels = {'human': [], 'classifier': [], 'epoch': []}
for x, l, epoch in DataLoader(dataset, batch_size=32):
    lc = model(x.to(device)).to('cpu')
    labels['human'].extend(l)
    labels['classifier'].extend(lc)
    labels['epoch'].extend(epoch)

04_train_new_models.py

Train a new PCReduced model using the PCNov2017 dataset.


import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from spacephyml.datasets.mms import MMS1IonDistLabeled
from spacephyml.models.arcs.mms import PCReduced_arc, PCBaseline_arc

_VERBOSE = True
_EPOCHS = 5
_LEARNING_RATE = 1e-5
_BATCH_SIZE = 32


def train_loop(dataloader, model, loss_fn, optimizer, batch_size, device):
    model.train()
    size = len(dataloader.dataset)
    for batch, (x, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(x.to(device))
        loss = loss_fn(pred, y.to(device))

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if _VERBOSE and batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(x)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def train_model(model, dataset, batch_size = _BATCH_SIZE, device="cpu"):

    dataloader_train = DataLoader(dataset, batch_size = _BATCH_SIZE, shuffle=True)

    model = model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = _LEARNING_RATE)

    for t in range(_EPOCHS):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(dataloader_train, model, loss_fn, optimizer, batch_size, device)

    print("Done!")

    return model


def main(seed):
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )

    torch.manual_seed(seed)

    print(f"Using {device} device")

    dataset = MMS1IonDistLabeled('SCNov2017')

    model = PCBaseline_arc()

    model = train_model(model, dataset, device = device)

    torch.save(model.classifier.state_dict(), f"./model_PCBaseline_s{seed}.ptk")

if __name__ == "__main__":
    seed = 42
    main(seed)