]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_tokens.py
Implement is_insert()
[sqlconvert.git] / sqlconvert / process_tokens.py
1
2 from sqlobject.converters import sqlrepr
3 from sqlparse import parse
4 from sqlparse.compat import PY3
5 from sqlparse import tokens as T
6
7
8 def find_error(token_list):
9     """Find an error"""
10     for token in token_list.flatten():
11         if token.ttype is T.Error:
12             return True
13     return False
14
15
16 def is_comment_or_space(token):
17     return token.ttype in (T.Comment.Single, T.Comment.Multiline,
18                            T.Newline, T.Whitespace)
19
20
21 def is_newline_statement(statement):
22     for token in statement.tokens[:]:
23         if token.ttype is not T.Newline:
24             return False
25     return True
26
27
28 def escape_strings(token_list, dbname):
29     """Escape strings"""
30     for token in token_list.flatten():
31         if token.ttype is T.String.Single:
32             value = token.value[1:-1]  # unquote by removing apostrophes
33             value = sqlrepr(value, dbname)
34             token.normalized = token.value = value
35
36
37 if PY3:
38     xrange = range
39
40
41 class StatementGrouper(object):
42     """Collect lines and reparse until the last statement is complete"""
43
44     def __init__(self, encoding=None):
45         self.lines = []
46         self.statements = []
47         self.encoding = encoding
48
49     def process_line(self, line):
50         self.lines.append(line)
51         self.process_lines()
52
53     def process_lines(self):
54         statements = parse(''.join(self.lines), encoding=self.encoding)
55         last_stmt = statements[-1]
56         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
57             token = last_stmt.tokens[i]
58             if is_comment_or_space(token):
59                 continue
60             if token.ttype is T.Punctuation and token.value == ';':
61                 break  # The last statement is complete
62             # The last statement is still incomplete - wait for the next line
63             return
64         self.lines = []
65         self.statements = statements
66
67     def get_statements(self):
68         for stmt in self.statements:
69             yield stmt
70         self.statements = []
71         return
72
73     def close(self):
74         if not self.lines:
75             return
76         tokens = parse(''.join(self.lines), encoding=self.encoding)
77         for token in tokens:
78             if not is_comment_or_space(token):
79                 raise ValueError("Incomplete SQL statement: %s" %
80                                  tokens)
81         self.lines = []
82         self.statements = []
83         return tokens