Script implementing the terrain neural net
You must have created the images with CreateTerrainImages before running this script. Saves the trained network in the file LunarNet.mat
See also FindDirectory
Contents
%-------------------------------------------------------------------------- % Copyright (c) 2021 Princeton Satellite Systems, Inc. % All rights reserved. %-------------------------------------------------------------------------- % Since version 2021.1 %-------------------------------------------------------------------------- v = ver; any(strcmp('Deep Learning Toolbox', {v.Name}))
ans = logical 1
Get the images
pImages = FindDirectory('LROImages'); c0 = cd; cd(pImages); label = load('Label'); cd(c0); t = categorical(label.t); nClasses = max(label.t); imds = imageDatastore(pImages,'labels',t); labelCount = countEachLabel(imds); % Display a few snapshots NewFig('Lunar Snapshots'); n = 4; m = 5; ks = sort(randi(length(label.t),1,n*m)); % random selection for i = 1:n*m subplot(n,m,i); imshow(imds.Files{ks(i)}); title(sprintf('Image %d: %d',ks(i),label.t(ks(i)))) end % We need the size of the images for the input layer img = readimage(imds,1); % Split into training and testing sets fracTraining = 0.8; [imdsTrain,imdsTest] = splitEachLabel(imds,fracTraining,'randomized');
Error using imageDatastore (line 138) Input folders or files contain non-standard file extensions. Use FileExtensions Name-Value pair to include the non-standard file extensions. Error in TerrainNeuralNet (line 26) imds = imageDatastore(pImages,'labels',t);
Training
This gives the structure of the convolutional neural net
layers = [
imageInputLayer(size(img))
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(nClasses)
softmaxLayer
classificationLayer
];
disp(layers)
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',6, ...
'MiniBatchSize',100,...
'ValidationData',imdsTest, ...
'ValidationFrequency',10, ...
'ValidationPatience',inf,...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
disp(options)
fprintf('Fraction for training %8.2f%%\n',fracTraining*100);
terrainNet = trainNetwork(imdsTrain,layers,options);
Test the neural net
predLabels = classify(terrainNet,imdsTest); testLabels = imdsTest.Labels; accuracy = sum(predLabels == testLabels)/numel(testLabels); fprintf('Accuracy is %8.2f%%\n',accuracy*100) save('LunarNet','terrainNet') %-------------------------------------- % $Date$ % $Revision$