rename `NativeString` to `CString`
[nit.git] / lib / websocket / websocket.nit
index 9066391..8a97421 100644 (file)
@@ -22,53 +22,67 @@ import socket
 import sha1
 import base64
 
-intrude import standard::stream
+intrude import core::stream
+intrude import core::bytes
 
-# Websocket compatible server, works as an extra layer to the original Sockets
-class WebSocket
-       super BufferedIStream
-       super OStream
-       super PollableIStream
-
-       # Client connection to the server
-       var client: Socket
+# Websocket compatible listener
+#
+# Produces Websocket client-server connections
+class WebSocketListener
+       super Socket
 
        # Socket listening to connections on a defined port
-       var listener: Socket
+       var listener: TCPServer
 
        # Creates a new Websocket server listening on given port with `max_clients` slots available
        init(port: Int, max_clients: Int)
        do
-               _buffer = new FlatBuffer
-               _buffer_pos = 0
-               listener = new Socket.server(port, max_clients)
+               listener = new TCPServer(port)
+               listener.listen max_clients
+       end
+
+       # Accepts an incoming connection
+       fun accept: WebsocketConnection
+       do
+               assert not listener.closed
+
+               var client = listener.accept
+               assert client != null
+
+               return new WebsocketConnection(listener.port, "", client)
        end
 
-       # Accept an incoming connection and initializes the handshake
-       fun accept
+       # Stop listening for incoming connections
+       fun close
        do
-               assert not listener.eof
+               listener.close
+       end
+end
 
-               client = listener.accept
+# Connection to a websocket client
+#
+# Can be used to communicate with a client
+class WebsocketConnection
+       super TCPStream
 
+       init do
+               _buffer = new CString(1024)
+               _buffer_pos = 0
+               _buffer_capacity = 1024
+               _buffer_length = 0
                var headers = parse_handshake
                var resp = handshake_response(headers)
 
                client.write(resp)
        end
 
-       # Disconnect from a client
-       fun disconnect_client
-       do
-               client.close
-       end
+       # Client connection to the server
+       var client: TCPStream
 
-       # Disconnects the client if one is connected
-       # And stops the server
+       # Disconnect from a client
        redef fun close
        do
                client.close
-               listener.close
        end
 
        # Parses the input handshake sent by the client
@@ -100,7 +114,7 @@ class WebSocket
                resp_map["Connection:"] = "Upgrade"
                var key = heads["Sec-WebSocket-Key"]
                key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
-               key = key.sha1.encode_base64
+               key = key.sha1.encode_base64.to_s
                resp_map["Sec-WebSocket-Accept:"] = key
                var resp = resp_map.join("\r\n", " ")
                resp += "\r\n\r\n"
@@ -108,39 +122,52 @@ class WebSocket
        end
 
        # Frames a text message to be sent to a client
-       private fun frame_message(msg: String): String
+       private fun frame_message(msg: String): Bytes
        do
-               var ans_buffer = new FlatBuffer
+               var ans_buffer = new Bytes.with_capacity(msg.length)
                # Flag for final frame set to 1
                # opcode set to 1 (for text)
-               ans_buffer.add(129.ascii)
+               ans_buffer.add(129u8)
                if msg.length < 126 then
-                       ans_buffer.add(msg.length.ascii)
+                       ans_buffer.add(msg.length.to_b)
                end
                if msg.length >= 126 and msg.length <= 65535 then
-                       ans_buffer.add(126.ascii)
-                       ans_buffer.add(msg.length.rshift(8).ascii)
-                       ans_buffer.add(msg.length.ascii)
+                       ans_buffer.add(126u8)
+                       ans_buffer.add((msg.length >> 8).to_b)
+                       ans_buffer.add(msg.length.to_b)
+               end
+               if msg isa FlatString then
+                       ans_buffer.append_ns_from(msg.items, msg.length, msg.first_byte)
+               else
+                       for i in msg.substrings do
+                               ans_buffer.append_ns_from(i.as(FlatString).items, i.length, i.as(FlatString).first_byte)
+                       end
                end
-               ans_buffer.append(msg)
-               return ans_buffer.to_s
+               return ans_buffer
        end
 
        # Reads an HTTP frame
        protected fun read_http_frame(buf: Buffer): String
        do
-               client.append_line_to(buf)
-               buf.chars.add('\n')
-               if buf.has_substring("\r\n\r\n", buf.length - 4) then return buf.to_s
+               var ln = client.read_line
+               buf.append ln
+               buf.append "\r\n"
+               if buf.has_suffix("\r\n\r\n") then return buf.to_s
                return read_http_frame(buf)
        end
 
        # Gets the message from the client, unpads it and reconstitutes the message
        private fun unpad_message do
                var fin = false
+               var bf = new Bytes.empty
                while not fin do
-                       var fst_char = client.read_char
-                       var snd_char = client.read_char
+                       var fst_byte = client.read_byte
+                       var snd_byte = client.read_byte
+                       if fst_byte == null or snd_byte == null then
+                               last_error = new IOError("Error: bad frame")
+                               client.close
+                               return
+                       end
                        # First byte in msg is formatted this way :
                        # |(fin - 1bit)|(RSV1 - 1bit)|(RSV2 - 1bit)|(RSV3 - 1bit)|(opcode - 4bits)
                        # fin = Flag indicating if current frame is the last one
