Tambourine作業メモ

主にスキル習得のためにやった作業のメモ。他人には基本的に無用のものです。

ケント・ベックの「テスト駆動開発」の写経をRustでやってみる(15)

さて、もうちょっとコードをキレイにしてから先に進みたい。

今、通貨は文字列で定義されている。Java版だとStringだ。私のRustのコードだと、ここが&'static strになっている。

Rustの&strはヒープに確保された領域に書き込まれたUTF-8シーケンスに対する参照、というかポインタだ。そして、&'static strは文字列リテラル、つまりプログラム開始時に特別な領域に作られて、プログラム中で共有される文字列に対するポインタなので、プログラム中のどの"USD"も厳密に同じアドレスを指すポインタになる。

そして、メソッドの引数の型が&'static strしか受け付けないということは、そのように絶対に無くならない領域に対するポインタしか受け付けませんということになる。非常にRustらしくて面白い。面白いんだけど、&'static strは長いし、ずっと同じアドレスということはもうそれは文字列である必要は全然ないんで、ちゃんと列挙型にしてやるべきだろう。タイプミスもちゃんとコンパイルエラーになるし。

というわけで、サクッと定義する。

enum Currency {
    CHF,
    USD,
}

Rustの列挙型はいろんなことが出来て面白い。JavaでいうところのOptionalなんかも列挙型で実現されている。けど、まあ、この例は、ごく普通だ。

さて、これで通貨を扱っていたところを全部置き換える。するといろいろ怒られる。

まず、そのままだと大小比較が出来ない。

error[E0369]: binary operation `<` cannot be applied to type `Currency`
  --> src/main.rs:14:16
   |
