Commit a866b080 authored by Zhangkai Wu's avatar Zhangkai Wu
Browse files

上传新文件

parent 49a0b361
Loading
Loading
Loading
Loading
+91 −0
Original line number Diff line number Diff line
%% Cell type:code id: tags:

``` python
# You are encouraged to follow this template as a guide, but feel free to make reasonable modifications as needed.
# The key requirement is that your code runs successfully and produces the expected results.
```

%% Cell type:code id: tags:

``` python
# NOTE:
# You may choose to use ChatGPT (or any AI-based tool) to assist with your assignment,
# but you must ensure that you fully understand the entire code.
# You are solely responsible for the work you submit.
# Please keep in mind: ChatGPT will not be available during the exam.
```

%% Cell type:code id: tags:

``` python
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision as torchvision
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split


class MNISTClassification(pl.LightningModule):
    def __init__(self):
        super(MNISTClassification, self).__init__()
        # TODO: Implement this function

    def forward(self, x):
        # TODO: Implement this function
        return 0

    def configure_optimizers(self):
        # TODO: Implement this function
        return 0

    def training_step(self, batch, batch_idx):
        # TODO: Implement this function
        return 0

    def validation_step(self, batch, batch_idx):
        # TODO: Implement this function
        return 0

    def test_step(self, batch, batch_idx):
        # TODO: Implement this function
        return 0

    def test_epoch_end(self, outputs):
        # TODO: Implement this function
        return 0


if __name__ == "__main__":
    # Main function of script
    train_set = torchvision.datasets.MNIST(
        "./files/",
        train=True,
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                torchvision.transforms.Lambda(lambda x: torch.reshape(x, (784,))),
            ]
        ),
    )

    test_set = torchvision.datasets.MNIST(
        "./files/",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                torchvision.transforms.Lambda(lambda x: torch.reshape(x, (784,))),
            ]
        ),
    )

    train_set, val_set = random_split(train_set, [54000, 6000])

    # TODO: Implement data loader, learning and prediction
```