Deep learning с использованием языка R и библиотеки mxnet. Предсказания, итераторы и дополнительные возможности
1. Вступление
Это сообщение является продолжением Deep learning с использованием языка R и библиотеки mxnet. Установка и начало работы. Будет рассмотрено предсказание классов изображений на основе модели, а также работа с итераторами и некоторые другие аспекты.
Полезные ссылки:
End-to-End Deep Learning Tutorial,
https://github.com/dmlc/mxnet/tree/master/docs/tutorials/r,
https://github.com/dmlc/mxnet/tree/master/R-package/vignettes.
По двум последним ссылкам доступна самая актуальная документация и примеры от разработчиков.
2. Создание “словаря синонимов”
В комплекте с готовыми моделями, доступными для скачивания, обычно идет файл synset.txt. Этот файл содержит информацию о соответствии между номером класса и его названием/меткой, например 0 "airplane"
. При создании бинарного файла с изображениями были также созданы файлы в формате .lst, из которых легко получить нужную нам таблицу:
setwd("~/R/cifar10/")
library(mxnet)
library(imager)
library(abind)
labels <- read.table("cifar_train.lst")
# Оставляем только уникальные значения
labels <- labels[!duplicated(labels$V2), ]
# Разделяем имена файлов и имена папок, которые соответствуют меткам классов
tmp <- strsplit(as.character(labels$V3), split = "/", fixed = TRUE)
# Создаем таблицу с метками и номерами классов, сортируем по возрастанию номеров
class_labels <- data.frame(response = labels$V2,
label = sapply(tmp, function(x) x[1]))
class_labels <- class_labels[order(class_labels$response), ]
# Сохраняем в файл "synset.txt" для дальнейшего использования
write.table(class_labels, "synset.txt", row.names = FALSE, col.names = TRUE)
class_labels <- read.table("synset.txt", header = TRUE)
3. Обучение модели
Повторим обучение той же модели с тем же набором данных, что и в прошлый раз. Но теперь укажем размер изображений 32х30, то есть 32 пикселя в высоту и 30 пикселей в ширину (исходные картинки 32х32 будут обрезаться). Это нужно для лучшего понимания того, как правильно указывать размерности на этапе обучения модели и при ее использовании для предсказаний. Значения аргументов kernel
, stride
, pad
задаются всегда в том же формате: сначала высота (y-координата), затем ширина (x-координата); третьим числов в векторе размерностей может быть глубина.
Создаем итераторы:
get_iterator <- function(data_shape,
train_data,
val_data,
batch_size = 128) {
train <- mx.io.ImageRecordIter(
path.imgrec = train_data,
batch.size = batch_size,
data.shape = data_shape,
rand.crop = TRUE,
rand.mirror = TRUE)
val <- mx.io.ImageRecordIter(
path.imgrec = val_data,
batch.size = batch_size,
data.shape = data_shape,
rand.crop = FALSE,
rand.mirror = FALSE
)
return(list(train = train, val = val))
}
data <- get_iterator(data_shape = c(32, 30, 3), # 32 пикселя в высоту
train_data = "/home/andrey/R/cifar10/cifar_train.rec",
val_data = "/home/andrey/R/cifar10/cifar_val.rec",
batch_size = 100)
train <- data$train
val <- data$val
Используем ту же архитектуру Resnet:
conv_factory <- function(data, num_filter, kernel, stride,
pad, act_type = 'relu', conv_type = 0) {
if (conv_type == 0) {
conv = mx.symbol.Convolution(data = data, num_filter = num_filter,
kernel = kernel, stride = stride, pad = pad)
bn = mx.symbol.BatchNorm(data = conv)
act = mx.symbol.Activation(data = bn, act_type = act_type)
return(act)
} else if (conv_type == 1) {
conv = mx.symbol.Convolution(data = data, num_filter = num_filter,
kernel = kernel, stride = stride, pad = pad)
bn = mx.symbol.BatchNorm(data = conv)
return(bn)
}
}
residual_factory <- function(data, num_filter, dim_match) {
if (dim_match) {
identity_data = data
conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3),
stride = c(1, 1), pad = c(1, 1), act_type = 'relu', conv_type = 0)
conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3),
stride = c(1, 1), pad = c(1, 1), conv_type = 1)
new_data = identity_data + conv2
act = mx.symbol.Activation(data = new_data, act_type = 'relu')
return(act)
} else {
conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3),
stride = c(2, 2), pad = c(1, 1), act_type = 'relu', conv_type = 0)
conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3),
stride = c(1, 1), pad = c(1, 1), conv_type = 1)
# adopt project method in the paper when dimension increased
project_data = conv_factory(data = data, num_filter = num_filter, kernel = c(1, 1),
stride = c(2, 2), pad = c(0, 0), conv_type = 1)
new_data = project_data + conv2
act = mx.symbol.Activation(data = new_data, act_type = 'relu')
return(act)
}
}
residual_net <- function(data, n) {
#fisrt 2n layers
for (i in 1:n) {
data = residual_factory(data = data, num_filter = 16, dim_match = TRUE)
}
#second 2n layers
for (i in 1:n) {
if (i == 1) {
data = residual_factory(data = data, num_filter = 32, dim_match = FALSE)
} else {
data = residual_factory(data = data, num_filter = 32, dim_match = TRUE)
}
}
#third 2n layers
for (i in 1:n) {
if (i == 1) {
data = residual_factory(data = data, num_filter = 64, dim_match = FALSE)
} else {
data = residual_factory(data = data, num_filter = 64, dim_match = TRUE)
}
}
return(data)
}
get_symbol <- function(num_classes = 10) {
conv <- conv_factory(data = mx.symbol.Variable(name = 'data'), num_filter = 16,
kernel = c(3, 3), stride = c(1, 1), pad = c(1, 1),
act_type = 'relu', conv_type = 0)
n <- 3 # set n = 3 means get a model with 3*6+2=20 layers, set n = 9 means 9*6+2=56 layers
resnet <- residual_net(conv, n) #
pool <- mx.symbol.Pooling(data = resnet, kernel = c(7, 7), pool_type = 'avg')
flatten <- mx.symbol.Flatten(data = pool, name = 'flatten')
fc <- mx.symbol.FullyConnected(data = flatten, num_hidden = num_classes, name = 'fc1')
softmax <- mx.symbol.SoftmaxOutput(data = fc, name = 'softmax')
return(softmax)
}
# Сеть для 10 классов
resnet <- get_symbol(10)
Обучаем модель в течение 15 эпох:
model <- mx.model.FeedForward.create(
symbol = resnet,
X = train,
eval.data = val,
ctx = mx.gpu(0),
eval.metric = mx.metric.accuracy,
num.round = 15,
learning.rate = 0.05,
momentum = 0.9,
wd = 0.00001,
kvstore = "local",
array.batch.size = 100,
epoch.end.callback = NULL,
batch.end.callback = mx.callback.log.train.metric(150),
initializer = mx.init.Xavier(factor_type = "in", magnitude = 2.34),
optimizer = "sgd"
)
## Start training with 1 devices
## Batch [150] Train-accuracy=0.2868
## Batch [300] Train-accuracy=0.347766666666667
## [1] Train-accuracy=0.38280701754386
## [1] Validation-accuracy=0.4331
## Batch [150] Train-accuracy=0.527066666666666
## Batch [300] Train-accuracy=0.557066666666667
## [2] Train-accuracy=0.5738
## [2] Validation-accuracy=0.5977
## Batch [150] Train-accuracy=0.637066666666666
## Batch [300] Train-accuracy=0.649399999999999
## [3] Train-accuracy=0.660374999999999
## [3] Validation-accuracy=0.6443
## Batch [150] Train-accuracy=0.704533333333333
## Batch [300] Train-accuracy=0.711766666666666
## [4] Train-accuracy=0.718349999999999
## [4] Validation-accuracy=0.7042
## Batch [150] Train-accuracy=0.7422
## Batch [300] Train-accuracy=0.747533333333334
## [5] Train-accuracy=0.752975000000001
## [5] Validation-accuracy=0.7198
## Batch [150] Train-accuracy=0.7704
## Batch [300] Train-accuracy=0.772866666666667
## [6] Train-accuracy=0.777025
## [6] Validation-accuracy=0.7485
## Batch [150] Train-accuracy=0.79
## Batch [300] Train-accuracy=0.790533333333334
## [7] Train-accuracy=0.794425000000001
## [7] Validation-accuracy=0.772699999999999
## Batch [150] Train-accuracy=0.806733333333333
## Batch [300] Train-accuracy=0.807766666666668
## [8] Train-accuracy=0.8109
## [8] Validation-accuracy=0.7733
## Batch [150] Train-accuracy=0.822799999999999
## Batch [300] Train-accuracy=0.821533333333334
## [9] Train-accuracy=0.825025
## [9] Validation-accuracy=0.7679
## Batch [150] Train-accuracy=0.8332
## Batch [300] Train-accuracy=0.833066666666668
## [10] Train-accuracy=0.834800000000001
## [10] Validation-accuracy=0.7856
## Batch [150] Train-accuracy=0.845333333333333
## Batch [300] Train-accuracy=0.842533333333334
## [11] Train-accuracy=0.843475
## [11] Validation-accuracy=0.7886
## Batch [150] Train-accuracy=0.8502
## Batch [300] Train-accuracy=0.850333333333335
## [12] Train-accuracy=0.851950000000001
## [12] Validation-accuracy=0.7801
## Batch [150] Train-accuracy=0.860866666666666
## Batch [300] Train-accuracy=0.858900000000001
## [13] Train-accuracy=0.8601
## [13] Validation-accuracy=0.8068
## Batch [150] Train-accuracy=0.8648
## Batch [300] Train-accuracy=0.865733333333334
## [14] Train-accuracy=0.86715
## [14] Validation-accuracy=0.8171
## Batch [150] Train-accuracy=0.867466666666666
## Batch [300] Train-accuracy=0.869833333333333
## [15] Train-accuracy=0.871949999999999
## [15] Validation-accuracy=0.8096
4. Предсказания на основе модели
Для работы с изображениями в R будем использовать пакет imager.
В общих чертах процесс описан в руководстве Classify Images with a Pretrained Model, но там есть некоторые нюансы и неточности. Примерами будут служить следующие два изображения:
Разберем операции предварительной обработки подробно для первого изображения:
# Скачиваем или загружаем с диска изображение
im <- load.image("http://kindersay.com/files/images/bird.png")
# Image. Width: 445 pix Height: 355 pix Depth: 1 Colour channels: 3
# Размерность: ширина на высоту
# Depth - количество кадров, если это видео; для изображений всегда 1
# Colour channels - 3 цветовых канала (RGB)
shape <- dim(im)
# 445 355 1 3
# Индексация идет сначала по столбцам (ширина), затем по строкам (высота) -
# изображение из линейного вектора формируется именно в таком порядке
# В R матрицы формируются и индексируются в обратном порядке,
# то есть используется так называемый Fortran-style, или порядок column-major:
# a <- 1:4
# dim(a) <- c(2, 2)
# a
# [,1] [,2]
# [1,] 1 3
# [2,] 2 4
# Меняем размер на 32х30, требуемый для нашей модели
# Обрезка (crop) не используется
resized <- resize(im, size_x = 30, size_y = 32)
# Image. Width: 30 pix Height: 32 pix Depth: 1 Colour channels: 3
# Конвертируем в массив
# Если значения для каждого цветового канала заданы в диапазоне [0, 1],
# то нужно умножить на 255. В нашем случае это не требуется
arr <- as.array(resized)
# 30 32 1 3 - 30 строк, а не 32 строки, как в изображении
# Произошло транспонирование: строки стали столбцами
# Средние значения для каждого пикселя не отнимаем
# Задаем нужный формат (width, height, channel, num)
dim(arr) <- c(30, 32, 3, 1)
# Предсказываем вероятности и класс
prob <- predict(model, X = arr)
prob
## [,1]
## [1,] 0.097901896
## [2,] 0.002545726
## [3,] 0.742579162
## [4,] 0.041871570
## [5,] 0.016076958
## [6,] 0.017367113
## [7,] 0.011065663
## [8,] 0.066732869
## [9,] 0.001639599
## [10,] 0.002219437
class_labels$label[prob == max(prob)]
## [1] bird
## Levels: airplane automobile bird cat deer dog frog horse ship truck
Все этапы предварительной обработки можно оформить в виде функции (измененный вариант preproc.image
из https://github.com/dmlc/mxnet/blob/master/docs/tutorials/r/classifyRealImageWithPretrainedModel.md):
preproc_image <- function(src, # URL or file location
height,
width,
num_channels = 3, # 3 for RGB, 1 for grayscale
mult_by = 1, # set to 255 for normalized image
crop = FALSE) { # no crop by default
im <- load.image(src)
if (crop) {
shape <- dim(im)
short_edge <- min(shape[1:2])
xx <- floor((shape[1] - short_edge) / 2)
yy <- floor((shape[2] - short_edge) / 2)
im <- crop.borders(im, xx, yy)
}
resized <- resize(im, size_x = width, size_y = height)
arr <- as.array(resized) * mult_by
dim(arr) <- c(width, height, num_channels, 1)
return(arr)
}
Предсказание для картинки с оленем:
arr <- preproc_image("http://kingofwallpapers.com/deer-images/deer-images-007.jpg",
height = 32,
width = 30)
prob <- predict(model, X = arr)
prob
## [,1]
## [1,] 1.974541e-05
## [2,] 4.638057e-06
## [3,] 3.291356e-02
## [4,] 6.472487e-03
## [5,] 8.688778e-01
## [6,] 1.603282e-02
## [7,] 4.853056e-03
## [8,] 7.081089e-02
## [9,] 2.876292e-06
## [10,] 1.208015e-05
class_labels$label[prob == max(prob)]
## [1] deer
## Levels: airplane automobile bird cat deer dog frog horse ship truck
Рассмотренную особенность с порядком индексации массивов нужно учитывать и при использовании итераторов по файлам в формате .csv, таких как в этом примере. При создании такого файла исходное изображение (или несколько изображений, которые соответствуют одному наблюдению) превращаются в вектор, вектор становится строкой файла, а затем при обучении модели и при использовании модели для предсказаний задается правильная размерность массива. Это позволяет воспроизвести пространственную структуру входных данных, хранящихся в линейном виде.
Вопросы оптимальной реализации этих операций в данном сообщении не рассматриваются, но при работе с большим количество изображений наверняка пригодятся пакеты типа foreach и doParallel для параллельной обработки, также может быть полезен пакет data.table и консольные утилиты.
Простейший пример обработки нескольких изображений с использованием пакета abind:
image_urls <- c(
"http://kindersay.com/files/images/bird.png",
"http://kingofwallpapers.com/deer-images/deer-images-007.jpg"
)
images <- lapply(image_urls,
preproc_image,
height = 32,
width = 30)
images <- do.call(abind, images)
probs <- predict(model, X = images)
probs
## [,1] [,2]
## [1,] 0.097901933 1.974543e-05
## [2,] 0.002545726 4.638062e-06
## [3,] 0.742579103 3.291357e-02
## [4,] 0.041871566 6.472491e-03
## [5,] 0.016076960 8.688778e-01
## [6,] 0.017367108 1.603283e-02
## [7,] 0.011065662 4.853058e-03
## [8,] 0.066732846 7.081092e-02
## [9,] 0.001639599 2.876297e-06
## [10,] 0.002219437 1.208016e-05
class_labels$label[apply(probs, 2, function(x) which(x == max(x)))]
## [1] bird deer
## Levels: airplane automobile bird cat deer dog frog horse ship truck
5. Итераторы
Для работы с любыми данными, которые не помещаются в памяти, можно использовать функцию mx.io.CSVIter()
. Как понятно из названия, она обрабатывает файлы в формате .csv построчно и конструирует из каждой строки массив (тензор) нужной размерности, которая задается аргументами data.shape
для самого набора данных и label.shape
для целевой переменной. См. https://github.com/dmlc/mxnet/tree/master/example/kaggle-ndsb2. За создание .csv-файлов там отвечает код на Python. Аналог на R можно написать, взяв за основу представленную выше функцию preproc_image()
и заменив dim(arr) <- c(width, height, num_channels, 1)
на dim(arr) <- c(width * height * num_channels * 1)
.
Также есть возможность создавать свои собственные итераторы - см. Custom Iterator Tutorial. Чтобы сделать что-то действительно серьезно отличающееся от представленного варианта, понадобятся знания C++.
6. Доступные слои и функции потерь
Список слоев довольно обширен:
apropos("mx.symbol.")
## [1] "mx.symbol.abs"
## [2] "mx.symbol.Activation"
## [3] "mx.symbol.adam_update"
## [4] "mx.symbol.arccos"
## [5] "mx.symbol.arccosh"
## [6] "mx.symbol.arcsin"
## [7] "mx.symbol.arcsinh"
## [8] "mx.symbol.arctan"
## [9] "mx.symbol.arctanh"
## [10] "mx.symbol.argmax"
## [11] "mx.symbol.argmax_channel"
## [12] "mx.symbol.argmin"
## [13] "mx.symbol.argsort"
## [14] "mx.symbol.batch_dot"
## [15] "mx.symbol.BatchNorm"
## [16] "mx.symbol.BlockGrad"
## [17] "mx.symbol.broadcast_add"
## [18] "mx.symbol.broadcast_axis"
## [19] "mx.symbol.broadcast_div"
## [20] "mx.symbol.broadcast_equal"
## [21] "mx.symbol.broadcast_greater"
## [22] "mx.symbol.broadcast_greater_equal"
## [23] "mx.symbol.broadcast_hypot"
## [24] "mx.symbol.broadcast_lesser"
## [25] "mx.symbol.broadcast_lesser_equal"
## [26] "mx.symbol.broadcast_maximum"
## [27] "mx.symbol.broadcast_minimum"
## [28] "mx.symbol.broadcast_minus"
## [29] "mx.symbol.broadcast_mul"
## [30] "mx.symbol.broadcast_not_equal"
## [31] "mx.symbol.broadcast_plus"
## [32] "mx.symbol.broadcast_power"
## [33] "mx.symbol.broadcast_sub"
## [34] "mx.symbol.broadcast_to"
## [35] "mx.symbol.Cast"
## [36] "mx.symbol.ceil"
## [37] "mx.symbol.choose_element_0index"
## [38] "mx.symbol.clip"
## [39] "mx.symbol.Concat"
## [40] "mx.symbol.Convolution"
## [41] "mx.symbol.Correlation"
## [42] "mx.symbol.cos"
## [43] "mx.symbol.cosh"
## [44] "mx.symbol.crop"
## [45] "mx.symbol.Crop"
## [46] "mx.symbol.CuDNNBatchNorm"
## [47] "mx.symbol.Custom"
## [48] "mx.symbol.Deconvolution"
## [49] "mx.symbol.degrees"
## [50] "mx.symbol.dot"
## [51] "mx.symbol.Dropout"
## [52] "mx.symbol.ElementWiseSum"
## [53] "mx.symbol.elemwise_add"
## [54] "mx.symbol.Embedding"
## [55] "mx.symbol.exp"
## [56] "mx.symbol.expand_dims"
## [57] "mx.symbol.expm1"
## [58] "mx.symbol.fill_element_0index"
## [59] "mx.symbol.fix"
## [60] "mx.symbol.Flatten"
## [61] "mx.symbol.flip"
## [62] "mx.symbol.floor"
## [63] "mx.symbol.FullyConnected"
## [64] "mx.symbol.gamma"
## [65] "mx.symbol.gammaln"
## [66] "mx.symbol.Group"
## [67] "mx.symbol.identity"
## [68] "mx.symbol.IdentityAttachKLSparseReg"
## [69] "mx.symbol.infer.shape"
## [70] "mx.symbol.InstanceNorm"
## [71] "mx.symbol.L2Normalization"
## [72] "mx.symbol.LeakyReLU"
## [73] "mx.symbol.LinearRegressionOutput"
## [74] "mx.symbol.load"
## [75] "mx.symbol.load.json"
## [76] "mx.symbol.log"
## [77] "mx.symbol.log10"
## [78] "mx.symbol.log1p"
## [79] "mx.symbol.log2"
## [80] "mx.symbol.LogisticRegressionOutput"
## [81] "mx.symbol.LRN"
## [82] "mx.symbol.MAERegressionOutput"
## [83] "mx.symbol.MakeLoss"
## [84] "mx.symbol.max"
## [85] "mx.symbol.max_axis"
## [86] "mx.symbol.min"
## [87] "mx.symbol.min_axis"
## [88] "mx.symbol.nanprod"
## [89] "mx.symbol.nansum"
## [90] "mx.symbol.negative"
## [91] "mx.symbol.norm"
## [92] "mx.symbol.normal"
## [93] "mx.symbol.Pad"
## [94] "mx.symbol.Pooling"
## [95] "mx.symbol.prod"
## [96] "mx.symbol.radians"
## [97] "mx.symbol.Reshape"
## [98] "mx.symbol.rint"
## [99] "mx.symbol.RNN"
## [100] "mx.symbol.ROIPooling"
## [101] "mx.symbol.round"
## [102] "mx.symbol.rsqrt"
## [103] "mx.symbol.save"
## [104] "mx.symbol.SequenceLast"
## [105] "mx.symbol.SequenceMask"
## [106] "mx.symbol.SequenceReverse"
## [107] "mx.symbol.sgd_mom_update"
## [108] "mx.symbol.sgd_update"
## [109] "mx.symbol.sign"
## [110] "mx.symbol.sin"
## [111] "mx.symbol.sinh"
## [112] "mx.symbol.slice_axis"
## [113] "mx.symbol.SliceChannel"
## [114] "mx.symbol.smooth_l1"
## [115] "mx.symbol.Softmax"
## [116] "mx.symbol.SoftmaxActivation"
## [117] "mx.symbol.softmax_cross_entropy"
## [118] "mx.symbol.SoftmaxOutput"
## [119] "mx.symbol.sort"
## [120] "mx.symbol.SpatialTransformer"
## [121] "mx.symbol.sqrt"
## [122] "mx.symbol.square"
## [123] "mx.symbol.sum"
## [124] "mx.symbol.sum_axis"
## [125] "mx.symbol.SVMOutput"
## [126] "mx.symbol.SwapAxis"
## [127] "mx.symbol.tan"
## [128] "mx.symbol.tanh"
## [129] "mx.symbol.topk"
## [130] "mx.symbol.transpose"
## [131] "mx.symbol.uniform"
## [132] "mx.symbol.UpSampling"
## [133] "mx.symbol.Variable"
Если в названии функции последнее слово начинается с большой буквы - эта функция создает слой; если в конце имеется слово Output, то в этом “символе” есть не только выходной слой нейросети, но и функция потерь вместе со всем необходимым для операций обратного распространения ошибки. Например, функция mx.symbol.SoftmaxActivation()
создает слой с активацией softmax, после которого можно указать свою собственную функцию потерь. А если использовать mx.symbol.SoftmaxOutput
, то перекрестная энтропия (cross-entropy, в данном случае это синоним logloss) сразу будет использоваться в качестве функции потерь.
Остальные функции отвечают за операции над “символами”, которые являются аналогами соответствующих операций над значениями. Для арифметических операторов +
, -
, *
и /
добавлены соответствующие методы.
В руководстве Customized loss function рассматривается создание собственной функции потерь с помощью mx.symbol.MakeLoss()
, также полезные материалы есть по этой ссылке. Выглядит это вот так:
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=1)
lro <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.Reshape(fc1, shape = 0) - label))
# Аналог mx.symbol.LinearRegressionOutput()
7. Использование активаций скрытых слоев
Активации скрытых слоев нейросети можно использовать как для визуализации ее работы, так и в качестве признаков для других алгоритмов машинного обучения. Особенно это полезно, когда данных для обучения с нуля своей нейросети не хватает. В таком случае можно взять предобученную сеть для того же класса задач и получить не только финальные предсказания (которые скорее всего будут бесполезными), но и активации, например, последнего полносвязного слоя. Такие активации могут содержать высокоуровневые признаки, релевантные для решаемой задачи.
Руководства от разработчиков на эту тему пока нет, но в обсуждении приводится готовое решение. Модель обучается обычным образом, дополнительно создается executor
- объект, параметры которого можно обновить, используя параметры обученной модели. В конце выполняем forward pass и получаем значения на нужных нам слоях:
# Group some output layers for visual analysis
out <- mx.symbol.Group(c(convAct1, poolLayer1, convAct2, poolLayer2, LeNet1))
# Create an executor
executor <- mx.simple.bind(symbol = out,
data = dim(test.array),
ctx = mx.cpu())
# Update parameters
mx.exec.update.arg.arrays(executor,
model$arg.params,
match.name = TRUE)
mx.exec.update.aux.arrays(executor,
model$aux.params,
match.name = TRUE)
# Select data to use
mx.exec.update.arg.arrays(executor,
list(data = mx.nd.array(test.array)),
match.name = TRUE)
# Do a forward pass with the current parameters and data
mx.exec.forward(executor,
is.train = FALSE)
names(executor$ref.outputs)
Предобученную модель можно использовать и другим образом: можно дообучить ее на своих данных, подобно тому, как мы продолжаем обучение своей собственной модели с “контрольной точки”. Также можно при этом поменять конфигурацию сети - см. руководство для Python. Наверное, в R проще всего отредактировать .json-файл, содержащий описание архитектуры сети: сохраняем resnet$as.json()
, редактируем, загружаем файл с помощью mx.symbol.load()
.
Продолжение, надеюсь, следует, но уже немного в другом формате.
Комментариев нет:
Отправить комментарий