]> git.somenet.org - pub/jan/dst18.git/blob - ass1-jpa/src/test/java/dst/ass1/jpa/DatabaseGateway.java
hash password!
[pub/jan/dst18.git] / ass1-jpa / src / test / java / dst / ass1 / jpa / DatabaseGateway.java
1 package dst.ass1.jpa;
2
3 import java.sql.Connection;
4 import java.sql.ResultSet;
5 import java.sql.ResultSetMetaData;
6 import java.sql.SQLException;
7 import java.sql.Statement;
8 import java.util.ArrayList;
9 import java.util.HashMap;
10 import java.util.List;
11 import java.util.Map;
12 import java.util.Objects;
13 import java.util.stream.Collectors;
14
15 import javax.persistence.EntityManager;
16 import javax.persistence.metamodel.Type;
17
18 import org.hibernate.Session;
19 import org.hibernate.jdbc.ReturningWork;
20
21 // DO NOT MODIFY THIS CLASS.
22
23 /**
24  * Contains various methods for accessing the database underlying an EntityManager.
25  *
26  * Note that the caller is responsible for dealing with possible exceptions as well as doing the connection handling. A
27  * connection will not be closed even if a fatal error occurs. However, other SQL resources i.e.,
28  * {@link Statement Statements} and {@link ResultSet ResultSets} created within the methods, which are not returned to
29  * the caller, are closed before the method returns.
30  */
31 public class DatabaseGateway {
32
33     private final EntityManager em;
34
35     public DatabaseGateway(EntityManager em) {
36         this.em = em;
37     }
38
39     /**
40      * Returns a list of all table-names for the given database/connection.
41      *
42      * @return List of table names
43      */
44     public List<String> getTables() {
45         return getSession().doReturningWork(new CollectionWork<>("show tables", rs -> rs.getString(1)));
46     }
47
48     /**
49      * Returns a list of all column names in the given table.
50      *
51      * @param tableName the table
52      * @return a list of column names
53      */
54     public List<String> getColumns(String tableName) {
55         return getColumnsDefinitions(tableName).stream().map(m -> m.get("COLUMN_NAME")).collect(Collectors.toList());
56     }
57
58     public List<Map<String, String>> getColumnsDefinitions(String tableName) {
59         String sql = String.format("SELECT * FROM information_schema.columns "
60                 + "WHERE table_name='%s'", tableName.toUpperCase());
61
62         return getSession().doReturningWork(new QueryWork<List<Map<String, String>>>(sql) {
63             @Override
64             protected List<Map<String, String>> execute(ResultSet rs) throws SQLException {
65                 List<Map<String, String>> list = new ArrayList<>();
66                 while (rs.next()) {
67                     ResultSetMetaData meta = rs.getMetaData();
68                     Map<String, String> map = new HashMap<>();
69                     for (int i = 1; i <= meta.getColumnCount(); i++) {
70                         String key = meta.getColumnName(i);
71                         String value = rs.getString(key);
72                         map.put(key, value);
73                     }
74                     list.add(map);
75                 }
76                 return list;
77             }
78         });
79     }
80
81     /**
82      * Returns the java types of all managed entity types.
83      *
84      * @return a list of java types
85      */
86     public List<Class<?>> getManagedJavaTypes() {
87         return em.getMetamodel()
88                 .getManagedTypes().stream()
89                 .map(Type::getJavaType)
90                 .collect(Collectors.toList());
91     }
92
93     /**
94      * Checks if the named table can be accessed via the given EntityManager.
95      *
96      * @param tableName the name of the table to find
97      * @return {@code true} if the database schema contains a table with the given name, {@code false} otherwise
98      */
99     public boolean isTable(final String tableName) {
100         return getSession().doReturningWork(new QueryWork<Boolean>("show tables") {
101             @Override
102             public Boolean execute(ResultSet rs) throws SQLException {
103                 while (rs.next()) {
104                     String tbl = rs.getString(1);
105                     if (tbl.equalsIgnoreCase(tableName)) {
106                         return true;
107                     }
108                 }
109                 return false;
110             }
111
112         });
113     }
114
115     /**
116      * Checks whether a certain database table contains a column with the given
117      * name.
118      *
119      * @param tableName the name of the table to check
120      * @param column the name of the column to find
121      * @return {@code true} if the table contains the column, {@code false} otherwise
122      */
123     public boolean isColumnInTable(String tableName, String column) {
124         String sql = String.format(
125                 "SELECT * FROM information_schema.columns WHERE table_name='%s' and column_name='%s'",
126                 tableName.toUpperCase(), column.toUpperCase()
127         );
128
129         return getSession().doReturningWork(new HasAtLeastOneEntry(sql));
130     }
131
132     /**
133      * Checks whether a table contains a column of the given type and length.
134      *
135      * @param tableName the table to look for
136      * @param column the expected column name
137      * @param type the expected column type
138      * @param length the expected column length
139      * @return true if the information schema has at least one such column
140      */
141     public boolean isColumnInTableWithType(String tableName, String column, String type, String length) {
142         String sql = String.format("SELECT * FROM information_schema.columns "
143                         + "WHERE table_name='%s' and column_name='%s' and "
144                         + "type_name='%s' and character_maximum_length='%s'",
145                 tableName.toUpperCase(), column.toUpperCase(), type.toUpperCase(), length);
146
147         return getSession().doReturningWork(new HasAtLeastOneEntry(sql));
148     }
149
150     /**
151      * Checks whether a certain table contains an index for the given column
152      * name.
153      *
154      * @param tableName the name of the table to check
155      * @param indexName the name of the column the index is created for
156      * @param nonUnique {@code true} if the index is non unique, {@code false} otherwise
157      * @return {@code true} if the index exists, {@code false} otherwise
158      */
159     public boolean isIndex(String tableName, String indexName, boolean nonUnique) {
160
161         String sql = String.format(
162                 "SELECT * FROM information_schema.indexes WHERE table_name='%s' and column_name='%s' and non_unique='%s'",
163                 tableName.toUpperCase(), indexName.toUpperCase(), nonUnique ? "1" : "0"
164         );
165
166         return getSession().doReturningWork(new HasAtLeastOneEntry(sql));
167     }
168
169     public boolean isComposedIndex(String tableName, String columnName1, String columnName2) {
170         String indexName1 = getIndexName(tableName, columnName1);
171         String indexName2 = getIndexName(tableName, columnName2);
172
173         return Objects.nonNull(indexName1) && Objects.equals(indexName1, indexName2);
174     }
175
176     private String getIndexName(String tableName, String columnName) {
177         String sql = String.format(
178                 "SELECT index_name FROM information_schema.indexes WHERE table_name='%s' and column_name='%s'",
179                 tableName.toUpperCase(), columnName.toUpperCase()
180         );
181
182         return getSession().doReturningWork(new QueryWork<String>(sql) {
183             @Override
184             protected String execute(ResultSet rs) throws SQLException {
185                 return (rs.next()) ? rs.getString(1) : null;
186             }
187         });
188     }
189
190     /**
191      * Checks whether the given column of a certain table can contain {@code NULL} values.
192      *
193      * @param tableName the name of the table to check
194      * @param columnName the name of the column to check
195      * @return {@code true} if the column is nullable, {@code false} otherwise
196      */
197     public boolean isNullable(String tableName, String columnName) {
198         String sql = String.format(
199                 "SELECT * FROM information_schema.columns " +
200                         "WHERE table_name='%s' and column_name='%s' and IS_NULLABLE=true",
201                 tableName.toUpperCase(), columnName.toUpperCase()
202         );
203
204         return getSession().doReturningWork(new HasAtLeastOneEntry(sql));
205     }
206
207     /**
208      * Deletes all data from all tables that can be accessed via the given EntityManager.
209      */
210     public void truncateTables() {
211         List<String> tables = getTables();
212         tables.removeIf(t -> t.toLowerCase().startsWith("hibernate"));
213
214         getSession().doWork(connection -> {
215             try (Statement stmt = connection.createStatement()) {
216                 stmt.addBatch("SET FOREIGN_KEY_CHECKS=0");
217                 for (String table : tables) {
218                     stmt.addBatch("TRUNCATE TABLE " + table);
219                 }
220                 stmt.addBatch("SET FOREIGN_KEY_CHECKS=1");
221                 stmt.executeBatch();
222             }
223         });
224     }
225
226     public Session getSession() {
227         return em.unwrap(Session.class);
228     }
229
230     public interface StatementWork<T> extends ReturningWork<T> {
231
232         default T execute(Connection connection) throws SQLException {
233             try (Statement stmt = connection.createStatement()) {
234                 return execute(stmt);
235             }
236         }
237
238         T execute(Statement stmt) throws SQLException;
239     }
240
241     public static abstract class QueryWork<T> implements StatementWork<T> {
242         private final String sql;
243
244         public QueryWork(String sql) {
245             this.sql = sql;
246         }
247
248         @Override
249         public T execute(Statement stmt) throws SQLException {
250             try (ResultSet rs = stmt.executeQuery(sql)) {
251                 return execute(rs);
252             }
253         }
254
255         protected abstract T execute(ResultSet rs) throws SQLException;
256     }
257
258     public static class HasAtLeastOneEntry extends QueryWork<Boolean> {
259
260         public HasAtLeastOneEntry(String sql) {
261             super(sql);
262         }
263
264         @Override
265         protected Boolean execute(ResultSet rs) throws SQLException {
266             return rs.next();
267         }
268     }
269
270     public static class CollectionWork<T> extends QueryWork<List<T>> {
271
272         private final CheckedFunction<ResultSet, T, SQLException> extractor;
273
274         public CollectionWork(String sql, CheckedFunction<ResultSet, T, SQLException> extractor) {
275             super(sql);
276             this.extractor = extractor;
277         }
278
279         @Override
280         protected List<T> execute(ResultSet rs) throws SQLException {
281             List<T> list = new ArrayList<>();
282             while (rs.next()) {
283                 list.add(extractor.apply(rs));
284             }
285             return list;
286         }
287     }
288
289     @FunctionalInterface
290     public interface CheckedFunction<T, R, E extends Exception> {
291         R apply(T t) throws E;
292     }
293 }