Halide  13.0.4
Halide compiler and libraries
simd_op_check.h
Go to the documentation of this file.
1 #ifndef SIMD_OP_CHECK_H
2 #define SIMD_OP_CHECK_H
3 
4 #include "Halide.h"
5 #include "halide_test_dirs.h"
6 
7 #include <fstream>
8 
9 namespace Halide {
10 struct TestResult {
11  std::string op;
12  std::string error_msg;
13 };
14 
15 struct Task {
16  std::string op;
17  std::string name;
20 };
21 
23 public:
24  std::string filter{"*"};
26  std::vector<Task> tasks;
27  std::mt19937 rng;
28 
30 
31  ImageParam in_f32{Float(32), 1, "in_f32"};
32  ImageParam in_f64{Float(64), 1, "in_f64"};
33  ImageParam in_f16{Float(16), 1, "in_f16"};
34  ImageParam in_bf16{BFloat(16), 1, "in_bf16"};
35  ImageParam in_i8{Int(8), 1, "in_i8"};
36  ImageParam in_u8{UInt(8), 1, "in_u8"};
37  ImageParam in_i16{Int(16), 1, "in_i16"};
38  ImageParam in_u16{UInt(16), 1, "in_u16"};
39  ImageParam in_i32{Int(32), 1, "in_i32"};
40  ImageParam in_u32{UInt(32), 1, "in_u32"};
41  ImageParam in_i64{Int(64), 1, "in_i64"};
42  ImageParam in_u64{UInt(64), 1, "in_u64"};
43 
44  const std::vector<ImageParam> image_params{in_f32, in_f64, in_f16, in_bf16, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
45  const std::vector<Argument> arg_types{in_f32, in_f64, in_f16, in_bf16, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
46  int W;
47  int H;
48 
49  SimdOpCheckTest(const Target t, int w, int h)
50  : target(t), W(w), H(h) {
51  target = target
57  }
58  virtual ~SimdOpCheckTest() = default;
59 
60  void set_seed(int seed) {
61  rng.seed(seed);
62  }
63 
64  size_t get_num_threads() const {
65  return num_threads;
66  }
67 
68  void set_num_threads(size_t n) {
69  num_threads = n;
70  }
71 
72  virtual bool can_run_code() const {
73  // Assume we are configured to run wasm if requested
74  // (we'll fail further downstream if not)
76  return true;
77  }
78  // If we can (target matches host), run the error checking Halide::Func.
79  Target host_target = get_host_target();
80  bool can_run_the_code =
81  (target.arch == host_target.arch &&
82  target.bits == host_target.bits &&
83  target.os == host_target.os);
84  // A bunch of feature flags also need to match between the
85  // compiled code and the host in order to run the code.
92  if (target.has_feature(f) != host_target.has_feature(f)) {
93  can_run_the_code = false;
94  }
95  }
96  return can_run_the_code;
97  }
98 
99  virtual void compile_and_check(Func error, const std::string &op, const std::string &name, int vector_width, std::ostringstream &error_msg) {
100  std::string fn_name = "test_" + name;
101  std::string file_name = output_directory + fn_name;
102 
103  auto ext = Internal::get_output_info(target);
104  std::map<Output, std::string> outputs = {
105  {Output::c_header, file_name + ext.at(Output::c_header).extension},
106  {Output::object, file_name + ext.at(Output::object).extension},
107  {Output::assembly, file_name + ".s"},
108  };
109  error.compile_to(outputs, arg_types, fn_name, target);
110 
111  std::ifstream asm_file;
112  asm_file.open(file_name + ".s");
113 
114  bool found_it = false;
115 
116  std::ostringstream msg;
117  msg << op << " did not generate for target=" << target.to_string() << " vector_width=" << vector_width << ". Instead we got:\n";
118 
119  std::string line;
120  while (getline(asm_file, line)) {
121  msg << line << "\n";
122 
123  // Check for the op in question
124  found_it |= wildcard_search(op, line) && !wildcard_search("_" + op, line);
125  }
126 
127  if (!found_it) {
128  error_msg << "Failed: " << msg.str() << "\n";
129  }
130 
131  asm_file.close();
132  }
133 
134  // Check if pattern p matches str, allowing for wildcards (*).
135  bool wildcard_match(const char *p, const char *str) const {
136  // Match all non-wildcard characters.
137  while (*p && *str && *p == *str && *p != '*') {
138  str++;
139  p++;
140  }
141 
142  if (!*p) {
143  return *str == 0;
144  } else if (*p == '*') {
145  p++;
146  do {
147  if (wildcard_match(p, str)) {
148  return true;
149  }
150  } while (*str++);
151  } else if (*p == ' ') { // ignore whitespace in pattern
152  p++;
153  if (wildcard_match(p, str)) {
154  return true;
155  }
156  } else if (*str == ' ') { // ignore whitespace in string
157  str++;
158  if (wildcard_match(p, str)) {
159  return true;
160  }
161  }
162  return !*p;
163  }
164 
165  bool wildcard_match(const std::string &p, const std::string &str) const {
166  return wildcard_match(p.c_str(), str.c_str());
167  }
168 
169  // Check if a substring of str matches a pattern p.
170  bool wildcard_search(const std::string &p, const std::string &str) const {
171  return wildcard_match("*" + p + "*", str);
172  }
173 
174  TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e) {
175  std::ostringstream error_msg;
176 
177  class HasInlineReduction : public Internal::IRVisitor {
179  void visit(const Internal::Call *op) override {
180  if (op->call_type == Internal::Call::Halide) {
181  Internal::Function f(op->func);
182  if (f.has_update_definition()) {
183  inline_reduction = f;
184  result = true;
185  }
186  }
187  IRVisitor::visit(op);
188  }
189 
190  public:
191  Internal::Function inline_reduction;
192  bool result = false;
193  } has_inline_reduction;
194  e.accept(&has_inline_reduction);
195 
196  // Define a vectorized Halide::Func that uses the pattern.
197  Halide::Func f(name);
198  f(x, y) = e;
199  f.bound(x, 0, W).vectorize(x, vector_width);
200  f.compute_root();
201 
202  // Include a scalar version
203  Halide::Func f_scalar("scalar_" + name);
204  f_scalar(x, y) = e;
205 
206  if (has_inline_reduction.result) {
207  // If there's an inline reduction, we want to vectorize it
208  // over the RVar.
209  Var xo, xi;
210  RVar rxi;
211  Func g{has_inline_reduction.inline_reduction};
212 
213  // Do the reduction separately in f_scalar
214  g.clone_in(f_scalar);
215 
216  g.compute_at(f, x)
217  .update()
218  .split(x, xo, xi, vector_width)
219  .atomic(true)
220  .vectorize(g.rvars()[0])
221  .vectorize(xi);
222  }
223 
224  // The output to the pipeline is the maximum absolute difference as a double.
225  RDom r_check(0, W, 0, H);
226  Halide::Func error("error_" + name);
227  error() = Halide::cast<double>(maximum(absd(f(r_check.x, r_check.y), f_scalar(r_check.x, r_check.y))));
228 
229  setup_images();
230  compile_and_check(error, op, name, vector_width, error_msg);
231 
232  bool can_run_the_code = can_run_code();
233  if (can_run_the_code) {
234  Target run_target = target
238 
239  error.infer_input_bounds({}, run_target);
240  // Fill the inputs with noise
241  for (auto p : image_params) {
242  Halide::Buffer<> buf = p.get();
243  if (!buf.defined()) continue;
244  assert(buf.data());
245  Type t = buf.type();
246  // For floats/doubles, we only use values that aren't
247  // subject to rounding error that may differ between
248  // vectorized and non-vectorized versions
249  if (t == Float(32)) {
250  buf.as<float>().for_each_value([&](float &f) { f = (rng() & 0xfff) / 8.0f - 0xff; });
251  } else if (t == Float(64)) {
252  buf.as<double>().for_each_value([&](double &f) { f = (rng() & 0xfff) / 8.0 - 0xff; });
253  } else if (t == Float(16)) {
254  buf.as<float16_t>().for_each_value([&](float16_t &f) { f = float16_t((rng() & 0xff) / 8.0f - 0xf); });
255  } else {
256  // Random bits is fine
257  for (uint32_t *ptr = (uint32_t *)buf.data();
258  ptr != (uint32_t *)buf.data() + buf.size_in_bytes() / 4;
259  ptr++) {
260  // Never use the top four bits, to avoid
261  // signed integer overflow.
262  *ptr = ((uint32_t)rng()) & 0x0fffffff;
263  }
264  }
265  }
266  Realization r = error.realize();
267  double e = Buffer<double>(r[0])();
268  // Use a very loose tolerance for floating point tests. The
269  // kinds of bugs we're looking for are codegen bugs that
270  // return the wrong value entirely, not floating point
271  // accuracy differences between vectors and scalars.
272  if (e > 0.001) {
273  error_msg << "The vector and scalar versions of " << name << " disagree. Maximum error: " << e << "\n";
274 
275  std::string error_filename = output_directory + "error_" + name + ".s";
276  error.compile_to_assembly(error_filename, arg_types, target);
277 
278  std::ifstream error_file;
279  error_file.open(error_filename);
280 
281  error_msg << "Error assembly: \n";
282  std::string line;
283  while (getline(error_file, line)) {
284  error_msg << line << "\n";
285  }
286 
287  error_file.close();
288  }
289  }
290 
291  return {op, error_msg.str()};
292  }
293 
294  void check(std::string op, int vector_width, Expr e) {
295  // Make a name for the test by uniquing then sanitizing the op name
296  std::string name = "op_" + op;
297  for (size_t i = 0; i < name.size(); i++) {
298  if (!isalnum(name[i])) name[i] = '_';
299  }
300 
301  name += "_" + std::to_string(tasks.size());
302 
303  // Bail out after generating the unique_name, so that names are
304  // unique across different processes and don't depend on filter
305  // settings.
306  if (!wildcard_match(filter, op)) return;
307 
308  tasks.emplace_back(Task{op, name, vector_width, e});
309  }
310  virtual void add_tests() = 0;
311  virtual void setup_images() {
312  for (auto p : image_params) {
313  p.reset();
314 
315  const int alignment_bytes = 16;
316  p.set_host_alignment(alignment_bytes);
317  const int alignment = alignment_bytes / p.type().bytes();
318  p.dim(0).set_min((p.dim(0).min() / alignment) * alignment);
319  }
320  }
321  virtual bool test_all() {
322  /* First add some tests based on the target */
323  add_tests();
324  Internal::ThreadPool<TestResult> pool(num_threads);
325  std::vector<std::future<TestResult>> futures;
326  for (const Task &task : tasks) {
327  futures.push_back(pool.async([this, task]() {
328  return check_one(task.op, task.name, task.vector_width, task.expr);
329  }));
330  }
331 
332  bool success = true;
333  for (auto &f : futures) {
334  const TestResult &result = f.get();
335  std::cout << result.op << "\n";
336  if (!result.error_msg.empty()) {
337  std::cerr << result.error_msg;
338  success = false;
339  }
340  }
341 
342  return success;
343  }
344 
345 private:
346  size_t num_threads;
347  const Halide::Var x{"x"}, y{"y"};
348 };
349 } // namespace Halide
350 #endif // SIMD_OP_CHECK_H
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition: Buffer.h:115
A halide function.
Definition: Func.h:698
void compile_to_assembly(const std::string &filename, const std::vector< Argument > &, const std::string &fn_name, const Target &target=get_target_from_environment())
Statically compile this function to text assembly equivalent to the object file generated by compile_...
Func & compute_root()
Compute all of this function once ahead of time.
Stage update(int idx=0)
Get a handle on an update step for the purposes of scheduling it.
void infer_input_bounds(const std::vector< int32_t > &sizes, const Target &target=get_jit_target_from_environment(), const ParamMap &param_map=ParamMap::empty_map())
For a given size of output, or a given output buffer, determine the bounds required of all unbound Im...
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Func & compute_at(const Func &f, const Var &var)
Compute this function as needed for each unique value of the given var for the given calling function...
Realization realize(std::vector< int32_t > sizes={}, const Target &target=Target(), const ParamMap &param_map=ParamMap::empty_map())
Evaluate this function over some rectangular domain and return the resulting buffer or buffers.
Func & bound(const Var &var, Expr min, Expr extent)
Statically declare that the range over which a function should be evaluated is given by the second an...
void compile_to(const std::map< Output, std::string > &output_files, const std::vector< Argument > &args, const std::string &fn_name, const Target &target=get_target_from_environment())
Compile and generate multiple target files with single call.
Func clone_in(const Func &f)
Similar to Func::in; however, instead of replacing the call to this Func with an identity Func that r...
An Image parameter to a halide pipeline.
Definition: ImageParam.h:23
A reference-counted handle to Halide's internal representation of a function.
Definition: Function.h:38
bool has_update_definition() const
Does this function have an update definition?
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:19
virtual void visit(const IntImm *)
std::future< T > async(Func func, Args... args)
Definition: ThreadPool.h:117
static size_t num_processors_online()
Definition: ThreadPool.h:79
A multi-dimensional domain over which to iterate.
Definition: RDom.h:193
RVar x
Direct access to the first four dimensions of the reduction domain.
Definition: RDom.h:337
RVar y
Definition: RDom.h:337
A reduction variable represents a single dimension of a reduction domain (RDom).
Definition: RDom.h:29
A Realization is a vector of references to existing Buffer objects.
Definition: Realization.h:21
size_t get_num_threads() const
Definition: simd_op_check.h:64
virtual void compile_and_check(Func error, const std::string &op, const std::string &name, int vector_width, std::ostringstream &error_msg)
Definition: simd_op_check.h:99
const std::vector< Argument > arg_types
Definition: simd_op_check.h:45
std::string output_directory
Definition: simd_op_check.h:25
virtual void setup_images()
void set_seed(int seed)
Definition: simd_op_check.h:60
virtual void add_tests()=0
bool wildcard_match(const std::string &p, const std::string &str) const
virtual bool test_all()
bool wildcard_search(const std::string &p, const std::string &str) const
bool wildcard_match(const char *p, const char *str) const
virtual ~SimdOpCheckTest()=default
SimdOpCheckTest(const Target t, int w, int h)
Definition: simd_op_check.h:49
void check(std::string op, int vector_width, Expr e)
const std::vector< ImageParam > image_params
Definition: simd_op_check.h:44
TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e)
void set_num_threads(size_t n)
Definition: simd_op_check.h:68
virtual bool can_run_code() const
Definition: simd_op_check.h:72
std::vector< Task > tasks
Definition: simd_op_check.h:26
A Halide variable, to be used when defining functions.
Definition: Var.h:19
std::map< Output, const OutputInfo > get_output_info(const Target &target)
std::string get_test_tmp_dir()
Return the path to a directory that can be safely written to when running tests; the contents directo...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Target get_host_target()
Return the target corresponding to the host machine.
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Definition: Type.h:509
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Definition: Type.h:499
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Definition: Type.h:504
Expr maximum(Expr, const std::string &s="maximum")
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Definition: Type.h:494
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
char * buf
Definition: printer.h:32
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition: Expr.h:256
A function call.
Definition: IR.h:466
@ Halide
A call to a Func.
Definition: IR.h:473
FunctionPtr func
Definition: IR.h:599
CallType call_type
Definition: IR.h:477
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition: Expr.h:190
A struct representing a target machine and os to generate code for.
Definition: Target.h:19
enum Halide::Target::Arch arch
bool has_feature(Feature f) const
int bits
The bit-width of the target machine.
Definition: Target.h:51
enum Halide::Target::OS os
std::string to_string() const
Convert the Target into a string form that can be reconstituted by merge_string(),...
Target without_feature(Feature f) const
Return a copy of the target with the given feature cleared.
Feature
Optional features a target can have.
Definition: Target.h:57
@ NoBoundsQuery
Definition: Target.h:61
@ DisableLLVMLoopOpt
Definition: Target.h:120
@ POWER_ARCH_2_07
Definition: Target.h:71
Target with_feature(Feature f) const
Return a copy of the target with the given feature set.
std::string op
Definition: simd_op_check.h:16
std::string name
Definition: simd_op_check.h:17
std::string error_msg
Definition: simd_op_check.h:12
Types in the halide type system.
Definition: Type.h:265
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...
Definition: Float16.h:17