]> git.phdru.name Git - sqlconvert.git/blob - mysql2sql/process_tokens.py
Collect lines and reparse until the last statement is complete
[sqlconvert.git] / mysql2sql / process_tokens.py
1
2 from sqlparse import parse
3 from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
4     Whitespace
5
6
7 def requote_names(token_list):
8     """Remove backticks, quote non-lowercase identifiers"""
9     for token in token_list.flatten():
10         if token.ttype is Name:
11             value = token.value
12             if (value[0] == "`") and (value[-1] == "`"):
13                 value = value[1:-1]
14             if value.islower():
15                 token.normalized = token.value = value
16             else:
17                 token.normalized = token.value = '"%s"' % value
18
19
20 def find_error(token_list):
21     """Find an error"""
22     for token in token_list.flatten():
23         if token.ttype is Error:
24             return True
25     return False
26
27
28 class StatementGrouper(object):
29     """Collect lines and reparse until the last statement is complete"""
30
31     def __init__(self):
32         self.lines = []
33         self.statements = []
34
35     def process_line(self, line):
36         self.lines.append(line)
37         self.process_lines()
38
39     def process_lines(self):
40         statements = parse('\n'.join(self.lines))
41         last_stmt = statements[-1]
42         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
43             token = last_stmt.tokens[i]
44             if token.ttype in (Comment.Single, Comment.Multiline,
45                                Newline, Whitespace):
46                 continue
47             if token.ttype is Punctuation and token.value == ';':
48                 break  # The last statement is complete
49             # The last statement is still incomplete - wait for the next line
50             return
51         self.lines = []
52         self.statements = statements
53
54     def get_statements(self):
55         for stmt in self.statements:
56             yield stmt
57         self.statements = []
58
59     def close(self):
60         if not self.lines:
61             return
62         tokens = parse('\n'.join(self.lines))
63         for token in tokens:
64             if (token.ttype not in (Comment.Single, Comment.Multiline,
65                                     Newline, Whitespace)):
66                 raise ValueError("Incomplete SQL statement: %s" %
67                                  tokens)
68         return tokens