matrix: use a custom low-level C structure to avoid boxing of param types
[nit.git] / lib / matrix / matrix.nit
1 # This file is part of NIT ( http://www.nitlanguage.org ).
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 # Services for matrices of `Float` values
16 module matrix
17
18 # A rectangular array of `Float`
19 #
20 # Require: `width > 0 and height > 0`
21 class Matrix
22 super Cloneable
23
24 # Number of columns
25 var width: Int
26
27 # Number of rows
28 var height: Int
29
30 # Items of this matrix, rows by rows
31 private var items = new NativeDoubleArray(width*height) is lazy
32
33 # Create a matrix from nested sequences
34 #
35 # Require: all rows are of the same length
36 #
37 # ~~~
38 # var matrix = new Matrix.from([[1.0, 2.0],
39 # [3.0, 4.0]])
40 # assert matrix.to_s == """
41 # 1.0 2.0
42 # 3.0 4.0"""
43 # ~~~
44 init from(items: SequenceRead[SequenceRead[Float]])
45 do
46 if items.is_empty then
47 init(0, 0)
48 return
49 end
50
51 init(items.first.length, items.length)
52
53 for j in height.times do assert items[j].length == width
54
55 for j in height.times do
56 for i in width.times do
57 self[j, i] = items[j][i]
58 end
59 end
60 end
61
62 # Get each row of this matrix in nested arrays
63 #
64 # ~~~
65 # var items = [[1.0, 2.0],
66 # [3.0, 4.0]]
67 # var matrix = new Matrix.from(items)
68 # assert matrix.to_a == items
69 # ~~~
70 fun to_a: Array[Array[Float]]
71 do
72 var a = new Array[Array[Float]]
73 for j in height.times do
74 var row = new Array[Float]
75 for i in width.times do
76 row.add self[j, i]
77 end
78 a.add row
79 end
80 return a
81 end
82
83 # Create a matrix from an `Array[Float]` composed of rows after rows
84 #
85 # Require: `width > 0 and height > 0`
86 # Require: `array.length >= width*height`
87 #
88 # ~~~
89 # var matrix = new Matrix.from_array(2, 2, [1.0, 2.0,
90 # 3.0, 4.0])
91 # assert matrix.to_s == """
92 # 1.0 2.0
93 # 3.0 4.0"""
94 # ~~~
95 init from_array(width, height: Int, array: SequenceRead[Float])
96 do
97 assert width > 0
98 assert height > 0
99 assert array.length >= width*height
100
101 init(width, height)
102
103 for i in height.times do
104 for j in width.times do
105 self[j, i] = array[i + j*width]
106 end
107 end
108 end
109
110 # Create an identity matrix
111 #
112 # Require: `size >= 0`
113 #
114 # ~~~
115 # var i = new Matrix.identity(3)
116 # assert i.to_s == """
117 # 1.0 0.0 0.0
118 # 0.0 1.0 0.0
119 # 0.0 0.0 1.0"""
120 # ~~~
121 new identity(size: Int)
122 do
123 assert size >= 0
124
125 var matrix = new Matrix(size, size)
126 for i in size.times do
127 for j in size.times do
128 matrix[j, i] = if i == j then 1.0 else 0.0
129 end
130 end
131 return matrix
132 end
133
134 # Create a new clone of this matrix
135 redef fun clone
136 do
137 var c = new Matrix(width, height)
138 for i in [0..width*height[ do c.items[i] = items[i]
139 return c
140 end
141
142 # Get the value at column `y` and row `x`
143 #
144 # Require: `x >= 0 and x <= width and y >= 0 and y <= height`
145 #
146 # ~~~
147 # var matrix = new Matrix.from([[0.0, 0.1],
148 # [1.0, 1.1]])
149 #
150 # assert matrix[0, 0] == 0.0
151 # assert matrix[0, 1] == 0.1
152 # assert matrix[1, 0] == 1.0
153 # assert matrix[1, 1] == 1.1
154 # ~~~
155 fun [](y, x: Int): Float
156 do
157 assert x >= 0 and x < width
158 assert y >= 0 and y < height
159
160 return items[x + y*width]
161 end
162
163 # Set the `value` at row `y` and column `x`
164 #
165 # Require: `x >= 0 and x <= width and y >= 0 and y <= height`
166 #
167 # ~~~
168 # var matrix = new Matrix.identity(2)
169 #
170 # matrix[0, 0] = 0.0
171 # matrix[0, 1] = 0.1
172 # matrix[1, 0] = 1.0
173 # matrix[1, 1] = 1.1
174 #
175 # assert matrix.to_s == """
176 # 0.0 0.1
177 # 1.0 1.1"""
178 # ~~~
179 fun []=(y, x: Int, value: Float)
180 do
181 assert x >= 0 and x < width
182 assert y >= 0 and y < height
183
184 items[x + y*width] = value
185 end
186
187 # Matrix product (×)
188 #
189 # Require: `self.width == other.height`
190 #
191 # ~~~
192 # var m = new Matrix.from([[3.0, 4.0],
193 # [5.0, 6.0]])
194 # var i = new Matrix.identity(2)
195 #
196 # assert m * i == m
197 # assert (m * m).to_s == """
198 # 29.0 36.0
199 # 45.0 56.0"""
200 #
201 # var a = new Matrix.from([[1.0, 2.0, 3.0],
202 # [4.0, 5.0, 6.0]])
203 # var b = new Matrix.from([[1.0],
204 # [2.0],
205 # [3.0]])
206 # var c = a * b
207 # assert c.to_s == """
208 # 14.0
209 # 32.0"""
210 # ~~~
211 fun *(other: Matrix): Matrix
212 do
213 assert self.width == other.height
214
215 var out = new Matrix(other.width, self.height)
216 for j in self.height.times do
217 for i in other.width.times do
218 var sum = items[0].zero
219 for k in self.width.times do sum += self[j, k] * other[k, i]
220 out[j, i] = sum
221 end
222 end
223 return out
224 end
225
226 # Get the transpose of this matrix
227 #
228 # ~~~
229 # var matrix = new Matrix.from([[1.0, 2.0, 3.0],
230 # [4.0, 5.0, 6.0]])
231 # assert matrix.transposed.to_a == [[1.0, 4.0],
232 # [2.0, 5.0],
233 # [3.0, 6.0]]
234 #
235 # var i = new Matrix.identity(3)
236 # assert i.transposed == i
237 # ~~~
238 fun transposed: Matrix
239 do
240 var out = new Matrix(height, width)
241 for k, v in self do out[k.x, k.y] = v
242 return out
243 end
244
245 # Iterate over the values in this matrix
246 fun iterator: MapIterator[MatrixCoordinate, Float] do return new MatrixIndexIterator(self)
247
248 redef fun to_s
249 do
250 var s = new FlatBuffer
251 for y in [0..height[ do
252 for x in [0..width[ do
253 s.append items[y*width+x].to_s
254 if x < width-1 then s.add ' '
255 end
256 if y < height-1 then s.add '\n'
257 end
258 return s.to_s
259 end
260
261 redef fun ==(other) do return other isa Matrix and
262 width == other.width and height == other.height and
263 items.equal_items(items, width*height)
264
265 redef fun hash do return items.hash_items(width*height)
266 end
267
268 private class MatrixIndexIterator
269 super MapIterator[MatrixCoordinate, Float]
270
271 var matrix: Matrix
272
273 redef var key = new MatrixCoordinate(0, 0)
274
275 redef fun is_ok do return key.y < matrix.height
276
277 redef fun item
278 do
279 assert is_ok
280 return matrix[key.y, key.x]
281 end
282
283 redef fun next
284 do
285 assert is_ok
286 var key = key
287 if key.x == matrix.width - 1 then
288 key.x = 0
289 key.y += 1
290 else
291 key.x += 1
292 end
293 end
294 end
295
296 # Position key when iterating over the values of a matrix
297 class MatrixCoordinate
298 # Index of the current column
299 var x: Int
300
301 # Index of the current row
302 var y: Int
303
304 redef fun to_s do return "({x},{y})"
305 end
306
307 # Specialized native structure to store matrix items and avoid boxing cost
308 private extern class NativeDoubleArray `{ double* `}
309
310 new(size: Int) do
311 var sizeof_double = 8
312 var buf = new CString(sizeof_double*size)
313 return new NativeDoubleArray.in_buffer(buf)
314 end
315
316 new in_buffer(buffer: CString) `{ return (double*)buffer; `}
317
318 fun [](i: Int): Float `{ return self[i]; `}
319
320 fun []=(i: Int, value: Float) `{ self[i] = value; `}
321
322 fun equal_items(other: NativeDoubleArray, len: Int): Bool `{
323 int i;
324 for (i = 0; i < len; i ++)
325 if (self[i] != other[i])
326 return 0;
327 return 1;
328 `}
329
330 fun hash_items(len: Int): Int `{
331 // Adapted from `SequenceRead::hash`
332 long r = 17+len;
333 int i;
334 for (i = 0; i < len; i ++)
335 r = r * 3 / 2 + (long)(i*1024.0);
336 return r;
337 `}
338 end