Skip to content

Instantly share code, notes, and snippets.

@mzaks
Last active September 24, 2024 15:37
Show Gist options
  • Save mzaks/78f7d38f63fb234dadb1dae11f2ee3ae to your computer and use it in GitHub Desktop.
Save mzaks/78f7d38f63fb234dadb1dae11f2ee3ae to your computer and use it in GitHub Desktop.
Mojo String with small string optimisation and unicode support (based on UTF-8)
from algorithm.functional import vectorize
from bit import bit_width, byte_swap, count_leading_zeros
from collections._index_normalization import normalize_index
from memory import memcpy, memset_zero
from sys import is_big_endian, sizeof
from utils import StringSlice, Span
from utils.string_slice import _utf8_byte_type, _StringSliceIter
struct CrazyString[
dt: DType = DType.uint64 if sizeof[UnsafePointer[UInt8]]()
== DType.uint64.sizeof() else DType.uint32,
indexed: Bool = True,
](
CollectionElementNew,
Comparable,
Formattable,
KeyElement,
Stringable,
Sized,
):
"""
This string implementation contains inline small string optimization
and allows building up unicode index for fast indexing and length computaiton.
The index stores the offset of every 32nd Unicode codepoint
allowing fast code point indexing computation. The first value of the index contains the
total Unicode codepoints count. The index and the string are stored in the same heap reagion
reducing the number of heap allocations and the need for a second pointer field in the struct.
The size of the instance can be from 24 to 8 bytes dependent on platform and user parametrization.
Possible layouts:
64 bit arch:
[........][........][........] -> 24 bytes
[....][....][........] -> 16 bytes
32 bit arch:
[....][....][....] -> 12 bytes
[..][..][....] -> 8 bytes
Parameters:
dt: Represents the byte width of first two fields.
The value can be uint32/uint64 on 64 bit arch and unit16/unit32 on 32 bit arch.
indexed: A flag which indicates if we want to build up an index for .
"""
var _flagged_bytes_count: Scalar[dt]
"""
The value is always in little endian encoding (independent of platform endianess)
If the actual count is greater than inline_capacity,
the value is shifted by five bits to the left and stroed here.
If value is smaller, or equal to inline_capacity, the count is stored in the first byte
and the rest of this and the other two fields is used for the string value.
"""
var _capacity: Scalar[dt]
"""
Represents the space we reserved for the string.
Important to know that the memory region allcated on the heap might be larger than _capacity
if idexed parameter is set to True. We use single allocation to represent the index and teh string.
Given the value for capacity we use `_index_size` function to compute the actual heap memory size.
As a CrazyString user you should never use this field directly, but use `capacity` function,
as in case of inline string, the value of the field does not represent the capacity,
but is part of the string value.
"""
var _pointer: UnsafePointer[UInt8]
"""
Represents a pointer to the heap reagion storing the index and the string value.
As a CrazyString user you should never use this string directly, but use `unsafe_ptr` function,
as in case of inline string, the value of the field is part of the string value.
"""
alias inline_capacity = sizeof[Self]() - 2
"""
This value represents how many bytes are available for an inline string.
We subtract 2 from the Self size, because we need one byte for length (and as a flag)
and we need the second byte for zero termination.
"""
fn __init__(inout self, literal: StringLiteral):
"""Construct an instance from a static string literal.
Args:
literal: Static string literal.
"""
self.__init__(literal.unsafe_ptr(), len(literal))
fn __init__(inout self, reference: StringRef):
"""Construct an instance from a StringRef object.
Args:
reference: The StringRef from which to construct this string object.
"""
self.__init__(reference.unsafe_ptr(), len(reference))
fn __init__(inout self, pointer: UnsafePointer[UInt8], length: Int):
"""Creates an instance from the buffer. The data in the buffer will be copied.
Args:
pointer: The pointer to the buffer.
length: The length of the buffer, without the null terminator.
"""
constrained[dt.is_unsigned(), "dt must be an unsigned int"]()
constrained[
dt.sizeof() == sizeof[UnsafePointer[UInt8]]()
or dt.sizeof() == sizeof[UnsafePointer[UInt8]]() >> 1,
"dt must be equal or half of the pointer size",
]()
if length <= self.inline_capacity:
# Save buffer as an inline string
self._flagged_bytes_count = length
self._capacity = 0
self._pointer = UnsafePointer[UInt8]()
var str_pointer = UnsafePointer.address_of(self).bitcast[
DType.uint8
]().offset(1)
self._correct_bytes_count()
memcpy(str_pointer, pointer, length)
else:
# Save buffer as a string with buffer allocation
self._flagged_bytes_count = length << 5
self._capacity = _roundup_to_32(length + 1)
var index_size = 0
@parameter
if indexed:
index_size = _index_size(int(self._capacity))
var total_buffer_size = int(self._capacity + index_size)
self._pointer = UnsafePointer[UInt8].alloc(total_buffer_size)
memset_zero(self._pointer, total_buffer_size)
var str_pointer = self._pointer.offset(index_size)
self._correct_bytes_count()
memcpy(str_pointer, pointer, length)
@parameter
if indexed:
self._build_index()
fn __init__(inout self, *, other: Self):
"""Explicitly copy the provided value.
Args:
other: The value to copy.
"""
self.__copyinit__(other)
@always_inline
fn __copyinit__(inout self, existing: Self, /):
"""Creates a deep copy of an existing string.
Args:
existing: The string to copy.
"""
self._flagged_bytes_count = existing._flagged_bytes_count
self._capacity = existing._capacity
if existing._is_inline_string():
self._pointer = existing._pointer
else:
var index_size = 0
@parameter
if indexed:
index_size = _index_size(int(self._capacity))
var total_buffer_size = int(self._capacity + index_size)
self._pointer = UnsafePointer[UInt8].alloc(total_buffer_size)
memcpy(self._pointer, existing._pointer, total_buffer_size)
@always_inline
fn __moveinit__(inout self, owned existing: Self):
"""Move the value of a string.
Args:
existing: The string to move.
"""
self._flagged_bytes_count = existing._flagged_bytes_count
self._capacity = existing._capacity
self._pointer = existing._pointer
@always_inline
fn _correct_bytes_count(inout self):
"""We want _flagged_bytes_count value to alway be in little endian representation
as we want the leading byte to be the size in case of an inline string.
"""
@parameter
if is_big_endian():
self._flagged_bytes_count = byte_swap(self._flagged_bytes_count)
@always_inline
fn _is_inline_string(self) -> Bool:
"""Return True if _flagged_bytes_count is 0 or the first 5 bits are bigger than 0.
"""
var value = self._flagged_bytes_count
@parameter
if is_big_endian():
value = byte_swap(value)
return value == 0 or (value & 31) > 0
fn _build_index(inout self):
"""Computes the index based on the capacity.
Capacity dictates how wide the index entries should be.
A string with capacity up to 255 allows an index entry to be 1 byte wide,
where capacity over 2^32 needs 4 bytes for each entry."""
var byte_width = _index_byte_width(int(self._capacity))
if byte_width == 1:
self._compute_index[DType.uint8]()
elif byte_width == 2:
self._compute_index[DType.uint16]()
elif byte_width == 4:
self._compute_index[DType.uint32]()
else:
self._compute_index[DType.uint64]()
fn _compute_index[dt: DType](self):
"""Noop if the `byte_length` is smaller than 32.
Computes the totoal count and the offset of every 32nd Unicode code points .
Parameters:
dt: Represents the byte width of an index entry.
"""
var length = self.byte_length()
if length < 32:
return
var char_count = 1
var text_ptr = self.unsafe_ptr()
var p = self._pointer.bitcast[dt]()
var index = 0
var p_index = 1
while index < length:
var num_bytes = _num_bytes(text_ptr[0])
var should_set = (char_count & 31) == 0
p[p_index] = should_set * index + int(not should_set) * p[p_index]
p_index += should_set
char_count += 1
text_ptr += num_bytes
index += num_bytes
p[0] = char_count - 1
fn _lookup_index(self, slot: Int) -> Int:
"""Return the value in the index. The value at slot 0
is the total count of the Unicode codepoints.
The value at slot > 0 is the offset of `32 * slot` Unicode codepoint."""
var byte_width = _index_byte_width(int(self._capacity))
if byte_width == 1:
return int(self._pointer[slot])
elif byte_width == 2:
return int(self._pointer.bitcast[DType.uint16]()[slot])
elif byte_width == 4:
return int(self._pointer.bitcast[DType.uint32]()[slot])
else:
return int(self._pointer.bitcast[DType.uint64]()[slot])
fn _pointer_at(self, idx: Int) -> UnsafePointer[UInt8]:
"""Find the pointer to the Unicode codepoint index."""
var index = idx
var p = self.unsafe_ptr()
@parameter
if indexed:
if index >= 31:
var slot = (index + 1) >> 5
index = (index + 1) & 31
var offset = self._lookup_index(slot)
p = p.offset(offset)
for _ in range(index):
var num_bytes = _num_bytes(p[0])
p += num_bytes
return p
fn unsafe_ptr(self) -> UnsafePointer[UInt8]:
"""Retrieves a pointer to the underlying string memory region.
Returns:
The pointer to the underlying string memory region.
"""
if self._is_inline_string():
return UnsafePointer.address_of(self).bitcast[UInt8]().offset(1)
else:
@parameter
if indexed:
return self._pointer.offset(_index_size(int(self._capacity)))
else:
return self._pointer
fn byte_length(self) -> Int:
"""Get the string length in bytes.
Returns:
The length of this string in bytes, excluding null terminator.
Notes:
This does not include the trailing null terminator in the count.
"""
var value = self._flagged_bytes_count
@parameter
if is_big_endian():
value = byte_swap(value)
if self._is_inline_string():
return int(value & 31)
return int(value >> 5)
fn __del__(owned self):
if not self._is_inline_string():
self._pointer.free()
fn __str__(self) -> String:
"""Gets the string.
This method ensures that you can pass a `CrazyString` to a method that
takes a `Stringable` value.
Returns:
An instance of a `String`.
"""
var l = self.byte_length() + 1
var p = UnsafePointer[UInt8].alloc(l)
memcpy(p, self.unsafe_ptr(), l)
return String(p, l)
fn __len__(self) -> Int:
"""Gets the string length, in Unicode codepoints.
Returns:
The string length, in bytes (for now).
"""
var p = self.unsafe_ptr()
var bytes = self.byte_length()
var result = 0
@parameter
fn count[simd_width: Int](offset: Int):
result += int(
((p.load[width=simd_width](offset) >> 6) != 0b10)
.cast[DType.uint8]()
.reduce_add()
)
@parameter
if not indexed:
vectorize[count, 16](bytes)
return result
else:
if bytes >= 32:
return self._lookup_index(0)
elif bytes <= self.inline_capacity:
p = UnsafePointer.address_of(self).bitcast[DType.uint8]()
vectorize[count, 16](sizeof[Self]())
return result - (sizeof[Self]() - bytes)
else:
vectorize[count, 16](bytes)
return result
fn __getitem__(self, idx: Int) -> CrazyString:
"""Gets the character at the specified position.
Args:
idx: The index value.
Returns:
A new string containing the character at the specified position.
"""
var index = normalize_index["CrazyString"](idx, self)
var p = self._pointer_at(index)
var num_bytes = _num_bytes(p[0])
return CrazyString(p, int(num_bytes))
fn __getitem__(self, span: Slice) -> CrazyString:
"""Gets the sequence of characters at the specified positions.
Args:
span: A slice that specifies positions of the new substring.
Returns:
A new string containing the string at the specified positions.
"""
var start: Int
var end: Int
var step: Int
start, end, step = span.indices(self.byte_length())
var p1 = self._pointer_at(start)
var p2 = self._pointer_at(end)
var bytes_upper_bound = int(p2) - int(p1) if end > start else int(
p1
) - int(p2)
if step == 1:
return CrazyString(p1, bytes_upper_bound)
var tmp = UnsafePointer[UInt8].alloc(bytes_upper_bound)
var bytes = 0
var current_step = 0
if step > 1:
for _ in range(end - start):
var num_bytes = _num_bytes(p1[0])
if current_step % step == 0:
memcpy(tmp.offset(bytes), p1, num_bytes)
bytes += num_bytes
p1 += num_bytes
current_step += 1
elif step < 0:
for _ in range(start - end):
var num_bytes = _num_bytes(p1[0])
if current_step % step == 0:
memcpy(tmp.offset(bytes), p1, num_bytes)
bytes += num_bytes
p1 -= 1
while (p1[0] >> 6) == 0b10:
p1 -= 1
current_step += 1
var result = CrazyString(tmp, bytes)
tmp.free()
return result^
fn __iter__(ref [_]self) -> _StringSliceIter[__lifetime_of(self)]:
"""Iterate over elements of the string, returning immutable references.
Returns:
An iterator of references to the string elements.
"""
return _StringSliceIter[__lifetime_of(self)](
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length()
)
@always_inline
fn __eq__(self, other: Self) -> Bool:
"""Compares two Strings if they have the same values.
Args:
other: The rhs of the operation.
Returns:
True if the Strings are equal and False otherwise.
"""
return not (self != other)
@always_inline
fn __ne__(self, other: Self) -> Bool:
"""Compares two Strings if they do not have the same values.
Args:
other: The rhs of the operation.
Returns:
True if the Strings are not equal and False otherwise.
"""
return self._strref_dangerous() != other._strref_dangerous()
@always_inline
fn __lt__(self, rhs: Self) -> Bool:
"""Compare this String to the RHS using LT comparison.
Args:
rhs: The other String to compare against.
Returns:
True if this String is strictly less than the RHS String and False
otherwise.
"""
return self._strref_dangerous() < rhs._strref_dangerous()
@always_inline
fn __le__(self, rhs: Self) -> Bool:
"""Compare this String to the RHS using LE comparison.
Args:
rhs: The other String to compare against.
Returns:
True iff this String is less than or equal to the RHS String.
"""
return not (rhs < self)
@always_inline
fn __gt__(self, rhs: Self) -> Bool:
"""Compare this String to the RHS using GT comparison.
Args:
rhs: The other String to compare against.
Returns:
True iff this String is strictly greater than the RHS String.
"""
return rhs < self
@always_inline
fn __ge__(self, rhs: Self) -> Bool:
"""Compare this String to the RHS using GE comparison.
Args:
rhs: The other String to compare against.
Returns:
True iff this String is greater than or equal to the RHS String.
"""
return not (self < rhs)
fn __hash__(self) -> UInt:
"""Hash the underlying buffer using builtin hash.
Returns:
A 64-bit hash value. This value is _not_ suitable for cryptographic
uses. Its intended usage is for data structures. See the `hash`
builtin documentation for more details.
"""
return hash(self._strref_dangerous())
@always_inline
fn __bool__(self) -> Bool:
"""Checks if the string is not empty.
Returns:
True if the string length is greater than zero, and False otherwise.
"""
return self.byte_length() > 0
fn __iadd__(inout self, other: Self):
"""Appends another string to this string.
Args:
other: The string to append.
"""
if not other:
return
if not self and self.dt == other.dt and self.indexed == other.indexed:
self = other
return
var self_len = self.byte_length()
var other_len = other.byte_length()
var total_len = self_len + other_len
if total_len <= self.inline_capacity:
# Copy the data alongside the terminator.
memcpy(
dest=self.unsafe_ptr() + self_len,
src=other.unsafe_ptr(),
count=other_len + 1,
)
UnsafePointer.address_of(self).bitcast[UInt8]()[0] = total_len
else:
self.reserve(_roundup_to_32(total_len))
memcpy(
dest=self.unsafe_ptr() + self_len,
src=other.unsafe_ptr(),
count=other_len + 1,
)
self._flagged_bytes_count = total_len << 5
# TODO: optimize this
self._build_index()
fn _strref_dangerous(self) -> StringRef:
"""
Returns an inner pointer to the string as a StringRef.
This functionality is extremely dangerous because Mojo eagerly releases
strings. Using this requires the use of the _strref_keepalive() method
to keep the underlying string alive long enough.
"""
return StringRef(self.unsafe_ptr(), self.byte_length())
fn _strref_keepalive(self):
"""
A noop that keeps `self` alive through the call. This
can be carefully used with `_strref_dangerous()` to wield inner pointers
without the string getting deallocated early.
"""
pass
@always_inline
fn as_bytes_slice(ref [_]self) -> Span[UInt8, __lifetime_of(self)]:
"""Returns a contiguous slice of the bytes owned by this string.
Returns:
A contiguous slice pointing to the bytes owned by this string.
Notes:
This does not include the trailing null terminator.
"""
# Does NOT include the NUL terminator.
return Span[UInt8, __lifetime_of(self)](
unsafe_ptr=self.unsafe_ptr(), len=self.byte_length()
)
@always_inline
fn as_string_slice(ref [_]self) -> StringSlice[__lifetime_of(self)]:
"""Returns a string slice of the data owned by this string.
Returns:
A string slice pointing to the data owned by this string.
"""
return StringSlice(unsafe_from_utf8=self.as_bytes_slice())
fn format_to(self, inout writer: Formatter):
"""
Formats this string to the provided formatter.
Args:
writer: The formatter to write to.
"""
writer.write_str(self.as_string_slice())
fn capacity(self) -> Int:
"""Capacity of the string.
Returns:
How many bytes the string can hold.
"""
if self._is_inline_string():
return self.inline_capacity
return int(self._capacity)
fn reserve(inout self, new_capacity: Int):
"""Reserves the requested capacity.
If the current capacity is greater or equal, this is a no-op.
Otherwise, the storage is reallocated and the date is moved.
Args:
new_capacity: The new capacity.
"""
if self.capacity() >= new_capacity:
return
var current_index_size = 0
@parameter
if indexed:
current_index_size = _index_size(int(self._capacity))
var new_index_size = 0
var needs_index_widening = False
@parameter
if indexed:
new_index_size = _index_size(new_capacity)
needs_index_widening = _index_byte_width(
int(self._capacity)
) != _index_byte_width(new_capacity)
var new_total_size = new_capacity + new_index_size
var current_p = self._pointer
self._pointer = UnsafePointer[UInt8].alloc(new_total_size)
if not needs_index_widening:
memcpy(self._pointer, current_p, current_index_size)
memcpy(
self._pointer.offset(new_index_size),
current_p.offset(current_index_size),
int(self._capacity),
)
self._capacity = new_capacity
if needs_index_widening:
# TODO: do actual widening instead of rebuild
self._build_index()
current_p.free()
fn _index_size(capacity: Int) -> Int:
"""Compute the index size in bytes.
The additional index entry is needed to store the total length as the frist value in the index.
"""
return (_index_count(capacity) + 1) * _index_byte_width(int(capacity))
fn _index_count(capacity: Int) -> Int:
"""Compute the upper bound of index entries needed."""
return (capacity >> 5) + int((capacity & 31) > 0)
fn _index_byte_width(capacity: Int) -> Int:
"""Compute min byte width for the provided value."""
var bits_width = bit_width(capacity)
return (bits_width >> 3) + int(bits_width & 7 > 0)
@always_inline
fn _num_bytes(value: UInt8) -> Int:
var flipped_value = ~value
return int(count_leading_zeros(flipped_value) + (flipped_value >> 7))
fn _roundup_to_32(value: Int) -> Int:
alias mask = Int.MAX << 5
return (value & mask) + 32 * ((value & 31) > 0)
from crazy_string import *
from testing import *
def test_inline_string():
var text = "hello world this is Me"
var cs = CrazyString(text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 22)
assert_equal(len(cs), 22)
assert_equal(cs, text)
assert_equal(cs[0], "h")
assert_equal(cs[1], "e")
assert_equal(cs[20], "M")
assert_equal(cs[21], "e")
assert_equal(cs[-1], "e")
assert_equal(cs[-22], "h")
def test_inline_string_no_index():
var text = "hello world this is Me"
var cs = CrazyString[indexed=False](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 22)
assert_equal(len(cs), 22)
assert_equal(cs, text)
assert_equal(cs[0], "h")
assert_equal(cs[1], "e")
assert_equal(cs[20], "M")
assert_equal(cs[21], "e")
assert_equal(cs[-1], "e")
assert_equal(cs[-22], "h")
def test_short_inline_string():
var text = "hello 🔥!"
var cs = CrazyString(text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 11)
assert_equal(len(cs), 8)
assert_equal(cs, text)
assert_equal(cs[0], "h")
assert_equal(cs[1], "e")
assert_equal(cs[6], "🔥")
assert_equal(cs[7], "!")
assert_equal(cs[-1], "!")
assert_equal(cs[-8], "h")
def test_short_inline_string_no_index():
var text = "hello 🔥!"
var cs = CrazyString[indexed=False](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 11)
assert_equal(len(cs), 8)
assert_equal(cs, text)
assert_equal(cs[0], "h")
assert_equal(cs[1], "e")
assert_equal(cs[6], "🔥")
assert_equal(cs[7], "!")
def test_not_inline_string():
var text = "hello world this is Max"
var cs = CrazyString(text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 23)
assert_equal(len(cs), 23)
assert_equal(cs, text)
assert_equal(cs[22], "x")
assert_equal(cs[-1], "x")
assert_equal(cs[-23], "h")
def test_not_inline_string_no_index():
var text = "hello world this is Max"
var cs = CrazyString[indexed=False](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 23)
assert_equal(len(cs), 23)
assert_equal(cs, text)
assert_equal(cs[22], "x")
assert_equal(cs[-1], "x")
assert_equal(cs[-23], "h")
def test_not_inline_string_becuase_of_dt():
var text = "hello world this is Me"
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 22)
assert_equal(len(cs), 22)
assert_equal(cs, text)
assert_equal(cs[21], "e")
assert_equal(cs[-1], "e")
assert_equal(cs[-22], "h")
def test_not_inline_string_becuase_of_dt_no_index():
var text = "hello world this is Me"
var cs = CrazyString[DType.uint32, indexed=False](
text.unsafe_ptr(), text.byte_length()
)
assert_equal(cs.byte_length(), 22)
assert_equal(len(cs), 22)
assert_equal(cs, text)
assert_equal(cs[21], "e")
assert_equal(cs[-1], "e")
assert_equal(cs[-22], "h")
def test_ascii_string_at_32_byte_boundary():
var text = "hello world this is Me and Maxim"
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 32)
assert_equal(len(cs), 32)
assert_equal(cs, text)
def test_ascii_string_at_32_byte_boundary_no_index():
var text = "hello world this is Me and Maxim"
var cs = CrazyString[DType.uint32, indexed=False](
text.unsafe_ptr(), text.byte_length()
)
assert_equal(cs.byte_length(), 32)
assert_equal(len(cs), 32)
assert_equal(cs, text)
def test_ascii_string_over_32_byte_boundary():
var text = "hello world this is Me and Maxim!"
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 33)
assert_equal(len(cs), 33)
assert_equal(cs, text)
def test_ascii_string_over_32_byte_boundary_no_index():
var text = "hello world this is Me and Maxim!"
var cs = CrazyString[DType.uint32, indexed=False](
text.unsafe_ptr(), text.byte_length()
)
assert_equal(cs.byte_length(), 33)
assert_equal(len(cs), 33)
assert_equal(cs, text)
def test_non_ascii_string_at_32_byte_boundary_below_32_chars():
var text = "hello world this is Me and 🔥."
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 32)
assert_equal(len(cs), 29)
assert_equal(cs, text)
assert_equal(cs[2], "l")
assert_equal(cs[27], "🔥")
assert_equal(cs[28], ".")
def test_non_ascii_string_at_32_byte_boundary_below_32_chars_no_index():
var text = "hello world this is Me and 🔥."
var cs = CrazyString[DType.uint32, indexed=False](
text.unsafe_ptr(), text.byte_length()
)
assert_equal(cs.byte_length(), 32)
assert_equal(len(cs), 29)
assert_equal(cs, text)
assert_equal(cs[2], "l")
assert_equal(cs[27], "🔥")
assert_equal(cs[28], ".")
def test_non_ascii_string_over_32_byte_boundary_at_32_chars():
var text = "hello world this is Me and 🔥🔥🔥.."
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 41)
assert_equal(len(cs), 32)
assert_equal(cs, text)
assert_equal(cs[2], "l")
assert_equal(cs[27], "🔥")
assert_equal(cs[28], "🔥")
assert_equal(cs[29], "🔥")
assert_equal(cs[30], ".")
assert_equal(cs[31], ".")
def test_non_ascii_string_over_32_byte_boundary_at_32_chars_no_index():
var text = "hello world this is Me and 🔥🔥🔥.."
var cs = CrazyString[DType.uint32, indexed=False](
text.unsafe_ptr(), text.byte_length()
)
assert_equal(cs.byte_length(), 41)
assert_equal(len(cs), 32)
assert_equal(cs, text)
assert_equal(cs[2], "l")
assert_equal(cs[27], "🔥")
assert_equal(cs[28], "🔥")
assert_equal(cs[29], "🔥")
assert_equal(cs[30], ".")
assert_equal(cs[31], ".")
def test_non_ascii_string_over_32_byte_boundary_over_32_chars():
var text = "hello world this is Me and 🔥🔥🔥..."
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 42)
assert_equal(len(cs), 33)
assert_equal(cs, text)
assert_equal(cs[2], "l")
assert_equal(cs[27], "🔥")
assert_equal(cs[28], "🔥")
assert_equal(cs[29], "🔥")
assert_equal(cs[30], ".")
assert_equal(cs[31], ".")
assert_equal(cs[32], ".")
def test_non_ascii_string_over_32_byte_boundary_over_32_chars_no_index():
var text = "hello world this is Me and 🔥🔥🔥..."
var cs = CrazyString[DType.uint32, indexed=False](
text.unsafe_ptr(), text.byte_length()
)
assert_equal(cs.byte_length(), 42)
assert_equal(len(cs), 33)
assert_equal(cs, text)
assert_equal(cs[2], "l")
assert_equal(cs[27], "🔥")
assert_equal(cs[28], "🔥")
assert_equal(cs[29], "🔥")
assert_equal(cs[30], ".")
assert_equal(cs[31], ".")
assert_equal(cs[32], ".")
def long_mixed_string():
var text = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù."
var cs = CrazyString[DType.uint32](text.unsafe_ptr(), text.byte_length())
assert_equal(cs.byte_length(), 354)
assert_equal(len(cs), 304)
assert_equal(cs, text)
var result = String("")
for i in range(len(cs)):
result += str(cs[i])
assert_equal(text, result)
def long_mixed_string_no_index():
var text = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù."
var cs = CrazyString[DType.uint32, False](
text.unsafe_ptr(), text.byte_length()
)
assert_equal(cs.byte_length(), 354)
assert_equal(len(cs), 304)
assert_equal(cs, text)
var result = String("")
for i in range(len(cs)):
result += str(cs[i])
assert_equal(text, result)
def test_from_literal():
var cs: CrazyString = "hello"
assert_equal(cs, "hello")
cs = "hello 🔥"
assert_equal(cs, "hello 🔥")
def test_from_reference():
var cs: CrazyString = StringRef("hello")
assert_equal(cs, "hello")
cs = StringRef("hello 🔥")
assert_equal(cs, "hello 🔥")
def test_iterator():
var text: CrazyString = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù."
var result: String = ""
for c in text:
result += c
assert_equal(result, text.as_string_slice())
text = "hello 🔥"
result = ""
for c in text:
result += c
assert_equal(result, text.as_string_slice())
def test_get_slice():
var text: CrazyString = "Lorem ipsûm dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêar vulputate fôrensibùs has, dicùnt cœpîœsàé për ïn. No sèd férri vîvendœ perpétûa. Hinc dîctà pôstea sît ut, sêa habeo affert rîdêns id, ùtinam ëqûidem eà vïm. Dicô nôstro mândâmus të pro, èst cétëro voluptatûm no. Nam vœcibus corrûmpit cù."
var s1 = text[12:64]
assert_equal(len(s1), 64 - 12)
assert_equal(
"dôlor sit amet, in ïdquè soleat ànîmâl vïm, eù verêa",
s1.as_string_slice(),
)
s1 = text[12:64:2]
assert_equal("dlrstae,i du oetàîâ ï,e eê", s1.as_string_slice())
s1 = text[12:64:3]
assert_equal("dos e q ltnâv,ùea", s1.as_string_slice())
s1 = text[12:64:7]
assert_equal("di,qem,r", s1.as_string_slice())
s1 = text[120:64:-1]
assert_equal(
"ès oN .nï rëp éàsœîpœc tnùcid ,sah sùbisnerôf etatupluv ",
s1.as_string_slice(),
)
s1 = text[120:64:-2]
assert_equal("è N.ïrpésîœ ncd,a ùinrfeaulv", s1.as_string_slice())
s1 = text[120:64:-3]
assert_equal("èo. pàîcni,hùsr apv", s1.as_string_slice())
s1 = text[120:64:-5]
assert_equal("è ràœù,sn u ", s1.as_string_slice())
def test_iadd_inline_strings():
var cs1: CrazyString = "hello"
cs1 += " Maxim 🔥"
assert_equal(cs1, "hello Maxim 🔥")
def test_iadd_non_inline_strings_but_keep_in_capacity():
var cs1: CrazyString = "hello my good old friend"
cs1 += " Maxim"
assert_equal(cs1, "hello my good old friend Maxim")
def test_iadd_non_inline_strings_but_keep_over_capacity():
var cs1: CrazyString = "hello my good old friend"
cs1 += " Maxim 🔥. I think we need much more text now."
assert_equal(
cs1,
"hello my good old friend Maxim 🔥. I think we need much more text now.",
)
def test_strided_store():
var a = List[UInt8](1, 2, 3, 4)
var b = List[UInt16](0, 0, 0, 0)
b.unsafe_ptr().bitcast[DType.uint8]().strided_store[width=4](a.unsafe_ptr().load[width=4](), 2)
print(a.__str__())
print(b.__str__())
def main():
test_strided_store()
test_inline_string()
test_inline_string_no_index()
test_short_inline_string()
test_short_inline_string_no_index()
test_not_inline_string()
test_not_inline_string_no_index()
test_not_inline_string_becuase_of_dt()
test_not_inline_string_becuase_of_dt_no_index()
test_ascii_string_at_32_byte_boundary()
test_ascii_string_at_32_byte_boundary_no_index()
test_ascii_string_over_32_byte_boundary()
test_ascii_string_over_32_byte_boundary_no_index()
test_non_ascii_string_at_32_byte_boundary_below_32_chars()
test_non_ascii_string_at_32_byte_boundary_below_32_chars_no_index()
test_non_ascii_string_over_32_byte_boundary_at_32_chars()
test_non_ascii_string_over_32_byte_boundary_at_32_chars_no_index()
test_non_ascii_string_over_32_byte_boundary_over_32_chars()
test_non_ascii_string_over_32_byte_boundary_over_32_chars_no_index()
long_mixed_string()
long_mixed_string_no_index()
test_from_literal()
test_from_reference()
test_iterator()
test_get_slice()
test_iadd_inline_strings()
test_iadd_non_inline_strings_but_keep_in_capacity()
test_iadd_non_inline_strings_but_keep_over_capacity()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment