]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_mysql.py
Fix misspelling in a comment
[sqlconvert.git] / sqlconvert / process_mysql.py
1
2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \
3     Statement, Token
4 from sqlparse import tokens as T
5 from .process_tokens import escape_strings, is_comment_or_space
6
7
8 def _is_directive_token(token):
9     if isinstance(token, Comment):
10         subtokens = token.tokens
11         if subtokens:
12             comment = subtokens[0]
13             if comment.ttype is T.Comment.Multiline and \
14                     comment.value.startswith('/*!'):
15                 return True
16     return False
17
18
19 def is_directive_statement(statement):
20     tokens = statement.tokens
21     if not _is_directive_token(tokens[0]):
22         return False
23     if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
24         return False
25     for token in tokens[1:-1]:
26         if token.ttype not in (T.Newline, T.Whitespace):
27             return False
28     return True
29
30
31 def remove_directive_tokens(statement):
32     """Remove /\*! directives \*/ from the first-level"""
33     new_tokens = []
34     for token in statement.tokens:
35         if _is_directive_token(token):
36             continue
37         new_tokens.append(token)
38     statement.tokens = new_tokens
39
40
41 def requote_names(token_list):
42     """Remove backticks, quote non-lowercase identifiers"""
43     for token in token_list.flatten():
44         if token.ttype is T.Name:
45             value = token.value
46             if (value[0] == "`") and (value[-1] == "`"):
47                 value = value[1:-1]
48             if value.islower():
49                 token.normalized = token.value = value
50             else:
51                 token.normalized = token.value = '"%s"' % value
52
53
54 def unescape_strings(token_list):
55     """Unescape strings"""
56     for token in token_list.flatten():
57         if token.ttype is T.String.Single:
58             value = token.value
59             for orig, repl in (
60                 ('\\"', '"'),
61                 ("\\'", "'"),
62                 ("''", "'"),
63                 ('\\b', '\b'),
64                 ('\\n', '\n'),
65                 ('\\r', '\r'),
66                 ('\\t', '\t'),
67                 ('\\\032', '\032'),
68                 ('\\\\', '\\'),
69             ):
70                 value = value.replace(orig, repl)
71             token.normalized = token.value = value
72
73
74 def is_insert(statement):
75     for token in statement.tokens:
76         if is_comment_or_space(token):
77             continue
78         return (token.ttype is T.DML) and (token.normalized == 'INSERT')
79
80
81 def split_ext_insert(statement):
82     """Split extended INSERT into multiple standard INSERTs"""
83     insert_tokens = []
84     values_tokens = []
85     end_tokens = []
86     expected = 'INSERT'
87     for token in statement.tokens:
88         if is_comment_or_space(token):
89             if expected == 'END':
90                 end_tokens.append(token)
91             else:
92                 insert_tokens.append(token)
93             continue
94         elif expected == 'INSERT':
95             if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
96                 insert_tokens.append(token)
97                 expected = 'INTO'
98                 continue
99         elif expected == 'INTO':
100             if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
101                 insert_tokens.append(token)
102                 expected = 'TABLE_NAME'
103                 continue
104         elif expected == 'TABLE_NAME':
105             if isinstance(token, (Function, Identifier)):
106                 insert_tokens.append(token)
107                 expected = 'VALUES'
108                 continue
109         elif expected == 'VALUES':
110             if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
111                 insert_tokens.append(token)
112                 expected = 'VALUES_OR_SEMICOLON'
113                 continue
114         elif expected == 'VALUES_OR_SEMICOLON':
115             if isinstance(token, Parenthesis):
116                 values_tokens.append(token)
117                 continue
118             elif token.ttype is T.Punctuation:
119                 if token.value == ',':
120                     continue
121                 elif token.value == ';':
122                     end_tokens.append(token)
123                     expected = 'END'
124                     continue
125         raise ValueError(
126             'SQL syntax error: expected "%s", got %s "%s"' % (
127                 expected, token.ttype, token.normalized))
128     new_line = Token(T.Newline, '\n')
129     new_lines = [new_line]  # Insert newlines between split statements
130     for i, values in enumerate(values_tokens):
131         if i == len(values_tokens) - 1:  # Last but one statement
132             # Insert newlines only between split statements but not after
133             new_lines = []
134         # The statement sets `parent` attribute of the every token to self
135         # but we don't care.
136         statement = Statement(insert_tokens + [values] +
137                               end_tokens + new_lines)
138         yield statement
139
140
141 def process_statement(statement, quoting_style='sqlite'):
142     requote_names(statement)
143     unescape_strings(statement)
144     remove_directive_tokens(statement)
145     escape_strings(statement, quoting_style)
146     if is_insert(statement):
147         for statement in split_ext_insert(statement):
148             yield statement
149     else:
150         yield statement