Browse code

avfilter: add nnedi filter

Port of nnedi3 vapoursynth filter.

Signed-off-by: Paul B Mahol <onemda@gmail.com>

Paul B Mahol authored on 2016/01/24 01:15:53
Showing 7 changed files
... ...
@@ -63,6 +63,7 @@ version <next>:
63 63
 - Cineform HD decoder
64 64
 - new DCA decoder with full support for DTS-HD extensions
65 65
 - significant performance improvements in Windows Television (WTV) demuxer
66
+- nnedi deinterlacer
66 67
 
67 68
 
68 69
 version 2.8:
... ...
@@ -2873,6 +2873,7 @@ mpdecimate_filter_deps="gpl"
2873 2873
 mpdecimate_filter_select="pixelutils"
2874 2874
 mptestsrc_filter_deps="gpl"
2875 2875
 negate_filter_deps="lut_filter"
2876
+nnedi_filter_deps="gpl"
2876 2877
 ocr_filter_deps="libtesseract"
2877 2878
 ocv_filter_deps="libopencv"
2878 2879
 owdenoise_filter_deps="gpl"
... ...
@@ -8490,6 +8490,115 @@ Negate input video.
8490 8490
 It accepts an integer in input; if non-zero it negates the
8491 8491
 alpha component (if available). The default value in input is 0.
8492 8492
 
8493
+@section nnedi
8494
+
8495
+Deinterlace video using neural network edge directed interpolation.
8496
+
8497
+This filter accepts the following options:
8498
+
8499
+@table @option
8500
+@item weights
8501
+Mandatory option, without binary file filter can not work.
8502
+Currently file can be found here:
8503
+https://github.com/dubhater/vapoursynth-nnedi3/blob/master/src/nnedi3_weights.bin
8504
+
8505
+@item deint
8506
+Set which frames to deinterlace, by default it is @code{all}.
8507
+Can be @code{all} or @code{interlaced}.
8508
+
8509
+@item field
8510
+Set mode of operation.
8511
+
8512
+Can be one of the following:
8513
+
8514
+@table @samp
8515
+@item af
8516
+Use frame flags, both fields.
8517
+@item a
8518
+Use frame flags, single field.
8519
+@item t
8520
+Use top field only.
8521
+@item b
8522
+Use bottom field only.
8523
+@item ft
8524
+Use both fields, top first.
8525
+@item fb
8526
+Use both fields, bottom first.
8527
+@end table
8528
+
8529
+@item planes
8530
+Set which planes to process, by default filter process all frames.
8531
+
8532
+@item nsize
8533
+Set size of local neighborhood around each pixel, used by the predictor neural
8534
+network.
8535
+
8536
+Can be one of the following:
8537
+
8538
+@table @samp
8539
+@item s8x6
8540
+@item s16x6
8541
+@item s32x6
8542
+@item s48x6
8543
+@item s8x4
8544
+@item s16x4
8545
+@item s32x4
8546
+@end table
8547
+
8548
+@item nns
8549
+Set the number of neurons in predicctor neural network.
8550
+Can be one of the following:
8551
+
8552
+@table @samp
8553
+@item n16
8554
+@item n32
8555
+@item n64
8556
+@item n128
8557
+@item n256
8558
+@end table
8559
+
8560
+@item qual
8561
+Controls the number of different neural network predictions that are blended
8562
+together to compute the final output value. Can be @code{fast}, default or
8563
+@code{slow}.
8564
+
8565
+@item etype
8566
+Set which set of weights to use in the predictor.
8567
+Can be one of the following:
8568
+
8569
+@table @samp
8570
+@item a
8571
+weights trained to minimize absolute error
8572
+@item s
8573
+weights trained to minimize squared error
8574
+@end table
8575
+
8576
+@item pscrn
8577
+Controls whether or not the prescreener neural network is used to decide
8578
+which pixels should be processed by the predictor neural network and which
8579
+can be handled by simple cubic interpolation.
8580
+The prescreener is trained to know whether cubic interpolation will be
8581
+sufficient for a pixel or whether it should be predicted by the predictor nn.
8582
+The computational complexity of the prescreener nn is much less than that of
8583
+the predictor nn. Since most pixels can be handled by cubic interpolation,
8584
+using the prescreener generally results in much faster processing.
8585
+The prescreener is pretty accurate, so the difference between using it and not
8586
+using it is almost always unnoticeable.
8587
+
8588
+Can be one of the following:
8589
+
8590
+@table @samp
8591
+@item none
8592
+@item original
8593
+@item new
8594
+@end table
8595
+
8596
+Default is @code{new}.
8597
+
8598
+@item fapprox
8599
+Set various debugging flags.
8600
+@end table
8601
+
8493 8602
 @section noformat
8494 8603
 
8495 8604
 Force libavfilter not to use any of the specified pixel formats for the
... ...
@@ -187,6 +187,7 @@ OBJS-$(CONFIG_MCDEINT_FILTER)                += vf_mcdeint.o
187 187
 OBJS-$(CONFIG_MERGEPLANES_FILTER)            += vf_mergeplanes.o framesync.o
188 188
 OBJS-$(CONFIG_MPDECIMATE_FILTER)             += vf_mpdecimate.o
189 189
 OBJS-$(CONFIG_NEGATE_FILTER)                 += vf_lut.o
190
+OBJS-$(CONFIG_NNEDI_FILTER)                  += vf_nnedi.o
190 191
 OBJS-$(CONFIG_NOFORMAT_FILTER)               += vf_format.o
191 192
 OBJS-$(CONFIG_NOISE_FILTER)                  += vf_noise.o
192 193
 OBJS-$(CONFIG_NULL_FILTER)                   += vf_null.o
... ...
@@ -208,6 +208,7 @@ void avfilter_register_all(void)
208 208
     REGISTER_FILTER(MERGEPLANES,    mergeplanes,    vf);
209 209
     REGISTER_FILTER(MPDECIMATE,     mpdecimate,     vf);
210 210
     REGISTER_FILTER(NEGATE,         negate,         vf);
211
+    REGISTER_FILTER(NNEDI,          nnedi,          vf);
211 212
     REGISTER_FILTER(NOFORMAT,       noformat,       vf);
212 213
     REGISTER_FILTER(NOISE,          noise,          vf);
213 214
     REGISTER_FILTER(NULL,           null,           vf);
... ...
@@ -30,7 +30,7 @@
30 30
 #include "libavutil/version.h"
31 31
 
32 32
 #define LIBAVFILTER_VERSION_MAJOR   6
33
-#define LIBAVFILTER_VERSION_MINOR  27
33
+#define LIBAVFILTER_VERSION_MINOR  28
34 34
 #define LIBAVFILTER_VERSION_MICRO 100
35 35
 
36 36
 #define LIBAVFILTER_VERSION_INT AV_VERSION_INT(LIBAVFILTER_VERSION_MAJOR, \