14 |         if one < another {
   |            --- ^ ------- Currency
   |            |
   |            Currency
   |
   = note: an implementation of `std::cmp::PartialOrd` might be missing for `Currency`

Debugトレイトも実装されてないので、MoneyのDebugの自動実装がエラーになる。

error[E0277]: `Currency` doesn't implement `std::fmt::Debug`
  --> src/main.rs:31:5
   |
31 |     currency: Currency,
   |     ^^^^^^^^^^^^^^^^^^ `Currency` cannot be formatted using `{:?}`
   |
   = help: the trait `std::fmt::Debug` is not implemented for `Currency`
   = note: add `#[derive(Debug)]` or manually implement `std::fmt::Debug`
   = note: required because of the requirements on the impl of `std::fmt::Debug` for `&Currency`
   = note: required for the cast to the object type `dyn std::fmt::Debug`

Cloneも実装されていないので、MoneyのCloneの自動実装もエラーになる。

error[E0277]: the trait bound `Currency: std::clone::Clone` is not satisfied
  --> src/main.rs:31:5
   |
31 |     currency: Currency,
   |     ^^^^^^^^^^^^^^^^^^ the trait `std::clone::Clone` is not implemented for `Currency`
   |
   = note: required by `std::clone::Clone::clone`

==での比較も出来ない。

error[E0369]: binary operation `!=` cannot be applied to type `Currency`
  --> src/main.rs:66:26
   |
66 |         if self.currency != other.currency {
   |            ------------- ^^ -------------- Currency
   |            |
   |            Currency
   |
   = note: an implementation of `std::cmp::PartialEq` might be missing for `Currency`

EqとHashが実装されていないので、ハッシュのキーに出来ない

error[E0277]: the trait bound `Currency: std::cmp::Eq` is not satisfied
   --> src/main.rs:107:19
    |
107 |             rate: HashMap::new(),
    |                   ^^^^^^^^^^^^ the trait `std::cmp::Eq` is not implemented for `Currency`
    |
    = note: required because of the requirements on the impl of `std::cmp::Eq` for `(Currency, Currency)`
    = note: required by `std::collections::HashMap::<K, V>::new`


error[E0277]: the trait bound `Currency: std::hash::Hash` is not satisfied
   --> src/main.rs:107:19
    |
107 |             rate: HashMap::new(),
    |                   ^^^^^^^^^^^^ the trait `std::hash::Hash` is not implemented for `Currency`
    |
    = note: required because of the requirements on the impl of `std::hash::Hash` for `(Currency, Currency)`
    = note: required by `std::collections::HashMap::<K, V>::new`

はいはい、すいませんね・・・ということで、まとめて自動実装をお願いする。

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
pub enum Currency {
    CHF,
    USD,
}

これでOK。ラクチン。

さらに、enumにはメソッドも定義できる。「通貨レートの表示の際にFromとToをどちらにするのか」は金融機関ではなく、通貨の定義で決まっているべきなので、Bank::get_pairはCurrencyに移動しよう。

impl Currency {
    fn get_pair(one: Currency, another: Currency) -> (Currency, Currency) {
        if one < another {
            (one, another)
        } else {
            (another, one)
        }
    }
}

これで、通貨は文字列では無く、こんな感じで指定できるようになった。

bank.add_rate(Currency::CHF, Currency::USD, 2.0);

見づらい。useしてやれば修飾しなくても良くなる。

use self::Currency::*;

これでOKになる。素晴らしい。まあ、やるべきでないときも多いと思うけど。

bank.add_rate(CHF, USD, 2.0);

これでだいぶ見やすくなった。最後に今の時点のコードを全部上げておく。次の章は15章、いよいよ「5ドル+10フラン=10ドル」ができるようになるときが来る、らしい。

use std::collections::HashMap;

fn main() {}

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
pub enum Currency {
    CHF,
    USD,
}

impl Currency {
    fn get_pair(one: Currency, another: Currency) -> (Currency, Currency) {
        if one < another {
            (one, another)
        } else {
            (another, one)
        }
    }
}

use self::Currency::*;

pub trait Expression {
    fn reduce(&self, to: Currency, bank: &Bank) -> Result<Money, &'static str>;
}

#[derive(Debug, Clone)]
pub struct Money {
    amount: f64,
    currency: Currency,
}

impl Money {
    pub fn dollar(f: f64) -> Money {
        Money {
            amount: f,
            currency: USD,
        }
    }

    pub fn franc(f: f64) -> Money {
        Money {
            amount: f,
            currency: CHF,
        }
    }

    pub fn times(&self, multiplier: f64) -> Money {
        Money {
            amount: self.amount * multiplier,
            currency: self.currency,
        }
    }

    pub fn add(&self, other: &Money) -> Sum {
        Sum {
            augend: self.clone(),
            addend: other.clone(),
        }
    }
}

impl PartialEq for Money {
    fn eq(&self, other: &Money) -> bool {
        if self.currency != other.currency {
            false
        } else {
            self.amount == other.amount
        }
    }
}

impl Expression for Money {
    fn reduce(&self, to: Currency, bank: &Bank) -> Result<Money, &'static str> {
        let rate = bank.get_rate(self.currency, to);
        match rate {
            Some(r) => Ok(Money {
                amount: self.amount / r,
                currency: to,
            }),
            None => Err("Never set rate these currencies"),
        }
    }
}

pub struct Sum {
    augend: Money,
    addend: Money,
}
impl Expression for Sum {
    fn reduce(&self, to: Currency, _: &Bank) -> Result<Money, &'static str> {
        Ok(Money {
            amount: self.augend.amount + self.addend.amount,
            currency: to,
        })
    }
}

pub struct Bank {
    rate: HashMap<(Currency, Currency), f64>,
}

impl Bank {
    pub fn new() -> Bank {
        Bank {
            rate: HashMap::new(),
        }
    }

    pub fn add_rate(&mut self, from: Currency, to: Currency, r: f64) {
        let pair = Currency::get_pair(from, to);
        self.rate
            .insert(pair, if pair.0 == from { r } else { 1.0 / r });
    }

