Commit cb907a5e authored by Zhangkai Wu's avatar Zhangkai Wu
Browse files

上传新文件

parent c2d997cb
Loading
Loading
Loading
Loading
+130 −0
Original line number Diff line number Diff line
%% Cell type:code id:a139b79a-05ea-4e35-aa93-5061e7285d40 tags:

``` python
# NOTE:
# You may choose to use ChatGPT (or any AI-based tool) to assist with your assignment,
# but you must ensure that you fully understand the entire code.
# You are solely responsible for the work you submit.
# Please keep in mind: ChatGPT will not be available during the exam.
```

%% Cell type:code id:551d437a tags:

``` python
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision as torchvision
import torchvision.models as models
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
```

%% Cell type:code id:9ae7f21b tags:

``` python
class TransferCNN(pl.LightningModule):
    def __init__(self):
        super(TransferCNN, self).__init__()

        self.model_ft = models.resnet18(pretrained=True)
        for param in self.model_ft.parameters():
            param.requires_grad = False
        num_ftrs = self.model_ft.fc.in_features
        self.model_ft.fc = nn.Linear(num_ftrs, 5)

        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.model_ft(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        predictions = self.forward(x)
        loss = self.cross_entropy_loss(predictions, y)

        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.cross_entropy_loss(y_hat, y)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        a, y_hat = torch.max(logits, dim=1)
        test_acc = accuracy_score(y_hat.cpu(), y.cpu())
        metrics = {"test_loss": loss, "test_acc": torch.tensor(test_acc)}
        self.log_dict(metrics)
        return metrics

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        logs = {"test_loss": avg_loss, "test_acc": avg_test_acc}
        results = {
            "avg_test_loss": avg_loss,
            "avg_test_acc": avg_test_acc,
            "log": logs,
            "progress_bar": logs,
        }
        self.log_dict(results)
        return results
```

%% Cell type:code id:a024efa3 tags:

``` python
if __name__ == "__main__":
    # Main function of script

    batch_size = 16

    train_set = torchvision.datasets.ImageFolder(root="/home/jovyan/data/IntroML/Chapter8_data/Assignment_4/train_crop_images",
                                                 transform=torchvision.transforms.Compose(
                                                     [torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.RandomResizedCrop(224),
                                                      torchvision.transforms.Normalize((0.5, 0.5, 0.5),
                                                                                       (0.5, 0.5, 0.5))]))

    test_set = torchvision.datasets.ImageFolder(root="/home/jovyan/data/IntroML/Chapter8_data/Assignment_4/test_crop_images",
                                                transform=torchvision.transforms.Compose(
                                                    [torchvision.transforms.ToTensor(),
                                                     torchvision.transforms.RandomResizedCrop(224),
                                                     torchvision.transforms.Normalize((0.5, 0.5, 0.5),
                                                                                      (0.5, 0.5, 0.5))]))

    train_set, val_set = random_split(train_set, [770, 84])

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                               shuffle=True, num_workers=12)

    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,
                                             shuffle=False, num_workers=12)

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                              shuffle=False, num_workers=12)

    model = TransferCNN()

    trainer = pl.Trainer(max_epochs=3, gpus=0, default_root_dir="./hands_on_transfer", #3 epochs was used here to save running time, you could change it to train better results
                         callbacks=EarlyStopping(monitor="val_loss", min_delta=0.00, patience=15, verbose=True))

    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
    trainer.test(model=model, test_dataloaders=test_loader)
```

%% Cell type:code id:64c5b677-9db8-4d23-97bd-e2c4ec177421 tags:

``` python
```