다항식 평가를위한 생성 된 메소드

algebra c# expression-trees linq math

문제

나는 생성 된 다항식을 다루기위한 우아한 방법을 고안하려고 노력 중이다. 다음은이 질문에 대해 집중적으로 다룰 상황입니다.

  1. ordern 차 다항식을 생성 할 때의 파라미터이며, n : = order + 1입니다.
  2. i 는 0..n 범위의 정수 매개 변수입니다.
  3. 다항식은 x_j에 0을 가지고 있습니다. 여기서 j = 1..n 및 j.i (이 시점에서 StackOverflow에 새로운 기능이 필요하거나 현재 존재하고 있어야한다는 것을 명확히해야합니다)
  4. 다항식은 x_i에서 1로 평가됩니다.

이 특별한 코드 예제는 x_1 .. x_n을 생성하기 때문에 코드에서 어떻게 발견되는지 설명 할 것입니다. 점들은 x_j = j * elementSize / order 간격으로 균등하게 배열됩니다. 여기서 n = order + 1 입니다.

이 다항식 ¹을 평가하기 위해 Func<double, double> 을 생성합니다.

private static Func<double, double> GeneratePsi(double elementSize, int order, int i)
{
    if (order < 1)
        throw new ArgumentOutOfRangeException("order", "order must be greater than 0.");

    if (i < 0)
        throw new ArgumentOutOfRangeException("i", "i cannot be less than zero.");
    if (i > order)
        throw new ArgumentException("i", "i cannot be greater than order");

    ParameterExpression xp = Expression.Parameter(typeof(double), "x");

    // generate the terms of the factored polynomial in form (x_j - x)
    List<Expression> factors = new List<Expression>();
    for (int j = 0; j <= order; j++)
    {
        if (j == i)
            continue;

        double p = j * elementSize / order;
        factors.Add(Expression.Subtract(Expression.Constant(p), xp));
    }

    // evaluate the result at the point x_i to get scaleInv=1.0/scale.
    double xi = i * elementSize / order;
    double scaleInv = Enumerable.Range(0, order + 1).Aggregate(0.0, (product, j) => product * (j == i ? 1.0 : (j * elementSize / order - xi)));

    /* generate an expression to evaluate
     *   (x_0 - x) * (x_1 - x) .. (x_n - x) / (x_i - x)
     * obviously the term (x_i - x) is cancelled in this result, but included here to make the result clear
     */
    Expression expr = factors.Skip(1).Aggregate(factors[0], Expression.Multiply);
    // multiplying by scale forces the condition f(x_i)=1
    expr = Expression.Multiply(Expression.Constant(1.0 / scaleInv), expr);

    Expression<Func<double, double>> lambdaMethod = Expression.Lambda<Func<double, double>>(expr, xp);
    return lambdaMethod.Compile();
}

문제 : 나는 또한 Ï ² = dÏ / dx를 평가할 필요가있다. 이것을하기 위해, 나는 Ï = scaleÃ- (x_0 - x) (x_1 - x) Ã - - à (x_n - x) / + Î ± _nÃ-x ^ (n-1) + .. + Î ± _1-x + Î ± _0. 이것은 (n-1) + (n-1) Ã-Î ± _nÃ-x ^ (n-2) + .. + 1Ã-Î ± _1을 제공합니다.

계산상의 이유로, 우리는 Math.Pow 를 호출하지 않고 최종 답을 다시 쓸 수 있습니다. Math.Pow 를 쓰면 ϲ² = xÃ- (xà - (..) - β_2) - β_1) - β_0이됩니다.

이 모든 "속임수"(모든 기본적인 대수)를 수행하려면 다음과 같은 방법이 필요합니다.

  1. ConstantExpressionParameterExpression 및 기본 수학 연산을 포함하는 인수 Expression 을 확장합니다 ( NodeType BinaryExpression 으로 설정 한 BinaryExpression 으로 끝남). 여기에 결과는 특별한 방식으로 처리 할 Math.PowMethodInfo 대한 InvocationExpression 요소를 포함 할 수 있습니다 전역.
  2. 그런 다음 지정된 ParameterExpression 과 관련하여 미분을 취합니다. 결과에서 Math.Pow 의 호출에 대한 오른쪽 매개 변수가 상수 2 Math.PowConstantExpression(2) 에 왼쪽 부분에 곱한 값으로 대체됩니다 ConstantExpression(2) Math.Pow(x,1) 의 호출은 다음과 Math.Pow(x,1) 제거됨). x에 대해 상수이기 때문에 0이되는 결과의 용어는 제거됩니다.
  3. 그런 다음 특정 ParameterExpression 의 인스턴스를 Math.Pow 호출의 왼쪽 매개 변수로 Math.Pow 합니다. 호출의 오른쪽이 값 1ConstantExpression 이되면 호출을 ParameterExpression 자체로 바꿉니다.

