1 /*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18 #include "RenderScript.h"
19 #include "rsCppInternal.h"
20
21 #define NELEM(m) (sizeof(m) / sizeof((m)[0]))
22
23 using android::RSC::Allocation;
24 using android::RSC::Element;
25 using android::RSC::RS;
26 using android::RSC::RS_ERROR_INVALID_ELEMENT;
27 using android::RSC::RS_ERROR_INVALID_PARAMETER;
28 using android::RSC::RS_SUCCESS;
29 using android::RSC::ScriptIntrinsicBLAS;
30 using android::RSC::sp;
31
32 // ScriptIntrinsicBLAS APIS
ScriptIntrinsicBLAS(sp<RS> rs,sp<const Element> e)33 ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
34 : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
35
36 }
37
create(const sp<RS> & rs)38 sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(const sp<RS>& rs) {
39 return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
40 }
41
42 enum RsBlasDataType {
43 SINGLE,
44 DOUBLE,
45 SINGLE_COMPLEX,
46 DOUBLE_COMPLEX
47 };
48
49 static RsBlasCall
setUpBLASCall(RsBlasDataType dataType,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,int incX,int incY,int KL,int KU,float alphaF,float betaF,double alphaD,double betaD,float alphaCX,float alphaCY,float betaCX,float betaCY,double alphaZX,double alphaZY,double betaZX,double betaZY)50 setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
51 int TransA, int TransB, int Side, int Uplo, int Diag,
52 int M, int N, int K, int incX, int incY, int KL, int KU,
53 float alphaF, float betaF, double alphaD, double betaD,
54 float alphaCX, float alphaCY, float betaCX, float betaCY,
55 double alphaZX, double alphaZY, double betaZX, double betaZY
56 ) {
57 RsBlasCall call;
58 memset(&call, 0, sizeof(call));
59 call.func = func;
60 call.transA = (RsBlasTranspose)TransA;
61 call.transB = (RsBlasTranspose)TransB;
62 call.side = (RsBlasSide)Side;
63 call.uplo = (RsBlasUplo)Uplo;
64 call.diag = (RsBlasDiag)Diag;
65 call.M = M;
66 call.N = N;
67 call.K = K;
68
69 switch (dataType) {
70 case SINGLE:
71 // For Single-precision BLAS.
72 call.alpha.f = alphaF;
73 call.beta.f = betaF;
74 break;
75 case DOUBLE:
76 // For Double-precision BLAS.
77 call.alpha.d = alphaD;
78 call.beta.d = betaD;
79 break;
80 case SINGLE_COMPLEX:
81 // For Single-precision complex BLAS.
82 call.alpha.c.r = alphaCX;
83 call.alpha.c.i = alphaCY;
84 call.beta.c.r = betaCX;
85 call.beta.c.i = betaCY;
86 break;
87 case DOUBLE_COMPLEX:
88 // For Double-precision complex BLAS.
89 call.alpha.z.r = alphaZX;
90 call.alpha.z.i = alphaZY;
91 call.beta.z.r = betaZX;
92 call.beta.z.i = betaZY;
93 break;
94 default:
95 break;
96 }
97
98 call.incX = incX;
99 call.incY = incY;
100 call.KL = KL;
101 call.KU = KU;
102
103 return call;
104 }
105
106 static void
nScriptIntrinsicBLAS_Single(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alpha,RsAllocation A,RsAllocation B,float beta,RsAllocation C,int incX,int incY,int KL,int KU)107 nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
108 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
109 float alpha, RsAllocation A, RsAllocation B,
110 float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
111 RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
112 M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
113 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
114 RsAllocation in_allocs[3] = {A, B, C};
115 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
116 &call, sizeof(call), nullptr, 0));
117 }
118
119
120 static void
nScriptIntrinsicBLAS_Double(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alpha,RsAllocation A,RsAllocation B,double beta,RsAllocation C,int incX,int incY,int KL,int KU)121 nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
122 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
123 double alpha, RsAllocation A, RsAllocation B,
124 double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
125 RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
126 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
127 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
128 RsAllocation in_allocs[3] = {A, B, C};
129 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
130 &call, sizeof(call), nullptr, 0));
131 }
132
133 static void
nScriptIntrinsicBLAS_Complex(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alphaX,float alphaY,RsAllocation A,RsAllocation B,float betaX,float betaY,RsAllocation C,int incX,int incY,int KL,int KU)134 nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
135 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
136 float alphaX, float alphaY, RsAllocation A, RsAllocation B,
137 float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
138 RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
139 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
140 alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
141 RsAllocation in_allocs[3] = {A, B, C};
142 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
143 &call, sizeof(call), nullptr, 0));
144 }
145
146 static void
nScriptIntrinsicBLAS_Z(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alphaX,double alphaY,RsAllocation A,RsAllocation B,double betaX,double betaY,RsAllocation C,int incX,int incY,int KL,int KU)147 nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
148 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
149 double alphaX, double alphaY, RsAllocation A, RsAllocation B,
150 double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
151 RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
152 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
153 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
154 RsAllocation in_allocs[3] = {A, B, C};
155 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
156 &call, sizeof(call), nullptr, 0));
157 }
158
159
160 static void
nScriptIntrinsicBLAS_BNNM(RS * mRS,RsContext con,RsScript id,int M,int N,int K,RsAllocation A,int a_offset,RsAllocation B,int b_offset,RsAllocation C,int c_offset,int c_mult_int)161 nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
162 RsAllocation A, int a_offset, RsAllocation B, int b_offset,
163 RsAllocation C, int c_offset, int c_mult_int) {
164 RsBlasCall call;
165 memset(&call, 0, sizeof(call));
166 call.func = RsBlas_bnnm;
167 call.M = M;
168 call.N = N;
169 call.K = K;
170 call.a_offset = a_offset & 0xFF;
171 call.b_offset = b_offset & 0xFF;
172 call.c_offset = c_offset;
173 call.c_mult_int = c_mult_int;
174
175 RsAllocation in_allocs[3] = {A, B, C};
176 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
177 &call, sizeof(call), nullptr, 0));
178 }
179
180 /**
181 * Level 2 BLAS
182 */
validateGEMV(RS * mRS,const sp<const Element> & e,RsBlasTranspose TransA,const sp<Allocation> & A,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY)183 static void validateGEMV(RS* mRS, const sp<const Element>& e, RsBlasTranspose TransA, const sp<Allocation>& A,
184 const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) {
185 int M = A->getType()->getY();
186 int N = A->getType()->getX();
187 if (!A->getType()->getElement()->isCompatible(e) ||
188 !X->getType()->getElement()->isCompatible(e) ||
189 !Y->getType()->getElement()->isCompatible(e)) {
190 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
191 }
192 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
193 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
194 }
195
196 if (incX <= 0 || incY <= 0) {
197 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
198 }
199 int expectedXDim = -1, expectedYDim = -1;
200 if (TransA == RsBlasNoTrans) {
201 expectedXDim = 1 + (N - 1) * incX;
202 expectedYDim = 1 + (M - 1) * incY;
203 } else {
204 expectedXDim = 1 + (M - 1) * incX;
205 expectedYDim = 1 + (N - 1) * incY;
206 }
207 if ((int)X->getType()->getX() != expectedXDim ||
208 (int)Y->getType()->getX() != expectedYDim) {
209 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
210 }
211 }
212
SGEMV(RsBlasTranspose TransA,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)213 void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
214 int incX, float beta, const sp<Allocation>& Y, int incY) {
215 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
216 int M = A->getType()->getY();
217 int N = A->getType()->getX();
218 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
219 TransA, 0, 0, 0, 0, M, N, 0,
220 alpha, A->getID(), X->getID(),
221 beta, Y->getID(), incX, incY, 0, 0);
222 }
223
DGEMV(RsBlasTranspose TransA,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)224 void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
225 int incX, double beta, const sp<Allocation>& Y, int incY) {
226 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
227 int M = A->getType()->getY();
228 int N = A->getType()->getX();
229 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
230 TransA, 0, 0, 0, 0, M, N, 0,
231 alpha, A->getID(), X->getID(),
232 beta, Y->getID(), incX, incY, 0, 0);
233 }
234
CGEMV(RsBlasTranspose TransA,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)235 void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
236 int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
237 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
238 int M = A->getType()->getY();
239 int N = A->getType()->getX();
240 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
241 TransA, 0, 0, 0, 0, M, N, 0,
242 alpha.x, alpha.y, A->getID(), X->getID(),
243 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
244 }
245
ZGEMV(RsBlasTranspose TransA,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)246 void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
247 int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
248 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
249 int M = A->getType()->getY();
250 int N = A->getType()->getX();
251 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
252 TransA, 0, 0, 0, 0, M, N, 0,
253 alpha.x, alpha.y, A->getID(), X->getID(),
254 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
255 }
256
SGBMV(RsBlasTranspose TransA,int KL,int KU,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)257 void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, const sp<Allocation>& A,
258 const sp<Allocation>& X, int incX, float beta, const sp<Allocation>& Y, int incY) {
259 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
260 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
261 if (KL < 0 || KU < 0) {
262 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
263 }
264 int M = A->getType()->getY();
265 int N = A->getType()->getX();
266
267 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
268 TransA, 0, 0, 0, 0, M, N, 0,
269 alpha, A->getID(), X->getID(),
270 beta, Y->getID(), incX, incY, KL, KU);
271 }
272
DGBMV(RsBlasTranspose TransA,int KL,int KU,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)273 void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, const sp<Allocation>& A,
274 const sp<Allocation>& X, int incX, double beta, const sp<Allocation>& Y, int incY) {
275 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
276 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
277 if (KL < 0 || KU < 0) {
278 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
279 }
280 int M = A->getType()->getY();
281 int N = A->getType()->getX();
282
283 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
284 TransA, 0, 0, 0, 0, M, N, 0,
285 alpha, A->getID(), X->getID(),
286 beta, Y->getID(), incX, incY, KL, KU);
287 }
288
CGBMV(RsBlasTranspose TransA,int KL,int KU,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)289 void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, const sp<Allocation>& A,
290 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
291 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
292 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
293 if (KL < 0 || KU < 0) {
294 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
295 }
296 int M = A->getType()->getY();
297 int N = A->getType()->getX();
298
299 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
300 TransA, 0, 0, 0, 0, M, N, 0,
301 alpha.x, alpha.y, A->getID(), X->getID(),
302 beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
303 }
304
ZGBMV(RsBlasTranspose TransA,int KL,int KU,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)305 void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, const sp<Allocation>& A,
306 const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
307 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
308 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
309 if (KL < 0 || KU < 0) {
310 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
311 }
312 int M = A->getType()->getY();
313 int N = A->getType()->getX();
314
315 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
316 TransA, 0, 0, 0, 0, M, N, 0,
317 alpha.x, alpha.y, A->getID(), X->getID(),
318 beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
319 }
320
validateTRMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)321 static void validateTRMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA,
322 RsBlasDiag Diag, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
323 int N = A->getType()->getY();
324 if ((int)A->getType()->getX() != N) {
325 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
326 }
327 if (!A->getType()->getElement()->isCompatible(e) ||
328 !X->getType()->getElement()->isCompatible(e)) {
329 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
330 }
331 if (X->getType()->getY() > 1) {
332 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
333 }
334
335 if (incX <= 0) {
336 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
337 }
338 int expectedXDim = 1 + (N - 1) * incX;
339 if ((int)X->getType()->getX() != expectedXDim) {
340 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
341 }
342 }
343
validateTPMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)344 static int validateTPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA,
345 RsBlasDiag Diag, const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
346 if (!Ap->getType()->getElement()->isCompatible(e) ||
347 !X->getType()->getElement()->isCompatible(e)) {
348 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
349 }
350 if (X->getType()->getY() > 1) {
351 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
352 }
353
354 if (Ap->getType()->getY() > 1) {
355 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
356 }
357
358 int N = sqrt((double)Ap->getType()->getX() * 2);
359 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
360 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
361 }
362 if (incX <= 0) {
363 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
364 }
365 int expectedXDim = 1 + (N - 1) * incX;
366 if ((int)X->getType()->getX() != expectedXDim) {
367 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
368 }
369
370 return N;
371 }
372
373
STRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)374 void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
375 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
376 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
377 int N = A->getType()->getY();
378 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
379 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
380 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
381 }
382
DTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)383 void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
384 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
385 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
386 int N = A->getType()->getY();
387 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
388 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
389 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
390 }
391
CTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)392 void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
393 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
394 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
395 int N = A->getType()->getY();
396 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
397 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
398 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
399 }
400
ZTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)401 void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
402 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
403 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
404 int N = A->getType()->getY();
405 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
406 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
407 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
408 }
409
STBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)410 void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
411 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
412 // TBMV has the same requirements as TRMV + K >= 0
413 if (K < 0) {
414 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
415 }
416 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
417 int N = A->getType()->getY();
418 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
419 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
420 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
421 }
422
DTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)423 void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
424 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
425 // TBMV has the same requirements as TRMV + K >= 0
426 if (K < 0) {
427 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
428 }
429 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
430 int N = A->getType()->getY();
431 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
432 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
433 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
434 }
435
CTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)436 void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
437 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
438 // TBMV has the same requirements as TRMV + K >= 0
439 if (K < 0) {
440 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
441 }
442 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
443 int N = A->getType()->getY();
444 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
445 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
446 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
447 }
448
ZTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)449 void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
450 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
451 // TBMV has the same requirements as TRMV + K >= 0
452 if (K < 0) {
453 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
454 }
455 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
456 int N = A->getType()->getY();
457 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
458 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
459 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
460 }
461
STPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)462 void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
463 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
464 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
465 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
466 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
467 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
468 }
469
DTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)470 void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
471 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
472 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
473 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
474 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
475 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
476 }
477
CTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)478 void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
479 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
480 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
481 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
482 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
483 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
484 }
485
ZTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)486 void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
487 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
488 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
489 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
490 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
491 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
492 }
493
STRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)494 void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
495 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
496 // TRSV is the same as TRMV
497 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
498 int N = A->getType()->getY();
499 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
500 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
501 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
502 }
503
DTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)504 void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
505 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
506 // TRSV is the same as TRMV
507 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
508 int N = A->getType()->getY();
509 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
510 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
511 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
512
513 }
514
CTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)515 void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
516 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
517 // TRSV is the same as TRMV
518 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
519 int N = A->getType()->getY();
520 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
521 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
522 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
523
524 }
525
ZTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)526 void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
527 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
528 // TRSV is the same as TRMV
529 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
530 int N = A->getType()->getY();
531 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
532 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
533 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
534
535 }
536
STBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)537 void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
538 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
539 // TBSV is the same as TRMV + K >= 0
540 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
541 int N = A->getType()->getY();
542 if (K < 0) {
543 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
544 }
545 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
546 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
547 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
548 }
549
DTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)550 void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
551 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
552 // TBSV is the same as TRMV + K >= 0
553 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
554 int N = A->getType()->getY();
555 if (K < 0) {
556 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
557 }
558 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
559 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
560 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
561 }
562
CTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)563 void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
564 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
565 // TBSV is the same as TRMV + K >= 0
566 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
567 int N = A->getType()->getY();
568 if (K < 0) {
569 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
570 }
571 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
572 TransA, 0, 0, Uplo, Diag, 0, N, K,
573 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
574 }
575
ZTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)576 void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
577 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
578 // TBSV is the same as TRMV + K >= 0
579 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
580 int N = A->getType()->getY();
581 if (K < 0) {
582 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
583 }
584 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
585 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
586 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
587 }
588
STPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)589 void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
590 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
591 // TPSV is same as TPMV
592 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
593 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
594 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
595 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
596 }
597
DTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)598 void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
599 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
600 // TPSV is same as TPMV
601 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
602 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
603 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
604 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
605 }
606
CTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)607 void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
608 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
609 // TPSV is same as TPMV
610 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
611 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
612 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
613 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
614 }
615
ZTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)616 void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
617 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
618 // TPSV is same as TPMV
619 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
620 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
621 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
622 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
623 }
624
625 /**
626 * Level 2, S and D only
627 */
validateSYMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & A,const sp<Allocation> & X,const sp<Allocation> & Y,int incX,int incY)628 static int validateSYMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& A,
629 const sp<Allocation>& X, const sp<Allocation>& Y, int incX, int incY) {
630 int N = A->getType()->getY();
631 if ((int)A->getType()->getX() != N) {
632 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
633 }
634 if (!A->getType()->getElement()->isCompatible(e) ||
635 !X->getType()->getElement()->isCompatible(e) ||
636 !Y->getType()->getElement()->isCompatible(e) ) {
637 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
638 }
639 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
640 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
641 }
642
643 if (incX <= 0 || incY <= 0) {
644 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
645 }
646 int expectedXDim = 1 + (N - 1) * incX;
647 if ((int)X->getType()->getX() != expectedXDim) {
648 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
649 }
650 int expectedYDim = 1 + (N - 1) * incY;
651 if ((int)Y->getType()->getX() != expectedYDim) {
652 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
653 }
654 return N;
655 }
validateSPMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY)656 static int validateSPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& Ap,
657 const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) {
658 if (!Ap->getType()->getElement()->isCompatible(e) ||
659 !X->getType()->getElement()->isCompatible(e) ||
660 !Y->getType()->getElement()->isCompatible(e)) {
661 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
662 }
663 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
664 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
665 }
666
667 if (Ap->getType()->getY() > 1) {
668 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
669 }
670
671 int N = sqrt((double)Ap->getType()->getX() * 2);
672 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
673 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
674 }
675 if (incX <= 0 || incY <= 0) {
676 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
677 }
678 int expectedXDim = 1 + (N - 1) * incX;
679 if ((int)X->getType()->getX() != expectedXDim) {
680 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
681 }
682 int expectedYDim = 1 + (N - 1) * incY;
683 if ((int)Y->getType()->getX() != expectedYDim) {
684 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
685 }
686
687 return N;
688 }
validateGER(RS * mRS,const sp<const Element> & e,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)689 static void validateGER(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX,
690 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
691 if (!A->getType()->getElement()->isCompatible(e) ||
692 !X->getType()->getElement()->isCompatible(e) ||
693 !Y->getType()->getElement()->isCompatible(e) ) {
694 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
695 }
696
697 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
698 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
699 }
700
701 int M = A->getType()->getY();
702 int N = A->getType()->getX();
703
704 if (N < 1 || M < 1) {
705 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
706 }
707 if (incX <= 0 || incY <= 0) {
708 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
709 }
710 int expectedXDim = 1 + (M - 1) * incX;
711 if ((int)X->getType()->getX() != expectedXDim) {
712 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
713 }
714 int expectedYDim = 1 + (N - 1) * incY;
715 if ((int)Y->getType()->getX() != expectedYDim) {
716 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
717 }
718
719
720 }
validateSYR(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & A)721 static int validateSYR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo,
722 const sp<Allocation>& X, int incX, const sp<Allocation>& A) {
723 if (!A->getType()->getElement()->isCompatible(e) ||
724 !X->getType()->getElement()->isCompatible(e)) {
725 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
726 }
727
728 int N = A->getType()->getX();
729
730 if (X->getType()->getY() > 1) {
731 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
732 }
733 if (N != (int)A->getType()->getY()) {
734 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
735 }
736 if (incX <= 0) {
737 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
738 }
739 int expectedXDim = 1 + (N - 1) * incX;
740 if ((int)X->getType()->getX() != expectedXDim) {
741 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
742 }
743 return N;
744 }
validateSPR(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)745 static int validateSPR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo,
746 const sp<Allocation>& X, int incX, const sp<Allocation>& Ap) {
747 if (!Ap->getType()->getElement()->isCompatible(e) ||
748 !X->getType()->getElement()->isCompatible(e)) {
749 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
750 }
751 if (X->getType()->getY() > 1) {
752 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
753 }
754
755 if (Ap->getType()->getY() > 1) {
756 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
757 }
758
759 int N = sqrt((double)Ap->getType()->getX() * 2);
760 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
761 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
762 }
763 if (incX <= 0) {
764 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
765 }
766 int expectedXDim = 1 + (N - 1) * incX;
767 if ((int)X->getType()->getX() != expectedXDim) {
768 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
769 }
770
771 return N;
772 }
773
validateSYR2(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)774 static int validateSYR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X,
775 int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
776 if (!A->getType()->getElement()->isCompatible(e) ||
777 !X->getType()->getElement()->isCompatible(e) ||
778 !Y->getType()->getElement()->isCompatible(e)) {
779 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
780 }
781
782 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
783 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
784 }
785
786 int N = A->getType()->getX();
787
788 if (N != (int)A->getType()->getY()) {
789 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
790 }
791 if (incX <= 0 || incY <= 0) {
792 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
793 }
794 int expectedXDim = 1 + (N - 1) * incX;
795 int expectedYDim = 1 + (N - 1) * incY;
796 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
797 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
798 }
799 return N;
800
801 }
validateSPR2(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)802 static int validateSPR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X,
803 int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
804 if (!Ap->getType()->getElement()->isCompatible(e) ||
805 !X->getType()->getElement()->isCompatible(e) ||
806 !Y->getType()->getElement()->isCompatible(e)) {
807 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
808 }
809 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
810 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
811 }
812
813 if (Ap->getType()->getY() > 1) {
814 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
815 }
816
817 int N = sqrt((double)Ap->getType()->getX() * 2);
818 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
819 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
820 }
821 if (incX <= 0 || incY <= 0) {
822 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
823 }
824 int expectedXDim = 1 + (N - 1) * incX;
825 int expectedYDim = 1 + (N - 1) * incY;
826 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
827 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
828 }
829
830 return N;
831 }
832
SSYMV(RsBlasUplo Uplo,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)833 void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
834 int incX, float beta, const sp<Allocation>& Y, int incY) {
835 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
836 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
837 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
838 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
839 }
840
SSBMV(RsBlasUplo Uplo,int K,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)841 void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
842 int incX, float beta, const sp<Allocation>& Y, int incY) {
843 // SBMV is the same as SYMV + K >= 0
844 if (K < 0) {
845 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
846 }
847 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
848 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
849 0, 0, 0, Uplo, 0, 0, N, K, alpha,
850 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
851 }
852
SSPMV(RsBlasUplo Uplo,float alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)853 void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
854 int incX, float beta, const sp<Allocation>& Y, int incY) {
855 int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
856 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
857 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
858 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
859 }
860
SGER(float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)861 void ScriptIntrinsicBLAS::SGER(float alpha, const sp<Allocation>& X, int incX,
862 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
863 int M = A->getType()->getY();
864 int N = A->getType()->getX();
865 validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
866 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
867 0, 0, 0, 0, 0, M, N, 0, alpha,
868 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
869 }
870
SSYR(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)871 void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
872 int incX, const sp<Allocation>& A) {
873 int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
874 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
875 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
876 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
877 }
878
SSPR(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)879 void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
880 int incX, const sp<Allocation>& Ap) {
881 int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
882 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
883 0, 0, 0, Uplo, 0, 0, N, 0,
884 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
885 }
886
SSYR2(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)887 void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX,
888 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
889 int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
890 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
891 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
892 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
893 }
894
SSPR2(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)895 void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX,
896 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
897 int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
898 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
899 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
900 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
901 }
902
DSYMV(RsBlasUplo Uplo,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)903 void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
904 int incX, double beta, const sp<Allocation>& Y, int incY) {
905 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
906 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
907 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
908 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
909 }
910
DSBMV(RsBlasUplo Uplo,int K,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)911 void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
912 int incX, double beta, const sp<Allocation>& Y, int incY) {
913 // SBMV is the same as SYMV + K >= 0
914 if (K < 0) {
915 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
916 }
917 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
918 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
919 0, 0, 0, Uplo, 0, 0, N, K, alpha,
920 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
921 }
922
DSPMV(RsBlasUplo Uplo,double alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)923 void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
924 int incX, double beta, const sp<Allocation>& Y, int incY) {
925 int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
926 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
927 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
928 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
929 }
930
DGER(double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)931 void ScriptIntrinsicBLAS::DGER(double alpha, const sp<Allocation>& X, int incX, const sp<Allocation>& Y,
932 int incY, const sp<Allocation>& A) {
933 int M = A->getType()->getY();
934 int N = A->getType()->getX();
935 validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
936 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
937 0, 0, 0, 0, 0, M, N, 0, alpha,
938 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
939 }
940
DSYR(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)941 void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
942 int incX, const sp<Allocation>& A) {
943 int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
944 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
945 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
946 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
947 }
948
DSPR(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)949 void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
950 int incX, const sp<Allocation>& Ap) {
951 int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
952 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
953 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
954 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
955 }
956
DSYR2(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)957 void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX,
958 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
959 int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
960 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
961 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
962 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
963 }
964
DSPR2(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)965 void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX,
966 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
967 int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
968 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
969 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
970 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
971 }
972
973
974 /**
975 * Level 2, C and Z only
976 */
977
validateGERU(RS * mRS,const sp<const Element> & e,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)978 static void validateGERU(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX,
979 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
980 if (!A->getType()->getElement()->isCompatible(e) ||
981 !X->getType()->getElement()->isCompatible(e) ||
982 !Y->getType()->getElement()->isCompatible(e)) {
983 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
984 }
985 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
986 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
987 }
988
989 int M = A->getType()->getY();
990 int N = A->getType()->getX();
991 if (incX <= 0 || incY <= 0) {
992 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
993 }
994 int expectedXDim = 1 + (M - 1) * incX;
995 if ((int)X->getType()->getX() != expectedXDim) {
996 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
997 }
998 int expectedYDim = 1 + (N - 1) * incY;
999 if ((int)Y->getType()->getX() != expectedYDim) {
1000 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
1001 }
1002
1003 }
1004
CHEMV(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)1005 void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& A,
1006 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1007 // HEMV is the same as SYR2 validation-wise
1008 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1009 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
1010 0, 0, 0, Uplo, 0, 0, N, 0,
1011 alpha.x, alpha.y, A->getID(), X->getID(),
1012 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1013 }
1014
CHBMV(RsBlasUplo Uplo,int K,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)1015 void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, const sp<Allocation>& A,
1016 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1017 // HBMV is the same as SYR2 validation-wise
1018 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1019 if (K < 0) {
1020 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1021 }
1022 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
1023 0, 0, 0, Uplo, 0, 0, N, K,
1024 alpha.x, alpha.y, A->getID(), X->getID(),
1025 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1026 }
1027
CHPMV(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)1028 void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& Ap,
1029 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1030 // HPMV is the same as SPR2
1031 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1032 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
1033 0, 0, 0, Uplo, 0, 0, N, 0,
1034 alpha.x, alpha.y, Ap->getID(), X->getID(),
1035 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1036 }
1037
CGERU(Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1038 void ScriptIntrinsicBLAS::CGERU(Float2 alpha, const sp<Allocation>& X, int incX,
1039 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1040 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1041 int M = A->getType()->getY();
1042 int N = A->getType()->getX();
1043 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
1044 0, 0, 0, 0, 0, M, N, 0,
1045 alpha.x, alpha.y, X->getID(), Y->getID(),
1046 0, 0, A->getID(), incX, incY, 0, 0);
1047 }
1048
CGERC(Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1049 void ScriptIntrinsicBLAS::CGERC(Float2 alpha, const sp<Allocation>& X, int incX,
1050 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1051 // Same as GERU
1052 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1053 int M = A->getType()->getY();
1054 int N = A->getType()->getX();
1055 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
1056 0, 0, 0, 0, 0, M, N, 0,
1057 alpha.x, alpha.y, X->getID(), Y->getID(),
1058 0, 0, A->getID(), incX, incY, 0, 0);
1059 }
1060
CHER(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)1061 void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
1062 int incX, const sp<Allocation>& A) {
1063 // Same as SYR
1064 int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
1065 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
1066 0, 0, 0, Uplo, 0, 0, N, 0,
1067 alpha, 0, X->getID(), 0,
1068 0, 0, A->getID(), incX, 0, 0, 0);
1069 }
1070
CHPR(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)1071 void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
1072 int incX, const sp<Allocation>& Ap) {
1073 // Equivalent to SPR for validation
1074 int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
1075 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
1076 0, 0, 0, Uplo, 0, 0, N, 0,
1077 alpha, 0, X->getID(), 0,
1078 0, 0, Ap->getID(), incX, 0, 0, 0);
1079 }
1080
CHER2(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1081 void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX,
1082 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1083 // Same as SYR2
1084 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1085 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
1086 0, 0, 0, Uplo, 0, 0, N, 0,
1087 alpha.x, alpha.y, X->getID(), Y->getID(),
1088 0, 0, A->getID(), incX, incY, 0, 0);
1089 }
1090
CHPR2(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)1091 void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX,
1092 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
1093 // Same as SPR2
1094 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1095 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
1096 0, 0, 0, Uplo, 0, 0, N, 0,
1097 alpha.x, alpha.y, X->getID(), Y->getID(),
1098 0, 0, Ap->getID(), incX, incY, 0, 0);
1099 }
1100
ZHEMV(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)1101 void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& A,
1102 const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1103 // HEMV is the same as SYR2 validation-wise
1104 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1105 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
1106 0, 0, 0, Uplo, 0, 0, N, 0,
1107 alpha.x, alpha.y, A->getID(), X->getID(),
1108 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1109 }
1110
ZHBMV(RsBlasUplo Uplo,int K,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)1111 void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
1112 int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1113 // HBMV is the same as SYR2 validation-wise
1114 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1115 if (K < 0) {
1116 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1117 }
1118 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
1119 0, 0, 0, Uplo, 0, 0, N, K,
1120 alpha.x, alpha.y, A->getID(), X->getID(),
1121 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1122 }
1123
ZHPMV(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)1124 void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
1125 int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1126 // HPMV is the same as SPR2
1127 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1128 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
1129 0, 0, 0, Uplo, 0, 0, N, 0,
1130 alpha.x, alpha.y, Ap->getID(), X->getID(),
1131 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1132 }
1133
ZGERU(Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1134 void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, const sp<Allocation>& X, int incX,
1135 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1136 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1137 int M = A->getType()->getY();
1138 int N = A->getType()->getX();
1139 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
1140 0, 0, 0, 0, 0, M, N, 0,
1141 alpha.x, alpha.y, X->getID(), Y->getID(),
1142 0, 0, A->getID(), incX, incY, 0, 0);
1143 }
1144
ZGERC(Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1145 void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, const sp<Allocation>& X, int incX,
1146 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1147 // Same as GERU
1148 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1149 int M = A->getType()->getY();
1150 int N = A->getType()->getX();
1151 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
1152 0, 0, 0, 0, 0, M, N, 0,
1153 alpha.x, alpha.y, X->getID(), Y->getID(),
1154 0, 0, A->getID(), incX, incY, 0, 0);
1155 }
1156
ZHER(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)1157 void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
1158 int incX, const sp<Allocation>& A) {
1159 // Same as SYR
1160 int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
1161 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
1162 0, 0, 0, Uplo, 0, 0, N, 0,
1163 alpha, 0, X->getID(), 0,
1164 0, 0, A->getID(), incX, 0, 0, 0);
1165 }
1166
ZHPR(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)1167 void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
1168 int incX, const sp<Allocation>& Ap) {
1169 // Equivalent to SPR for validation
1170 int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
1171 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
1172 0, 0, 0, Uplo, 0, 0, N, 0,
1173 alpha, 0, X->getID(), 0,
1174 0, 0, Ap->getID(), incX, 0, 0, 0);
1175 }
1176
ZHER2(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1177 void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX,
1178 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1179 // Same as SYR2
1180 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1181 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
1182 0, 0, 0, Uplo, 0, 0, N, 0,
1183 alpha.x, alpha.y, X->getID(), Y->getID(),
1184 0, 0, A->getID(), incX, incY, 0, 0);
1185 }
1186
ZHPR2(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)1187 void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX,
1188 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
1189 // Same as SPR2
1190 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1191 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
1192 0, 0, 0, Uplo, 0, 0, N, 0,
1193 alpha.x, alpha.y, X->getID(), Y->getID(),
1194 0, 0, Ap->getID(), incX, incY, 0, 0);
1195 }
1196
1197
1198 /**
1199 * Level 3 BLAS
1200 */
1201
validateL3(RS * mRS,const sp<const Element> & e,int TransA,int TransB,int Side,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1202 static void validateL3(RS* mRS, const sp<const Element>& e, int TransA, int TransB, int Side,
1203 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1204 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
1205 if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
1206 (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
1207 (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
1208 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1209 }
1210 if (C == nullptr) {
1211 // Since matrix C is used to store the result, it cannot be null.
1212 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
1213 }
1214 cM = C->getType()->getY();
1215 cN = C->getType()->getX();
1216
1217 if (Side == RsBlasRight) {
1218 if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
1219 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
1220 }
1221 if (B != nullptr) {
1222 bM = A->getType()->getY();
1223 bN = A->getType()->getX();
1224 }
1225 if (A != nullptr) {
1226 aM = B->getType()->getY();
1227 aN = B->getType()->getX();
1228 }
1229 } else {
1230 if (A != nullptr) {
1231 if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
1232 aN = A->getType()->getY();
1233 aM = A->getType()->getX();
1234 } else {
1235 aM = A->getType()->getY();
1236 aN = A->getType()->getX();
1237 }
1238 }
1239 if (B != nullptr) {
1240 if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
1241 bN = B->getType()->getY();
1242 bM = B->getType()->getX();
1243 } else {
1244 bM = B->getType()->getY();
1245 bN = B->getType()->getX();
1246 }
1247 }
1248 }
1249 if (A != nullptr && B != nullptr && C != nullptr) {
1250 if (aN != bM || aM != cM || bN != cN) {
1251 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1252 }
1253 } else if (A != nullptr && C != nullptr) {
1254 // A and C only, for SYRK
1255 if (cM != cN) {
1256 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
1257 }
1258 if (aM != cM) {
1259 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1260 }
1261 } else if (A != nullptr && B != nullptr) {
1262 // A and B only
1263 if (aN != bM) {
1264 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1265 }
1266 }
1267
1268 }
1269
SGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,float alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1270 void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
1271 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1272 validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
1273
1274 int M = -1, N = -1, K = -1;
1275 if (TransA != RsBlasNoTrans) {
1276 M = A->getType()->getX();
1277 K = A->getType()->getY();
1278 } else {
1279 M = A->getType()->getY();
1280 K = A->getType()->getX();
1281 }
1282 if (TransB != RsBlasNoTrans) {
1283 N = B->getType()->getY();
1284 } else {
1285 N = B->getType()->getX();
1286 }
1287 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
1288 TransA, TransB, 0, 0, 0, M, N, K,
1289 alpha, A->getID(), B->getID(),
1290 beta, C->getID(), 0, 0, 0, 0);
1291 }
1292
DGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,double alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1293 void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
1294 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1295 validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
1296 int M = -1, N = -1, K = -1;
1297 if (TransA != RsBlasNoTrans) {
1298 M = A->getType()->getX();
1299 K = A->getType()->getY();
1300 } else {
1301 M = A->getType()->getY();
1302 K = A->getType()->getX();
1303 }
1304 if (TransB != RsBlasNoTrans) {
1305 N = B->getType()->getY();
1306 } else {
1307 N = B->getType()->getX();
1308 }
1309 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
1310 TransA, TransB, 0, 0, 0, M, N, K,
1311 alpha, A->getID(), B->getID(),
1312 beta, C->getID(), 0, 0, 0, 0);
1313 }
1314
CGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1315 void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
1316 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1317 validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
1318 int M = -1, N = -1, K = -1;
1319 if (TransA != RsBlasNoTrans) {
1320 M = A->getType()->getX();
1321 K = A->getType()->getY();
1322 } else {
1323 M = A->getType()->getY();
1324 K = A->getType()->getX();
1325 }
1326 if (TransB != RsBlasNoTrans) {
1327 N = B->getType()->getY();
1328 } else {
1329 N = B->getType()->getX();
1330 }
1331 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
1332 TransA, TransB, 0, 0, 0, M, N, K,
1333 alpha.x, alpha.y, A->getID(), B->getID(),
1334 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1335 }
1336
ZGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1337 void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
1338 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1339 validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
1340 int M = -1, N = -1, K = -1;
1341 if (TransA != RsBlasNoTrans) {
1342 M = A->getType()->getX();
1343 K = A->getType()->getY();
1344 } else {
1345 M = A->getType()->getY();
1346 K = A->getType()->getX();
1347 }
1348 if (TransB != RsBlasNoTrans) {
1349 N = B->getType()->getY();
1350 } else {
1351 N = B->getType()->getX();
1352 }
1353 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
1354 TransA, TransB, 0, 0, 0, M, N, K,
1355 alpha.x, alpha.y, A->getID(), B->getID(),
1356 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1357 }
1358
SSYMM(RsBlasSide Side,RsBlasUplo Uplo,float alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1359 void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
1360 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1361 //For SYMM, Matrix A should be symmetric
1362 if (A->getType()->getX() != A->getType()->getY()) {
1363 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1364 }
1365 validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
1366 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
1367 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1368 alpha, A->getID(), B->getID(),
1369 beta, C->getID(), 0, 0, 0, 0);
1370 }
1371
DSYMM(RsBlasSide Side,RsBlasUplo Uplo,double alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1372 void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
1373 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1374 if (A->getType()->getX() != A->getType()->getY()) {
1375 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1376 }
1377 validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
1378 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
1379 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1380 alpha, A->getID(), B->getID(),
1381 beta, C->getID(), 0, 0, 0, 0);
1382 }
1383
CSYMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1384 void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1385 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1386 if (A->getType()->getX() != A->getType()->getY()) {
1387 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1388 }
1389 validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
1390 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
1391 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1392 alpha.x, alpha.y, A->getID(), B->getID(),
1393 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1394 }
1395
ZSYMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1396 void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1397 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1398 if (A->getType()->getX() != A->getType()->getY()) {
1399 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1400 }
1401 validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
1402 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
1403 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1404 alpha.x, alpha.y, A->getID(), B->getID(),
1405 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1406 }
1407
SSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,const sp<Allocation> & A,float beta,const sp<Allocation> & C)1408 void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1409 const sp<Allocation>& A, float beta, const sp<Allocation>& C) {
1410 validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
1411 int K = -1;
1412 if (Trans != RsBlasNoTrans) {
1413 K = A->getType()->getY();
1414 } else {
1415 K = A->getType()->getX();
1416 }
1417 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
1418 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1419 alpha, A->getID(), 0,
1420 beta, C->getID(), 0, 0, 0, 0);
1421 }
1422
DSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,const sp<Allocation> & A,double beta,const sp<Allocation> & C)1423 void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1424 const sp<Allocation>& A, double beta, const sp<Allocation>& C) {
1425 validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
1426 int K = -1;
1427 if (Trans != RsBlasNoTrans) {
1428 K = A->getType()->getY();
1429 } else {
1430 K = A->getType()->getX();
1431 }
1432 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
1433 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1434 alpha, A->getID(), 0,
1435 beta, C->getID(), 0, 0, 0, 0);
1436 }
1437
CSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,const sp<Allocation> & A,Float2 beta,const sp<Allocation> & C)1438 void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1439 const sp<Allocation>& A, Float2 beta, const sp<Allocation>& C) {
1440 validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
1441 int K = -1;
1442 if (Trans != RsBlasNoTrans) {
1443 K = A->getType()->getY();
1444 } else {
1445 K = A->getType()->getX();
1446 }
1447 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
1448 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1449 alpha.x, alpha.y, A->getID(), 0,
1450 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1451 }
1452
ZSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,const sp<Allocation> & A,Double2 beta,const sp<Allocation> & C)1453 void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1454 const sp<Allocation>& A, Double2 beta, const sp<Allocation>& C) {
1455 validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
1456 int K = -1;
1457 if (Trans != RsBlasNoTrans) {
1458 K = A->getType()->getY();
1459 } else {
1460 K = A->getType()->getX();
1461 }
1462 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
1463 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1464 alpha.x, alpha.y, A->getID(), 0,
1465 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1466 }
1467
validateSYR2K(RS * mRS,const sp<const Element> & e,RsBlasTranspose Trans,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1468 static void validateSYR2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1469 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1470 if (!A->getType()->getElement()->isCompatible(e) ||
1471 !B->getType()->getElement()->isCompatible(e) ||
1472 !C->getType()->getElement()->isCompatible(e)) {
1473 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1474 }
1475 int Cdim = -1;
1476 // A is n x k if no transpose, k x n if transpose
1477 // C is n x n
1478 if (Trans == RsBlasTrans) {
1479 // check columns versus C
1480 Cdim = A->getType()->getX();
1481 } else {
1482 // check rows versus C
1483 Cdim = A->getType()->getY();
1484 }
1485 if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
1486 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
1487 }
1488 // A dims == B dims
1489 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1490 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
1491 }
1492 }
1493
SSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1494 void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1495 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1496 validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
1497 int K = -1;
1498 if (Trans != RsBlasNoTrans) {
1499 K = A->getType()->getY();
1500 } else {
1501 K = A->getType()->getX();
1502 }
1503 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
1504 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1505 alpha, A->getID(), B->getID(),
1506 beta, C->getID(), 0, 0, 0, 0);
1507 }
1508
DSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1509 void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1510 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1511 validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
1512 int K = -1;
1513 if (Trans != RsBlasNoTrans) {
1514 K = A->getType()->getY();
1515 } else {
1516 K = A->getType()->getX();
1517 }
1518 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
1519 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1520 alpha, A->getID(), B->getID(),
1521 beta, C->getID(), 0, 0, 0, 0);
1522 }
1523
CSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1524 void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1525 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1526 validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1527 int K = -1;
1528 if (Trans != RsBlasNoTrans) {
1529 K = A->getType()->getY();
1530 } else {
1531 K = A->getType()->getX();
1532 }
1533 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
1534 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1535 alpha.x, alpha.y, A->getID(), B->getID(),
1536 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1537 }
1538
ZSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1539 void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1540 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1541 validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1542 int K = -1;
1543 if (Trans != RsBlasNoTrans) {
1544 K = A->getType()->getY();
1545 } else {
1546 K = A->getType()->getX();
1547 }
1548 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
1549 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1550 alpha.x, alpha.y, A->getID(), B->getID(),
1551 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1552 }
1553
validateTRMM(RS * mRS,const sp<const Element> & e,RsBlasSide Side,RsBlasTranspose TransA,const sp<Allocation> & A,const sp<Allocation> & B)1554 static void validateTRMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA,
1555 const sp<Allocation>& A, const sp<Allocation>& B) {
1556 int aM = -1, aN = -1, bM = -1, bN = -1;
1557 if (!A->getType()->getElement()->isCompatible(e) ||
1558 !B->getType()->getElement()->isCompatible(e)) {
1559 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1560 }
1561
1562 aM = A->getType()->getY();
1563 aN = A->getType()->getX();
1564 if (aM != aN) {
1565 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
1566 }
1567
1568 bM = B->getType()->getY();
1569 bN = B->getType()->getX();
1570 if (Side == RsBlasLeft) {
1571 if (aN != bM) {
1572 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1573 }
1574 } else {
1575 if (bN != aM) {
1576 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1577 }
1578 }
1579 }
1580
STRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,const sp<Allocation> & A,const sp<Allocation> & B)1581 void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1582 float alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1583 validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
1584 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
1585 TransA, 0, Side, Uplo, Diag,\
1586 B->getType()->getY(), B->getType()->getX(), 0,
1587 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
1588 }
1589
DTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,const sp<Allocation> & A,const sp<Allocation> & B)1590 void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1591 double alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1592 validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
1593 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
1594 TransA, 0, Side, Uplo, Diag,
1595 B->getType()->getY(), B->getType()->getX(), 0,
1596 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1597 }
1598
CTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1599 void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1600 Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1601 validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1602 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
1603 TransA, 0, Side, Uplo, Diag,
1604 B->getType()->getY(), B->getType()->getX(), 0,
1605 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1606 }
1607
ZTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1608 void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1609 Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1610 validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1611 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
1612 TransA, 0, Side, Uplo, Diag,
1613 B->getType()->getY(), B->getType()->getX(), 0,
1614 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1615 }
1616
validateTRSM(RS * mRS,const sp<const Element> & e,RsBlasSide Side,RsBlasTranspose TransA,const sp<Allocation> & A,const sp<Allocation> & B)1617 static void validateTRSM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA,
1618 const sp<Allocation>& A, const sp<Allocation>& B) {
1619 int adim = -1, bM = -1, bN = -1;
1620 if (!A->getType()->getElement()->isCompatible(e) ||
1621 !B->getType()->getElement()->isCompatible(e)) {
1622 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1623 }
1624 adim = A->getType()->getX();
1625 if (adim != (int)A->getType()->getY()) {
1626 // This may be unnecessary, the restriction could potentially be relaxed.
1627 // Allocation A needs to contain at least that symmetric matrix but could theoretically
1628 // be larger for now we assume adapters are sufficient, will reevaluate in the future.
1629 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
1630 }
1631 bM = B->getType()->getY();
1632 bN = B->getType()->getX();
1633 if (Side == RsBlasLeft) {
1634 // A is M*M
1635 if (adim != bM) {
1636 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1637 }
1638 } else {
1639 // A is N*N
1640 if (adim != bN) {
1641 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1642 }
1643 }
1644 }
1645
STRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,const sp<Allocation> & A,const sp<Allocation> & B)1646 void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1647 float alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1648 validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
1649 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
1650 TransA, 0, Side, Uplo, Diag,
1651 B->getType()->getY(), B->getType()->getX(), 0,
1652 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1653 }
1654
DTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,const sp<Allocation> & A,const sp<Allocation> & B)1655 void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1656 double alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1657 validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
1658 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
1659 TransA, 0, Side, Uplo, Diag,
1660 B->getType()->getY(), B->getType()->getX(), 0,
1661 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1662 }
1663
CTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1664 void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1665 Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1666 validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1667 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
1668 TransA, 0, Side, Uplo, Diag,
1669 B->getType()->getY(), B->getType()->getX(), 0,
1670 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1671 }
1672
ZTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1673 void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1674 Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1675 validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1676 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
1677 TransA, 0, Side, Uplo, Diag,
1678 B->getType()->getY(), B->getType()->getX(), 0,
1679 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1680 }
1681
validateHEMM(RS * mRS,const sp<const Element> & e,RsBlasSide Side,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1682 static void validateHEMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side,
1683 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1684 if (!A->getType()->getElement()->isCompatible(e) ||
1685 !B->getType()->getElement()->isCompatible(e) ||
1686 !C->getType()->getElement()->isCompatible(e)) {
1687 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1688 }
1689
1690 // A must be square; can potentially be relaxed similar to TRSM
1691 int adim = A->getType()->getX();
1692 if (adim != (int)A->getType()->getY()) {
1693 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
1694 }
1695 if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
1696 (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
1697 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
1698 }
1699 if (B->getType()->getX() != C->getType()->getX() ||
1700 B->getType()->getY() != C->getType()->getY()) {
1701 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
1702 }
1703 }
1704
CHEMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1705 void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1706 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1707 validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
1708 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
1709 0, 0, Side, Uplo, 0,
1710 C->getType()->getY(), C->getType()->getX(), 0,
1711 alpha.x, alpha.y, A->getID(), B->getID(),
1712 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1713 }
1714
ZHEMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1715 void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1716 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1717 validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
1718 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
1719 0, 0, Side, Uplo, 0,
1720 C->getType()->getY(), C->getType()->getX(), 0,
1721 alpha.x, alpha.y, A->getID(), B->getID(),
1722 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1723 }
1724
validateHERK(RS * mRS,const sp<const Element> & e,RsBlasTranspose Trans,const sp<Allocation> & A,const sp<Allocation> & C)1725 static void validateHERK(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1726 const sp<Allocation>& A, const sp<Allocation>& C) {
1727 if (!A->getType()->getElement()->isCompatible(e) ||
1728 !C->getType()->getElement()->isCompatible(e)) {
1729 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1730 }
1731 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1732 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1733 }
1734 int cdim = C->getType()->getX();
1735 if (cdim != (int)C->getType()->getY()) {
1736 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
1737 }
1738 if (Trans == RsBlasNoTrans) {
1739 if (cdim != (int)A->getType()->getY()) {
1740 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1741 }
1742 } else {
1743 if (cdim != (int)A->getType()->getX()) {
1744 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1745 }
1746 }
1747 }
1748
CHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,const sp<Allocation> & A,float beta,const sp<Allocation> & C)1749 void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1750 const sp<Allocation>& A, float beta, const sp<Allocation>& C) {
1751 validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
1752 int k = 0;
1753 if (Trans == RsBlasConjTrans) {
1754 k = A->getType()->getY();
1755 } else {
1756 k = A->getType()->getX();
1757 }
1758 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
1759 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1760 alpha, 0, A->getID(), 0,
1761 beta, 0, C->getID(), 0, 0, 0, 0);
1762 }
1763
ZHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,const sp<Allocation> & A,double beta,const sp<Allocation> & C)1764 void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1765 const sp<Allocation>& A, double beta, const sp<Allocation>& C) {
1766 validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
1767 int k = 0;
1768 if (Trans == RsBlasConjTrans) {
1769 k = A->getType()->getY();
1770 } else {
1771 k = A->getType()->getX();
1772 }
1773 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
1774 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1775 alpha, 0, A->getID(), 0,
1776 beta, 0, C->getID(), 0, 0, 0, 0);
1777 }
1778
validateHER2K(RS * mRS,const sp<const Element> & e,RsBlasTranspose Trans,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1779 static void validateHER2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1780 const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1781 if (!A->getType()->getElement()->isCompatible(e) ||
1782 !B->getType()->getElement()->isCompatible(e) ||
1783 !C->getType()->getElement()->isCompatible(e)) {
1784 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1785 }
1786 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1787 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1788 }
1789 int cdim = C->getType()->getX();
1790 if (cdim != (int)C->getType()->getY()) {
1791 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
1792 }
1793 if (Trans == RsBlasNoTrans) {
1794 if ((int)A->getType()->getY() != cdim) {
1795 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1796 }
1797 } else {
1798 if ((int)A->getType()->getX() != cdim) {
1799 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1800 }
1801 }
1802 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1803 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
1804 }
1805 }
1806
CHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1807 void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1808 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1809 validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1810 int k = 0;
1811 if (Trans == RsBlasNoTrans) {
1812 k = A->getType()->getX();
1813 } else {
1814 k = A->getType()->getY();
1815 }
1816 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
1817 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1818 alpha.x, alpha.y, A->getID(), B->getID(),
1819 beta, 0, C->getID(), 0, 0, 0, 0);
1820 }
1821
ZHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1822 void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1823 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1824 validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1825 int k = 0;
1826 if (Trans == RsBlasNoTrans) {
1827 k = A->getType()->getX();
1828 } else {
1829 k = A->getType()->getY();
1830 }
1831 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
1832 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1833 alpha.x, alpha.y, A->getID(), B->getID(),
1834 beta, 0, C->getID(), 0, 0, 0, 0);
1835 }
1836
1837
1838
BNNM(const sp<Allocation> & A,int a_offset,const sp<Allocation> & B,int b_offset,const sp<Allocation> & C,int c_offset,int c_mult)1839 void ScriptIntrinsicBLAS::BNNM(const sp<Allocation>& A, int a_offset, const sp<Allocation>& B, int b_offset,
1840 const sp<Allocation>& C, int c_offset, int c_mult) {
1841 validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
1842
1843 if (a_offset < 0 || a_offset > 255) {
1844 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
1845 }
1846 if (b_offset < 0 || b_offset > 255) {
1847 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
1848 }
1849 int M = -1, N = -1, K = -1;
1850 M = A->getType()->getY();
1851 N = B->getType()->getY();
1852 K = A->getType()->getX();
1853
1854 nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
1855 B->getID(), b_offset, C->getID(), c_offset, c_mult);
1856 }
1857