AOJ 1314 Matrix Calculator

問題

行列の電卓っぽいのを作れ。

制約条件

一行は80文字以下。
mod 32768で計算。

方針

構文解析するだけ。
デバッグがしんどかった……

ソースコード

const int mod = 32768;
typedef vector<vi> M;
M operator+(M a, const M &b){
	rep(i, a.size()) rep(j, a[0].size())
		a[i][j] = (a[i][j] + b[i][j]) % mod;
	return a;
}
M operator-(M a, const M &b){
	rep(i, a.size()) rep(j, a[0].size())
		a[i][j] = (a[i][j] + mod - b[i][j]) % mod;
	return a;
}
M operator*(M a, M b){
	if(b.size() == 1 && b[0].size() == 1) swap(a, b);
	if(a.size() == 1 && a[0].size() == 1){
		rep(i, b.size()) rep(j, b[0].size())
			b[i][j] = (b[i][j] * a[0][0]) % mod;
		return b;
	}
	M res(a.size(), vi(b[0].size()));
	rep(i, res.size()) rep(j, res[0].size()) rep(k, a[0].size())
		res[i][j] = (res[i][j] + a[i][k] * b[k][j]) % mod;
	return res;
}
ostream& operator<<(ostream &os, const M &m){
	rep(i, m.size()) rep(j, m[0].size())
		os << m[i][j] << (j == m[0].size() - 1 ? "\n" : " ");
	return os;
}

map<char, M> val;
int n, p, len;
string in, token;

M expr(), term(), factor(), primary(), matrix(), row(), digit();
void gettoken(){
	token = "";
	if(isdigit(in[p])) while(isdigit(in[p])) token += in[p++];
	else token += in[p++];
}

M digit(){
	gettoken();
	int n = 0;
	rep(i, token.size()){
		n *= 10;
		n += token[i] - '0';
		n %= mod;
	}
	return M(1, vi(1, n));
}
M row(){
	M res;
	while(1){
		M e = expr();
		if(res.empty()) res = e;
		else rep(i, e.size()) rep(j, e[0].size())
			res[i].pb(e[i][j]);
		if(in[p] != ' ') break;
		p++;
	}
	return res;
}
M matrix(){
	p++;
	M res;
	while(1){
		M r = row();
		if(res.empty()) res = r;
		else rep(i, r.size()) res.pb(r[i]);
		if(in[p] == ']') break;
		p++;
	}
	p++;
	return res;
}
M primary(){
	M res;
	if(isdigit(in[p])) res = digit();
	else if(isalpha(in[p])) res = val[in[p++]];
	else if(in[p] == '('){
		p++;
		res = expr();
		p++;
	}
	else res = matrix();
	
	while(1){
		if(in[p] == '\''){
			M tmp(res[0].size(), vi(res.size()));
			rep(i, res.size()) rep(j, res[0].size())
				tmp[j][i] = res[i][j];
			swap(res, tmp);
			p++;
		}
		else{
			if(in[p] != '(') break;
			int i = p + 1, d = 0;
			for(; i < len; i++){
				if(in[i] == '(') d++;
				if(in[i] == ')') d--;
				if(d == 0 && in[i] == ',') break;
			}
			if(i >= len) break;
			p++;
			M a = expr(); p++;
			M b = expr(); p++;
			M tmp(a[0].size(), vi(b[0].size()));
			rep(i, tmp.size()) rep(j, tmp[0].size())
				tmp[i][j] = res[a[0][i] - 1][b[0][j] - 1];
			swap(res, tmp);
		}
	}
	return res;
}
M factor(){
	bool minus = 0;
	while(in[p] == '-') p++, minus ^= 1;
	M res = primary();
	if(minus) rep(i, res.size()) rep(j, res[0].size())
		res[i][j] = (mod - res[i][j]) % mod;
	return res;
}
M term(){
	M res;
	while(1){
		M f = factor();
		if(res.empty()) res = f;
		else res = res * f;
		if(in[p] != '*') break;
		p++;
	}
	return res;
}
M expr(){
	M res;
	bool minus = 0;
	while(1){
		M e = term();
		if(res.empty()) res = e;
		else{
			if(minus) res = res - e;
			else res = res + e;
		}
		if(in[p] != '+' && in[p] != '-') break;
		minus = in[p] == '-';
		p++;
	}
	return res;
}

M parse(){
	len = in.size();
	p = 0;
	return expr();
}
int main()
{
	while(getline(cin, in)){
		n = atoi(in.c_str());
		if(n == 0) break;
		val.clear();
		rep(i, n){
			getline(cin, in);
			char v = in[0];
			in = in.substr(2, in.size() - 3);
			val[v] = parse();
			cout << val[v];
		}
		cout << "-----" << endl;
	}
	return 0;
}