base64/websocket: Rewrite for clarity and compliance to Byte-oriented NativeString
authorLucas Bajolet <r4pass@hotmail.com>
Wed, 8 Jul 2015 20:45:43 +0000 (16:45 -0400)
committerLucas Bajolet <r4pass@hotmail.com>
Thu, 9 Jul 2015 15:45:40 +0000 (11:45 -0400)
Signed-off-by: Lucas Bajolet <r4pass@hotmail.com>

lib/base64.nit
lib/websocket/websocket.nit

index 7b78d78..5f1c920 100644 (file)
@@ -24,121 +24,121 @@ redef class String
        do
                return "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
        end
-       private fun inverted_base64_chars : HashMap[Char,Int]
+       private fun inverted_base64_chars : HashMap[Byte, Byte]
        do
-               var inv_base64_chars = new HashMap[Char,Int]
-               for k in [0..base64_chars.length[ do
-                       inv_base64_chars[ base64_chars.chars[k] ] = k
+               var inv_base64_chars = new HashMap[Byte, Byte]
+               for k in [0..base64_chars.bytelen[ do
+                       inv_base64_chars[base64_chars.bytes[k]] = k.to_b
                end
                return inv_base64_chars
        end
 
        # Encodes the receiver string to base64.
        # By default, uses "=" for padding.
-       fun encode_base64 : String do return encode_base64_custom_padding( '=' )
+       fun encode_base64 : String do return encode_base64_custom_padding('='.ascii.to_b)
 
        # Encodes the receiver string to base64 using a custom padding character.
        #
        # If using the default padding character `=`, see `encode_base64`.
-       fun encode_base64_custom_padding( padding : Char ) : String
+       fun encode_base64_custom_padding(padding : Byte) : String
        do
-               var base64_chars = once base64_chars
-               var length = length
+               var base64_bytes = once base64_chars.bytes
+               var length = bytelen
 
                var steps = length / 3
-               var chars_in_last_step = length % 3
-               var result_length = steps*4
-               if chars_in_last_step > 0 then result_length += 4
-               var result = (padding.to_s*result_length).to_cstring
+               var bytes_in_last_step = length % 3
+               var result_length = steps * 4
+               if bytes_in_last_step > 0 then result_length += 4
+               var result = new NativeString(result_length + 1)
+               var bytes = self.bytes
+               result[result_length] = 0u8
 
-               var mask_6bit = 63
+               var mask_6bit = 0b0011_1111
 
-               for s in [0..steps[ do
+               for s in [0 .. steps[ do
                        var e = 0
-                       for ss in [0..3[ do
-                               e += self.chars[s*3+ss].ascii.lshift((2-ss)*8)
+                       for ss in [0 .. 3[ do
+                               e += bytes[s * 3 + ss].to_i << ((2 - ss) * 8)
                        end
                        for ss in [0..4[ do
-                               result[s*4+3-ss] = base64_chars.chars[ e.rshift(ss*6).bin_and( mask_6bit ) ]
+                               result[s * 4 + 3 - ss] = base64_bytes[(e >> (ss * 6)) & mask_6bit]
                        end
                end
 
-               if chars_in_last_step == 1 then
-                       var e = self.chars[length-1].ascii.lshift(16)
-                       for ss in [0..2[ do
-                               result[steps*4+1-ss] = base64_chars.chars[ e.rshift((ss+2)*6).bin_and( mask_6bit ) ]
-                       end
-               else if chars_in_last_step == 2 then
-                       var e = self.chars[length-2].ascii.lshift(16) +
-                               self.chars[length-1].ascii.lshift(8)
-                       for ss in [0..3[ do
-                               result[steps*4+2-ss] = base64_chars.chars[ e.rshift((ss+1)*6).bin_and( mask_6bit ) ]
-                       end
+               var out_off = result_length - 4
+               var in_off = length - bytes_in_last_step
+               if bytes_in_last_step == 1 then
+                       result[out_off] = base64_bytes[((bytes[in_off] & 0b1111_1100u8) >> 2).to_i]
+                       result[out_off + 1] = base64_bytes[((bytes[in_off] & 0b0000_0011u8) << 4).to_i]
+                       out_off += 2
+               else if bytes_in_last_step == 2 then
+                       result[out_off] = base64_bytes[((bytes[in_off] & 0b1111_1100u8) >> 2).to_i]
+                       result[out_off + 1] = base64_bytes[(((bytes[in_off] & 0b0000_0011u8) << 4) | ((bytes[in_off + 1] & 0b1111_0000u8) >> 4)).to_i]
+                       result[out_off + 2] = base64_bytes[((bytes[in_off + 1] & 0b0000_1111u8) << 2).to_i]
+                       out_off += 3
+               end
+               if bytes_in_last_step > 0 then
+                       for i in [out_off .. result_length[ do result[i] = padding
                end
 
-               return result.to_s
+               return result.to_s_with_length(result_length)
        end
 
        # Decodes the receiver string from base64.
        # By default, uses "=" for padding.
-       fun decode_base64 : String do return decode_base64_custom_padding( '=' )
+       fun decode_base64 : String do return decode_base64_custom_padding('='.ascii.to_b)
 
        # Decodes the receiver string to base64 using a custom padding character.
        #
        # If using the default padding character `=`, see `decode_base64`.
-       fun decode_base64_custom_padding( padding : Char ) : String
+       fun decode_base64_custom_padding(padding : Byte) : String
        do
-               var inverted_base64_chars = once inverted_base64_chars
-               var length = length
+               var inv = once inverted_base64_chars
+               var length = bytelen
+               if length == 0 then return ""
                assert length % 4 == 0 else print "base64::decode_base64 only supports strings of length multiple of 4"
 
+               var bytes = self.bytes
                var steps = length / 4
-               var result_length = steps*3
-
-               var padding_begin = self.search(padding)
-               var padding_count : Int
-               if padding_begin == null then
-                       padding_count = 0
-               else
-                       padding_count = length - padding_begin.from
-                       steps -= 1
-                       result_length -= padding_count
-               end
-
-               var result = ("#"*result_length).to_cstring
-
-               var mask_8bit = 255
-
-               for s in [0..steps[ do
-                       var e = 0
-                       for ss in [0..4[ do
-                               e += inverted_base64_chars[self.chars[s*4+ss]].lshift((3-ss)*6)
-                       end
+               var result_length = steps * 3
 
-                       for ss in [0..3[ do
-                               result[s*3+ss] = e.rshift((2-ss)*8).bin_and( mask_8bit ).ascii
-                       end
+               var epos = length - 1
+               var padding_len = 0
+               while epos >= 0 and bytes[epos] == padding do
+                       epos -= 1
+                       padding_len += 1
                end
 
-               var s = steps
-               if padding_count == 1 then
-                       var e = 0
-                       for ss in [0..3[ do
-                               e += inverted_base64_chars[self.chars[s*4+ss]].lshift((3-ss)*6)
-                       end
-
-                       for ss in [0..2[ do
-                               result[s*3+ss] = e.rshift((2-ss)*8).bin_and( mask_8bit ).ascii
-                       end
-               else if padding_count == 2 then
-                       var e = 0
-                       for ss in [0..2[ do
-                               e += inverted_base64_chars[self.chars[s*4+ss]].lshift((3-ss)*6)
-                       end
+               if padding_len != 0 then steps -= 1
+               if padding_len == 1 then result_length -= 1
+               if padding_len == 2 then result_length -= 2
+
+               var result = new NativeString(result_length + 1)
+               result[result_length] = 0u8
+
+               for s in [0 .. steps[ do
+                       var c0 = inv[bytes[s * 4]]
+                       var c1 = inv[bytes[s * 4 + 1]]
+                       var c2 = inv[bytes[s * 4 + 2]]
+                       var c3 = inv[bytes[s * 4 + 3]]
+                       result[s * 3] = ((c0 & 0b0011_1111u8) << 2) | ((c1 & 0b0011_0000u8) >> 4)
+                       result[s * 3 + 1] = ((c1 & 0b0000_1111u8) << 4) | ((c2 & 0b0011_1100u8) >> 2)
+                       result[s * 3 + 2] = ((c2 & 0b0000_0011u8) << 6) | (c3 & 0b0011_1111u8)
+               end
 
-                       result[s*3] = e.rshift(2*8).bin_and( mask_8bit ).ascii
+               var last_start = steps * 4
+               if padding_len == 1 then
+                       var c0 = inv[bytes[last_start]]
+                       var c1 = inv[bytes[last_start + 1]]
+                       var c2 = inv[bytes[last_start + 2]]
+                       result[result_length - 2] = ((c0 & 0b0011_1111u8) << 2) | ((c1 & 0b0011_0000u8) >> 4)
+                       result[result_length - 1] = ((c1 & 0b0000_1111u8) << 4) | ((c2 & 0b0011_1100u8) >> 2)
+               else if padding_len == 2 then
+                       var c0 = inv[bytes[last_start]]
+                       var c1 = inv[bytes[last_start + 1]]
+                       result[result_length - 1] = ((c0 & 0b0011_1111u8) << 2) | ((c1 & 0b0011_0000u8) >> 4)
                end
 
-               return result.to_s
+               return result.to_s_with_length(result_length)
        end
 end
index 6eac618..d384512 100644 (file)
@@ -127,14 +127,14 @@ class WebsocketConnection
                var ans_buffer = new Bytes.with_capacity(msg.length)
                # Flag for final frame set to 1
                # opcode set to 1 (for text)
-               ans_buffer.add(129)
+               ans_buffer.add(129u8)
                if msg.length < 126 then
-                       ans_buffer.add(msg.length)
+                       ans_buffer.add(msg.length.to_b)
                end
                if msg.length >= 126 and msg.length <= 65535 then
-                       ans_buffer.add(126)
-                       ans_buffer.add(msg.length.rshift(8))
-                       ans_buffer.add(msg.length)
+                       ans_buffer.add(126u8)
+                       ans_buffer.add((msg.length >> 8).to_b)
+                       ans_buffer.add(msg.length.to_b)
                end
                if msg isa FlatString then
                        ans_buffer.append_ns_from(msg.items, msg.length, msg.index_from)
@@ -149,8 +149,9 @@ class WebsocketConnection
        # Reads an HTTP frame
        protected fun read_http_frame(buf: Buffer): String
        do
-               buf.append client.read_line
-               buf.append("\r\n")
+               var ln = client.read_line
+               buf.append ln
+               buf.append "\r\n"
                if buf.has_suffix("\r\n\r\n") then return buf.to_s
                return read_http_frame(buf)
        end
@@ -180,12 +181,12 @@ class WebsocketConnection
                        #       %x9 denotes a ping
                        #       %xA denotes a pong
                        #       %xB-F are reserved for further control frames
-                       var fin_flag = fst_byte.bin_and(128)
+                       var fin_flag = fst_byte & 0b1000_0000u8
                        if fin_flag != 0 then fin = true
-                       var opcode = fst_byte.bin_and(15)
+                       var opcode = fst_byte & 0b0000_1111u8
                        if opcode == 9 then
-                               bf.add(138)
-                               bf.add(0)
+                               bf.add(138u8)
+                               bf.add(0u8)
                                client.write(bf.to_s)
                                _buffer_pos = _buffer_length
                                return
@@ -198,8 +199,8 @@ class WebsocketConnection
                        # |(mask - 1bit)|(payload length - 7 bits)
                        # As specified, if the payload length is 126 or 127
                        # The next 16 or 64 bits contain an extended payload length
-                       var mask_flag = snd_byte.bin_and(128)
-                       var len = snd_byte.bin_and(127)
+                       var mask_flag = snd_byte & 0b1000_0000u8
+                       var len = (snd_byte & 0b0111_1111u8).to_i
                        var payload_ext_len = 0
                        if len == 126 then
                                var tmp = client.read_bytes(2)
@@ -208,19 +209,17 @@ class WebsocketConnection
                                        client.close
                                        return
                                end
-                               payload_ext_len = tmp[1] + tmp[0].lshift(8)
+                               payload_ext_len += tmp[0].to_i << 8
+                               payload_ext_len += tmp[1].to_i
                        else if len == 127 then
-                               # 64 bits for length are not supported,
-                               # only the last 32 will be interpreted as a Nit Integer
                                var tmp = client.read_bytes(8)
                                if tmp.length != 8 then
                                        last_error = new IOError("Error: received interrupted frame")
                                        client.close
                                        return
                                end
-                               for pos in [0 .. tmp.length[ do
-                                       var i = tmp[pos]
-                                       payload_ext_len += i.lshift(8 * (7 - pos))
+                               for i in [0 .. 8[ do
+                                       payload_ext_len += tmp[i].to_i << (8 * (7 - i))
                                end
                        end
                        if mask_flag != 0 then
@@ -242,7 +241,7 @@ class WebsocketConnection
                var return_message = new NativeString(len)
 
                for i in [0 .. len[ do
-                       return_message[i] = message[i].ascii.bin_xor(key[i%4].ascii).ascii
+                       return_message[i] = message[i] ^ key[i % 4]
                end
 
                return return_message