    pub fn get_rate(&self, from: Currency, to: Currency) -> Option<f64> {
        if from == to {
            return Some(1.);
        }

        let pair = Currency::get_pair(from, to);
        self.rate
            .get(&pair)
            .map(|i| if pair.0 == from { *i } else { 1.0 / *i })
    }

    pub fn reduce<T: Expression>(&self, source: &T, to: Currency) -> Result<Money, &str> {
        source.reduce(to, self)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_multiplication() {
        let five_d = Money::dollar(5.);
        assert_eq!(Money::dollar(10.), five_d.times(2.));
        assert_eq!(Money::dollar(15.), five_d.times(3.));
    }

    #[test]
    fn test_equality() {
        assert_eq!(Money::dollar(5.), Money::dollar(5.));
        assert_ne!(Money::dollar(5.), Money::dollar(6.));
        assert_eq!(Money::franc(7.), Money::franc(7.));
        assert_ne!(Money::franc(7.), Money::franc(6.));

        // 通貨が違うので、一致しない。
        assert_ne!(Money::franc(5.), Money::dollar(5.));
    }

    #[test]
    fn test_currency() {
        assert_eq!(USD, Money::dollar(1.).currency);
        assert_eq!(CHF, Money::franc(1.).currency);
    }

    #[test]
    fn test_simple_addition() {
        let five = Money::dollar(5.);
        let sum = five.add(&five);
        let bank = Bank::new();
        let reduced: Money = bank.reduce(&sum, USD).unwrap();
        assert_eq!(Money::dollar(10.), reduced);
    }

    #[test]
    fn test_add_returns_sum() {
        let five = Money::dollar(5.);
        let sum = five.add(&five);
        assert_eq!(five, sum.augend);
        assert_eq!(five, sum.addend);
    }

    #[test]
    fn test_reduce_sum() {
        let sum = Sum {
            augend: Money::dollar(3.),
            addend: Money::dollar(4.),
        };
        let bank = Bank::new();
        let result = bank.reduce(&sum, USD).unwrap();
        assert_eq!(Money::dollar(7.), result);
    }

    #[test]
    fn test_reduce_money_different_currency() {
        let mut bank = Bank::new();
        bank.add_rate(CHF, USD, 2.0);
        let result = bank.reduce(&Money::franc(2.0), USD).unwrap();
        assert_eq!(Money::dollar(1.0), result);

        // ドルをドルに変更する
        let result2 = bank.reduce(&Money::dollar(2.0), USD).unwrap();
        assert_eq!(Money::dollar(2.), result2);
    }

    #[test]
    fn test_get_rate() {
        let mut bank = Bank::new();

        // 未登録の場合にはNoneが返る
        assert!(match bank.get_rate(CHF, USD) {
            None => true,
            _ => false,
        });

        // 登録したらそれが取れる
        bank.add_rate(CHF, USD, 2.0);
        assert_eq!(2.0, bank.get_rate(CHF, USD).unwrap());

        // 逆順に取得したら、逆数が取れる
        assert_eq!(0.5, bank.get_rate(USD, CHF).unwrap());
    }

    #[test]
    fn test_get_pair() {
        assert_eq!((CHF, USD), Currency::get_pair(CHF, USD));
        assert_eq!((CHF, USD), Currency::get_pair(USD, CHF));
    }

    #[test]
    fn test_add_rate() {
        let mut bank = Bank::new();

        // 初回登録
        bank.add_rate(CHF, USD, 2.0);
        assert_eq!(2.0, bank.rate[&(CHF, USD)]);

        // 上書き登録すると書き換えられる
        bank.add_rate(CHF, USD, 4.0);
        assert_eq!(4.0, bank.rate[&(CHF, USD)]);

        // 逆数でも登録できる
        bank.add_rate(USD, CHF, 4.0);
        assert_eq!(0.25, bank.rate[&(CHF, USD)]);
    }
}