X-Git-Url: http://nitlanguage.org diff --git a/lib/websocket/websocket.nit b/lib/websocket/websocket.nit index 441df14..f1a559a 100644 --- a/lib/websocket/websocket.nit +++ b/lib/websocket/websocket.nit @@ -1,7 +1,5 @@ # This file is part of NIT ( http://www.nitlanguage.org ). # -# Copyright 2014 Lucas Bajolet -# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,24 +19,26 @@ module websocket import socket import sha1 import base64 +import crypto -intrude import standard::stream -intrude import standard::bytes - -# Websocket compatible listener +# Websocket compatible server # # Produces Websocket client-server connections -class WebSocketListener - super Socket +class WebsocketServer - # Socket listening to connections on a defined port + # Socket listening for incoming Websocket connections var listener: TCPServer - # Creates a new Websocket server listening on given port with `max_clients` slots available - init(port: Int, max_clients: Int) + # Is `self` closed? + var closed = false + + # Creates a new Websocket server listening on given port + # with `max_clients` slots available + init with_infos(port: Int, max_clients: Int) do - listener = new TCPServer(port) + var listener = new TCPServer(port) listener.listen max_clients + init(listener) end # Accepts an incoming connection @@ -49,13 +49,14 @@ class WebSocketListener var client = listener.accept assert client != null - return new WebsocketConnection(listener.port, "", client) + return new WebsocketConnection(client) end - # Stop listening for incoming connections + # Close the server and the socket below it fun close do listener.close + closed = true end end @@ -63,29 +64,49 @@ end # # Can be used to communicate with a client class WebsocketConnection - super TCPStream + super DuplexProtocol + super PollableReader + + redef type STREAM: TCPStream + + # Does the current frame have a mask? + private var has_mask = false + + # Mask with which to XOR input data + private var mask = new CString(4) + + # Offset of the mask to use when decoding input data + private var mask_offset = -1 + + # Length of the current frame + private var frame_length = -1 + + # Position in current frame + private var frame_cursor = -1 + + # Type of the current frame + var frame_type = -1 + + # Is `self` closed? + var closed = false init do - _buffer = new NativeString(1024) - _buffer_pos = 0 - _buffer_capacity = 1024 - _buffer_length = 0 var headers = parse_handshake var resp = handshake_response(headers) - client.write(resp) + origin.write(resp) end - # Client connection to the server - var client: TCPStream - # Disconnect from a client - redef fun close - do - client.close + redef fun close do + origin.close + closed = true end - # Parses the input handshake sent by the client + # Ping response message + private fun pong_msg: Bytes do return once b"\x8a\x00" + + # Parse the input handshake sent by the client # See RFC 6455 for information private fun parse_handshake: Map[String,String] do @@ -105,7 +126,7 @@ class WebsocketConnection return headmap end - # Generates the handshake + # Generate a handshake response private fun handshake_response(heads: Map[String,String]): String do var resp_map = new HashMap[String,String] @@ -114,158 +135,160 @@ class WebsocketConnection resp_map["Connection:"] = "Upgrade" var key = heads["Sec-WebSocket-Key"] key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - key = key.sha1.encode_base64 + key = key.sha1.encode_base64.to_s resp_map["Sec-WebSocket-Accept:"] = key var resp = resp_map.join("\r\n", " ") resp += "\r\n\r\n" return resp end - # Frames a text message to be sent to a client - private fun frame_message(msg: String): Bytes + # Frame a text message to be sent to a client + private fun frame_message(msg: Text): Bytes do - var ans_buffer = new Bytes.with_capacity(msg.length) + var ans_buffer = new Bytes.with_capacity(msg.byte_length + 2) # Flag for final frame set to 1 # opcode set to 1 (for text) - ans_buffer.add(129u8) + ans_buffer.add(129) if msg.length < 126 then - ans_buffer.add(msg.length.to_b) + ans_buffer.add(msg.length) end if msg.length >= 126 and msg.length <= 65535 then - ans_buffer.add(126u8) - ans_buffer.add((msg.length >> 8).to_b) - ans_buffer.add(msg.length.to_b) - end - if msg isa FlatString then - ans_buffer.append_ns_from(msg.items, msg.length, msg.first_byte) - else - for i in msg.substrings do - ans_buffer.append_ns_from(i.as(FlatString).items, i.length, i.as(FlatString).first_byte) - end + ans_buffer.add(126) + ans_buffer.add(msg.length >> 8) + ans_buffer.add(msg.length) end + msg.append_to_bytes(ans_buffer) return ans_buffer end - # Reads an HTTP frame + # Read an HTTP frame protected fun read_http_frame(buf: Buffer): String do - var ln = client.read_line + var ln = origin.read_line buf.append ln buf.append "\r\n" if buf.has_suffix("\r\n\r\n") then return buf.to_s return read_http_frame(buf) end - # Gets the message from the client, unpads it and reconstitutes the message - private fun unpad_message do - var fin = false - var bf = new Bytes.empty - while not fin do - var fst_byte = client.read_byte - var snd_byte = client.read_byte - if fst_byte == null or snd_byte == null then - last_error = new IOError("Error: bad frame") - client.close - return - end - # First byte in msg is formatted this way : - # |(fin - 1bit)|(RSV1 - 1bit)|(RSV2 - 1bit)|(RSV3 - 1bit)|(opcode - 4bits) - # fin = Flag indicating if current frame is the last one - # RSV1/2/3 = Extension flags, unsupported - # Opcode values : - # %x0 denotes a continuation frame - # %x1 denotes a text frame - # %x2 denotes a binary frame - # %x3-7 are reserved for further non-control frames - # %x8 denotes a connection close - # %x9 denotes a ping - # %xA denotes a pong - # %xB-F are reserved for further control frames - var fin_flag = fst_byte & 0b1000_0000u8 - if fin_flag != 0 then fin = true - var opcode = fst_byte & 0b0000_1111u8 - if opcode == 9 then - bf.add(138u8) - bf.add(0u8) - client.write(bf.to_s) - _buffer_pos = _buffer_length + # Get a frame's information + private fun read_frame_info do + var fst_byte = origin.read_byte + var snd_byte = origin.read_byte + if fst_byte < 0 or snd_byte < 0 then + last_error = new IOError("Error: bad frame") + close + return + end + # First byte in msg is formatted this way : + # |(fin - 1bit)|(RSV1 - 1bit)|(RSV2 - 1bit)|(RSV3 - 1bit)|(opcode - 4bits) + # fin = Flag indicating if current frame is the last one for the current message + # RSV1/2/3 = Extension flags, unsupported + # Opcode values : + # %x0 denotes a continuation frame + # %x1 denotes a text frame + # %x2 denotes a binary frame + # %x3-7 are reserved for further non-control frames + # %x8 denotes a connection close + # %x9 denotes a ping + # %xA denotes a pong + # %xB-F are reserved for further control frames + var opcode = fst_byte & 0b0000_1111 + if opcode == 9 then + origin.write_bytes(pong_msg) + return + end + if opcode == 8 then + close + return + end + frame_type = opcode + # Second byte is formatted this way : + # |(mask - 1bit)|(payload length - 7 bits) + # As specified, if the payload length is 126 or 127 + # The next 16 or 64 bits contain an extended payload length + var mask_flag = snd_byte & 0b1000_0000 + var len = snd_byte & 0b0111_1111 + var payload_ext_len = 0 + if len == 126 then + var tmp = origin.read_bytes(2) + if tmp.length != 2 then + last_error = new IOError("Error: received interrupted frame") + origin.close return end - if opcode == 8 then - self.client.close + payload_ext_len += tmp[0].to_i << 8 + payload_ext_len += tmp[1].to_i + else if len == 127 then + var tmp = origin.read_bytes(8) + if tmp.length != 8 then + last_error = new IOError("Error: received interrupted frame") + origin.close return end - # Second byte is formatted this way : - # |(mask - 1bit)|(payload length - 7 bits) - # As specified, if the payload length is 126 or 127 - # The next 16 or 64 bits contain an extended payload length - var mask_flag = snd_byte & 0b1000_0000u8 - var len = (snd_byte & 0b0111_1111u8).to_i - var payload_ext_len = 0 - if len == 126 then - var tmp = client.read_bytes(2) - if tmp.length != 2 then - last_error = new IOError("Error: received interrupted frame") - client.close - return - end - payload_ext_len += tmp[0].to_i << 8 - payload_ext_len += tmp[1].to_i - else if len == 127 then - var tmp = client.read_bytes(8) - if tmp.length != 8 then - last_error = new IOError("Error: received interrupted frame") - client.close - return - end - for i in [0 .. 8[ do - payload_ext_len += tmp[i].to_i << (8 * (7 - i)) - end - end - if mask_flag != 0 then - var mask = client.read_bytes(4).items - if payload_ext_len != 0 then - len = payload_ext_len - end - var msg = client.read_bytes(len).items - bf.append_ns(unmask_message(mask, msg, len), len) + for i in [0 .. 8[ do + payload_ext_len += tmp[i].to_i << (8 * (7 - i)) end end - _buffer = bf.items - _buffer_length = bf.length + if mask_flag != 0 then + origin.read_bytes_to_cstring(mask, 4) + has_mask = true + else + mask.memset(0, 4) + has_mask = false + end + if payload_ext_len != 0 then + len = payload_ext_len + end + frame_length = len + frame_cursor = 0 end - # Unmasks a message sent by a client - private fun unmask_message(key: NativeString, message: NativeString, len: Int): NativeString - do - var return_message = new NativeString(len) - - for i in [0 .. len[ do - return_message[i] = message[i] ^ key[i % 4] + redef fun raw_read_byte do + while not closed and frame_cursor >= frame_length do + read_frame_info + end + if closed then return -1 + var b = origin.read_byte + if b >= 0 then + frame_cursor += 1 end + return b + end - return return_message + redef fun raw_read_bytes(ns, len) do + while not closed and frame_cursor >= frame_length do + read_frame_info + end + if closed then return -1 + var available = frame_length - frame_cursor + var to_rd = len.min(available) + var rd = origin.read_bytes_to_cstring(ns, to_rd) + if rd < 0 then + close + return 0 + end + if has_mask then + ns.xor(mask, rd, 4, mask_offset) + mask_offset = rd % 4 + end + frame_cursor += rd + return rd end # Checks if a connection to a client is available - redef fun connected do return client.connected - - redef fun write_bytes(s) do client.write_bytes(frame_message(s.to_s)) + fun connected: Bool do return not closed and origin.connected - redef fun write(msg) do client.write(frame_message(msg.to_s).to_s) - - redef fun is_writable do return client.connected - - redef fun fill_buffer - do - buffer_reset - unpad_message + redef fun write_bytes_from_cstring(ns, len) do + origin.write_bytes(frame_message(ns.to_s_unsafe(len))) end - redef fun end_reached do return client._buffer_pos >= client._buffer_length and client.end_reached + redef fun write(msg) do origin.write_bytes(frame_message(msg)) + + redef fun is_writable do return origin.connected # Is there some data available to be read ? - fun can_read(timeout: Int): Bool do return client.ready_to_read(timeout) + fun can_read(timeout: Int): Bool do return not closed and origin.ready_to_read(timeout) - redef fun poll_in do return client.poll_in + redef fun poll_in do return origin.poll_in end