91adb80f3075322baa4ec3199ea0422f71e7a09b
[nit.git] / lib / matrix / matrix.nit
1 # This file is part of NIT ( http://www.nitlanguage.org ).
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 # Services for matrices of `Float` values
16 module matrix
17
18 # A rectangular array of `Float`
19 #
20 # Require: `width > 0 and height > 0`
21 class Matrix
22 super Cloneable
23
24 # Number of columns
25 var width: Int
26
27 # Number of rows
28 var height: Int
29
30 # Items of this matrix, rows by rows
31 var items: Array[Float] is lazy do
32 return new Array[Float].filled_with(0.0, width*height)
33 end
34
35 # Create a matrix from nested sequences
36 #
37 # Require: all rows are of the same length
38 #
39 # ~~~
40 # var matrix = new Matrix.from([[1.0, 2.0],
41 # [3.0, 4.0]])
42 # assert matrix.to_s == """
43 # 1.0 2.0
44 # 3.0 4.0"""
45 # ~~~
46 init from(items: SequenceRead[SequenceRead[Float]])
47 do
48 if items.is_empty then
49 init(0, 0)
50 return
51 end
52
53 init(items.first.length, items.length)
54
55 for j in height.times do assert items[j].length == width
56
57 for j in height.times do
58 for i in width.times do
59 self[j, i] = items[j][i]
60 end
61 end
62 end
63
64 # Get each row of this matrix in nested arrays
65 #
66 # ~~~
67 # var items = [[1.0, 2.0],
68 # [3.0, 4.0]]
69 # var matrix = new Matrix.from(items)
70 # assert matrix.to_a == items
71 # ~~~
72 fun to_a: Array[Array[Float]]
73 do
74 var a = new Array[Array[Float]]
75 for j in height.times do
76 var row = new Array[Float]
77 for i in width.times do
78 row.add self[j, i]
79 end
80 a.add row
81 end
82 return a
83 end
84
85 # Create a matrix from an `Array[Float]` composed of rows after rows
86 #
87 # Require: `width > 0 and height > 0`
88 # Require: `array.length >= width*height`
89 #
90 # ~~~
91 # var matrix = new Matrix.from_array(2, 2, [1.0, 2.0,
92 # 3.0, 4.0])
93 # assert matrix.to_s == """
94 # 1.0 2.0
95 # 3.0 4.0"""
96 # ~~~
97 init from_array(width, height: Int, array: SequenceRead[Float])
98 do
99 assert width > 0
100 assert height > 0
101 assert array.length >= width*height
102
103 init(width, height)
104
105 for i in height.times do
106 for j in width.times do
107 self[j, i] = array[i + j*width]
108 end
109 end
110 end
111
112 # Create an identity matrix
113 #
114 # Require: `size >= 0`
115 #
116 # ~~~
117 # var i = new Matrix.identity(3)
118 # assert i.to_s == """
119 # 1.0 0.0 0.0
120 # 0.0 1.0 0.0
121 # 0.0 0.0 1.0"""
122 # ~~~
123 new identity(size: Int)
124 do
125 assert size >= 0
126
127 var matrix = new Matrix(size, size)
128 for i in size.times do
129 for j in size.times do
130 matrix[j, i] = if i == j then 1.0 else 0.0
131 end
132 end
133 return matrix
134 end
135
136 # Create a new clone of this matrix
137 redef fun clone do return new Matrix.from_array(width, height, items.clone)
138
139 # Get the value at column `y` and row `x`
140 #
141 # Require: `x >= 0 and x <= width and y >= 0 and y <= height`
142 #
143 # ~~~
144 # var matrix = new Matrix.from([[0.0, 0.1],
145 # [1.0, 1.1]])
146 #
147 # assert matrix[0, 0] == 0.0
148 # assert matrix[0, 1] == 0.1
149 # assert matrix[1, 0] == 1.0
150 # assert matrix[1, 1] == 1.1
151 # ~~~
152 fun [](y, x: Int): Float
153 do
154 assert x >= 0 and x < width
155 assert y >= 0 and y < height
156
157 return items[x + y*width]
158 end
159
160 # Set the `value` at row `y` and column `x`
161 #
162 # Require: `x >= 0 and x <= width and y >= 0 and y <= height`
163 #
164 # ~~~
165 # var matrix = new Matrix.identity(2)
166 #
167 # matrix[0, 0] = 0.0
168 # matrix[0, 1] = 0.1
169 # matrix[1, 0] = 1.0
170 # matrix[1, 1] = 1.1
171 #
172 # assert matrix.to_s == """
173 # 0.0 0.1
174 # 1.0 1.1"""
175 # ~~~
176 fun []=(y, x: Int, value: Float)
177 do
178 assert x >= 0 and x < width
179 assert y >= 0 and y < height
180
181 items[x + y*width] = value
182 end
183
184 # Matrix product (×)
185 #
186 # Require: `self.width == other.height`
187 #
188 # ~~~
189 # var m = new Matrix.from([[3.0, 4.0],
190 # [5.0, 6.0]])
191 # var i = new Matrix.identity(2)
192 #
193 # assert m * i == m
194 # assert (m * m).to_s == """
195 # 29.0 36.0
196 # 45.0 56.0"""
197 #
198 # var a = new Matrix.from([[1.0, 2.0, 3.0],
199 # [4.0, 5.0, 6.0]])
200 # var b = new Matrix.from([[1.0],
201 # [2.0],
202 # [3.0]])
203 # var c = a * b
204 # assert c.to_s == """
205 # 14.0
206 # 32.0"""
207 # ~~~
208 fun *(other: Matrix): Matrix
209 do
210 assert self.width == other.height
211
212 var out = new Matrix(other.width, self.height)
213 for j in self.height.times do
214 for i in other.width.times do
215 var sum = items.first.zero
216 for k in self.width.times do sum += self[j, k] * other[k, i]
217 out[j, i] = sum
218 end
219 end
220 return out
221 end
222
223 # Get the transpose of this matrix
224 #
225 # ~~~
226 # var matrix = new Matrix.from([[1.0, 2.0, 3.0],
227 # [4.0, 5.0, 6.0]])
228 # assert matrix.transposed.to_a == [[1.0, 4.0],
229 # [2.0, 5.0],
230 # [3.0, 6.0]]
231 #
232 # var i = new Matrix.identity(3)
233 # assert i.transposed == i
234 # ~~~
235 fun transposed: Matrix
236 do
237 var out = new Matrix(height, width)
238 for k, v in self do out[k.x, k.y] = v
239 return out
240 end
241
242 # Iterate over the values in this matrix
243 fun iterator: MapIterator[MatrixCoordinate, Float] do return new MatrixIndexIterator(self)
244
245 redef fun to_s
246 do
247 var lines = new Array[String]
248 for y in height.times do
249 lines.add items.subarray(y*width, width).join(" ")
250 end
251 return lines.join("\n")
252 end
253
254 redef fun ==(other) do return other isa Matrix and other.items == self.items
255 redef fun hash do return items.hash
256 end
257
258 private class MatrixIndexIterator
259 super MapIterator[MatrixCoordinate, Float]
260
261 var matrix: Matrix
262
263 redef var key = new MatrixCoordinate(0, 0)
264
265 redef fun is_ok do return key.y < matrix.height
266
267 redef fun item
268 do
269 assert is_ok
270 return matrix[key.y, key.x]
271 end
272
273 redef fun next
274 do
275 assert is_ok
276 var key = key
277 if key.x == matrix.width - 1 then
278 key.x = 0
279 key.y += 1
280 else
281 key.x += 1
282 end
283 end
284 end
285
286 # Position key when iterating over the values of a matrix
287 class MatrixCoordinate
288 # Index of the current column
289 var x: Int
290
291 # Index of the current row
292 var y: Int
293
294 redef fun to_s do return "({x},{y})"
295 end