2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \
3 Statement, Token, Values
4 from sqlparse import tokens as T
5 from .process_tokens import escape_strings, is_comment_or_space
8 def _is_directive_token(token):
9 if isinstance(token, Comment):
10 subtokens = token.tokens
12 comment = subtokens[0]
13 if comment.ttype is T.Comment.Multiline and \
14 comment.value.startswith('/*!'):
19 def is_directive_statement(statement):
20 tokens = statement.tokens
21 if not _is_directive_token(tokens[0]):
23 if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
25 for token in tokens[1:-1]:
26 if token.ttype not in (T.Newline, T.Whitespace):
31 def remove_directive_tokens(statement):
32 """Remove /*! directives */ from the first-level"""
34 for token in statement.tokens:
35 if _is_directive_token(token):
37 new_tokens.append(token)
38 statement.tokens = new_tokens
41 def requote_names(token_list):
42 """Remove backticks, quote non-lowercase identifiers"""
43 for token in token_list.flatten():
44 if token.ttype is T.Name:
46 if (value[0] == "`") and (value[-1] == "`"):
49 token.normalized = token.value = value
51 token.normalized = token.value = '"%s"' % value
54 def unescape_strings(token_list):
55 """Unescape strings"""
56 for token in token_list.flatten():
57 if token.ttype is T.String.Single:
70 value = value.replace(orig, repl)
71 token.normalized = token.value = value
74 def get_DML_type(statement):
75 for token in statement.tokens:
76 if is_comment_or_space(token):
78 if (token.ttype is T.DML):
79 return token.normalized
81 raise ValueError("Not a DML statement")
84 def split_ext_insert(statement):
85 """Split extended INSERT into multiple standard INSERTs"""
90 for token in statement.tokens:
91 if is_comment_or_space(token):
93 end_tokens.append(token)
95 insert_tokens.append(token)
97 elif expected == 'INSERT':
98 if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
99 insert_tokens.append(token)
102 elif expected == 'INTO':
103 if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
104 insert_tokens.append(token)
105 expected = 'TABLE_NAME'
107 elif expected == 'TABLE_NAME':
108 if isinstance(token, (Function, Identifier)):
109 insert_tokens.append(token)
112 elif expected == 'VALUES':
113 if isinstance(token, Values):
114 for subtoken in token.tokens:
115 if isinstance(subtoken, Parenthesis):
116 values_tokens.append(subtoken)
117 insert_tokens.append(Token(T.Keyword, 'VALUES'))
118 insert_tokens.append(Token(T.Whitespace, ' '))
119 expected = 'VALUES_OR_SEMICOLON'
121 if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
122 insert_tokens.append(token)
123 expected = 'VALUES_OR_SEMICOLON'
125 elif expected == 'VALUES_OR_SEMICOLON':
126 if isinstance(token, Parenthesis):
127 values_tokens.append(token)
129 elif token.ttype is T.Punctuation:
130 if token.value == ',':
132 elif token.value == ';':
133 end_tokens.append(token)
137 'SQL syntax error: expected "%s", got %s "%s"' % (
138 expected, token.ttype, token.normalized))
139 new_line = Token(T.Newline, '\n')
140 new_lines = [new_line] # Insert newlines between split statements
141 for i, values in enumerate(values_tokens):
142 if i == len(values_tokens) - 1: # Last but one statement
143 # Insert newlines only between split statements but not after
145 # The statement sets `parent` attribute of the every token to self
147 statement = Statement(insert_tokens + [values] +
148 end_tokens + new_lines)
152 def process_statement(statement, dbname='sqlite'):
153 requote_names(statement)
154 unescape_strings(statement)
155 remove_directive_tokens(statement)
156 escape_strings(statement, dbname)
158 dml_type = get_DML_type(statement)
161 if dml_type == 'INSERT':
162 for statement in split_ext_insert(statement):