]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_tokens.py
Build(GHActions): Use `checkout@v4` instead of outdated `v2`
[sqlconvert.git] / sqlconvert / process_tokens.py
1
2 from sqlparse.sql import Comment
3 from sqlobject.converters import sqlrepr
4 from sqlparse import parse
5 from sqlparse import tokens as T
6
7 try:
8     xrange
9 except NameError:
10     xrange = range
11
12
13 def find_error(token_list):
14     """Find an error"""
15     for token in token_list.flatten():
16         if token.ttype is T.Error:
17             return True
18     return False
19
20
21 def is_comment_or_space(token):
22     return isinstance(token, Comment) or \
23         token.ttype in (T.Comment, T.Comment.Single, T.Comment.Multiline,
24                         T.Newline, T.Whitespace)
25
26
27 def is_newline_statement(statement):
28     for token in statement.tokens[:]:
29         if token.ttype is not T.Newline:
30             return False
31     return True
32
33
34 def escape_strings(token_list, dbname):
35     """Escape strings"""
36     for token in token_list.flatten():
37         if token.ttype is T.String.Single:
38             value = token.value[1:-1]  # unquote by removing apostrophes
39             value = sqlrepr(value, dbname)
40             token.normalized = token.value = value
41
42
43 class StatementGrouper(object):
44     """Collect lines and reparse until the last statement is complete"""
45
46     def __init__(self, encoding=None):
47         self.lines = []
48         self.statements = []
49         self.encoding = encoding
50
51     def process_line(self, line):
52         self.lines.append(line)
53         self.process_lines()
54
55     def process_lines(self):
56         statements = parse(''.join(self.lines), encoding=self.encoding)
57         if not statements:
58             return
59         last_stmt = statements[-1]
60         for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
61             token = last_stmt.tokens[i]
62             if is_comment_or_space(token):
63                 continue
64             if token.ttype is T.Punctuation and token.value == ';':
65                 break  # The last statement is complete
66             # The last statement is still incomplete - wait for the next line
67             return
68         self.lines = []
69         self.statements = statements
70
71     def get_statements(self):
72         for stmt in self.statements:
73             yield stmt
74         self.statements = []
75         return
76
77     def close(self):
78         if not self.lines:
79             return
80         tokens = parse(''.join(self.lines), encoding=self.encoding)
81         for token in tokens:
82             if not is_comment_or_space(token):
83                 raise ValueError("Incomplete SQL statement: %s" %
84                                  tokens)
85         self.lines = []
86         self.statements = []
87         return tokens