@@ -154,14 +181,14 @@ class WebSocket
                        #       %x9 denotes a ping
                        #       %xA denotes a pong
                        #       %xB-F are reserved for further control frames
-                       var fin_flag = fst_char.bin_and(128)
+                       var fin_flag = fst_byte & 0b1000_0000u8
                        if fin_flag != 0 then fin = true
-                       var opcode = fst_char.bin_and(15)
+                       var opcode = fst_byte & 0b0000_1111u8
                        if opcode == 9 then
-                               _buffer.add(138.ascii)
-                               _buffer.add('\0')
-                               client.write(_buffer.to_s)
-                               _buffer_pos += 2
+                               bf.add(138u8)
+                               bf.add(0u8)
+                               client.write(bf.to_s)
+                               _buffer_pos = _buffer_length
                                return
                        end
                        if opcode == 8 then
@@ -172,70 +199,70 @@ class WebSocket
                        # |(mask - 1bit)|(payload length - 7 bits)
                        # As specified, if the payload length is 126 or 127
                        # The next 16 or 64 bits contain an extended payload length
-                       var mask_flag = snd_char.bin_and(128)
-                       var len = snd_char.bin_and(127)
+                       var mask_flag = snd_byte & 0b1000_0000u8
+                       var len = (snd_byte & 0b0111_1111u8).to_i
                        var payload_ext_len = 0
                        if len == 126 then
-                               payload_ext_len = client.read_char.lshift(8)
-                               payload_ext_len += client.read_char
+                               var tmp = client.read_bytes(2)
+                               if tmp.length != 2 then
+                                       last_error = new IOError("Error: received interrupted frame")
+                                       client.close
+                                       return
+                               end
+                               payload_ext_len += tmp[0].to_i << 8
+                               payload_ext_len += tmp[1].to_i
                        else if len == 127 then
-                               # 64 bits for length are not supported,
-                               # only the last 32 will be interpreted as a Nit Integer
-                               for i in [0..4[ do client.read_char
-                               payload_ext_len = client.read_char.lshift(24)
-                               payload_ext_len += client.read_char.lshift(16)
-                               payload_ext_len += client.read_char.lshift(8)
-                               payload_ext_len += client.read_char
+                               var tmp = client.read_bytes(8)
+                               if tmp.length != 8 then
+                                       last_error = new IOError("Error: received interrupted frame")
+                                       client.close
+                                       return
+                               end
+                               for i in [0 .. 8[ do
+                                       payload_ext_len += tmp[i].to_i << (8 * (7 - i))
+                               end
                        end
                        if mask_flag != 0 then
+                               var mask = client.read_bytes(4).items
                                if payload_ext_len != 0 then
-                                       var msg = client.read(payload_ext_len+4)
-                                       var mask = msg.substring(0,4)
-                                       _buffer.append(unmask_message(mask, msg.substring(4, payload_ext_len)))
-                               else
-                                       if len == 0 then
-                                               return
-                                       end
-                                       var msg = client.read(len+4)
-                                       var mask = msg.substring(0,4)
-                                       _buffer.append(unmask_message(mask, msg.substring(4, len)))
+                                       len = payload_ext_len
                                end
+                               var msg = client.read_bytes(len).items
+                               bf.append_ns(unmask_message(mask, msg, len), len)
                        end
                end
+               _buffer = bf.items
+               _buffer_length = bf.length
        end
 
        # Unmasks a message sent by a client
-       private fun unmask_message(key: String, message: String): String
+       private fun unmask_message(key: CString, message: CString, len: Int): CString
        do
-               var return_message = new FlatBuffer.with_capacity(message.length)
-               var msg_iter = message.chars.iterator
+               var return_message = new CString(len)
 
-               while msg_iter.is_ok do
-                       return_message.chars[msg_iter.index] = msg_iter.item.ascii.bin_xor(key.chars[msg_iter.index%4].ascii).ascii
-                       msg_iter.next
+               for i in [0 .. len[ do
+                       return_message[i] = message[i] ^ key[i % 4]
                end
 
-               return return_message.to_s
+               return return_message
        end
 
        # Checks if a connection to a client is available
-       fun connected: Bool do return client.connected
+       redef fun connected do return client.connected
 
-       redef fun write(msg: Text)
-       do
-               client.write(frame_message(msg.to_s))
-       end
+       redef fun write_bytes(s) do client.write_bytes(frame_message(s.to_s))
+
+       redef fun write(msg) do client.write(frame_message(msg.to_s).to_s)
 
        redef fun is_writable do return client.connected
 
        redef fun fill_buffer
        do
-               _buffer.clear
-               _buffer_pos = 0
+               buffer_reset
                unpad_message
        end
 
-       redef fun end_reached do return _buffer_pos >= _buffer.length and client.eof
+       redef fun end_reached do return client._buffer_pos >= client._buffer_length and client.end_reached
 
        # Is there some data available to be read ?
        fun can_read(timeout: Int): Bool do return client.ready_to_read(timeout)