Skip to main content

一组脚本,使使用 pytorch 进行神经网络训练更快

项目描述

Scorch:使用 PyTorch 进行网络训练的实用程序

这个包是一组代码,可以经常用于不同的网络训练。

先决条件

你需要 Python 3 才能使用这个包。

您将需要安装以下软件包:

  • 火炬
  • 麻木的
  • scikit 图像
  • scikit-学习

要使用笔记本测试模型和数据集,您需要安装 Jupyter Notebook 或 JupyterLab。

用法

这是使用 DATASET_FILE 中指定的数据集训练 MODEL_FILE 中指定的模型的最小命令,数据位于 DATASET_PATH 中。

python train.py --model MODEL_FILE --dataset DATASET_FILE --dataset-path DATASET_PATH

这是脚本的参数列表(它将很快更新):

  -b BATCH_SIZE, --batch-size BATCH_SIZE
                        Batch size to train or validate your model
  -w WORKERS, --workers WORKERS
                        Number of workers in a dataloader
  --pretraining         Pretraining mode
  -lr LEARNING_RATE, --learning-rate LEARNING_RATE
                        Learning rate
  -d DUMP_PERIOD, --dump-period DUMP_PERIOD
                        Dump period
  -e EPOCHS, --epochs EPOCHS
                        Number of epochs to perform
  -c CHECKPOINT, --checkpoint CHECKPOINT
                        Checkpoint to load from
  --use-cuda            Use cuda for training
  --validate-on-train   Flag showing that you want to perform validation on training dataset 
                        along with the validation on the validation set

  --model MODEL         File with a model specification
  --dataset DATASET     File with a dataset sepcification
  --max-train-iterations MAX_TRAIN_ITERATIONS
                        Maximum training iterations
  --max-valid-iterations MAX_VALID_ITERATIONS
                        Maximum validation iterations
  -dp DATASET_PATH, --dataset-path DATASET_PATH
                        Path to the dataset
  -v VERBOSITY, --verbosity VERBOSITY
                        -1 for no output, 0 for epoch output, positive number
                        is printout frequency during the training
  -cp CHECKPOINT_PREFIX, --checkpoint-prefix CHECKPOINT_PREFIX
                        Prefix to the checkpoint name

模型模块语法

模型文件的语法如下:

class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        pass

    def forward(self, input):
        return [output1, output2]

    def __call__(self, input):
        return self.forward(input)

class Socket:
    def __init__(self, model):
        self.model = model

    def criterion(self, pred, target):
        pass

    def metrics(self, pred, target):
        pass

要求如下:

  • 网络应该(显然)有构造函数
  • 网络应具有转发功能:
    • 将输入列表作为输入。
    • 输出输出列表
  • 还应该__call__指定作为转发函数代理的函数。
  • 应该Socket定义一个类来指定如何处理模型,它应该包含:
    • criterion该方法将带有预测的张量列表和带有目标的张量列表作为输入。输出应该是一个数字。
    • metrics指定您的实验感兴趣的指标的方法。它应该将带有预测的张量列表和带有目标的张量列表作为输入,并返回一个指标列表。

到处都有列表的原因如下:网络可能有多个输入和多个输出。我们必须足够聪明地处理这个事实以重用代码。因此,最好的方法是在列表中传递兴趣值。

数据集模块语法

这是 Dataset 模块的语法:

class DataSetIndex():
    def __init__(self, path):
        pass

class DataSet():
    def __init__(self, ds_index, mode='train'):
        self.ds_index = ds_index

    def __len__(self):
        if self.mode == 'test':
            pass

        elif self.mode == 'valid':
            pass

        else:
            pass


    def __getitem__(self, index):
        img = None
        target = None

        if self.mode == 'test':
            pass

        elif self.mode == 'valid':
            pass

        else:
            pass

        return [img1, img2], [target1, target2]

数据集脚本应至少具有应指定以下内容的类 DataSet:

  • __init__,定义数据集的所有三个部分的构造函数。应在此处定义数据集的模式。
  • __len__返回数据集长度的函数
  • __getitem__返回输入张量列表和目标张量列表的函数

尽管仅指定 DataSet 就足够了,但建议还指定包含有关数据集数据的信息的 DataSetIndex 类。建议在不同模式的DataSet的所有实例之间共享一个DataSetIndex实例,以避免存储该索引的内存增加一倍或三倍,也避免多次收集数据集索引。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

pytorch-scorch-0.0.7.ta​​r.gz (18.5 kB 查看哈希

已上传 source

内置分布

pytorch_scorch-0.0.7-py3-none-any.whl (17.7 kB 查看哈希

已上传 py3