1use diesel::associations::HasTable;
2use diesel::backend::Backend;
3use diesel::dsl;
4use diesel::migration::{
5 Migration, MigrationConnection, MigrationSource, MigrationVersion, Result,
6};
7use diesel::prelude::*;
8use diesel::query_builder::{DeleteStatement, InsertStatement, IntoUpdateTarget};
9use diesel::query_dsl::methods::ExecuteDsl;
10use diesel::query_dsl::LoadQuery;
11use diesel::serialize::ToSql;
12use diesel::sql_types::Text;
13use std::cell::RefCell;
14use std::collections::HashMap;
15use std::io::Write;
16
17use crate::errors::MigrationError;
18
19diesel::table! {
20 __diesel_schema_migrations (version) {
21 version -> VarChar,
22 run_on -> Timestamp,
23 }
24}
25
26pub trait MigrationHarness<DB: Backend> {
28 fn has_pending_migration<S: MigrationSource<DB>>(&mut self, source: S) -> Result<bool> {
30 self.pending_migrations(source).map(|p| !p.is_empty())
31 }
32
33 fn run_pending_migrations<S: MigrationSource<DB>>(
44 &mut self,
45 source: S,
46 ) -> Result<Vec<MigrationVersion>> {
47 let pending = self.pending_migrations(source)?;
48 self.run_migrations(&pending)
49 }
50
51 #[doc(hidden)]
55 fn run_migrations(
56 &mut self,
57 migrations: &[Box<dyn Migration<DB>>],
58 ) -> Result<Vec<MigrationVersion>> {
59 migrations.iter().map(|m| self.run_migration(m)).collect()
60 }
61
62 fn run_next_migration<S: MigrationSource<DB>>(
64 &mut self,
65 source: S,
66 ) -> Result<MigrationVersion> {
67 let pending_migrations = self.pending_migrations(source)?;
68 let next_migration = pending_migrations
69 .first()
70 .ok_or(MigrationError::NoMigrationRun)?;
71 self.run_migration(next_migration)
72 }
73
74 fn revert_all_migrations<S: MigrationSource<DB>>(
76 &mut self,
77 source: S,
78 ) -> Result<Vec<MigrationVersion>> {
79 let applied_versions = self.applied_migrations()?;
80 let mut migrations = source
81 .migrations()?
82 .into_iter()
83 .map(|m| (m.name().version().as_owned(), m))
84 .collect::<HashMap<_, _>>();
85
86 applied_versions
87 .into_iter()
88 .map(|version| {
89 let migration_to_revert = migrations
90 .remove(&version)
91 .ok_or(MigrationError::UnknownMigrationVersion(version))?;
92 self.revert_migration(&migration_to_revert)
93 })
94 .collect()
95 }
96
97 fn revert_last_migration<S: MigrationSource<DB>>(
102 &mut self,
103 source: S,
104 ) -> Result<MigrationVersion<'static>> {
105 let applied_versions = self.applied_migrations()?;
106 let migrations = source.migrations()?;
107 let last_migration_version = applied_versions
108 .first()
109 .ok_or(MigrationError::NoMigrationRun)?;
110 let migration_to_revert = migrations
111 .iter()
112 .find(|m| m.name().version() == *last_migration_version)
113 .ok_or_else(|| {
114 MigrationError::UnknownMigrationVersion(last_migration_version.as_owned())
115 })?;
116 self.revert_migration(migration_to_revert)
117 }
118
119 fn pending_migrations<S: MigrationSource<DB>>(
124 &mut self,
125 source: S,
126 ) -> Result<Vec<Box<dyn Migration<DB>>>> {
127 let applied_versions = self.applied_migrations()?;
128 let mut migrations = source
129 .migrations()?
130 .into_iter()
131 .map(|m| (m.name().version().as_owned(), m))
132 .collect::<HashMap<_, _>>();
133
134 for applied_version in applied_versions {
135 migrations.remove(&applied_version);
136 }
137
138 let mut migrations = migrations.into_values().collect::<Vec<_>>();
139
140 migrations.sort_unstable_by(|a, b| a.name().version().cmp(&b.name().version()));
141
142 Ok(migrations)
143 }
144
145 fn run_migration(&mut self, migration: &dyn Migration<DB>)
150 -> Result<MigrationVersion<'static>>;
151
152 fn revert_migration(
157 &mut self,
158 migration: &dyn Migration<DB>,
159 ) -> Result<MigrationVersion<'static>>;
160
161 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>>;
163}
164
165impl<'b, C, DB> MigrationHarness<DB> for C
166where
167 DB: Backend,
168 C: Connection<Backend = DB> + MigrationConnection + 'static,
169 dsl::Order<
170 dsl::Select<__diesel_schema_migrations::table, __diesel_schema_migrations::version>,
171 dsl::Desc<__diesel_schema_migrations::version>,
172 >: LoadQuery<'b, C, MigrationVersion<'static>>,
173 for<'a> InsertStatement<
174 __diesel_schema_migrations::table,
175 <dsl::Eq<__diesel_schema_migrations::version, MigrationVersion<'static>> as Insertable<
176 __diesel_schema_migrations::table,
177 >>::Values,
178 >: diesel::query_builder::QueryFragment<DB> + ExecuteDsl<C, DB>,
179 DeleteStatement<
180 <dsl::Find<
181 __diesel_schema_migrations::table,
182 MigrationVersion<'static>,
183 > as HasTable>::Table,
184 <dsl::Find<
185 __diesel_schema_migrations::table,
186 MigrationVersion<'static>,
187 > as IntoUpdateTarget>::WhereClause,
188 >: ExecuteDsl<C>,
189 str: ToSql<Text, DB>,
190{
191 fn run_migration(
192 &mut self,
193 migration: &dyn Migration<DB>,
194 ) -> Result<MigrationVersion<'static>> {
195 let apply_migration = |conn: &mut C| -> Result<()> {
196 migration.run(conn)?;
197 diesel::insert_into(__diesel_schema_migrations::table)
198 .values(__diesel_schema_migrations::version.eq(migration.name().version().as_owned())).execute(conn)?;
199 Ok(())
200 };
201
202 if migration.metadata().run_in_transaction() {
203 self.transaction(apply_migration)?;
204 } else {
205 apply_migration(self)?;
206 }
207 Ok(migration.name().version().as_owned())
208 }
209
210 fn revert_migration(
211 &mut self,
212 migration: &dyn Migration<DB>,
213 ) -> Result<MigrationVersion<'static>> {
214 let revert_migration = |conn: &mut C| -> Result<()> {
215 migration.revert(conn)?;
216 diesel::delete(__diesel_schema_migrations::table.find(migration.name().version().as_owned()))
217 .execute(conn)?;
218 Ok(())
219 };
220
221 if migration.metadata().run_in_transaction() {
222 self.transaction(revert_migration)?;
223 } else {
224 revert_migration(self)?;
225 }
226 Ok(migration.name().version().as_owned())
227 }
228
229 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>> {
230 setup_database(self)?;
231 Ok(__diesel_schema_migrations::table
232 .select(__diesel_schema_migrations::version)
233 .order(__diesel_schema_migrations::version.desc())
234 .load(self)?)
235 }
236}
237
238pub struct HarnessWithOutput<'a, C, W> {
241 connection: &'a mut C,
242 output: RefCell<W>,
243}
244
245impl<'a, C, W> HarnessWithOutput<'a, C, W> {
246 pub fn new<DB>(harness: &'a mut C, output: W) -> Self
248 where
249 C: MigrationHarness<DB>,
250 DB: Backend,
251 W: Write,
252 {
253 Self {
254 connection: harness,
255 output: RefCell::new(output),
256 }
257 }
258}
259
260impl<'a, C> HarnessWithOutput<'a, C, std::io::Stdout> {
261 pub fn write_to_stdout<DB>(harness: &'a mut C) -> Self
263 where
264 C: MigrationHarness<DB>,
265 DB: Backend,
266 {
267 Self {
268 connection: harness,
269 output: RefCell::new(std::io::stdout()),
270 }
271 }
272}
273
274impl<C, W, DB> MigrationHarness<DB> for HarnessWithOutput<'_, C, W>
275where
276 W: Write,
277 C: MigrationHarness<DB>,
278 DB: Backend,
279{
280 fn run_migration(
281 &mut self,
282 migration: &dyn Migration<DB>,
283 ) -> Result<MigrationVersion<'static>> {
284 if migration.name().version() != MigrationVersion::from("00000000000000") {
285 let mut output = self.output.try_borrow_mut()?;
286 writeln!(output, "Running migration {}", migration.name())?;
287 }
288 self.connection.run_migration(migration)
289 }
290
291 fn revert_migration(
292 &mut self,
293 migration: &dyn Migration<DB>,
294 ) -> Result<MigrationVersion<'static>> {
295 if migration.name().version() != MigrationVersion::from("00000000000000") {
296 let mut output = self.output.try_borrow_mut()?;
297 writeln!(output, "Rolling back migration {}", migration.name())?;
298 }
299 self.connection.revert_migration(migration)
300 }
301
302 fn applied_migrations(&mut self) -> Result<Vec<MigrationVersion<'static>>> {
303 self.connection.applied_migrations()
304 }
305}
306
307fn setup_database<Conn: MigrationConnection>(conn: &mut Conn) -> QueryResult<usize> {
308 conn.setup()
309}