1use diesel::backend::Backend;
2use diesel::migration::{
3 Migration, MigrationConnection, MigrationSource, MigrationVersion, Result,
4};
5use diesel::prelude::*;
6use diesel::query_builder::QueryFragment;
7use diesel::query_dsl::methods;
8use diesel::serialize::ToSql;
9use diesel::sql_types::{Text, VarChar};
10use std::cell::RefCell;
11use std::collections::HashMap;
12use std::io::Write;
13
14use crate::errors::MigrationError;
15
16diesel::table! {
17 __diesel_schema_migrations (version) {
18 version -> VarChar,
19 run_on -> Timestamp,
20 }
21}
22
23pub trait MigrationHarness<DB: Backend> {
25 fn has_pending_migration<S: MigrationSource<DB>>(&mut self, source: S) -> Result<bool> {
27 self.pending_migrations(source).map(|p| !p.is_empty())
28 }
29
30 fn run_pending_migrations<S: MigrationSource<DB>>(
40 &mut self,
41 source: S,
42 ) -> Result<Vec<MigrationVersion<'_>>> {
43 let pending = self.pending_migrations(source)?;
44 self.run_migrations(&pending)
45 }
46
47 #[doc(hidden)]
51 fn run_migrations(
52 &mut self,
53 migrations: &[Box<dyn Migration<DB>>],
54 ) -> Result<Vec<MigrationVersion<'_>>> {
55 migrations.iter().map(|m| self.run_migration(m)).collect()
56 }
57
58 fn run_next_migration<S: MigrationSource<DB>>(
60 &mut self,
61 source: S,
62 ) -> Result<MigrationVersion<'_>> {
63 let pending_migrations = self.pending_migrations(source)?;
64 let next_migration = pending_migrations
65 .first()
66 .ok_or(MigrationError::NoMigrationRun)?;
67 self.run_migration(next_migration)
68 }
69
70 fn revert_all_migrations<S: MigrationSource<DB>>(
72 &mut self,
73 source: S,
74 ) -> Result<Vec<MigrationVersion<'_>>> {
75 let applied_versions = self.applied_migrations()?;
76 let mut migrations = source
77 .migrations()?
78 .into_iter()
79 .map(|m| (m.name().version().as_owned(), m))
80 .collect::<HashMap<_, _>>();
81
82 applied_versions
83 .into_iter()
84 .map(|version| {
85 let migration_to_revert = migrations
86 .remove(&version)
87 .ok_or(MigrationError::UnknownMigrationVersion(version))?;
88 self.revert_migration(&migration_to_revert)
89 })
90 .collect()
91 }
92
93 fn revert_last_migration<S: MigrationSource<DB>>(
98 &mut self,
99 source: S,
100 ) -> Result<MigrationVersion<'static>> {
101 let applied_versions = self.applied_migrations()?;
102 let migrations = source.migrations()?;
103 let last_migration_version = applied_versions
104 .first()
105 .ok_or(MigrationError::NoMigrationRun)?;
106 let migration_to_revert = migrations
107 .iter()
108 .find(|m| m.name().version() == *last_migration_version)
109 .ok_or_else(|| {
110 MigrationError::UnknownMigrationVersion(last_migration_version.as_owned())
111 })?;
112 self.revert_migration(migration_to_revert)
113 }
114
115 fn pending_migrations<S: MigrationSource<DB>>(
120 &mut self,
121 source: S,
122 ) -> Result<Vec<Box<dyn Migration<DB>>>> {
123 let applied_versions = self.applied_migrations()?;
124 let mut migrations = source
125 .migrations()?
126 .into_iter()
127 .map(|m| (m.name().version().as_owned(), m))
128 .collect::<HashMap<_, _>>();
129
130 for applied_version in applied_versions {
131 migrations.remove(&applied_version);
132 }
133
134 let mut migrations = migrations.into_values().collect::<Vec<_>>();
135
136 migrations.sort_unstable_by(|a, b| a.name().version().cmp(&b.name().version()));
137
138 Ok(migrations)
139 }
140
141 fn run_migration(&mut self, migration: &dyn Migration<DB>)
146 -> Result<MigrationVersion<'static>>;
147
148 fn revert_migration(
153 &mut self,
154 migration: &dyn Migration<DB>,
155 ) -> Result<MigrationVersion<'static>>;
156
157 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>>;
159}
160
161impl<C, DB> MigrationHarness<DB> for C
162where
163 DB: Backend + diesel::internal::migrations::DieselReserveSpecialization,
164 C: Connection<Backend = DB> + MigrationConnection + 'static,
165 __diesel_schema_migrations::table: methods::BoxedDsl<
166 'static,
167 DB,
168 Output = __diesel_schema_migrations::BoxedQuery<'static, DB>,
169 >,
170 __diesel_schema_migrations::BoxedQuery<'static, DB, VarChar>:
171 methods::LoadQuery<'static, C, MigrationVersion<'static>>,
172 diesel::internal::migrations::DefaultValues: QueryFragment<DB>,
173 str: ToSql<Text, DB>,
174{
175 fn run_migration(
176 &mut self,
177 migration: &dyn Migration<DB>,
178 ) -> Result<MigrationVersion<'static>> {
179 let apply_migration = |conn: &mut C| -> Result<()> {
180 migration.run(conn)?;
181 diesel::insert_into(__diesel_schema_migrations::table)
182 .values(
183 __diesel_schema_migrations::version.eq(migration.name().version().as_owned()),
184 )
185 .execute(conn)?;
186 Ok(())
187 };
188
189 if migration.metadata().run_in_transaction() {
190 self.transaction(apply_migration)?;
191 } else {
192 apply_migration(self)?;
193 }
194 Ok(migration.name().version().as_owned())
195 }
196
197 fn revert_migration(
198 &mut self,
199 migration: &dyn Migration<DB>,
200 ) -> Result<MigrationVersion<'static>> {
201 let revert_migration = |conn: &mut C| -> Result<()> {
202 migration.revert(conn)?;
203 diesel::delete(
204 __diesel_schema_migrations::table.find(migration.name().version().as_owned()),
205 )
206 .execute(conn)?;
207 Ok(())
208 };
209
210 if migration.metadata().run_in_transaction() {
211 self.transaction(revert_migration)?;
212 } else {
213 revert_migration(self)?;
214 }
215 Ok(migration.name().version().as_owned())
216 }
217
218 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>> {
219 setup_database(self)?;
220 Ok(__diesel_schema_migrations::table
221 .into_boxed()
222 .select(__diesel_schema_migrations::version)
223 .order(__diesel_schema_migrations::version.desc())
224 .load(self)?)
225 }
226}
227
228pub struct HarnessWithOutput<'a, C, W> {
231 connection: &'a mut C,
232 output: RefCell<W>,
233}
234
235impl<'a, C, W> HarnessWithOutput<'a, C, W> {
236 pub fn new<DB>(harness: &'a mut C, output: W) -> Self
238 where
239 C: MigrationHarness<DB>,
240 DB: Backend,
241 W: Write,
242 {
243 Self {
244 connection: harness,
245 output: RefCell::new(output),
246 }
247 }
248}
249
250impl<'a, C> HarnessWithOutput<'a, C, std::io::Stdout> {
251 pub fn write_to_stdout<DB>(harness: &'a mut C) -> Self
253 where
254 C: MigrationHarness<DB>,
255 DB: Backend,
256 {
257 Self {
258 connection: harness,
259 output: RefCell::new(std::io::stdout()),
260 }
261 }
262}
263
264impl<C, W, DB> MigrationHarness<DB> for HarnessWithOutput<'_, C, W>
265where
266 W: Write,
267 C: MigrationHarness<DB>,
268 DB: Backend,
269{
270 fn run_migration(
271 &mut self,
272 migration: &dyn Migration<DB>,
273 ) -> Result<MigrationVersion<'static>> {
274 if migration.name().version() != MigrationVersion::from("00000000000000") {
275 let mut output = self.output.try_borrow_mut()?;
276 writeln!(output, "Running migration {}", migration.name())?;
277 }
278 self.connection.run_migration(migration)
279 }
280
281 fn revert_migration(
282 &mut self,
283 migration: &dyn Migration<DB>,
284 ) -> Result<MigrationVersion<'static>> {
285 if migration.name().version() != MigrationVersion::from("00000000000000") {
286 let mut output = self.output.try_borrow_mut()?;
287 writeln!(output, "Rolling back migration {}", migration.name())?;
288 }
289 self.connection.revert_migration(migration)
290 }
291
292 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>> {
293 self.connection.applied_migrations()
294 }
295}
296
297fn setup_database<Conn: MigrationConnection>(conn: &mut Conn) -> QueryResult<usize> {
298 conn.setup()
299}