]> git.phdru.name Git - sqlconvert.git/blob - sqlconvert/process_mysql.py
Refactor(process_mysql): Rename quoting_style -> dbname
[sqlconvert.git] / sqlconvert / process_mysql.py
1
2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \
3     Statement, Token
4 from sqlparse import tokens as T
5 from .process_tokens import escape_strings, is_comment_or_space
6
7
8 def _is_directive_token(token):
9     if isinstance(token, Comment):
10         subtokens = token.tokens
11         if subtokens:
12             comment = subtokens[0]
13             if comment.ttype is T.Comment.Multiline and \
14                     comment.value.startswith('/*!'):
15                 return True
16     return False
17
18
19 def is_directive_statement(statement):
20     tokens = statement.tokens
21     if not _is_directive_token(tokens[0]):
22         return False
23     if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
24         return False
25     for token in tokens[1:-1]:
26         if token.ttype not in (T.Newline, T.Whitespace):
27             return False
28     return True
29
30
31 def remove_directive_tokens(statement):
32     """Remove /\*! directives \*/ from the first-level"""
33     new_tokens = []
34     for token in statement.tokens:
35         if _is_directive_token(token):
36             continue
37         new_tokens.append(token)
38     statement.tokens = new_tokens
39
40
41 def requote_names(token_list):
42     """Remove backticks, quote non-lowercase identifiers"""
43     for token in token_list.flatten():
44         if token.ttype is T.Name:
45             value = token.value
46             if (value[0] == "`") and (value[-1] == "`"):
47                 value = value[1:-1]
48             if value.islower():
49                 token.normalized = token.value = value
50             else:
51                 token.normalized = token.value = '"%s"' % value
52
53
54 def unescape_strings(token_list):
55     """Unescape strings"""
56     for token in token_list.flatten():
57         if token.ttype is T.String.Single:
58             value = token.value
59             for orig, repl in (
60                 ('\\"', '"'),
61                 ("\\'", "'"),
62                 ("''", "'"),
63                 ('\\b', '\b'),
64                 ('\\n', '\n'),
65                 ('\\r', '\r'),
66                 ('\\t', '\t'),
67                 ('\\\032', '\032'),
68                 ('\\\\', '\\'),
69             ):
70                 value = value.replace(orig, repl)
71             token.normalized = token.value = value
72
73
74 def get_DML_type(statement):
75     for token in statement.tokens:
76         if is_comment_or_space(token):
77             continue
78         if (token.ttype is T.DML):
79             return token.normalized
80         break
81     raise ValueError("Not a DML statement")
82
83
84 def split_ext_insert(statement):
85     """Split extended INSERT into multiple standard INSERTs"""
86     insert_tokens = []
87     values_tokens = []
88     end_tokens = []
89     expected = 'INSERT'
90     for token in statement.tokens:
91         if is_comment_or_space(token):
92             if expected == 'END':
93                 end_tokens.append(token)
94             else:
95                 insert_tokens.append(token)
96             continue
97         elif expected == 'INSERT':
98             if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
99                 insert_tokens.append(token)
100                 expected = 'INTO'
101                 continue
102         elif expected == 'INTO':
103             if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
104                 insert_tokens.append(token)
105                 expected = 'TABLE_NAME'
106                 continue
107         elif expected == 'TABLE_NAME':
108             if isinstance(token, (Function, Identifier)):
109                 insert_tokens.append(token)
110                 expected = 'VALUES'
111                 continue
112         elif expected == 'VALUES':
113             if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
114                 insert_tokens.append(token)
115                 expected = 'VALUES_OR_SEMICOLON'
116                 continue
117         elif expected == 'VALUES_OR_SEMICOLON':
118             if isinstance(token, Parenthesis):
119                 values_tokens.append(token)
120                 continue
121             elif token.ttype is T.Punctuation:
122                 if token.value == ',':
123                     continue
124                 elif token.value == ';':
125                     end_tokens.append(token)
126                     expected = 'END'
127                     continue
128         raise ValueError(
129             'SQL syntax error: expected "%s", got %s "%s"' % (
130                 expected, token.ttype, token.normalized))
131     new_line = Token(T.Newline, '\n')
132     new_lines = [new_line]  # Insert newlines between split statements
133     for i, values in enumerate(values_tokens):
134         if i == len(values_tokens) - 1:  # Last but one statement
135             # Insert newlines only between split statements but not after
136             new_lines = []
137         # The statement sets `parent` attribute of the every token to self
138         # but we don't care.
139         statement = Statement(insert_tokens + [values] +
140                               end_tokens + new_lines)
141         yield statement
142
143
144 def process_statement(statement, dbname='sqlite'):
145     requote_names(statement)
146     unescape_strings(statement)
147     remove_directive_tokens(statement)
148     escape_strings(statement, dbname)
149     try:
150         is_insert = get_DML_type(statement) == 'INSERT'
151     except ValueError:
152         is_insert = False
153     if is_insert:
154         for statement in split_ext_insert(statement):
155             yield statement
156     else:
157         yield statement