]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_mysql.py
Fix(process_mysql): Fix invalid escape sequence `\*`
[sqlconvert.git] / sqlconvert / process_mysql.py
1
2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \
3     Statement, Token, Values
4 from sqlparse import tokens as T
5 from .process_tokens import escape_strings, is_comment_or_space
6
7
8 def _is_directive_token(token):
9     if isinstance(token, Comment):
10         subtokens = token.tokens
11         if subtokens:
12             comment = subtokens[0]
13             if comment.ttype is T.Comment.Multiline and \
14                     comment.value.startswith('/*!'):
15                 return True
16     return False
17
18
19 def is_directive_statement(statement):
20     tokens = statement.tokens
21     if not _is_directive_token(tokens[0]):
22         return False
23     if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
24         return False
25     for token in tokens[1:-1]:
26         if token.ttype not in (T.Newline, T.Whitespace):
27             return False
28     return True
29
30
31 def remove_directive_tokens(statement):
32     """Remove /*! directives */ from the first-level"""
33     new_tokens = []
34     for token in statement.tokens:
35         if _is_directive_token(token):
36             continue
37         new_tokens.append(token)
38     statement.tokens = new_tokens
39
40
41 def requote_names(token_list):
42     """Remove backticks, quote non-lowercase identifiers"""
43     for token in token_list.flatten():
44         if token.ttype is T.Name:
45             value = token.value
46             if (value[0] == "`") and (value[-1] == "`"):
47                 value = value[1:-1]
48             if value.islower():
49                 token.normalized = token.value = value
50             else:
51                 token.normalized = token.value = '"%s"' % value
52
53
54 def unescape_strings(token_list):
55     """Unescape strings"""
56     for token in token_list.flatten():
57         if token.ttype is T.String.Single:
58             value = token.value
59             for orig, repl in (
60                 ('\\"', '"'),
61                 ("\\'", "'"),
62                 ("''", "'"),
63                 ('\\b', '\b'),
64                 ('\\n', '\n'),
65                 ('\\r', '\r'),
66                 ('\\t', '\t'),
67                 ('\\\032', '\032'),
68                 ('\\\\', '\\'),
69             ):
70                 value = value.replace(orig, repl)
71             token.normalized = token.value = value
72
73
74 def get_DML_type(statement):
75     for token in statement.tokens:
76         if is_comment_or_space(token):
77             continue
78         if (token.ttype is T.DML):
79             return token.normalized
80         break
81     raise ValueError("Not a DML statement")
82
83
84 def split_ext_insert(statement):
85     """Split extended INSERT into multiple standard INSERTs"""
86     insert_tokens = []
87     values_tokens = []
88     end_tokens = []
89     expected = 'INSERT'
90     for token in statement.tokens:
91         if is_comment_or_space(token):
92             if expected == 'END':
93                 end_tokens.append(token)
94             else:
95                 insert_tokens.append(token)
96             continue
97         elif expected == 'INSERT':
98             if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
99                 insert_tokens.append(token)
100                 expected = 'INTO'
101                 continue
102         elif expected == 'INTO':
103             if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
104                 insert_tokens.append(token)
105                 expected = 'TABLE_NAME'
106                 continue
107         elif expected == 'TABLE_NAME':
108             if isinstance(token, (Function, Identifier)):
109                 insert_tokens.append(token)
110                 expected = 'VALUES'
111                 continue
112         elif expected == 'VALUES':
113             if isinstance(token, Values):
114                 for subtoken in token.tokens:
115                     if isinstance(subtoken, Parenthesis):
116                         values_tokens.append(subtoken)
117                 insert_tokens.append(Token(T.Keyword, 'VALUES'))
118                 insert_tokens.append(Token(T.Whitespace, ' '))
119                 expected = 'VALUES_OR_SEMICOLON'
120                 continue
121             if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
122                 insert_tokens.append(token)
123                 expected = 'VALUES_OR_SEMICOLON'
124                 continue
125         elif expected == 'VALUES_OR_SEMICOLON':
126             if isinstance(token, Parenthesis):
127                 values_tokens.append(token)
128                 continue
129             elif token.ttype is T.Punctuation:
130                 if token.value == ',':
131                     continue
132                 elif token.value == ';':
133                     end_tokens.append(token)
134                     expected = 'END'
135                     continue
136         raise ValueError(
137             'SQL syntax error: expected "%s", got %s "%s"' % (
138                 expected, token.ttype, token.normalized))
139     new_line = Token(T.Newline, '\n')
140     new_lines = [new_line]  # Insert newlines between split statements
141     for i, values in enumerate(values_tokens):
142         if i == len(values_tokens) - 1:  # Last but one statement
143             # Insert newlines only between split statements but not after
144             new_lines = []
145         # The statement sets `parent` attribute of the every token to self
146         # but we don't care.
147         statement = Statement(insert_tokens + [values] +
148                               end_tokens + new_lines)
149         yield statement
150
151
152 def process_statement(statement, dbname='sqlite'):
153     requote_names(statement)
154     unescape_strings(statement)
155     remove_directive_tokens(statement)
156     escape_strings(statement, dbname)
157     try:
158         dml_type = get_DML_type(statement)
159     except ValueError:
160         dml_type = 'UNKNOWN'
161     if dml_type == 'INSERT':
162         for statement in split_ext_insert(statement):
163             yield statement
164     else:
165         yield statement