go home Home | Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages
itkStochasticGradientDescentOptimizer.h
Go to the documentation of this file.
1/*=========================================================================
2 *
3 * Copyright UMC Utrecht and contributors
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18#ifndef itkStochasticGradientDescentOptimizer_h
19#define itkStochasticGradientDescentOptimizer_h
20
22#include "itkPlatformMultiThreader.h"
23
24namespace itk
25{
54{
55public:
57
61 using Pointer = SmartPointer<Self>;
62 using ConstPointer = SmartPointer<const Self>;
63
65 itkNewMacro(Self);
66
69
71 using Superclass::MeasureType;
72 using Superclass::ParametersType;
73 using Superclass::DerivativeType;
74 using Superclass::CostFunctionType;
78
83 {
90 };
91
93 virtual void
95
97 void
99
102 virtual void
104
106 virtual void
107 MetricErrorResponse(ExceptionObject & err);
108
111 virtual void
113
115 itkSetMacro(LearningRate, double);
116
118 itkGetConstReferenceMacro(LearningRate, double);
119
121 itkSetMacro(NumberOfIterations, unsigned long);
122
124 itkGetConstMacro(LBFGSMemory, unsigned int);
125
127 itkGetConstReferenceMacro(NumberOfIterations, unsigned long);
128
130 itkGetConstReferenceMacro(NumberOfInnerIterations, unsigned long);
131
133 itkGetConstMacro(CurrentIteration, unsigned int);
134
136 itkGetConstMacro(CurrentInnerIteration, unsigned int);
137
139 itkGetConstReferenceMacro(Value, double);
140
142 itkGetConstReferenceMacro(StopCondition, StopConditionType);
143
145 itkGetConstReferenceMacro(Gradient, DerivativeType);
146
148 itkGetConstReferenceMacro(SearchDir, DerivativeType);
149
151 itkSetMacro(PreviousPosition, ParametersType);
152
154 itkGetConstReferenceMacro(PreviousPosition, ParametersType);
155
157 itkSetMacro(PreviousGradient, DerivativeType);
158
160 itkGetConstReferenceMacro(PreviousGradient, DerivativeType);
161
163 void
164 SetNumberOfWorkUnits(ThreadIdType numberOfThreads)
165 {
166 this->m_Threader->SetNumberOfWorkUnits(numberOfThreads);
167 }
168 // itkGetConstReferenceMacro( NumberOfThreads, ThreadIdType );
169 itkSetMacro(UseMultiThread, bool);
170
171 itkSetMacro(UseOpenMP, bool);
172 itkSetMacro(UseEigen, bool);
173
174protected:
177 void
178 PrintSelf(std::ostream & os, Indent indent) const override;
179
181 using ThreaderType = itk::PlatformMultiThreader;
182 using ThreadInfoType = ThreaderType::WorkUnitInfo;
183
184 // made protected so subclass can access
185 double m_Value{ 0.0 };
186 DerivativeType m_Gradient;
187 ParametersType m_SearchDir;
188 ParametersType m_PreviousSearchDir;
190 ParametersType m_MeanSearchDir;
191 double m_LearningRate{ 1.0 };
193 DerivativeType m_PreviousGradient;
194 DerivativeType m_PrePreviousGradient;
195 ParametersType m_PreviousPosition;
196 ThreaderType::Pointer m_Threader{ ThreaderType::New() };
197
198 bool m_Stop{ false };
199 unsigned long m_NumberOfIterations{ 100 };
201 unsigned long m_CurrentIteration{ 0 };
203 unsigned long m_LBFGSMemory{ 0 };
204
205private:
206 // multi-threaded AdvanceOneStep:
207 bool m_UseMultiThread{ false };
209 {
210 ParametersType * t_NewPosition;
212 };
213
214 bool m_UseOpenMP{ false };
215 bool m_UseEigen{ false };
216
218 static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
220
222 inline void
223 ThreadedAdvanceOneStep(ThreadIdType threadId, ParametersType & newPosition);
224};
225
226} // end namespace itk
227
228
229#endif
A cost function that applies a scaling to another cost function.
static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION AdvanceOneStepThreaderCallback(void *arg)
void PrintSelf(std::ostream &os, Indent indent) const override
virtual void MetricErrorResponse(ExceptionObject &err)
~StochasticGradientDescentOptimizer() override=default
ITK_DISALLOW_COPY_AND_MOVE(StochasticGradientDescentOptimizer)
void ThreadedAdvanceOneStep(ThreadIdType threadId, ParametersType &newPosition)


Generated on 2023-01-13 for elastix by doxygen 1.9.6 elastix logo