aboutsummaryrefslogtreecommitdiff
path: root/tracing/graph.py
blob: aa51b6a6b0379c608b4eda663743dbbad9fb608f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#! /usr/bin/env python

import sys
import yaml
import array

class Call(object):
    def __init__(self, call):
        self.func, = call
        args = dict(call[self.func])
        self.output = array.array("B", args.pop("output")).tostring()
        self.inputs = {
            name: array.array("B", args[name]).tostring()
            for name in args
            if not name.endswith("_length")
        }
        self.bind = {}

    def expr(self, stream, indent="  ", level=""):
        stream.write(self.func + "(\n")
        for name, value in self.inputs.items():
            stream.write(level + indent + name + "=")
            self.bind.get(name, Literal(value)).expr(
                stream, indent, level + indent
            )
            stream.write(",\n")
        stream.write(level + ")")


class Literal(str):
    def expr(self, stream, indent, level):
        stream.write("\"" + self.encode("hex") + "\"")


class Slice(object):
    def __init__(self, thing, start, end):
        self.thing = thing
        self.start = start
        self.end = end

    def expr(self, stream, indent="  ", level=""):
        self.thing.expr(stream, indent, level)
        stream.write("[%d:%d]" % (self.start, self.end))


class Concat(list):
    def expr(self, stream, indent="  ", level=""):
        stream.write("concat(\n")
        for thing in self:
            stream.write(level + indent)
            thing.expr(stream, indent, level + indent)
            stream.write(",\n")
        stream.write(level + ")")


calls = [Call(c) for c in yaml.load(sys.stdin)]

outputs = {}

for call in calls:
    for i in range(8, len(call.output)):
        outputs.setdefault(call.output[i - 8: i], []).append(call)

for call in calls:
    for name, value in call.inputs.items():
        for bind in outputs.get(value[:8], ()):
            if value == bind.output:
                call.bind[name] = bind
            else:
                for end in range(len(value), len(bind.output) + 1):
                    start = end - len(value)
                    if value == bind.output[start:end]:
                        call.bind[name] = Slice(bind, start, end)
        if not name in call.bind:
            i = 0
            j = 1
            k = 0
            concat = Concat()
            while i < len(value):
                for bind in outputs.get(value[i:i+8], ()):
                    if value[i:].startswith(bind.output):
                        if k != i:
                            concat.append(Literal(value[k:i]))
                        concat.append(bind)
                        j = len(bind.output)
                        k = i + j
                        break
                i += j
                j = 1
            if concat:
                if k != i:
                    concat.append(Literal(value[k:i]))
                call.bind[name] = concat

for call in calls:
    if call.func.startswith("h"):
        sys.stdout.write("\"" + call.output.encode("hex") + "\" = ")
        call.expr(sys.stdout)
        sys.stdout.write("\n")