Loading...
Note: File does not exist in v4.6.
1// SPDX-License-Identifier: GPL-2.0
2#include <kunit/test.h>
3
4#include "mean_and_variance.h"
5
6#define MAX_SQR (SQRT_U64_MAX*SQRT_U64_MAX)
7
8static void mean_and_variance_basic_test(struct kunit *test)
9{
10 struct mean_and_variance s = {};
11
12 mean_and_variance_update(&s, 2);
13 mean_and_variance_update(&s, 2);
14
15 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 2);
16 KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 0);
17 KUNIT_EXPECT_EQ(test, s.n, 2);
18
19 mean_and_variance_update(&s, 4);
20 mean_and_variance_update(&s, 4);
21
22 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 3);
23 KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 1);
24 KUNIT_EXPECT_EQ(test, s.n, 4);
25}
26
27/*
28 * Test values computed using a spreadsheet from the psuedocode at the bottom:
29 * https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf
30 */
31
32static void mean_and_variance_weighted_test(struct kunit *test)
33{
34 struct mean_and_variance_weighted s = { .weight = 2 };
35
36 mean_and_variance_weighted_update(&s, 10);
37 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 10);
38 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 0);
39
40 mean_and_variance_weighted_update(&s, 20);
41 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 12);
42 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 18);
43
44 mean_and_variance_weighted_update(&s, 30);
45 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 16);
46 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 72);
47
48 s = (struct mean_and_variance_weighted) { .weight = 2 };
49
50 mean_and_variance_weighted_update(&s, -10);
51 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -10);
52 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 0);
53
54 mean_and_variance_weighted_update(&s, -20);
55 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -12);
56 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 18);
57
58 mean_and_variance_weighted_update(&s, -30);
59 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -16);
60 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 72);
61}
62
63static void mean_and_variance_weighted_advanced_test(struct kunit *test)
64{
65 struct mean_and_variance_weighted s = { .weight = 8 };
66 s64 i;
67
68 for (i = 10; i <= 100; i += 10)
69 mean_and_variance_weighted_update(&s, i);
70
71 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 11);
72 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 107);
73
74 s = (struct mean_and_variance_weighted) { .weight = 8 };
75
76 for (i = -10; i >= -100; i -= 10)
77 mean_and_variance_weighted_update(&s, i);
78
79 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -11);
80 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 107);
81}
82
83static void do_mean_and_variance_test(struct kunit *test,
84 s64 initial_value,
85 s64 initial_n,
86 s64 n,
87 unsigned weight,
88 s64 *data,
89 s64 *mean,
90 s64 *stddev,
91 s64 *weighted_mean,
92 s64 *weighted_stddev)
93{
94 struct mean_and_variance mv = {};
95 struct mean_and_variance_weighted vw = { .weight = weight };
96
97 for (unsigned i = 0; i < initial_n; i++) {
98 mean_and_variance_update(&mv, initial_value);
99 mean_and_variance_weighted_update(&vw, initial_value);
100
101 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv), initial_value);
102 KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv), 0);
103 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw), initial_value);
104 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw),0);
105 }
106
107 for (unsigned i = 0; i < n; i++) {
108 mean_and_variance_update(&mv, data[i]);
109 mean_and_variance_weighted_update(&vw, data[i]);
110
111 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv), mean[i]);
112 KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv), stddev[i]);
113 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw), weighted_mean[i]);
114 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw),weighted_stddev[i]);
115 }
116
117 KUNIT_EXPECT_EQ(test, mv.n, initial_n + n);
118}
119
120/* Test behaviour with a single outlier, then back to steady state: */
121static void mean_and_variance_test_1(struct kunit *test)
122{
123 s64 d[] = { 100, 10, 10, 10, 10, 10, 10 };
124 s64 mean[] = { 22, 21, 20, 19, 18, 17, 16 };
125 s64 stddev[] = { 32, 29, 28, 27, 26, 25, 24 };
126 s64 weighted_mean[] = { 32, 27, 22, 19, 17, 15, 14 };
127 s64 weighted_stddev[] = { 38, 35, 31, 27, 24, 21, 18 };
128
129 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
130 d, mean, stddev, weighted_mean, weighted_stddev);
131}
132
133static void mean_and_variance_test_2(struct kunit *test)
134{
135 s64 d[] = { 100, 10, 10, 10, 10, 10, 10 };
136 s64 mean[] = { 10, 10, 10, 10, 10, 10, 10 };
137 s64 stddev[] = { 9, 9, 9, 9, 9, 9, 9 };
138 s64 weighted_mean[] = { 32, 27, 22, 19, 17, 15, 14 };
139 s64 weighted_stddev[] = { 38, 35, 31, 27, 24, 21, 18 };
140
141 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
142 d, mean, stddev, weighted_mean, weighted_stddev);
143}
144
145/* Test behaviour where we switch from one steady state to another: */
146static void mean_and_variance_test_3(struct kunit *test)
147{
148 s64 d[] = { 100, 100, 100, 100, 100 };
149 s64 mean[] = { 22, 32, 40, 46, 50 };
150 s64 stddev[] = { 32, 39, 42, 44, 45 };
151 s64 weighted_mean[] = { 32, 49, 61, 71, 78 };
152 s64 weighted_stddev[] = { 38, 44, 44, 41, 38 };
153
154 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
155 d, mean, stddev, weighted_mean, weighted_stddev);
156}
157
158static void mean_and_variance_test_4(struct kunit *test)
159{
160 s64 d[] = { 100, 100, 100, 100, 100 };
161 s64 mean[] = { 10, 11, 12, 13, 14 };
162 s64 stddev[] = { 9, 13, 15, 17, 19 };
163 s64 weighted_mean[] = { 32, 49, 61, 71, 78 };
164 s64 weighted_stddev[] = { 38, 44, 44, 41, 38 };
165
166 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
167 d, mean, stddev, weighted_mean, weighted_stddev);
168}
169
170static void mean_and_variance_fast_divpow2(struct kunit *test)
171{
172 s64 i;
173 u8 d;
174
175 for (i = 0; i < 100; i++) {
176 d = 0;
177 KUNIT_EXPECT_EQ(test, fast_divpow2(i, d), div_u64(i, 1LLU << d));
178 KUNIT_EXPECT_EQ(test, abs(fast_divpow2(-i, d)), div_u64(i, 1LLU << d));
179 for (d = 1; d < 32; d++) {
180 KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(i, d)),
181 div_u64(i, 1 << d), "%lld %u", i, d);
182 KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(-i, d)),
183 div_u64(i, 1 << d), "%lld %u", -i, d);
184 }
185 }
186}
187
188static void mean_and_variance_u128_basic_test(struct kunit *test)
189{
190 u128_u a = u64s_to_u128(0, U64_MAX);
191 u128_u a1 = u64s_to_u128(0, 1);
192 u128_u b = u64s_to_u128(1, 0);
193 u128_u c = u64s_to_u128(0, 1LLU << 63);
194 u128_u c2 = u64s_to_u128(U64_MAX, U64_MAX);
195
196 KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a, a1)), 1);
197 KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a, a1)), 0);
198 KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a1, a)), 1);
199 KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a1, a)), 0);
200
201 KUNIT_EXPECT_EQ(test, u128_lo(u128_sub(b, a1)), U64_MAX);
202 KUNIT_EXPECT_EQ(test, u128_hi(u128_sub(b, a1)), 0);
203
204 KUNIT_EXPECT_EQ(test, u128_hi(u128_shl(c, 1)), 1);
205 KUNIT_EXPECT_EQ(test, u128_lo(u128_shl(c, 1)), 0);
206
207 KUNIT_EXPECT_EQ(test, u128_hi(u128_square(U64_MAX)), U64_MAX - 1);
208 KUNIT_EXPECT_EQ(test, u128_lo(u128_square(U64_MAX)), 1);
209
210 KUNIT_EXPECT_EQ(test, u128_lo(u128_div(b, 2)), 1LLU << 63);
211
212 KUNIT_EXPECT_EQ(test, u128_hi(u128_div(c2, 2)), U64_MAX >> 1);
213 KUNIT_EXPECT_EQ(test, u128_lo(u128_div(c2, 2)), U64_MAX);
214
215 KUNIT_EXPECT_EQ(test, u128_hi(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U32_MAX >> 1);
216 KUNIT_EXPECT_EQ(test, u128_lo(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U64_MAX << 31);
217}
218
219static struct kunit_case mean_and_variance_test_cases[] = {
220 KUNIT_CASE(mean_and_variance_fast_divpow2),
221 KUNIT_CASE(mean_and_variance_u128_basic_test),
222 KUNIT_CASE(mean_and_variance_basic_test),
223 KUNIT_CASE(mean_and_variance_weighted_test),
224 KUNIT_CASE(mean_and_variance_weighted_advanced_test),
225 KUNIT_CASE(mean_and_variance_test_1),
226 KUNIT_CASE(mean_and_variance_test_2),
227 KUNIT_CASE(mean_and_variance_test_3),
228 KUNIT_CASE(mean_and_variance_test_4),
229 {}
230};
231
232static struct kunit_suite mean_and_variance_test_suite = {
233 .name = "mean and variance tests",
234 .test_cases = mean_and_variance_test_cases
235};
236
237kunit_test_suite(mean_and_variance_test_suite);
238
239MODULE_AUTHOR("Daniel B. Hill");
240MODULE_LICENSE("GPL");