package edu.princeton.cs.algorithms;

import edu.princeton.cs.introcs.StdOut;
import java.util.Iterator;

/* loaded from: input_file:edu/princeton/cs/algorithms/SparseVector.class */
public class SparseVector {
    private int N;
    private ST<Integer, Double> st = new ST<>();

    public SparseVector(int i) {
        this.N = i;
    }

    public void put(int i, double d) {
        if (i < 0 || i >= this.N) {
            throw new IndexOutOfBoundsException("Illegal index");
        }
        if (d == 0.0d) {
            this.st.delete(Integer.valueOf(i));
        } else {
            this.st.put(Integer.valueOf(i), Double.valueOf(d));
        }
    }

    public double get(int i) {
        if (i < 0 || i >= this.N) {
            throw new IndexOutOfBoundsException("Illegal index");
        }
        if (this.st.contains(Integer.valueOf(i))) {
            return this.st.get(Integer.valueOf(i)).doubleValue();
        }
        return 0.0d;
    }

    public int nnz() {
        return this.st.size();
    }

    public int size() {
        return this.N;
    }

    public double dot(SparseVector sparseVector) {
        if (this.N != sparseVector.N) {
            throw new IllegalArgumentException("Vector lengths disagree");
        }
        double d = 0.0d;
        if (this.st.size() <= sparseVector.st.size()) {
            Iterator<Integer> it = this.st.keys().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (sparseVector.st.contains(Integer.valueOf(intValue))) {
                    d += get(intValue) * sparseVector.get(intValue);
                }
            }
        } else {
            Iterator<Integer> it2 = sparseVector.st.keys().iterator();
            while (it2.hasNext()) {
                int intValue2 = it2.next().intValue();
                if (this.st.contains(Integer.valueOf(intValue2))) {
                    d += get(intValue2) * sparseVector.get(intValue2);
                }
            }
        }
        return d;
    }

    public double dot(double[] dArr) {
        double d = 0.0d;
        Iterator<Integer> it = this.st.keys().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            d += dArr[intValue] * get(intValue);
        }
        return d;
    }

    public double norm() {
        return Math.sqrt(dot(this));
    }

    public SparseVector scale(double d) {
        SparseVector sparseVector = new SparseVector(this.N);
        Iterator<Integer> it = this.st.keys().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            sparseVector.put(intValue, d * get(intValue));
        }
        return sparseVector;
    }

    public SparseVector plus(SparseVector sparseVector) {
        if (this.N != sparseVector.N) {
            throw new IllegalArgumentException("Vector lengths disagree");
        }
        SparseVector sparseVector2 = new SparseVector(this.N);
        Iterator<Integer> it = this.st.keys().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            sparseVector2.put(intValue, get(intValue));
        }
        Iterator<Integer> it2 = sparseVector.st.keys().iterator();
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            sparseVector2.put(intValue2, sparseVector.get(intValue2) + sparseVector2.get(intValue2));
        }
        return sparseVector2;
    }

    public String toString() {
        String str = "";
        Iterator<Integer> it = this.st.keys().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            str = str + "(" + intValue + ", " + this.st.get(Integer.valueOf(intValue)) + ") ";
        }
        return str;
    }

    public static void main(String[] strArr) {
        SparseVector sparseVector = new SparseVector(10);
        SparseVector sparseVector2 = new SparseVector(10);
        sparseVector.put(3, 0.5d);
        sparseVector.put(9, 0.75d);
        sparseVector.put(6, 0.11d);
        sparseVector.put(6, 0.0d);
        sparseVector2.put(3, 0.6d);
        sparseVector2.put(4, 0.9d);
        StdOut.println("a = " + sparseVector);
        StdOut.println("b = " + sparseVector2);
        StdOut.println("a dot b = " + sparseVector.dot(sparseVector2));
        StdOut.println("a + b   = " + sparseVector.plus(sparseVector2));
    }
}