앞으로는 ParameterExpression 을 취하여 해당 매개 변수를 기반으로 평가되는 Expression 을 반환 할 방법을 원합니다. 그렇게하면 생성 된 함수를 집계 할 수 있습니다. 나는 아직 거기에 없다. 미래에는 LINQ 표현식을 상징적 인 수학으로 사용하기위한 일반 라이브러리를 공개하기를 희망합니다.

수락 된 답변

.NET 4에서 ExpressionVisitor 형식을 사용하여 여러 가지 상징적 인 수학 기능의 기초를 썼습니다. 완벽하지는 않지만 실행 가능한 솔루션의 기초처럼 보입니다.

  • SymbolicExpand , SimplifyPartialDerivative 와 같은 메소드를 표시하는 public static 클래스입니다.
  • ExpandVisitor 는 표현을 확장하는 내부 도우미 유형입니다.
  • SimplifyVisitor 는 표현을 단순화하는 내부 도우미 유형입니다.
  • DerivativeVisitor 는 표현식의 파생물을 취하는 내부 도우미 유형입니다.
  • ListPrintVisitorExpression 을 Lisp 구문으로 접두사 표기법으로 변환하는 내부 도우미 유형입니다.

Symbolic

public static class Symbolic
{
    public static Expression Expand(Expression expression)
    {
        return new ExpandVisitor().Visit(expression);
    }

    public static Expression Simplify(Expression expression)
    {
        return new SimplifyVisitor().Visit(expression);
    }

    public static Expression PartialDerivative(Expression expression, ParameterExpression parameter)
    {
        bool totalDerivative = false;
        return new DerivativeVisitor(parameter, totalDerivative).Visit(expression);
    }

    public static string ToString(Expression expression)
    {
        ConstantExpression result = (ConstantExpression)new ListPrintVisitor().Visit(expression);
        return result.Value.ToString();
    }
}

ExpandVisitor 표현식 확장하기

internal class ExpandVisitor : ExpressionVisitor
{
    protected override Expression VisitBinary(BinaryExpression node)
    {
        var left = Visit(node.Left);
        var right = Visit(node.Right);

        if (node.NodeType == ExpressionType.Multiply)
        {
            Expression[] leftNodes = GetAddedNodes(left).ToArray();
            Expression[] rightNodes = GetAddedNodes(right).ToArray();
            var result =
                leftNodes
                .SelectMany(x => rightNodes.Select(y => Expression.Multiply(x, y)))
                .Aggregate((sum, term) => Expression.Add(sum, term));

            return result;
        }

        if (node.Left == left && node.Right == right)
            return node;

        return Expression.MakeBinary(node.NodeType, left, right, node.IsLiftedToNull, node.Method, node.Conversion);
    }

    /// <summary>
    /// Treats the <paramref name="node"/> as the sum (or difference) of one or more child nodes and returns the
    /// the individual addends in the sum.
    /// </summary>
    private static IEnumerable<Expression> GetAddedNodes(Expression node)
    {
        BinaryExpression binary = node as BinaryExpression;
        if (binary != null)
        {
            switch (binary.NodeType)
            {
            case ExpressionType.Add:
                foreach (var n in GetAddedNodes(binary.Left))
                    yield return n;

                foreach (var n in GetAddedNodes(binary.Right))
                    yield return n;

                yield break;

            case ExpressionType.Subtract:
                foreach (var n in GetAddedNodes(binary.Left))
                    yield return n;

                foreach (var n in GetAddedNodes(binary.Right))
                    yield return Expression.Negate(n);

                yield break;

            default:
                break;
            }
        }

        yield return node;
    }
}

DerivativeVisitor 를 사용하여 파생물 가져 오기

