import listmixin # Uses recipe 440656
import array
# Public domain
class BitList(listmixin.ListMixin):
"""
List of bits.
The constructor takes a list or string containing zeros and ones,
and creates an object that acts like list().
This class is memory compact (uses 1 byte per every 8 elements).
"""
def __init__(self, other=()):
self.data = array.array('B')
self.length = len(other)
if hasattr(other, 'capitalize'):
# Initialize from string.
for i in xrange((len(other) + 7) // 8):
c = other[i*8:(i+1)*8]
byte = 0
for j in xrange(len(c)):
if c[j] != '0':
byte |= 1<<j
self.data.append(byte)
else:
# Initialize from sequence.
for i in xrange((len(other) + 7) // 8):
c = other[i*8:(i+1)*8]
byte = 0
for j in xrange(len(c)):
if c[j]:
byte |= 1<<j
self.data.append(byte)
def _constructor(self, iterable):
return BitList(iterable)
def __len__(self):
return self.length
def _get_element(self, i):
return (self.data[i>>3]>>(i&7))&1
def _set_element(self, i, x):
index = i>>3
mask = (1<<(i&7))
if x and x != '0':
if not self.data[index] & mask:
self.data[index] |= mask
else:
if self.data[index] & mask:
self.data[index] ^= mask
def _resize_region(self, start, end, new_size):
"""
Resize slice self[start:end] so that it has size new_size.
"""
old_size = end - start
if new_size == old_size:
return
elif new_size > old_size:
delta = new_size - old_size
self.length += delta
add_bytes = (self.length + 7) // 8 - len(self.data)
self.data.extend(array.array('B', [0] * add_bytes))
for i in xrange(self.length-1, start+new_size-1, -1):
self._set_element(i, self._get_element(i - delta))
elif new_size < old_size:
delta = old_size - new_size
for i in xrange(start+new_size, self.length-delta):
self._set_element(i, self._get_element(i + delta))
self.length -= delta
del_bytes = len(self.data) - (self.length + 7) // 8
assert del_bytes <= len(self.data)
del self.data[len(self.data)-del_bytes:]
def __getstate__(self):
return (self.to_binary(), len(self))
def __setstate__(self, (data, length)):
self.__init__()
self[:] = BitList.from_binary(data, length)
def to_binary(self):
"""
Return base256_binary_str.
"""
return self.data.tostring()
def from_binary(base256_binary_str, num_bits):
"""
Return new BitList from base256_binary_str and number of bits.
"""
ans = BitList()
if len(base256_binary_str) != (num_bits+7)//8:
raise ValueError('invalid length')
ans.length = int(num_bits)
ans.data = array.array('B')
ans.data.fromstring(base256_binary_str)
return ans
from_binary = staticmethod(from_binary)
def set_bit(self, i, x):
"""
Set bit i to x (extending to the right with zeros if needed).
"""
i = int(i)
if i >= len(self):
self.extend([0] * (i + 1 - len(self)))
self[i] = x
def get_bit(self, i):
"""
Get bit i (or zero if i >= len(self)).
"""
i = int(i)
if i >= len(self):
return 0
return self[i]