Tambourine作業メモ

主にスキル習得のためにやった作業のメモ。他人には基本的に無用のものです。

ケント・ベックの「テスト駆動開発」の写経をRustでやってみる(17)

元のJavaのテストではfiveBucksはMoneyではなくExpressionだった。ここから派生する汎用化を15章では行っている。例えば、plusメソッドがExpressionを受け付けるようにしたり、Expressionにplusメソッドを追加したりだ。これは結局、何がしたいのかというと、今のままではこういうことが出来ないぞと言うことだ。

#[test]
fn test_continuous_adding() {
    let bank = Bank::new();
    let one = Money::dollar(1.);
    let two = Money::dollar(2.);
    let three = Money::dollar(3.);

    // 1 + (2 + 3)
    let result = bank.reduce(&one.add(&two.add(&three)), USD);
    // (1 + 2) + 3
    let result2 = bank.reduce(&(one.add(&two)).add(&three), USD);
    assert_eq!(Money::dollar(6.), result1.unwrap());
    assert_eq!(Money::dollar(6.), result2.unwrap());
}

つまり、Sumの足す数、足される数のどちらにもMoneyだけじゃなく、Sumも入れられる様にしたい。そのために、MoneyもSumもExpressionを実装して、Expressionを入れられる様にしたいぞというわけだ。一種の再帰的なデータ構造を取りたいわけ。

Javaなら単にSumがExpressionを保持するよっていうので構わないんだけども、Rustではなかなか難しい。というのも、Expressionはトレイトなので参照しか保持できないから。

ちょっとやってみよう。Sumの定義はこうだった。

pub struct Sum {
    augend: Money,
    addend: Money,
}

これを、Expressionのポインタを持つように変える。

pub struct Sum {
    augend: Box<Expression>,
    addend: Box<Expression>,
}

これは問題ない。

Sumの定義が変わったので、Money.add()の実装も変える必要がある。以前はこうだった。

impl Money {
    pub fn add(&self, other: &Money) -> Sum {
        Sum {
            augend: self.clone(),
            addend: other.clone(),
        }
    }

//...

}

otherをExpressionの参照に・・・しかし、Sumが渡ってきたときにはcloneできないし・・・単にBoxに詰めるとライフタイムがおかしい。

impl Money {
    pub fn add<T: Expression>(&self, other: &T) -> Sum {
        Sum {
            augend: Box::new(*self),
            addend: Box::new(*other),
        }
    }
}
error[E0310]: the parameter type `T` may not live long enough
  --> src/main.rs:59:21
   |
