1 #include "All.h"
2 #include "GlobalFunctions.h"
3 #include "NNFilter.h"
4 #include "Assembly.h"
5
CNNFilter(int nOrder,int nShift,int nVersion)6 CNNFilter::CNNFilter(int nOrder, int nShift, int nVersion)
7 {
8 if ((nOrder <= 0) || ((nOrder % 16) != 0)) throw(1);
9 m_nOrder = nOrder;
10 m_nShift = nShift;
11 m_nVersion = nVersion;
12
13 m_bMMXAvailable = GetMMXAvailable();
14
15 m_rbInput.Create(NN_WINDOW_ELEMENTS, m_nOrder);
16 m_rbDeltaM.Create(NN_WINDOW_ELEMENTS, m_nOrder);
17 m_paryM = new short [m_nOrder];
18
19 #ifdef NN_TEST_MMX
20 srand(GetTickCount());
21 #endif
22 }
23
~CNNFilter()24 CNNFilter::~CNNFilter()
25 {
26 SAFE_ARRAY_DELETE(m_paryM)
27 }
28
Flush()29 void CNNFilter::Flush()
30 {
31 memset(&m_paryM[0], 0, m_nOrder * sizeof(short));
32 m_rbInput.Flush();
33 m_rbDeltaM.Flush();
34 m_nRunningAverage = 0;
35 }
36
Compress(int nInput)37 int CNNFilter::Compress(int nInput)
38 {
39 // convert the input to a short and store it
40 m_rbInput[0] = GetSaturatedShortFromInt(nInput);
41
42 // figure a dot product
43 int nDotProduct;
44 if (m_bMMXAvailable)
45 nDotProduct = CalculateDotProduct(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
46 else
47 nDotProduct = CalculateDotProductNoMMX(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
48
49 // calculate the output
50 int nOutput = nInput - ((nDotProduct + (1 << (m_nShift - 1))) >> m_nShift);
51
52 // adapt
53 if (m_bMMXAvailable)
54 Adapt(&m_paryM[0], &m_rbDeltaM[-m_nOrder], -nOutput, m_nOrder);
55 else
56 AdaptNoMMX(&m_paryM[0], &m_rbDeltaM[-m_nOrder], nOutput, m_nOrder);
57
58 int nTempABS = abs(nInput);
59
60 if (nTempABS > (m_nRunningAverage * 3))
61 m_rbDeltaM[0] = ((nInput >> 25) & 64) - 32;
62 else if (nTempABS > (m_nRunningAverage * 4) / 3)
63 m_rbDeltaM[0] = ((nInput >> 26) & 32) - 16;
64 else if (nTempABS > 0)
65 m_rbDeltaM[0] = ((nInput >> 27) & 16) - 8;
66 else
67 m_rbDeltaM[0] = 0;
68
69 m_nRunningAverage += (nTempABS - m_nRunningAverage) / 16;
70
71 m_rbDeltaM[-1] >>= 1;
72 m_rbDeltaM[-2] >>= 1;
73 m_rbDeltaM[-8] >>= 1;
74
75 // increment and roll if necessary
76 m_rbInput.IncrementSafe();
77 m_rbDeltaM.IncrementSafe();
78
79 return nOutput;
80 }
81
Decompress(int nInput)82 int CNNFilter::Decompress(int nInput)
83 {
84 // figure a dot product
85 int nDotProduct;
86
87 if (m_bMMXAvailable)
88 nDotProduct = CalculateDotProduct(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
89 else
90 nDotProduct = CalculateDotProductNoMMX(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
91
92 // adapt
93 if (m_bMMXAvailable)
94 Adapt(&m_paryM[0], &m_rbDeltaM[-m_nOrder], -nInput, m_nOrder);
95 else
96 AdaptNoMMX(&m_paryM[0], &m_rbDeltaM[-m_nOrder], nInput, m_nOrder);
97
98 // store the output value
99 int nOutput = nInput + ((nDotProduct + (1 << (m_nShift - 1))) >> m_nShift);
100
101 // update the input buffer
102 m_rbInput[0] = GetSaturatedShortFromInt(nOutput);
103
104 if (m_nVersion >= 3980)
105 {
106 int nTempABS = abs(nOutput);
107
108 if (nTempABS > (m_nRunningAverage * 3))
109 m_rbDeltaM[0] = ((nOutput >> 25) & 64) - 32;
110 else if (nTempABS > (m_nRunningAverage * 4) / 3)
111 m_rbDeltaM[0] = ((nOutput >> 26) & 32) - 16;
112 else if (nTempABS > 0)
113 m_rbDeltaM[0] = ((nOutput >> 27) & 16) - 8;
114 else
115 m_rbDeltaM[0] = 0;
116
117 m_nRunningAverage += (nTempABS - m_nRunningAverage) / 16;
118
119 m_rbDeltaM[-1] >>= 1;
120 m_rbDeltaM[-2] >>= 1;
121 m_rbDeltaM[-8] >>= 1;
122 }
123 else
124 {
125 m_rbDeltaM[0] = (nOutput == 0) ? 0 : ((nOutput >> 28) & 8) - 4;
126 m_rbDeltaM[-4] >>= 1;
127 m_rbDeltaM[-8] >>= 1;
128 }
129
130 // increment and roll if necessary
131 m_rbInput.IncrementSafe();
132 m_rbDeltaM.IncrementSafe();
133
134 return nOutput;
135 }
136
AdaptNoMMX(short * pM,short * pAdapt,int nDirection,int nOrder)137 void CNNFilter::AdaptNoMMX(short * pM, short * pAdapt, int nDirection, int nOrder)
138 {
139 nOrder >>= 4;
140
141 if (nDirection < 0)
142 {
143 while (nOrder--)
144 {
145 EXPAND_16_TIMES(*pM++ += *pAdapt++;)
146 }
147 }
148 else if (nDirection > 0)
149 {
150 while (nOrder--)
151 {
152 EXPAND_16_TIMES(*pM++ -= *pAdapt++;)
153 }
154 }
155 }
156
CalculateDotProductNoMMX(short * pA,short * pB,int nOrder)157 int CNNFilter::CalculateDotProductNoMMX(short * pA, short * pB, int nOrder)
158 {
159 int nDotProduct = 0;
160 nOrder >>= 4;
161
162 while (nOrder--)
163 {
164 EXPAND_16_TIMES(nDotProduct += *pA++ * *pB++;)
165 }
166
167 return nDotProduct;
168 }
169