blob: dce2682dbb1bcadf7e2e33b8f86b4674e692f050 [file] [log] [blame]
Piotr Krysik9e2e8352018-02-27 12:16:25 +01001/*! \file conv_acc.c
2 * Accelerated Viterbi decoder implementation. */
3/*
4 * Copyright (C) 2013, 2014 Thomas Tsou <tom@tsou.cc>
5 *
6 * All Rights Reserved
7 *
8 * SPDX-License-Identifier: GPL-2.0+
9 *
10 * This program is free software; you can redistribute it and/or modify
11 * it under the terms of the GNU General Public License as published by
12 * the Free Software Foundation; either version 2 of the License, or
13 * (at your option) any later version.
14 *
15 * This program is distributed in the hope that it will be useful,
16 * but WITHOUT ANY WARRANTY; without even the implied warranty of
17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 * GNU General Public License for more details.
19 *
20 * You should have received a copy of the GNU General Public License along
21 * with this program; if not, write to the Free Software Foundation, Inc.,
22 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23 */
24
25#include <stdlib.h>
26#include <string.h>
27#include <errno.h>
28
29#ifdef HAVE_CONFIG_H
30#include "config.h"
31#endif
32
33#define __attribute__(_arg_)
34
35#include <osmocom/core/conv.h>
36
37#define BIT2NRZ(REG,N) (((REG >> N) & 0x01) * 2 - 1) * -1
38#define NUM_STATES(K) (K == 7 ? 64 : 16)
39
40#define INIT_POINTERS(simd) \
41{ \
42 osmo_conv_metrics_k5_n2 = osmo_conv_##simd##_metrics_k5_n2; \
43 osmo_conv_metrics_k5_n3 = osmo_conv_##simd##_metrics_k5_n3; \
44 osmo_conv_metrics_k5_n4 = osmo_conv_##simd##_metrics_k5_n4; \
45 osmo_conv_metrics_k7_n2 = osmo_conv_##simd##_metrics_k7_n2; \
46 osmo_conv_metrics_k7_n3 = osmo_conv_##simd##_metrics_k7_n3; \
47 osmo_conv_metrics_k7_n4 = osmo_conv_##simd##_metrics_k7_n4; \
48 vdec_malloc = &osmo_conv_##simd##_vdec_malloc; \
49 vdec_free = &osmo_conv_##simd##_vdec_free; \
50}
51
52static int init_complete = 0;
53
54__attribute__ ((visibility("hidden"))) int avx2_supported = 0;
55__attribute__ ((visibility("hidden"))) int ssse3_supported = 0;
56__attribute__ ((visibility("hidden"))) int sse41_supported = 0;
57
58/**
59 * These pointers are being initialized at runtime by the
60 * osmo_conv_init() depending on supported SIMD extensions.
61 */
62static int16_t *(*vdec_malloc)(size_t n);
63static void (*vdec_free)(int16_t *ptr);
64
65void (*osmo_conv_metrics_k5_n2)(const int8_t *seq,
66 const int16_t *out, int16_t *sums, int16_t *paths, int norm);
67void (*osmo_conv_metrics_k5_n3)(const int8_t *seq,
68 const int16_t *out, int16_t *sums, int16_t *paths, int norm);
69void (*osmo_conv_metrics_k5_n4)(const int8_t *seq,
70 const int16_t *out, int16_t *sums, int16_t *paths, int norm);
71void (*osmo_conv_metrics_k7_n2)(const int8_t *seq,
72 const int16_t *out, int16_t *sums, int16_t *paths, int norm);
73void (*osmo_conv_metrics_k7_n3)(const int8_t *seq,
74 const int16_t *out, int16_t *sums, int16_t *paths, int norm);
75void (*osmo_conv_metrics_k7_n4)(const int8_t *seq,
76 const int16_t *out, int16_t *sums, int16_t *paths, int norm);
77
78/* Forward malloc wrappers */
79int16_t *osmo_conv_gen_vdec_malloc(size_t n);
80void osmo_conv_gen_vdec_free(int16_t *ptr);
81
82#if defined(HAVE_SSSE3)
83int16_t *osmo_conv_sse_vdec_malloc(size_t n);
84void osmo_conv_sse_vdec_free(int16_t *ptr);
85#endif
86
87#if defined(HAVE_SSSE3) && defined(HAVE_AVX2)
88int16_t *osmo_conv_sse_avx_vdec_malloc(size_t n);
89void osmo_conv_sse_avx_vdec_free(int16_t *ptr);
90#endif
91
92/* Forward Metric Units */
93void osmo_conv_gen_metrics_k5_n2(const int8_t *seq, const int16_t *out,
94 int16_t *sums, int16_t *paths, int norm);
95void osmo_conv_gen_metrics_k5_n3(const int8_t *seq, const int16_t *out,
96 int16_t *sums, int16_t *paths, int norm);
97void osmo_conv_gen_metrics_k5_n4(const int8_t *seq, const int16_t *out,
98 int16_t *sums, int16_t *paths, int norm);
99void osmo_conv_gen_metrics_k7_n2(const int8_t *seq, const int16_t *out,
100 int16_t *sums, int16_t *paths, int norm);
101void osmo_conv_gen_metrics_k7_n3(const int8_t *seq, const int16_t *out,
102 int16_t *sums, int16_t *paths, int norm);
103void osmo_conv_gen_metrics_k7_n4(const int8_t *seq, const int16_t *out,
104 int16_t *sums, int16_t *paths, int norm);
105
106#if defined(HAVE_SSSE3)
107void osmo_conv_sse_metrics_k5_n2(const int8_t *seq, const int16_t *out,
108 int16_t *sums, int16_t *paths, int norm);
109void osmo_conv_sse_metrics_k5_n3(const int8_t *seq, const int16_t *out,
110 int16_t *sums, int16_t *paths, int norm);
111void osmo_conv_sse_metrics_k5_n4(const int8_t *seq, const int16_t *out,
112 int16_t *sums, int16_t *paths, int norm);
113void osmo_conv_sse_metrics_k7_n2(const int8_t *seq, const int16_t *out,
114 int16_t *sums, int16_t *paths, int norm);
115void osmo_conv_sse_metrics_k7_n3(const int8_t *seq, const int16_t *out,
116 int16_t *sums, int16_t *paths, int norm);
117void osmo_conv_sse_metrics_k7_n4(const int8_t *seq, const int16_t *out,
118 int16_t *sums, int16_t *paths, int norm);
119#endif
120
121#if defined(HAVE_SSSE3) && defined(HAVE_AVX2)
122void osmo_conv_sse_avx_metrics_k5_n2(const int8_t *seq, const int16_t *out,
123 int16_t *sums, int16_t *paths, int norm);
124void osmo_conv_sse_avx_metrics_k5_n3(const int8_t *seq, const int16_t *out,
125 int16_t *sums, int16_t *paths, int norm);
126void osmo_conv_sse_avx_metrics_k5_n4(const int8_t *seq, const int16_t *out,
127 int16_t *sums, int16_t *paths, int norm);
128void osmo_conv_sse_avx_metrics_k7_n2(const int8_t *seq, const int16_t *out,
129 int16_t *sums, int16_t *paths, int norm);
130void osmo_conv_sse_avx_metrics_k7_n3(const int8_t *seq, const int16_t *out,
131 int16_t *sums, int16_t *paths, int norm);
132void osmo_conv_sse_avx_metrics_k7_n4(const int8_t *seq, const int16_t *out,
133 int16_t *sums, int16_t *paths, int norm);
134#endif
135
136/* Trellis State
137 * state - Internal lshift register value
138 * prev - Register values of previous 0 and 1 states
139 */
140struct vstate {
141 unsigned state;
142 unsigned prev[2];
143};
144
145/* Trellis Object
146 * num_states - Number of states in the trellis
147 * sums - Accumulated path metrics
148 * outputs - Trellis output values
149 * vals - Input value that led to each state
150 */
151struct vtrellis {
152 int num_states;
153 int16_t *sums;
154 int16_t *outputs;
155 uint8_t *vals;
156};
157
158/* Viterbi Decoder
159 * n - Code order
160 * k - Constraint length
161 * len - Horizontal length of trellis
162 * recursive - Set to '1' if the code is recursive
163 * intrvl - Normalization interval
164 * trellis - Trellis object
165 * paths - Trellis paths
166 */
167struct vdecoder {
168 int n;
169 int k;
170 int len;
171 int recursive;
172 int intrvl;
173 struct vtrellis trellis;
174 int16_t **paths;
175
176 void (*metric_func)(const int8_t *, const int16_t *,
177 int16_t *, int16_t *, int);
178};
179
180/* Accessor calls */
181static inline int conv_code_recursive(const struct osmo_conv_code *code)
182{
183 return code->next_term_output ? 1 : 0;
184}
185
186/* Left shift and mask for finding the previous state */
187static unsigned vstate_lshift(unsigned reg, int k, int val)
188{
189 unsigned mask;
190
191 if (k == 5)
192 mask = 0x0e;
193 else if (k == 7)
194 mask = 0x3e;
195 else
196 mask = 0;
197
198 return ((reg << 1) & mask) | val;
199}
200
201/* Bit endian manipulators */
202static inline unsigned bitswap2(unsigned v)
203{
204 return ((v & 0x02) >> 1) | ((v & 0x01) << 1);
205}
206
207static inline unsigned bitswap3(unsigned v)
208{
209 return ((v & 0x04) >> 2) | ((v & 0x02) >> 0) |
210 ((v & 0x01) << 2);
211}
212
213static inline unsigned bitswap4(unsigned v)
214{
215 return ((v & 0x08) >> 3) | ((v & 0x04) >> 1) |
216 ((v & 0x02) << 1) | ((v & 0x01) << 3);
217}
218
219static inline unsigned bitswap5(unsigned v)
220{
221 return ((v & 0x10) >> 4) | ((v & 0x08) >> 2) | ((v & 0x04) >> 0) |
222 ((v & 0x02) << 2) | ((v & 0x01) << 4);
223}
224
225static inline unsigned bitswap6(unsigned v)
226{
227 return ((v & 0x20) >> 5) | ((v & 0x10) >> 3) | ((v & 0x08) >> 1) |
228 ((v & 0x04) << 1) | ((v & 0x02) << 3) | ((v & 0x01) << 5);
229}
230
231static unsigned bitswap(unsigned v, unsigned n)
232{
233 switch (n) {
234 case 1:
235 return v;
236 case 2:
237 return bitswap2(v);
238 case 3:
239 return bitswap3(v);
240 case 4:
241 return bitswap4(v);
242 case 5:
243 return bitswap5(v);
244 case 6:
245 return bitswap6(v);
246 default:
247 return 0;
248 }
249}
250
251/* Generate non-recursive state output from generator state table
252 * Note that the shift register moves right (i.e. the most recent bit is
253 * shifted into the register at k-1 bit of the register), which is typical
254 * textbook representation. The API transition table expects the most recent
255 * bit in the low order bit, or left shift. A bitswap operation is required
256 * to accommodate the difference.
257 */
258static unsigned gen_output(struct vstate *state, int val,
259 const struct osmo_conv_code *code)
260{
261 unsigned out, prev;
262
263 prev = bitswap(state->prev[0], code->K - 1);
264 out = code->next_output[prev][val];
265 out = bitswap(out, code->N);
266
267 return out;
268}
269
270/* Populate non-recursive trellis state
271 * For a given state defined by the k-1 length shift register, find the
272 * value of the input bit that drove the trellis to that state. Also
273 * generate the N outputs of the generator polynomial at that state.
274 */
275static int gen_state_info(uint8_t *val, unsigned reg,
276 int16_t *output, const struct osmo_conv_code *code)
277{
278 int i;
279 unsigned out;
280 struct vstate state;
281
282 /* Previous '0' state */
283 state.state = reg;
284 state.prev[0] = vstate_lshift(reg, code->K, 0);
285 state.prev[1] = vstate_lshift(reg, code->K, 1);
286
287 *val = (reg >> (code->K - 2)) & 0x01;
288
289 /* Transition output */
290 out = gen_output(&state, *val, code);
291
292 /* Unpack to NRZ */
293 for (i = 0; i < code->N; i++)
294 output[i] = BIT2NRZ(out, i);
295
296 return 0;
297}
298
299/* Generate recursive state output from generator state table */
300static unsigned gen_recursive_output(struct vstate *state,
301 uint8_t *val, unsigned reg,
302 const struct osmo_conv_code *code, int pos)
303{
304 int val0, val1;
305 unsigned out, prev;
306
307 /* Previous '0' state */
308 prev = vstate_lshift(reg, code->K, 0);
309 prev = bitswap(prev, code->K - 1);
310
311 /* Input value */
312 val0 = (reg >> (code->K - 2)) & 0x01;
313 val1 = (code->next_term_output[prev] >> pos) & 0x01;
314 *val = val0 == val1 ? 0 : 1;
315
316 /* Wrapper for osmocom state access */
317 prev = bitswap(state->prev[0], code->K - 1);
318
319 /* Compute the transition output */
320 out = code->next_output[prev][*val];
321 out = bitswap(out, code->N);
322
323 return out;
324}
325
326/* Populate recursive trellis state
327 * The bit position of the systematic bit is not explicitly marked by the
328 * API, so it must be extracted from the generator table. Otherwise,
329 * populate the trellis similar to the non-recursive version.
330 * Non-systematic recursive codes are not supported.
331 */
332static int gen_recursive_state_info(uint8_t *val,
333 unsigned reg, int16_t *output, const struct osmo_conv_code *code)
334{
335 int i, j, pos = -1;
336 int ns = NUM_STATES(code->K);
337 unsigned out;
338 struct vstate state;
339
340 /* Previous '0' and '1' states */
341 state.state = reg;
342 state.prev[0] = vstate_lshift(reg, code->K, 0);
343 state.prev[1] = vstate_lshift(reg, code->K, 1);
344
345 /* Find recursive bit location */
346 for (i = 0; i < code->N; i++) {
347 for (j = 0; j < ns; j++) {
348 if ((code->next_output[j][0] >> i) & 0x01)
349 break;
350 }
351
352 if (j == ns) {
353 pos = i;
354 break;
355 }
356 }
357
358 /* Non-systematic recursive code not supported */
359 if (pos < 0)
360 return -EPROTO;
361
362 /* Transition output */
363 out = gen_recursive_output(&state, val, reg, code, pos);
364
365 /* Unpack to NRZ */
366 for (i = 0; i < code->N; i++)
367 output[i] = BIT2NRZ(out, i);
368
369 return 0;
370}
371
372/* Release the trellis */
373static void free_trellis(struct vtrellis *trellis)
374{
375 if (!trellis)
376 return;
377
378 vdec_free(trellis->outputs);
379 vdec_free(trellis->sums);
380 free(trellis->vals);
381}
382
383/* Initialize the trellis object
384 * Initialization consists of generating the outputs and output value of a
385 * given state. Due to trellis symmetry and anti-symmetry, only one of the
386 * transition paths is utilized by the butterfly operation in the forward
387 * recursion, so only one set of N outputs is required per state variable.
388 */
389static int generate_trellis(struct vdecoder *dec,
390 const struct osmo_conv_code *code)
391{
392 struct vtrellis *trellis = &dec->trellis;
393 int16_t *outputs;
394 int i, rc;
395
396 int ns = NUM_STATES(code->K);
397 int olen = (code->N == 2) ? 2 : 4;
398
399 trellis->num_states = ns;
400 trellis->sums = vdec_malloc(ns);
401 trellis->outputs = vdec_malloc(ns * olen);
402 trellis->vals = (uint8_t *) malloc(ns * sizeof(uint8_t));
403
404 if (!trellis->sums || !trellis->outputs || !trellis->vals) {
405 rc = -ENOMEM;
406 goto fail;
407 }
408
409 /* Populate the trellis state objects */
410 for (i = 0; i < ns; i++) {
411 outputs = &trellis->outputs[olen * i];
412 if (dec->recursive) {
413 rc = gen_recursive_state_info(&trellis->vals[i],
414 i, outputs, code);
415 } else {
416 rc = gen_state_info(&trellis->vals[i],
417 i, outputs, code);
418 }
419
420 if (rc < 0)
421 goto fail;
422
423 /* Set accumulated path metrics to zero */
424 trellis->sums[i] = 0;
425 }
426
427 /**
428 * For termination other than tail-biting, initialize the zero state
429 * as the encoder starting state. Initialize with the maximum
430 * accumulated sum at length equal to the constraint length.
431 */
432 if (code->term != CONV_TERM_TAIL_BITING)
433 trellis->sums[0] = INT8_MAX * code->N * code->K;
434
435 return 0;
436
437fail:
438 free_trellis(trellis);
439 return rc;
440}
441
442static void _traceback(struct vdecoder *dec,
443 unsigned state, uint8_t *out, int len)
444{
445 int i;
446 unsigned path;
447
448 for (i = len - 1; i >= 0; i--) {
449 path = dec->paths[i][state] + 1;
450 out[i] = dec->trellis.vals[state];
451 state = vstate_lshift(state, dec->k, path);
452 }
453}
454
455static void _traceback_rec(struct vdecoder *dec,
456 unsigned state, uint8_t *out, int len)
457{
458 int i;
459 unsigned path;
460
461 for (i = len - 1; i >= 0; i--) {
462 path = dec->paths[i][state] + 1;
463 out[i] = path ^ dec->trellis.vals[state];
464 state = vstate_lshift(state, dec->k, path);
465 }
466}
467
468/* Traceback and generate decoded output
469 * Find the largest accumulated path metric at the final state except for
470 * the zero terminated case, where we assume the final state is always zero.
471 */
472static int traceback(struct vdecoder *dec, uint8_t *out, int term, int len)
473{
474 int i, sum, max = -1;
475 unsigned path, state = 0;
476
477 if (term != CONV_TERM_FLUSH) {
478 for (i = 0; i < dec->trellis.num_states; i++) {
479 sum = dec->trellis.sums[i];
480 if (sum > max) {
481 max = sum;
482 state = i;
483 }
484 }
485
486 if (max < 0)
487 return -EPROTO;
488 }
489
490 for (i = dec->len - 1; i >= len; i--) {
491 path = dec->paths[i][state] + 1;
492 state = vstate_lshift(state, dec->k, path);
493 }
494
495 if (dec->recursive)
496 _traceback_rec(dec, state, out, len);
497 else
498 _traceback(dec, state, out, len);
499
500 return 0;
501}
502
503/* Release decoder object */
504static void vdec_deinit(struct vdecoder *dec)
505{
506 if (!dec)
507 return;
508
509 free_trellis(&dec->trellis);
510
511 if (dec->paths != NULL) {
512 vdec_free(dec->paths[0]);
513 free(dec->paths);
514 }
515}
516
517/* Initialize decoder object with code specific params
518 * Subtract the constraint length K on the normalization interval to
519 * accommodate the initialization path metric at state zero.
520 */
521static int vdec_init(struct vdecoder *dec, const struct osmo_conv_code *code)
522{
523 int i, ns, rc;
524
525 ns = NUM_STATES(code->K);
526
527 dec->n = code->N;
528 dec->k = code->K;
529 dec->recursive = conv_code_recursive(code);
530 dec->intrvl = INT16_MAX / (dec->n * INT8_MAX) - dec->k;
531
532 if (dec->k == 5) {
533 switch (dec->n) {
534 case 2:
535 dec->metric_func = osmo_conv_metrics_k5_n2;
536 break;
537 case 3:
538 dec->metric_func = osmo_conv_metrics_k5_n3;
539 break;
540 case 4:
541 dec->metric_func = osmo_conv_metrics_k5_n4;
542 break;
543 default:
544 return -EINVAL;
545 }
546 } else if (dec->k == 7) {
547 switch (dec->n) {
548 case 2:
549 dec->metric_func = osmo_conv_metrics_k7_n2;
550 break;
551 case 3:
552 dec->metric_func = osmo_conv_metrics_k7_n3;
553 break;
554 case 4:
555 dec->metric_func = osmo_conv_metrics_k7_n4;
556 break;
557 default:
558 return -EINVAL;
559 }
560 } else {
561 return -EINVAL;
562 }
563
564 if (code->term == CONV_TERM_FLUSH)
565 dec->len = code->len + code->K - 1;
566 else
567 dec->len = code->len;
568
569 rc = generate_trellis(dec, code);
570 if (rc)
571 return rc;
572
573 dec->paths = (int16_t **) malloc(sizeof(int16_t *) * dec->len);
574 if (!dec->paths)
575 goto enomem;
576
577 dec->paths[0] = vdec_malloc(ns * dec->len);
578 if (!dec->paths[0])
579 goto enomem;
580
581 for (i = 1; i < dec->len; i++)
582 dec->paths[i] = &dec->paths[0][i * ns];
583
584 return 0;
585
586enomem:
587 vdec_deinit(dec);
588 return -ENOMEM;
589}
590
591/* Depuncture sequence with nagative value terminated puncturing matrix */
592static int depuncture(const int8_t *in, const int *punc, int8_t *out, int len)
593{
594 int i, n = 0, m = 0;
595
596 for (i = 0; i < len; i++) {
597 if (i == punc[n]) {
598 out[i] = 0;
599 n++;
600 continue;
601 }
602
603 out[i] = in[m++];
604 }
605
606 return 0;
607}
608
609/* Forward trellis recursion
610 * Generate branch metrics and path metrics with a combined function. Only
611 * accumulated path metric sums and path selections are stored. Normalize on
612 * the interval specified by the decoder.
613 */
614static void forward_traverse(struct vdecoder *dec, const int8_t *seq)
615{
616 int i;
617
618 for (i = 0; i < dec->len; i++) {
619 dec->metric_func(&seq[dec->n * i],
620 dec->trellis.outputs,
621 dec->trellis.sums,
622 dec->paths[i],
623 !(i % dec->intrvl));
624 }
625}
626
627/* Convolutional decode with a decoder object
628 * Initial puncturing run if necessary followed by the forward recursion.
629 * For tail-biting perform a second pass before running the backward
630 * traceback operation.
631 */
632static int conv_decode(struct vdecoder *dec, const int8_t *seq,
633 const int *punc, uint8_t *out, int len, int term)
634{
635 //int8_t depunc[dec->len * dec->n]; //!! this isn't portable, in strict C you can't use size of an array that is not known at compile time
636 int8_t * depunc = malloc(sizeof(int8_t)*dec->len * dec->n);
637
638
639 if (punc) {
640 depuncture(seq, punc, depunc, dec->len * dec->n);
641 seq = depunc;
642 }
643
644 /* Propagate through the trellis with interval normalization */
645 forward_traverse(dec, seq);
646
647 if (term == CONV_TERM_TAIL_BITING)
648 forward_traverse(dec, seq);
649
650 free(depunc);
651 return traceback(dec, out, term, len);
652}
653
654static void osmo_conv_init(void)
655{
656 init_complete = 1;
657
658#ifdef HAVE___BUILTIN_CPU_SUPPORTS
659 /* Detect CPU capabilities */
660 #ifdef HAVE_AVX2
661 avx2_supported = __builtin_cpu_supports("avx2");
662 #endif
663
664 #ifdef HAVE_SSSE3
665 ssse3_supported = __builtin_cpu_supports("ssse3");
666 #endif
667
668 #ifdef HAVE_SSE4_1
669 sse41_supported = __builtin_cpu_supports("sse4.1");
670 #endif
671#endif
672
673/**
674 * Usage of curly braces is mandatory,
675 * because we use multi-line define.
676 */
677#if defined(HAVE_SSSE3) && defined(HAVE_AVX2)
678 if (ssse3_supported && avx2_supported) {
679 INIT_POINTERS(sse_avx);
680 } else if (ssse3_supported) {
681 INIT_POINTERS(sse);
682 } else {
683 INIT_POINTERS(gen);
684 }
685#elif defined(HAVE_SSSE3)
686 if (ssse3_supported) {
687 INIT_POINTERS(sse);
688 } else {
689 INIT_POINTERS(gen);
690 }
691#else
692 INIT_POINTERS(gen);
693#endif
694}
695
696/* All-in-one Viterbi decoding */
697int osmo_conv_decode_acc(const struct osmo_conv_code *code,
698 const sbit_t *input, ubit_t *output)
699{
700 int rc;
701 struct vdecoder dec;
702
703 if (!init_complete)
704 osmo_conv_init();
705
706 if ((code->N < 2) || (code->N > 4) || (code->len < 1) ||
707 ((code->K != 5) && (code->K != 7)))
708 return -EINVAL;
709
710 rc = vdec_init(&dec, code);
711 if (rc)
712 return rc;
713
714 rc = conv_decode(&dec, input, code->puncture,
715 output, code->len, code->term);
716
717 vdec_deinit(&dec);
718
719 return rc;
720}