#!/bin/env python

from common import *
from treecomp import TreeComparer
import sys

import codecs

# Avoid problems with unicode chars when piping the output of this program
sys.stdout = codecs.getwriter('utf8')(sys.stdout)

class ManifestTree:
    def __init__(self, path):
        self.old_dict = {}
        for md5, fn in read_md5sum(os.path.join(path, mff)):
            old_dict[fn] = md5
        self.new_dict = load_tree(path)
        for fn in self.new_dict:
            if is_manifest_fn(fn):
                del self.new_dict[fn]
        self.tc = TreeComparer(self.old_dict, self.new_dict)

def is_manifest_fn(fn):
    return fn.startswith("manifest-") and fn.endswith("md5")
            
def list_manifests(d):
    return [fn for fn in os.listdir(d)
            if is_manifest_fn(fn)]

def get_manifest_fn(root):
    manifests = list_manifests(root)
    if not manifests:
        return None
    assert len(manifests) == 1
    return manifests[0]

def write_manifest(manifest_contents, root):
    assert type(manifest_contents) == unicode
    assert not list_manifests(root)
    manifest_contents = manifest_contents.encode("utf8")
    manifest_fn = "manifest-%s.md5" % md5sum(manifest_contents)
    create_file(os.path.join(root, manifest_fn), manifest_contents)

def replace_manifest(manifest_contents, root):
    old_manifests = list_manifests(root)
    assert len(old_manifests) == 1
    new_manifest_tmp = os.path.join(root, "manifest-%s.new" % md5sum(manifest_contents.encode("utf8")))
    new_manifest = os.path.join(root, "manifest-%s.md5" % md5sum(manifest_contents.encode("utf8")))
    create_file(new_manifest_tmp, manifest_contents.encode("utf8"))    
    old_manifest = os.path.join(root, old_manifests[0])
    old_manifest_bak = os.path.join(root, old_manifests[0].replace("manifest-", "oldmanifest-"))
    os.rename(old_manifest, old_manifest_bak)
    os.rename(new_manifest_tmp, new_manifest)
    os.unlink(old_manifest_bak)

def load_tree(path):
    """Return the current tree as a dict on the form { fn -> md5 }"""
    tree = get_tree(path)
    new_dict = {}
    for t in tree:
        if not os.path.isfile(t):
            continue
        if is_manifest_fn(t):
            continue
        new_dict[t] = md5sum_file(t)
    return new_dict

def mkmanifest(tree):
    """Returns the manifest as a utf-8 encoded string"""
    manifest = u""
    for fn, md5 in sorted(tree.items()):
        manifest += u"%s *%s\n" % (md5, fn)
    assert type(manifest) == unicode
    return manifest

def ugetcwd():
    return os.getcwd().decode(locale.getpreferredencoding())

def cmd_init():
    import locale
    cwd = ugetcwd()
    tree = load_tree(cwd)
    manifest = mkmanifest(tree)
    write_manifest(manifest, cwd)

def cmd_accept():
    expected_md5 = None
    if len(sys.argv) > 2:
        assert len(sys.argv) == 3
        expected_md5 = sys.argv[2]

    tree = load_tree(ugetcwd())
    manifest = mkmanifest(tree)
    if expected_md5 and expected_md5 != md5sum(manifest.encode("utf8")):
        print "Current tree status does not match given checksum", expected_md5, md5sum(manifest)
        sys.exit(1)
    replace_manifest(manifest, ugetcwd())


def cmd_diff():
    old_dict = {}
    manifest_fn = get_manifest_fn(".")
    for md5, fn in read_md5sum(manifest_fn):
        old_dict[fn] = md5
    new_dict = load_tree(ugetcwd())
    print "# Comparing with manifest", manifest_fn
    number_of_unchanged_files, number_of_new_files, number_of_deleted_files = print_diff(old_dict, new_dict)
    #print number_of_unchanged_files, "unchanged files"
    print "# Set summary:", number_of_new_files, "new files,", number_of_deleted_files, "deleted files, ", len(set(new_dict.values())), "files in total"
    print "# Current tree manifest md5:", md5sum(mkmanifest(new_dict).encode("utf8"))

def print_diff(old_dict, new_dict):
    old_set = set(old_dict.values())
    new_set = set(new_dict.values())
    tc = TreeComparer(old_dict, new_dict)
    for fn in sorted(tc.all_changed_filenames()):
        old_md5 = old_dict.get(fn, None)
        new_md5 = new_dict.get(fn, None)
        set_change = " " 
        if tc.is_deleted(fn):
            assert old_md5 and not new_md5
            if old_md5 not in new_set:
                set_change = " -"
            print "D %s %s" % (set_change, fn)
        elif tc.is_new(fn):
            assert new_md5 and not old_md5
            if new_md5 not in old_set:
                set_change = "+ "
            print "N %s %s" % (set_change, fn)
        elif tc.is_modified(fn):
            assert old_md5 and new_md5
            a = " "
            r = " "
            if old_md5 not in new_set:
                # Old file overwritten and lost in new set
                r = "-"
            if new_md5 not in old_set:
                # A new file is introduced in the set
                a = "+"
            print "M %s %s" % (a+r, fn)
        else:
            assert False
    number_of_unchanged_files = len(old_set & new_set)
    number_of_new_files = len(new_set - old_set)
    number_of_deleted_files = len(old_set - new_set)
    return number_of_unchanged_files, number_of_new_files, number_of_deleted_files

def main():
    cmd = sys.argv[1]
    assert cmd in ("diff", "init", "accept")
    eval("cmd_" + cmd + "()")

main()
