一组脚本,使使用 pytorch 进行神经网络训练更快
项目描述
Scorch:使用 PyTorch 进行网络训练的实用程序
这个包是一组代码,可以经常用于不同的网络训练。
先决条件
你需要 Python 3 才能使用这个包。
您将需要安装以下软件包:
- 火炬
- 麻木的
- scikit 图像
- scikit-学习
要使用笔记本测试模型和数据集,您需要安装 Jupyter Notebook 或 JupyterLab。
用法
这是使用 DATASET_FILE 中指定的数据集训练 MODEL_FILE 中指定的模型的最小命令,数据位于 DATASET_PATH 中。
python train.py --model MODEL_FILE --dataset DATASET_FILE --dataset-path DATASET_PATH
这是脚本的参数列表(它将很快更新):
-b BATCH_SIZE, --batch-size BATCH_SIZE
Batch size to train or validate your model
-w WORKERS, --workers WORKERS
Number of workers in a dataloader
--pretraining Pretraining mode
-lr LEARNING_RATE, --learning-rate LEARNING_RATE
Learning rate
-d DUMP_PERIOD, --dump-period DUMP_PERIOD
Dump period
-e EPOCHS, --epochs EPOCHS
Number of epochs to perform
-c CHECKPOINT, --checkpoint CHECKPOINT
Checkpoint to load from
--use-cuda Use cuda for training
--validate-on-train Flag showing that you want to perform validation on training dataset
along with the validation on the validation set
--model MODEL File with a model specification
--dataset DATASET File with a dataset sepcification
--max-train-iterations MAX_TRAIN_ITERATIONS
Maximum training iterations
--max-valid-iterations MAX_VALID_ITERATIONS
Maximum validation iterations
-dp DATASET_PATH, --dataset-path DATASET_PATH
Path to the dataset
-v VERBOSITY, --verbosity VERBOSITY
-1 for no output, 0 for epoch output, positive number
is printout frequency during the training
-cp CHECKPOINT_PREFIX, --checkpoint-prefix CHECKPOINT_PREFIX
Prefix to the checkpoint name
模型模块语法
模型文件的语法如下:
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
pass
def forward(self, input):
return [output1, output2]
def __call__(self, input):
return self.forward(input)
class Socket:
def __init__(self, model):
self.model = model
def criterion(self, pred, target):
pass
def metrics(self, pred, target):
pass
要求如下:
- 网络应该(显然)有构造函数
- 网络应具有转发功能:
- 将输入列表作为输入。
- 输出输出列表。
- 还应该
__call__指定作为转发函数代理的函数。 - 应该
Socket定义一个类来指定如何处理模型,它应该包含:criterion该方法将带有预测的张量列表和带有目标的张量列表作为输入。输出应该是一个数字。metrics指定您的实验感兴趣的指标的方法。它应该将带有预测的张量列表和带有目标的张量列表作为输入,并返回一个指标列表。
到处都有列表的原因如下:网络可能有多个输入和多个输出。我们必须足够聪明地处理这个事实以重用代码。因此,最好的方法是在列表中传递兴趣值。
数据集模块语法
这是 Dataset 模块的语法:
class DataSetIndex():
def __init__(self, path):
pass
class DataSet():
def __init__(self, ds_index, mode='train'):
self.ds_index = ds_index
def __len__(self):
if self.mode == 'test':
pass
elif self.mode == 'valid':
pass
else:
pass
def __getitem__(self, index):
img = None
target = None
if self.mode == 'test':
pass
elif self.mode == 'valid':
pass
else:
pass
return [img1, img2], [target1, target2]
数据集脚本应至少具有应指定以下内容的类 DataSet:
__init__,定义数据集的所有三个部分的构造函数。应在此处定义数据集的模式。__len__返回数据集长度的函数__getitem__返回输入张量列表和目标张量列表的函数
尽管仅指定 DataSet 就足够了,但建议还指定包含有关数据集数据的信息的 DataSetIndex 类。建议在不同模式的DataSet的所有实例之间共享一个DataSetIndex实例,以避免存储该索引的内存增加一倍或三倍,也避免多次收集数据集索引。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。
源分布
pytorch-scorch-0.0.7.tar.gz
(18.5 kB
查看哈希)
内置分布
pytorch_scorch-0.0.7-py3-none-any.whl
(17.7 kB
查看哈希)
关
pytorch_scorch -0.0.7-py3-none-any.whl 的哈希值
| 算法 | 哈希摘要 | |
|---|---|---|
| SHA256 | 6a2136a9302a4a19d7ab9be5528e283b7610490f3371c31ec3ab09dd6e316646 |
|
| MD5 | 07e998a6fd7fdba54475ac7f50ad0c26 |
|
| 布莱克2-256 | abb942250110fdecb7c6d54ba4f89a44830b101e3dd6383d17b71fea30cd1023 |