Skip to content

Commit 89cab31

Browse files
authoredApr 6, 2025··
Merge pull request #60 from clulab/kwalcock/nonlinearilty
Implement nonLin slightly differently
2 parents 6f45c7d + 1dd8069 commit 89cab31

File tree

5 files changed

+20
-17
lines changed

5 files changed

+20
-17
lines changed
 

Diff for: ‎encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
package org.clulab.scala_transformers.encoder
22

33
import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}
4-
import org.clulab.scala_transformers.encoder.math.Mathematics
54
import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix}
65

76
import java.io.DataInputStream
87
import java.util.{HashMap => JHashMap}
98

10-
class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLin: Option[NonLinearity] = None) {
9+
class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLinOpt: Option[NonLinearity] = None) {
1110
/**
1211
* Runs the inference using a transformer encoder over a batch of sentences
1312
*
@@ -21,18 +20,11 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes
2120
val result: OrtSession.Result = encoderSession.run(inputs)
2221
val outputs = Math.fromResult(result)
2322

24-
if(nonLin.isDefined) {
25-
for (matrix <- outputs) {
26-
for (i <- 0 until Math.rows(matrix)) {
27-
val row = Math.row(matrix, i)
28-
for (j <- 0 until Math.cols(matrix)) {
29-
val orig = Math.get(row, j)
30-
Math.set(row, j, nonLin.get.compute(orig))
31-
}
32-
}
23+
nonLinOpt.foreach { nonLin =>
24+
outputs.foreach { matrix =>
25+
Math.map(matrix, nonLin.compute)
3326
}
3427
}
35-
3628
outputs
3729
}
3830

Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
package org.clulab.scala_transformers.encoder
22

3-
import org.clulab.scala_transformers.encoder.math.EjmlMath.MathValue
4-
5-
import java.lang
3+
import org.clulab.scala_transformers.encoder.math.Mathematics.MathValue
64

75
trait NonLinearity {
86
def compute(input: MathValue): MathValue
97
}
108

11-
class ReLU extends NonLinearity {
9+
object ReLU extends NonLinearity {
1210
override def compute(input: MathValue): MathValue = {
13-
lang.Float.max(0, input)
11+
scala.math.max(0, input)
1412
}
1513
}

Diff for: ‎encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala

+11
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ object EjmlMath extends Math {
100100
new FMatrixRMaj(rows, cols)
101101
}
102102

103+
override def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit = {
104+
val iterator = matrix.iterator(true, 0, 0, matrix.getNumRows, matrix.getNumCols)
105+
106+
while (iterator.hasNext) {
107+
val oldValue = iterator.next()
108+
val newValue = f(oldValue)
109+
110+
iterator.set(newValue)
111+
}
112+
}
113+
103114
def row(matrix: MathRowMatrix, index: Int): MathRowVector = {
104115
val result = SimpleMatrix.wrap(matrix).rows(index, index + 1).getMatrix[FMatrixRMaj]
105116

Diff for: ‎encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ trait Math {
1313
def inplaceMatrixAddition(matrix: MathRowMatrix, colVector: MathColVector): Unit
1414
def inplaceMatrixAddition(matrix: MathRowMatrix, rowIndex: Int, rowVector: MathRowVector): Unit
1515
// def rowVectorAddition(leftRowVector: MathRowVector, rightRowVector: MathRowVector): MathRowVector
16+
def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit
1617
def mul(leftMatrix: MathRowMatrix, rightMatrix: MathRowMatrix): MathRowMatrix
1718
def rows(matrix: MathRowMatrix): Int
1819
def cols(matrix: MathRowMatrix): Int

Diff for: ‎encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ object Mathematics {
77
// val Math = CommonsMath
88
// val Math = CluMath
99

10+
type MathValue = Math.MathValue
1011
type MathMatrix = Math.MathRowMatrix
1112
type MathColVector = Math.MathColVector
1213
type MathRowVector = Math.MathRowVector

0 commit comments

Comments
 (0)
Please sign in to comment.