56 |     pub fn add<T: Expression>(&self, other: &T) -> Sum {
   |                -- help: consider adding an explicit lifetime bound `T: 'static`...
...
59 |             addend: Box::new(*other),
   |                     ^^^^^^^^^^^^^^^^
   |
note: ...so that the type `T` will meet its required lifetime bounds
  --> src/main.rs:59:21
   |
59 |             addend: Box::new(*other),
   |                     ^^^^^^^^^^^^^^^^

何か、非常にやりづらい・・・。

このような再帰型のデータ構造を作るときには、Rustにはお決まりのパターンがある。Enumを使うのだ。

Sumには、Moneyが格納されるかも知れないし、別のSumが格納されるかもしれない。Sumを処理する場合には、どちらが格納されているかによって、処理を変えなくてはいけない。

これをJavaではポリモーフィズムで表現している。Cであれば、ポインタを格納してキャストして使用する。Rustでは、Enumを使うのだ。

つまり、ExpressionとはMoneyかSumのことだとするならば、そう定義してしまおう。

pub enum Expression {
    Value(Money),
    Operation(Box<Sum>),
}

ここで、Enumの名前やメンバーの名前をどうするかは大分悩んだ。しっくりこない。いろいろと例を見ないとなんとも言えないかなあという気がする。

とにかく、MoneyかSumのポインタ(Sum自体はサイズが決まらないから、ポインタしか扱えないのだ)を取れると。

Sumの定義は素直にこれを使って行う。

pub struct Sum {
    augend: Expression,
    addend: Expression,
}

これを返すMoney.add()を考える。しかし、Moneyの実体を格納するんだとすれば、add()がSumを返すと、そこに元のMoneyは移動したことになる。なんだかよくわからない気持ちになったので、インスタンスメソッドはやめてMoney::add()のような関数に変えよう。しかし、引数にはExpressionを格納するのだから、Money::add()ではなくて、Expression::add()が適切だろう。

impl Expression {
    pub fn add(one: Expression, other: Expression) -> Sum {
        Sum {
            augend: one,
            addend: other,
        }
    }
}

渡された2つのExpressionをそのままSumにつめて返しているだけ。シンプルだ。

さあ、今回のクライマックスはSum.reduce()だ。ExpressionがMoneyならそれでよし。Sumが入ってたら再帰的にreduceしてやる必要がある。こんな感じ。

impl Exchangable for Sum {
    fn reduce(&self, to: Currency, bank: &Bank) -> Result<Money, &'static str> {
        let augend_reduce = match self.augend {
            Expression::Value(ref m) => m.reduce(to, bank)?.amount,
            Expression::Operation(ref s) => s.reduce(to, bank)?.amount,
        };
        let addend_reduce = match self.addend {
            Expression::Value(ref m) => m.reduce(to, bank)?.amount,
            Expression::Operation(ref s) => s.reduce(to, bank)?.amount,
        };
        Ok(Money {
            amount: augend_reduce + addend_reduce,
            currency: to,
        })
    }
}

パターンマッチがあるので見慣れないと辛いけど、やっていることは非常にシンプルだね。

これで、修正はOK。では、テストを直していきながらこれで良かったのか確認していこう。

足し算は、以前はこうだった

let five = Money::dollar(5.);
let sum = five.add(&five)

こうなった。

let five =Money::dollar(5.);
let sum = Expression::add(Expression::Value(five), Expression::Value(five));

これはエラーになる。 fiveを2つ詰めようと思うと、もともと1つなのでどこかで増やさないといけない。移動しちゃってるからね。かねてからMoneyはValue ObjectなのでCopyトレイトを実装してもいいだろうと思っていたので、このタイミングで自動実装を追加することにした。

それにしても、長い。うん、Money::add()も作ろう。

impl Money {
    // ...

    pub fn add(one: Money, other: Money) -> Sum {
        Expression::add(Expression::Value(one), Expression::Value(other))
    }
}

引数の型の違いによるオーバーロードが出来ないのは、ちょっと不便に感じることもあるね。

というわけで、これで良くなった。

let five = Money::dollar(5.);
let sum = Money::add(five, five)

ただし、本来やりたかったテストである1 + (2 + 3)はこうなってしまった。

#[test]
fn test_continuous_adding() {
    let bank = Bank::new();
    let one = Money::dollar(1.);
    let two = Money::dollar(2.);
    let three = Money::dollar(3.);

    // 1 + (2 + 3)
    let result1 = bank.reduce(
        &Expression::add(
            Expression::Value(one),
            Expression::Operation(Box::new(Money::add(two, three))),
        ),
        USD,
    );
    // (1 + 2) + 3
    let result2 = bank.reduce(
        &Expression::add(
            Expression::Operation(Box::new(Money::add(one, two))),
            Expression::Value(three),
        ),
        USD,
    );
    assert_eq!(Money::dollar(6.), result1.unwrap());
    assert_eq!(Money::dollar(6.), result2.unwrap());
}

ちょっと見栄えはよろしくない・・・。まあ、テストは通ったのでよしとしよう。

最後に、この段階でのソースコード全行を載せておく。

use std::collections::HashMap;

fn main() {}

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
pub enum Currency {
    CHF,
    USD,
}

impl Currency {
    fn get_pair(one: Currency, another: Currency) -> (Currency, Currency) {
        if one < another {
            (one, another)
        } else {
            (another, one)
        }
    }
}

use self::Currency::*;

pub trait Exchangable {
    fn reduce(&self, to: Currency, bank: &Bank) -> Result<Money, &'static str>;
}

pub enum Expression {
    Value(Money),
    Operation(Box<Sum>),
}

impl Expression {
    pub fn add(one: Expression, other: Expression) -> Sum {
        Sum {
            augend: one,
            addend: other,
        }
    }
}

#[derive(Debug, Clone, Copy)]
pub struct Money {
    amount: f64,
    currency: Currency,
}

impl Money {
    pub fn dollar(f: f64) -> Money {
        Money {
            amount: f,
            currency: USD,
        }
    }

    pub fn franc(f: f64) -> Money {
        Money {
            amount: f,
            currency: CHF,
        }
    }

    pub fn times(&self, multiplier: f64) -> Money {
        Money {
            amount: self.amount * multiplier,
            currency: self.currency,
        }
    }

    pub fn add(one: Money, other: Money) -> Sum {
        Expression::add(Expression::Value(one), Expression::Value(other))
    }
}

impl PartialEq for Money {
    fn eq(&self, other: &Money) -> bool {
        if self.currency != other.currency {
            false
        } else {
            self.amount == other.amount
        }
    }
}

impl Exchangable for Money {
    fn reduce(&self, to: Currency, bank: &Bank) -> Result<Money, &'static str> {
        let rate = bank.get_rate(self.currency, to);
        match rate {
            Some(r) => Ok(Money {
                amount: self.amount / r,
                currency: to,
            }),
            None => Err("Never set rate these currencies"),
        }
    }
}

pub struct Sum {
    augend: Expression,
    addend: Expression,
}
impl Exchangable for Sum {
    fn reduce(&self, to: Currency, bank: &Bank) -> Result<Money, &'static str> {
        let augend_reduce = match self.augend {
            Expression::Value(ref m) => m.reduce(to, bank)?.amount,
            Expression::Operation(ref s) => s.reduce(to, bank)?.amount,
        };
        let addend_reduce = match self.addend {
            Expression::Value(ref m) => m.reduce(to, bank)?.amount,
            Expression::Operation(ref s) => s.reduce(to, bank)?.amount,
        };
        Ok(Money {
            amount: augend_reduce + addend_reduce,
            currency: to,
        })
    }
}

pub struct Bank {
    rate: HashMap<(Currency, Currency), f64>,
}

impl Bank {
    pub fn new() -> Bank {
        Bank {
            rate: HashMap::new(),
        }
    }

    pub fn add_rate(&mut self, from: Currency, to: Currency, r: f64) {
        let pair = Currency::get_pair(from, to);
        self.rate
            .insert(pair, if pair.0 == from { r } else { 1.0 / r });
    }

    pub fn get_rate(&self, from: Currency, to: Currency) -> Option<f64> {
        if from == to {
            return Some(1.);
        }

        let pair = Currency::get_pair(from, to);
        self.rate
            .get(&pair)
            .map(|i| if pair.0 == from { *i } else { 1.0 / *i })
    }

