Merge: Native Types
[nit.git] / lib / websocket / websocket.nit
index a22c5b2..6eac618 100644 (file)
@@ -23,6 +23,7 @@ import sha1
 import base64
 
 intrude import standard::stream
+intrude import standard::bytes
 
 # Websocket compatible listener
 #
@@ -65,8 +66,10 @@ class WebsocketConnection
        super TCPStream
 
        init do
-               _buffer = new FlatBuffer
+               _buffer = new NativeString(1024)
                _buffer_pos = 0
+               _buffer_capacity = 1024
+               _buffer_length = 0
                var headers = parse_handshake
                var resp = handshake_response(headers)
 
@@ -119,22 +122,28 @@ class WebsocketConnection
        end
 
        # Frames a text message to be sent to a client
-       private fun frame_message(msg: String): String
+       private fun frame_message(msg: String): Bytes
        do
-               var ans_buffer = new FlatBuffer
+               var ans_buffer = new Bytes.with_capacity(msg.length)
                # Flag for final frame set to 1
                # opcode set to 1 (for text)
-               ans_buffer.add(129.ascii)
+               ans_buffer.add(129)
                if msg.length < 126 then
-                       ans_buffer.add(msg.length.ascii)
+                       ans_buffer.add(msg.length)
                end
                if msg.length >= 126 and msg.length <= 65535 then
-                       ans_buffer.add(126.ascii)
-                       ans_buffer.add(msg.length.rshift(8).ascii)
-                       ans_buffer.add(msg.length.ascii)
+                       ans_buffer.add(126)
+                       ans_buffer.add(msg.length.rshift(8))
+                       ans_buffer.add(msg.length)
                end
-               ans_buffer.append(msg)
-               return ans_buffer.to_s
+               if msg isa FlatString then
+                       ans_buffer.append_ns_from(msg.items, msg.length, msg.index_from)
+               else
+                       for i in msg.substrings do
+                               ans_buffer.append_ns_from(i.as(FlatString).items, i.length, i.as(FlatString).index_from)
+                       end
+               end
+               return ans_buffer
        end
 
        # Reads an HTTP frame
@@ -149,6 +158,7 @@ class WebsocketConnection
        # Gets the message from the client, unpads it and reconstitutes the message
        private fun unpad_message do
                var fin = false
+               var bf = new Bytes.empty
                while not fin do
                        var fst_byte = client.read_byte
                        var snd_byte = client.read_byte
@@ -174,10 +184,10 @@ class WebsocketConnection
                        if fin_flag != 0 then fin = true
                        var opcode = fst_byte.bin_and(15)
                        if opcode == 9 then
-                               _buffer.add(138.ascii)
-                               _buffer.add('\0')
-                               client.write(_buffer.to_s)
-                               _buffer_pos += 2
+                               bf.add(138)
+                               bf.add(0)
+                               client.write(bf.to_s)
+                               _buffer_pos = _buffer_length
                                return
                        end
                        if opcode == 8 then
@@ -192,76 +202,68 @@ class WebsocketConnection
                        var len = snd_byte.bin_and(127)
                        var payload_ext_len = 0
                        if len == 126 then
-                               var tmp = client.read(2)
+                               var tmp = client.read_bytes(2)
                                if tmp.length != 2 then
                                        last_error = new IOError("Error: received interrupted frame")
                                        client.close
                                        return
                                end
-                               payload_ext_len = tmp[1].ascii + tmp[0].ascii.lshift(8)
+                               payload_ext_len = tmp[1] + tmp[0].lshift(8)
                        else if len == 127 then
                                # 64 bits for length are not supported,
                                # only the last 32 will be interpreted as a Nit Integer
-                               var tmp = client.read(8)
+                               var tmp = client.read_bytes(8)
                                if tmp.length != 8 then
                                        last_error = new IOError("Error: received interrupted frame")
                                        client.close
                                        return
                                end
                                for pos in [0 .. tmp.length[ do
-                                       var i = tmp[pos].ascii
+                                       var i = tmp[pos]
                                        payload_ext_len += i.lshift(8 * (7 - pos))
                                end
                        end
                        if mask_flag != 0 then
+                               var mask = client.read_bytes(4).items
                                if payload_ext_len != 0 then
-                                       var msg = client.read(payload_ext_len+4)
-                                       var mask = msg.substring(0,4)
-                                       _buffer.append(unmask_message(mask, msg.substring(4, payload_ext_len)))
-                               else
-                                       if len == 0 then
-                                               return
-                                       end
-                                       var msg = client.read(len+4)
-                                       var mask = msg.substring(0,4)
-                                       _buffer.append(unmask_message(mask, msg.substring(4, len)))
+                                       len = payload_ext_len
                                end
+                               var msg = client.read_bytes(len).items
+                               bf.append_ns(unmask_message(mask, msg, len), len)
                        end
                end
+               _buffer = bf.items
+               _buffer_length = bf.length
        end
 
        # Unmasks a message sent by a client
-       private fun unmask_message(key: String, message: String): String
+       private fun unmask_message(key: NativeString, message: NativeString, len: Int): NativeString
        do
-               var return_message = new FlatBuffer.with_capacity(message.length)
-               var msg_iter = message.chars.iterator
+               var return_message = new NativeString(len)
 
-               while msg_iter.is_ok do
-                       return_message.chars[msg_iter.index] = msg_iter.item.ascii.bin_xor(key.chars[msg_iter.index%4].ascii).ascii
-                       msg_iter.next
+               for i in [0 .. len[ do
+                       return_message[i] = message[i].ascii.bin_xor(key[i%4].ascii).ascii
                end
 
-               return return_message.to_s
+               return return_message
        end
 
        # Checks if a connection to a client is available
        redef fun connected do return client.connected
 
-       redef fun write(msg)
-       do
-               client.write(frame_message(msg.to_s))
-       end
+       redef fun write_bytes(s) do client.write_bytes(frame_message(s.to_s))
+
+       redef fun write(msg) do client.write(frame_message(msg.to_s).to_s)
 
        redef fun is_writable do return client.connected
 
        redef fun fill_buffer
        do
-               _buffer.clear
-               _buffer_pos = 0
+               buffer_reset
                unpad_message
        end
 
-       redef fun end_reached do return client._buffer_pos >= client._buffer.length and client.end_reached
+       redef fun end_reached do return client._buffer_pos >= client._buffer_length and client.end_reached
 
        # Is there some data available to be read ?
        fun can_read(timeout: Int): Bool do return client.ready_to_read(timeout)