Skip to content

Commit f29344d

Browse files
committed
Feedback from Andrew.
1 parent 70debb1 commit f29344d

3 files changed

Lines changed: 30 additions & 19 deletions

File tree

src/LegalizeVectors.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "Util.h"
99

1010
#include <optional>
11+
#include <unordered_set>
12+
#include <vector>
1113

1214
namespace Halide {
1315
namespace Internal {
@@ -16,13 +18,19 @@ namespace {
1618

1719
using namespace std;
1820

19-
const char *legalization_error_guide = "\n(This issue can most likely be resolved by reducing lane count for vectorize() calls in the schedule, or disabling it.)";
21+
const char *legalization_error_guide = "\n"
22+
"(This is an implemenation limitation in Halide right now. This issue can most likely be \n"
23+
" worked around by reducing lane count for vectorize() calls in GPU schedules, or disabling it.)";
2024

2125
int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) {
26+
// The environment variable below (HL_FORCE_VECTOR_LEGALIZATION) is here solely for testing purposes.
27+
// It is useful to "stress-test" this lowering pass by forcing a shorter maximal vector size across
28+
// all codegen across the entire test suite. This should not be used in real uses of Halide.
2229
std::string envvar = Halide::Internal::get_env_variable("HL_FORCE_VECTOR_LEGALIZATION");
2330
if (!envvar.empty()) {
2431
return std::atoi(envvar.c_str());
2532
}
33+
// The remainder of this function correctly determines the number of lanes the device API supports.
2634
switch (api) {
2735
case DeviceAPI::Metal:
2836
case DeviceAPI::WebGPU:
@@ -53,13 +61,13 @@ std::string vec_name(const string &name, int lane_start, int lane_count) {
5361
class LiftLetToLetStmt : public IRMutator {
5462
using IRMutator::visit;
5563

64+
unordered_set<string> lifted_let_names;
5665
vector<const Let *> lets;
5766
Expr visit(const Let *op) override {
58-
for (const Let *existing : lets) {
59-
internal_assert(existing->name != op->name)
60-
<< "Let " << op->name << " = ... cannot be lifted to LetStmt because the name is not unique.";
61-
}
67+
internal_assert(lifted_let_names.count(op->name) == 0)
68+
<< "Let " << op->name << " = ... cannot be lifted to LetStmt because the name is not unique.";
6269
lets.push_back(op);
70+
lifted_let_names.insert(op->name);
6371
return mutate(op->body);
6472
}
6573

@@ -124,8 +132,7 @@ class ExtractLanes : public IRMutator {
124132
return result;
125133
}
126134

127-
internal_error << "Unhandled trace call in LegalizeVectors' ExtractLanes: " << *event << legalization_error_guide << "\n"
128-
<< "Please report this error on GitHub." << legalization_error_guide;
135+
internal_error << "Unhandled trace call in LegalizeVectors' ExtractLanes: " << *event << legalization_error_guide;
129136
return Expr(0);
130137
}
131138

@@ -332,7 +339,7 @@ class LiftExceedingVectors : public IRMutator {
332339
just_in_let_definition = false;
333340
Stmt mutated = IRMutator::mutate(s);
334341
for (auto &let : reverse_view(lets)) {
335-
// There is no recurse into let.second. This is handled by repeatedly calling this tranform.
342+
// There is no recurse into let.second. This is handled by repeatedly calling this transform.
336343
mutated = LetStmt::make(let.first, let.second, mutated);
337344
}
338345
return mutated;
@@ -576,17 +583,12 @@ Stmt legalize_vectors_in_device_loop(const For *op) {
576583
}
577584

578585
Stmt legalize_vectors(const Stmt &s) {
579-
class LegalizeDeviceLoops : public IRMutator {
580-
using IRMutator::visit;
581-
Stmt visit(const For *op) override {
582-
if (max_lanes_for_device(op->device_api, 0)) {
583-
return legalize_vectors_in_device_loop(op);
584-
} else {
585-
return IRMutator::visit(op);
586-
}
586+
return mutate_with(s, [&](auto *self, const For *op) {
587+
if (max_lanes_for_device(op->device_api, 0)) {
588+
return legalize_vectors_in_device_loop(op);
587589
}
588-
} mutator;
589-
return mutator.mutate(s);
590+
return self->visit_base(op);
591+
});
590592
}
591593
} // namespace Internal
592594
} // namespace Halide

src/Simplify_Let.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) {
187187
// pure operations from _all_ arguments to the Shuffle, we will
188188
// instead substitute all of the vars that go in the shuffle, and
189189
// instead guard against side effects by checking with `is_pure()`.
190+
//
191+
// Also, it is safe to substitute in without combinatorial
192+
// blow-up, because deeply nested concats implies a
193+
// combinatorially-large number of vector lanes, which we can't
194+
// express in the type system anyway.
190195
replacement = substitute(f.new_name, shuffle, replacement);
191196
f.new_value = Expr();
192197
break;

src/Simplify_Shuffle.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) {
191191
}
192192
} else {
193193
// We can't... Leave it as a Shuffle of Loads.
194-
// Note: don't proceed down.
194+
// Note: no mutate-recursion as we are dealing here with a
195+
// Shuffle of Loads, which have already undergone mutation
196+
// early in this function (new_vectors).
195197
return result;
196198
}
197199
}
@@ -362,13 +364,15 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) {
362364
}
363365
}
364366

367+
#if 0 // Not sure what this was for. Disabling for now, and will run tests to see what's up.
365368
for (size_t i = 0; i < new_vectors.size() && can_collapse; i++) {
366369
if (new_vectors[i].as<Load>()) {
367370
// Don't create a Ramp of a Load, like:
368371
// ramp(buf[x], buf[x + 1] - buf[x], ...)
369372
can_collapse = false;
370373
}
371374
}
375+
#endif
372376

373377
if (can_collapse) {
374378
return Ramp::make(new_vectors[0], stride, op->indices.size());

0 commit comments

Comments
 (0)