    pub fn reduce<T: Exchangable>(&self, source: &T, to: Currency) -> Result<Money, &str> {
        source.reduce(to, self)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_multiplication() {
        let five_d = Money::dollar(5.);
        assert_eq!(Money::dollar(10.), five_d.times(2.));
        assert_eq!(Money::dollar(15.), five_d.times(3.));
    }

    #[test]
    fn test_equality() {
        assert_eq!(Money::dollar(5.), Money::dollar(5.));
        assert_ne!(Money::dollar(5.), Money::dollar(6.));
        assert_eq!(Money::franc(7.), Money::franc(7.));
        assert_ne!(Money::franc(7.), Money::franc(6.));

        // 通貨が違うので、一致しない。
        assert_ne!(Money::franc(5.), Money::dollar(5.));
    }

    #[test]
    fn test_currency() {
        assert_eq!(USD, Money::dollar(1.).currency);
        assert_eq!(CHF, Money::franc(1.).currency);
    }

    #[test]
    fn test_simple_addition() {
        let five = Money::dollar(5.);
        let sum = Money::add(five, five);
        let bank = Bank::new();
        let reduced: Money = bank.reduce(&sum, USD).unwrap();
        assert_eq!(Money::dollar(10.), reduced);
    }

    #[test]
    fn test_add_returns_sum() {
        let five = Money::dollar(5.);
        let sum = Money::add(five, five);
        assert_eq!(
            five,
            match sum.augend {
                Expression::Value(m) => m,
                _ => panic!("It is not correct"),
            }
        );
        assert_eq!(
            five,
            match sum.addend {
                Expression::Value(m) => m,
                _ => panic!("It is not correct"),
            }
        );
    }

    #[test]
    fn test_reduce_sum() {
        let sum = Sum {
            augend: Expression::Value(Money::dollar(3.)),
            addend: Expression::Value(Money::dollar(4.)),
        };
        let bank = Bank::new();
        let result = bank.reduce(&sum, USD).unwrap();
        assert_eq!(Money::dollar(7.), result);
    }

    #[test]
    fn test_reduce_money_different_currency() {
        let mut bank = Bank::new();
        bank.add_rate(CHF, USD, 2.0);
        let result = bank.reduce(&Money::franc(2.0), USD).unwrap();
        assert_eq!(Money::dollar(1.0), result);

        // ドルをドルに変更する
        let result2 = bank.reduce(&Money::dollar(2.0), USD).unwrap();
        assert_eq!(Money::dollar(2.), result2);
    }

    #[test]
    fn test_get_rate() {
        let mut bank = Bank::new();

        // 未登録の場合にはNoneが返る
        assert!(match bank.get_rate(CHF, USD) {
            None => true,
            _ => false,
        });

        // 登録したらそれが取れる
        bank.add_rate(CHF, USD, 2.0);
        assert_eq!(2.0, bank.get_rate(CHF, USD).unwrap());

        // 逆順に取得したら、逆数が取れる
        assert_eq!(0.5, bank.get_rate(USD, CHF).unwrap());
    }

    #[test]
    fn test_get_pair() {
        assert_eq!((CHF, USD), Currency::get_pair(CHF, USD));
        assert_eq!((CHF, USD), Currency::get_pair(USD, CHF));
    }

    #[test]
    fn test_add_rate() {
        let mut bank = Bank::new();

        // 初回登録
        bank.add_rate(CHF, USD, 2.0);
        assert_eq!(2.0, bank.rate[&(CHF, USD)]);

        // 上書き登録すると書き換えられる
        bank.add_rate(CHF, USD, 4.0);
        assert_eq!(4.0, bank.rate[&(CHF, USD)]);

        // 逆数でも登録できる
        bank.add_rate(USD, CHF, 4.0);
        assert_eq!(0.25, bank.rate[&(CHF, USD)]);
    }

    #[test]
    fn test_mixed_addition() {
        let five_bucks = Money::dollar(5.);
        let ten_francs = Money::franc(10.);
        let mut bank = Bank::new();
        bank.add_rate(CHF, USD, 2.0);
        let result = bank.reduce(&Money::add(five_bucks, ten_francs), USD);
        assert_eq!(Money::dollar(10.), result.unwrap());
    }

    #[test]
    fn test_continuous_adding() {
        let bank = Bank::new();
        let one = Money::dollar(1.);
        let two = Money::dollar(2.);
        let three = Money::dollar(3.);

        // 1 + (2 + 3)
        let result1 = bank.reduce(
            &Expression::add(
                Expression::Value(one),
                Expression::Operation(Box::new(Money::add(two, three))),
            ),
            USD,
        );
        // (1 + 2) + 3
        let result2 = bank.reduce(
            &Expression::add(
                Expression::Operation(Box::new(Money::add(one, two))),
                Expression::Value(three),
            ),
            USD,
        );
        assert_eq!(Money::dollar(6.), result1.unwrap());
        assert_eq!(Money::dollar(6.), result2.unwrap());
    }
}