]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_tokens.py
Reorder processing
[sqlconvert.git] / sqlconvert / process_tokens.py
1
2 from sqlobject.converters import sqlrepr
3 from sqlparse import parse
4 from sqlparse.compat import PY3
5 from sqlparse import tokens as T
6
7
8 def find_error(token_list):
9     """Find an error"""
10     for token in token_list.flatten():
11         if token.ttype is T.Error:
12             return True
13     return False
14
15
16 def is_newline_statement(statement):
17     for token in statement.tokens[:]:
18         if token.ttype is not T.Newline:
19             return False
20     return True
21
22
23 def escape_strings(token_list, dbname):
24     """Escape strings"""
25     for token in token_list.flatten():
26         if token.ttype is T.String.Single:
27             value = token.value[1:-1]  # unquote by removing apostrophes
28             value = sqlrepr(value, dbname)
29             token.normalized = token.value = value
30
31
32 if PY3:
33     xrange = range
34
35
36 class StatementGrouper(object):
37     """Collect lines and reparse until the last statement is complete"""
38
39     def __init__(self, encoding=None):
40         self.lines = []
41         self.statements = []
42         self.encoding = encoding
43
44     def process_line(self, line):
45         self.lines.append(line)
46         self.process_lines()
47
48     def process_lines(self):
49         statements = parse(''.join(self.lines), encoding=self.encoding)
50         last_stmt = statements[-1]
51         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
52             token = last_stmt.tokens[i]
53             if token.ttype in (T.Comment.Single, T.Comment.Multiline,
54                                T.Newline, T.Whitespace):
55                 continue
56             if token.ttype is T.Punctuation and token.value == ';':
57                 break  # The last statement is complete
58             # The last statement is still incomplete - wait for the next line
59             return
60         self.lines = []
61         self.statements = statements
62
63     def get_statements(self):
64         for stmt in self.statements:
65             yield stmt
66         self.statements = []
67         return
68
69     def close(self):
70         if not self.lines:
71             return
72         tokens = parse(''.join(self.lines), encoding=self.encoding)
73         for token in tokens:
74             if (token.ttype not in (T.Comment.Single, T.Comment.Multiline,
75                                     T.Newline, T.Whitespace)):
76                 raise ValueError("Incomplete SQL statement: %s" %
77                                  tokens)
78         self.lines = []
79         self.statements = []
80         return tokens