Commit 9eeec3ba authored by Zhangkai Wu's avatar Zhangkai Wu
Browse files

上传新文件

parent d41610c2
Loading
Loading
Loading
Loading
+155 −0
Original line number Diff line number Diff line
%% Cell type:code id:dcd04598-762c-4b18-b8d8-db8b67ada05c tags:

``` python
# NOTE:
# You may choose to use ChatGPT (or any AI-based tool) to assist with your assignment,
# but you must ensure that you fully understand the entire code.
# You are solely responsible for the work you submit.
# Please keep in mind: ChatGPT will not be available during the exam.
```

%% Cell type:code id:37b9826f tags:

``` python
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision as torchvision
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
```

%% Cell type:code id:c28b4277 tags:

``` python
class MNISTCNN(pl.LightningModule):
    def __init__(self):
        super(MNISTCNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(3, 3)),  # 16, 26 ,26
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=(3, 3)),  # 32, 24, 24
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))  # 32, 12,12     (24-2) /2 +1

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3)),  # 64,10,10
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3)),  # 128,8,8
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))  # 128, 4,4

        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10))

        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        predictions = self.forward(x)
        loss = self.cross_entropy_loss(predictions, y)

        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.cross_entropy_loss(y_hat, y)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        a, y_hat = torch.max(logits, dim=1)
        test_acc = accuracy_score(y_hat.cpu(), y.cpu())
        metrics = {"test_loss": loss, "test_acc": torch.tensor(test_acc)}
        self.log_dict(metrics)
        return metrics

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        logs = {"test_loss": avg_loss, "test_acc": avg_test_acc}
        results = {
            "avg_test_loss": avg_loss,
            "avg_test_acc": avg_test_acc,
            "log": logs,
            "progress_bar": logs,
        }
        self.log_dict(results)
        return results
```

%% Cell type:code id:4ea83383 tags:

``` python
if __name__ == "__main__":
    # Main function of script
    batch_size = 128

    train_set = torchvision.datasets.MNIST('./files/', train=True, download=True,
                                           transform=torchvision.transforms.Compose([
                                               torchvision.transforms.ToTensor(),
                                               torchvision.transforms.Normalize(
                                                   (0.1307,), (0.3081,))
                                           ]))

    test_set = torchvision.datasets.MNIST('./files/', train=False, download=True,
                                          transform=torchvision.transforms.Compose([
                                              torchvision.transforms.ToTensor(),
                                              torchvision.transforms.Normalize(
                                                  (0.1307,), (0.3081,))
                                          ]))

    train_set, val_set = random_split(train_set, [54000, 6000])

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

    model = MNISTCNN()

    trainer = pl.Trainer(max_epochs=3, default_root_dir="./hands_on_custom", #3 epochs was used here to save running time, you could change it to train better results
                         callbacks=EarlyStopping(monitor="val_loss", min_delta=0.00, patience=5, verbose=True))

    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
    trainer.test(model=model, test_dataloaders=test_loader)
```

%% Cell type:code id:f658ee2d-af9b-4e8b-aecc-9cf6d02bbfdc tags:

``` python
```