37 37
new file mode 100644
... ...
@@ -0,0 +1,1211 @@
0
+/*
1
+ * Copyright (C) 2010-2011 Kevin Stone
2
+ * Copyright (C) 2016 Paul B Mahol
3
+ *
4
+ * This file is part of FFmpeg.
5
+ *
6
+ * FFmpeg is free software; you can redistribute it and/or modify
7
+ * it under the terms of the GNU General Public License as published by
8
+ * the Free Software Foundation; either version 2 of the License, or
9
+ * (at your option) any later version.
10
+ *
11
+ * FFmpeg is distributed in the hope that it will be useful,
12
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
+ * GNU General Public License for more details.
15
+ *
16
+ * You should have received a copy of the GNU General Public License along
17
+ * with FFmpeg; if not, write to the Free Software Foundation, Inc.,
18
+ * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19
+ */
20
+
21
+#include <float.h>
22
+
23
+#include "libavutil/common.h"
24
+#include "libavutil/float_dsp.h"
25
+#include "libavutil/imgutils.h"
26
+#include "libavutil/opt.h"
27
+#include "libavutil/pixdesc.h"
28
+#include "avfilter.h"
29
+#include "formats.h"
30
+#include "internal.h"
31
+#include "video.h"
32
+
33
+typedef struct FrameData {
34
+    uint8_t *paddedp[3];
35
+    int padded_stride[3];
36
+    int padded_width[3];
37
+    int padded_height[3];
38
+
39
+    uint8_t *dstp[3];
40
+    int dst_stride[3];
41
+
42
+    int field[3];
43
+
44
+    int32_t *lcount[3];
45
+    float *input;
46
+    float *temp;
47
+} FrameData;
48
+
49
+typedef struct NNEDIContext {
50
+    const AVClass *class;
51
+
52
+    char *weights_file;
53
+
54
+    AVFrame *src;
55
+    AVFrame *second;
56
+    AVFrame *dst;
57
+    int eof;
58
+    int64_t cur_pts;
59
+
60
+    AVFloatDSPContext *fdsp;
61
+    int nb_planes;
62
+    int linesize[4];
63
+    int planeheight[4];
64
+
65
+    float *weights0;
66
+    float *weights1[2];
67
+    int asize;
68
+    int nns;
69
+    int xdia;
70
+    int ydia;
71
+
72
+    // Parameters
73
+    int deint;
74
+    int field;
75
+    int process_plane;
76
+    int nsize;
77
+    int nnsparam;
78
+    int qual;
79
+    int etype;
80
+    int pscrn;
81
+    int fapprox;
82
+
83
+    int max_value;
84
+
85
+    void (*copy_pad)(const AVFrame *, FrameData *, struct NNEDIContext *, int);
86
+    void (*evalfunc_0)(struct NNEDIContext *, FrameData *);
87
+    void (*evalfunc_1)(struct NNEDIContext *, FrameData *);
88
+
89
+    // Functions used in evalfunc_0
90
+    void (*readpixels)(const uint8_t *, const int, float *);
91
+    void (*compute_network0)(struct NNEDIContext *s, const float *, const float *, uint8_t *);
92
+    int32_t (*process_line0)(const uint8_t *, int, uint8_t *, const uint8_t *, const int, const int, const int);
93
+
94
+    // Functions used in evalfunc_1
95
+    void (*extract)(const uint8_t *, const int, const int, const int, float *, float *);
96
+    void (*dot_prod)(struct NNEDIContext *, const float *, const float *, float *, const int, const int, const float *);
97
+    void (*expfunc)(float *, const int);
98
+    void (*wae5)(const float *, const int, float *);
99
+
100
+    FrameData frame_data;
101
+} NNEDIContext;
102
+
103
+#define OFFSET(x) offsetof(NNEDIContext, x)
104
+#define FLAGS AV_OPT_FLAG_VIDEO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
105
+
106
+static const AVOption nnedi_options[] = {
107
+    {"weights",  "set weights file", OFFSET(weights_file),  AV_OPT_TYPE_STRING, {.str="nnedi3_weights.bin"}, 0, 0, FLAGS },
108
+    {"deint",         "set which frames to deinterlace", OFFSET(deint),         AV_OPT_TYPE_INT, {.i64=0}, 0, 1, FLAGS, "deint" },
109
+        {"all",        "deinterlace all frames",                       0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "deint" },
110
+        {"interlaced", "only deinterlace frames marked as interlaced", 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "deint" },
111
+    {"field",  "set mode of operation", OFFSET(field),         AV_OPT_TYPE_INT, {.i64=-1}, -2, 3, FLAGS, "field" },
112
+        {"af", "use frame flags, both fields",  0, AV_OPT_TYPE_CONST, {.i64=-2}, 0, 0, FLAGS, "field" },
113
+        {"a",  "use frame flags, single field", 0, AV_OPT_TYPE_CONST, {.i64=-1}, 0, 0, FLAGS, "field" },
114
+        {"t",  "use top field only",            0, AV_OPT_TYPE_CONST, {.i64=0},  0, 0, FLAGS, "field" },
115
+        {"b",  "use bottom field only",         0, AV_OPT_TYPE_CONST, {.i64=1},  0, 0, FLAGS, "field" },
116
+        {"tf", "use both fields, top first",    0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "field" },
117
+        {"bf", "use both fields, bottom first", 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "field" },
118
+    {"planes", "set which planes to process", OFFSET(process_plane), AV_OPT_TYPE_INT, {.i64=7}, 0, 7, FLAGS },
119
+    {"nsize",  "set size of local neighborhood around each pixel, used by the predictor neural network", OFFSET(nsize), AV_OPT_TYPE_INT, {.i64=6}, 0, 6, FLAGS, "nsize" },
120
+        {"s8x6",     NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "nsize" },
121
+        {"s16x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "nsize" },
122
+        {"s32x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "nsize" },
123
+        {"s48x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "nsize" },
124
+        {"s8x4",     NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, FLAGS, "nsize" },
125
+        {"s16x4",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=5}, 0, 0, FLAGS, "nsize" },
126
+        {"s32x4",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=6}, 0, 0, FLAGS, "nsize" },
127
+    {"nns",    "set number of neurons in predictor neural network", OFFSET(nnsparam), AV_OPT_TYPE_INT, {.i64=1}, 0, 4, FLAGS, "nns" },
128
+        {"n16",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "nns" },
129
+        {"n32",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "nns" },
130
+        {"n64",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "nns" },
131
+        {"n128",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "nns" },
132
+        {"n256",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, FLAGS, "nns" },
133
+    {"qual",  "set quality", OFFSET(qual), AV_OPT_TYPE_INT, {.i64=1}, 1, 2, FLAGS, "qual" },
134
+        {"fast", NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "qual" },
135
+        {"slow", NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "qual" },
136
+    {"etype", "set which set of weights to use in the predictor", OFFSET(etype), AV_OPT_TYPE_INT, {.i64=0}, 0, 1, FLAGS, "etype" },
137
+        {"a",  "weights trained to minimize absolute error", 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "etype" },
138
+        {"s",  "weights trained to minimize squared error",  0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "etype" },
139
+    {"pscrn", "set prescreening", OFFSET(pscrn), AV_OPT_TYPE_INT, {.i64=2}, 0, 2, FLAGS, "pscrn" },
140
+        {"none",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "pscrn" },
141
+        {"original",  NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "pscrn" },
142
+        {"new",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "pscrn" },
143
+    {"fapprox",       NULL, OFFSET(fapprox),       AV_OPT_TYPE_INT, {.i64=0}, 0, 3, FLAGS },
144
+    { NULL }
145
+};
146
+
147
+AVFILTER_DEFINE_CLASS(nnedi);
148
+
149
+static int config_input(AVFilterLink *inlink)
150
+{
151
+    AVFilterContext *ctx = inlink->dst;
152
+    NNEDIContext *s = ctx->priv;
153
+    const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format);
154
+    int ret;
155
+
156
+    s->nb_planes = av_pix_fmt_count_planes(inlink->format);
157
+    if ((ret = av_image_fill_linesizes(s->linesize, inlink->format, inlink->w)) < 0)
158
+        return ret;
159
+
160
+    s->planeheight[1] = s->planeheight[2] = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
161
+    s->planeheight[0] = s->planeheight[3] = inlink->h;
162
+
163
+    return 0;
164
+}
165
+
166
+static int config_output(AVFilterLink *outlink)
167
+{
168
+    AVFilterContext *ctx = outlink->src;
169
+    NNEDIContext *s = ctx->priv;
170
+
171
+    outlink->time_base.num = ctx->inputs[0]->time_base.num;
172
+    outlink->time_base.den = ctx->inputs[0]->time_base.den * 2;
173
+    outlink->w             = ctx->inputs[0]->w;
174
+    outlink->h             = ctx->inputs[0]->h;
175
+
176
+    if (s->field > 1 || s->field == -2)
177
+        outlink->frame_rate = av_mul_q(ctx->inputs[0]->frame_rate,
178
+                                       (AVRational){2, 1});
179
+
180
+    return 0;
181
+}
182
+
183
+static int query_formats(AVFilterContext *ctx)
184
+{
185
+    static const enum AVPixelFormat pix_fmts[] = {
186
+        AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P,
187
+        AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P,
188
+        AV_PIX_FMT_YUV440P, AV_PIX_FMT_YUV444P,
189
+        AV_PIX_FMT_YUVJ444P, AV_PIX_FMT_YUVJ440P,
190
+        AV_PIX_FMT_YUVJ422P, AV_PIX_FMT_YUVJ420P,
191
+        AV_PIX_FMT_YUVJ411P,
192
+        AV_PIX_FMT_GBRP,
193
+        AV_PIX_FMT_GRAY8,
194
+        AV_PIX_FMT_NONE
195
+    };
196
+
197
+    AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts);
198
+    if (!fmts_list)
199
+        return AVERROR(ENOMEM);
200
+    return ff_set_common_formats(ctx, fmts_list);
201
+}
202
+
203
+static void copy_pad(const AVFrame *src, FrameData *frame_data, NNEDIContext *s, int fn)
204
+{
205
+    const int off = 1 - fn;
206
+    int plane, y, x;
207
+
208
+    for (plane = 0; plane < s->nb_planes; plane++) {
209
+        const uint8_t *srcp = (const uint8_t *)src->data[plane];
210
+        uint8_t *dstp = (uint8_t *)frame_data->paddedp[plane];
211
+
212
+        const int src_stride = src->linesize[plane];
213
+        const int dst_stride = frame_data->padded_stride[plane];
214
+
215
+        const int src_height = s->planeheight[plane];
216
+        const int dst_height = frame_data->padded_height[plane];
217
+
218
+        const int src_width = s->linesize[plane];
219
+        const int dst_width = frame_data->padded_width[plane];
220
+
221
+        int c = 4;
222
+
223
+        if (!(s->process_plane & (1 << plane)))
224
+            continue;
225
+
226
+        // Copy.
227
+        for (y = off; y < src_height; y += 2)
228
+            memcpy(dstp + 32 + (6 + y) * dst_stride,
229
+                   srcp + y * src_stride,
230
+                   src_width * sizeof(uint8_t));
231
+
232
+        // And pad.
233
+        dstp += (6 + off) * dst_stride;
234
+        for (y = 6 + off; y < dst_height - 6; y += 2) {
235
+            int c = 2;
236
+
237
+            for (x = 0; x < 32; x++)
238
+                dstp[x] = dstp[64 - x];
239
+
240
+            for (x = dst_width - 32; x < dst_width; x++, c += 2)
241
+                dstp[x] = dstp[x - c];
242
+
243
+            dstp += dst_stride * 2;
244
+        }
245
+
246
+        dstp = (uint8_t *)frame_data->paddedp[plane];
247
+        for (y = off; y < 6; y += 2)
248
+            memcpy(dstp + y * dst_stride,
249
+                   dstp + (12 + 2 * off - y) * dst_stride,
250
+                   dst_width * sizeof(uint8_t));
251
+
252
+        for (y = dst_height - 6 + off; y < dst_height; y += 2, c += 4)
253
+            memcpy(dstp + y * dst_stride,
254
+                   dstp + (y - c) * dst_stride,
255
+                   dst_width * sizeof(uint8_t));
256
+    }
257
+}
258
+
259
+static void elliott(float *data, const int n)
260
+{
261
+    int i;
262
+
263
+    for (i = 0; i < n; i++)
264
+        data[i] = data[i] / (1.0f + FFABS(data[i]));
265
+}
266
+
267
+static void dot_prod(NNEDIContext *s, const float *data, const float *weights, float *vals, const int n, const int len, const float *scale)
268
+{
269
+    int i;
270
+
271
+    for (i = 0; i < n; i++) {
272
+        float sum;
273
+
274
+        sum = s->fdsp->scalarproduct_float(data, &weights[i * len], len);
275
+
276
+        vals[i] = sum * scale[0] + weights[n * len + i];
277
+    }
278
+}
279
+
280
+static void dot_prods(NNEDIContext *s, const float *dataf, const float *weightsf, float *vals, const int n, const int len, const float *scale)
281
+{
282
+    const int16_t *data = (int16_t *)dataf;
283
+    const int16_t *weights = (int16_t *)weightsf;
284
+    const float *wf = (float *)&weights[n * len];
285
+    int i, j;
286
+
287
+    for (i = 0; i < n; i++) {
288
+        int sum = 0, off = ((i >> 2) << 3) + (i & 3);
289
+        for (j = 0; j < len; j++)
290
+            sum += data[j] * weights[i * len + j];
291
+
292
+        vals[i] = sum * wf[off] * scale[0] + wf[off + 4];
293
+    }
294
+}
295
+
296
+static void compute_network0(NNEDIContext *s, const float *input, const float *weights, uint8_t *d)
297
+{
298
+    float t, temp[12], scale = 1.0f;
299
+
300
+    dot_prod(s, input, weights, temp, 4, 48, &scale);
301
+    t = temp[0];
302
+    elliott(temp, 4);
303
+    temp[0] = t;
304
+    dot_prod(s, temp, weights + 4 * 49, temp + 4, 4, 4, &scale);
305
+    elliott(temp + 4, 4);
306
+    dot_prod(s, temp, weights + 4 * 49 + 4 * 5, temp + 8, 4, 8, &scale);
307
+    if (FFMAX(temp[10], temp[11]) <= FFMAX(temp[8], temp[9]))
308
+        d[0] = 1;
309
+    else
310
+        d[0] = 0;
311
+}
312
+
313
+static void compute_network0_i16(NNEDIContext *s, const float *inputf, const float *weightsf, uint8_t *d)
314
+{
315
+    const float *wf = weightsf + 2 * 48;
316
+    float t, temp[12], scale = 1.0f;
317
+
318
+    dot_prods(s, inputf, weightsf, temp, 4, 48, &scale);
319
+    t = temp[0];
320
+    elliott(temp, 4);
321
+    temp[0] = t;
322
+    dot_prod(s, temp, wf + 8, temp + 4, 4, 4, &scale);
323
+    elliott(temp + 4, 4);
324
+    dot_prod(s, temp, wf + 8 + 4 * 5, temp + 8, 4, 8, &scale);
325
+    if (FFMAX(temp[10], temp[11]) <= FFMAX(temp[8], temp[9]))
326
+        d[0] = 1;
327
+    else
328
+        d[0] = 0;
329
+}
330
+
331
+static void pixel2float48(const uint8_t *t8, const int pitch, float *p)
332
+{
333
+    const uint8_t *t = (const uint8_t *)t8;
334
+    int y, x;
335
+
336
+    for (y = 0; y < 4; y++)
337
+        for (x = 0; x < 12; x++)
338
+            p[y * 12 + x] = t[y * pitch * 2 + x];
339
+}
340
+
341
+static void byte2word48(const uint8_t *t, const int pitch, float *pf)
342
+{
343
+    int16_t *p = (int16_t *)pf;
344
+    int y, x;
345
+
346
+    for (y = 0; y < 4; y++)
347
+        for (x = 0; x < 12; x++)
348
+            p[y * 12 + x] = t[y * pitch * 2 + x];
349
+}
350
+
351
+static int32_t process_line0(const uint8_t *tempu, int width, uint8_t *dstp8, const uint8_t *src3p8, const int src_pitch, const int max_value, const int chroma)
352
+{
353
+    uint8_t *dstp = (uint8_t *)dstp8;
354
+    const uint8_t *src3p = (const uint8_t *)src3p8;
355
+    int minimum = 0;
356
+    int maximum = max_value - 1; // Technically the -1 is only needed for 8 and 16 bit input.
357
+    int count = 0, x;
358
+    for (x = 0; x < width; x++) {
359
+        if (tempu[x]) {
360
+            int tmp = 19 * (src3p[x + src_pitch * 2] + src3p[x + src_pitch * 4]) - 3 * (src3p[x] + src3p[x + src_pitch * 6]);
361
+            tmp /= 32;
362
+            dstp[x] = FFMAX(FFMIN(tmp, maximum), minimum);
363
+        } else {
364
+            memset(dstp + x, 255, sizeof(uint8_t));
365
+            count++;
366
+        }
367
+    }
368
+    return count;
369
+}
370
+
371
+// new prescreener functions
372
+static void byte2word64(const uint8_t *t, const int pitch, float *p)
373
+{
374
+    int16_t *ps = (int16_t *)p;
375
+    int y, x;
376
+
377
+    for (y = 0; y < 4; y++)
378
+        for (x = 0; x < 16; x++)
379
+            ps[y * 16 + x] = t[y * pitch * 2 + x];
380
+}
381
+
382
+static void compute_network0new(NNEDIContext *s, const float *datai, const float *weights, uint8_t *d)
383
+{
384
+    int16_t *data = (int16_t *)datai;
385
+    int16_t *ws = (int16_t *)weights;
386
+    float *wf = (float *)&ws[4 * 64];
387
+    float vals[8];
388
+    int mask, i, j;
389
+
390
+    for (i = 0; i < 4; i++) {
391
+        int sum = 0;
392
+        float t;
393
+
394
+        for (j = 0; j < 64; j++)
395
+            sum += data[j] * ws[(i << 3) + ((j >> 3) << 5) + (j & 7)];
396
+        t = sum * wf[i] + wf[4 + i];
397
+        vals[i] = t / (1.0f + FFABS(t));
398
+    }
399
+
400
+    for (i = 0; i < 4; i++) {
401
+        float sum = 0.0f;
402
+
403
+        for (j = 0; j < 4; j++)
404
+            sum += vals[j] * wf[8 + i + (j << 2)];
405
+        vals[4 + i] = sum + wf[8 + 16 + i];
406
+    }
407
+
408
+    mask = 0;
409
+    for (i = 0; i < 4; i++) {
410
+        if (vals[4 + i] > 0.0f)
411
+            mask |= (0x1 << (i << 3));
412
+    }
413
+
414
+    ((int *)d)[0] = mask;
415
+}
416
+
417
+static void evalfunc_0(NNEDIContext *s, FrameData *frame_data)
418
+{
419
+    float *input = frame_data->input;
420
+    const float *weights0 = s->weights0;
421
+    float *temp = frame_data->temp;
422
+    uint8_t *tempu = (uint8_t *)temp;
423
+    int plane, x, y;
424
+
425
+    // And now the actual work.
426
+    for (plane = 0; plane < s->nb_planes; plane++) {
427
+        const uint8_t *srcp = (const uint8_t *)frame_data->paddedp[plane];
428
+        const int src_stride = frame_data->padded_stride[plane] / sizeof(uint8_t);
429
+
430
+        const int width = frame_data->padded_width[plane];
431
+        const int height = frame_data->padded_height[plane];
432
+
433
+        uint8_t *dstp = (uint8_t *)frame_data->dstp[plane];
434
+        const int dst_stride = frame_data->dst_stride[plane] / sizeof(uint8_t);
435
+        const uint8_t *src3p;
436
+        int ystart, ystop;
437
+        int32_t *lcount;
438
+
439
+        if (!(s->process_plane & (1 << plane)))
440
+            continue;
441
+
442
+        for (y = 1 - frame_data->field[plane]; y < height - 12; y += 2) {
443
+            memcpy(dstp + y * dst_stride,
444
+                   srcp + 32 + (6 + y) * src_stride,
445
+                   (width - 64) * sizeof(uint8_t));
446
+
447
+        }
448
+
449
+        ystart = 6 + frame_data->field[plane];
450
+        ystop = height - 6;
451
+        srcp += ystart * src_stride;
452
+        dstp += (ystart - 6) * dst_stride - 32;
453
+        src3p = srcp - src_stride * 3;
454
+        lcount = frame_data->lcount[plane] - 6;
455
+
456
+        if (s->pscrn == 1) { // original
457
+            for (y = ystart; y < ystop; y += 2) {
458
+                for (x = 32; x < width - 32; x++) {
459
+                    s->readpixels((const uint8_t *)(src3p + x - 5), src_stride, input);
460
+                    s->compute_network0(s, input, weights0, tempu+x);
461
+                }
462
+                lcount[y] += s->process_line0(tempu + 32, width - 64, (uint8_t *)(dstp + 32), (const uint8_t *)(src3p + 32), src_stride, s->max_value, plane);
463
+                src3p += src_stride * 2;
464
+                dstp += dst_stride * 2;
465
+            }
466
+        } else if (s->pscrn > 1) { // new
467
+            for (y = ystart; y < ystop; y += 2) {
468
+                for (x = 32; x < width - 32; x += 4) {
469
+                    s->readpixels((const uint8_t *)(src3p + x - 6), src_stride, input);
470
+                    s->compute_network0(s, input, weights0, tempu + x);
471
+                }
472
+                lcount[y] += s->process_line0(tempu + 32, width - 64, (uint8_t *)(dstp + 32), (const uint8_t *)(src3p + 32), src_stride, s->max_value, plane);
473
+                src3p += src_stride * 2;
474
+                dstp += dst_stride * 2;
475
+            }
476
+        } else { // no prescreening
477
+            for (y = ystart; y < ystop; y += 2) {
478
+                memset(dstp + 32, 255, (width - 64) * sizeof(uint8_t));
479
+                lcount[y] += width - 64;
480
+                dstp += dst_stride * 2;
481
+            }
482
+        }
483
+    }
484
+}
485
+
486
+static void extract_m8(const uint8_t *srcp8, const int stride, const int xdia, const int ydia, float *mstd, float *input)
487
+{
488
+    // uint8_t or uint16_t or float
489
+    const uint8_t *srcp = (const uint8_t *)srcp8;
490
+
491
+    // int32_t or int64_t or double
492
+    int64_t sum = 0, sumsq = 0;
493
+    int y, x;
494
+
495
+    for (y = 0; y < ydia; y++) {
496
+        const uint8_t *srcpT = srcp + y * stride * 2;
497
+
498
+        for (x = 0; x < xdia; x++) {
499
+            sum += srcpT[x];
500
+            sumsq += (uint32_t)srcpT[x] * (uint32_t)srcpT[x];
501
+            input[x] = srcpT[x];
502
+        }
503
+        input += xdia;
504
+    }
505
+    const float scale = 1.0f / (xdia * ydia);
506
+    mstd[0] = sum * scale;
507
+    const double tmp = (double)sumsq * scale - (double)mstd[0] * mstd[0];
508
+    mstd[3] = 0.0f;
509
+    if (tmp <= FLT_EPSILON)
510
+        mstd[1] = mstd[2] = 0.0f;
511
+    else {
512
+        mstd[1] = sqrt(tmp);
513
+        mstd[2] = 1.0f / mstd[1];
514
+    }
515
+}
516
+
517
+static void extract_m8_i16(const uint8_t *srcp, const int stride, const int xdia, const int ydia, float *mstd, float *inputf)
518
+{
519
+    int16_t *input = (int16_t *)inputf;
520
+    int sum = 0, sumsq = 0;
521
+    int y, x;
522
+
523
+    for (y = 0; y < ydia; y++) {
524
+        const uint8_t *srcpT = srcp + y * stride * 2;
525
+        for (x = 0; x < xdia; x++) {
526
+            sum += srcpT[x];
527
+            sumsq += srcpT[x] * srcpT[x];
528
+            input[x] = srcpT[x];
529
+        }
530
+        input += xdia;
531
+    }
532
+    const float scale = 1.0f / (float)(xdia * ydia);
533
+    mstd[0] = sum * scale;
534
+    mstd[1] = sumsq * scale - mstd[0] * mstd[0];
535
+    mstd[3] = 0.0f;
536
+    if (mstd[1] <= FLT_EPSILON)
537
+        mstd[1] = mstd[2] = 0.0f;
538
+    else {
539
+        mstd[1] = sqrt(mstd[1]);
540
+        mstd[2] = 1.0f / mstd[1];
541
+    }
542
+}
543
+
544
+
545
+static const float exp_lo = -80.0f;
546
+static const float exp_hi = +80.0f;
547
+
548
+static void e2_m16(float *s, const int n)
549
+{
550
+    int i;
551
+
552
+    for (i = 0; i < n; i++)
553
+        s[i] = exp(av_clipf(s[i], exp_lo, exp_hi));
554
+}
555
+
556
+const float min_weight_sum = 1e-10f;
557
+
558
+static void weighted_avg_elliott_mul5_m16(const float *w, const int n, float *mstd)
559
+{
560
+    float vsum = 0.0f, wsum = 0.0f;
561
+    int i;
562
+
563
+    for (i = 0; i < n; i++) {
564
+        vsum += w[i] * (w[n + i] / (1.0f + FFABS(w[n + i])));
565
+        wsum += w[i];
566
+    }
567
+    if (wsum > min_weight_sum)
568
+        mstd[3] += ((5.0f * vsum) / wsum) * mstd[1] + mstd[0];
569
+    else
570
+        mstd[3] += mstd[0];
571
+}
572
+
573
+
574
+static void evalfunc_1(NNEDIContext *s, FrameData *frame_data)
575
+{
576
+    float *input = frame_data->input;
577
+    float *temp = frame_data->temp;
578
+    float **weights1 = s->weights1;
579
+    const int qual = s->qual;
580
+    const int asize = s->asize;
581
+    const int nns = s->nns;
582
+    const int xdia = s->xdia;
583
+    const int xdiad2m1 = (xdia / 2) - 1;
584
+    const int ydia = s->ydia;
585
+    const float scale = 1.0f / (float)qual;
586
+    int plane, y, x, i;
587
+
588
+    for (plane = 0; plane < s->nb_planes; plane++) {
589
+        if (!(s->process_plane & (1 << plane)))
590
+            continue;
591
+
592
+        const uint8_t *srcp = (const uint8_t *)frame_data->paddedp[plane];
593
+        const int src_stride = frame_data->padded_stride[plane] / sizeof(uint8_t);
594
+
595
+        const int width = frame_data->padded_width[plane];
596
+        const int height = frame_data->padded_height[plane];
597
+
598
+        uint8_t *dstp = (uint8_t *)frame_data->dstp[plane];
599
+        const int dst_stride = frame_data->dst_stride[plane] / sizeof(uint8_t);
600
+
601
+        const int ystart = frame_data->field[plane];
602
+        const int ystop = height - 12;
603
+
604
+        srcp += (ystart + 6) * src_stride;
605
+        dstp += ystart * dst_stride - 32;
606
+        const uint8_t *srcpp = srcp - (ydia - 1) * src_stride - xdiad2m1;
607
+
608
+        for (y = ystart; y < ystop; y += 2) {
609
+            for (x = 32; x < width - 32; x++) {
610
+                uint32_t pixel = 0;
611
+                memcpy(&pixel, dstp + x, sizeof(uint8_t));
612
+
613
+                uint32_t all_ones = 0;
614
+                memset(&all_ones, 255, sizeof(uint8_t));
615
+
616
+                if (pixel != all_ones)
617
+                    continue;
618
+
619
+                float mstd[4];
620
+                s->extract((const uint8_t *)(srcpp + x), src_stride, xdia, ydia, mstd, input);
621
+                for (i = 0; i < qual; i++) {
622
+                    s->dot_prod(s, input, weights1[i], temp, nns * 2, asize, mstd + 2);
623
+                    s->expfunc(temp, nns);
624
+                    s->wae5(temp, nns, mstd);
625
+                }
626
+
627
+                dstp[x] = FFMIN(FFMAX((int)(mstd[3] * scale + 0.5f), 0), s->max_value);
628
+            }
629
+            srcpp += src_stride * 2;
630
+            dstp += dst_stride * 2;
631
+        }
632
+    }
633
+}
634
+
635
+#define NUM_NSIZE 7
636
+#define NUM_NNS 5
637
+
638
+static int roundds(const double f)
639
+{
640
+    if (f - floor(f) >= 0.5)
641
+        return FFMIN((int)ceil(f), 32767);
642
+    return FFMAX((int)floor(f), -32768);
643
+}
644
+
645
+static void select_functions(NNEDIContext *s)
646
+{
647
+    s->copy_pad = copy_pad;
648
+    s->evalfunc_0 = evalfunc_0;
649
+    s->evalfunc_1 = evalfunc_1;
650
+
651
+    // evalfunc_0
652
+    s->process_line0 = process_line0;
653
+
654
+    if (s->pscrn < 2) { // original prescreener
655
+        if (s->fapprox & 1) { // int16 dot products
656
+            s->readpixels = byte2word48;
657
+            s->compute_network0 = compute_network0_i16;
658
+        } else {
659
+            s->readpixels = pixel2float48;
660
+            s->compute_network0 = compute_network0;
661
+        }
662
+    } else { // new prescreener
663
+        // only int16 dot products
664
+        s->readpixels = byte2word64;
665
+        s->compute_network0 = compute_network0new;
666
+    }
667
+
668
+    // evalfunc_1
669
+    s->wae5 = weighted_avg_elliott_mul5_m16;
670
+
671
+    if (s->fapprox & 2) { // use int16 dot products
672
+        s->extract = extract_m8_i16;
673
+        s->dot_prod = dot_prods;
674
+    } else { // use float dot products
675
+        s->extract = extract_m8;
676
+        s->dot_prod = dot_prod;
677
+    }
678
+
679
+    s->expfunc = e2_m16;
680
+}
681
+
682
+static int modnpf(const int m, const int n)
683
+{
684
+    if ((m % n) == 0)
685
+        return m;
686
+    return m + n - (m % n);
687
+}
688
+
689
+static int get_frame(AVFilterContext *ctx, int is_second)
690
+{
691
+    NNEDIContext *s = ctx->priv;
692
+    AVFilterLink *outlink = ctx->outputs[0];
693
+    AVFrame *src = s->src;
694
+    FrameData *frame_data;
695
+    int effective_field = s->field;
696
+    size_t temp_size;
697
+    int field_n;
698
+    int plane;
699
+
700
+    if (effective_field > 1)
701
+        effective_field -= 2;
702
+    else if (effective_field < 0)
703
+        effective_field += 2;
704
+
705
+    if (s->field < 0 && src->interlaced_frame && src->top_field_first == 0)
706
+        effective_field = 0;
707
+    else if (s->field < 0 && src->interlaced_frame && src->top_field_first == 1)
708
+        effective_field = 1;
709
+    else
710
+        effective_field = !effective_field;
711
+
712
+    if (s->field > 1 || s->field == -2) {
713
+        if (is_second) {
714
+            field_n = (effective_field == 0);
715
+        } else {
716
+            field_n = (effective_field == 1);
717
+        }
718
+    } else {
719
+        field_n = effective_field;
720
+    }
721
+
722
+    s->dst = ff_get_video_buffer(outlink, outlink->w, outlink->h);
723
+    if (!s->dst)
724
+        return AVERROR(ENOMEM);
725
+    av_frame_copy_props(s->dst, src);
726
+    s->dst->interlaced_frame = 0;
727
+
728
+    frame_data = &s->frame_data;
729
+
730
+    for (plane = 0; plane < s->nb_planes; plane++) {
731
+        int dst_height = s->planeheight[plane];
732
+        int dst_width = s->linesize[plane];
733
+
734
+        const int min_alignment = 16;
735
+        const int min_pad = 10;
736
+
737
+        if (!(s->process_plane & (1 << plane))) {
738
+            av_image_copy_plane(s->dst->data[plane], s->dst->linesize[plane],
739
+                                src->data[plane], src->linesize[plane],
740
+                                s->linesize[plane],
741
+                                s->planeheight[plane]);
742
+            continue;
743
+        }
744
+
745
+        frame_data->padded_width[plane]  = dst_width + 64;
746
+        frame_data->padded_height[plane] = dst_height + 12;
747
+        frame_data->padded_stride[plane] = modnpf(frame_data->padded_width[plane] + min_pad, min_alignment); // TODO: maybe min_pad is in pixels too?
748
+        if (!frame_data->paddedp[plane]) {
749
+            frame_data->paddedp[plane] = av_malloc_array(frame_data->padded_stride[plane], frame_data->padded_height[plane]);
750
+            if (!frame_data->paddedp[plane])
751
+                return AVERROR(ENOMEM);
752
+        }
753
+
754
+        frame_data->dstp[plane] = s->dst->data[plane];
755
+        frame_data->dst_stride[plane] = s->dst->linesize[plane];
756
+
757
+        if (!frame_data->lcount[plane]) {
758
+            frame_data->lcount[plane] = av_calloc(dst_height, sizeof(int32_t) * 16);
759
+            if (!frame_data->lcount[plane])
760
+                return AVERROR(ENOMEM);
761
+        } else {
762
+            memset(frame_data->lcount[plane], 0, dst_height * sizeof(int32_t) * 16);
763
+        }
764
+
765
+        frame_data->field[plane] = field_n;
766
+    }
767
+
768
+    if (!frame_data->input) {
769
+        frame_data->input = av_malloc(512 * sizeof(float));
770
+        if (!frame_data->input)
771
+            return AVERROR(ENOMEM);
772
+    }
773
+    // evalfunc_0 requires at least padded_width[0] bytes.
774
+    // evalfunc_1 requires at least 512 floats.
775
+    if (!frame_data->temp) {
776
+        temp_size = FFMAX(frame_data->padded_width[0], 512 * sizeof(float));
777
+        frame_data->temp = av_malloc(temp_size);
778
+        if (!frame_data->temp)
779
+            return AVERROR(ENOMEM);
780
+    }
781
+
782
+    // Copy src to a padded "frame" in frame_data and mirror the edges.
783
+    s->copy_pad(src, frame_data, s, field_n);
784
+
785
+    // Handles prescreening and the cubic interpolation.
786
+    s->evalfunc_0(s, frame_data);
787
+
788
+    // The rest.
789
+    s->evalfunc_1(s, frame_data);
790
+
791
+    return 0;
792
+}
793
+
794
+static int filter_frame(AVFilterLink *inlink, AVFrame *src)
795
+{
796
+    AVFilterContext *ctx = inlink->dst;
797
+    AVFilterLink *outlink = ctx->outputs[0];
798
+    NNEDIContext *s = ctx->priv;
799
+    int ret;
800
+
801
+    if ((s->field > 1 ||
802
+         s->field == -2) && !s->second) {
803
+        goto second;
804
+    } else if (s->field > 1 ||
805
+               s->field == -2) {
806
+        AVFrame *dst;
807
+
808
+        s->src = s->second;
809
+        ret = get_frame(ctx, 1);
810
+        if (ret < 0) {
811
+            av_frame_free(&s->dst);
812
+            av_frame_free(&s->src);
813
+            av_frame_free(&s->second);
814
+            return ret;
815
+        }
816
+        dst = s->dst;
817
+
818
+        if (src->pts != AV_NOPTS_VALUE &&
819
+            dst->pts != AV_NOPTS_VALUE)
820
+            dst->pts += src->pts;
821
+        else
822
+            dst->pts = AV_NOPTS_VALUE;
823
+
824
+        ret = ff_filter_frame(outlink, dst);
825
+        if (ret < 0)
826
+            return ret;
827
+        if (s->eof)
828
+            return 0;
829
+        s->cur_pts = s->second->pts;
830
+        av_frame_free(&s->second);
831
+second:
832
+        if ((s->deint && src->interlaced_frame &&
833
+             !ctx->is_disabled) ||
834
+            (!s->deint && !ctx->is_disabled)) {
835
+            s->second = src;
836
+        }
837
+    }
838
+
839
+    if ((s->deint && !src->interlaced_frame) || ctx->is_disabled) {
840
+        AVFrame *dst = av_frame_clone(src);
841
+        if (!dst) {
842
+            av_frame_free(&src);
843
+            av_frame_free(&s->second);
844
+            return AVERROR(ENOMEM);
845
+        }
846
+
847
+        if (s->field > 1 || s->field == -2) {
848
+            av_frame_free(&s->second);
849
+            if ((s->deint && src->interlaced_frame) ||
850
+                (!s->deint))
851
+                s->second = src;
852
+        } else {
853
+            av_frame_free(&src);
854
+        }
855
+        if (dst->pts != AV_NOPTS_VALUE)
856
+            dst->pts *= 2;
857
+        return ff_filter_frame(outlink, dst);
858
+    }
859
+
860
+    s->src = src;
861
+    ret = get_frame(ctx, 0);
862
+    if (ret < 0) {
863
+        av_frame_free(&s->dst);
864
+        av_frame_free(&s->src);
865
+        av_frame_free(&s->second);
866
+        return ret;
867
+    }
868
+
869
+    if (src->pts != AV_NOPTS_VALUE)
870
+        s->dst->pts = src->pts * 2;
871
+    if (s->field <= 1 && s->field > -2) {
872
+        av_frame_free(&src);
873
+        s->src = NULL;
874
+    }
875
+
876
+    return ff_filter_frame(outlink, s->dst);
877
+}
878
+
879
+static int request_frame(AVFilterLink *link)
880
+{
881
+    AVFilterContext *ctx = link->src;
882
+    NNEDIContext *s = ctx->priv;
883
+    int ret;
884
+
885
+    if (s->eof)
886
+        return AVERROR_EOF;
887
+
888
+    ret  = ff_request_frame(ctx->inputs[0]);
889
+
890
+    if (ret == AVERROR_EOF && s->second) {
891
+        AVFrame *next = av_frame_clone(s->second);
892
+
893
+        if (!next)
894
+            return AVERROR(ENOMEM);
895
+
896
+        next->pts = s->second->pts * 2 - s->cur_pts;
897
+        s->eof = 1;
898
+
899
+        filter_frame(ctx->inputs[0], next);
900
+    } else if (ret < 0) {
901
+        return ret;
902
+    }
903
+
904
+    return 0;
905
+}
906
+
907
+static av_cold int init(AVFilterContext *ctx)
908
+{
909
+    NNEDIContext *s = ctx->priv;
910
+    FILE *weights_file = NULL;
911
+    int64_t expected_size = 13574928;
912
+    int64_t weights_size;
913
+    float *bdata;
914
+    size_t bytes_read;
915
+    const int xdia_table[NUM_NSIZE] = { 8, 16, 32, 48, 8, 16, 32 };
916
+    const int ydia_table[NUM_NSIZE] = { 6, 6, 6, 6, 4, 4, 4 };
917
+    const int nns_table[NUM_NNS] = { 16, 32, 64, 128, 256 };
918
+    const int dims0 = 49 * 4 + 5 * 4 + 9 * 4;
919
+    const int dims0new = 4 * 65 + 4 * 5;
920
+    const int dims1 = nns_table[s->nnsparam] * 2 * (xdia_table[s->nsize] * ydia_table[s->nsize] + 1);
921
+    int dims1tsize = 0;
922
+    int dims1offset = 0;
923
+    int ret = 0, i, j, k;
924
+
925
+    weights_file = fopen(s->weights_file, "rb");
926
+    if (!weights_file) {
927
+        av_log(ctx, AV_LOG_ERROR, "No weights file provided, aborting!\n");
928
+        return AVERROR(EINVAL);
929
+    }
930
+
931
+    if (fseek(weights_file, 0, SEEK_END)) {
932
+        av_log(ctx, AV_LOG_ERROR, "Couldn't seek to the end of weights file.\n");
933
+        fclose(weights_file);
934
+        return AVERROR(EINVAL);
935
+    }
936
+
937
+    weights_size = ftell(weights_file);
938
+
939
+    if (weights_size == -1) {
940
+        fclose(weights_file);
941
+        av_log(ctx, AV_LOG_ERROR, "Couldn't get size of weights file.\n");
942
+        return AVERROR(EINVAL);
943
+    } else if (weights_size != expected_size) {
944
+        fclose(weights_file);
945
+        av_log(ctx, AV_LOG_ERROR, "Unexpected weights file size.\n");
946
+        return AVERROR(EINVAL);
947
+    }
948
+
949
+    if (fseek(weights_file, 0, SEEK_SET)) {
950
+        fclose(weights_file);
951
+        av_log(ctx, AV_LOG_ERROR, "Couldn't seek to the start of weights file.\n");
952
+        return AVERROR(EINVAL);
953
+    }
954
+
955
+    bdata = (float *)av_malloc(expected_size);
956
+    if (!bdata) {
957
+        fclose(weights_file);
958
+        return AVERROR(ENOMEM);
959
+    }
960
+
961
+    bytes_read = fread(bdata, 1, expected_size, weights_file);
962
+
963
+    if (bytes_read != (size_t)expected_size) {
964
+        fclose(weights_file);
965
+        ret = AVERROR_INVALIDDATA;
966
+        av_log(ctx, AV_LOG_ERROR, "Couldn't read weights file.\n");
967
+        goto fail;
968
+    }
969
+
970
+    fclose(weights_file);
971
+
972
+    for (j = 0; j < NUM_NNS; j++) {
973
+        for (i = 0; i < NUM_NSIZE; i++) {
974
+            if (i == s->nsize && j == s->nnsparam)
975
+                dims1offset = dims1tsize;
976
+            dims1tsize += nns_table[j] * 2 * (xdia_table[i] * ydia_table[i] + 1) * 2;
977
+        }
978
+    }
979
+
980
+    s->weights0 = av_malloc_array(FFMAX(dims0, dims0new), sizeof(float));
981
+    if (!s->weights0) {
982
+        ret = AVERROR(ENOMEM);
983
+        goto fail;
984
+    }
985
+
986
+    for (i = 0; i < 2; i++) {
987
+        s->weights1[i] = av_malloc_array(dims1, sizeof(float));
988
+        if (!s->weights1[i]) {
989
+            ret = AVERROR(ENOMEM);
990
+            goto fail;
991
+        }
992
+    }
993
+
994
+    // Adjust prescreener weights
995
+    if (s->pscrn >= 2) {// using new prescreener
996
+        const float *bdw;
997
+        int16_t *ws;
998
+        float *wf;
999
+        double mean[4] = { 0.0, 0.0, 0.0, 0.0 };
1000
+        int *offt = av_calloc(4 * 64, sizeof(int));
1001
+
1002
+        if (!offt) {
1003
+            ret = AVERROR(ENOMEM);
1004
+            goto fail;
1005
+        }
1006
+
1007
+        for (j = 0; j < 4; j++)
1008
+            for (k = 0; k < 64; k++)
1009
+                offt[j * 64 + k] = ((k >> 3) << 5) + ((j & 3) << 3) + (k & 7);
1010
+
1011
+        bdw = bdata + dims0 + dims0new * (s->pscrn - 2);
1012
+        ws = (int16_t *)s->weights0;
1013
+        wf = (float *)&ws[4 * 64];
1014
+        // Calculate mean weight of each first layer neuron
1015
+        for (j = 0; j < 4; j++) {
1016
+            double cmean = 0.0;
1017
+            for (k = 0; k < 64; k++)
1018
+                cmean += bdw[offt[j * 64 + k]];
1019
+            mean[j] = cmean / 64.0;
1020
+        }
1021
+        // Factor mean removal and 1.0/127.5 scaling
1022
+        // into first layer weights. scale to int16 range
1023
+        for (j = 0; j < 4; j++) {
1024
+            double scale, mval = 0.0;
1025
+
1026
+            for (k = 0; k < 64; k++)
1027
+                mval = FFMAX(mval, FFABS((bdw[offt[j * 64 + k]] - mean[j]) / 127.5));
1028
+            scale = 32767.0 / mval;
1029
+            for (k = 0; k < 64; k++)
1030
+                ws[offt[j * 64 + k]] = roundds(((bdw[offt[j * 64 + k]] - mean[j]) / 127.5) * scale);
1031
+            wf[j] = (float)(mval / 32767.0);
1032
+        }
1033
+        memcpy(wf + 4, bdw + 4 * 64, (dims0new - 4 * 64) * sizeof(float));
1034
+        av_free(offt);
1035
+    } else { // using old prescreener
1036
+        double mean[4] = { 0.0, 0.0, 0.0, 0.0 };
1037
+        // Calculate mean weight of each first layer neuron
1038
+        for (j = 0; j < 4; j++) {
1039
+            double cmean = 0.0;
1040
+            for (k = 0; k < 48; k++)
1041
+                cmean += bdata[j * 48 + k];
1042
+            mean[j] = cmean / 48.0;
1043
+        }
1044
+        if (s->fapprox & 1) {// use int16 dot products in first layer
1045
+            int16_t *ws = (int16_t *)s->weights0;
1046
+            float *wf = (float *)&ws[4 * 48];
1047
+            // Factor mean removal and 1.0/127.5 scaling
1048
+            // into first layer weights. scale to int16 range
1049
+            for (j = 0; j < 4; j++) {
1050
+                double mval = 0.0;
1051
+                for (k = 0; k < 48; k++)
1052
+                    mval = FFMAX(mval, FFABS((bdata[j * 48 + k] - mean[j]) / 127.5));
1053
+                const double scale = 32767.0 / mval;
1054
+                for (k = 0; k < 48; k++)
1055
+                    ws[j * 48 + k] = roundds(((bdata[j * 48 + k] - mean[j]) / 127.5) * scale);
1056
+                wf[j] = (float)(mval / 32767.0);
1057
+            }
1058
+            memcpy(wf + 4, bdata + 4 * 48, (dims0 - 4 * 48) * sizeof(float));
1059
+        } else {// use float dot products in first layer
1060
+            double half = (1 << 8) - 1;
1061
+
1062
+            half /= 2;
1063
+
1064
+            // Factor mean removal and 1.0/half scaling
1065
+            // into first layer weights.
1066
+            for (j = 0; j < 4; j++)
1067
+                for (k = 0; k < 48; k++)
1068
+                    s->weights0[j * 48 + k] = (float)((bdata[j * 48 + k] - mean[j]) / half);
1069
+            memcpy(s->weights0 + 4 * 48, bdata + 4 * 48, (dims0 - 4 * 48) * sizeof(float));
1070
+        }
1071
+    }
1072
+
1073
+    // Adjust prediction weights
1074
+    for (i = 0; i < 2; i++) {
1075
+        const float *bdataT = bdata + dims0 + dims0new * 3 + dims1tsize * s->etype + dims1offset + i * dims1;
1076
+        const int nnst = nns_table[s->nnsparam];
1077
+        const int asize = xdia_table[s->nsize] * ydia_table[s->nsize];
1078
+        const int boff = nnst * 2 * asize;
1079
+        double *mean = (double *)av_calloc(asize + 1 + nnst * 2, sizeof(double));
1080
+
1081
+        if (!mean) {
1082
+            ret = AVERROR(ENOMEM);
1083
+            goto fail;
1084
+        }
1085
+
1086
+        // Calculate mean weight of each neuron (ignore bias)
1087
+        for (j = 0; j < nnst * 2; j++) {
1088
+            double cmean = 0.0;
1089
+            for (k = 0; k < asize; k++)
1090
+                cmean += bdataT[j * asize + k];
1091
+            mean[asize + 1 + j] = cmean / (double)asize;
1092
+        }
1093
+        // Calculate mean softmax neuron
1094
+        for (j = 0; j < nnst; j++) {
1095
+            for (k = 0; k < asize; k++)
1096
+                mean[k] += bdataT[j * asize + k] - mean[asize + 1 + j];
1097
+            mean[asize] += bdataT[boff + j];
1098
+        }
1099
+        for (j = 0; j < asize + 1; j++)
1100
+            mean[j] /= (double)(nnst);
1101
+
1102
+        if (s->fapprox & 2) { // use int16 dot products
1103
+            int16_t *ws = (int16_t *)s->weights1[i];
1104
+            float *wf = (float *)&ws[nnst * 2 * asize];
1105
+            // Factor mean removal into weights, remove global offset from
1106
+            // softmax neurons, and scale weights to int16 range.
1107
+            for (j = 0; j < nnst; j++) { // softmax neurons
1108
+                double scale, mval = 0.0;
1109
+                for (k = 0; k < asize; k++)
1110
+                    mval = FFMAX(mval, FFABS(bdataT[j * asize + k] - mean[asize + 1 + j] - mean[k]));
1111
+                scale = 32767.0 / mval;
1112
+                for (k = 0; k < asize; k++)
1113
+                    ws[j * asize + k] = roundds((bdataT[j * asize + k] - mean[asize + 1 + j] - mean[k]) * scale);
1114
+                wf[(j >> 2) * 8 + (j & 3)] = (float)(mval / 32767.0);
1115
+                wf[(j >> 2) * 8 + (j & 3) + 4] = (float)(bdataT[boff + j] - mean[asize]);
1116
+            }
1117
+            for (j = nnst; j < nnst * 2; j++) { // elliott neurons
1118
+                double scale, mval = 0.0;
1119
+                for (k = 0; k < asize; k++)
1120
+                    mval = FFMAX(mval, FFABS(bdataT[j * asize + k] - mean[asize + 1 + j]));
1121
+                scale = 32767.0 / mval;
1122
+                for (k = 0; k < asize; k++)
1123
+                    ws[j * asize + k] = roundds((bdataT[j * asize + k] - mean[asize + 1 + j]) * scale);
1124
+                wf[(j >> 2) * 8 + (j & 3)] = (float)(mval / 32767.0);
1125
+                wf[(j >> 2) * 8 + (j & 3) + 4] = bdataT[boff + j];
1126
+            }
1127
+        } else { // use float dot products
1128
+            // Factor mean removal into weights, and remove global
1129
+            // offset from softmax neurons.
1130
+            for (j = 0; j < nnst * 2; j++) {
1131
+                for (k = 0; k < asize; k++) {
1132
+                    const double q = j < nnst ? mean[k] : 0.0;
1133
+                    s->weights1[i][j * asize + k] = (float)(bdataT[j * asize + k] - mean[asize + 1 + j] - q);
1134
+                }
1135
+                s->weights1[i][boff + j] = (float)(bdataT[boff + j] - (j < nnst ? mean[asize] : 0.0));
1136
+            }
1137
+        }
1138
+        av_free(mean);
1139
+    }
1140
+
1141
+    s->nns = nns_table[s->nnsparam];
1142
+    s->xdia = xdia_table[s->nsize];
1143
+    s->ydia = ydia_table[s->nsize];
1144
+    s->asize = xdia_table[s->nsize] * ydia_table[s->nsize];
1145
+
1146
+    s->max_value = 65535 >> 8;
1147
+
1148
+    select_functions(s);
1149
+
1150
+    s->fdsp = avpriv_float_dsp_alloc(0);
1151
+    if (!s->fdsp)
1152
+        return AVERROR(ENOMEM);
1153
+
1154
+fail:
1155
+    av_free(bdata);
1156
+    return ret;
1157
+}
1158
+
1159
+static av_cold void uninit(AVFilterContext *ctx)
1160
+{
1161
+    NNEDIContext *s = ctx->priv;
1162
+    int i;
1163
+
1164
+    av_freep(&s->weights0);
1165
+
1166
+    for (i = 0; i < 2; i++)
1167
+        av_freep(&s->weights1[i]);
1168
+
1169
+    for (i = 0; i < s->nb_planes; i++) {
1170
+        av_freep(&s->frame_data.paddedp[i]);
1171
+        av_freep(&s->frame_data.lcount[i]);
1172
+    }
1173
+
1174
+    av_freep(&s->frame_data.input);
1175
+    av_freep(&s->frame_data.temp);
1176
+    av_frame_free(&s->second);
1177
+}
1178
+
1179
+static const AVFilterPad inputs[] = {
1180
+    {
1181
+        .name          = "default",
1182
+        .type          = AVMEDIA_TYPE_VIDEO,
1183
+        .filter_frame  = filter_frame,
1184
+        .config_props  = config_input,
1185
+    },
1186
+    { NULL }
1187
+};
1188
+
1189
+static const AVFilterPad outputs[] = {
1190
+    {
1191
+        .name          = "default",
1192
+        .type          = AVMEDIA_TYPE_VIDEO,
1193
+        .config_props  = config_output,
1194
+        .request_frame = request_frame,
1195
+    },
1196
+    { NULL }
1197
+};
1198
+
1199
+AVFilter ff_vf_nnedi = {
1200
+    .name          = "nnedi",
1201
+    .description   = NULL_IF_CONFIG_SMALL("Apply neural network edge directed interpolation intra-only deinterlacer."),
1202
+    .priv_size     = sizeof(NNEDIContext),
1203
+    .priv_class    = &nnedi_class,
1204
+    .init          = init,
1205
+    .uninit        = uninit,
1206
+    .query_formats = query_formats,
1207
+    .inputs        = inputs,
1208
+    .outputs       = outputs,
1209
+    .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL,
1210
+};