Commit d6e85b0a authored by Zhangkai Wu's avatar Zhangkai Wu
Browse files

删除template_assignment_8_4_custom_cnn.ipynb

parent 636a93d3
Loading
Loading
Loading
Loading
+0 −76
Original line number Diff line number Diff line
%% Cell type:code id: tags:

``` python
# You are encouraged to follow this template as a guide, but feel free to make reasonable modifications as needed.
# The key requirement is that your code runs successfully and produces the expected results.
```

%% Cell type:code id: tags:

``` python
# NOTE:
# You may choose to use ChatGPT (or any AI-based tool) to assist with your assignment,
# but you must ensure that you fully understand the entire code.
# You are solely responsible for the work you submit.
# Please keep in mind: ChatGPT will not be available during the exam.
```

%% Cell type:code id: tags:

``` python
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as torchvision
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split


class PLANTSCNN(pl.LightningModule):
    def __init__(self):
        super(PLANTSCNN, self).__init__()
        # TODO: Implement this function

    def forward(self, x):
        # TODO: Implement this function
        return 0

    def configure_optimizers(self):
        # TODO: Implement this function
        return 0

    def training_step(self, batch, batch_idx):
        # TODO: Implement this function
        return 0

    def validation_step(self, batch, batch_idx):
        # TODO: Implement this function
        return 0

    def test_step(self, batch, batch_idx):
        # TODO: Implement this function
        return 0

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        logs = {"test_loss": avg_loss, "test_acc": avg_test_acc}
        results = {
            "avg_test_loss": avg_loss,
            "avg_test_acc": avg_test_acc,
            "log": logs,
            "progress_bar": logs,
        }
        self.log_dict(results)
        return results


if __name__ == "__main__":
    # Main function of script
    # TODO: Implement data loading, learning and prediction
    # data path:
    # train "/home/jovyan/data/IntroML/Chapter8_data/Assignment_4/train_crop_images"
    # test "/home/jovyan/data/IntroML/Chapter8_data/Assignment_4/test_crop_images"
```