1import compiler
2import dis
3import types
4
5def extract_code_objects(co):
6    l = [co]
7    for const in co.co_consts:
8        if type(const) == types.CodeType:
9            l.append(const)
10    return l
11
12def compare(a, b):
13    if not (a.co_name == "?" or a.co_name.startswith('<lambda')):
14        assert a.co_name == b.co_name, (a, b)
15    if a.co_stacksize != b.co_stacksize:
16        print "stack mismatch %s: %d vs. %d" % (a.co_name,
17                                                a.co_stacksize,
18                                                b.co_stacksize)
19        if a.co_stacksize > b.co_stacksize:
20            print "good code"
21            dis.dis(a)
22            print "bad code"
23            dis.dis(b)
24            assert 0
25
26def main(files):
27    for file in files:
28        print file
29        buf = open(file).read()
30        try:
31            co1 = compile(buf, file, "exec")
32        except SyntaxError:
33            print "skipped"
34            continue
35        co2 = compiler.compile(buf, file, "exec")
36        co1l = extract_code_objects(co1)
37        co2l = extract_code_objects(co2)
38        for a, b in zip(co1l, co2l):
39            compare(a, b)
40
41if __name__ == "__main__":
42    import sys
43    main(sys.argv[1:])
44