Vowpal Wabbit
primitives.cc
Go to the documentation of this file.
1 #include "primitives.h"
2 #include <intrin.h>
3 #include <array>
4 #include <bitset>
5 
6 // sum_of_squares_func get_sum_of_squares()
7 //{
8 // std::array<int, 4> cpui;
9 //
10 // // Calling __cpuid with 0x0 as the function_id argument
11 // // gets the number of the highest valid function ID.
12 // __cpuid(cpui.data(), 0);
13 // __cpuidex(cpui.data(), 1, 0);
14 //
15 // std::bitset<32> f_1_ECX_ = cpui[2];
16 // std::bitset<32> f_7_EBX_
17 // if (nIds_ >= 7)
18 // {
19 // f_7_EBX_ = data_[7][1];
20 //
21 // // static bool AVX2(void) { return CPU_Rep.f_7_EBX_[5]; }
22 // // static bool AVX(void) { return CPU_Rep.f_1_ECX_[28]; }
23 //}
24 
25 float sum_of_squares(float* begin, float* end)
26 {
27  float sum = 0;
28 
29  for (; begin != end; begin++) sum += *begin * *begin;
30 
31  return sum;
32 }
33 
34 /*
35 float sum_of_squares_avx(float* begin, float* end)
36 {
37  size_t length = (end - begin);
38  size_t remainder = length % (sizeof(__m128) / sizeof(float));
39 
40  float* sseEnd = end - remainder;
41  __m128 source;
42  __m128 dest = _mm_setzero_ps();
43 
44  for (; begin != sseEnd; begin += (sizeof(__m128) / sizeof(float)))
45  {
46  source = _mm_loadu_ps(begin);
47  _mm_fnmadd_ps(source, source, dest);
48  }
49 
50  _mm_hadd_ps(dest, dest);
51  _mm_hadd_ps(dest, dest);
52  _mm_hadd_ps(dest, dest);
53 
54  float sum = dest.m128_f32[0];
55 
56  for (; begin != end; begin++)
57  sum += *begin * *begin;
58 
59  return sum;
60 }
61 
62 float sum_of_squares_avx2(float* begin, float* end)
63 {
64  size_t length = (end - begin);
65  size_t remainder = length % (sizeof(__m256) / sizeof(float));
66 
67  float* sseEnd = end - remainder;
68  __m256 source;
69  __m256 dest = _mm256_setzero_ps();
70 
71  for (; begin != sseEnd; begin += (sizeof(__m256) / sizeof(float)))
72  {
73  source = _mm256_loadu_ps(begin);
74  _mm256_fnmadd_ps(source, source, dest);
75  }
76 
77  _mm256_hadd_ps(dest, dest);
78  _mm256_hadd_ps(dest, dest);
79  _mm256_hadd_ps(dest, dest);
80 
81  float sum = dest.m256_f32[0];
82 
83  for (; begin != end; begin++)
84  sum += *begin * *begin;
85 
86  return sum;
87 }
88 */
float sum_of_squares(float *begin, float *end)
Definition: primitives.cc:25