X-Git-Url: http://nitlanguage.org diff --git a/lib/websocket/websocket.nit b/lib/websocket/websocket.nit index 200c1f5..8a97421 100644 --- a/lib/websocket/websocket.nit +++ b/lib/websocket/websocket.nit @@ -22,16 +22,14 @@ import socket import sha1 import base64 -intrude import standard::stream +intrude import core::stream +intrude import core::bytes -# Websocket compatible server, works as an extra layer to the original Sockets -class WebSocket - super BufferedIStream - super OStream - super PollableIStream - - # Client connection to the server - var client: TCPStream +# Websocket compatible listener +# +# Produces Websocket client-server connections +class WebSocketListener + super Socket # Socket listening to connections on a defined port var listener: TCPServer @@ -39,39 +37,52 @@ class WebSocket # Creates a new Websocket server listening on given port with `max_clients` slots available init(port: Int, max_clients: Int) do - _buffer = new FlatBuffer - _buffer_pos = 0 listener = new TCPServer(port) listener.listen max_clients end - # Accept an incoming connection and initializes the handshake - fun accept + # Accepts an incoming connection + fun accept: WebsocketConnection do assert not listener.closed var client = listener.accept assert client != null - self.client = client + return new WebsocketConnection(listener.port, "", client) + end + + # Stop listening for incoming connections + fun close + do + listener.close + end +end + +# Connection to a websocket client +# +# Can be used to communicate with a client +class WebsocketConnection + super TCPStream + + init do + _buffer = new CString(1024) + _buffer_pos = 0 + _buffer_capacity = 1024 + _buffer_length = 0 var headers = parse_handshake var resp = handshake_response(headers) client.write(resp) end - # Disconnect from a client - fun disconnect_client - do - client.close - end + # Client connection to the server + var client: TCPStream - # Disconnects the client if one is connected - # And stops the server + # Disconnect from a client redef fun close do client.close - listener.close end # Parses the input handshake sent by the client @@ -103,7 +114,7 @@ class WebSocket resp_map["Connection:"] = "Upgrade" var key = heads["Sec-WebSocket-Key"] key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - key = key.sha1.encode_base64 + key = key.sha1.encode_base64.to_s resp_map["Sec-WebSocket-Accept:"] = key var resp = resp_map.join("\r\n", " ") resp += "\r\n\r\n" @@ -111,39 +122,52 @@ class WebSocket end # Frames a text message to be sent to a client - private fun frame_message(msg: String): String + private fun frame_message(msg: String): Bytes do - var ans_buffer = new FlatBuffer + var ans_buffer = new Bytes.with_capacity(msg.length) # Flag for final frame set to 1 # opcode set to 1 (for text) - ans_buffer.add(129.ascii) + ans_buffer.add(129u8) if msg.length < 126 then - ans_buffer.add(msg.length.ascii) + ans_buffer.add(msg.length.to_b) end if msg.length >= 126 and msg.length <= 65535 then - ans_buffer.add(126.ascii) - ans_buffer.add(msg.length.rshift(8).ascii) - ans_buffer.add(msg.length.ascii) + ans_buffer.add(126u8) + ans_buffer.add((msg.length >> 8).to_b) + ans_buffer.add(msg.length.to_b) + end + if msg isa FlatString then + ans_buffer.append_ns_from(msg.items, msg.length, msg.first_byte) + else + for i in msg.substrings do + ans_buffer.append_ns_from(i.as(FlatString).items, i.length, i.as(FlatString).first_byte) + end end - ans_buffer.append(msg) - return ans_buffer.to_s + return ans_buffer end # Reads an HTTP frame protected fun read_http_frame(buf: Buffer): String do - buf.append client.read_line - buf.append("\r\n") - if buf.has_substring("\r\n\r\n", buf.length - 4) then return buf.to_s + var ln = client.read_line + buf.append ln + buf.append "\r\n" + if buf.has_suffix("\r\n\r\n") then return buf.to_s return read_http_frame(buf) end # Gets the message from the client, unpads it and reconstitutes the message private fun unpad_message do var fin = false + var bf = new Bytes.empty while not fin do - var fst_char = client.read_char - var snd_char = client.read_char + var fst_byte = client.read_byte + var snd_byte = client.read_byte + if fst_byte == null or snd_byte == null then + last_error = new IOError("Error: bad frame") + client.close + return + end # First byte in msg is formatted this way : # |(fin - 1bit)|(RSV1 - 1bit)|(RSV2 - 1bit)|(RSV3 - 1bit)|(opcode - 4bits) # fin = Flag indicating if current frame is the last one @@ -157,14 +181,14 @@ class WebSocket # %x9 denotes a ping # %xA denotes a pong # %xB-F are reserved for further control frames - var fin_flag = fst_char.bin_and(128) + var fin_flag = fst_byte & 0b1000_0000u8 if fin_flag != 0 then fin = true - var opcode = fst_char.bin_and(15) + var opcode = fst_byte & 0b0000_1111u8 if opcode == 9 then - _buffer.add(138.ascii) - _buffer.add('\0') - client.write(_buffer.to_s) - _buffer_pos += 2 + bf.add(138u8) + bf.add(0u8) + client.write(bf.to_s) + _buffer_pos = _buffer_length return end if opcode == 8 then @@ -175,70 +199,70 @@ class WebSocket # |(mask - 1bit)|(payload length - 7 bits) # As specified, if the payload length is 126 or 127 # The next 16 or 64 bits contain an extended payload length - var mask_flag = snd_char.bin_and(128) - var len = snd_char.bin_and(127) + var mask_flag = snd_byte & 0b1000_0000u8 + var len = (snd_byte & 0b0111_1111u8).to_i var payload_ext_len = 0 if len == 126 then - payload_ext_len = client.read_char.lshift(8) - payload_ext_len += client.read_char + var tmp = client.read_bytes(2) + if tmp.length != 2 then + last_error = new IOError("Error: received interrupted frame") + client.close + return + end + payload_ext_len += tmp[0].to_i << 8 + payload_ext_len += tmp[1].to_i else if len == 127 then - # 64 bits for length are not supported, - # only the last 32 will be interpreted as a Nit Integer - for i in [0..4[ do client.read_char - payload_ext_len = client.read_char.lshift(24) - payload_ext_len += client.read_char.lshift(16) - payload_ext_len += client.read_char.lshift(8) - payload_ext_len += client.read_char + var tmp = client.read_bytes(8) + if tmp.length != 8 then + last_error = new IOError("Error: received interrupted frame") + client.close + return + end + for i in [0 .. 8[ do + payload_ext_len += tmp[i].to_i << (8 * (7 - i)) + end end if mask_flag != 0 then + var mask = client.read_bytes(4).items if payload_ext_len != 0 then - var msg = client.read(payload_ext_len+4) - var mask = msg.substring(0,4) - _buffer.append(unmask_message(mask, msg.substring(4, payload_ext_len))) - else - if len == 0 then - return - end - var msg = client.read(len+4) - var mask = msg.substring(0,4) - _buffer.append(unmask_message(mask, msg.substring(4, len))) + len = payload_ext_len end + var msg = client.read_bytes(len).items + bf.append_ns(unmask_message(mask, msg, len), len) end end + _buffer = bf.items + _buffer_length = bf.length end # Unmasks a message sent by a client - private fun unmask_message(key: String, message: String): String + private fun unmask_message(key: CString, message: CString, len: Int): CString do - var return_message = new FlatBuffer.with_capacity(message.length) - var msg_iter = message.chars.iterator + var return_message = new CString(len) - while msg_iter.is_ok do - return_message.chars[msg_iter.index] = msg_iter.item.ascii.bin_xor(key.chars[msg_iter.index%4].ascii).ascii - msg_iter.next + for i in [0 .. len[ do + return_message[i] = message[i] ^ key[i % 4] end - return return_message.to_s + return return_message end # Checks if a connection to a client is available - fun connected: Bool do return client.connected + redef fun connected do return client.connected - redef fun write(msg: Text) - do - client.write(frame_message(msg.to_s)) - end + redef fun write_bytes(s) do client.write_bytes(frame_message(s.to_s)) + + redef fun write(msg) do client.write(frame_message(msg.to_s).to_s) redef fun is_writable do return client.connected redef fun fill_buffer do - _buffer.clear - _buffer_pos = 0 + buffer_reset unpad_message end - redef fun end_reached do return client._buffer_pos >= client._buffer.length and client.end_reached + redef fun end_reached do return client._buffer_pos >= client._buffer_length and client.end_reached # Is there some data available to be read ? fun can_read(timeout: Int): Bool do return client.ready_to_read(timeout)