internal class DerivativeVisitor : ExpressionVisitor
{
    private ParameterExpression _parameter;
    private bool _totalDerivative;

    public DerivativeVisitor(ParameterExpression parameter, bool totalDerivative)
    {
        if (_totalDerivative)
            throw new NotImplementedException();

        _parameter = parameter;
        _totalDerivative = totalDerivative;
    }

    protected override Expression VisitBinary(BinaryExpression node)
    {
        switch (node.NodeType)
        {
        case ExpressionType.Add:
        case ExpressionType.Subtract:
            return Expression.MakeBinary(node.NodeType, Visit(node.Left), Visit(node.Right));

        case ExpressionType.Multiply:
            return Expression.Add(Expression.Multiply(node.Left, Visit(node.Right)), Expression.Multiply(Visit(node.Left), node.Right));

        case ExpressionType.Divide:
            return Expression.Divide(Expression.Subtract(Expression.Multiply(Visit(node.Left), node.Right), Expression.Multiply(node.Left, Visit(node.Right))), Expression.Power(node.Right, Expression.Constant(2)));

        case ExpressionType.Power:
            if (node.Right is ConstantExpression)
            {
                return Expression.Multiply(node.Right, Expression.Multiply(Visit(node.Left), Expression.Subtract(node.Right, Expression.Constant(1))));
            }
            else if (node.Left is ConstantExpression)
            {
                return Expression.Multiply(node, MathExpressions.Log(node.Left));
            }
            else
            {
                return Expression.Multiply(node, Expression.Add(
                    Expression.Multiply(Visit(node.Left), Expression.Divide(node.Right, node.Left)),
                    Expression.Multiply(Visit(node.Right), MathExpressions.Log(node.Left))
                    ));
            }

        default:
            throw new NotImplementedException();
        }
    }

    protected override Expression VisitConstant(ConstantExpression node)
    {
        return MathExpressions.Zero;
    }

    protected override Expression VisitInvocation(InvocationExpression node)
    {
        MemberExpression memberExpression = node.Expression as MemberExpression;
        if (memberExpression != null)
        {
            var member = memberExpression.Member;
            if (member.DeclaringType != typeof(Math))
                throw new NotImplementedException();

            switch (member.Name)
            {
            case "Log":
                return Expression.Divide(Visit(node.Expression), node.Expression);

            case "Log10":
                return Expression.Divide(Visit(node.Expression), Expression.Multiply(Expression.Constant(Math.Log(10)), node.Expression));

            case "Exp":
            case "Sin":
            case "Cos":
            default:
                throw new NotImplementedException();
            }
        }

        throw new NotImplementedException();
    }

    protected override Expression VisitParameter(ParameterExpression node)
    {
        if (node == _parameter)
            return MathExpressions.One;

        return MathExpressions.Zero;
    }
}

SimplifyVisitor로 표현식 SimplifyVisitor

internal class SimplifyVisitor : ExpressionVisitor
{
    protected override Expression VisitBinary(BinaryExpression node)
    {
        var left = Visit(node.Left);
        var right = Visit(node.Right);

        ConstantExpression leftConstant = left as ConstantExpression;
        ConstantExpression rightConstant = right as ConstantExpression;
        if (leftConstant != null && rightConstant != null
            && (leftConstant.Value is double) && (rightConstant.Value is double))
        {
            double leftValue = (double)leftConstant.Value;
            double rightValue = (double)rightConstant.Value;

            switch (node.NodeType)
            {
            case ExpressionType.Add:
                return Expression.Constant(leftValue + rightValue);
            case ExpressionType.Subtract:
                return Expression.Constant(leftValue - rightValue);
            case ExpressionType.Multiply:
                return Expression.Constant(leftValue * rightValue);
            case ExpressionType.Divide:
                return Expression.Constant(leftValue / rightValue);
            default:
                throw new NotImplementedException();
            }
        }

        switch (node.NodeType)
        {
        case ExpressionType.Add:
            if (IsZero(left))
                return right;
            if (IsZero(right))
                return left;
            break;

        case ExpressionType.Subtract:
            if (IsZero(left))
                return Expression.Negate(right);
            if (IsZero(right))
                return left;
            break;

        case ExpressionType.Multiply:
            if (IsZero(left) || IsZero(right))
                return MathExpressions.Zero;
            if (IsOne(left))
                return right;
            if (IsOne(right))
                return left;
            break;

        case ExpressionType.Divide:
            if (IsZero(right))
                throw new DivideByZeroException();
            if (IsZero(left))
                return MathExpressions.Zero;
            if (IsOne(right))
                return left;
            break;

        default:
            throw new NotImplementedException();
        }

        return Expression.MakeBinary(node.NodeType, left, right);
    }

