#!/usr/bin/env python3 from collections import namedtuple from bitstream import BitStream Node = namedtuple('Node', 'chars weight left right') left = False right = True def count_frequencies(text): freqs = {} for c in text: freqs[c] = freqs.get(c, 0) + 1 return [Node([c], n, None, None) for c, n in freqs.items()] def build_tree(freqs): def sortkey(node): return (node[1], node[0]) nodes = sorted(freqs, key=sortkey) while (len(nodes) > 1): a = nodes[0] b = nodes[1] c = Node(sorted(a.chars + b.chars), a.weight + b.weight, a, b) nodes = sorted([c] + nodes[2:], key=sortkey) return nodes[0] def encode(text, tree): bits = BitStream() for c in text: bits.write(encode_char(c, tree)) return bits def encode_char(c, tree): path = [] if len(tree.chars) == 1 and c == tree.chars[0]: # leaf return path elif c in tree.chars: # node if c in tree.left.chars: return path + [left] + encode_char(c, tree.left) elif c in tree.right.chars: return path + [right] + encode_char(c, tree.right) else: raise ValueError(f'{c} not found in left/right branch of {tree}') else: raise ValueError(f'{c} not found in {tree}') def decode(bits, tree): text = '' bits = bits.copy() while len(bits): char, bits = decode_next(bits, tree) text += char return text def decode_next(bits, tree): if len(tree.chars) == 1: # leaf return tree.chars[0], bits elif len(bits): # node bit = bits.read(bool, 1)[0] if bit is left: return decode_next(bits, tree.left) else: return decode_next(bits, tree.right) else: raise ValueError('bits consumed prematurely') if __name__ == '__main__': text = 'abracadabra' freqs = count_frequencies(text) tree = build_tree(freqs) encoded = encode(text, tree) decoded = decode(encoded, tree) print(f'Compressed "{text}" as {encoded} in {len(encoded)} bits.') print(f'Decompressed {encoded} as "{text}".')