Browse code

aacenc_pred: rework the way prediction is done

This commit completely alters the algorithm of prediction.
The original commit which introduced prediction was completely
incorrect to even remotely care about what the actual coefficients
contain or whether any options were enabled. Not my actual fault.

This commit treats prediction the way the decoder does and expects
to do: like lossy encryption. Everything related to prediction now
happens at the very end but just before quantization and encoding
of coefficients. On the decoder side, prediction happens before
anything has had a chance to even access the coefficients.

Also the original implementation had problems because it actually
touched the band_type of special bands which already had their
scalefactor indices marked and it's a wonder the asserion wasn't
triggered when transmitting those.

Overall, this now drastically increases audio quality and you should
think about enabling it if you don't plan on playing anything encoded
on really old low power ultra-embedded devices since they might not
support decoding of prediction or AAC-Main. Though the specifications
were written ages ago and as times change so do the FLOPS.

Signed-off-by: Rostislav Pehlivanov <atomnuker@gmail.com>

Rostislav Pehlivanov authored on 2015/08/29 14:34:08
Showing 6 changed files
... ...
@@ -247,7 +247,7 @@ typedef struct SingleChannelElement {
247 247
     TemporalNoiseShaping tns;
248 248
     Pulse pulse;
249 249
     enum BandType band_type[128];                   ///< band types
250
-    enum BandType orig_band_type[128];              ///< band type backups for undoing prediction
250
+    enum BandType band_alt[128];                    ///< alternative band type (used by encoder)
251 251
     int band_type_run_end[120];                     ///< band type run end points
252 252
     INTFLOAT sf[120];                               ///< scalefactors
253 253
     int sf_idx[128];                                ///< scalefactor indices (used by encoder)
... ...
@@ -964,7 +964,6 @@ AACCoefficientsEncoder ff_aac_coders[AAC_CODER_NB] = {
964 964
         ff_aac_encode_main_pred,
965 965
         ff_aac_adjust_common_prediction,
966 966
         ff_aac_apply_main_pred,
967
-        ff_aac_update_main_pred,
968 967
         set_special_band_scalefactors,
969 968
         search_for_pns,
970 969
         ff_aac_search_for_tns,
... ...
@@ -980,7 +979,6 @@ AACCoefficientsEncoder ff_aac_coders[AAC_CODER_NB] = {
980 980
         ff_aac_encode_main_pred,
981 981
         ff_aac_adjust_common_prediction,
982 982
         ff_aac_apply_main_pred,
983
-        ff_aac_update_main_pred,
984 983
         set_special_band_scalefactors,
985 984
         search_for_pns,
986 985
         ff_aac_search_for_tns,
... ...
@@ -996,7 +994,6 @@ AACCoefficientsEncoder ff_aac_coders[AAC_CODER_NB] = {
996 996
         ff_aac_encode_main_pred,
997 997
         ff_aac_adjust_common_prediction,
998 998
         ff_aac_apply_main_pred,
999
-        ff_aac_update_main_pred,
1000 999
         set_special_band_scalefactors,
1001 1000
         search_for_pns,
1002 1001
         ff_aac_search_for_tns,
... ...
@@ -1012,7 +1009,6 @@ AACCoefficientsEncoder ff_aac_coders[AAC_CODER_NB] = {
1012 1012
         ff_aac_encode_main_pred,
1013 1013
         ff_aac_adjust_common_prediction,
1014 1014
         ff_aac_apply_main_pred,
1015
-        ff_aac_update_main_pred,
1016 1015
         set_special_band_scalefactors,
1017 1016
         search_for_pns,
1018 1017
         ff_aac_search_for_tns,
... ...
@@ -354,15 +354,15 @@ static void encode_spectral_coeffs(AACEncContext *s, SingleChannelElement *sce)
354 354
                 start += sce->ics.swb_sizes[i];
355 355
                 continue;
356 356
             }
357
-            for (w2 = w; w2 < w + sce->ics.group_len[w]; w2++)
357
+            for (w2 = w; w2 < w + sce->ics.group_len[w]; w2++) {
358 358
                 s->coder->quantize_and_encode_band(s, &s->pb,
359 359
                                                    &sce->coeffs[start + w2*128],
360
-                                                   &sce->pqcoeffs[start + w2*128],
361
-                                                   sce->ics.swb_sizes[i],
360
+                                                   NULL, sce->ics.swb_sizes[i],
362 361
                                                    sce->sf_idx[w*16 + i],
363 362
                                                    sce->band_type[w*16 + i],
364 363
                                                    s->lambda,
365 364
                                                    sce->ics.window_clipping[w]);
365
+            }
366 366
             start += sce->ics.swb_sizes[i];
367 367
         }
368 368
     }
... ...
@@ -609,12 +609,8 @@ static int aac_encode_frame(AVCodecContext *avctx, AVPacket *avpkt,
609 609
                     s->coder->search_for_pns(s, avctx, sce);
610 610
                 if (s->options.tns && s->coder->search_for_tns)
611 611
                     s->coder->search_for_tns(s, sce);
612
-                if (s->options.pred && s->coder->search_for_pred)
613
-                    s->coder->search_for_pred(s, sce);
614 612
                 if (sce->tns.present)
615 613
                     tns_mode = 1;
616
-                if (sce->ics.predictor_present)
617
-                    pred_mode = 1;
618 614
             }
619 615
             s->cur_channel = start_ch;
620 616
             if (s->options.stereo_mode && cpe->common_window) {
... ...
@@ -631,15 +627,26 @@ static int aac_encode_frame(AVCodecContext *avctx, AVPacket *avpkt,
631 631
                 s->coder->search_for_is(s, avctx, cpe);
632 632
                 if (cpe->is_mode) is_mode = 1;
633 633
             }
634
-            if (s->options.pred && s->coder->adjust_common_prediction)
635
-                s->coder->adjust_common_prediction(s, cpe);
636 634
             if (s->coder->set_special_band_scalefactors)
637 635
                 for (ch = 0; ch < chans; ch++)
638 636
                     s->coder->set_special_band_scalefactors(s, &cpe->ch[ch]);
639
-            if (s->options.pred && s->coder->apply_main_pred)
640
-                for (ch = 0; ch < chans; ch++)
641
-                    s->coder->apply_main_pred(s, &cpe->ch[ch]);
642 637
             adjust_frame_information(cpe, chans);
638
+            for (ch = 0; ch < chans; ch++) {
639
+                sce = &cpe->ch[ch];
640
+                s->cur_channel = start_ch + ch;
641
+                if (s->options.pred && s->coder->search_for_pred)
642
+                    s->coder->search_for_pred(s, sce);
643
+                if (cpe->ch[ch].ics.predictor_present) pred_mode = 1;
644
+            }
645
+            if (s->options.pred && s->coder->adjust_common_prediction)
646
+                s->coder->adjust_common_prediction(s, cpe);
647
+            for (ch = 0; ch < chans; ch++) {
648
+                sce = &cpe->ch[ch];
649
+                s->cur_channel = start_ch + ch;
650
+                if (s->options.pred && s->coder->apply_main_pred)
651
+                    s->coder->apply_main_pred(s, sce);
652
+            }
653
+            s->cur_channel = start_ch;
643 654
             if (chans == 2) {
644 655
                 put_bits(&s->pb, 1, cpe->common_window);
645 656
                 if (cpe->common_window) {
... ...
@@ -676,16 +683,6 @@ static int aac_encode_frame(AVCodecContext *avctx, AVPacket *avpkt,
676 676
 
677 677
     } while (1);
678 678
 
679
-    // update predictor state
680
-    if (s->options.pred && s->coder->update_main_pred) {
681
-        for (i = 0; i < s->chan_map[0]; i++) {
682
-            cpe = &s->cpe[i];
683
-            for (ch = 0; ch < chans; ch++)
684
-                s->coder->update_main_pred(s, &cpe->ch[ch],
685
-                                           (cpe->common_window && !ch) ? cpe : NULL);
686
-        }
687
-    }
688
-
689 679
     put_bits(&s->pb, 3, TYPE_END);
690 680
     flush_put_bits(&s->pb);
691 681
     avctx->frame_bits = put_bits_count(&s->pb);
... ...
@@ -63,7 +63,6 @@ typedef struct AACCoefficientsEncoder {
63 63
     void (*encode_main_pred)(struct AACEncContext *s, SingleChannelElement *sce);
64 64
     void (*adjust_common_prediction)(struct AACEncContext *s, ChannelElement *cpe);
65 65
     void (*apply_main_pred)(struct AACEncContext *s, SingleChannelElement *sce);
66
-    void (*update_main_pred)(struct AACEncContext *s, SingleChannelElement *sce, ChannelElement *cpe);
67 66
     void (*set_special_band_scalefactors)(struct AACEncContext *s, SingleChannelElement *sce);
68 67
     void (*search_for_pns)(struct AACEncContext *s, AVCodecContext *avctx, SingleChannelElement *sce);
69 68
     void (*search_for_tns)(struct AACEncContext *s, SingleChannelElement *sce);
... ...
@@ -21,15 +21,22 @@
21 21
 
22 22
 /**
23 23
  * @file
24
- * AAC encoder main prediction
24
+ * AAC encoder Intensity Stereo
25 25
  * @author Rostislav Pehlivanov ( atomnuker gmail com )
26 26
  */
27 27
 
28 28
 #include "aactab.h"
29 29
 #include "aacenc_pred.h"
30 30
 #include "aacenc_utils.h"
31
+#include "aacenc_is.h"            /* <- Needed for common window distortions */
31 32
 #include "aacenc_quantization.h"
32 33
 
34
+#define RESTORE_PRED(sce, sfb) \
35
+        if (sce->ics.prediction_used[sfb]) {\
36
+            sce->ics.prediction_used[sfb] = 0;\
37
+            sce->band_type[sfb] = sce->band_alt[sfb];\
38
+        }
39
+
33 40
 static inline float flt16_round(float pf)
34 41
 {
35 42
     union av_intfloat32 tmp;
... ...
@@ -54,73 +61,57 @@ static inline float flt16_trunc(float pf)
54 54
     return pun.f;
55 55
 }
56 56
 
57
-static inline void predict(PredictorState *ps, float *coef, float *rcoef,
58
-                           int output_enable)
57
+static inline void predict(PredictorState *ps, float *coef, float *rcoef, int set)
59 58
 {
60
-    const float a     = 0.953125; // 61.0 / 64
61 59
     float k2;
62
-    float   r0 = ps->r0,     r1 = ps->r1;
63
-    float cor0 = ps->cor0, cor1 = ps->cor1;
64
-    float var0 = ps->var0, var1 = ps->var1;
65
-
66
-    ps->k1 = var0 > 1 ? cor0 * flt16_even(a / var0) : 0;
67
-        k2 = var1 > 1 ? cor1 * flt16_even(a / var1) : 0;
68
-
69
-    ps->x_est = flt16_round(ps->k1*r0 + k2*r1);
70
-
71
-    if (output_enable)
72
-        *coef -= ps->x_est;
73
-    else
74
-        *rcoef = *coef - ps->x_est;
75
-}
76
-
77
-static inline void update_predictor(PredictorState *ps, float qcoef)
78
-{
79
-    const float alpha = 0.90625;  // 29.0 / 32
80 60
     const float a     = 0.953125; // 61.0 / 64
81
-    float k1 = ps->k1;
82
-    float r0 = ps->r0;
83
-    float r1 = ps->r1;
84
-    float e0 = qcoef + ps->x_est;
85
-    float e1 = e0 - k1 * r0;
86
-    float cor0 = ps->cor0, cor1 = ps->cor1;
87
-    float var0 = ps->var0, var1 = ps->var1;
61
+    const float alpha = 0.90625;  // 29.0 / 32
62
+    const float   k1 = ps->k1;
63
+    const float   r0 = ps->r0,     r1 = ps->r1;
64
+    const float cor0 = ps->cor0, cor1 = ps->cor1;
65
+    const float var0 = ps->var0, var1 = ps->var1;
66
+    const float e0 = *coef - ps->x_est;
67
+    const float e1 = e0 - k1 * r0;
68
+
69
+    if (set)
70
+        *coef = e0;
88 71
 
89 72
     ps->cor1 = flt16_trunc(alpha * cor1 + r1 * e1);
90 73
     ps->var1 = flt16_trunc(alpha * var1 + 0.5f * (r1 * r1 + e1 * e1));
91 74
     ps->cor0 = flt16_trunc(alpha * cor0 + r0 * e0);
92 75
     ps->var0 = flt16_trunc(alpha * var0 + 0.5f * (r0 * r0 + e0 * e0));
76
+    ps->r1   = flt16_trunc(a * (r0 - k1 * e0));
77
+    ps->r0   = flt16_trunc(a * e0);
93 78
 
94
-    ps->r1 = flt16_trunc(a * (r0 - k1 * e0));
95
-    ps->r0 = flt16_trunc(a * e0);
79
+    /* Prediction for next frame */
80
+    ps->k1   = ps->var0 > 1 ? ps->cor0 * flt16_even(a / ps->var0) : 0;
81
+    k2       = ps->var1 > 1 ? ps->cor1 * flt16_even(a / ps->var1) : 0;
82
+    *rcoef   = ps->x_est = flt16_round(ps->k1*ps->r0 + k2*ps->r1);
96 83
 }
97 84
 
98 85
 static inline void reset_predict_state(PredictorState *ps)
99 86
 {
100
-    ps->r0   = 0.0f;
101
-    ps->r1   = 0.0f;
102
-    ps->cor0 = 0.0f;
103
-    ps->cor1 = 0.0f;
104
-    ps->var0 = 1.0f;
105
-    ps->var1 = 1.0f;
106
-    ps->k1   = 0.0f;
107
-    ps->x_est= 0.0f;
87
+    ps->r0    = 0.0f;
88
+    ps->r1    = 0.0f;
89
+    ps->k1    = 0.0f;
90
+    ps->cor0  = 0.0f;
91
+    ps->cor1  = 0.0f;
92
+    ps->var0  = 1.0f;
93
+    ps->var1  = 1.0f;
94
+    ps->x_est = 0.0f;
108 95
 }
109 96
 
110
-static inline void reset_all_predictors(SingleChannelElement *sce)
97
+static inline void reset_all_predictors(PredictorState *ps)
111 98
 {
112 99
     int i;
113 100
     for (i = 0; i < MAX_PREDICTORS; i++)
114
-        reset_predict_state(&sce->predictor_state[i]);
115
-    for (i = 1; i < 31; i++)
116
-        sce->ics.predictor_reset_count[i] = 0;
101
+        reset_predict_state(&ps[i]);
117 102
 }
118 103
 
119 104
 static inline void reset_predictor_group(SingleChannelElement *sce, int group_num)
120 105
 {
121 106
     int i;
122 107
     PredictorState *ps = sce->predictor_state;
123
-    sce->ics.predictor_reset_count[group_num] = 0;
124 108
     for (i = group_num - 1; i < MAX_PREDICTORS; i += 30)
125 109
         reset_predict_state(&ps[i]);
126 110
 }
... ...
@@ -128,136 +119,89 @@ static inline void reset_predictor_group(SingleChannelElement *sce, int group_nu
128 128
 void ff_aac_apply_main_pred(AACEncContext *s, SingleChannelElement *sce)
129 129
 {
130 130
     int sfb, k;
131
+    const int pmax = FFMIN(sce->ics.max_sfb, ff_aac_pred_sfb_max[s->samplerate_index]);
131 132
 
132 133
     if (sce->ics.window_sequence[0] != EIGHT_SHORT_SEQUENCE) {
133
-        for (sfb = 0; sfb < ff_aac_pred_sfb_max[s->samplerate_index]; sfb++) {
134
-            for (k = sce->ics.swb_offset[sfb]; k < sce->ics.swb_offset[sfb + 1]; k++)
134
+        for (sfb = 0; sfb < pmax; sfb++) {
135
+            for (k = sce->ics.swb_offset[sfb]; k < sce->ics.swb_offset[sfb + 1]; k++) {
135 136
                 predict(&sce->predictor_state[k], &sce->coeffs[k], &sce->prcoeffs[k],
136
-                        (sce->ics.predictor_present && sce->ics.prediction_used[sfb]));
137
-        }
138
-    }
139
-}
140
-
141
-static void decode_joint_stereo(ChannelElement *cpe)
142
-{
143
-    int i, w, w2, g;
144
-    SingleChannelElement *sce0 = &cpe->ch[0];
145
-    SingleChannelElement *sce1 = &cpe->ch[1];
146
-    IndividualChannelStream *ics;
147
-
148
-    for (i = 0; i < MAX_PREDICTORS; i++)
149
-        sce0->prcoeffs[i] = sce0->predictor_state[i].x_est;
150
-
151
-    ics = &sce0->ics;
152
-    for (w = 0; w < ics->num_windows; w += ics->group_len[w]) {
153
-        for (w2 =  0; w2 < ics->group_len[w]; w2++) {
154
-            int start = (w+w2) * 128;
155
-            for (g = 0; g < ics->num_swb; g++) {
156
-                int sfb = w*16 + g;
157
-                //apply Intensity stereo coeffs transformation
158
-                if (cpe->is_mask[sfb]) {
159
-                    int p = -1 + 2 * (sce1->band_type[sfb] - 14);
160
-                    float rscale = ff_aac_pow2sf_tab[-sce1->sf_idx[sfb] + POW_SF2_ZERO];
161
-                    p *= 1 - 2 * cpe->ms_mask[sfb];
162
-                    for (i = 0; i < ics->swb_sizes[g]; i++) {
163
-                        sce0->pqcoeffs[start+i] = (sce0->prcoeffs[start+i] + p*sce0->pqcoeffs[start+i]) * rscale;
164
-                    }
165
-                } else if (cpe->ms_mask[sfb] &&
166
-                           sce0->band_type[sfb] < NOISE_BT &&
167
-                           sce1->band_type[sfb] < NOISE_BT) {
168
-                    for (i = 0; i < ics->swb_sizes[g]; i++) {
169
-                        float L = sce0->pqcoeffs[start+i] + sce1->pqcoeffs[start+i];
170
-                        float R = sce0->pqcoeffs[start+i] - sce1->pqcoeffs[start+i];
171
-                        sce0->pqcoeffs[start+i] = L;
172
-                        sce1->pqcoeffs[start+i] = R;
173
-                    }
174
-                }
175
-                start += ics->swb_sizes[g];
137
+                        sce->ics.predictor_present && sce->ics.prediction_used[sfb]);
176 138
             }
177 139
         }
140
+        if (sce->ics.predictor_reset_group) {
141
+            reset_predictor_group(sce, sce->ics.predictor_reset_group);
142
+        }
143
+    } else {
144
+        reset_all_predictors(sce->predictor_state);
178 145
     }
179 146
 }
180 147
 
181
-static inline void prepare_predictors(SingleChannelElement *sce)
182
-{
183
-    int k;
184
-    for (k = 0; k < MAX_PREDICTORS; k++)
185
-        predict(&sce->predictor_state[k], &sce->coeffs[k], &sce->prcoeffs[k], 0);
186
-}
187
-
188
-void ff_aac_update_main_pred(AACEncContext *s, SingleChannelElement *sce, ChannelElement *cpe)
189
-{
190
-    int k;
191
-
192
-    if (sce->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE)
193
-        return;
194
-
195
-    if (cpe && cpe->common_window)
196
-        decode_joint_stereo(cpe);
197
-
198
-    for (k = 0; k < MAX_PREDICTORS; k++)
199
-        update_predictor(&sce->predictor_state[k], sce->pqcoeffs[k]);
200
-
201
-    if (sce->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE) {
202
-        reset_all_predictors(sce);
203
-    }
204
-
205
-    if (sce->ics.predictor_reset_group)
206
-        reset_predictor_group(sce, sce->ics.predictor_reset_group);
207
-}
208
-
209
-/* If inc == 0 check if it returns 0 to see if you can reset freely */
148
+/* If inc = 0 you can check if this returns 0 to see if you can reset freely */
210 149
 static inline int update_counters(IndividualChannelStream *ics, int inc)
211 150
 {
212
-    int i, rg = 0;
151
+    int i;
213 152
     for (i = 1; i < 31; i++) {
214 153
         ics->predictor_reset_count[i] += inc;
215
-        if (!rg && ics->predictor_reset_count[i] > PRED_RESET_FRAME_MIN)
216
-            rg = i; /* Reset this immediately */
154
+        if (ics->predictor_reset_count[i] > PRED_RESET_FRAME_MIN)
155
+            return i; /* Reset this immediately */
217 156
     }
218
-    return rg;
157
+    return 0;
219 158
 }
220 159
 
221 160
 void ff_aac_adjust_common_prediction(AACEncContext *s, ChannelElement *cpe)
222 161
 {
223
-    int start, w, g, count = 0;
162
+    int start, w, w2, g, i, count = 0;
224 163
     SingleChannelElement *sce0 = &cpe->ch[0];
225 164
     SingleChannelElement *sce1 = &cpe->ch[1];
165
+    const int pmax0 = FFMIN(sce0->ics.max_sfb, ff_aac_pred_sfb_max[s->samplerate_index]);
166
+    const int pmax1 = FFMIN(sce1->ics.max_sfb, ff_aac_pred_sfb_max[s->samplerate_index]);
167
+    const int pmax  = FFMIN(pmax0, pmax1);
226 168
 
227
-    if (!cpe->common_window || sce0->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE)
169
+    if (!cpe->common_window ||
170
+        sce0->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE ||
171
+        sce1->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE)
228 172
         return;
229 173
 
230
-    /* Predict if IS or MS is on and at least one channel is marked or when both are */
231 174
     for (w = 0; w < sce0->ics.num_windows; w += sce0->ics.group_len[w]) {
232 175
         start = 0;
233 176
         for (g = 0; g < sce0->ics.num_swb; g++) {
234 177
             int sfb = w*16+g;
235
-            if (sfb < PRED_SFB_START || sfb > ff_aac_pred_sfb_max[s->samplerate_index]) {
236
-                ;
237
-            } else if ((cpe->is_mask[sfb] || cpe->ms_mask[sfb]) &&
238
-                (sce0->ics.prediction_used[sfb] || sce1->ics.prediction_used[sfb])) {
239
-                sce0->ics.prediction_used[sfb] = sce1->ics.prediction_used[sfb] = 1;
240
-                count++;
241
-            } else if (sce0->ics.prediction_used[sfb] && sce1->ics.prediction_used[sfb]) {
178
+            int sum = sce0->ics.prediction_used[sfb] + sce1->ics.prediction_used[sfb];
179
+            float ener0 = 0.0f, ener1 = 0.0f, ener01 = 0.0f;
180
+            struct AACISError ph_err1, ph_err2, *erf;
181
+            if (sfb < PRED_SFB_START || sfb > pmax || sum != 2) {
182
+                RESTORE_PRED(sce0, sfb);
183
+                RESTORE_PRED(sce1, sfb);
184
+                start += sce0->ics.swb_sizes[g];
185
+                continue;
186
+            }
187
+            for (w2 = 0; w2 < sce0->ics.group_len[w]; w2++) {
188
+                for (i = 0; i < sce0->ics.swb_sizes[g]; i++) {
189
+                    float coef0 = sce0->pcoeffs[start+(w+w2)*128+i];
190
+                    float coef1 = sce1->pcoeffs[start+(w+w2)*128+i];
191
+                    ener0 += coef0*coef0;
192
+                    ener1 += coef1*coef1;
193
+                    ener01 += (coef0 + coef1)*(coef0 + coef1);
194
+                }
195
+            }
196
+            ph_err1 = ff_aac_is_encoding_err(s, cpe, start, w, g,
197
+                                             ener0, ener1, ener01, -1);
198
+            ph_err2 = ff_aac_is_encoding_err(s, cpe, start, w, g,
199
+                                             ener0, ener1, ener01, +1);
200
+            erf = ph_err1.error < ph_err2.error ? &ph_err1 : &ph_err2;
201
+            if (erf->pass) {
202
+                sce0->ics.prediction_used[sfb] = 1;
203
+                sce1->ics.prediction_used[sfb] = 1;
242 204
                 count++;
243 205
             } else {
244
-                /* Restore band types, if changed - prediction never sets > RESERVED_BT */
245
-                if (sce0->ics.prediction_used[sfb] && sce0->band_type[sfb] < RESERVED_BT)
246
-                    sce0->band_type[sfb] = sce0->orig_band_type[sfb];
247
-                if (sce1->ics.prediction_used[sfb] && sce1->band_type[sfb] < RESERVED_BT)
248
-                    sce1->band_type[sfb] = sce1->orig_band_type[sfb];
249
-                sce0->ics.prediction_used[sfb] = sce1->ics.prediction_used[sfb] = 0;
206
+                RESTORE_PRED(sce0, sfb);
207
+                RESTORE_PRED(sce1, sfb);
250 208
             }
251 209
             start += sce0->ics.swb_sizes[g];
252 210
         }
253 211
     }
254 212
 
255 213
     sce1->ics.predictor_present = sce0->ics.predictor_present = !!count;
256
-
257
-    if (!count)
258
-        return;
259
-
260
-    sce1->ics.predictor_reset_group = sce0->ics.predictor_reset_group;
261 214
 }
262 215
 
263 216
 static void update_pred_resets(SingleChannelElement *sce)
... ...
@@ -266,14 +210,12 @@ static void update_pred_resets(SingleChannelElement *sce)
266 266
     float avg_frame = 0.0f;
267 267
     IndividualChannelStream *ics = &sce->ics;
268 268
 
269
-    /* Some other code probably chose the reset group */
270
-    if (ics->predictor_reset_group)
271
-        return;
272
-
269
+    /* Update the counters and immediately update any frame behind schedule */
273 270
     if ((ics->predictor_reset_group = update_counters(&sce->ics, 1)))
274 271
         return;
275 272
 
276 273
     for (i = 1; i < 31; i++) {
274
+        /* Count-based */
277 275
         if (ics->predictor_reset_count[i] > max_frame) {
278 276
             max_group_id_c = i;
279 277
             max_frame = ics->predictor_reset_count[i];
... ...
@@ -281,8 +223,7 @@ static void update_pred_resets(SingleChannelElement *sce)
281 281
         avg_frame = (ics->predictor_reset_count[i] + avg_frame)/2;
282 282
     }
283 283
 
284
-    if (avg_frame*2 > max_frame && max_frame > PRED_RESET_MIN ||
285
-        max_frame > (2*PRED_RESET_MIN)/3) {
284
+    if (max_frame > PRED_RESET_MIN) {
286 285
         ics->predictor_reset_group = max_group_id_c;
287 286
     } else {
288 287
         ics->predictor_reset_group = 0;
... ...
@@ -291,56 +232,91 @@ static void update_pred_resets(SingleChannelElement *sce)
291 291
 
292 292
 void ff_aac_search_for_pred(AACEncContext *s, SingleChannelElement *sce)
293 293
 {
294
-    int sfb, i, count = 0;
295
-    float *O34  = &s->scoefs[256*0], *P34  = &s->scoefs[256*1];
296
-    int cost_coeffs = PRICE_OFFSET;
297
-    int cost_pred = 1+(sce->ics.predictor_reset_group ? 5 : 0) +
298
-                  FFMIN(sce->ics.max_sfb, ff_aac_pred_sfb_max[s->samplerate_index]);
294
+    int sfb, i, count = 0, cost_coeffs = 0, cost_pred = 0;
295
+    const int pmax = FFMIN(sce->ics.max_sfb, ff_aac_pred_sfb_max[s->samplerate_index]);
296
+    float *O34  = &s->scoefs[128*0], *P34 = &s->scoefs[128*1];
297
+    float *SENT = &s->scoefs[128*2], *S34 = &s->scoefs[128*3];
298
+    float *QERR = &s->scoefs[128*4];
299 299
 
300
-    memcpy(sce->orig_band_type, sce->band_type, 128*sizeof(enum BandType));
300
+    if (sce->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE) {
301
+        sce->ics.predictor_present = 0;
302
+        return;
303
+    }
301 304
 
302
-    if (!sce->ics.predictor_initialized ||
303
-        sce->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE) {
304
-        reset_all_predictors(sce);
305
+    if (!sce->ics.predictor_initialized) {
306
+        reset_all_predictors(sce->predictor_state);
307
+        sce->ics.predictor_initialized = 1;
308
+        memcpy(sce->prcoeffs, sce->coeffs, 1024*sizeof(float));
305 309
         for (i = 1; i < 31; i++)
306 310
             sce->ics.predictor_reset_count[i] = i;
307
-        sce->ics.predictor_initialized = 1;
308 311
     }
309 312
 
310 313
     update_pred_resets(sce);
311
-    prepare_predictors(sce);
312
-    sce->ics.predictor_reset_group = 0;
313
-
314
-    for (sfb = PRED_SFB_START; sfb < ff_aac_pred_sfb_max[s->samplerate_index]; sfb++) {
315
-        float dist1 = 0.0f, dist2 = 0.0f;
316
-        int swb_start = sce->ics.swb_offset[sfb];
317
-        int swb_len = sce->ics.swb_offset[sfb + 1] - swb_start;
318
-        int cb1 = sce->band_type[sfb], cb2, bits1 = 0, bits2 = 0;
319
-        FFPsyBand *band = &s->psy.ch[s->cur_channel].psy_bands[sfb];
320
-        abs_pow34_v(O34, &sce->coeffs[swb_start], swb_len);
321
-        abs_pow34_v(P34, &sce->prcoeffs[swb_start], swb_len);
322
-        cb2 = find_min_book(find_max_val(1, swb_len, P34), sce->sf_idx[sfb]);
323
-        if (cb2 <= cb1) {
324
-            dist1 += quantize_band_cost(s, &sce->coeffs[swb_start],   O34, swb_len,
325
-                                        sce->sf_idx[sfb], cb1, s->lambda / band->threshold,
326
-                                        INFINITY, &bits1, 0);
327
-            dist2 += quantize_band_cost(s, &sce->prcoeffs[swb_start], P34, swb_len,
328
-                                        sce->sf_idx[sfb], cb2, s->lambda / band->threshold,
329
-                                        INFINITY, &bits2, 0);
330
-            if (dist2 <= dist1) {
331
-                sce->ics.prediction_used[sfb] = 1;
332
-                sce->band_type[sfb] = cb2;
333
-                count++;
334
-            }
335
-            cost_coeffs += bits1;
336
-            cost_pred   += bits2;
314
+    memcpy(sce->band_alt, sce->band_type, sizeof(sce->band_type));
315
+
316
+    for (sfb = PRED_SFB_START; sfb < pmax; sfb++) {
317
+        int cost1, cost2, cb_p;
318
+        float dist1, dist2, dist_spec_err = 0.0f;
319
+        const int cb_n = sce->band_type[sfb];
320
+        const int start_coef = sce->ics.swb_offset[sfb];
321
+        const int num_coeffs = sce->ics.swb_offset[sfb + 1] - start_coef;
322
+        const FFPsyBand *band = &s->psy.ch[s->cur_channel].psy_bands[sfb];
323
+
324
+        if (start_coef + num_coeffs > MAX_PREDICTORS)
325
+            continue;
326
+
327
+        /* Normal coefficients */
328
+        abs_pow34_v(O34, &sce->coeffs[start_coef], num_coeffs);
329
+        dist1 = quantize_and_encode_band_cost(s, NULL, &sce->coeffs[start_coef], NULL,
330
+                                              O34, num_coeffs, sce->sf_idx[sfb],
331
+                                              cb_n, s->lambda / band->threshold, INFINITY, &cost1, 0);
332
+        cost_coeffs += cost1;
333
+
334
+        /* Encoded coefficients - needed for #bits, band type and quant. error */
335
+        for (i = 0; i < num_coeffs; i++)
336
+            SENT[i] = sce->coeffs[start_coef + i] - sce->prcoeffs[start_coef + i];
337
+        abs_pow34_v(S34, SENT, num_coeffs);
338
+        if (cb_n < RESERVED_BT)
339
+            cb_p = find_min_book(find_max_val(1, num_coeffs, S34), sce->sf_idx[sfb]);
340
+        else
341
+            cb_p = cb_n;
342
+        quantize_and_encode_band_cost(s, NULL, SENT, QERR, S34, num_coeffs,
343
+                                      sce->sf_idx[sfb], cb_p, s->lambda / band->threshold, INFINITY,
344
+                                      &cost2, 0);
345
+
346
+        /* Reconstructed coefficients - needed for distortion measurements */
347
+        for (i = 0; i < num_coeffs; i++)
348
+            sce->prcoeffs[start_coef + i] += QERR[i] != 0.0f ? (sce->prcoeffs[start_coef + i] - QERR[i]) : 0.0f;
349
+        abs_pow34_v(P34, &sce->prcoeffs[start_coef], num_coeffs);
350
+        if (cb_n < RESERVED_BT)
351
+            cb_p = find_min_book(find_max_val(1, num_coeffs, P34), sce->sf_idx[sfb]);
352
+        else
353
+            cb_p = cb_n;
354
+        dist2 = quantize_and_encode_band_cost(s, NULL, &sce->prcoeffs[start_coef], NULL,
355
+                                              P34, num_coeffs, sce->sf_idx[sfb],
356
+                                              cb_p, s->lambda / band->threshold, INFINITY, NULL, 0);
357
+        for (i = 0; i < num_coeffs; i++)
358
+            dist_spec_err += (O34[i] - P34[i])*(O34[i] - P34[i]);
359
+        dist_spec_err *= s->lambda / band->threshold;
360
+        dist2 += dist_spec_err;
361
+
362
+        if (dist2 <= dist1 && cb_p <= cb_n) {
363
+            cost_pred += cost2;
364
+            sce->ics.prediction_used[sfb] = 1;
365
+            sce->band_alt[sfb]  = cb_n;
366
+            sce->band_type[sfb] = cb_p;
367
+            count++;
368
+        } else {
369
+            cost_pred += cost1;
370
+            sce->band_alt[sfb] = cb_p;
337 371
         }
338 372
     }
339 373
 
340
-    if (count && cost_pred > cost_coeffs) {
341
-        memset(sce->ics.prediction_used, 0, sizeof(sce->ics.prediction_used));
342
-        memcpy(sce->band_type, sce->orig_band_type, sizeof(sce->band_type));
374
+    if (count && cost_coeffs < cost_pred) {
343 375
         count = 0;
376
+        for (sfb = PRED_SFB_START; sfb < pmax; sfb++)
377
+            RESTORE_PRED(sce, sfb);
378
+        memset(&sce->ics.prediction_used, 0, sizeof(sce->ics.prediction_used));
344 379
     }
345 380
 
346 381
     sce->ics.predictor_present = !!count;
... ...
@@ -352,14 +328,15 @@ void ff_aac_search_for_pred(AACEncContext *s, SingleChannelElement *sce)
352 352
 void ff_aac_encode_main_pred(AACEncContext *s, SingleChannelElement *sce)
353 353
 {
354 354
     int sfb;
355
+    IndividualChannelStream *ics = &sce->ics;
356
+    const int pmax = FFMIN(ics->max_sfb, ff_aac_pred_sfb_max[s->samplerate_index]);
355 357
 
356
-    if (!sce->ics.predictor_present ||
357
-        sce->ics.window_sequence[0] == EIGHT_SHORT_SEQUENCE)
358
+    if (!ics->predictor_present)
358 359
         return;
359 360
 
360
-    put_bits(&s->pb, 1, !!sce->ics.predictor_reset_group);
361
-    if (sce->ics.predictor_reset_group)
362
-        put_bits(&s->pb, 5, sce->ics.predictor_reset_group);
363
-    for (sfb = 0; sfb < FFMIN(sce->ics.max_sfb, ff_aac_pred_sfb_max[s->samplerate_index]); sfb++)
364
-        put_bits(&s->pb, 1, sce->ics.prediction_used[sfb]);
361
+    put_bits(&s->pb, 1, !!ics->predictor_reset_group);
362
+    if (ics->predictor_reset_group)
363
+        put_bits(&s->pb, 5, ics->predictor_reset_group);
364
+    for (sfb = 0; sfb < pmax; sfb++)
365
+        put_bits(&s->pb, 1, ics->prediction_used[sfb]);
365 366
 }
... ...
@@ -34,16 +34,12 @@
34 34
 #define PRED_RESET_FRAME_MIN 240
35 35
 
36 36
 /* Any frame with less than this amount of frames since last reset is ok */
37
-#define PRED_RESET_MIN 128
37
+#define PRED_RESET_MIN 64
38 38
 
39 39
 /* Raise to filter any low frequency artifacts due to prediction */
40 40
 #define PRED_SFB_START 10
41 41
 
42
-/* Offset for the number of bits to encode normal coefficients */
43
-#define PRICE_OFFSET 440
44
-
45 42
 void ff_aac_apply_main_pred(AACEncContext *s, SingleChannelElement *sce);
46
-void ff_aac_update_main_pred(AACEncContext *s, SingleChannelElement *sce, ChannelElement *cpe);
47 43
 void ff_aac_adjust_common_prediction(AACEncContext *s, ChannelElement *cpe);
48 44
 void ff_aac_search_for_pred(AACEncContext *s, SingleChannelElement *sce);
49 45
 void ff_aac_encode_main_pred(AACEncContext *s, SingleChannelElement *sce);