    protected override Expression VisitUnary(UnaryExpression node)
    {
        var operand = Visit(node.Operand);

        ConstantExpression operandConstant = operand as ConstantExpression;
        if (operandConstant != null && (operandConstant.Value is double))
        {
            double operandValue = (double)operandConstant.Value;

            switch (node.NodeType)
            {
            case ExpressionType.Negate:
                if (operandValue == 0.0)
                    return MathExpressions.Zero;

                return Expression.Constant(-operandValue);

            default:
                throw new NotImplementedException();
            }
        }

        switch (node.NodeType)
        {
        case ExpressionType.Negate:
            if (operand.NodeType == ExpressionType.Negate)
            {
                return ((UnaryExpression)operand).Operand;
            }

            break;

        default:
            throw new NotImplementedException();
        }

        return Expression.MakeUnary(node.NodeType, operand, node.Type);
    }

    private static bool IsZero(Expression expression)
    {
        ConstantExpression constant = expression as ConstantExpression;
        if (constant != null)
        {
            if (constant.Value.Equals(0.0))
                return true;
        }

        return false;
    }

    private static bool IsOne(Expression expression)
    {
        ConstantExpression constant = expression as ConstantExpression;
        if (constant != null)
        {
            if (constant.Value.Equals(1.0))
                return true;
        }

        return false;
    }
}

ListPrintVisitor 로 표시 할 수식 서식 지정

internal class ListPrintVisitor : ExpressionVisitor
{
    protected override Expression VisitBinary(BinaryExpression node)
    {
        string op = null;

        switch (node.NodeType)
        {
        case ExpressionType.Add:
            op = "+";
            break;
        case ExpressionType.Subtract:
            op = "-";
            break;
        case ExpressionType.Multiply:
            op = "*";
            break;
        case ExpressionType.Divide:
            op = "/";
            break;
        default:
            throw new NotImplementedException();
        }

        var left = Visit(node.Left);
        var right = Visit(node.Right);
        string result = string.Format("({0} {1} {2})", op, ((ConstantExpression)left).Value, ((ConstantExpression)right).Value);
        return Expression.Constant(result);
    }

    protected override Expression VisitConstant(ConstantExpression node)
    {
        if (node.Value is string)
            return node;

        return Expression.Constant(node.Value.ToString());
    }

    protected override Expression VisitParameter(ParameterExpression node)
    {
        return Expression.Constant(node.Name);
    }
}

결과 테스트하기

[TestMethod]
public void BasicSymbolicTest()
{
    ParameterExpression x = Expression.Parameter(typeof(double), "x");
    Expression linear = Expression.Add(Expression.Constant(3.0), x);
    Assert.AreEqual("(+ 3 x)", Symbolic.ToString(linear));

    Expression quadratic = Expression.Multiply(linear, Expression.Add(Expression.Constant(2.0), x));
    Assert.AreEqual("(* (+ 3 x) (+ 2 x))", Symbolic.ToString(quadratic));

    Expression expanded = Symbolic.Expand(quadratic);
    Assert.AreEqual("(+ (+ (+ (* 3 2) (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(expanded));
    Assert.AreEqual("(+ (+ (+ 6 (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(Symbolic.Simplify(expanded)));

    Expression derivative = Symbolic.PartialDerivative(expanded, x);
    Assert.AreEqual("(+ (+ (+ (+ (* 3 0) (* 0 2)) (+ (* 3 1) (* 0 x))) (+ (* x 0) (* 1 2))) (+ (* x 1) (* 1 x)))", Symbolic.ToString(derivative));

    Expression simplified = Symbolic.Simplify(derivative);
    Assert.AreEqual("(+ 5 (+ x x))", Symbolic.ToString(simplified));
}


아래 라이선스: CC-BY-SA with attribution
와 제휴하지 않음 Stack Overflow
아래 라이선스: CC-BY-SA with attribution
와 제휴하지 않음 Stack Overflow