artificial-intelligence

Linear Regression in Artificial Intelligence

Linear regression is the "hello world" of machine learning. It predicts a number based on input data by finding the best straight line through your data p...

26 Mar 2024

Linear regression is the "hello world" of machine learning. It predicts a number based on input data by finding the best straight line through your data points.

Think of it like this: if you plot house sizes on the X axis and prices on the Y axis, linear regression draws the line that best fits those dots. Then you can feed it a new size and it predicts the price.

The math (kept simple)

The equation:

y = mx + b

  • y -- the thing you're predicting (price, temperature, sales)
  • x -- the input feature (size, time, ad spend)
  • m -- the slope. How much y changes when x changes by 1.
  • b -- the y-intercept. Where the line crosses zero.

The model's job: find values of m and b that minimize the gap between predicted values and actual values. That gap is called the "loss," and the most common measure is mean squared error -- the average of (predicted - actual) squared.

Building it with TensorFlow.js

You don't need Python for this. TensorFlow.js runs machine learning directly in Node.js or the browser.

Here's a complete example:

Js
const tf = require('@tensorflow/tfjs');

const X = tf.tensor2d([[1], [2], [3], [4], [5]]);
const Y = tf.tensor2d([[2], [3], [4], [5], [6]]);

const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));

model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });

async function trainModel() {
  await model.fit(X, Y, { epochs: 100 });
  console.log('Training complete.');

  const weights = model.getWeights();
  const slope = weights[0].dataSync()[0];
  const yIntercept = weights[1].dataSync()[0];
  console.log(`Slope: ${slope}, Y-intercept: ${yIntercept}`);

  const newX = tf.tensor2d([[6]]);
  const predictedY = model.predict(newX);
  predictedY.print();
}

trainModel();

// Output:
// Training complete.
// Slope: ~1.0, Y-intercept: ~1.0
// Tensor [[~7.0]]

What's happening: the model trains for 100 epochs (passes through the data), adjusting m and b each time to reduce the error. After training, it predicts that when x=6, y should be approximately 7. Which makes sense -- our data follows y = x + 1.

When to use linear regression

Good for: Predicting continuous values where the relationship between input and output is roughly linear. Sales forecasting, trend analysis, simple pricing models.

Bad for: Non-linear relationships, classification problems, or anything where the input-output relationship has curves, jumps, or complex interactions. For those, you need different tools.

The benefit of linear regression: it's fast, interpretable, and hard to overfit. You can explain the model to a non-technical stakeholder by showing them the line.

The cost: it assumes linearity. Real-world data is rarely that clean. It's a starting point, not a destination.