Last active
September 24, 2024 15:37
-
-
Save mzaks/78f7d38f63fb234dadb1dae11f2ee3ae to your computer and use it in GitHub Desktop.
Mojo String with small string optimisation and unicode support (based on UTF-8)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from algorithm.functional import vectorize | |
from bit import bit_width, byte_swap, count_leading_zeros | |
from collections._index_normalization import normalize_index | |
from memory import memcpy, memset_zero | |
from sys import is_big_endian, sizeof | |
from utils import StringSlice, Span | |
from utils.string_slice import _utf8_byte_type, _StringSliceIter | |
struct CrazyString[ | |
dt: DType = DType.uint64 if sizeof[UnsafePointer[UInt8]]() | |
== DType.uint64.sizeof() else DType.uint32, | |
indexed: Bool = True, | |
]( | |
CollectionElementNew, | |
Comparable, | |
Formattable, | |
KeyElement, | |
Stringable, | |
Sized, | |
): | |
""" | |
This string implementation contains inline small string optimization | |
and allows building up unicode index for fast indexing and length computaiton. | |
The index stores the offset of every 32nd Unicode codepoint | |
allowing fast code point indexing computation. The first value of the index contains the | |
total Unicode codepoints count. The index and the string are stored in the same heap reagion | |
reducing the number of heap allocations and the need for a second pointer field in the struct. | |
The size of the instance can be from 24 to 8 bytes dependent on platform and user parametrization. | |
Possible layouts: | |
64 bit arch: | |
[........][........][........] -> 24 bytes | |
[....][....][........] -> 16 bytes | |
32 bit arch: | |
[....][....][....] -> 12 bytes | |
[..][..][....] -> 8 bytes | |
Parameters: | |
dt: Represents the byte width of first two fields. | |
The value can be uint32/uint64 on 64 bit arch and unit16/unit32 on 32 bit arch. | |
indexed: A flag which indicates if we want to build up an index for . | |
""" | |
var _flagged_bytes_count: Scalar[dt] | |
""" | |
The value is always in little endian encoding (independent of platform endianess) | |
If the actual count is greater than inline_capacity, | |
the value is shifted by five bits to the left and stroed here. | |
If value is smaller, or equal to inline_capacity, the count is stored in the first byte | |
and the rest of this and the other two fields is used for the string value. | |
""" | |
var _capacity: Scalar[dt] | |
""" | |
Represents the space we reserved for the string. | |
Important to know that the memory region allcated on the heap might be larger than _capacity | |
if idexed parameter is set to True. We use single allocation to represent the index and teh string. | |
Given the value for capacity we use `_index_size` function to compute the actual heap memory size. | |
As a CrazyString user you should never use this field directly, but use `capacity` function, | |
as in case of inline string, the value of the field does not represent the capacity, | |
but is part of the string value. | |
""" | |
var _pointer: UnsafePointer[UInt8] | |
""" | |
Represents a pointer to the heap reagion storing the index and the string value. | |
As a CrazyString user you should never use this string directly, but use `unsafe_ptr` function, | |
as in case of inline string, the value of the field is part of the string value. | |
""" | |
alias inline_capacity = sizeof[Self]() - 2 | |
""" | |
This value represents how many bytes are available for an inline string. | |
We subtract 2 from the Self size, because we need one byte for length (and as a flag) | |
and we need the second byte for zero termination. | |
""" | |
fn __init__(inout self, literal: StringLiteral): | |
"""Construct an instance from a static string literal. | |
Args: | |
literal: Static string literal. | |
""" | |
self.__init__(literal.unsafe_ptr(), len(literal)) | |
fn __init__(inout self, reference: StringRef): | |
"""Construct an instance from a StringRef object. | |
Args: | |
reference: The StringRef from which to construct this string object. | |
""" | |
self.__init__(reference.unsafe_ptr(), len(reference)) | |
fn __init__(inout self, pointer: UnsafePointer[UInt8], length: Int): | |
"""Creates an instance from the buffer. The data in the buffer will be copied. | |
Args: | |
pointer: The pointer to the buffer. | |
length: The length of the buffer, without the null terminator. | |
""" | |
constrained[dt.is_unsigned(), "dt must be an unsigned int"]() | |
constrained[ | |
dt.sizeof() == sizeof[UnsafePointer[UInt8]]() | |
or dt.sizeof() == sizeof[UnsafePointer[UInt8]]() >> 1, | |
"dt must be equal or half of the pointer size", | |
]() | |
if length <= self.inline_capacity: | |
# Save buffer as an inline string | |
self._flagged_bytes_count = length | |
self._capacity = 0 | |
self._pointer = UnsafePointer[UInt8]() | |
var str_pointer = UnsafePointer.address_of(self).bitcast[ | |
DType.uint8 | |
]().offset(1) | |
self._correct_bytes_count() | |
memcpy(str_pointer, pointer, length) | |
else: | |
# Save buffer as a string with buffer allocation | |
self._flagged_bytes_count = length << 5 | |
self._capacity = _roundup_to_32(length + 1) | |
var index_size = 0 | |
@parameter | |
if indexed: | |
index_size = _index_size(int(self._capacity)) | |
var total_buffer_size = int(self._capacity + index_size) | |
self._pointer = UnsafePointer[UInt8].alloc(total_buffer_size) | |
memset_zero(self._pointer, total_buffer_size) | |
var str_pointer = self._pointer.offset(index_size) | |
self._correct_bytes_count() | |
memcpy(str_pointer, pointer, length) | |
@parameter | |
if indexed: | |
self._build_index() | |
fn __init__(inout self, *, other: Self): | |
"""Explicitly copy the provided value. | |
Args: | |
other: The value to copy. | |
""" | |
self.__copyinit__(other) | |
@always_inline | |
fn __copyinit__(inout self, existing: Self, /): | |
"""Creates a deep copy of an existing string. | |
Args: | |
existing: The string to copy. | |
""" | |
self._flagged_bytes_count = existing._flagged_bytes_count | |
self._capacity = existing._capacity | |
if existing._is_inline_string(): | |
self._pointer = existing._pointer | |
else: | |
var index_size = 0 | |
@parameter | |
if indexed: | |
index_size = _index_size(int(self._capacity)) | |
var total_buffer_size = int(self._capacity + index_size) | |
self._pointer = UnsafePointer[UInt8].alloc(total_buffer_size) | |
memcpy(self._pointer, existing._pointer, total_buffer_size) | |
@always_inline | |
fn __moveinit__(inout self, owned existing: Self): | |
"""Move the value of a string. | |
Args: | |
existing: The string to move. | |
""" | |
self._flagged_bytes_count = existing._flagged_bytes_count | |
self._capacity = existing._capacity | |
self._pointer = existing._pointer | |
@always_inline | |
fn _correct_bytes_count(inout self): | |
"""We want _flagged_bytes_count value to alway be in little endian representation | |
as we want the leading byte to be the size in case of an inline string. | |
""" | |
@parameter | |
if is_big_endian(): | |
self._flagged_bytes_count = byte_swap(self._flagged_bytes_count) | |
@always_inline | |
fn _is_inline_string(self) -> Bool: | |
"""Return True if _flagged_bytes_count is 0 or the first 5 bits are bigger than 0. | |
""" | |
var value = self._flagged_bytes_count | |
@parameter | |
if is_big_endian(): | |
value = byte_swap(value) | |
return value == 0 or (value & 31) > 0 | |
fn _build_index(inout self): | |
"""Computes the index based on the capacity. | |
Capacity dictates how wide the index entries should be. | |
A string with capacity up to 255 allows an index entry to be 1 byte wide, | |
where capacity over 2^32 needs 4 bytes for each entry.""" | |
var byte_width = _index_byte_width(int(self._capacity)) | |
if byte_width == 1: | |
self._compute_index[DType.uint8]() | |
elif byte_width == 2: | |
self._compute_index[DType.uint16]() | |
elif byte_width == 4: | |
self._compute_index[DType.uint32]() | |
else: | |
self._compute_index[DType.uint64]() | |
fn _compute_index[dt: DType](self): | |
"""Noop if the `byte_length` is smaller than 32. | |
Computes the totoal count and the offset of every 32nd Unicode code points . | |
Parameters: | |
dt: Represents the byte width of an index entry. | |
""" | |
var length = self.byte_length() | |
if length < 32: | |
return | |
var char_count = 1 | |
var text_ptr = self.unsafe_ptr() | |
var p = self._pointer.bitcast[dt]() | |
var index = 0 | |
var p_index = 1 | |
while index < length: | |
var num_bytes = _num_bytes(text_ptr[0]) | |
var should_set = (char_count & 31) == 0 | |
p[p_index] = should_set * index + int(not should_set) * p[p_index] | |
p_index += should_set | |
char_count += 1 | |
text_ptr += num_bytes | |
index += num_bytes | |
p[0] = char_count - 1 | |
fn _lookup_index(self, slot: Int) -> Int: | |
"""Return the value in the index. The value at slot 0 | |
is the total count of the Unicode codepoints. | |
The value at slot > 0 is the offset of `32 * slot` Unicode codepoint.""" | |
var byte_width = _index_byte_width(int(self._capacity)) | |
if byte_width == 1: | |
return int(self._pointer[slot]) | |
elif byte_width == 2: | |
return int(self._pointer.bitcast[DType.uint16]()[slot]) | |
elif byte_width == 4: | |
return int(self._pointer.bitcast[DType.uint32]()[slot]) | |
else: | |
return int(self._pointer.bitcast[DType.uint64]()[slot]) | |
fn _pointer_at(self, idx: Int) -> UnsafePointer[UInt8]: | |
"""Find the pointer to the Unicode codepoint index.""" | |
var index = idx | |
var p = self.unsafe_ptr() | |
@parameter | |
if indexed: | |
if index >= 31: | |
var slot = (index + 1) >> 5 | |
index = (index + 1) & 31 | |
var offset = self._lookup_index(slot) | |
p = p.offset(offset) | |
for _ in range(index): | |
var num_bytes = _num_bytes(p[0]) | |
p += num_bytes | |
return p | |
fn unsafe_ptr(self) -> UnsafePointer[UInt8]: | |
"""Retrieves a pointer to the underlying string memory region. | |
Returns: | |
The pointer to the underlying string memory region. | |
""" | |
if self._is_inline_string(): | |
return UnsafePointer.address_of(self).bitcast[UInt8]().offset(1) | |
else: | |
@parameter | |
if indexed: | |
return self._pointer.offset(_index_size(int(self._capacity))) | |
else: | |
return self._pointer | |
fn byte_length(self) -> Int: | |
"""Get the string length in bytes. | |
Returns: | |
The length of this string in bytes, excluding null terminator. | |
Notes: | |
This does not include the trailing null terminator in the count. | |
""" | |
var value = self._flagged_bytes_count | |
@parameter | |
if is_big_endian(): | |
value = byte_swap(value) | |
if self._is_inline_string(): | |
return int(value & 31) | |
return int(value >> 5) | |
fn __del__(owned self): | |
if not self._is_inline_string(): | |
self._pointer.free() | |
fn __str__(self) -> String: | |
"""Gets the string. | |
This method ensures that you can pass a `CrazyString` to a method that | |
takes a `Stringable` value. | |
Returns: | |
An instance of a `String`. | |
""" | |
var l = self.byte_length() + 1 | |
var p = UnsafePointer[UInt8].alloc(l) | |
memcpy(p, self.unsafe_ptr(), l) | |
return String(p, l) | |
fn __len__(self) -> Int: | |
"""Gets the string length, in Unicode codepoints. | |
Returns: | |
The string length, in bytes (for now). | |
""" | |
var p = self.unsafe_ptr() | |
var bytes = self.byte_length() | |
var result = 0 | |
@parameter | |
fn count[simd_width: Int](offset: Int): | |
result += int( | |
((p.load[width=simd_width](offset) >> 6) != 0b10) | |
.cast[DType.uint8]() | |
.reduce_add() | |
) | |
@parameter | |
if not indexed: | |
vectorize[count, 16](bytes) | |
return result | |
else: | |
if bytes >= 32: | |
return self._lookup_index(0) | |
elif bytes <= self.inline_capacity: | |
p = UnsafePointer.address_of(self).bitcast[DType.uint8]() | |
vectorize[count, 16](sizeof[Self]()) | |
return result - (sizeof[Self]() - bytes) | |
else: | |
vectorize[count, 16](bytes) | |
return result | |
fn __getitem__(self, idx: Int) -> CrazyString: | |
"""Gets the character at the specified position. | |
Args: | |
idx: The index value. | |
Returns: | |
A new string containing the character at the specified position. | |
""" | |
var index = normalize_index["CrazyString"](idx, self) | |
var p = self._pointer_at(index) | |
var num_bytes = _num_bytes(p[0]) | |
return CrazyString(p, int(num_bytes)) | |
fn __getitem__(self, span: Slice) -> CrazyString: | |
"""Gets the sequence of characters at the specified positions. | |
Args: | |
span: A slice that specifies positions of the new substring. | |
Returns: | |
A new string containing the string at the specified positions. | |
""" | |
var start: Int | |
var end: Int | |
var step: Int | |
start, end, step = span.indices(self.byte_length()) | |
var p1 = self._pointer_at(start) | |
var p2 = self._pointer_at(end) | |
var bytes_upper_bound = int(p2) - int(p1) if end > start else int( | |
p1 | |
) - int(p2) | |
if step == 1: | |
return CrazyString(p1, bytes_upper_bound) | |
var tmp = UnsafePointer[UInt8].alloc(bytes_upper_bound) | |
var bytes = 0 | |
var current_step = 0 | |
if step > 1: | |
for _ in range(end - start): | |
var num_bytes = _num_bytes(p1[0]) | |
if current_step % step == 0: | |
memcpy(tmp.offset(bytes), p1, num_bytes) | |
bytes += num_bytes | |
p1 += num_bytes | |
current_step += 1 | |
elif step < 0: | |
for _ in range(start - end): | |
var num_bytes = _num_bytes(p1[0]) | |
if current_step % step == 0: | |
memcpy(tmp.offset(bytes), p1, num_bytes) | |
bytes += num_bytes | |
p1 -= 1 | |
while (p1[0] >> 6) == 0b10: | |
p1 -= 1 | |
current_step += 1 | |
var result = CrazyString(tmp, bytes) | |
tmp.free() | |
return result^ | |
fn __iter__(ref [_]self) -> _StringSliceIter[__lifetime_of(self)]: | |
"""Iterate over elements of the string, returning immutable references. | |
Returns: | |
An iterator of references to the string elements. | |
""" | |
return _StringSliceIter[__lifetime_of(self)]( | |
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length() | |
) | |
@always_inline | |
fn __eq__(self, other: Self) -> Bool: | |
"""Compares two Strings if they have the same values. | |
Args: | |
other: The rhs of the operation. | |
Returns: | |
True if the Strings are equal and False otherwise. | |
""" | |
return not (self != other) | |
@always_inline | |
fn __ne__(self, other: Self) -> Bool: | |
"""Compares two Strings if they do not have the same values. | |
Args: | |
other: The rhs of the operation. | |
Returns: | |
True if the Strings are not equal and False otherwise. | |
""" | |
return self._strref_dangerous() != other._strref_dangerous() | |
@always_inline | |
fn __lt__(self, rhs: Self) -> Bool: | |
"""Compare this String to the RHS using LT comparison. | |
Args: | |
rhs: The other String to compare against. | |
Returns: | |
True if this String is strictly less than the RHS String and False | |
otherwise. | |
""" | |
return self._strref_dangerous() < rhs._strref_dangerous() | |
@always_inline | |
fn __le__(self, rhs: Self) -> Bool: | |
"""Compare this String to the RHS using LE comparison. | |
Args: | |
rhs: The other String to compare against. | |
Returns: | |
True iff this String is less than or equal to the RHS String. | |
""" | |
return not (rhs < self) | |
@always_inline | |
fn __gt__(self, rhs: Self) -> Bool: | |
"""Compare this String to the RHS using GT comparison. | |
Args: | |
rhs: The other String to compare against. | |
Returns: | |
True iff this String is strictly greater than the RHS String. | |
""" | |
return rhs < self | |
@always_inline | |
fn __ge__(self, rhs: Self) -> Bool: | |
"""Compare this String to the RHS using GE comparison. | |
Args: | |
rhs: The other String to compare against. | |
Returns: | |
True iff this String is greater than or equal to the RHS String. | |
""" | |
return not (self < rhs) | |
fn __hash__(self) -> UInt: | |
"""Hash the underlying buffer using builtin hash. | |
Returns: | |
A 64-bit hash value. This value is _not_ suitable for cryptographic | |
uses. Its intended usage is for data structures. See the `hash` | |
builtin documentation for more details. | |
""" | |
return hash(self._strref_dangerous()) | |
@always_inline | |
fn __bool__(self) -> Bool: | |
"""Checks if the string is not empty. | |
Returns: | |
True if the string length is greater than zero, and False otherwise. | |
""" | |
return self.byte_length() > 0 | |
fn __iadd__(inout self, other: Self): | |
"""Appends another string to this string. | |
Args: | |
other: The string to append. | |
""" | |
if not other: | |
return | |
if not self and self.dt == other.dt and self.indexed == other.indexed: | |
self = other | |
return | |
var self_len = self.byte_length() | |
var other_len = other.byte_length() | |
var total_len = self_len + other_len | |
if total_len <= self.inline_capacity: | |
# Copy the data alongside the terminator. | |
memcpy( | |
dest=self.unsafe_ptr() + self_len, | |
src=other.unsafe_ptr(), | |
count=other_len + 1, | |
) | |
UnsafePointer.address_of(self).bitcast[UInt8]()[0] = total_len | |
else: | |
self.reserve(_roundup_to_32(total_len)) | |
memcpy( | |
dest=self.unsafe_ptr() + self_len, | |
src=other.unsafe_ptr(), | |
count=other_len + 1, | |
) | |
self._flagged_bytes_count = total_len << 5 | |
# TODO: optimize this | |
self._build_index() | |
fn _strref_dangerous(self) -> StringRef: | |
""" | |
Returns an inner pointer to the string as a StringRef. | |
This functionality is extremely dangerous because Mojo eagerly releases | |
strings. Using this requires the use of the _strref_keepalive() method | |
to keep the underlying string alive long enough. | |
""" | |
return StringRef(self.unsafe_ptr(), self.byte_length()) | |
fn _strref_keepalive(self): | |
""" | |
A noop that keeps `self` alive through the call. This | |
can be carefully used with `_strref_dangerous()` to wield inner pointers | |
without the string getting deallocated early. | |
""" | |
pass | |
@always_inline | |
fn as_bytes_slice(ref [_]self) -> Span[UInt8, __lifetime_of(self)]: | |
"""Returns a contiguous slice of the bytes owned by this string. | |
Returns: | |
A contiguous slice pointing to the bytes owned by this string. | |
Notes: | |
This does not include the trailing null terminator. | |
""" | |
# Does NOT include the NUL terminator. | |
return Span[UInt8, __lifetime_of(self)]( | |
unsafe_ptr=self.unsafe_ptr(), len=self.byte_length() | |
) | |
@always_inline | |
fn as_string_slice(ref [_]self) -> StringSlice[__lifetime_of(self)]: | |
"""Returns a string slice of the data owned by this string. | |
Returns: | |
A string slice pointing to the data owned by this string. | |
""" | |
return StringSlice(unsafe_from_utf8=self.as_bytes_slice()) | |
fn format_to(self, inout writer: Formatter): | |
""" | |
Formats this string to the provided formatter. | |
Args: | |
writer: The formatter to write to. | |
""" | |
writer.write_str(self.as_string_slice()) | |
fn capacity(self) -> Int: | |
"""Capacity of the string. | |
Returns: | |
How many bytes the string can hold. | |
""" | |
if self._is_inline_string(): | |
return self.inline_capacity | |
return int(self._capacity) | |
fn reserve(inout self, new_capacity: Int): | |
"""Reserves the requested capacity. | |
If the current capacity is greater or equal, this is a no-op. | |
Otherwise, the storage is reallocated and the date is moved. | |
Args: | |
new_capacity: The new capacity. | |
""" | |
if self.capacity() >= new_capacity: | |
return | |
var current_index_size = 0 | |
@parameter | |
if indexed: | |
current_index_size = _index_size(int(self._capacity)) | |
var new_index_size = 0 | |
var needs_index_widening = False | |
@parameter | |
if indexed: | |
new_index_size = _index_size(new_capacity) | |
needs_index_widening = _index_byte_width( | |
int(self._capacity) | |
) != _index_byte_width(new_capacity) | |
var new_total_size = new_capacity + new_index_size | |
var current_p = self._pointer | |
self._pointer = UnsafePointer[UInt8].alloc(new_total_size) | |
if not needs_index_widening: | |
memcpy(self._pointer, current_p, current_index_size) | |
memcpy( | |
self._pointer.offset(new_index_size), | |
current_p.offset(current_index_size), | |
int(self._capacity), | |
) | |
self._capacity = new_capacity | |
if needs_index_widening: | |
# TODO: do actual widening instead of rebuild | |
self._build_index() | |
current_p.free() | |
fn _index_size(capacity: Int) -> Int: | |
"""Compute the index size in bytes. | |
The additional index entry is needed to store the total length as the frist value in the index. | |
""" | |
return (_index_count(capacity) + 1) * _index_byte_width(int(capacity)) | |
fn _index_count(capacity: Int) -> Int: | |
"""Compute the upper bound of index entries needed.""" | |
return (capacity >> 5) + int((capacity & 31) > 0) | |
fn _index_byte_width(capacity: Int) -> Int: | |
"""Compute min byte width for the provided value.""" | |
var bits_width = bit_width(capacity) | |
return (bits_width >> 3) + int(bits_width & 7 > 0) | |
@always_inline | |
fn _num_bytes(value: UInt8) -> Int: | |
var flipped_value = ~value | |
return int(count_leading_zeros(flipped_value) + (flipped_value >> 7)) | |
fn _roundup_to_32(value: Int) -> Int: | |
alias mask = Int.MAX << 5 | |
return (value & mask) + 32 * ((value & 31) > 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from crazy_string import * | |
from testing import * | |
def test_inline_string(): | |
var text = "hello world this is Me" | |
var cs = CrazyString(text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 22) | |
assert_equal(len(cs), 22) | |
assert_equal(cs, text) | |
assert_equal(cs[0], "h") | |
assert_equal(cs[1], "e") | |
assert_equal(cs[20], "M") | |
assert_equal(cs[21], "e") | |
assert_equal(cs[-1], "e") | |
assert_equal(cs[-22], "h") | |
def test_inline_string_no_index(): | |
var text = "hello world this is Me" | |
var cs = CrazyString[indexed=False](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 22) | |
assert_equal(len(cs), 22) | |
assert_equal(cs, text) | |
assert_equal(cs[0], "h") | |
assert_equal(cs[1], "e") | |
assert_equal(cs[20], "M") | |
assert_equal(cs[21], "e") | |
assert_equal(cs[-1], "e") | |
assert_equal(cs[-22], "h") | |
def test_short_inline_string(): | |
var text = "hello 🔥!" | |
var cs = CrazyString(text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 11) | |
assert_equal(len(cs), 8) | |
assert_equal(cs, text) | |
assert_equal(cs[0], "h") | |
assert_equal(cs[1], "e") | |
assert_equal(cs[6], "🔥") | |
assert_equal(cs[7], "!") | |
assert_equal(cs[-1], "!") | |
assert_equal(cs[-8], "h") | |
def test_short_inline_string_no_index(): | |
var text = "hello 🔥!" | |
var cs = CrazyString[indexed=False](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 11) | |
assert_equal(len(cs), 8) | |
assert_equal(cs, text) | |
assert_equal(cs[0], "h") | |
assert_equal(cs[1], "e") | |
assert_equal(cs[6], "🔥") | |
assert_equal(cs[7], "!") | |
def test_not_inline_string(): | |
var text = "hello world this is Max" | |
var cs = CrazyString(text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 23) | |
assert_equal(len(cs), 23) | |
assert_equal(cs, text) | |
assert_equal(cs[22], "x") | |
assert_equal(cs[-1], "x") | |
assert_equal(cs[-23], "h") | |
def test_not_inline_string_no_index(): | |
var text = "hello world this is Max" | |
var cs = CrazyString[indexed=False](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 23) | |
assert_equal(len(cs), 23) | |
assert_equal(cs, text) | |
assert_equal(cs[22], "x") | |
assert_equal(cs[-1], "x") | |
assert_equal(cs[-23], "h") | |
def test_not_inline_string_becuase_of_dt(): | |
var text = "hello world this is Me" | |
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 22) | |
assert_equal(len(cs), 22) | |
assert_equal(cs, text) | |
assert_equal(cs[21], "e") | |
assert_equal(cs[-1], "e") | |
assert_equal(cs[-22], "h") | |
def test_not_inline_string_becuase_of_dt_no_index(): | |
var text = "hello world this is Me" | |
var cs = CrazyString[DType.uint32, indexed=False]( | |
text.unsafe_ptr(), text.byte_length() | |
) | |
assert_equal(cs.byte_length(), 22) | |
assert_equal(len(cs), 22) | |
assert_equal(cs, text) | |
assert_equal(cs[21], "e") | |
assert_equal(cs[-1], "e") | |
assert_equal(cs[-22], "h") | |
def test_ascii_string_at_32_byte_boundary(): | |
var text = "hello world this is Me and Maxim" | |
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 32) | |
assert_equal(len(cs), 32) | |
assert_equal(cs, text) | |
def test_ascii_string_at_32_byte_boundary_no_index(): | |
var text = "hello world this is Me and Maxim" | |
var cs = CrazyString[DType.uint32, indexed=False]( | |
text.unsafe_ptr(), text.byte_length() | |
) | |
assert_equal(cs.byte_length(), 32) | |
assert_equal(len(cs), 32) | |
assert_equal(cs, text) | |
def test_ascii_string_over_32_byte_boundary(): | |
var text = "hello world this is Me and Maxim!" | |
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 33) | |
assert_equal(len(cs), 33) | |
assert_equal(cs, text) | |
def test_ascii_string_over_32_byte_boundary_no_index(): | |
var text = "hello world this is Me and Maxim!" | |
var cs = CrazyString[DType.uint32, indexed=False]( | |
text.unsafe_ptr(), text.byte_length() | |
) | |
assert_equal(cs.byte_length(), 33) | |
assert_equal(len(cs), 33) | |
assert_equal(cs, text) | |
def test_non_ascii_string_at_32_byte_boundary_below_32_chars(): | |
var text = "hello world this is Me and 🔥." | |
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 32) | |
assert_equal(len(cs), 29) | |
assert_equal(cs, text) | |
assert_equal(cs[2], "l") | |
assert_equal(cs[27], "🔥") | |
assert_equal(cs[28], ".") | |
def test_non_ascii_string_at_32_byte_boundary_below_32_chars_no_index(): | |
var text = "hello world this is Me and 🔥." | |
var cs = CrazyString[DType.uint32, indexed=False]( | |
text.unsafe_ptr(), text.byte_length() | |
) | |
assert_equal(cs.byte_length(), 32) | |
assert_equal(len(cs), 29) | |
assert_equal(cs, text) | |
assert_equal(cs[2], "l") | |
assert_equal(cs[27], "🔥") | |
assert_equal(cs[28], ".") | |
def test_non_ascii_string_over_32_byte_boundary_at_32_chars(): | |
var text = "hello world this is Me and 🔥🔥🔥.." | |
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 41) | |
assert_equal(len(cs), 32) | |
assert_equal(cs, text) | |
assert_equal(cs[2], "l") | |
assert_equal(cs[27], "🔥") | |
assert_equal(cs[28], "🔥") | |
assert_equal(cs[29], "🔥") | |
assert_equal(cs[30], ".") | |
assert_equal(cs[31], ".") | |
def test_non_ascii_string_over_32_byte_boundary_at_32_chars_no_index(): | |
var text = "hello world this is Me and 🔥🔥🔥.." | |
var cs = CrazyString[DType.uint32, indexed=False]( | |
text.unsafe_ptr(), text.byte_length() | |
) | |
assert_equal(cs.byte_length(), 41) | |
assert_equal(len(cs), 32) | |
assert_equal(cs, text) | |
assert_equal(cs[2], "l") | |
assert_equal(cs[27], "🔥") | |
assert_equal(cs[28], "🔥") | |
assert_equal(cs[29], "🔥") | |
assert_equal(cs[30], ".") | |
assert_equal(cs[31], ".") | |
def test_non_ascii_string_over_32_byte_boundary_over_32_chars(): | |
var text = "hello world this is Me and 🔥🔥🔥..." | |
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 42) | |
assert_equal(len(cs), 33) | |
assert_equal(cs, text) | |
assert_equal(cs[2], "l") | |
assert_equal(cs[27], "🔥") | |
assert_equal(cs[28], "🔥") | |
assert_equal(cs[29], "🔥") | |
assert_equal(cs[30], ".") | |
assert_equal(cs[31], ".") | |
assert_equal(cs[32], ".") | |
def test_non_ascii_string_over_32_byte_boundary_over_32_chars_no_index(): | |
var text = "hello world this is Me and 🔥🔥🔥..." | |
var cs = CrazyString[DType.uint32, indexed=False]( | |
text.unsafe_ptr(), text.byte_length() | |
) | |
assert_equal(cs.byte_length(), 42) | |
assert_equal(len(cs), 33) | |
assert_equal(cs, text) | |
assert_equal(cs[2], "l") | |
assert_equal(cs[27], "🔥") | |
assert_equal(cs[28], "🔥") | |
assert_equal(cs[29], "🔥") | |
assert_equal(cs[30], ".") | |
assert_equal(cs[31], ".") | |
assert_equal(cs[32], ".") | |
def long_mixed_string(): | |
var text = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù." | |
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length()) | |
assert_equal(cs.byte_length(), 354) | |
assert_equal(len(cs), 304) | |
assert_equal(cs, text) | |
var result = String("") | |
for i in range(len(cs)): | |
result += str(cs[i]) | |
assert_equal(text, result) | |
def long_mixed_string_no_index(): | |
var text = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù." | |
var cs = CrazyString[DType.uint32, False]( | |
text.unsafe_ptr(), text.byte_length() | |
) | |
assert_equal(cs.byte_length(), 354) | |
assert_equal(len(cs), 304) | |
assert_equal(cs, text) | |
var result = String("") | |
for i in range(len(cs)): | |
result += str(cs[i]) | |
assert_equal(text, result) | |
def test_from_literal(): | |
var cs: CrazyString = "hello" | |
assert_equal(cs, "hello") | |
cs = "hello 🔥" | |
assert_equal(cs, "hello 🔥") | |
def test_from_reference(): | |
var cs: CrazyString = StringRef("hello") | |
assert_equal(cs, "hello") | |
cs = StringRef("hello 🔥") | |
assert_equal(cs, "hello 🔥") | |
def test_iterator(): | |
var text: CrazyString = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù." | |
var result: String = "" | |
for c in text: | |
result += c | |
assert_equal(result, text.as_string_slice()) | |
text = "hello 🔥" | |
result = "" | |
for c in text: | |
result += c | |
assert_equal(result, text.as_string_slice()) | |
def test_get_slice(): | |
var text: CrazyString = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù." | |
var s1 = text[12:64] | |
assert_equal(len(s1), 64 - 12) | |
assert_equal( | |
"dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêa", | |
s1.as_string_slice(), | |
) | |
s1 = text[12:64:2] | |
assert_equal("dlrstae,i du oetàîâ ï,e eê", s1.as_string_slice()) | |
s1 = text[12:64:3] | |
assert_equal("dos e q ltnâv,ùea", s1.as_string_slice()) | |
s1 = text[12:64:7] | |
assert_equal("di,qem,r", s1.as_string_slice()) | |
s1 = text[120:64:-1] | |
assert_equal( | |
"ès oN .nï rëp éàsœîpœc tnùcid ,sah sùbisnerôf etatupluv ", | |
s1.as_string_slice(), | |
) | |
s1 = text[120:64:-2] | |
assert_equal("è N.ïrpésîœ ncd,a ùinrfeaulv", s1.as_string_slice()) | |
s1 = text[120:64:-3] | |
assert_equal("èo. pàîcni,hùsr apv", s1.as_string_slice()) | |
s1 = text[120:64:-5] | |
assert_equal("è ràœù,sn u ", s1.as_string_slice()) | |
def test_iadd_inline_strings(): | |
var cs1: CrazyString = "hello" | |
cs1 += " Maxim 🔥" | |
assert_equal(cs1, "hello Maxim 🔥") | |
def test_iadd_non_inline_strings_but_keep_in_capacity(): | |
var cs1: CrazyString = "hello my good old friend" | |
cs1 += " Maxim" | |
assert_equal(cs1, "hello my good old friend Maxim") | |
def test_iadd_non_inline_strings_but_keep_over_capacity(): | |
var cs1: CrazyString = "hello my good old friend" | |
cs1 += " Maxim 🔥. I think we need much more text now." | |
assert_equal( | |
cs1, | |
"hello my good old friend Maxim 🔥. I think we need much more text now.", | |
) | |
def test_strided_store(): | |
var a = List[UInt8](1, 2, 3, 4) | |
var b = List[UInt16](0, 0, 0, 0) | |
b.unsafe_ptr().bitcast[DType.uint8]().strided_store[width=4](a.unsafe_ptr().load[width=4](), 2) | |
print(a.__str__()) | |
print(b.__str__()) | |
def main(): | |
test_strided_store() | |
test_inline_string() | |
test_inline_string_no_index() | |
test_short_inline_string() | |
test_short_inline_string_no_index() | |
test_not_inline_string() | |
test_not_inline_string_no_index() | |
test_not_inline_string_becuase_of_dt() | |
test_not_inline_string_becuase_of_dt_no_index() | |
test_ascii_string_at_32_byte_boundary() | |
test_ascii_string_at_32_byte_boundary_no_index() | |
test_ascii_string_over_32_byte_boundary() | |
test_ascii_string_over_32_byte_boundary_no_index() | |
test_non_ascii_string_at_32_byte_boundary_below_32_chars() | |
test_non_ascii_string_at_32_byte_boundary_below_32_chars_no_index() | |
test_non_ascii_string_over_32_byte_boundary_at_32_chars() | |
test_non_ascii_string_over_32_byte_boundary_at_32_chars_no_index() | |
test_non_ascii_string_over_32_byte_boundary_over_32_chars() | |
test_non_ascii_string_over_32_byte_boundary_over_32_chars_no_index() | |
long_mixed_string() | |
long_mixed_string_no_index() | |
test_from_literal() | |
test_from_reference() | |
test_iterator() | |
test_get_slice() | |
test_iadd_inline_strings() | |
test_iadd_non_inline_strings_but_keep_in_capacity() | |
test_iadd_non_inline_strings_but_keep_over_capacity() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment