]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_tokens.py
Rename mysql2sql -> sqlconvert
[sqlconvert.git] / sqlconvert / process_tokens.py
1
2 from sqlparse import parse
3 from sqlparse.compat import PY3
4 from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
5     Whitespace
6
7
8 def requote_names(token_list):
9     """Remove backticks, quote non-lowercase identifiers"""
10     for token in token_list.flatten():
11         if token.ttype is Name:
12             value = token.value
13             if (value[0] == "`") and (value[-1] == "`"):
14                 value = value[1:-1]
15             if value.islower():
16                 token.normalized = token.value = value
17             else:
18                 token.normalized = token.value = '"%s"' % value
19
20
21 def find_error(token_list):
22     """Find an error"""
23     for token in token_list.flatten():
24         if token.ttype is Error:
25             return True
26     return False
27
28
29 if PY3:
30     xrange = range
31
32
33 class StatementGrouper(object):
34     """Collect lines and reparse until the last statement is complete"""
35
36     def __init__(self, encoding=None):
37         self.lines = []
38         self.statements = []
39         self.encoding = encoding
40
41     def process_line(self, line):
42         self.lines.append(line)
43         self.process_lines()
44
45     def process_lines(self):
46         statements = parse(''.join(self.lines), encoding=self.encoding)
47         last_stmt = statements[-1]
48         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
49             token = last_stmt.tokens[i]
50             if token.ttype in (Comment.Single, Comment.Multiline,
51                                Newline, Whitespace):
52                 continue
53             if token.ttype is Punctuation and token.value == ';':
54                 break  # The last statement is complete
55             # The last statement is still incomplete - wait for the next line
56             return
57         self.lines = []
58         self.statements = statements
59
60     def get_statements(self):
61         for stmt in self.statements:
62             yield stmt
63         self.statements = []
64
65     def close(self):
66         if not self.lines:
67             return
68         tokens = parse(''.join(self.lines), encoding=self.encoding)
69         for token in tokens:
70             if (token.ttype not in (Comment.Single, Comment.Multiline,
71                                     Newline, Whitespace)):
72                 raise ValueError("Incomplete SQL statement: %s" %
73                                  tokens)
74         return tokens