2a3c32fa8d6f35b98da2dbc5b73d0af4674ae832
[nit.git] / lib / websocket / websocket.nit
1 # This file is part of NIT ( http://www.nitlanguage.org ).
2 #
3 # Copyright 2014 Lucas Bajolet <r4pass@hotmail.com>
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 # Adds support for a websocket connection in Nit
18 # Uses standard sockets
19 module websocket
20
21 import socket
22 import sha1
23 import base64
24
25 intrude import core::stream
26 intrude import core::bytes
27
28 # Websocket compatible listener
29 #
30 # Produces Websocket client-server connections
31 class WebSocketListener
32 super Socket
33
34 # Socket listening to connections on a defined port
35 var listener: TCPServer
36
37 # Creates a new Websocket server listening on given port with `max_clients` slots available
38 init(port: Int, max_clients: Int)
39 do
40 listener = new TCPServer(port)
41 listener.listen max_clients
42 end
43
44 # Accepts an incoming connection
45 fun accept: WebsocketConnection
46 do
47 assert not listener.closed
48
49 var client = listener.accept
50 assert client != null
51
52 return new WebsocketConnection(listener.port, "", client)
53 end
54
55 # Stop listening for incoming connections
56 fun close
57 do
58 listener.close
59 end
60 end
61
62 # Connection to a websocket client
63 #
64 # Can be used to communicate with a client
65 class WebsocketConnection
66 super TCPStream
67
68 init do
69 _buffer = new NativeString(1024)
70 _buffer_pos = 0
71 _buffer_capacity = 1024
72 _buffer_length = 0
73 var headers = parse_handshake
74 var resp = handshake_response(headers)
75
76 client.write(resp)
77 end
78
79 # Client connection to the server
80 var client: TCPStream
81
82 # Disconnect from a client
83 redef fun close
84 do
85 client.close
86 end
87
88 # Parses the input handshake sent by the client
89 # See RFC 6455 for information
90 private fun parse_handshake: Map[String,String]
91 do
92 var recved = read_http_frame(new FlatBuffer)
93 var headers = recved.split("\r\n")
94 var headmap = new HashMap[String,String]
95 for i in headers do
96 var temp_head = i.split(" ")
97 var head = temp_head.shift
98 if head.is_empty or head.length == 1 then continue
99 if head.chars.last == ':' then
100 head = head.substring(0, head.length - 1)
101 end
102 var body = temp_head.join(" ")
103 headmap[head] = body
104 end
105 return headmap
106 end
107
108 # Generates the handshake
109 private fun handshake_response(heads: Map[String,String]): String
110 do
111 var resp_map = new HashMap[String,String]
112 resp_map["HTTP/1.1"] = "101 Switching Protocols"
113 resp_map["Upgrade:"] = "websocket"
114 resp_map["Connection:"] = "Upgrade"
115 var key = heads["Sec-WebSocket-Key"]
116 key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
117 key = key.sha1.encode_base64.to_s
118 resp_map["Sec-WebSocket-Accept:"] = key
119 var resp = resp_map.join("\r\n", " ")
120 resp += "\r\n\r\n"
121 return resp
122 end
123
124 # Frames a text message to be sent to a client
125 private fun frame_message(msg: String): Bytes
126 do
127 var ans_buffer = new Bytes.with_capacity(msg.length)
128 # Flag for final frame set to 1
129 # opcode set to 1 (for text)
130 ans_buffer.add(129u8)
131 if msg.length < 126 then
132 ans_buffer.add(msg.length.to_b)
133 end
134 if msg.length >= 126 and msg.length <= 65535 then
135 ans_buffer.add(126u8)
136 ans_buffer.add((msg.length >> 8).to_b)
137 ans_buffer.add(msg.length.to_b)
138 end
139 if msg isa FlatString then
140 ans_buffer.append_ns_from(msg.items, msg.length, msg.first_byte)
141 else
142 for i in msg.substrings do
143 ans_buffer.append_ns_from(i.as(FlatString).items, i.length, i.as(FlatString).first_byte)
144 end
145 end
146 return ans_buffer
147 end
148
149 # Reads an HTTP frame
150 protected fun read_http_frame(buf: Buffer): String
151 do
152 var ln = client.read_line
153 buf.append ln
154 buf.append "\r\n"
155 if buf.has_suffix("\r\n\r\n") then return buf.to_s
156 return read_http_frame(buf)
157 end
158
159 # Gets the message from the client, unpads it and reconstitutes the message
160 private fun unpad_message do
161 var fin = false
162 var bf = new Bytes.empty
163 while not fin do
164 var fst_byte = client.read_byte
165 var snd_byte = client.read_byte
166 if fst_byte == null or snd_byte == null then
167 last_error = new IOError("Error: bad frame")
168 client.close
169 return
170 end
171 # First byte in msg is formatted this way :
172 # |(fin - 1bit)|(RSV1 - 1bit)|(RSV2 - 1bit)|(RSV3 - 1bit)|(opcode - 4bits)
173 # fin = Flag indicating if current frame is the last one
174 # RSV1/2/3 = Extension flags, unsupported
175 # Opcode values :
176 # %x0 denotes a continuation frame
177 # %x1 denotes a text frame
178 # %x2 denotes a binary frame
179 # %x3-7 are reserved for further non-control frames
180 # %x8 denotes a connection close
181 # %x9 denotes a ping
182 # %xA denotes a pong
183 # %xB-F are reserved for further control frames
184 var fin_flag = fst_byte & 0b1000_0000u8
185 if fin_flag != 0 then fin = true
186 var opcode = fst_byte & 0b0000_1111u8
187 if opcode == 9 then
188 bf.add(138u8)
189 bf.add(0u8)
190 client.write(bf.to_s)
191 _buffer_pos = _buffer_length
192 return
193 end
194 if opcode == 8 then
195 self.client.close
196 return
197 end
198 # Second byte is formatted this way :
199 # |(mask - 1bit)|(payload length - 7 bits)
200 # As specified, if the payload length is 126 or 127
201 # The next 16 or 64 bits contain an extended payload length
202 var mask_flag = snd_byte & 0b1000_0000u8
203 var len = (snd_byte & 0b0111_1111u8).to_i
204 var payload_ext_len = 0
205 if len == 126 then
206 var tmp = client.read_bytes(2)
207 if tmp.length != 2 then
208 last_error = new IOError("Error: received interrupted frame")
209 client.close
210 return
211 end
212 payload_ext_len += tmp[0].to_i << 8
213 payload_ext_len += tmp[1].to_i
214 else if len == 127 then
215 var tmp = client.read_bytes(8)
216 if tmp.length != 8 then
217 last_error = new IOError("Error: received interrupted frame")
218 client.close
219 return
220 end
221 for i in [0 .. 8[ do
222 payload_ext_len += tmp[i].to_i << (8 * (7 - i))
223 end
224 end
225 if mask_flag != 0 then
226 var mask = client.read_bytes(4).items
227 if payload_ext_len != 0 then
228 len = payload_ext_len
229 end
230 var msg = client.read_bytes(len).items
231 bf.append_ns(unmask_message(mask, msg, len), len)
232 end
233 end
234 _buffer = bf.items
235 _buffer_length = bf.length
236 end
237
238 # Unmasks a message sent by a client
239 private fun unmask_message(key: NativeString, message: NativeString, len: Int): NativeString
240 do
241 var return_message = new NativeString(len)
242
243 for i in [0 .. len[ do
244 return_message[i] = message[i] ^ key[i % 4]
245 end
246
247 return return_message
248 end
249
250 # Checks if a connection to a client is available
251 redef fun connected do return client.connected
252
253 redef fun write_bytes(s) do client.write_bytes(frame_message(s.to_s))
254
255 redef fun write(msg) do client.write(frame_message(msg.to_s).to_s)
256
257 redef fun is_writable do return client.connected
258
259 redef fun fill_buffer
260 do
261 buffer_reset
262 unpad_message
263 end
264
265 redef fun end_reached do return client._buffer_pos >= client._buffer_length and client.end_reached
266
267 # Is there some data available to be read ?
268 fun can_read(timeout: Int): Bool do return client.ready_to_read(timeout)
269
270 redef fun poll_in do return client.poll_in
271 end