]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_mysql.py
Split extended INSERTs
[sqlconvert.git] / sqlconvert / process_mysql.py
1
2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, Statement
3 from sqlparse import tokens as T
4 from .process_tokens import escape_strings, is_comment_or_space
5
6
7 def _is_directive_token(token):
8     if isinstance(token, Comment):
9         subtokens = token.tokens
10         if subtokens:
11             comment = subtokens[0]
12             if comment.ttype is T.Comment.Multiline and \
13                     comment.value.startswith('/*!'):
14                 return True
15     return False
16
17
18 def is_directive_statement(statement):
19     tokens = statement.tokens
20     if not _is_directive_token(tokens[0]):
21         return False
22     if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
23         return False
24     for token in tokens[1:-1]:
25         if token.ttype not in (T.Newline, T.Whitespace):
26             return False
27     return True
28
29
30 def remove_directive_tokens(statement):
31     """Remove /*! directives */ from the first-level"""
32     new_tokens = []
33     for token in statement.tokens:
34         if _is_directive_token(token):
35             continue
36         new_tokens.append(token)
37     statement.tokens = new_tokens
38
39
40 def requote_names(token_list):
41     """Remove backticks, quote non-lowercase identifiers"""
42     for token in token_list.flatten():
43         if token.ttype is T.Name:
44             value = token.value
45             if (value[0] == "`") and (value[-1] == "`"):
46                 value = value[1:-1]
47             if value.islower():
48                 token.normalized = token.value = value
49             else:
50                 token.normalized = token.value = '"%s"' % value
51
52
53 def unescape_strings(token_list):
54     """Unescape strings"""
55     for token in token_list.flatten():
56         if token.ttype is T.String.Single:
57             value = token.value
58             for orig, repl in (
59                 ('\\"', '"'),
60                 ("\\'", "'"),
61                 ("''", "'"),
62                 ('\\b', '\b'),
63                 ('\\n', '\n'),
64                 ('\\r', '\r'),
65                 ('\\t', '\t'),
66                 ('\\\032', '\032'),
67                 ('\\\\', '\\'),
68             ):
69                 value = value.replace(orig, repl)
70             token.normalized = token.value = value
71
72
73 def is_insert(statement):
74     for token in statement.tokens:
75         if is_comment_or_space(token):
76             continue
77         return (token.ttype is T.DML) and (token.normalized == 'INSERT')
78
79
80 def split_ext_insert(statement):
81     """Split extended INSERT into multiple standard INSERTs"""
82     insert_tokens = []
83     values_tokens = []
84     last_token = None
85     expected = 'INSERT'
86     for token in statement.tokens:
87         if is_comment_or_space(token):
88             insert_tokens.append(token)
89             continue
90         elif expected == 'INSERT':
91             if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
92                 insert_tokens.append(token)
93                 expected = 'INTO'
94                 continue
95         elif expected == 'INTO':
96             if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
97                 insert_tokens.append(token)
98                 expected = 'TABLE_NAME'
99                 continue
100         elif expected == 'TABLE_NAME':
101             if isinstance(token, (Function, Identifier)):
102                 insert_tokens.append(token)
103                 expected = 'VALUES'
104                 continue
105         elif expected == 'VALUES':
106             if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
107                 insert_tokens.append(token)
108                 expected = 'VALUES_OR_SEMICOLON'
109                 continue
110         elif expected == 'VALUES_OR_SEMICOLON':
111             if isinstance(token, Parenthesis):
112                 values_tokens.append(token)
113                 continue
114             elif token.ttype is T.Punctuation:
115                 if token.value == ',':
116                     continue
117                 elif token.value == ';':
118                     last_token = token
119                     break
120         raise ValueError(
121             'SQL syntax error: expected "%s", got %s "%s"' % (
122                 expected, token.ttype, token.normalized))
123     for values in values_tokens:
124         # The statemnt sets `parent` attribute of the every token to self
125         # but we don't care.
126         vl = [values]
127         if last_token:
128             vl.append(last_token)
129         statement = Statement(insert_tokens + vl)
130         yield statement
131
132
133 def process_statement(statement, quoting_style='sqlite'):
134     requote_names(statement)
135     unescape_strings(statement)
136     remove_directive_tokens(statement)
137     escape_strings(statement, quoting_style)
138     if is_insert(statement):
139         for statement in split_ext_insert(statement):
140             yield statement
141     else:
142         yield statement