#!/usr/bin/env python2.4

# $Id: svn-fixup-rename 8566 2005-08-20 21:06:19Z quarl $

# svn-fixup-rename: heuristically find moved files to fixup.

#   syntax: svn-fixup-rename [paths...]
#
#     Heuristically matches up files in the specified paths.
#
#     If no paths specified, use current directory.

# Algorithm:
#
#   0. Find missing and new files.
#
#   1. Compare the cross product of new_files, missing_files and rate each
#      match according to whether the contents are the same, and the basenames
#      are the same.  [For efficiency, only compare the ones where the md5sum
#      or the basename is the same.]
#
#   2. Make each match in order of descending score.
#

## quarl 2005-08-08 initial version

import os
import sys
import subprocess
import re
import md5
import stat
import itertools
import math
# import heapq

CONTENTS_MATCH_WEIGHT = 100
BASENAME_MATCH_WEIGHT = 10
EMPTY_FILES_SCORE = 5

def get_file_xcontents(filename):
    if os.path.islink(filename):
        # compatible with Subversion 1.1's symlink contents (we don't check
        # the svn:special property, but this works well enough)
        return 'link '+os.readlink(filename)
    if os.path.isfile(filename):
        return open(filename).read()
    raise Exception("Path '%s' is not a file or symlink"%filename)

def get_file_md5(filename):
    return md5.md5(get_file_xcontents(filename)).digest()

def get_file_size(filename):
    return os.lstat(filename)[stat.ST_SIZE]

def mean(numbers):
    '''Return arithmetic mean of NUMBERS.'''
    # Don't use len(numbers) so we can use any iterable.
    sum = 0
    count = 0
    for number in numbers:
        sum += number
        count += 1
    if not count:
        return 'NaN'
    return float(sum) / count

def uniquify(list):
    '''Return LIST with only unique items.  Use only for very short lists.'''
    result = []
    for item in list:
        if item not in result:
            result.append(item)
    return result

def script_directory():
    return os.path.dirname(os.path.abspath(sys.argv[0]))

def do_svn_rename(old_filename, new_filename):
    # Add directory of script to path, in case we're using sudo(8)
    env=os.environ.copy()
    env['PATH'] = script_directory() + ':' + env['PATH']
    exitcode = subprocess.call(['svn-x-mv', old_filename, new_filename],
                               env=env)
    if exitcode:
        print >>sys.stderr, "%s: warning: svn-x-mv %s %s failed" %(sys.argv[0], old_filename, new_filename)

class DictOfLists(dict):
    def add(self, key, value):
        self.setdefault(key, []).append(value)

class FileDescriptor:
    def __init__(self, filename, use_svn_base):
        self.filename = filename
        self.basename = os.path.basename(filename)
        if use_svn_base:
            # ffilename = os.path.join(os.path.dirname(filename), '.svn', 'text-base',
            #                          os.path.basename(filename) + '.svn-base')

            # The above method of using the text-base doesn't work because the
            # text-base doesn't have keyword substitutions.
            if os.path.exists(filename):
                raise SystemExit("Somehow file '%s' exists?!"%filename)
            # temporarily revert the file, so we can get its contents.
            revert = subprocess.Popen(['svn', 'revert', filename], stdout=subprocess.PIPE).communicate()[0]
            if revert != "Reverted '%s'\n"%filename:
                raise SystemExit("Revert file '%s' failed" %filename)
            ffilename = filename
        else:
            ffilename = filename
        self.digest = get_file_md5(ffilename)
        self.size = get_file_size(ffilename)
        self.matched = False

        if use_svn_base:
            os.unlink(filename)

    def __cmp__(self, other):
        return cmp(self.filename, other.filename)

class FileList:
    def __init__(self):
        self.files = []
        self.bases = DictOfLists()
        self.digests = DictOfLists()

    def add(self, filename, use_svn_base=False):
        if os.path.isdir(filename):
            print "Skipping directory", filename
            return
        fd = FileDescriptor(filename,use_svn_base)
        self.files.append(fd)
        self.bases.add(fd.basename, fd)
        self.digests.add(fd.digest, fd)

    def lookup_by_basename(self, filename):
        basename = os.path.basename(filename)

    def __iter__(self):
        return iter(self.files)
    def __len__(self):
        return len(self.files)

