Skip to content

Commit 2f47541

Browse files
committed
revert pattern
When learnPattern is set, use learnPattern. inferPattern is set, use inferPattern
1 parent 27aec54 commit 2f47541

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

src/htm/regions/ClassifierRegion.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ namespace htm {
7979
inputs: {
8080
bucket: { description: "The quantized value of the current sample, one from each encoder if more than one, for the learn step",
8181
type: Real64, count: 0},
82+
pattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM",
83+
type: SDR, count: 0},
8284
inferPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: predictiveCells from TM",
8385
type: SDR, count: 0},
8486
learnPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM",
@@ -140,8 +142,8 @@ Dimensions ClassifierRegion::askImplForOutputDimensions(const std::string &name)
140142

141143

142144
void ClassifierRegion::compute() {
145+
SDR &pattern = getInput("pattern")->getData().getSDR();
143146
if (learn_) {
144-
SDR &learnPattern = getInput("learnPattern")->getData().getSDR();
145147
Array &b = getInput("bucket")->getData();
146148
// 'bucket' is a list of quantized samples being processed for this iteration.
147149
// There are one of these for each encoder (or value being encoded).
@@ -165,14 +167,25 @@ void ClassifierRegion::compute() {
165167
}
166168
categoryIdxList.push_back(c);
167169
}
168-
classifier_->learn(learnPattern, categoryIdxList);
170+
171+
SDR &learnPattern = getInput("learnPattern")->getData().getSDR();
172+
if (learnPattern.size == 0) {
173+
classifier_->learn(pattern, categoryIdxList);
174+
} else {
175+
classifier_->learn(learnPattern, categoryIdxList);
176+
}
169177
}
170178

171179
SDR &inferPattern = getInput("inferPattern")->getData().getSDR();
172180
// Note: if there is no link to 'inferPattern' input, the 'inferPattern' SDR length is 0
173181
// and SDRClassifier::infer() will throw an exception.
174-
175-
PDF pdf = classifier_->infer(inferPattern);
182+
//
183+
PDF pdf;
184+
if (inferPattern.size == 0) {
185+
pdf = classifier_->infer(pattern);
186+
} else {
187+
pdf = classifier_->infer(inferPattern);
188+
}
176189

177190
// Adjust the buffer size to match the pdf.
178191
if (getOutput("pdf")->getData().getCount() < pdf.size()) {

0 commit comments

Comments
 (0)