enumeration.h
Go to the documentation of this file.
1/*
2MIT License
3
4Copyright (c) 2016 Marc Stevens
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23*/
24
25#ifndef ENUMLIB_ENUMERATION_HPP
26#define ENUMLIB_ENUMERATION_HPP
27
28#include "fplll_types.h"
29#include <fplll/threadpool.h>
30
31#include <algorithm>
32#include <array>
33#include <atomic>
34#include <chrono>
35#include <cmath>
36#include <cstdint>
37#include <fstream>
38#include <functional>
39#include <iostream>
40#include <memory>
41#include <mutex>
42#include <sstream>
43#include <stdexcept>
44#include <string>
45#include <thread>
46#include <vector>
47
48#include <fplll/defs.h>
49
51
52namespace enumlib
53{
54
55//#define SINGLE_THREADED
56//#define NOCOUNTS
57#define NOLOCALUPDATE
58
59using namespace std;
60
63typedef std::lock_guard<std::mutex> lock_type;
64
65typedef atomic<float_type> global_A_t;
66typedef array<atomic_int_fast8_t, 256> global_signal_t;
67
68template <int N> struct globals_t
69{
70 typedef array<int, N> introw_t;
71 typedef pair<introw_t, pair<float_type, float_type>> swirl_item_t;
72
76
77 std::function<extenum_cb_process_sol> process_sol;
78 std::function<extenum_cb_process_subsol> process_subsol;
79
80 vector<vector<swirl_item_t>> swirlys;
81};
82
84
85template <int N, int SWIRLY, int SWIRLY2BUF, int SWIRLY1FRACTION, bool findsubsols = false>
87{
88 typedef array<float_type, N> fltrow_t;
89 typedef array<int, N> introw_t;
90 typedef pair<introw_t, pair<float_type, float_type>> swirl_item_t;
91
92 /* inputs */
97
98 /* config */
100
101 /* internals */
104
105 float_type _A; // overall enumeration bound
106 fltrow_t _AA, _AA2; // enumeration pruning bounds
108 fltrow_t _sol; // to pass to fplll
111 array<float_type, N + 1> _l;
112 array<uint64_t, N + 1> _counts;
113
115
117 array<fltrow_t, N> _subsol;
118
119 std::chrono::system_clock::time_point starttime;
120
122 : activeswirly(false), globals(globals_), starttime(std::chrono::system_clock::now())
123 {
124 }
125
126 inline int myround(double a) { return (int)(round(a)); }
127 inline int myround(float a) { return (int)(roundf(a)); }
128 inline int myround(long double a) { return (int)(roundl(a)); }
129
131 {
132 if (globals.signal[threadid])
133 {
134 globals.signal[threadid] = 0;
135 _A = globals.A;
136 _update_AA();
137 }
138 }
139
140 inline void _update_AA()
141 {
142 for (int k = 0; k < N; ++k)
143 _AA[k] = _A * pr[k];
144 for (int k = 0; k < N; ++k)
145 _AA2[k] = _A * pr2[k];
146 }
147
148 // compile time parameters for enumerate_recur (without ANY runtime overhead)
149 // allows specialization for certain specific cases, e.g., i=0, or i=swirl
150 template <int i, bool svp, int swirl, int swirlid> struct i_tag
151 {
152 };
153
154 template <int i, bool svp, int swirl, int swirlid>
156 {
157 if (_r[i] > _r[i - 1])
158 _r[i - 1] = _r[i];
159 float_type ci = _sigT[i][i];
160 float_type yi = round(ci);
161 int xi = (int)(yi);
162 yi = ci - yi;
163 float_type li = _l[i + 1] + (yi * yi * risq[i]);
164#ifndef NOCOUNTS
165 ++_counts[i];
166#endif
167
168 if (findsubsols && li < _subsolL[i] && li != 0.0)
169 {
170 _subsolL[i] = li;
171 _subsol[i][i] = xi;
172 for (int j = i + 1; j < N; ++j)
173 _subsol[i][j] = _x[j];
174 }
175 if (li > _AA[i])
176 return;
177
178 _Dx[i] = _D2x[i] = (((int)(yi >= 0) & 1) << 1) - 1;
179 _c[i] = ci;
180 _x[i] = xi;
181 _l[i] = li;
182
183 for (int j = _r[i - 1]; j > i - 1; --j)
184 _sigT[i - 1][j - 1] = _sigT[i - 1][j] - _x[j] * muT[i - 1][j];
185
186#ifndef NOLOCALUPDATE
188#endif
189
190 while (true)
191 {
193
194 if (_l[i + 1] == 0.0)
195 {
196 ++_x[i];
197 _r[i - 1] = i;
198 float_type yi2 = _c[i] - _x[i];
199 float_type li2 = _l[i + 1] + (yi2 * yi2 * risq[i]);
200 if (li2 > _AA2[i])
201 return;
202 _l[i] = li2;
203 _sigT[i - 1][i - 1] = _sigT[i - 1][i] - _x[i] * muT[i - 1][i];
204 }
205 else
206 {
207 _x[i] += _Dx[i];
208 _D2x[i] = -_D2x[i];
209 _Dx[i] = _D2x[i] - _Dx[i];
210 _r[i - 1] = i;
211 float_type yi2 = _c[i] - _x[i];
212 float_type li2 = _l[i + 1] + (yi2 * yi2 * risq[i]);
213 if (li2 > _AA2[i])
214 return;
215 _l[i] = li2;
216 _sigT[i - 1][i - 1] = _sigT[i - 1][i] - _x[i] * muT[i - 1][i];
217 }
218 }
219 }
220
221 template <bool svp, int swirl, int swirlid>
223 {
224 static const int i = 0;
225 float_type ci = _sigT[i][i];
226 float_type yi = round(ci);
227 int xi = (int)(yi);
228 yi = ci - yi;
229 float_type li = _l[i + 1] + (yi * yi * risq[i]);
230#ifndef NOCOUNTS
231 ++_counts[i];
232#endif
233
234 if (findsubsols && li < _subsolL[i] && li != 0.0)
235 {
236 _subsolL[i] = li;
237 _subsol[i][i] = xi;
238 for (int j = i + 1; j < N; ++j)
239 _subsol[i][j] = _x[j];
240 }
241 if (li > _AA[i])
242 return;
243
244 _Dx[i] = _D2x[i] = (((int)(yi >= 0) & 1) << 1) - 1;
245 _c[i] = ci;
246 _x[i] = xi;
247 _l[i] = li;
248
249#ifndef NOLOCALUPDATE
251#endif
252
253 while (true)
254 {
256
257 if (_l[i + 1] == 0.0)
258 {
259 ++_x[i];
260 float_type yi2 = _c[i] - _x[i];
261 float_type li2 = _l[i + 1] + (yi2 * yi2 * risq[i]);
262 if (li2 > _AA2[i])
263 return;
264 _l[i] = li2;
265 }
266 else
267 {
268 _x[i] += _Dx[i];
269 _D2x[i] = -_D2x[i];
270 _Dx[i] = _D2x[i] - _Dx[i];
271 float_type yi2 = _c[i] - _x[i];
272 float_type li2 = _l[i + 1] + (yi2 * yi2 * risq[i]);
273 if (li2 > _AA2[i])
274 return;
275 _l[i] = li2;
276 }
277 }
278 }
279
280 template <bool svp, int swirl, int swirlid>
282 {
283 if (_l[0] > _AA[0] || _l[0] == 0.0)
284 return;
285
286 lock_type lock(globals.mutex);
287
288 for (int j = 0; j < N; ++j)
289 _sol[j] = _x[j];
290 globals.A = globals.process_sol(_l[0], &_sol[0]);
291
292 // if it has changed then signal all threads to update and update ourselves
293 if (globals.A != _A)
294 {
295 for (size_t j = 0; j < globals.signal.size(); ++j)
296 globals.signal[j] = 1;
297
299 }
300 }
301
302 template <bool svp, int swirl, int swirlid>
304 {
305 }
306 template <bool svp, int swirl, int swirlid>
308 {
309 }
310
311 template <int i, bool svp, int swirlid> inline void enumerate_recur(i_tag<i, svp, i, swirlid>)
312 {
313 if (_r[i] > _r[i - 1])
314 _r[i - 1] = _r[i];
315
316 float_type ci = _sigT[i][i];
317 float_type yi = round(ci);
318 int xi = (int)(yi);
319 yi = ci - yi;
320 float_type li = _l[i + 1] + (yi * yi * risq[i]);
321#ifndef NOCOUNTS
322 ++_counts[i];
323#endif
324
325 if (findsubsols && li < _subsolL[i] && li != 0.0)
326 {
327 _subsolL[i] = li;
328 _subsol[i][i] = xi;
329 for (int j = i + 1; j < N; ++j)
330 _subsol[i][j] = _x[j];
331 }
332 if (li > _AA[i])
333 return;
334 _c[i] = ci;
335 _x[i] = xi;
336 _l[i] = li;
337 _Dx[i] = _D2x[i] = (((int)(yi >= 0) & 1) << 1) - 1;
338
339 for (int j = _r[i - 1]; j > i - 1; --j)
340 _sigT[i - 1][j - 1] = _sigT[i - 1][j] - _x[j] * muT[i - 1][j];
341
342 while (true)
343 {
344 float_type ci2 = _sigT[i - 1][i - 1];
345 int xi2 = myround(ci2);
346 float_type yi2 = ci2 - xi2;
347 float_type li2 = _l[i] + (yi2 * yi2 * risq[i - 1]);
348
349 globals.swirlys[swirlid].emplace_back();
350 for (int j = i; j < N; ++j)
351 globals.swirlys[swirlid].back().first[j] = _x[j];
352 globals.swirlys[swirlid].back().second.first = _l[i];
353 globals.swirlys[swirlid].back().second.second = li2;
354
355 if (_l[i + 1] == 0.0)
356 {
357 ++_x[i];
358 _r[i - 1] = i;
359 float_type yi2 = _c[i] - _x[i];
360 float_type li = _l[i + 1] + (yi2 * yi2 * risq[i]);
361 if (li > _AA2[i])
362 return;
363 _l[i] = li;
364 _sigT[i - 1][i - 1] = _sigT[i - 1][i] - _x[i] * muT[i - 1][i];
365 }
366 else
367 {
368 _x[i] += _Dx[i];
369 _D2x[i] = -_D2x[i];
370 _Dx[i] = _D2x[i] - _Dx[i];
371 _r[i - 1] = i;
372 float_type yi2 = _c[i] - _x[i];
373 float_type li = _l[i + 1] + (yi2 * yi2 * risq[i]);
374 if (li > _AA2[i])
375 return;
376 _l[i] = li;
377 _sigT[i - 1][i - 1] = _sigT[i - 1][i] - _x[i] * muT[i - 1][i];
378 }
379 }
380 }
381
382 template <bool svp = true> void enumerate_recursive()
383 {
384 for (size_t i = 0; i < globals.signal.size(); ++i)
385 globals.signal[i] = 0;
387
388 _A = globals.A;
389 _update_AA();
390
391 for (int j = 0; j < N; ++j)
392 {
393 _x[j] = _Dx[j] = 0;
394 _D2x[j] = -1;
395 _sol[j] = 0;
396 _c[j] = _l[j] = 0.0;
397 _subsolL[j] = risq[j];
398 for (int k = 0; k < N; ++k)
399 {
400 _sigT[j][k] = 0.0;
401 _subsol[j][k] = 0;
402 }
403 _r[j] = N - 1;
404 _counts[j] = 0;
405 }
406 _l[N] = 0.0;
407 _counts[N] = 0;
408
409#ifdef SINGLE_THREADED
411#else
412 auto &swirlys = globals.swirlys;
413 swirlys.resize(2);
414 swirlys[0].clear();
416
417 const auto swirl_less = [](const swirl_item_t &l, const swirl_item_t &r)
418 { return l.second.second < r.second.second; };
419 if (activeswirly)
420 {
421 sort(swirlys[0].begin(), swirlys[0].end(), swirl_less);
422 }
423
424 size_t swirly0idx = 0;
425 swirlys[1].clear();
426 while (swirly0idx < swirlys[0].size())
427 {
428 int swirly1newstart = (int)(swirlys[1].size());
429 while (swirly0idx < swirlys[0].size() && swirlys[1].size() < SWIRLY2BUF)
430 {
431 const int i = N - SWIRLY;
432 _x = swirlys[0][swirly0idx].first;
433 _l[i] = swirlys[0][swirly0idx].second.first;
434 for (int j = 0; j < N; ++j)
435 _r[j] = N - 1;
436 for (int j = N - 1; j > i - 1; --j)
437 _sigT[i - 1][j - 1] = _sigT[i - 1][j] - _x[j] * muT[i - 1][j];
438
440
441 ++swirly0idx;
442 }
443
444 size_t swirly1end = (int)(swirlys[1].size());
445 if (activeswirly)
446 {
447 // sort the new additions to swirly1
448 sort(swirlys[1].begin() + swirly1newstart, swirlys[1].end(), swirl_less);
449 // merge with previous elms in swirly1
450 inplace_merge(swirlys[1].begin(), swirlys[1].begin() + swirly1newstart, swirlys[1].end(),
451 swirl_less);
452
453 // process portion of swirly[1] and then add more
454 swirly1end = (SWIRLY2BUF >> SWIRLY1FRACTION);
455 if (swirly1end > swirlys[1].size())
456 swirly1end = swirlys[1].size();
457 }
458
459 auto &swirly_ref = swirlys[1];
460 std::atomic<std::size_t> swirly_i(0);
461 unsigned threadid = 0;
462 auto f = [this, &swirly_ref, &swirly_i, swirly1end, &threadid]()
463 {
464 auto mylat = *this;
465 {
466 lock_type lock(globals.mutex);
467 mylat.threadid = threadid++;
468 }
469 for (int j = 0; j < N - SWIRLY; ++j)
470 mylat._counts[j] = 0;
471 while (true)
472 {
473 std::size_t idx = swirly_i++;
474 if (idx >= swirly1end)
475 break;
476
477 const int i = N - 2 * SWIRLY;
478 mylat._x = swirly_ref[idx].first;
479 mylat._l[i] = swirly_ref[idx].second.first;
480 for (int j = 0; j < N; ++j)
481 mylat._r[j] = N - 1;
482 for (int j = N - 1; j > i - 1; --j)
483 mylat._sigT[i - 1][j - 1] = mylat._sigT[i - 1][j] - mylat._x[j] * mylat.muT[i - 1][j];
484
485 mylat._thread_local_update();
486
487 mylat.enumerate_recur(i_tag<N - 2 * SWIRLY - 1, svp, -2, -1>());
488 }
489
490 lock_type lock(globals.mutex);
491 for (int j = 0; j < N - SWIRLY; ++j)
492 this->_counts[j] += mylat._counts[j];
493 for (int j = 0; j < N; ++j)
494 if (mylat._subsolL[j] < this->_subsolL[j])
495 {
496 this->_subsolL[j] = mylat._subsolL[j];
497 this->_subsol[j] = mylat._subsol[j];
498 }
499 };
500 for (int i = 0; i < ::fplll::get_threads(); ++i)
501 threadpool.push(f);
503
504 swirlys[1].erase(swirlys[1].begin(), swirlys[1].begin() + swirly1end);
505 }
506#ifndef NOCOUNTS
507// if (enumlib_loglevel >= 1) cout << "[enumlib] counts: " << _counts << endl;
508#endif
509#endif
510 }
511};
512
513} // namespace enumlib
514
516
517#endif // ENUMLIB_ENUMERATION_HPP
void push(const std::function< void()> &f)
Definition: thread_pool.hpp:214
void wait_work()
Definition: thread_pool.hpp:200
#define FPLLL_END_NAMESPACE
Definition: defs.h:117
#define FPLLL_BEGIN_NAMESPACE
Definition: defs.h:114
fplll_extenum_enumf std::function< extenum_cb_set_config > std::function< extenum_cb_process_sol > std::function< extenum_cb_process_subsol > bool bool findsubsols
Definition: enumerate_ext_api.h:92
Definition: enumeration.h:53
mutex_type global_mutex
std::lock_guard< std::mutex > lock_type
Definition: enumeration.h:63
array< atomic_int_fast8_t, 256 > global_signal_t
Definition: enumeration.h:66
::fplll::enumf fplll_float
Definition: fplll_types.h:36
fplll_float float_type
Definition: enumeration.h:61
std::mutex mutex_type
Definition: enumeration.h:62
atomic< float_type > global_A_t
Definition: enumeration.h:65
Definition: enumeration.h:69
std::function< extenum_cb_process_subsol > process_subsol
Definition: enumeration.h:78
mutex_type mutex
Definition: enumeration.h:73
global_A_t A
Definition: enumeration.h:74
pair< introw_t, pair< float_type, float_type > > swirl_item_t
Definition: enumeration.h:71
vector< vector< swirl_item_t > > swirlys
Definition: enumeration.h:80
global_signal_t signal
Definition: enumeration.h:75
std::function< extenum_cb_process_sol > process_sol
Definition: enumeration.h:77
array< int, N > introw_t
Definition: enumeration.h:70
Definition: enumeration.h:151
Definition: enumeration.h:87
fltrow_t risq
Definition: enumeration.h:94
introw_t _D2x
Definition: enumeration.h:107
fltrow_t _sol
Definition: enumeration.h:108
std::chrono::system_clock::time_point starttime
Definition: enumeration.h:119
fltrow_t pr2
Definition: enumeration.h:96
void enumerate_recur(i_tag< N+1, svp, swirl, swirlid >)
Definition: enumeration.h:303
lattice_enum_t(globals_t< N > &globals_)
Definition: enumeration.h:121
introw_t _x
Definition: enumeration.h:107
int myround(float a)
Definition: enumeration.h:127
fltrow_t _AA
Definition: enumeration.h:106
void enumerate_recur(i_tag<-1, svp, swirl, swirlid >)
Definition: enumeration.h:281
array< float_type, N+1 > _l
Definition: enumeration.h:111
void _update_AA()
Definition: enumeration.h:140
fltrow_t _AA2
Definition: enumeration.h:106
void _thread_local_update()
Definition: enumeration.h:130
fltrow_t pr
Definition: enumeration.h:95
float_type _A
Definition: enumeration.h:105
introw_t _r
Definition: enumeration.h:110
globals_t< N > & globals
Definition: enumeration.h:103
bool activeswirly
Definition: enumeration.h:99
pair< introw_t, pair< float_type, float_type > > swirl_item_t
Definition: enumeration.h:90
float_type muT[N][N]
Definition: enumeration.h:93
void enumerate_recur(i_tag< 0, svp, swirl, swirlid >)
Definition: enumeration.h:222
int threadid
Definition: enumeration.h:102
fltrow_t _c
Definition: enumeration.h:109
float_type _sigT[N][N]
Definition: enumeration.h:114
int myround(double a)
Definition: enumeration.h:126
fltrow_t _subsolL
Definition: enumeration.h:116
void enumerate_recur(i_tag< N+2, svp, swirl, swirlid >)
Definition: enumeration.h:307
int myround(long double a)
Definition: enumeration.h:128
void enumerate_recur(i_tag< i, svp, swirl, swirlid >)
Definition: enumeration.h:155
void enumerate_recursive()
Definition: enumeration.h:382
void enumerate_recur(i_tag< i, svp, i, swirlid >)
Definition: enumeration.h:311
array< int, N > introw_t
Definition: enumeration.h:89
array< fltrow_t, N > _subsol
Definition: enumeration.h:117
array< uint64_t, N+1 > _counts
Definition: enumeration.h:112
introw_t _Dx
Definition: enumeration.h:107
array< float_type, N > fltrow_t
Definition: enumeration.h:88
#define N
Read matrix from input_filename.
Definition: test_pruner.cpp:34
FPLLL_BEGIN_NAMESPACE thread_pool::thread_pool threadpool
Definition: threadpool.cpp:20
int get_threads()
Definition: threadpool.cpp:23
FPLLL_BEGIN_NAMESPACE typedef std::mutex mutex
Definition: threadpool.h:25