1
1
package org .clulab .scala_transformers .encoder
2
2
3
3
import ai .onnxruntime .{OnnxTensor , OrtEnvironment , OrtSession }
4
- import org .clulab .scala_transformers .encoder .math .Mathematics
5
4
import org .clulab .scala_transformers .encoder .math .Mathematics .{Math , MathMatrix }
6
5
7
6
import java .io .DataInputStream
8
7
import java .util .{HashMap => JHashMap }
9
8
10
- class Encoder (val encoderEnvironment : OrtEnvironment , val encoderSession : OrtSession , nonLin : Option [NonLinearity ] = None ) {
9
+ class Encoder (val encoderEnvironment : OrtEnvironment , val encoderSession : OrtSession , nonLinOpt : Option [NonLinearity ] = None ) {
11
10
/**
12
11
* Runs the inference using a transformer encoder over a batch of sentences
13
12
*
@@ -21,18 +20,11 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes
21
20
val result : OrtSession .Result = encoderSession.run(inputs)
22
21
val outputs = Math .fromResult(result)
23
22
24
- if (nonLin.isDefined) {
25
- for (matrix <- outputs) {
26
- for (i <- 0 until Math .rows(matrix)) {
27
- val row = Math .row(matrix, i)
28
- for (j <- 0 until Math .cols(matrix)) {
29
- val orig = Math .get(row, j)
30
- Math .set(row, j, nonLin.get.compute(orig))
31
- }
32
- }
23
+ nonLinOpt.foreach { nonLin =>
24
+ outputs.foreach { matrix =>
25
+ Math .map(matrix, nonLin.compute)
33
26
}
34
27
}
35
-
36
28
outputs
37
29
}
38
30
0 commit comments