class SVNFixupRename:
    def find_files(self, paths):
        pipe = subprocess.Popen(["svn", "status"]+paths, stdout=subprocess.PIPE)
        svn_status = pipe.communicate()[0]
        exitcode = pipe.returncode
        if exitcode != 0:
            raise SystemExit("%s: Failed to get status of %s" %(sys.argv[0], ' '.join(paths)))
        for line in svn_status.split('\n'):
            m = re.match('[?]......(.*)', line)
            if m:
                self.new_files.add(m.group(1))
                continue
            m = re.match('[!]......(.*)', line)
            if m:
                self.missing_files.add(m.group(1), use_svn_base=True)

    def calculate_averages(self):
        self.average_file_size = mean(
            fd.size for fd in itertools.chain(self.new_files, self.missing_files))

        self.average_basename_length = mean(
            len(fd.basename) for fd in itertools.chain(self.new_files, self.missing_files))

    def find_matches(self):
        for new_file in self.new_files:
            self.consider_matches(new_file,
                                  (self.missing_files.bases.get(new_file.basename,[])+
                                   self.missing_files.digests.get(new_file.digest,[])))

    def consider_matches(self, new_file, potential_match_files):
        potential_match_files = filter(None, potential_match_files)
        potential_match_files = uniquify(potential_match_files)
        for file in potential_match_files:
            self.potential_matches.append((self.rate_match(file, new_file), file, new_file))

    def rate_match(self, old_file, new_file):
        score = 0
        if old_file.size == 0 and new_file.size == 0:
            score += EMPTY_FILES_SCORE
        if old_file.digest == new_file.digest:
            score += CONTENTS_MATCH_WEIGHT * math.sqrt(old_file.size / self.average_file_size)
        if old_file.basename == new_file.basename:
            score += BASENAME_MATCH_WEIGHT * math.sqrt(len(old_file.basename))
        return score

    def make_matches(self):
        for (score, old_file, new_file) in sorted(self.potential_matches,reverse=True):
            if not old_file.matched and not new_file.matched:
                self.matches.append((score, old_file, new_file))
                old_file.matched = True
                new_file.matched = True

    def show_matches(self):
        if self.matches:
            print "Matched pairs:"
            count = 0
            for (score, old_file, new_file) in self.matches:
                count += 1
                print "    %d.  [score: %d] %s -> %s" %(count, score, old_file.filename, new_file.filename)

        if len(self.missing_files) > len(self.matches):
            print "Unmatched missing files:"
            for file in sorted(self.missing_files.files):
                if not file.matched:
                    print "      %s" % file.filename

        if len(self.new_files) > len(self.matches):
            print "Unmatched new files:"
            for file in sorted(self.new_files.files):
                if not file.matched:
                    print "      %s" % file.filename

        print

    def query_continue(self):
        if not self.missing_files:
            raise SystemExit("No missing files!")
        if not self.new_files:
            raise SystemExit("No new files!")
        if not self.matches:
            raise SystemExit("No matches!")

        print "Continue? [Y/n]",
        try:
            if raw_input().strip().lower()[:1] == 'n':
                raise SystemExit("Aborted.")
        except KeyboardInterrupt:
            raise SystemExit("Aborted.")

    def apply_matches(self):
        for (score, old_file, new_file) in self.matches:
            do_svn_rename(old_file.filename, new_file.filename)

    def __init__(self):
        self.new_files = FileList()
        self.missing_files = FileList()
        self.potential_matches = []
        self.matches = []

    def main(self, args):
        self.find_files(args)
        self.calculate_averages()
        self.find_matches()
        self.make_matches()
        self.show_matches()
        self.query_continue()
        self.apply_matches()


SVNFixupRename().main(sys.argv[1:] or ['.'])

